refactor(tokens-repository): move use_mnemonic_recovery_key() to abstract class

pull/26/head
Houkime 2022-12-12 14:22:36 +00:00
parent 671203e990
commit 772c0dfc64
2 changed files with 34 additions and 28 deletions

View File

@ -1,9 +1,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from mnemonic import Mnemonic
from selfprivacy_api.models.tokens.token import Token from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.repositories.tokens.exceptions import TokenNotFound from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound,
InvalidMnemonic,
RecoveryKeyNotFound,
)
from selfprivacy_api.models.tokens.recovery_key import RecoveryKey from selfprivacy_api.models.tokens.recovery_key import RecoveryKey
from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
@ -87,11 +92,22 @@ class AbstractTokensRepository(ABC):
) -> RecoveryKey: ) -> RecoveryKey:
"""Create the recovery key""" """Create the recovery key"""
@abstractmethod
def use_mnemonic_recovery_key( def use_mnemonic_recovery_key(
self, mnemonic_phrase: str, device_name: str self, mnemonic_phrase: str, device_name: str
) -> Token: ) -> Token:
"""Use the mnemonic recovery key and create a new token with the given name""" """Use the mnemonic recovery key and create a new token with the given name"""
if not self.is_recovery_key_valid():
raise RecoveryKeyNotFound("Recovery key not found")
recovery_hex_key = self.get_recovery_key().key
if not self._assert_mnemonic(recovery_hex_key, mnemonic_phrase):
raise RecoveryKeyNotFound("Recovery key not found")
new_token = self.create_token(device_name=device_name)
self._decrement_recovery_token()
return new_token
def is_recovery_key_valid(self) -> bool: def is_recovery_key_valid(self) -> bool:
"""Check if the recovery key is valid""" """Check if the recovery key is valid"""
@ -117,3 +133,18 @@ class AbstractTokensRepository(ABC):
@abstractmethod @abstractmethod
def _store_token(self, new_token: Token): def _store_token(self, new_token: Token):
"""Store a token directly""" """Store a token directly"""
@abstractmethod
def _decrement_recovery_token(self):
"""Decrement recovery key use count by one"""
# TODO: find a proper place for it
def _assert_mnemonic(self, hex_key: str, mnemonic_phrase: str):
"""Return true if hex string matches the phrase, false otherwise
Raise an InvalidMnemonic error if not mnemonic"""
recovery_token = bytes.fromhex(hex_key)
if not Mnemonic(language="english").check(mnemonic_phrase):
raise InvalidMnemonic("Phrase is not mnemonic!")
phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase)
return phrase_bytes == recovery_token

View File

@ -11,7 +11,6 @@ from selfprivacy_api.models.tokens.recovery_key import RecoveryKey
from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
from selfprivacy_api.repositories.tokens.exceptions import ( from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound, TokenNotFound,
RecoveryKeyNotFound,
InvalidMnemonic, InvalidMnemonic,
NewDeviceKeyNotFound, NewDeviceKeyNotFound,
) )
@ -98,36 +97,12 @@ class JsonTokensRepository(AbstractTokensRepository):
return recovery_key return recovery_key
def use_mnemonic_recovery_key(
self, mnemonic_phrase: str, device_name: str
) -> Token:
"""Use the mnemonic recovery key and create a new token with the given name"""
if not self.is_recovery_key_valid():
raise RecoveryKeyNotFound("Recovery key not found")
recovery_hex_key = self.get_recovery_key().key
if not self._assert_mnemonic(recovery_hex_key, mnemonic_phrase):
raise RecoveryKeyNotFound("Recovery key not found")
new_token = self.create_token(device_name=device_name)
self._decrement_recovery_token()
return new_token
def _decrement_recovery_token(self): def _decrement_recovery_token(self):
"""Decrement recovery key use count by one"""
if self.is_recovery_key_valid(): if self.is_recovery_key_valid():
with WriteUserData(UserDataFiles.TOKENS) as tokens: with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["recovery_token"]["uses_left"] -= 1 tokens["recovery_token"]["uses_left"] -= 1
def _assert_mnemonic(self, hex_key: str, mnemonic_phrase: str):
recovery_token = bytes.fromhex(hex_key)
if not Mnemonic(language="english").check(mnemonic_phrase):
raise InvalidMnemonic("Phrase is not mnemonic!")
phrase_bytes = Mnemonic(language="english").to_entropy(mnemonic_phrase)
return phrase_bytes == recovery_token
def get_new_device_key(self) -> NewDeviceKey: def get_new_device_key(self) -> NewDeviceKey:
"""Creates and returns the new device key""" """Creates and returns the new device key"""
new_device_key = NewDeviceKey.generate() new_device_key = NewDeviceKey.generate()