diff --git a/.drone.yml b/.drone.yml index 0f5f93a..24ab5da 100644 --- a/.drone.yml +++ b/.drone.yml @@ -5,12 +5,16 @@ name: default steps: - name: Run Tests and Generate Coverage Report commands: + - kill $(ps aux | grep '[r]edis-server 127.0.0.1:6389' | awk '{print $2}') + - redis-server --bind 127.0.0.1 --port 6389 >/dev/null & - coverage run -m pytest -q - coverage xml - sonar-scanner -Dsonar.projectKey=SelfPrivacy-REST-API -Dsonar.sources=. -Dsonar.host.url=http://analyzer.lan:9000 -Dsonar.login="$SONARQUBE_TOKEN" environment: SONARQUBE_TOKEN: from_secret: SONARQUBE_TOKEN + USE_REDIS_PORT: 6389 + - name: Run Bandit Checks commands: diff --git a/selfprivacy_api/actions/api_tokens.py b/selfprivacy_api/actions/api_tokens.py index 61c695d..38133fd 100644 --- a/selfprivacy_api/actions/api_tokens.py +++ b/selfprivacy_api/actions/api_tokens.py @@ -2,20 +2,19 @@ from datetime import datetime from typing import Optional from pydantic import BaseModel +from mnemonic import Mnemonic - -from selfprivacy_api.utils.auth import ( - delete_token, - generate_recovery_token, - get_recovery_token_status, - get_tokens_info, - is_recovery_token_exists, - is_recovery_token_valid, - is_token_name_exists, - is_token_name_pair_valid, - refresh_token, - get_token_name, +from selfprivacy_api.repositories.tokens.json_tokens_repository import ( + JsonTokensRepository, ) +from selfprivacy_api.repositories.tokens.exceptions import ( + TokenNotFound, + RecoveryKeyNotFound, + InvalidMnemonic, + NewDeviceKeyNotFound, +) + +TOKEN_REPO = JsonTokensRepository() class TokenInfoWithIsCaller(BaseModel): @@ -28,18 +27,23 @@ class TokenInfoWithIsCaller(BaseModel): def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCaller]: """Get the tokens info""" - caller_name = get_token_name(caller_token) - tokens = get_tokens_info() + caller_name = TOKEN_REPO.get_token_by_token_string(caller_token).device_name + tokens = TOKEN_REPO.get_tokens() return [ TokenInfoWithIsCaller( - name=token.name, - date=token.date, - is_caller=token.name == caller_name, + name=token.device_name, + date=token.created_at, + is_caller=token.device_name == caller_name, ) for token in tokens ] +def is_token_valid(token) -> bool: + """Check if token is valid""" + return TOKEN_REPO.is_token_valid(token) + + class NotFoundException(Exception): """Not found exception""" @@ -50,19 +54,22 @@ class CannotDeleteCallerException(Exception): def delete_api_token(caller_token: str, token_name: str) -> None: """Delete the token""" - if is_token_name_pair_valid(token_name, caller_token): + if TOKEN_REPO.is_token_name_pair_valid(token_name, caller_token): raise CannotDeleteCallerException("Cannot delete caller's token") - if not is_token_name_exists(token_name): + if not TOKEN_REPO.is_token_name_exists(token_name): raise NotFoundException("Token not found") - delete_token(token_name) + token = TOKEN_REPO.get_token_by_name(token_name) + TOKEN_REPO.delete_token(token) def refresh_api_token(caller_token: str) -> str: """Refresh the token""" - new_token = refresh_token(caller_token) - if new_token is None: + try: + old_token = TOKEN_REPO.get_token_by_token_string(caller_token) + new_token = TOKEN_REPO.refresh_token(old_token) + except TokenNotFound: raise NotFoundException("Token not found") - return new_token + return new_token.token class RecoveryTokenStatus(BaseModel): @@ -77,18 +84,16 @@ class RecoveryTokenStatus(BaseModel): def get_api_recovery_token_status() -> RecoveryTokenStatus: """Get the recovery token status""" - if not is_recovery_token_exists(): + token = TOKEN_REPO.get_recovery_key() + if token is None: return RecoveryTokenStatus(exists=False, valid=False) - status = get_recovery_token_status() - if status is None: - return RecoveryTokenStatus(exists=False, valid=False) - is_valid = is_recovery_token_valid() + is_valid = TOKEN_REPO.is_recovery_key_valid() return RecoveryTokenStatus( exists=True, valid=is_valid, - date=status["date"], - expiration=status["expiration"], - uses_left=status["uses_left"], + date=token.created_at, + expiration=token.expires_at, + uses_left=token.uses_left, ) @@ -112,5 +117,46 @@ def get_new_api_recovery_key( if uses_left <= 0: raise InvalidUsesLeft("Uses must be greater than 0") - key = generate_recovery_token(expiration_date, uses_left) - return key + key = TOKEN_REPO.create_recovery_key(expiration_date, uses_left) + mnemonic_phrase = Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key)) + return mnemonic_phrase + + +def use_mnemonic_recovery_token(mnemonic_phrase, name): + """Use the recovery token by converting the mnemonic word list to a byte array. + If the recovery token if invalid itself, return None + If the binary representation of phrase not matches + the byte array of the recovery token, return None. + If the mnemonic phrase is valid then generate a device token and return it. + Substract 1 from uses_left if it exists. + mnemonic_phrase is a string representation of the mnemonic word list. + """ + try: + token = TOKEN_REPO.use_mnemonic_recovery_key(mnemonic_phrase, name) + return token.token + except (RecoveryKeyNotFound, InvalidMnemonic): + return None + + +def delete_new_device_auth_token() -> None: + TOKEN_REPO.delete_new_device_key() + + +def get_new_device_auth_token() -> str: + """Generate and store a new device auth token which is valid for 10 minutes + and return a mnemonic phrase representation + """ + key = TOKEN_REPO.get_new_device_key() + return Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key)) + + +def use_new_device_auth_token(mnemonic_phrase, name) -> Optional[str]: + """Use the new device auth token by converting the mnemonic string to a byte array. + If the mnemonic phrase is valid then generate a device token and return it. + New device auth token must be deleted. + """ + try: + token = TOKEN_REPO.use_mnemonic_new_device_key(mnemonic_phrase, name) + return token.token + except (NewDeviceKeyNotFound, InvalidMnemonic): + return None diff --git a/selfprivacy_api/dependencies.py b/selfprivacy_api/dependencies.py index 9568a40..1348f65 100644 --- a/selfprivacy_api/dependencies.py +++ b/selfprivacy_api/dependencies.py @@ -2,7 +2,7 @@ from fastapi import Depends, HTTPException, status from fastapi.security import APIKeyHeader from pydantic import BaseModel -from selfprivacy_api.utils.auth import is_token_valid +from selfprivacy_api.actions.api_tokens import is_token_valid class TokenHeader(BaseModel): diff --git a/selfprivacy_api/graphql/__init__.py b/selfprivacy_api/graphql/__init__.py index 7372197..6124a1a 100644 --- a/selfprivacy_api/graphql/__init__.py +++ b/selfprivacy_api/graphql/__init__.py @@ -4,7 +4,7 @@ import typing from strawberry.permission import BasePermission from strawberry.types import Info -from selfprivacy_api.utils.auth import is_token_valid +from selfprivacy_api.actions.api_tokens import is_token_valid class IsAuthenticated(BasePermission): diff --git a/selfprivacy_api/graphql/mutations/api_mutations.py b/selfprivacy_api/graphql/mutations/api_mutations.py index c6727db..49c49ad 100644 --- a/selfprivacy_api/graphql/mutations/api_mutations.py +++ b/selfprivacy_api/graphql/mutations/api_mutations.py @@ -11,6 +11,11 @@ from selfprivacy_api.actions.api_tokens import ( NotFoundException, delete_api_token, get_new_api_recovery_key, + use_mnemonic_recovery_token, + refresh_api_token, + delete_new_device_auth_token, + get_new_device_auth_token, + use_new_device_auth_token, ) from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql.mutations.mutation_interface import ( @@ -18,14 +23,6 @@ from selfprivacy_api.graphql.mutations.mutation_interface import ( MutationReturnInterface, ) -from selfprivacy_api.utils.auth import ( - delete_new_device_auth_token, - get_new_device_auth_token, - refresh_token, - use_mnemonic_recoverery_token, - use_new_device_auth_token, -) - @strawberry.type class ApiKeyMutationReturn(MutationReturnInterface): @@ -98,50 +95,53 @@ class ApiMutations: self, input: UseRecoveryKeyInput ) -> DeviceApiTokenMutationReturn: """Use recovery key""" - token = use_mnemonic_recoverery_token(input.key, input.deviceName) - if token is None: + token = use_mnemonic_recovery_token(input.key, input.deviceName) + if token is not None: + return DeviceApiTokenMutationReturn( + success=True, + message="Recovery key used", + code=200, + token=token, + ) + else: return DeviceApiTokenMutationReturn( success=False, message="Recovery key not found", code=404, token=None, ) - return DeviceApiTokenMutationReturn( - success=True, - message="Recovery key used", - code=200, - token=token, - ) @strawberry.mutation(permission_classes=[IsAuthenticated]) def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn: """Refresh device api token""" - token = ( + token_string = ( info.context["request"] .headers.get("Authorization", "") .replace("Bearer ", "") ) - if token is None: + if token_string is None: return DeviceApiTokenMutationReturn( success=False, message="Token not found", code=404, token=None, ) - new_token = refresh_token(token) - if new_token is None: + + try: + new_token = refresh_api_token(token_string) + return DeviceApiTokenMutationReturn( + success=True, + message="Token refreshed", + code=200, + token=new_token, + ) + except NotFoundException: return DeviceApiTokenMutationReturn( success=False, message="Token not found", code=404, token=None, ) - return DeviceApiTokenMutationReturn( - success=True, - message="Token refreshed", - code=200, - token=new_token, - ) @strawberry.mutation(permission_classes=[IsAuthenticated]) def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn: diff --git a/selfprivacy_api/graphql/queries/api_queries.py b/selfprivacy_api/graphql/queries/api_queries.py index 7994a8f..cf56231 100644 --- a/selfprivacy_api/graphql/queries/api_queries.py +++ b/selfprivacy_api/graphql/queries/api_queries.py @@ -4,16 +4,12 @@ import datetime import typing import strawberry from strawberry.types import Info -from selfprivacy_api.actions.api_tokens import get_api_tokens_with_caller_flag -from selfprivacy_api.graphql import IsAuthenticated -from selfprivacy_api.utils import parse_date -from selfprivacy_api.dependencies import get_api_version as get_api_version_dependency - -from selfprivacy_api.utils.auth import ( - get_recovery_token_status, - is_recovery_token_exists, - is_recovery_token_valid, +from selfprivacy_api.actions.api_tokens import ( + get_api_tokens_with_caller_flag, + get_api_recovery_token_status, ) +from selfprivacy_api.graphql import IsAuthenticated +from selfprivacy_api.dependencies import get_api_version as get_api_version_dependency def get_api_version() -> str: @@ -43,16 +39,8 @@ class ApiRecoveryKeyStatus: def get_recovery_key_status() -> ApiRecoveryKeyStatus: """Get recovery key status""" - if not is_recovery_token_exists(): - return ApiRecoveryKeyStatus( - exists=False, - valid=False, - creation_date=None, - expiration_date=None, - uses_left=None, - ) - status = get_recovery_token_status() - if status is None: + status = get_api_recovery_token_status() + if status is None or not status.exists: return ApiRecoveryKeyStatus( exists=False, valid=False, @@ -62,12 +50,10 @@ def get_recovery_key_status() -> ApiRecoveryKeyStatus: ) return ApiRecoveryKeyStatus( exists=True, - valid=is_recovery_token_valid(), - creation_date=parse_date(status["date"]), - expiration_date=parse_date(status["expiration"]) - if status["expiration"] is not None - else None, - uses_left=status["uses_left"] if status["uses_left"] is not None else None, + valid=status.valid, + creation_date=status.date, + expiration_date=status.expiration, + uses_left=status.uses_left, ) diff --git a/selfprivacy_api/jobs/__init__.py b/selfprivacy_api/jobs/__init__.py index 1547b84..fe4a053 100644 --- a/selfprivacy_api/jobs/__init__.py +++ b/selfprivacy_api/jobs/__init__.py @@ -97,8 +97,8 @@ class Jobs: error=None, result=None, ) - r = RedisPool().get_connection() - _store_job_as_hash(r, _redis_key_from_uuid(job.uid), job) + redis = RedisPool().get_connection() + _store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job) return job @staticmethod @@ -113,10 +113,10 @@ class Jobs: """ Remove a job from the jobs list. """ - r = RedisPool().get_connection() + redis = RedisPool().get_connection() key = _redis_key_from_uuid(job_uuid) - if (r.exists(key)): - r.delete(key) + if redis.exists(key): + redis.delete(key) return True return False @@ -149,12 +149,12 @@ class Jobs: if status in (JobStatus.FINISHED, JobStatus.ERROR): job.finished_at = datetime.datetime.now() - r = RedisPool().get_connection() + redis = RedisPool().get_connection() key = _redis_key_from_uuid(job.uid) - if r.exists(key): - _store_job_as_hash(r, key, job) + if redis.exists(key): + _store_job_as_hash(redis, key, job) if status in (JobStatus.FINISHED, JobStatus.ERROR): - r.expire(key, JOB_EXPIRATION_SECONDS) + redis.expire(key, JOB_EXPIRATION_SECONDS) return job @@ -163,10 +163,10 @@ class Jobs: """ Get a job from the jobs list. """ - r = RedisPool().get_connection() + redis = RedisPool().get_connection() key = _redis_key_from_uuid(uid) - if r.exists(key): - return _job_from_hash(r, key) + if redis.exists(key): + return _job_from_hash(redis, key) return None @staticmethod @@ -174,9 +174,14 @@ class Jobs: """ Get the jobs list. """ - r = RedisPool().get_connection() - jobs = r.keys("jobs:*") - return [_job_from_hash(r, job_key) for job_key in jobs] + redis = RedisPool().get_connection() + job_keys = redis.keys("jobs:*") + jobs = [] + for job_key in job_keys: + job = _job_from_hash(redis, job_key) + if job is not None: + jobs.append(job) + return jobs @staticmethod def is_busy() -> bool: @@ -189,11 +194,11 @@ class Jobs: return False -def _redis_key_from_uuid(uuid): - return "jobs:" + str(uuid) +def _redis_key_from_uuid(uuid_string): + return "jobs:" + str(uuid_string) -def _store_job_as_hash(r, redis_key, model): +def _store_job_as_hash(redis, redis_key, model): for key, value in model.dict().items(): if isinstance(value, uuid.UUID): value = str(value) @@ -201,12 +206,12 @@ def _store_job_as_hash(r, redis_key, model): value = value.isoformat() if isinstance(value, JobStatus): value = value.value - r.hset(redis_key, key, str(value)) + redis.hset(redis_key, key, str(value)) -def _job_from_hash(r, redis_key): - if r.exists(redis_key): - job_dict = r.hgetall(redis_key) +def _job_from_hash(redis, redis_key): + if redis.exists(redis_key): + job_dict = redis.hgetall(redis_key) for date in [ "created_at", "updated_at", diff --git a/selfprivacy_api/repositories/tokens/abstract_tokens_repository.py b/selfprivacy_api/repositories/tokens/abstract_tokens_repository.py index a67d62d..3a20ede 100644 --- a/selfprivacy_api/repositories/tokens/abstract_tokens_repository.py +++ b/selfprivacy_api/repositories/tokens/abstract_tokens_repository.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from datetime import datetime from typing import Optional from mnemonic import Mnemonic +from secrets import randbelow +import re from selfprivacy_api.models.tokens.token import Token from selfprivacy_api.repositories.tokens.exceptions import ( @@ -15,7 +17,7 @@ from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey class AbstractTokensRepository(ABC): - def get_token_by_token_string(self, token_string: str) -> Optional[Token]: + def get_token_by_token_string(self, token_string: str) -> Token: """Get the token by token""" tokens = self.get_tokens() for token in tokens: @@ -24,7 +26,7 @@ class AbstractTokensRepository(ABC): raise TokenNotFound("Token not found!") - def get_token_by_name(self, token_name: str) -> Optional[Token]: + def get_token_by_name(self, token_name: str) -> Token: """Get the token by name""" tokens = self.get_tokens() for token in tokens: @@ -39,7 +41,8 @@ class AbstractTokensRepository(ABC): def create_token(self, device_name: str) -> Token: """Create new token""" - new_token = Token.generate(device_name) + unique_name = self._make_unique_device_name(device_name) + new_token = Token.generate(unique_name) self._store_token(new_token) @@ -52,6 +55,7 @@ class AbstractTokensRepository(ABC): def refresh_token(self, input_token: Token) -> Token: """Change the token field of the existing token""" new_token = Token.generate(device_name=input_token.device_name) + new_token.created_at = input_token.created_at if input_token in self.get_tokens(): self.delete_token(input_token) @@ -62,22 +66,19 @@ class AbstractTokensRepository(ABC): def is_token_valid(self, token_string: str) -> bool: """Check if the token is valid""" - token = self.get_token_by_token_string(token_string) - if token is None: - return False - return True + return token_string in [token.token for token in self.get_tokens()] def is_token_name_exists(self, token_name: str) -> bool: """Check if the token name exists""" - token = self.get_token_by_name(token_name) - if token is None: - return False - return True + return token_name in [token.device_name for token in self.get_tokens()] def is_token_name_pair_valid(self, token_name: str, token_string: str) -> bool: """Check if the token name and token are valid""" - token = self.get_token_by_name(token_name) - if token is None: + try: + token = self.get_token_by_name(token_name) + if token is None: + return False + except TokenNotFound: return False return token.token == token_string @@ -100,7 +101,12 @@ class AbstractTokensRepository(ABC): if not self.is_recovery_key_valid(): raise RecoveryKeyNotFound("Recovery key not found") - recovery_hex_key = self.get_recovery_key().key + recovery_key = self.get_recovery_key() + + if recovery_key is None: + raise RecoveryKeyNotFound("Recovery key not found") + + recovery_hex_key = recovery_key.key if not self._assert_mnemonic(recovery_hex_key, mnemonic_phrase): raise RecoveryKeyNotFound("Recovery key not found") @@ -117,9 +123,15 @@ class AbstractTokensRepository(ABC): return False return recovery_key.is_valid() - @abstractmethod def get_new_device_key(self) -> NewDeviceKey: """Creates and returns the new device key""" + new_device_key = NewDeviceKey.generate() + self._store_new_device_key(new_device_key) + + return new_device_key + + def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None: + """Store new device key directly""" @abstractmethod def delete_new_device_key(self) -> None: @@ -133,6 +145,9 @@ class AbstractTokensRepository(ABC): if not new_device_key: raise NewDeviceKeyNotFound + if not new_device_key.is_valid(): + raise NewDeviceKeyNotFound + if not self._assert_mnemonic(new_device_key.key, mnemonic_phrase): raise NewDeviceKeyNotFound("Phrase is not token!") @@ -153,6 +168,19 @@ class AbstractTokensRepository(ABC): def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]: """Retrieves new device key that is already stored.""" + def _make_unique_device_name(self, name: str) -> str: + """Token name must be an alphanumeric string and not empty. + Replace invalid characters with '_' + If name exists, add a random number to the end of the name until it is unique. + """ + if not re.match("^[a-zA-Z0-9]*$", name): + name = re.sub("[^a-zA-Z0-9]", "_", name) + if name == "": + name = "Unknown device" + while self.is_token_name_exists(name): + name += str(randbelow(10)) + return name + # TODO: find a proper place for it def _assert_mnemonic(self, hex_key: str, mnemonic_phrase: str): """Return true if hex string matches the phrase, false otherwise diff --git a/selfprivacy_api/repositories/tokens/json_tokens_repository.py b/selfprivacy_api/repositories/tokens/json_tokens_repository.py index b4c0ab2..77e1311 100644 --- a/selfprivacy_api/repositories/tokens/json_tokens_repository.py +++ b/selfprivacy_api/repositories/tokens/json_tokens_repository.py @@ -69,7 +69,7 @@ class JsonTokensRepository(AbstractTokensRepository): recovery_key = RecoveryKey( key=tokens_file["recovery_token"].get("token"), created_at=tokens_file["recovery_token"].get("date"), - expires_at=tokens_file["recovery_token"].get("expitation"), + expires_at=tokens_file["recovery_token"].get("expiration"), uses_left=tokens_file["recovery_token"].get("uses_left"), ) @@ -85,10 +85,13 @@ class JsonTokensRepository(AbstractTokensRepository): recovery_key = RecoveryKey.generate(expiration, uses_left) with WriteUserData(UserDataFiles.TOKENS) as tokens_file: + key_expiration: Optional[str] = None + if recovery_key.expires_at is not None: + key_expiration = recovery_key.expires_at.strftime(DATETIME_FORMAT) tokens_file["recovery_token"] = { "token": recovery_key.key, "date": recovery_key.created_at.strftime(DATETIME_FORMAT), - "expiration": recovery_key.expires_at, + "expiration": key_expiration, "uses_left": recovery_key.uses_left, } @@ -98,12 +101,10 @@ class JsonTokensRepository(AbstractTokensRepository): """Decrement recovery key use count by one""" if self.is_recovery_key_valid(): with WriteUserData(UserDataFiles.TOKENS) as tokens: - tokens["recovery_token"]["uses_left"] -= 1 - - def get_new_device_key(self) -> NewDeviceKey: - """Creates and returns the new device key""" - new_device_key = NewDeviceKey.generate() + if tokens["recovery_token"]["uses_left"] is not None: + tokens["recovery_token"]["uses_left"] -= 1 + def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None: with WriteUserData(UserDataFiles.TOKENS) as tokens_file: tokens_file["new_device"] = { "token": new_device_key.key, @@ -111,8 +112,6 @@ class JsonTokensRepository(AbstractTokensRepository): "expiration": new_device_key.expires_at.strftime(DATETIME_FORMAT), } - return new_device_key - def delete_new_device_key(self) -> None: """Delete the new device key""" with WriteUserData(UserDataFiles.TOKENS) as tokens_file: diff --git a/selfprivacy_api/repositories/tokens/redis_tokens_repository.py b/selfprivacy_api/repositories/tokens/redis_tokens_repository.py index b1fb4b0..c72e231 100644 --- a/selfprivacy_api/repositories/tokens/redis_tokens_repository.py +++ b/selfprivacy_api/repositories/tokens/redis_tokens_repository.py @@ -32,29 +32,34 @@ class RedisTokensRepository(AbstractTokensRepository): def get_tokens(self) -> list[Token]: """Get the tokens""" - r = self.connection - token_keys = r.keys(TOKENS_PREFIX + "*") - return [self._token_from_hash(key) for key in token_keys] + redis = self.connection + token_keys = redis.keys(TOKENS_PREFIX + "*") + tokens = [] + for key in token_keys: + token = self._token_from_hash(key) + if token is not None: + tokens.append(token) + return tokens def delete_token(self, input_token: Token) -> None: """Delete the token""" - r = self.connection + redis = self.connection key = RedisTokensRepository._token_redis_key(input_token) if input_token not in self.get_tokens(): raise TokenNotFound - r.delete(key) + redis.delete(key) def reset(self): for token in self.get_tokens(): self.delete_token(token) self.delete_new_device_key() - r = self.connection - r.delete(RECOVERY_KEY_REDIS_KEY) + redis = self.connection + redis.delete(RECOVERY_KEY_REDIS_KEY) def get_recovery_key(self) -> Optional[RecoveryKey]: """Get the recovery key""" - r = self.connection - if r.exists(RECOVERY_KEY_REDIS_KEY): + redis = self.connection + if redis.exists(RECOVERY_KEY_REDIS_KEY): return self._recovery_key_from_hash(RECOVERY_KEY_REDIS_KEY) return None @@ -68,16 +73,14 @@ class RedisTokensRepository(AbstractTokensRepository): self._store_model_as_hash(RECOVERY_KEY_REDIS_KEY, recovery_key) return recovery_key - def get_new_device_key(self) -> NewDeviceKey: - """Creates and returns the new device key""" - new_device_key = NewDeviceKey.generate() + def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None: + """Store new device key directly""" self._store_model_as_hash(NEW_DEVICE_KEY_REDIS_KEY, new_device_key) - return new_device_key def delete_new_device_key(self) -> None: """Delete the new device key""" - r = self.connection - r.delete(NEW_DEVICE_KEY_REDIS_KEY) + redis = self.connection + redis.delete(NEW_DEVICE_KEY_REDIS_KEY) @staticmethod def _token_redis_key(token: Token) -> str: @@ -91,9 +94,13 @@ class RedisTokensRepository(AbstractTokensRepository): def _decrement_recovery_token(self): """Decrement recovery key use count by one""" if self.is_recovery_key_valid(): - uses_left = self.get_recovery_key().uses_left - r = self.connection - r.hset(RECOVERY_KEY_REDIS_KEY, "uses_left", uses_left - 1) + recovery_key = self.get_recovery_key() + if recovery_key is None: + return + uses_left = recovery_key.uses_left + if uses_left is not None: + redis = self.connection + redis.hset(RECOVERY_KEY_REDIS_KEY, "uses_left", uses_left - 1) def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]: """Retrieves new device key that is already stored.""" @@ -117,9 +124,9 @@ class RedisTokensRepository(AbstractTokensRepository): d[key] = None def _model_dict_from_hash(self, redis_key: str) -> Optional[dict]: - r = self.connection - if r.exists(redis_key): - token_dict = r.hgetall(redis_key) + redis = self.connection + if redis.exists(redis_key): + token_dict = redis.hgetall(redis_key) RedisTokensRepository._prepare_model_dict(token_dict) return token_dict return None @@ -140,8 +147,8 @@ class RedisTokensRepository(AbstractTokensRepository): return self._hash_as_model(redis_key, NewDeviceKey) def _store_model_as_hash(self, redis_key, model): - r = self.connection + redis = self.connection for key, value in model.dict().items(): if isinstance(value, datetime): value = value.isoformat() - r.hset(redis_key, key, str(value)) + redis.hset(redis_key, key, str(value)) diff --git a/selfprivacy_api/rest/api_auth.py b/selfprivacy_api/rest/api_auth.py index f73056c..275dac3 100644 --- a/selfprivacy_api/rest/api_auth.py +++ b/selfprivacy_api/rest/api_auth.py @@ -8,20 +8,18 @@ from selfprivacy_api.actions.api_tokens import ( InvalidUsesLeft, NotFoundException, delete_api_token, + refresh_api_token, get_api_recovery_token_status, get_api_tokens_with_caller_flag, get_new_api_recovery_key, - refresh_api_token, + use_mnemonic_recovery_token, + delete_new_device_auth_token, + get_new_device_auth_token, + use_new_device_auth_token, ) from selfprivacy_api.dependencies import TokenHeader, get_token_header -from selfprivacy_api.utils.auth import ( - delete_new_device_auth_token, - get_new_device_auth_token, - use_mnemonic_recoverery_token, - use_new_device_auth_token, -) router = APIRouter( prefix="/auth", @@ -99,7 +97,7 @@ class UseTokenInput(BaseModel): @router.post("/recovery_token/use") async def rest_use_recovery_token(input: UseTokenInput): - token = use_mnemonic_recoverery_token(input.token, input.device) + token = use_mnemonic_recovery_token(input.token, input.device) if token is None: raise HTTPException(status_code=404, detail="Token not found") return {"token": token} diff --git a/selfprivacy_api/utils/auth.py b/selfprivacy_api/utils/auth.py deleted file mode 100644 index ecaf9af..0000000 --- a/selfprivacy_api/utils/auth.py +++ /dev/null @@ -1,329 +0,0 @@ -#!/usr/bin/env python3 -"""Token management utils""" -import secrets -from datetime import datetime, timedelta -import re -import typing - -from pydantic import BaseModel -from mnemonic import Mnemonic - -from . import ReadUserData, UserDataFiles, WriteUserData, parse_date - -""" -Token are stored in the tokens.json file. -File contains device tokens, recovery token and new device auth token. -File structure: -{ - "tokens": [ - { - "token": "device token", - "name": "device name", - "date": "date of creation", - } - ], - "recovery_token": { - "token": "recovery token", - "date": "date of creation", - "expiration": "date of expiration", - "uses_left": "number of uses left" - }, - "new_device": { - "token": "new device auth token", - "date": "date of creation", - "expiration": "date of expiration", - } -} -Recovery token may or may not have expiration date and uses_left. -There may be no recovery token at all. -Device tokens must be unique. -""" - - -def _get_tokens(): - """Get all tokens as list of tokens of every device""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - return [token["token"] for token in tokens["tokens"]] - - -def _get_token_names(): - """Get all token names""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - return [t["name"] for t in tokens["tokens"]] - - -def _validate_token_name(name): - """Token name must be an alphanumeric string and not empty. - Replace invalid characters with '_' - If token name exists, add a random number to the end of the name until it is unique. - """ - if not re.match("^[a-zA-Z0-9]*$", name): - name = re.sub("[^a-zA-Z0-9]", "_", name) - if name == "": - name = "Unknown device" - while name in _get_token_names(): - name += str(secrets.randbelow(10)) - return name - - -def is_token_valid(token): - """Check if token is valid""" - if token in _get_tokens(): - return True - return False - - -def is_token_name_exists(token_name): - """Check if token name exists""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - return token_name in [t["name"] for t in tokens["tokens"]] - - -def is_token_name_pair_valid(token_name, token): - """Check if token name and token pair exists""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - for t in tokens["tokens"]: - if t["name"] == token_name and t["token"] == token: - return True - return False - - -def get_token_name(token: str) -> typing.Optional[str]: - """Return the name of the token provided""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - for t in tokens["tokens"]: - if t["token"] == token: - return t["name"] - return None - - -class BasicTokenInfo(BaseModel): - """Token info""" - - name: str - date: datetime - - -def get_tokens_info(): - """Get all tokens info without tokens themselves""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - return [ - BasicTokenInfo( - name=t["name"], - date=parse_date(t["date"]), - ) - for t in tokens["tokens"] - ] - - -def _generate_token(): - """Generates new token and makes sure it is unique""" - token = secrets.token_urlsafe(32) - while token in _get_tokens(): - token = secrets.token_urlsafe(32) - return token - - -def create_token(name): - """Create new token""" - token = _generate_token() - name = _validate_token_name(name) - with WriteUserData(UserDataFiles.TOKENS) as tokens: - tokens["tokens"].append( - { - "token": token, - "name": name, - "date": str(datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")), - } - ) - return token - - -def delete_token(token_name): - """Delete token""" - with WriteUserData(UserDataFiles.TOKENS) as tokens: - tokens["tokens"] = [t for t in tokens["tokens"] if t["name"] != token_name] - - -def refresh_token(token: str) -> typing.Optional[str]: - """Change the token field of the existing token""" - new_token = _generate_token() - with WriteUserData(UserDataFiles.TOKENS) as tokens: - for t in tokens["tokens"]: - if t["token"] == token: - t["token"] = new_token - return new_token - return None - - -def is_recovery_token_exists(): - """Check if recovery token exists""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - return "recovery_token" in tokens - - -def is_recovery_token_valid(): - """Check if recovery token is valid""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - if "recovery_token" not in tokens: - return False - recovery_token = tokens["recovery_token"] - if "uses_left" in recovery_token and recovery_token["uses_left"] is not None: - if recovery_token["uses_left"] <= 0: - return False - if "expiration" not in recovery_token or recovery_token["expiration"] is None: - return True - return datetime.now() < parse_date(recovery_token["expiration"]) - - -def get_recovery_token_status(): - """Get recovery token date of creation, expiration and uses left""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - if "recovery_token" not in tokens: - return None - recovery_token = tokens["recovery_token"] - return { - "date": recovery_token["date"], - "expiration": recovery_token["expiration"] - if "expiration" in recovery_token - else None, - "uses_left": recovery_token["uses_left"] - if "uses_left" in recovery_token - else None, - } - - -def _get_recovery_token(): - """Get recovery token""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - if "recovery_token" not in tokens: - return None - return tokens["recovery_token"]["token"] - - -def generate_recovery_token( - expiration: typing.Optional[datetime], uses_left: typing.Optional[int] -) -> str: - """Generate a 24 bytes recovery token and return a mneomnic word list. - Write a string representation of the recovery token to the tokens.json file. - """ - # expires must be a date or None - # uses_left must be an integer or None - if expiration is not None: - if not isinstance(expiration, datetime): - raise TypeError("expires must be a datetime object") - if uses_left is not None: - if not isinstance(uses_left, int): - raise TypeError("uses_left must be an integer") - if uses_left <= 0: - raise ValueError("uses_left must be greater than 0") - - recovery_token = secrets.token_bytes(24) - recovery_token_str = recovery_token.hex() - with WriteUserData(UserDataFiles.TOKENS) as tokens: - tokens["recovery_token"] = { - "token": recovery_token_str, - "date": str(datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")), - "expiration": expiration.strftime("%Y-%m-%dT%H:%M:%S.%f") - if expiration is not None - else None, - "uses_left": uses_left if uses_left is not None else None, - } - return Mnemonic(language="english").to_mnemonic(recovery_token) - - -def use_mnemonic_recoverery_token(mnemonic_phrase, name): - """Use the recovery token by converting the mnemonic word list to a byte array. - If the recovery token if invalid itself, return None - If the binary representation of phrase not matches - the byte array of the recovery token, return None. - If the mnemonic phrase is valid then generate a device token and return it. - Substract 1 from uses_left if it exists. - mnemonic_phrase is a string representation of the mnemonic word list. - """ - if not is_recovery_token_valid(): - return None - recovery_token_str = _get_recovery_token() - if recovery_token_str is None: - return None - recovery_token = bytes.fromhex(recovery_token_str) - if not Mnemonic(language="english").check(mnemonic_phrase): - return None - phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase) - if phrase_bytes != recovery_token: - return None - token = _generate_token() - name = _validate_token_name(name) - with WriteUserData(UserDataFiles.TOKENS) as tokens: - tokens["tokens"].append( - { - "token": token, - "name": name, - "date": str(datetime.now()), - } - ) - if "recovery_token" in tokens: - if ( - "uses_left" in tokens["recovery_token"] - and tokens["recovery_token"]["uses_left"] is not None - ): - tokens["recovery_token"]["uses_left"] -= 1 - return token - - -def get_new_device_auth_token() -> str: - """Generate a new device auth token which is valid for 10 minutes - and return a mnemonic phrase representation - Write token to the new_device of the tokens.json file. - """ - token = secrets.token_bytes(16) - token_str = token.hex() - with WriteUserData(UserDataFiles.TOKENS) as tokens: - tokens["new_device"] = { - "token": token_str, - "date": str(datetime.now()), - "expiration": str(datetime.now() + timedelta(minutes=10)), - } - return Mnemonic(language="english").to_mnemonic(token) - - -def _get_new_device_auth_token(): - """Get new device auth token. If it is expired, return None""" - with ReadUserData(UserDataFiles.TOKENS) as tokens: - if "new_device" not in tokens: - return None - new_device = tokens["new_device"] - if "expiration" not in new_device: - return None - expiration = parse_date(new_device["expiration"]) - if datetime.now() > expiration: - return None - return new_device["token"] - - -def delete_new_device_auth_token(): - """Delete new device auth token""" - with WriteUserData(UserDataFiles.TOKENS) as tokens: - if "new_device" in tokens: - del tokens["new_device"] - - -def use_new_device_auth_token(mnemonic_phrase, name): - """Use the new device auth token by converting the mnemonic string to a byte array. - If the mnemonic phrase is valid then generate a device token and return it. - New device auth token must be deleted. - """ - token_str = _get_new_device_auth_token() - if token_str is None: - return None - token = bytes.fromhex(token_str) - if not Mnemonic(language="english").check(mnemonic_phrase): - return None - phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase) - if phrase_bytes != token: - return None - token = create_token(name) - with WriteUserData(UserDataFiles.TOKENS) as tokens: - if "new_device" in tokens: - del tokens["new_device"] - return token diff --git a/tests/test_graphql/test_api_devices.py b/tests/test_graphql/test_api_devices.py index d8dc974..07cf42a 100644 --- a/tests/test_graphql/test_api_devices.py +++ b/tests/test_graphql/test_api_devices.py @@ -2,8 +2,14 @@ # pylint: disable=unused-argument # pylint: disable=missing-function-docstring import datetime +import pytest from mnemonic import Mnemonic +from selfprivacy_api.repositories.tokens.json_tokens_repository import ( + JsonTokensRepository, +) +from selfprivacy_api.models.tokens.token import Token + from tests.common import generate_api_query, read_json, write_json TOKENS_FILE_CONTETS = { @@ -30,6 +36,11 @@ devices { """ +@pytest.fixture +def token_repo(): + return JsonTokensRepository() + + def test_graphql_tokens_info(authorized_client, tokens_file): response = authorized_client.post( "/graphql", @@ -170,7 +181,7 @@ def test_graphql_refresh_token_unauthorized(client, tokens_file): assert response.json()["data"] is None -def test_graphql_refresh_token(authorized_client, tokens_file): +def test_graphql_refresh_token(authorized_client, tokens_file, token_repo): response = authorized_client.post( "/graphql", json={"query": REFRESH_TOKEN_MUTATION}, @@ -180,11 +191,12 @@ def test_graphql_refresh_token(authorized_client, tokens_file): assert response.json()["data"]["refreshDeviceApiToken"]["success"] is True assert response.json()["data"]["refreshDeviceApiToken"]["message"] is not None assert response.json()["data"]["refreshDeviceApiToken"]["code"] == 200 - assert read_json(tokens_file)["tokens"][0] == { - "token": response.json()["data"]["refreshDeviceApiToken"]["token"], - "name": "test_token", - "date": "2022-01-14 08:31:10.789314", - } + token = token_repo.get_token_by_name("test_token") + assert token == Token( + token=response.json()["data"]["refreshDeviceApiToken"]["token"], + device_name="test_token", + created_at=datetime.datetime(2022, 1, 14, 8, 31, 10, 789314), + ) NEW_DEVICE_KEY_MUTATION = """ diff --git a/tests/test_graphql/test_repository/test_tokens_repository.py b/tests/test_graphql/test_repository/test_tokens_repository.py index 43f7626..020a868 100644 --- a/tests/test_graphql/test_repository/test_tokens_repository.py +++ b/tests/test_graphql/test_repository/test_tokens_repository.py @@ -2,7 +2,8 @@ # pylint: disable=unused-argument # pylint: disable=missing-function-docstring -from datetime import datetime +from datetime import datetime, timedelta +from mnemonic import Mnemonic import pytest @@ -32,6 +33,10 @@ ORIGINAL_DEVICE_NAMES = [ ] +def mnemonic_from_hex(hexkey): + return Mnemonic(language="english").to_mnemonic(bytes.fromhex(hexkey)) + + @pytest.fixture def empty_keys(mocker, datadir): mocker.patch("selfprivacy_api.utils.TOKENS_FILE", new=datadir / "empty_keys.json") @@ -132,21 +137,6 @@ def mock_recovery_key_generate(mocker): return mock -@pytest.fixture -def mock_recovery_key_generate_for_mnemonic(mocker): - mock = mocker.patch( - "selfprivacy_api.models.tokens.recovery_key.RecoveryKey.generate", - autospec=True, - return_value=RecoveryKey( - key="ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), - expires_at=None, - uses_left=1, - ), - ) - return mock - - @pytest.fixture def empty_json_repo(empty_keys): repo = JsonTokensRepository() @@ -221,6 +211,28 @@ def test_get_token_by_non_existent_name(some_tokens_repo): assert repo.get_token_by_name(token_name="badname") is None +def test_is_token_valid(some_tokens_repo): + repo = some_tokens_repo + token = repo.get_tokens()[0] + assert repo.is_token_valid(token.token) + assert not repo.is_token_valid("gibberish") + + +def test_is_token_name_pair_valid(some_tokens_repo): + repo = some_tokens_repo + token = repo.get_tokens()[0] + assert repo.is_token_name_pair_valid(token.device_name, token.token) + assert not repo.is_token_name_pair_valid(token.device_name, "gibberish") + assert not repo.is_token_name_pair_valid("gibberish", token.token) + + +def test_is_token_name_exists(some_tokens_repo): + repo = some_tokens_repo + token = repo.get_tokens()[0] + assert repo.is_token_name_exists(token.device_name) + assert not repo.is_token_name_exists("gibberish") + + def test_get_tokens(some_tokens_repo): repo = some_tokens_repo tokenstrings = [] @@ -249,6 +261,17 @@ def test_create_token(empty_repo, mock_token_generate): ] +def test_create_token_existing(some_tokens_repo): + repo = some_tokens_repo + old_token = repo.get_tokens()[0] + + new_token = repo.create_token(device_name=old_token.device_name) + assert new_token.device_name != old_token.device_name + + assert old_token in repo.get_tokens() + assert new_token in repo.get_tokens() + + def test_delete_token(some_tokens_repo): repo = some_tokens_repo original_tokens = repo.get_tokens() @@ -280,15 +303,17 @@ def test_delete_not_found_token(some_tokens_repo): assert token in new_tokens -def test_refresh_token(some_tokens_repo, mock_token_generate): +def test_refresh_token(some_tokens_repo): repo = some_tokens_repo input_token = some_tokens_repo.get_tokens()[0] - assert repo.refresh_token(input_token) == Token( - token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM", - device_name="IamNewDevice", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), - ) + output_token = repo.refresh_token(input_token) + + assert output_token.token != input_token.token + assert output_token.device_name == input_token.device_name + assert output_token.created_at == input_token.created_at + + assert output_token in repo.get_tokens() def test_refresh_not_found_token(some_tokens_repo, mock_token_generate): @@ -355,6 +380,23 @@ def test_use_mnemonic_not_valid_recovery_key( ) +def test_use_mnemonic_expired_recovery_key( + some_tokens_repo, +): + repo = some_tokens_repo + expiration = datetime.now() - timedelta(minutes=5) + assert repo.create_recovery_key(uses_left=2, expiration=expiration) is not None + recovery_key = repo.get_recovery_key() + assert recovery_key.expires_at == expiration + assert not repo.is_recovery_key_valid() + + with pytest.raises(RecoveryKeyNotFound): + token = repo.use_mnemonic_recovery_key( + mnemonic_phrase=mnemonic_from_hex(recovery_key.key), + device_name="newdevice", + ) + + def test_use_mnemonic_not_mnemonic_recovery_key(some_tokens_repo): repo = some_tokens_repo assert repo.create_recovery_key(uses_left=1, expiration=None) is not None @@ -397,46 +439,38 @@ def test_use_not_found_mnemonic_recovery_key(some_tokens_repo): ) -def test_use_mnemonic_recovery_key_when_empty(empty_repo): - repo = empty_repo - - with pytest.raises(RecoveryKeyNotFound): - assert ( - repo.use_mnemonic_recovery_key( - mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb", - device_name="primary_token", - ) - is None - ) +@pytest.fixture(params=["recovery_uses_1", "recovery_eternal"]) +def recovery_key_uses_left(request): + if request.param == "recovery_uses_1": + return 1 + if request.param == "recovery_eternal": + return None -# agnostic test mixed with an implementation test -def test_use_mnemonic_recovery_key( - some_tokens_repo, mock_recovery_key_generate_for_mnemonic, mock_generate_token -): +def test_use_mnemonic_recovery_key(some_tokens_repo, recovery_key_uses_left): repo = some_tokens_repo - assert repo.create_recovery_key(uses_left=1, expiration=None) is not None - - test_token = Token( - token="ur71mC4aiI6FIYAN--cTL-38rPHS5D6NuB1bgN_qKF4", - device_name="newdevice", - created_at=datetime(2022, 11, 14, 6, 6, 32, 777123), - ) - assert ( - repo.use_mnemonic_recovery_key( - mnemonic_phrase="uniform clarify napkin bid dress search input armor police cross salon because myself uphold slice bamboo hungry park", - device_name="newdevice", - ) - == test_token + repo.create_recovery_key(uses_left=recovery_key_uses_left, expiration=None) + is not None + ) + assert repo.is_recovery_key_valid() + recovery_key = repo.get_recovery_key() + + token = repo.use_mnemonic_recovery_key( + mnemonic_phrase=mnemonic_from_hex(recovery_key.key), + device_name="newdevice", ) - assert test_token in repo.get_tokens() + assert token.device_name == "newdevice" + assert token in repo.get_tokens() + new_uses = None + if recovery_key_uses_left is not None: + new_uses = recovery_key_uses_left - 1 assert repo.get_recovery_key() == RecoveryKey( - key="ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + key=recovery_key.key, + created_at=recovery_key.created_at, expires_at=None, - uses_left=0, + uses_left=new_uses, ) @@ -497,15 +531,16 @@ def test_use_not_exists_mnemonic_new_device_key( ) -def test_use_mnemonic_new_device_key( - empty_repo, mock_new_device_key_generate_for_mnemonic -): +def test_use_mnemonic_new_device_key(empty_repo): repo = empty_repo - assert repo.get_new_device_key() is not None + key = repo.get_new_device_key() + assert key is not None + + mnemonic_phrase = mnemonic_from_hex(key.key) new_token = repo.use_mnemonic_new_device_key( device_name="imnew", - mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb", + mnemonic_phrase=mnemonic_phrase, ) assert new_token.device_name == "imnew" @@ -516,12 +551,32 @@ def test_use_mnemonic_new_device_key( assert ( repo.use_mnemonic_new_device_key( device_name="imnew", - mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb", + mnemonic_phrase=mnemonic_phrase, ) is None ) +def test_use_mnemonic_expired_new_device_key( + some_tokens_repo, +): + repo = some_tokens_repo + expiration = datetime.now() - timedelta(minutes=5) + + key = repo.get_new_device_key() + assert key is not None + assert key.expires_at is not None + key.expires_at = expiration + assert not key.is_valid() + repo._store_new_device_key(key) + + with pytest.raises(NewDeviceKeyNotFound): + token = repo.use_mnemonic_new_device_key( + mnemonic_phrase=mnemonic_from_hex(key.key), + device_name="imnew", + ) + + def test_use_mnemonic_new_device_key_when_empty(empty_repo): repo = empty_repo diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..2263e82 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,18 @@ +import pytest +from datetime import datetime, timedelta + +from selfprivacy_api.models.tokens.recovery_key import RecoveryKey +from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey + + +def test_recovery_key_expired(): + expiration = datetime.now() - timedelta(minutes=5) + key = RecoveryKey.generate(expiration=expiration, uses_left=2) + assert not key.is_valid() + + +def test_new_device_key_expired(): + expiration = datetime.now() - timedelta(minutes=5) + key = NewDeviceKey.generate() + key.expires_at = expiration + assert not key.is_valid() diff --git a/tests/test_rest_endpoints/test_auth.py b/tests/test_rest_endpoints/test_auth.py index 1083be5..12de0cf 100644 --- a/tests/test_rest_endpoints/test_auth.py +++ b/tests/test_rest_endpoints/test_auth.py @@ -5,6 +5,12 @@ import datetime import pytest from mnemonic import Mnemonic +from selfprivacy_api.repositories.tokens.json_tokens_repository import ( + JsonTokensRepository, +) + +TOKEN_REPO = JsonTokensRepository() + from tests.common import read_json, write_json @@ -97,7 +103,7 @@ def test_refresh_token(authorized_client, tokens_file): response = authorized_client.post("/auth/tokens") assert response.status_code == 200 new_token = response.json()["token"] - assert read_json(tokens_file)["tokens"][0]["token"] == new_token + assert TOKEN_REPO.get_token_by_token_string(new_token) is not None # new device