Merge pull request 'Migrate to AbstractTokenRepository API' (#28) from redis/token-repo into redis/connection-pool

Reviewed-on: #28
pull/30/head
Inex Code 2022-12-30 20:06:43 +02:00
commit 45c6133881
16 changed files with 394 additions and 559 deletions

View File

@ -5,12 +5,16 @@ name: default
steps:
- name: Run Tests and Generate Coverage Report
commands:
- kill $(ps aux | grep '[r]edis-server 127.0.0.1:6389' | awk '{print $2}')
- redis-server --bind 127.0.0.1 --port 6389 >/dev/null &
- coverage run -m pytest -q
- coverage xml
- sonar-scanner -Dsonar.projectKey=SelfPrivacy-REST-API -Dsonar.sources=. -Dsonar.host.url=http://analyzer.lan:9000 -Dsonar.login="$SONARQUBE_TOKEN"
environment:
SONARQUBE_TOKEN:
from_secret: SONARQUBE_TOKEN
USE_REDIS_PORT: 6389
- name: Run Bandit Checks
commands:

View File

@ -2,20 +2,19 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
from mnemonic import Mnemonic
from selfprivacy_api.utils.auth import (
delete_token,
generate_recovery_token,
get_recovery_token_status,
get_tokens_info,
is_recovery_token_exists,
is_recovery_token_valid,
is_token_name_exists,
is_token_name_pair_valid,
refresh_token,
get_token_name,
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
JsonTokensRepository,
)
from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound,
RecoveryKeyNotFound,
InvalidMnemonic,
NewDeviceKeyNotFound,
)
TOKEN_REPO = JsonTokensRepository()
class TokenInfoWithIsCaller(BaseModel):
@ -28,18 +27,23 @@ class TokenInfoWithIsCaller(BaseModel):
def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCaller]:
"""Get the tokens info"""
caller_name = get_token_name(caller_token)
tokens = get_tokens_info()
caller_name = TOKEN_REPO.get_token_by_token_string(caller_token).device_name
tokens = TOKEN_REPO.get_tokens()
return [
TokenInfoWithIsCaller(
name=token.name,
date=token.date,
is_caller=token.name == caller_name,
name=token.device_name,
date=token.created_at,
is_caller=token.device_name == caller_name,
)
for token in tokens
]
def is_token_valid(token) -> bool:
"""Check if token is valid"""
return TOKEN_REPO.is_token_valid(token)
class NotFoundException(Exception):
"""Not found exception"""
@ -50,19 +54,22 @@ class CannotDeleteCallerException(Exception):
def delete_api_token(caller_token: str, token_name: str) -> None:
"""Delete the token"""
if is_token_name_pair_valid(token_name, caller_token):
if TOKEN_REPO.is_token_name_pair_valid(token_name, caller_token):
raise CannotDeleteCallerException("Cannot delete caller's token")
if not is_token_name_exists(token_name):
if not TOKEN_REPO.is_token_name_exists(token_name):
raise NotFoundException("Token not found")
delete_token(token_name)
token = TOKEN_REPO.get_token_by_name(token_name)
TOKEN_REPO.delete_token(token)
def refresh_api_token(caller_token: str) -> str:
"""Refresh the token"""
new_token = refresh_token(caller_token)
if new_token is None:
try:
old_token = TOKEN_REPO.get_token_by_token_string(caller_token)
new_token = TOKEN_REPO.refresh_token(old_token)
except TokenNotFound:
raise NotFoundException("Token not found")
return new_token
return new_token.token
class RecoveryTokenStatus(BaseModel):
@ -77,18 +84,16 @@ class RecoveryTokenStatus(BaseModel):
def get_api_recovery_token_status() -> RecoveryTokenStatus:
"""Get the recovery token status"""
if not is_recovery_token_exists():
token = TOKEN_REPO.get_recovery_key()
if token is None:
return RecoveryTokenStatus(exists=False, valid=False)
status = get_recovery_token_status()
if status is None:
return RecoveryTokenStatus(exists=False, valid=False)
is_valid = is_recovery_token_valid()
is_valid = TOKEN_REPO.is_recovery_key_valid()
return RecoveryTokenStatus(
exists=True,
valid=is_valid,
date=status["date"],
expiration=status["expiration"],
uses_left=status["uses_left"],
date=token.created_at,
expiration=token.expires_at,
uses_left=token.uses_left,
)
@ -112,5 +117,46 @@ def get_new_api_recovery_key(
if uses_left <= 0:
raise InvalidUsesLeft("Uses must be greater than 0")
key = generate_recovery_token(expiration_date, uses_left)
return key
key = TOKEN_REPO.create_recovery_key(expiration_date, uses_left)
mnemonic_phrase = Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key))
return mnemonic_phrase
def use_mnemonic_recovery_token(mnemonic_phrase, name):
"""Use the recovery token by converting the mnemonic word list to a byte array.
If the recovery token if invalid itself, return None
If the binary representation of phrase not matches
the byte array of the recovery token, return None.
If the mnemonic phrase is valid then generate a device token and return it.
Substract 1 from uses_left if it exists.
mnemonic_phrase is a string representation of the mnemonic word list.
"""
try:
token = TOKEN_REPO.use_mnemonic_recovery_key(mnemonic_phrase, name)
return token.token
except (RecoveryKeyNotFound, InvalidMnemonic):
return None
def delete_new_device_auth_token() -> None:
TOKEN_REPO.delete_new_device_key()
def get_new_device_auth_token() -> str:
"""Generate and store a new device auth token which is valid for 10 minutes
and return a mnemonic phrase representation
"""
key = TOKEN_REPO.get_new_device_key()
return Mnemonic(language="english").to_mnemonic(bytes.fromhex(key.key))
def use_new_device_auth_token(mnemonic_phrase, name) -> Optional[str]:
"""Use the new device auth token by converting the mnemonic string to a byte array.
If the mnemonic phrase is valid then generate a device token and return it.
New device auth token must be deleted.
"""
try:
token = TOKEN_REPO.use_mnemonic_new_device_key(mnemonic_phrase, name)
return token.token
except (NewDeviceKeyNotFound, InvalidMnemonic):
return None

