From 158c1f13a6425d726bfa0810d4f4bb58e6b7dc6a Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Wed, 11 Jan 2023 17:02:01 +0000 Subject: [PATCH] refactor(tokens-repo): switch token backend to redis And use timezone-aware comparisons for expiry checks --- selfprivacy_api/actions/api_tokens.py | 20 +++++-- .../models/tokens/new_device_key.py | 10 ++-- selfprivacy_api/models/tokens/recovery_key.py | 10 +++- selfprivacy_api/models/tokens/time.py | 13 +++++ .../tokens/json_tokens_repository.py | 25 ++++++++- .../tokens/redis_tokens_repository.py | 6 +- tests/common.py | 22 +++++--- tests/conftest.py | 22 +++++--- tests/test_graphql/test_api_recovery.py | 12 ++-- .../test_repository/test_tokens_repository.py | 56 +++++++++---------- tests/test_rest_endpoints/test_auth.py | 10 ++-- 11 files changed, 136 insertions(+), 70 deletions(-) create mode 100644 selfprivacy_api/models/tokens/time.py diff --git a/selfprivacy_api/actions/api_tokens.py b/selfprivacy_api/actions/api_tokens.py index 38133fd..2337224 100644 --- a/selfprivacy_api/actions/api_tokens.py +++ b/selfprivacy_api/actions/api_tokens.py @@ -1,11 +1,11 @@ """App tokens actions""" -from datetime import datetime +from datetime import datetime, timezone from typing import Optional from pydantic import BaseModel from mnemonic import Mnemonic -from selfprivacy_api.repositories.tokens.json_tokens_repository import ( - JsonTokensRepository, +from selfprivacy_api.repositories.tokens.redis_tokens_repository import ( + RedisTokensRepository, ) from selfprivacy_api.repositories.tokens.exceptions import ( TokenNotFound, @@ -14,7 +14,7 @@ from selfprivacy_api.repositories.tokens.exceptions import ( NewDeviceKeyNotFound, ) -TOKEN_REPO = JsonTokensRepository() +TOKEN_REPO = RedisTokensRepository() class TokenInfoWithIsCaller(BaseModel): @@ -82,6 +82,14 @@ class RecoveryTokenStatus(BaseModel): uses_left: Optional[int] = None +def naive(date_time: datetime) -> datetime: + if date_time is None: + return None + if date_time.tzinfo is not None: + date_time.astimezone(timezone.utc) + return date_time.replace(tzinfo=None) + + def get_api_recovery_token_status() -> RecoveryTokenStatus: """Get the recovery token status""" token = TOKEN_REPO.get_recovery_key() @@ -91,8 +99,8 @@ def get_api_recovery_token_status() -> RecoveryTokenStatus: return RecoveryTokenStatus( exists=True, valid=is_valid, - date=token.created_at, - expiration=token.expires_at, + date=naive(token.created_at), + expiration=naive(token.expires_at), uses_left=token.uses_left, ) diff --git a/selfprivacy_api/models/tokens/new_device_key.py b/selfprivacy_api/models/tokens/new_device_key.py index dda926c..9fbd23b 100644 --- a/selfprivacy_api/models/tokens/new_device_key.py +++ b/selfprivacy_api/models/tokens/new_device_key.py @@ -1,11 +1,13 @@ """ New device key used to obtain access token. """ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import secrets from pydantic import BaseModel from mnemonic import Mnemonic +from selfprivacy_api.models.tokens.time import is_past + class NewDeviceKey(BaseModel): """ @@ -22,7 +24,7 @@ class NewDeviceKey(BaseModel): """ Check if the recovery key is valid. """ - if self.expires_at < datetime.now(): + if is_past(self.expires_at): return False return True @@ -37,10 +39,10 @@ class NewDeviceKey(BaseModel): """ Factory to generate a random token. """ - creation_date = datetime.now() + creation_date = datetime.now(timezone.utc) key = secrets.token_bytes(16).hex() return NewDeviceKey( key=key, created_at=creation_date, - expires_at=datetime.now() + timedelta(minutes=10), + expires_at=creation_date + timedelta(minutes=10), ) diff --git a/selfprivacy_api/models/tokens/recovery_key.py b/selfprivacy_api/models/tokens/recovery_key.py index 098aceb..3b81398 100644 --- a/selfprivacy_api/models/tokens/recovery_key.py +++ b/selfprivacy_api/models/tokens/recovery_key.py @@ -3,12 +3,14 @@ Recovery key used to obtain access token. Recovery key has a token string, date of creation, optional date of expiration and optional count of uses left. """ -from datetime import datetime +from datetime import datetime, timezone import secrets from typing import Optional from pydantic import BaseModel from mnemonic import Mnemonic +from selfprivacy_api.models.tokens.time import is_past, ensure_timezone + class RecoveryKey(BaseModel): """ @@ -26,7 +28,7 @@ class RecoveryKey(BaseModel): """ Check if the recovery key is valid. """ - if self.expires_at is not None and self.expires_at < datetime.now(): + if self.expires_at is not None and is_past(self.expires_at): return False if self.uses_left is not None and self.uses_left <= 0: return False @@ -46,7 +48,9 @@ class RecoveryKey(BaseModel): """ Factory to generate a random token. """ - creation_date = datetime.now() + creation_date = datetime.now(timezone.utc) + if expiration is not None: + expiration = ensure_timezone(expiration) key = secrets.token_bytes(24).hex() return RecoveryKey( key=key, diff --git a/selfprivacy_api/models/tokens/time.py b/selfprivacy_api/models/tokens/time.py new file mode 100644 index 0000000..35fd992 --- /dev/null +++ b/selfprivacy_api/models/tokens/time.py @@ -0,0 +1,13 @@ +from datetime import datetime, timezone + +def is_past(dt: datetime) -> bool: + # we cannot compare a naive now() + # to dt which might be tz-aware or unaware + dt = ensure_timezone(dt) + return dt < datetime.now(timezone.utc) + +def ensure_timezone(dt:datetime) -> datetime: + if dt.tzinfo is None or dt.tzinfo.utcoffset(None) is None: + dt = dt.replace(tzinfo= timezone.utc) + return dt + diff --git a/selfprivacy_api/repositories/tokens/json_tokens_repository.py b/selfprivacy_api/repositories/tokens/json_tokens_repository.py index 77e1311..09204a8 100644 --- a/selfprivacy_api/repositories/tokens/json_tokens_repository.py +++ b/selfprivacy_api/repositories/tokens/json_tokens_repository.py @@ -2,7 +2,7 @@ temporary legacy """ from typing import Optional -from datetime import datetime +from datetime import datetime, timezone from selfprivacy_api.utils import UserDataFiles, WriteUserData, ReadUserData from selfprivacy_api.models.tokens.token import Token @@ -15,6 +15,7 @@ from selfprivacy_api.repositories.tokens.abstract_tokens_repository import ( AbstractTokensRepository, ) + DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f" @@ -56,6 +57,20 @@ class JsonTokensRepository(AbstractTokensRepository): raise TokenNotFound("Token not found!") + def __key_date_from_str(self, date_string: str) -> datetime: + if date_string is None or date_string == "": + return None + # we assume that we store dates in json as naive utc + utc_no_tz = datetime.fromisoformat(date_string) + utc_with_tz = utc_no_tz.replace(tzinfo=timezone.utc) + return utc_with_tz + + def __date_from_tokens_file( + self, tokens_file: object, tokenfield: str, datefield: str + ): + date_string = tokens_file[tokenfield].get(datefield) + return self.__key_date_from_str(date_string) + def get_recovery_key(self) -> Optional[RecoveryKey]: """Get the recovery key""" with ReadUserData(UserDataFiles.TOKENS) as tokens_file: @@ -68,8 +83,12 @@ class JsonTokensRepository(AbstractTokensRepository): recovery_key = RecoveryKey( key=tokens_file["recovery_token"].get("token"), - created_at=tokens_file["recovery_token"].get("date"), - expires_at=tokens_file["recovery_token"].get("expiration"), + created_at=self.__date_from_tokens_file( + tokens_file, "recovery_token", "date" + ), + expires_at=self.__date_from_tokens_file( + tokens_file, "recovery_token", "expiration" + ), uses_left=tokens_file["recovery_token"].get("uses_left"), ) diff --git a/selfprivacy_api/repositories/tokens/redis_tokens_repository.py b/selfprivacy_api/repositories/tokens/redis_tokens_repository.py index c72e231..a16b79d 100644 --- a/selfprivacy_api/repositories/tokens/redis_tokens_repository.py +++ b/selfprivacy_api/repositories/tokens/redis_tokens_repository.py @@ -2,7 +2,7 @@ Token repository using Redis as backend. """ from typing import Optional -from datetime import datetime +from datetime import datetime, timezone from selfprivacy_api.repositories.tokens.abstract_tokens_repository import ( AbstractTokensRepository, @@ -38,6 +38,8 @@ class RedisTokensRepository(AbstractTokensRepository): for key in token_keys: token = self._token_from_hash(key) if token is not None: + # token creation dates are temporarily not tz-aware + token.created_at = token.created_at.replace(tzinfo=None) tokens.append(token) return tokens @@ -150,5 +152,7 @@ class RedisTokensRepository(AbstractTokensRepository): redis = self.connection for key, value in model.dict().items(): if isinstance(value, datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) value = value.isoformat() redis.hset(redis_key, key, str(value)) diff --git a/tests/common.py b/tests/common.py index a49885a..08ddc66 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,16 +1,21 @@ import json -import datetime +from datetime import datetime, timezone, timedelta from mnemonic import Mnemonic # for expiration tests. If headache, consider freezegun -RECOVERY_KEY_VALIDATION_DATETIME = "selfprivacy_api.models.tokens.recovery_key.datetime" -DEVICE_KEY_VALIDATION_DATETIME = "selfprivacy_api.models.tokens.new_device_key.datetime" +RECOVERY_KEY_VALIDATION_DATETIME = "selfprivacy_api.models.tokens.time.datetime" +DEVICE_KEY_VALIDATION_DATETIME = RECOVERY_KEY_VALIDATION_DATETIME + +FIVE_MINUTES_INTO_FUTURE_NAIVE = datetime.now() + timedelta(minutes=5) +FIVE_MINUTES_INTO_FUTURE = datetime.now(timezone.utc) + timedelta(minutes=5) +FIVE_MINUTES_INTO_PAST_NAIVE = datetime.now() - timedelta(minutes=5) +FIVE_MINUTES_INTO_PAST = datetime.now(timezone.utc) - timedelta(minutes=5) -class NearFuture(datetime.datetime): +class NearFuture(datetime): @classmethod - def now(cls): - return datetime.datetime.now() + datetime.timedelta(minutes=13) + def now(cls, tz=None): + return datetime.now(tz) + timedelta(minutes=13) def read_json(file_path): @@ -41,7 +46,6 @@ def mnemonic_to_hex(mnemonic): def assert_recovery_recent(time_generated): assert ( - datetime.datetime.strptime(time_generated, "%Y-%m-%dT%H:%M:%S.%f") - - datetime.timedelta(seconds=5) - < datetime.datetime.now() + datetime.strptime(time_generated, "%Y-%m-%dT%H:%M:%S.%f") - timedelta(seconds=5) + < datetime.now() ) diff --git a/tests/conftest.py b/tests/conftest.py index 212b6da..52ded90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,9 @@ from selfprivacy_api.models.tokens.token import Token from selfprivacy_api.repositories.tokens.json_tokens_repository import ( JsonTokensRepository, ) +from selfprivacy_api.repositories.tokens.redis_tokens_repository import ( + RedisTokensRepository, +) from tests.common import read_json @@ -63,21 +66,26 @@ def empty_json_repo(empty_tokens): @pytest.fixture -def tokens_file(empty_json_repo, tmpdir): +def empty_redis_repo(): + repo = RedisTokensRepository() + repo.reset() + assert repo.get_tokens() == [] + return repo + + +@pytest.fixture +def tokens_file(empty_redis_repo, tmpdir): """A state with tokens""" + repo = empty_redis_repo for token in TOKENS_FILE_CONTENTS["tokens"]: - empty_json_repo._store_token( + repo._store_token( Token( token=token["token"], device_name=token["name"], created_at=token["date"], ) ) - # temporary return for compatibility with older tests - - tokenfile = tmpdir / "empty_tokens.json" - assert path.exists(tokenfile) - return tokenfile + return repo @pytest.fixture diff --git a/tests/test_graphql/test_api_recovery.py b/tests/test_graphql/test_api_recovery.py index 9d6e671..a19eae2 100644 --- a/tests/test_graphql/test_api_recovery.py +++ b/tests/test_graphql/test_api_recovery.py @@ -1,7 +1,6 @@ # pylint: disable=redefined-outer-name # pylint: disable=unused-argument # pylint: disable=missing-function-docstring -import datetime from tests.common import ( generate_api_query, @@ -9,6 +8,11 @@ from tests.common import ( NearFuture, RECOVERY_KEY_VALIDATION_DATETIME, ) + +# Graphql API's output should be timezone-naive +from tests.common import FIVE_MINUTES_INTO_FUTURE_NAIVE as FIVE_MINUTES_INTO_FUTURE +from tests.common import FIVE_MINUTES_INTO_PAST_NAIVE as FIVE_MINUTES_INTO_PAST + from tests.test_graphql.common import ( assert_empty, assert_data, @@ -153,7 +157,7 @@ def test_graphql_generate_recovery_key(client, authorized_client, tokens_file): def test_graphql_generate_recovery_key_with_expiration_date( client, authorized_client, tokens_file ): - expiration_date = datetime.datetime.now() + datetime.timedelta(minutes=5) + expiration_date = FIVE_MINUTES_INTO_FUTURE key = graphql_make_new_recovery_key(authorized_client, expires_at=expiration_date) status = graphql_recovery_status(authorized_client) @@ -171,7 +175,7 @@ def test_graphql_generate_recovery_key_with_expiration_date( def test_graphql_use_recovery_key_after_expiration( client, authorized_client, tokens_file, mocker ): - expiration_date = datetime.datetime.now() + datetime.timedelta(minutes=5) + expiration_date = FIVE_MINUTES_INTO_FUTURE key = graphql_make_new_recovery_key(authorized_client, expires_at=expiration_date) # Timewarp to after it expires @@ -193,7 +197,7 @@ def test_graphql_use_recovery_key_after_expiration( def test_graphql_generate_recovery_key_with_expiration_in_the_past( authorized_client, tokens_file ): - expiration_date = datetime.datetime.now() - datetime.timedelta(minutes=5) + expiration_date = FIVE_MINUTES_INTO_PAST response = request_make_new_recovery_key( authorized_client, expires_at=expiration_date ) diff --git a/tests/test_graphql/test_repository/test_tokens_repository.py b/tests/test_graphql/test_repository/test_tokens_repository.py index a2dbb7a..7eede6a 100644 --- a/tests/test_graphql/test_repository/test_tokens_repository.py +++ b/tests/test_graphql/test_repository/test_tokens_repository.py @@ -2,7 +2,7 @@ # pylint: disable=unused-argument # pylint: disable=missing-function-docstring -from datetime import datetime, timedelta +from datetime import datetime, timezone from mnemonic import Mnemonic import pytest @@ -16,9 +16,8 @@ from selfprivacy_api.repositories.tokens.exceptions import ( TokenNotFound, NewDeviceKeyNotFound, ) -from selfprivacy_api.repositories.tokens.redis_tokens_repository import ( - RedisTokensRepository, -) + +from tests.common import FIVE_MINUTES_INTO_PAST ORIGINAL_DEVICE_NAMES = [ @@ -28,6 +27,10 @@ ORIGINAL_DEVICE_NAMES = [ "forth_token", ] +TEST_DATE = datetime(2022, 7, 15, 17, 41, 31, 675698, timezone.utc) +# tokens are not tz-aware +TOKEN_TEST_DATE = datetime(2022, 7, 15, 17, 41, 31, 675698) + def mnemonic_from_hex(hexkey): return Mnemonic(language="english").to_mnemonic(bytes.fromhex(hexkey)) @@ -40,8 +43,8 @@ def mock_new_device_key_generate(mocker): autospec=True, return_value=NewDeviceKey( key="43478d05b35e4781598acd76e33832bb", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), - expires_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TEST_DATE, + expires_at=TEST_DATE, ), ) return mock @@ -55,8 +58,8 @@ def mock_new_device_key_generate_for_mnemonic(mocker): autospec=True, return_value=NewDeviceKey( key="2237238de23dc71ab558e317bdb8ff8e", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), - expires_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TEST_DATE, + expires_at=TEST_DATE, ), ) return mock @@ -83,7 +86,7 @@ def mock_recovery_key_generate_invalid(mocker): autospec=True, return_value=RecoveryKey( key="889bf49c1d3199d71a2e704718772bd53a422020334db051", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TEST_DATE, expires_at=None, uses_left=0, ), @@ -99,7 +102,7 @@ def mock_token_generate(mocker): return_value=Token( token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM", device_name="IamNewDevice", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TOKEN_TEST_DATE, ), ) return mock @@ -112,7 +115,7 @@ def mock_recovery_key_generate(mocker): autospec=True, return_value=RecoveryKey( key="889bf49c1d3199d71a2e704718772bd53a422020334db051", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TEST_DATE, expires_at=None, uses_left=1, ), @@ -120,14 +123,6 @@ def mock_recovery_key_generate(mocker): return mock -@pytest.fixture -def empty_redis_repo(): - repo = RedisTokensRepository() - repo.reset() - assert repo.get_tokens() == [] - return repo - - @pytest.fixture(params=["json", "redis"]) def empty_repo(request, empty_json_repo, empty_redis_repo): if request.param == "json": @@ -224,13 +219,13 @@ def test_create_token(empty_repo, mock_token_generate): assert repo.create_token(device_name="IamNewDevice") == Token( token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM", device_name="IamNewDevice", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TOKEN_TEST_DATE, ) assert repo.get_tokens() == [ Token( token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM", device_name="IamNewDevice", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TOKEN_TEST_DATE, ) ] @@ -266,7 +261,7 @@ def test_delete_not_found_token(some_tokens_repo): input_token = Token( token="imbadtoken", device_name="primary_token", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TEST_DATE, ) with pytest.raises(TokenNotFound): assert repo.delete_token(input_token) is None @@ -295,7 +290,7 @@ def test_refresh_not_found_token(some_tokens_repo, mock_token_generate): input_token = Token( token="idontknowwhoiam", device_name="tellmewhoiam?", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TEST_DATE, ) with pytest.raises(TokenNotFound): @@ -319,7 +314,7 @@ def test_create_get_recovery_key(some_tokens_repo, mock_recovery_key_generate): assert repo.create_recovery_key(uses_left=1, expiration=None) is not None assert repo.get_recovery_key() == RecoveryKey( key="889bf49c1d3199d71a2e704718772bd53a422020334db051", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TEST_DATE, expires_at=None, uses_left=1, ) @@ -358,10 +353,13 @@ def test_use_mnemonic_expired_recovery_key( some_tokens_repo, ): repo = some_tokens_repo - expiration = datetime.now() - timedelta(minutes=5) + expiration = FIVE_MINUTES_INTO_PAST assert repo.create_recovery_key(uses_left=2, expiration=expiration) is not None recovery_key = repo.get_recovery_key() - assert recovery_key.expires_at == expiration + # TODO: do not ignore timezone once json backend is deleted + assert recovery_key.expires_at.replace(tzinfo=None) == expiration.replace( + tzinfo=None + ) assert not repo.is_recovery_key_valid() with pytest.raises(RecoveryKeyNotFound): @@ -458,8 +456,8 @@ def test_get_new_device_key(some_tokens_repo, mock_new_device_key_generate): assert repo.get_new_device_key() == NewDeviceKey( key="43478d05b35e4781598acd76e33832bb", - created_at=datetime(2022, 7, 15, 17, 41, 31, 675698), - expires_at=datetime(2022, 7, 15, 17, 41, 31, 675698), + created_at=TEST_DATE, + expires_at=TEST_DATE, ) @@ -535,7 +533,7 @@ def test_use_mnemonic_expired_new_device_key( some_tokens_repo, ): repo = some_tokens_repo - expiration = datetime.now() - timedelta(minutes=5) + expiration = FIVE_MINUTES_INTO_PAST key = repo.get_new_device_key() assert key is not None diff --git a/tests/test_rest_endpoints/test_auth.py b/tests/test_rest_endpoints/test_auth.py index ff161fb..ba54745 100644 --- a/tests/test_rest_endpoints/test_auth.py +++ b/tests/test_rest_endpoints/test_auth.py @@ -11,6 +11,8 @@ from tests.common import ( NearFuture, assert_recovery_recent, ) +from tests.common import FIVE_MINUTES_INTO_FUTURE_NAIVE as FIVE_MINUTES_INTO_FUTURE +from tests.common import FIVE_MINUTES_INTO_PAST_NAIVE as FIVE_MINUTES_INTO_PAST DATE_FORMATS = [ "%Y-%m-%dT%H:%M:%S.%fZ", @@ -110,7 +112,7 @@ def rest_recover_with_mnemonic(client, mnemonic_token, device_name): def test_get_tokens_info(authorized_client, tokens_file): - assert rest_get_tokens_info(authorized_client) == [ + assert sorted(rest_get_tokens_info(authorized_client), key=lambda x: x["name"]) == [ {"name": "test_token", "date": "2022-01-14T08:31:10.789314", "is_caller": True}, { "name": "test_token2", @@ -321,7 +323,7 @@ def test_generate_recovery_token_with_expiration_date( ): # Generate token with expiration date # Generate expiration date in the future - expiration_date = datetime.datetime.now() + datetime.timedelta(minutes=5) + expiration_date = FIVE_MINUTES_INTO_FUTURE mnemonic_token = rest_make_recovery_token( authorized_client, expires_at=expiration_date, timeformat=timeformat ) @@ -333,7 +335,7 @@ def test_generate_recovery_token_with_expiration_date( "exists": True, "valid": True, "date": time_generated, - "expiration": expiration_date.strftime("%Y-%m-%dT%H:%M:%S.%f"), + "expiration": expiration_date.isoformat(), "uses_left": None, } @@ -360,7 +362,7 @@ def test_generate_recovery_token_with_expiration_in_the_past( authorized_client, tokens_file, timeformat ): # Server must return 400 if expiration date is in the past - expiration_date = datetime.datetime.utcnow() - datetime.timedelta(minutes=5) + expiration_date = FIVE_MINUTES_INTO_PAST expiration_date_str = expiration_date.strftime(timeformat) response = authorized_client.post( "/auth/recovery_token",