feat(tokens-repo): make device names unique before storage

redis/token-repo
Houkime 2022-12-26 15:51:12 +00:00
parent 8235c3595c
commit 450ff41ebd
2 changed files with 28 additions and 1 deletions

View File

@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
from mnemonic import Mnemonic
from secrets import randbelow
import re
from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.repositories.tokens.exceptions import (
@ -39,7 +41,8 @@ class AbstractTokensRepository(ABC):
def create_token(self, device_name: str) -> Token:
"""Create new token"""
new_token = Token.generate(device_name)
unique_name = self._make_unique_device_name(device_name)
new_token = Token.generate(unique_name)
self._store_token(new_token)
@ -160,6 +163,19 @@ class AbstractTokensRepository(ABC):
def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
"""Retrieves new device key that is already stored."""
def _make_unique_device_name(self, name: str) -> str:
"""Token name must be an alphanumeric string and not empty.
Replace invalid characters with '_'
If name exists, add a random number to the end of the name until it is unique.
"""
if not re.match("^[a-zA-Z0-9]*$", name):
name = re.sub("[^a-zA-Z0-9]", "_", name)
if name == "":
name = "Unknown device"
while self.is_token_name_exists(name):
name += str(randbelow(10))
return name
# TODO: find a proper place for it
def _assert_mnemonic(self, hex_key: str, mnemonic_phrase: str):
"""Return true if hex string matches the phrase, false otherwise

View File

@ -257,6 +257,17 @@ def test_create_token(empty_repo, mock_token_generate):
]
def test_create_token_existing(some_tokens_repo):
repo = some_tokens_repo
old_token = repo.get_tokens()[0]
new_token = repo.create_token(device_name=old_token.device_name)
assert new_token.device_name != old_token.device_name
assert old_token in repo.get_tokens()
assert new_token in repo.get_tokens()
def test_delete_token(some_tokens_repo):
repo = some_tokens_repo
original_tokens = repo.get_tokens()