View File

@ -2,7 +2,7 @@ from fastapi import Depends, HTTPException, status
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from selfprivacy_api.utils.auth import is_token_valid
from selfprivacy_api.actions.api_tokens import is_token_valid
class TokenHeader(BaseModel):

View File

@ -4,7 +4,7 @@ import typing
from strawberry.permission import BasePermission
from strawberry.types import Info
from selfprivacy_api.utils.auth import is_token_valid
from selfprivacy_api.actions.api_tokens import is_token_valid
class IsAuthenticated(BasePermission):

View File

@ -11,6 +11,11 @@ from selfprivacy_api.actions.api_tokens import (
NotFoundException,
delete_api_token,
get_new_api_recovery_key,
use_mnemonic_recovery_token,
refresh_api_token,
delete_new_device_auth_token,
get_new_device_auth_token,
use_new_device_auth_token,
)
from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.mutations.mutation_interface import (
@ -18,14 +23,6 @@ from selfprivacy_api.graphql.mutations.mutation_interface import (
MutationReturnInterface,
)
from selfprivacy_api.utils.auth import (
delete_new_device_auth_token,
get_new_device_auth_token,
refresh_token,
use_mnemonic_recoverery_token,
use_new_device_auth_token,
)
@strawberry.type
class ApiKeyMutationReturn(MutationReturnInterface):
@ -98,50 +95,53 @@ class ApiMutations:
self, input: UseRecoveryKeyInput
) -> DeviceApiTokenMutationReturn:
"""Use recovery key"""
token = use_mnemonic_recoverery_token(input.key, input.deviceName)
if token is None:
token = use_mnemonic_recovery_token(input.key, input.deviceName)
if token is not None:
return DeviceApiTokenMutationReturn(
success=True,
message="Recovery key used",
code=200,
token=token,
)
else:
return DeviceApiTokenMutationReturn(
success=False,
message="Recovery key not found",
code=404,
token=None,
)
return DeviceApiTokenMutationReturn(
success=True,
message="Recovery key used",
code=200,
token=token,
)
@strawberry.mutation(permission_classes=[IsAuthenticated])
def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn:
"""Refresh device api token"""
token = (
token_string = (
info.context["request"]
.headers.get("Authorization", "")
.replace("Bearer ", "")
)
if token is None:
if token_string is None:
return DeviceApiTokenMutationReturn(
success=False,
message="Token not found",
code=404,
token=None,
)
new_token = refresh_token(token)
if new_token is None:
try:
new_token = refresh_api_token(token_string)
return DeviceApiTokenMutationReturn(
success=True,
message="Token refreshed",
code=200,
token=new_token,
)
except NotFoundException:
return DeviceApiTokenMutationReturn(
success=False,
message="Token not found",
code=404,
token=None,
)
return DeviceApiTokenMutationReturn(
success=True,
message="Token refreshed",
code=200,
token=new_token,
)
@strawberry.mutation(permission_classes=[IsAuthenticated])
def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn:

View File

@ -4,16 +4,12 @@ import datetime
import typing
import strawberry
from strawberry.types import Info
from selfprivacy_api.actions.api_tokens import get_api_tokens_with_caller_flag
from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.utils import parse_date
from selfprivacy_api.dependencies import get_api_version as get_api_version_dependency
from selfprivacy_api.utils.auth import (
get_recovery_token_status,
is_recovery_token_exists,
is_recovery_token_valid,
from selfprivacy_api.actions.api_tokens import (
get_api_tokens_with_caller_flag,
get_api_recovery_token_status,
)
from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.dependencies import get_api_version as get_api_version_dependency
def get_api_version() -> str:
@ -43,16 +39,8 @@ class ApiRecoveryKeyStatus:
def get_recovery_key_status() -> ApiRecoveryKeyStatus:
"""Get recovery key status"""
if not is_recovery_token_exists():
return ApiRecoveryKeyStatus(
exists=False,
valid=False,
creation_date=None,
expiration_date=None,
uses_left=None,
)
status = get_recovery_token_status()
if status is None:
status = get_api_recovery_token_status()
if status is None or not status.exists:
return ApiRecoveryKeyStatus(
exists=False,
valid=False,
@ -62,12 +50,10 @@ def get_recovery_key_status() -> ApiRecoveryKeyStatus:
)
return ApiRecoveryKeyStatus(
exists=True,
valid=is_recovery_token_valid(),
creation_date=parse_date(status["date"]),
expiration_date=parse_date(status["expiration"])
if status["expiration"] is not None
else None,
uses_left=status["uses_left"] if status["uses_left"] is not None else None,
valid=status.valid,
creation_date=status.date,
expiration_date=status.expiration,
uses_left=status.uses_left,
)

View File

@ -97,8 +97,8 @@ class Jobs:
error=None,
result=None,
)
r = RedisPool().get_connection()
_store_job_as_hash(r, _redis_key_from_uuid(job.uid), job)
redis = RedisPool().get_connection()
_store_job_as_hash(redis, _redis_key_from_uuid(job.uid), job)
return job
@staticmethod
@ -113,10 +113,10 @@ class Jobs:
"""
Remove a job from the jobs list.
"""
r = RedisPool().get_connection()
redis = RedisPool().get_connection()
key = _redis_key_from_uuid(job_uuid)
if (r.exists(key)):
r.delete(key)
if redis.exists(key):
redis.delete(key)
return True
return False
@ -149,12 +149,12 @@ class Jobs:
if status in (JobStatus.FINISHED, JobStatus.ERROR):
job.finished_at = datetime.datetime.now()
r = RedisPool().get_connection()
redis = RedisPool().get_connection()
key = _redis_key_from_uuid(job.uid)
if r.exists(key):
_store_job_as_hash(r, key, job)
if redis.exists(key):
_store_job_as_hash(redis, key, job)
if status in (JobStatus.FINISHED, JobStatus.ERROR):
r.expire(key, JOB_EXPIRATION_SECONDS)
redis.expire(key, JOB_EXPIRATION_SECONDS)
return job
@ -163,10 +163,10 @@ class Jobs:
"""
Get a job from the jobs list.
"""
r = RedisPool().get_connection()
redis = RedisPool().get_connection()
key = _redis_key_from_uuid(uid)
if r.exists(key):
return _job_from_hash(r, key)
if redis.exists(key):
return _job_from_hash(redis, key)
return None
@staticmethod
@ -174,9 +174,14 @@ class Jobs:
"""
Get the jobs list.
"""
r = RedisPool().get_connection()
jobs = r.keys("jobs:*")
return [_job_from_hash(r, job_key) for job_key in jobs]
redis = RedisPool().get_connection()
job_keys = redis.keys("jobs:*")
jobs = []
for job_key in job_keys:
job = _job_from_hash(redis, job_key)
if job is not None:
jobs.append(job)
return jobs
@staticmethod
def is_busy() -> bool:
@ -189,11 +194,11 @@ class Jobs:
return False
def _redis_key_from_uuid(uuid):
return "jobs:" + str(uuid)
def _redis_key_from_uuid(uuid_string):
return "jobs:" + str(uuid_string)
def _store_job_as_hash(r, redis_key, model):
def _store_job_as_hash(redis, redis_key, model):
for key, value in model.dict().items():
if isinstance(value, uuid.UUID):
value = str(value)
@ -201,12 +206,12 @@ def _store_job_as_hash(r, redis_key, model):
value = value.isoformat()
if isinstance(value, JobStatus):
value = value.value
r.hset(redis_key, key, str(value))
redis.hset(redis_key, key, str(value))
def _job_from_hash(r, redis_key):
if r.exists(redis_key):
job_dict = r.hgetall(redis_key)
def _job_from_hash(redis, redis_key):
if redis.exists(redis_key):
job_dict = redis.hgetall(redis_key)
for date in [
"created_at",
"updated_at",

View File

@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
from mnemonic import Mnemonic
from secrets import randbelow
import re
from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.repositories.tokens.exceptions import (
@ -15,7 +17,7 @@ from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
class AbstractTokensRepository(ABC):
def get_token_by_token_string(self, token_string: str) -> Optional[Token]:
def get_token_by_token_string(self, token_string: str) -> Token:
"""Get the token by token"""
tokens = self.get_tokens()
for token in tokens:
@ -24,7 +26,7 @@ class AbstractTokensRepository(ABC):
raise TokenNotFound("Token not found!")
def get_token_by_name(self, token_name: str) -> Optional[Token]:
def get_token_by_name(self, token_name: str) -> Token:
"""Get the token by name"""
tokens = self.get_tokens()
for token in tokens:
@ -39,7 +41,8 @@ class AbstractTokensRepository(ABC):
def create_token(self, device_name: str) -> Token:
"""Create new token"""
new_token = Token.generate(device_name)
unique_name = self._make_unique_device_name(device_name)
new_token = Token.generate(unique_name)
self._store_token(new_token)
@ -52,6 +55,7 @@ class AbstractTokensRepository(ABC):
def refresh_token(self, input_token: Token) -> Token:
"""Change the token field of the existing token"""
new_token = Token.generate(device_name=input_token.device_name)
new_token.created_at = input_token.created_at
if input_token in self.get_tokens():
self.delete_token(input_token)
@ -62,22 +66,19 @@ class AbstractTokensRepository(ABC):
def is_token_valid(self, token_string: str) -> bool:
"""Check if the token is valid"""
token = self.get_token_by_token_string(token_string)
if token is None:
return False
return True
return token_string in [token.token for token in self.get_tokens()]
def is_token_name_exists(self, token_name: str) -> bool:
"""Check if the token name exists"""
token = self.get_token_by_name(token_name)
if token is None:
return False
return True
return token_name in [token.device_name for token in self.get_tokens()]
def is_token_name_pair_valid(self, token_name: str, token_string: str) -> bool:
"""Check if the token name and token are valid"""
token = self.get_token_by_name(token_name)
if token is None:
try:
token = self.get_token_by_name(token_name)
if token is None:
return False
except TokenNotFound:
return False
return token.token == token_string
@ -100,7 +101,12 @@ class AbstractTokensRepository(ABC):
if not self.is_recovery_key_valid():
raise RecoveryKeyNotFound("Recovery key not found")
recovery_hex_key = self.get_recovery_key().key
recovery_key = self.get_recovery_key()
if recovery_key is None:
raise RecoveryKeyNotFound("Recovery key not found")
recovery_hex_key = recovery_key.key
if not self._assert_mnemonic(recovery_hex_key, mnemonic_phrase):
raise RecoveryKeyNotFound("Recovery key not found")
@ -117,9 +123,15 @@ class AbstractTokensRepository(ABC):
return False
return recovery_key.is_valid()
@abstractmethod
def get_new_device_key(self) -> NewDeviceKey:
"""Creates and returns the new device key"""
new_device_key = NewDeviceKey.generate()
self._store_new_device_key(new_device_key)
return new_device_key
def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
"""Store new device key directly"""
@abstractmethod
def delete_new_device_key(self) -> None:
@ -133,6 +145,9 @@ class AbstractTokensRepository(ABC):
if not new_device_key:
raise NewDeviceKeyNotFound
if not new_device_key.is_valid():
raise NewDeviceKeyNotFound
if not self._assert_mnemonic(new_device_key.key, mnemonic_phrase):
raise NewDeviceKeyNotFound("Phrase is not token!")
@ -153,6 +168,19 @@ class AbstractTokensRepository(ABC):
def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
"""Retrieves new device key that is already stored."""
def _make_unique_device_name(self, name: str) -> str:
"""Token name must be an alphanumeric string and not empty.
Replace invalid characters with '_'
If name exists, add a random number to the end of the name until it is unique.
"""
if not re.match("^[a-zA-Z0-9]*$", name):
name = re.sub("[^a-zA-Z0-9]", "_", name)
if name == "":
name = "Unknown device"
while self.is_token_name_exists(name):
name += str(randbelow(10))
return name
# TODO: find a proper place for it
def _assert_mnemonic(self, hex_key: str, mnemonic_phrase: str):
"""Return true if hex string matches the phrase, false otherwise

View File

@ -69,7 +69,7 @@ class JsonTokensRepository(AbstractTokensRepository):
recovery_key = RecoveryKey(
key=tokens_file["recovery_token"].get("token"),
created_at=tokens_file["recovery_token"].get("date"),
expires_at=tokens_file["recovery_token"].get("expitation"),
expires_at=tokens_file["recovery_token"].get("expiration"),
uses_left=tokens_file["recovery_token"].get("uses_left"),
)
@ -85,10 +85,13 @@ class JsonTokensRepository(AbstractTokensRepository):
recovery_key = RecoveryKey.generate(expiration, uses_left)
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
key_expiration: Optional[str] = None
if recovery_key.expires_at is not None:
key_expiration = recovery_key.expires_at.strftime(DATETIME_FORMAT)
tokens_file["recovery_token"] = {
"token": recovery_key.key,
"date": recovery_key.created_at.strftime(DATETIME_FORMAT),
"expiration": recovery_key.expires_at,
"expiration": key_expiration,
"uses_left": recovery_key.uses_left,
}
@ -98,12 +101,10 @@ class JsonTokensRepository(AbstractTokensRepository):
"""Decrement recovery key use count by one"""
if self.is_recovery_key_valid():
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["recovery_token"]["uses_left"] -= 1
def get_new_device_key(self) -> NewDeviceKey:
"""Creates and returns the new device key"""
new_device_key = NewDeviceKey.generate()
if tokens["recovery_token"]["uses_left"] is not None:
tokens["recovery_token"]["uses_left"] -= 1
def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
tokens_file["new_device"] = {
"token": new_device_key.key,
@ -111,8 +112,6 @@ class JsonTokensRepository(AbstractTokensRepository):
"expiration": new_device_key.expires_at.strftime(DATETIME_FORMAT),
}
return new_device_key
def delete_new_device_key(self) -> None:
"""Delete the new device key"""
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:

View File

@ -32,29 +32,34 @@ class RedisTokensRepository(AbstractTokensRepository):
def get_tokens(self) -> list[Token]:
"""Get the tokens"""
r = self.connection
token_keys = r.keys(TOKENS_PREFIX + "*")
return [self._token_from_hash(key) for key in token_keys]
redis = self.connection
token_keys = redis.keys(TOKENS_PREFIX + "*")
tokens = []
for key in token_keys:
token = self._token_from_hash(key)
if token is not None:
tokens.append(token)
return tokens
def delete_token(self, input_token: Token) -> None:
"""Delete the token"""
r = self.connection
redis = self.connection
key = RedisTokensRepository._token_redis_key(input_token)
if input_token not in self.get_tokens():
raise TokenNotFound
r.delete(key)
redis.delete(key)
def reset(self):
for token in self.get_tokens():
self.delete_token(token)
self.delete_new_device_key()
r = self.connection
r.delete(RECOVERY_KEY_REDIS_KEY)
redis = self.connection
redis.delete(RECOVERY_KEY_REDIS_KEY)
def get_recovery_key(self) -> Optional[RecoveryKey]:
"""Get the recovery key"""
r = self.connection
if r.exists(RECOVERY_KEY_REDIS_KEY):
redis = self.connection
if redis.exists(RECOVERY_KEY_REDIS_KEY):
return self._recovery_key_from_hash(RECOVERY_KEY_REDIS_KEY)
return None
@ -68,16 +73,14 @@ class RedisTokensRepository(AbstractTokensRepository):
self._store_model_as_hash(RECOVERY_KEY_REDIS_KEY, recovery_key)
return recovery_key
def get_new_device_key(self) -> NewDeviceKey:
"""Creates and returns the new device key"""
new_device_key = NewDeviceKey.generate()
def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
"""Store new device key directly"""
self._store_model_as_hash(NEW_DEVICE_KEY_REDIS_KEY, new_device_key)
return new_device_key
def delete_new_device_key(self) -> None:
"""Delete the new device key"""
r = self.connection
r.delete(NEW_DEVICE_KEY_REDIS_KEY)
redis = self.connection
redis.delete(NEW_DEVICE_KEY_REDIS_KEY)
@staticmethod
def _token_redis_key(token: Token) -> str:
@ -91,9 +94,13 @@ class RedisTokensRepository(AbstractTokensRepository):
def _decrement_recovery_token(self):
"""Decrement recovery key use count by one"""
if self.is_recovery_key_valid():
uses_left = self.get_recovery_key().uses_left
r = self.connection
r.hset(RECOVERY_KEY_REDIS_KEY, "uses_left", uses_left - 1)
recovery_key = self.get_recovery_key()
if recovery_key is None:
return
uses_left = recovery_key.uses_left
if uses_left is not None:
redis = self.connection
redis.hset(RECOVERY_KEY_REDIS_KEY, "uses_left", uses_left - 1)
def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
"""Retrieves new device key that is already stored."""
@ -117,9 +124,9 @@ class RedisTokensRepository(AbstractTokensRepository):
d[key] = None
def _model_dict_from_hash(self, redis_key: str) -> Optional[dict]:
r = self.connection
if r.exists(redis_key):
token_dict = r.hgetall(redis_key)
redis = self.connection
if redis.exists(redis_key):
token_dict = redis.hgetall(redis_key)
RedisTokensRepository._prepare_model_dict(token_dict)
return token_dict
return None
@ -140,8 +147,8 @@ class RedisTokensRepository(AbstractTokensRepository):
return self._hash_as_model(redis_key, NewDeviceKey)
def _store_model_as_hash(self, redis_key, model):
r = self.connection
redis = self.connection
for key, value in model.dict().items():
if isinstance(value, datetime):
value = value.isoformat()
r.hset(redis_key, key, str(value))
redis.hset(redis_key, key, str(value))

View File

@ -8,20 +8,18 @@ from selfprivacy_api.actions.api_tokens import (
InvalidUsesLeft,
NotFoundException,
delete_api_token,
refresh_api_token,
get_api_recovery_token_status,
get_api_tokens_with_caller_flag,
get_new_api_recovery_key,
refresh_api_token,
use_mnemonic_recovery_token,
delete_new_device_auth_token,
get_new_device_auth_token,
use_new_device_auth_token,
)
from selfprivacy_api.dependencies import TokenHeader, get_token_header
from selfprivacy_api.utils.auth import (
delete_new_device_auth_token,
get_new_device_auth_token,
use_mnemonic_recoverery_token,
use_new_device_auth_token,
)
router = APIRouter(
prefix="/auth",
@ -99,7 +97,7 @@ class UseTokenInput(BaseModel):
@router.post("/recovery_token/use")
async def rest_use_recovery_token(input: UseTokenInput):
token = use_mnemonic_recoverery_token(input.token, input.device)
token = use_mnemonic_recovery_token(input.token, input.device)
if token is None:
raise HTTPException(status_code=404, detail="Token not found")
return {"token": token}

View File

@ -1,329 +0,0 @@
#!/usr/bin/env python3
"""Token management utils"""
import secrets
from datetime import datetime, timedelta
import re
import typing
from pydantic import BaseModel
from mnemonic import Mnemonic
from . import ReadUserData, UserDataFiles, WriteUserData, parse_date
"""
Token are stored in the tokens.json file.
File contains device tokens, recovery token and new device auth token.
File structure:
{
"tokens": [
{
"token": "device token",
"name": "device name",
"date": "date of creation",
}
],
"recovery_token": {
"token": "recovery token",
"date": "date of creation",
"expiration": "date of expiration",
"uses_left": "number of uses left"
},
"new_device": {
"token": "new device auth token",
"date": "date of creation",
"expiration": "date of expiration",
}
}
Recovery token may or may not have expiration date and uses_left.
There may be no recovery token at all.
Device tokens must be unique.
"""
def _get_tokens():
"""Get all tokens as list of tokens of every device"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return [token["token"] for token in tokens["tokens"]]
def _get_token_names():
"""Get all token names"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return [t["name"] for t in tokens["tokens"]]
def _validate_token_name(name):
"""Token name must be an alphanumeric string and not empty.
Replace invalid characters with '_'
If token name exists, add a random number to the end of the name until it is unique.
"""
if not re.match("^[a-zA-Z0-9]*$", name):
name = re.sub("[^a-zA-Z0-9]", "_", name)
if name == "":
name = "Unknown device"
while name in _get_token_names():
name += str(secrets.randbelow(10))
return name
def is_token_valid(token):
"""Check if token is valid"""
if token in _get_tokens():
return True
return False
def is_token_name_exists(token_name):
"""Check if token name exists"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return token_name in [t["name"] for t in tokens["tokens"]]
def is_token_name_pair_valid(token_name, token):
"""Check if token name and token pair exists"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
for t in tokens["tokens"]:
if t["name"] == token_name and t["token"] == token:
return True
return False
def get_token_name(token: str) -> typing.Optional[str]:
"""Return the name of the token provided"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
for t in tokens["tokens"]:
if t["token"] == token:
return t["name"]
return None
class BasicTokenInfo(BaseModel):
"""Token info"""
name: str
date: datetime
def get_tokens_info():
"""Get all tokens info without tokens themselves"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return [
BasicTokenInfo(
name=t["name"],
date=parse_date(t["date"]),
)
for t in tokens["tokens"]
]
def _generate_token():
"""Generates new token and makes sure it is unique"""
token = secrets.token_urlsafe(32)
while token in _get_tokens():
token = secrets.token_urlsafe(32)
return token
def create_token(name):
"""Create new token"""
token = _generate_token()
name = _validate_token_name(name)
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["tokens"].append(
{
"token": token,
"name": name,
"date": str(datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")),
}
)
return token
def delete_token(token_name):
"""Delete token"""
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["tokens"] = [t for t in tokens["tokens"] if t["name"] != token_name]
def refresh_token(token: str) -> typing.Optional[str]:
"""Change the token field of the existing token"""
new_token = _generate_token()
with WriteUserData(UserDataFiles.TOKENS) as tokens:
for t in tokens["tokens"]:
if t["token"] == token:
t["token"] = new_token
return new_token
return None
def is_recovery_token_exists():
"""Check if recovery token exists"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return "recovery_token" in tokens
def is_recovery_token_valid():
"""Check if recovery token is valid"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
if "recovery_token" not in tokens:
return False
recovery_token = tokens["recovery_token"]
if "uses_left" in recovery_token and recovery_token["uses_left"] is not None:
if recovery_token["uses_left"] <= 0:
return False
if "expiration" not in recovery_token or recovery_token["expiration"] is None:
return True
return datetime.now() < parse_date(recovery_token["expiration"])
def get_recovery_token_status():
"""Get recovery token date of creation, expiration and uses left"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
if "recovery_token" not in tokens:
return None
recovery_token = tokens["recovery_token"]
return {
"date": recovery_token["date"],
"expiration": recovery_token["expiration"]
if "expiration" in recovery_token
else None,
"uses_left": recovery_token["uses_left"]
if "uses_left" in recovery_token
else None,
}
def _get_recovery_token():
"""Get recovery token"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
if "recovery_token" not in tokens:
return None
return tokens["recovery_token"]["token"]
def generate_recovery_token(
expiration: typing.Optional[datetime], uses_left: typing.Optional[int]
) -> str:
"""Generate a 24 bytes recovery token and return a mneomnic word list.
Write a string representation of the recovery token to the tokens.json file.
"""
# expires must be a date or None
# uses_left must be an integer or None
if expiration is not None:
if not isinstance(expiration, datetime):
raise TypeError("expires must be a datetime object")
if uses_left is not None:
if not isinstance(uses_left, int):
raise TypeError("uses_left must be an integer")
if uses_left <= 0:
raise ValueError("uses_left must be greater than 0")
recovery_token = secrets.token_bytes(24)
recovery_token_str = recovery_token.hex()
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["recovery_token"] = {
"token": recovery_token_str,
"date": str(datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")),
"expiration": expiration.strftime("%Y-%m-%dT%H:%M:%S.%f")
if expiration is not None
else None,
"uses_left": uses_left if uses_left is not None else None,
}
return Mnemonic(language="english").to_mnemonic(recovery_token)
def use_mnemonic_recoverery_token(mnemonic_phrase, name):
"""Use the recovery token by converting the mnemonic word list to a byte array.
If the recovery token if invalid itself, return None
If the binary representation of phrase not matches
the byte array of the recovery token, return None.
If the mnemonic phrase is valid then generate a device token and return it.
Substract 1 from uses_left if it exists.
mnemonic_phrase is a string representation of the mnemonic word list.
"""
if not is_recovery_token_valid():
return None
recovery_token_str = _get_recovery_token()
if recovery_token_str is None:
return None
recovery_token = bytes.fromhex(recovery_token_str)
if not Mnemonic(language="english").check(mnemonic_phrase):
return None
phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase)
if phrase_bytes != recovery_token:
return None
token = _generate_token()
name = _validate_token_name(name)
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["tokens"].append(
{
"token": token,
"name": name,
"date": str(datetime.now()),
}
)
if "recovery_token" in tokens:
if (
"uses_left" in tokens["recovery_token"]
and tokens["recovery_token"]["uses_left"] is not None
):
tokens["recovery_token"]["uses_left"] -= 1
return token
def get_new_device_auth_token() -> str:
"""Generate a new device auth token which is valid for 10 minutes
and return a mnemonic phrase representation
Write token to the new_device of the tokens.json file.
"""
token = secrets.token_bytes(16)
token_str = token.hex()
with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["new_device"] = {
"token": token_str,
"date": str(datetime.now()),
"expiration": str(datetime.now() + timedelta(minutes=10)),
}
return Mnemonic(language="english").to_mnemonic(token)
def _get_new_device_auth_token():
"""Get new device auth token. If it is expired, return None"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
if "new_device" not in tokens:
return None
new_device = tokens["new_device"]
if "expiration" not in new_device:
return None
expiration = parse_date(new_device["expiration"])
if datetime.now() > expiration:
return None
return new_device["token"]
def delete_new_device_auth_token():
"""Delete new device auth token"""
with WriteUserData(UserDataFiles.TOKENS) as tokens:
if "new_device" in tokens:
del tokens["new_device"]
def use_new_device_auth_token(mnemonic_phrase, name):
"""Use the new device auth token by converting the mnemonic string to a byte array.
If the mnemonic phrase is valid then generate a device token and return it.
New device auth token must be deleted.
"""
token_str = _get_new_device_auth_token()
if token_str is None:
return None
token = bytes.fromhex(token_str)
if not Mnemonic(language="english").check(mnemonic_phrase):
return None
phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase)
if phrase_bytes != token:
return None
token = create_token(name)
with WriteUserData(UserDataFiles.TOKENS) as tokens:
if "new_device" in tokens:
del tokens["new_device"]
return token

View File

@ -2,8 +2,14 @@
# pylint: disable=unused-argument
# pylint: disable=missing-function-docstring
import datetime
import pytest
from mnemonic import Mnemonic
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
JsonTokensRepository,
)
from selfprivacy_api.models.tokens.token import Token
from tests.common import generate_api_query, read_json, write_json
TOKENS_FILE_CONTETS = {
@ -30,6 +36,11 @@ devices {
"""
@pytest.fixture
def token_repo():
return JsonTokensRepository()
def test_graphql_tokens_info(authorized_client, tokens_file):
response = authorized_client.post(
"/graphql",
@ -170,7 +181,7 @@ def test_graphql_refresh_token_unauthorized(client, tokens_file):
assert response.json()["data"] is None
def test_graphql_refresh_token(authorized_client, tokens_file):
def test_graphql_refresh_token(authorized_client, tokens_file, token_repo):
response = authorized_client.post(
"/graphql",
json={"query": REFRESH_TOKEN_MUTATION},
@ -180,11 +191,12 @@ def test_graphql_refresh_token(authorized_client, tokens_file):
assert response.json()["data"]["refreshDeviceApiToken"]["success"] is True
assert response.json()["data"]["refreshDeviceApiToken"]["message"] is not None
assert response.json()["data"]["refreshDeviceApiToken"]["code"] == 200
assert read_json(tokens_file)["tokens"][0] == {
"token": response.json()["data"]["refreshDeviceApiToken"]["token"],
"name": "test_token",
"date": "2022-01-14 08:31:10.789314",
}
token = token_repo.get_token_by_name("test_token")
assert token == Token(
token=response.json()["data"]["refreshDeviceApiToken"]["token"],
device_name="test_token",
created_at=datetime.datetime(2022, 1, 14, 8, 31, 10, 789314),
)
NEW_DEVICE_KEY_MUTATION = """

View File

@ -2,7 +2,8 @@
# pylint: disable=unused-argument
# pylint: disable=missing-function-docstring
from datetime import datetime
from datetime import datetime, timedelta
from mnemonic import Mnemonic
import pytest
@ -32,6 +33,10 @@ ORIGINAL_DEVICE_NAMES = [
]
def mnemonic_from_hex(hexkey):
return Mnemonic(language="english").to_mnemonic(bytes.fromhex(hexkey))
@pytest.fixture
def empty_keys(mocker, datadir):
mocker.patch("selfprivacy_api.utils.TOKENS_FILE", new=datadir / "empty_keys.json")
@ -132,21 +137,6 @@ def mock_recovery_key_generate(mocker):
return mock
@pytest.fixture
def mock_recovery_key_generate_for_mnemonic(mocker):
mock = mocker.patch(
"selfprivacy_api.models.tokens.recovery_key.RecoveryKey.generate",
autospec=True,
return_value=RecoveryKey(
key="ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
expires_at=None,
uses_left=1,
),
)
return mock
@pytest.fixture
def empty_json_repo(empty_keys):
repo = JsonTokensRepository()
@ -221,6 +211,28 @@ def test_get_token_by_non_existent_name(some_tokens_repo):
assert repo.get_token_by_name(token_name="badname") is None
def test_is_token_valid(some_tokens_repo):
repo = some_tokens_repo
token = repo.get_tokens()[0]
assert repo.is_token_valid(token.token)
assert not repo.is_token_valid("gibberish")
def test_is_token_name_pair_valid(some_tokens_repo):
repo = some_tokens_repo
token = repo.get_tokens()[0]
assert repo.is_token_name_pair_valid(token.device_name, token.token)
assert not repo.is_token_name_pair_valid(token.device_name, "gibberish")
assert not repo.is_token_name_pair_valid("gibberish", token.token)
def test_is_token_name_exists(some_tokens_repo):
repo = some_tokens_repo
token = repo.get_tokens()[0]
assert repo.is_token_name_exists(token.device_name)
assert not repo.is_token_name_exists("gibberish")
def test_get_tokens(some_tokens_repo):
repo = some_tokens_repo
tokenstrings = []
@ -249,6 +261,17 @@ def test_create_token(empty_repo, mock_token_generate):
]
def test_create_token_existing(some_tokens_repo):
repo = some_tokens_repo
old_token = repo.get_tokens()[0]
new_token = repo.create_token(device_name=old_token.device_name)
assert new_token.device_name != old_token.device_name
assert old_token in repo.get_tokens()
assert new_token in repo.get_tokens()
def test_delete_token(some_tokens_repo):
repo = some_tokens_repo
original_tokens = repo.get_tokens()
@ -280,15 +303,17 @@ def test_delete_not_found_token(some_tokens_repo):
assert token in new_tokens
def test_refresh_token(some_tokens_repo, mock_token_generate):
def test_refresh_token(some_tokens_repo):
repo = some_tokens_repo
input_token = some_tokens_repo.get_tokens()[0]
assert repo.refresh_token(input_token) == Token(
token="ZuLNKtnxDeq6w2dpOJhbB3iat_sJLPTPl_rN5uc5MvM",
device_name="IamNewDevice",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
)
output_token = repo.refresh_token(input_token)
assert output_token.token != input_token.token
assert output_token.device_name == input_token.device_name
assert output_token.created_at == input_token.created_at
assert output_token in repo.get_tokens()
def test_refresh_not_found_token(some_tokens_repo, mock_token_generate):
@ -355,6 +380,23 @@ def test_use_mnemonic_not_valid_recovery_key(
)
def test_use_mnemonic_expired_recovery_key(
some_tokens_repo,
):
repo = some_tokens_repo
expiration = datetime.now() - timedelta(minutes=5)
assert repo.create_recovery_key(uses_left=2, expiration=expiration) is not None
recovery_key = repo.get_recovery_key()
assert recovery_key.expires_at == expiration
assert not repo.is_recovery_key_valid()
with pytest.raises(RecoveryKeyNotFound):
token = repo.use_mnemonic_recovery_key(
mnemonic_phrase=mnemonic_from_hex(recovery_key.key),
device_name="newdevice",
)
def test_use_mnemonic_not_mnemonic_recovery_key(some_tokens_repo):
repo = some_tokens_repo
assert repo.create_recovery_key(uses_left=1, expiration=None) is not None
@ -397,46 +439,38 @@ def test_use_not_found_mnemonic_recovery_key(some_tokens_repo):
)
def test_use_mnemonic_recovery_key_when_empty(empty_repo):
repo = empty_repo
with pytest.raises(RecoveryKeyNotFound):
assert (
repo.use_mnemonic_recovery_key(
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb",
device_name="primary_token",
)
is None
)
@pytest.fixture(params=["recovery_uses_1", "recovery_eternal"])
def recovery_key_uses_left(request):
if request.param == "recovery_uses_1":
return 1
if request.param == "recovery_eternal":
return None
# agnostic test mixed with an implementation test
def test_use_mnemonic_recovery_key(
some_tokens_repo, mock_recovery_key_generate_for_mnemonic, mock_generate_token
):
def test_use_mnemonic_recovery_key(some_tokens_repo, recovery_key_uses_left):
repo = some_tokens_repo
assert repo.create_recovery_key(uses_left=1, expiration=None) is not None
test_token = Token(
token="ur71mC4aiI6FIYAN--cTL-38rPHS5D6NuB1bgN_qKF4",
device_name="newdevice",
created_at=datetime(2022, 11, 14, 6, 6, 32, 777123),
)
assert (
repo.use_mnemonic_recovery_key(
mnemonic_phrase="uniform clarify napkin bid dress search input armor police cross salon because myself uphold slice bamboo hungry park",
device_name="newdevice",
)
== test_token
repo.create_recovery_key(uses_left=recovery_key_uses_left, expiration=None)
is not None
)
assert repo.is_recovery_key_valid()
recovery_key = repo.get_recovery_key()
token = repo.use_mnemonic_recovery_key(
mnemonic_phrase=mnemonic_from_hex(recovery_key.key),
device_name="newdevice",
)
assert test_token in repo.get_tokens()
assert token.device_name == "newdevice"
assert token in repo.get_tokens()
new_uses = None
if recovery_key_uses_left is not None:
new_uses = recovery_key_uses_left - 1
assert repo.get_recovery_key() == RecoveryKey(
key="ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
key=recovery_key.key,
created_at=recovery_key.created_at,
expires_at=None,
uses_left=0,
uses_left=new_uses,
)
@ -497,15 +531,16 @@ def test_use_not_exists_mnemonic_new_device_key(
)
def test_use_mnemonic_new_device_key(
empty_repo, mock_new_device_key_generate_for_mnemonic
):
def test_use_mnemonic_new_device_key(empty_repo):
repo = empty_repo
assert repo.get_new_device_key() is not None
key = repo.get_new_device_key()
assert key is not None
mnemonic_phrase = mnemonic_from_hex(key.key)
new_token = repo.use_mnemonic_new_device_key(
device_name="imnew",
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb",
mnemonic_phrase=mnemonic_phrase,
)
assert new_token.device_name == "imnew"
@ -516,12 +551,32 @@ def test_use_mnemonic_new_device_key(
assert (
repo.use_mnemonic_new_device_key(
device_name="imnew",
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb",
mnemonic_phrase=mnemonic_phrase,
)
is None
)
def test_use_mnemonic_expired_new_device_key(
some_tokens_repo,
):
repo = some_tokens_repo
expiration = datetime.now() - timedelta(minutes=5)
key = repo.get_new_device_key()
assert key is not None
assert key.expires_at is not None
key.expires_at = expiration
assert not key.is_valid()
repo._store_new_device_key(key)
with pytest.raises(NewDeviceKeyNotFound):
token = repo.use_mnemonic_new_device_key(
mnemonic_phrase=mnemonic_from_hex(key.key),
device_name="imnew",
)
def test_use_mnemonic_new_device_key_when_empty(empty_repo):
repo = empty_repo

18
tests/test_models.py Normal file
View File

@ -0,0 +1,18 @@
import pytest
from datetime import datetime, timedelta
from selfprivacy_api.models.tokens.recovery_key import RecoveryKey
from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
def test_recovery_key_expired():
expiration = datetime.now() - timedelta(minutes=5)
key = RecoveryKey.generate(expiration=expiration, uses_left=2)
assert not key.is_valid()
def test_new_device_key_expired():
expiration = datetime.now() - timedelta(minutes=5)
key = NewDeviceKey.generate()
key.expires_at = expiration
assert not key.is_valid()

View File

@ -5,6 +5,12 @@ import datetime
import pytest
from mnemonic import Mnemonic
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
JsonTokensRepository,
)
TOKEN_REPO = JsonTokensRepository()
from tests.common import read_json, write_json
@ -97,7 +103,7 @@ def test_refresh_token(authorized_client, tokens_file):
response = authorized_client.post("/auth/tokens")
assert response.status_code == 200
new_token = response.json()["token"]
assert read_json(tokens_file)["tokens"][0]["token"] == new_token
assert TOKEN_REPO.get_token_by_token_string(new_token) is not None
# new device