refactor(backups): linting

pull/35/head
Inex Code 2023-07-20 18:24:26 +03:00
parent 2df448a4a9
commit 0245d629fd
12 changed files with 196 additions and 112 deletions

View File

@ -1,3 +1,6 @@
"""
This module contains the controller class for backups.
"""
from datetime import datetime, timedelta
from os import statvfs
from typing import List, Optional
@ -42,8 +45,12 @@ DEFAULT_JSON_PROVIDER = {
class NotDeadError(AssertionError):
"""
This error is raised when we try to back up a service that is not dead yet.
"""
def __init__(self, service: Service):
self.service_name = service.get_id()
super().__init__()
def __str__(self):
return f"""
@ -61,6 +68,9 @@ class Backups:
@staticmethod
def provider() -> AbstractBackupProvider:
"""
Returns the current backup storage provider.
"""
return Backups._lookup_provider()
@staticmethod
@ -71,6 +81,13 @@ class Backups:
location: str,
repo_id: str = "",
) -> None:
"""
Sets the new configuration of the backup storage provider.
In case of `BackupProviderEnum.BACKBLAZE`, the `login` is the key ID,
the `key` is the key itself, and the `location` is the bucket name and
the `repo_id` is the bucket ID.
"""
provider: AbstractBackupProvider = Backups._construct_provider(
kind,
login,
@ -82,6 +99,9 @@ class Backups:
@staticmethod
def reset(reset_json=True) -> None:
"""
Deletes all the data about the backup storage provider.
"""
Storage.reset()
if reset_json:
try:
@ -183,11 +203,19 @@ class Backups:
@staticmethod
def init_repo() -> None:
"""
Initializes the backup repository. This is required once per repo.
"""
Backups.provider().backupper.init()
Storage.mark_as_init()
@staticmethod
def is_initted() -> bool:
"""
Returns whether the backup repository is initialized or not.
If it is not initialized, we cannot back up and probably should
call `init_repo` first.
"""
if Storage.has_init_mark():
return True
@ -219,9 +247,9 @@ class Backups:
)
Backups._store_last_snapshot(tag, snapshot)
service.post_restore()
except Exception as e:
except Exception as error:
Jobs.update(job, status=JobStatus.ERROR)
raise e
raise error
Jobs.update(job, status=JobStatus.FINISHED)
return snapshot
@ -252,16 +280,17 @@ class Backups:
snapshot.id,
verify=False,
)
except Exception as e:
except Exception as error:
Backups._restore_service_from_snapshot(
service, failsafe_snapshot.id, verify=False
)
raise e
raise error
@staticmethod
def restore_snapshot(
snapshot: Snapshot, strategy=RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE
) -> None:
"""Restores a snapshot to its original service using the given strategy"""
service = get_service_by_id(snapshot.service_name)
if service is None:
raise ValueError(
@ -283,9 +312,9 @@ class Backups:
service.post_restore()
except Exception as e:
except Exception as error:
Jobs.update(job, status=JobStatus.ERROR)
raise e
raise error
Jobs.update(job, status=JobStatus.FINISHED)
@ -338,6 +367,7 @@ class Backups:
@staticmethod
def get_snapshots(service: Service) -> List[Snapshot]:
"""Returns all snapshots for a given service"""
snapshots = Backups.get_all_snapshots()
service_id = service.get_id()
return list(
@ -349,8 +379,9 @@ class Backups:
@staticmethod
def get_all_snapshots() -> List[Snapshot]:
"""Returns all snapshots"""
cached_snapshots = Storage.get_cached_snapshots()
if cached_snapshots != []:
if cached_snapshots:
return cached_snapshots
# TODO: the oldest snapshots will get expired faster than the new ones.
# How to detect that the end is missing?
@ -359,24 +390,32 @@ class Backups:
return Storage.get_cached_snapshots()
@staticmethod
def get_snapshot_by_id(id: str) -> Optional[Snapshot]:
snap = Storage.get_cached_snapshot_by_id(id)
def get_snapshot_by_id(snapshot_id: str) -> Optional[Snapshot]:
"""Returns a backup snapshot by its id"""
snap = Storage.get_cached_snapshot_by_id(snapshot_id)
if snap is not None:
return snap
# Possibly our cache entry got invalidated, let's try one more time
Backups.force_snapshot_cache_reload()
snap = Storage.get_cached_snapshot_by_id(id)
snap = Storage.get_cached_snapshot_by_id(snapshot_id)
return snap
@staticmethod
def forget_snapshot(snapshot: Snapshot) -> None:
"""Deletes a snapshot from the storage"""
Backups.provider().backupper.forget_snapshot(snapshot.id)
Storage.delete_cached_snapshot(snapshot)
@staticmethod
def force_snapshot_cache_reload() -> None:
"""
Forces a reload of the snapshot cache.
This may be an expensive operation, so use it wisely.
User pays for the API calls.
"""
upstream_snapshots = Backups.provider().backupper.get_snapshots()
Storage.invalidate_snapshot_storage()
for snapshot in upstream_snapshots:
@ -384,6 +423,7 @@ class Backups:
@staticmethod
def snapshot_restored_size(snapshot_id: str) -> int:
"""Returns the size of the snapshot"""
return Backups.provider().backupper.restored_size(
snapshot_id,
)
@ -434,6 +474,7 @@ class Backups:
@staticmethod
def services_to_back_up(time: datetime) -> List[Service]:
"""Returns a list of services that should be backed up at a given time"""
return [
service
for service in get_all_services()
@ -447,6 +488,7 @@ class Backups:
@staticmethod
def is_time_to_backup_service(service: Service, time: datetime):
"""Returns True if it is time to back up a service"""
period = Backups.autobackup_period_minutes()
service_id = service.get_id()
if not service.can_be_backed_up():
@ -467,6 +509,10 @@ class Backups:
@staticmethod
def space_usable_for_service(service: Service) -> int:
"""
Returns the amount of space available on the volume the given
service is located on.
"""
folders = service.get_folders()
if folders == []:
raise ValueError("unallocated service", service.get_id())
@ -478,6 +524,8 @@ class Backups:
@staticmethod
def set_localfile_repo(file_path: str):
"""Used by tests to set a local folder as a backup repo"""
# pylint: disable-next=invalid-name
ProviderClass = get_provider(BackupProviderEnum.FILE)
provider = ProviderClass(
login="",
@ -490,10 +538,7 @@ class Backups:
@staticmethod
def assert_dead(service: Service):
"""
If we backup the service that is failing to restore it to the previous snapshot,
its status can be FAILED.
And obviously restoring a failed service is the main route
Checks if a service is dead and can be safely restored from a snapshot.
"""
if service.get_status() not in [
ServiceStatus.INACTIVE,

View File

@ -5,19 +5,25 @@ from selfprivacy_api.models.backup.snapshot import Snapshot
class AbstractBackupper(ABC):
"""Abstract class for backuppers"""
# flake8: noqa: B027
def __init__(self) -> None:
pass
@abstractmethod
def is_initted(self) -> bool:
"""Returns true if the repository is initted"""
raise NotImplementedError
@abstractmethod
def set_creds(self, account: str, key: str, repo: str) -> None:
"""Set the credentials for the backupper"""
raise NotImplementedError
@abstractmethod
def start_backup(self, folders: List[str], repo_name: str) -> Snapshot:
def start_backup(self, folders: List[str], tag: str) -> Snapshot:
"""Start a backup of the given folders"""
raise NotImplementedError
@abstractmethod
@ -27,6 +33,7 @@ class AbstractBackupper(ABC):
@abstractmethod
def init(self) -> None:
"""Initialize the repository"""
raise NotImplementedError
@abstractmethod
@ -41,8 +48,10 @@ class AbstractBackupper(ABC):
@abstractmethod
def restored_size(self, snapshot_id: str) -> int:
"""Get the size of the restored snapshot"""
raise NotImplementedError
@abstractmethod
def forget_snapshot(self, snapshot_id) -> None:
"""Forget a snapshot"""
raise NotImplementedError

View File

@ -5,13 +5,15 @@ from selfprivacy_api.backup.backuppers import AbstractBackupper
class NoneBackupper(AbstractBackupper):
"""A backupper that does nothing"""
def is_initted(self, repo_name: str = "") -> bool:
return False
def set_creds(self, account: str, key: str, repo: str):
pass
def start_backup(self, folders: List[str], repo_name: str):
def start_backup(self, folders: List[str], tag: str):
raise NotImplementedError
def get_snapshots(self) -> List[Snapshot]:
@ -21,7 +23,7 @@ class NoneBackupper(AbstractBackupper):
def init(self):
raise NotImplementedError
def restore_from_backup(self, snapshot_id: str, folders: List[str]):
def restore_from_backup(self, snapshot_id: str, folders: List[str], verify=True):
"""Restore a target folder using a snapshot"""
raise NotImplementedError

View File

@ -21,13 +21,14 @@ from selfprivacy_api.backup.local_secret import LocalBackupSecret
class ResticBackupper(AbstractBackupper):
def __init__(self, login_flag: str, key_flag: str, type: str) -> None:
def __init__(self, login_flag: str, key_flag: str, storage_type: str) -> None:
self.login_flag = login_flag
self.key_flag = key_flag
self.type = type
self.storage_type = storage_type
self.account = ""
self.key = ""
self.repo = ""
super().__init__()
def set_creds(self, account: str, key: str, repo: str) -> None:
self.account = account
@ -37,7 +38,7 @@ class ResticBackupper(AbstractBackupper):
def restic_repo(self) -> str:
# https://restic.readthedocs.io/en/latest/030_preparing_a_new_repo.html#other-services-via-rclone
# https://forum.rclone.org/t/can-rclone-be-run-solely-with-command-line-options-no-config-no-env-vars/6314/5
return f"rclone:{self.type}{self.repo}"
return f"rclone:{self.storage_type}{self.repo}"
def rclone_args(self):
return "rclone.args=serve restic --stdio " + self.backend_rclone_args()
@ -72,12 +73,12 @@ class ResticBackupper(AbstractBackupper):
tag,
]
)
if args != []:
if args:
command.extend(ResticBackupper.__flatten_list(args))
return command
def mount_repo(self, dir):
mount_command = self.restic_command("mount", dir)
def mount_repo(self, mount_directory):
mount_command = self.restic_command("mount", mount_directory)
mount_command.insert(0, "nohup")
handle = subprocess.Popen(
mount_command,
@ -85,28 +86,28 @@ class ResticBackupper(AbstractBackupper):
shell=False,
)
sleep(2)
if "ids" not in listdir(dir):
raise IOError("failed to mount dir ", dir)
if "ids" not in listdir(mount_directory):
raise IOError("failed to mount dir ", mount_directory)
return handle
def unmount_repo(self, dir):
mount_command = ["umount", "-l", dir]
def unmount_repo(self, mount_directory):
mount_command = ["umount", "-l", mount_directory]
with subprocess.Popen(
mount_command, stdout=subprocess.PIPE, shell=False
) as handle:
output = handle.communicate()[0].decode("utf-8")
# TODO: check for exit code?
if "error" in output.lower():
return IOError("failed to unmount dir ", dir, ": ", output)
return IOError("failed to unmount dir ", mount_directory, ": ", output)
if not listdir(dir) == []:
return IOError("failed to unmount dir ", dir)
if not listdir(mount_directory) == []:
return IOError("failed to unmount dir ", mount_directory)
@staticmethod
def __flatten_list(list):
def __flatten_list(list_to_flatten):
"""string-aware list flattener"""
result = []
for item in list:
for item in list_to_flatten:
if isinstance(item, Iterable) and not isinstance(item, str):
result.extend(ResticBackupper.__flatten_list(item))
continue
@ -147,8 +148,8 @@ class ResticBackupper(AbstractBackupper):
messages,
tag,
)
except ValueError as e:
raise ValueError("Could not create a snapshot: ", messages) from e
except ValueError as error:
raise ValueError("Could not create a snapshot: ", messages) from error
@staticmethod
def _snapshot_from_backup_messages(messages, repo_name) -> Snapshot:
@ -231,8 +232,8 @@ class ResticBackupper(AbstractBackupper):
try:
parsed_output = ResticBackupper.parse_json_output(output)
return parsed_output["total_size"]
except ValueError as e:
raise ValueError("cannot restore a snapshot: " + output) from e
except ValueError as error:
raise ValueError("cannot restore a snapshot: " + output) from error
def restore_from_backup(
self,
@ -246,13 +247,13 @@ class ResticBackupper(AbstractBackupper):
if folders is None or folders == []:
raise ValueError("cannot restore without knowing where to!")
with tempfile.TemporaryDirectory() as dir:
with tempfile.TemporaryDirectory() as temp_dir:
if verify:
self._raw_verified_restore(snapshot_id, target=dir)
snapshot_root = dir
self._raw_verified_restore(snapshot_id, target=temp_dir)
snapshot_root = temp_dir
else: # attempting inplace restore via mount + sync
self.mount_repo(dir)
snapshot_root = join(dir, "ids", snapshot_id)
self.mount_repo(temp_dir)
snapshot_root = join(temp_dir, "ids", snapshot_id)
assert snapshot_root is not None
for folder in folders:
@ -263,7 +264,7 @@ class ResticBackupper(AbstractBackupper):
sync(src, dst)
if not verify:
self.unmount_repo(dir)
self.unmount_repo(temp_dir)
def _raw_verified_restore(self, snapshot_id, target="/"):
"""barebones restic restore"""
@ -355,8 +356,8 @@ class ResticBackupper(AbstractBackupper):
raise ValueError("No repository! : " + output)
try:
return ResticBackupper.parse_json_output(output)
except ValueError as e:
raise ValueError("Cannot load snapshots: ") from e
except ValueError as error:
raise ValueError("Cannot load snapshots: ") from error
def get_snapshots(self) -> List[Snapshot]:
"""Get all snapshots from the repo"""
@ -383,10 +384,10 @@ class ResticBackupper(AbstractBackupper):
if len(json_messages) == 1:
try:
return json.loads(truncated_output)
except JSONDecodeError as e:
except JSONDecodeError as error:
raise ValueError(
"There is no json in the restic output : " + output
) from e
) from error
result_array = []
for message in json_messages:

View File

@ -1,4 +1,4 @@
from .provider import AbstractBackupProvider
from selfprivacy_api.backup.providers.provider import AbstractBackupProvider
from selfprivacy_api.backup.backuppers.none_backupper import NoneBackupper
from selfprivacy_api.graphql.queries.providers import (
BackupProvider as BackupProviderEnum,

View File

@ -1,3 +1,6 @@
"""
Module for storing backup related data in redis.
"""
from typing import List, Optional
from datetime import datetime
@ -10,10 +13,6 @@ from selfprivacy_api.utils.redis_model_storage import (
hash_as_model,
)
from selfprivacy_api.services.service import Service
from selfprivacy_api.services import get_service_by_id
from selfprivacy_api.backup.providers.provider import AbstractBackupProvider
from selfprivacy_api.backup.providers import get_kind
@ -32,8 +31,10 @@ redis = RedisPool().get_connection()
class Storage:
"""Static class for storing backup related data in redis"""
@staticmethod
def reset():
def reset() -> None:
"""Deletes all backup related data from redis"""
redis.delete(REDIS_PROVIDER_KEY)
redis.delete(REDIS_AUTOBACKUP_PERIOD_KEY)
@ -48,20 +49,22 @@ class Storage:
redis.delete(key)
@staticmethod
def invalidate_snapshot_storage():
def invalidate_snapshot_storage() -> None:
"""Deletes all cached snapshots from redis"""
for key in redis.keys(REDIS_SNAPSHOTS_PREFIX + "*"):
redis.delete(key)
@staticmethod
def __last_backup_key(service_id):
def __last_backup_key(service_id: str) -> str:
return REDIS_LAST_BACKUP_PREFIX + service_id
@staticmethod
def __snapshot_key(snapshot: Snapshot):
def __snapshot_key(snapshot: Snapshot) -> str:
return REDIS_SNAPSHOTS_PREFIX + snapshot.id
@staticmethod
def get_last_backup_time(service_id: str) -> Optional[datetime]:
"""Returns last backup time for a service or None if it was never backed up"""
key = Storage.__last_backup_key(service_id)
if not redis.exists(key):
return None
@ -72,7 +75,8 @@ class Storage:
return snapshot.created_at
@staticmethod
def store_last_timestamp(service_id: str, snapshot: Snapshot):
def store_last_timestamp(service_id: str, snapshot: Snapshot) -> None:
"""Stores last backup time for a service"""
store_model_as_hash(
redis,
Storage.__last_backup_key(service_id),
@ -80,18 +84,21 @@ class Storage:
)
@staticmethod
def cache_snapshot(snapshot: Snapshot):
def cache_snapshot(snapshot: Snapshot) -> None:
"""Stores snapshot metadata in redis for caching purposes"""
snapshot_key = Storage.__snapshot_key(snapshot)
store_model_as_hash(redis, snapshot_key, snapshot)
redis.expire(snapshot_key, REDIS_SNAPSHOT_CACHE_EXPIRE_SECONDS)
@staticmethod
def delete_cached_snapshot(snapshot: Snapshot):
def delete_cached_snapshot(snapshot: Snapshot) -> None:
"""Deletes snapshot metadata from redis"""
snapshot_key = Storage.__snapshot_key(snapshot)
redis.delete(snapshot_key)
@staticmethod
def get_cached_snapshot_by_id(snapshot_id: str) -> Optional[Snapshot]:
"""Returns cached snapshot by id or None if it doesn't exist"""
key = REDIS_SNAPSHOTS_PREFIX + snapshot_id
if not redis.exists(key):
return None
@ -99,12 +106,14 @@ class Storage:
@staticmethod
def get_cached_snapshots() -> List[Snapshot]:
keys = redis.keys(REDIS_SNAPSHOTS_PREFIX + "*")
result = []
"""Returns all cached snapshots stored in redis"""
keys: list[str] = redis.keys(REDIS_SNAPSHOTS_PREFIX + "*") # type: ignore
result: list[Snapshot] = []
for key in keys:
snapshot = hash_as_model(redis, key, Snapshot)
result.append(snapshot)
if snapshot:
result.append(snapshot)
return result
@staticmethod
@ -112,18 +121,21 @@ class Storage:
"""None means autobackup is disabled"""
if not redis.exists(REDIS_AUTOBACKUP_PERIOD_KEY):
return None
return int(redis.get(REDIS_AUTOBACKUP_PERIOD_KEY))
return int(redis.get(REDIS_AUTOBACKUP_PERIOD_KEY)) # type: ignore
@staticmethod
def store_autobackup_period_minutes(minutes: int):
def store_autobackup_period_minutes(minutes: int) -> None:
"""Set the new autobackup period in minutes"""
redis.set(REDIS_AUTOBACKUP_PERIOD_KEY, minutes)
@staticmethod
def delete_backup_period():
def delete_backup_period() -> None:
"""Set the autobackup period to none, effectively disabling autobackup"""
redis.delete(REDIS_AUTOBACKUP_PERIOD_KEY)
@staticmethod
def store_provider(provider: AbstractBackupProvider):
def store_provider(provider: AbstractBackupProvider) -> None:
"""Stores backup stroage provider auth data in redis"""
store_model_as_hash(
redis,
REDIS_PROVIDER_KEY,
@ -138,6 +150,7 @@ class Storage:
@staticmethod
def load_provider() -> Optional[BackupProviderModel]:
"""Loads backup storage provider auth data from redis"""
provider_model = hash_as_model(
redis,
REDIS_PROVIDER_KEY,
@ -147,10 +160,12 @@ class Storage:
@staticmethod
def has_init_mark() -> bool:
"""Returns True if the repository was initialized"""
if redis.exists(REDIS_INITTED_CACHE_PREFIX):
return True
return False
@staticmethod
def mark_as_init():
"""Marks the repository as initialized"""
redis.set(REDIS_INITTED_CACHE_PREFIX, 1)

View File

@ -1,21 +1,24 @@
"""
The tasks module contains the worker tasks that are used to back up and restore
"""
from datetime import datetime
from selfprivacy_api.graphql.common_types.backup import RestoreStrategy
from selfprivacy_api.models.backup.snapshot import Snapshot
from selfprivacy_api.utils.huey import huey
from selfprivacy_api.services import get_service_by_id
from selfprivacy_api.services.service import Service
from selfprivacy_api.backup import Backups
from selfprivacy_api.backup.jobs import add_backup_job, add_restore_job
def validate_datetime(dt: datetime):
# dt = datetime.now(timezone.utc)
def validate_datetime(dt: datetime) -> bool:
"""
Validates that the datetime passed in is timezone-aware.
"""
if dt.timetz is None:
raise ValueError(
"""
huey passed in the timezone-unaware time!
huey passed in the timezone-unaware time!
Post it in support chat or maybe try uncommenting a line above
"""
)
@ -25,6 +28,9 @@ def validate_datetime(dt: datetime):
# huey tasks need to return something
@huey.task()
def start_backup(service: Service) -> bool:
"""
The worker task that starts the backup process.
"""
Backups.back_up(service)
return True
@ -34,12 +40,18 @@ def restore_snapshot(
snapshot: Snapshot,
strategy: RestoreStrategy = RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE,
) -> bool:
"""
The worker task that starts the restore process.
"""
Backups.restore_snapshot(snapshot, strategy)
return True
@huey.periodic_task(validate_datetime=validate_datetime)
def automatic_backup():
"""
The worker periodic task that starts the automatic backup process.
"""
time = datetime.now()
for service in Backups.services_to_back_up(time):
start_backup(service)

View File

@ -1,7 +1,5 @@
import datetime
import typing
import strawberry
from strawberry.types import Info
from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.mutations.mutation_interface import (
@ -16,7 +14,7 @@ from selfprivacy_api.graphql.common_types.jobs import job_to_api_job
from selfprivacy_api.graphql.common_types.backup import RestoreStrategy
from selfprivacy_api.backup import Backups
from selfprivacy_api.services import get_all_services, get_service_by_id
from selfprivacy_api.services import get_service_by_id
from selfprivacy_api.backup.tasks import start_backup, restore_snapshot
from selfprivacy_api.backup.jobs import add_backup_job, add_restore_job
@ -142,11 +140,11 @@ class BackupMutations:
try:
job = add_restore_job(snap)
except ValueError as e:
except ValueError as error:
return GenericJobMutationReturn(
success=False,
code=400,
message=str(e),
message=str(error),
job=None,
)

View File

@ -64,6 +64,8 @@ class Backup:
status=ServiceStatusEnum.OFF,
url=None,
dns_records=None,
can_be_backed_up=False,
backup_description="",
)
else:
service = service_to_graphql_service(service)

View File

@ -125,57 +125,57 @@ class Jobs:
return False
@staticmethod
def reset_logs():
def reset_logs() -> None:
redis = RedisPool().get_connection()
for key in redis.keys(STATUS_LOGS_PREFIX + "*"):
redis.delete(key)
@staticmethod
def log_status_update(job: Job, status: JobStatus):
def log_status_update(job: Job, status: JobStatus) -> None:
redis = RedisPool().get_connection()
key = _status_log_key_from_uuid(job.uid)
redis.lpush(key, status.value)
redis.expire(key, 10)
@staticmethod
def log_progress_update(job: Job, progress: int):
def log_progress_update(job: Job, progress: int) -> None:
redis = RedisPool().get_connection()
key = _progress_log_key_from_uuid(job.uid)
redis.lpush(key, progress)
redis.expire(key, 10)
@staticmethod
def status_updates(job: Job) -> typing.List[JobStatus]:
result = []
def status_updates(job: Job) -> list[JobStatus]:
result: list[JobStatus] = []
redis = RedisPool().get_connection()
key = _status_log_key_from_uuid(job.uid)
if not redis.exists(key):
return []
status_strings = redis.lrange(key, 0, -1)
status_strings: list[str] = redis.lrange(key, 0, -1) # type: ignore
for status in status_strings:
try:
result.append(JobStatus[status])
except KeyError as e:
raise ValueError("impossible job status: " + status) from e
except KeyError as error:
raise ValueError("impossible job status: " + status) from error
return result
@staticmethod
def progress_updates(job: Job) -> typing.List[int]:
result = []
def progress_updates(job: Job) -> list[int]:
result: list[int] = []
redis = RedisPool().get_connection()
key = _progress_log_key_from_uuid(job.uid)
if not redis.exists(key):
return []
progress_strings = redis.lrange(key, 0, -1)
progress_strings: list[str] = redis.lrange(key, 0, -1) # type: ignore
for progress in progress_strings:
try:
result.append(int(progress))
except KeyError as e:
raise ValueError("impossible job progress: " + progress) from e
except KeyError as error:
raise ValueError("impossible job progress: " + progress) from error
return result
@staticmethod
@ -257,19 +257,19 @@ class Jobs:
return False
def _redis_key_from_uuid(uuid_string):
def _redis_key_from_uuid(uuid_string) -> str:
return "jobs:" + str(uuid_string)
def _status_log_key_from_uuid(uuid_string):
def _status_log_key_from_uuid(uuid_string) -> str:
return STATUS_LOGS_PREFIX + str(uuid_string)
def _progress_log_key_from_uuid(uuid_string):
def _progress_log_key_from_uuid(uuid_string) -> str:
return PROGRESS_LOGS_PREFIX + str(uuid_string)
def _store_job_as_hash(redis, redis_key, model):
def _store_job_as_hash(redis, redis_key, model) -> None:
for key, value in model.dict().items():
if isinstance(value, uuid.UUID):
value = str(value)
@ -280,7 +280,7 @@ def _store_job_as_hash(redis, redis_key, model):
redis.hset(redis_key, key, str(value))
def _job_from_hash(redis, redis_key):
def _job_from_hash(redis, redis_key) -> typing.Optional[Job]:
if redis.exists(redis_key):
job_dict = redis.hgetall(redis_key)
for date in [

View File

@ -1,7 +1,7 @@
"""
Token repository using Redis as backend.
"""
from typing import Optional
from typing import Any, Optional
from datetime import datetime
from hashlib import md5
@ -29,15 +29,15 @@ class RedisTokensRepository(AbstractTokensRepository):
@staticmethod
def token_key_for_device(device_name: str):
hash = md5()
hash.update(bytes(device_name, "utf-8"))
digest = hash.hexdigest()
md5_hash = md5()
md5_hash.update(bytes(device_name, "utf-8"))
digest = md5_hash.hexdigest()
return TOKENS_PREFIX + digest
def get_tokens(self) -> list[Token]:
"""Get the tokens"""
redis = self.connection
token_keys = redis.keys(TOKENS_PREFIX + "*")
token_keys: list[str] = redis.keys(TOKENS_PREFIX + "*") # type: ignore
tokens = []
for key in token_keys:
token = self._token_from_hash(key)
@ -45,10 +45,10 @@ class RedisTokensRepository(AbstractTokensRepository):
tokens.append(token)
return tokens
def _discover_token_key(self, input_token: Token) -> str:
def _discover_token_key(self, input_token: Token) -> Optional[str]:
"""brute-force searching for tokens, for robust deletion"""
redis = self.connection
token_keys = redis.keys(TOKENS_PREFIX + "*")
token_keys: list[str] = redis.keys(TOKENS_PREFIX + "*") # type: ignore
for key in token_keys:
token = self._token_from_hash(key)
if token == input_token:
@ -120,26 +120,26 @@ class RedisTokensRepository(AbstractTokensRepository):
return self._new_device_key_from_hash(NEW_DEVICE_KEY_REDIS_KEY)
@staticmethod
def _is_date_key(key: str):
def _is_date_key(key: str) -> bool:
return key in [
"created_at",
"expires_at",
]
@staticmethod
def _prepare_model_dict(d: dict):
date_keys = [key for key in d.keys() if RedisTokensRepository._is_date_key(key)]
def _prepare_model_dict(model_dict: dict[str, Any]) -> None:
date_keys = [key for key in model_dict.keys() if RedisTokensRepository._is_date_key(key)]
for date in date_keys:
if d[date] != "None":
d[date] = datetime.fromisoformat(d[date])
for key in d.keys():
if d[key] == "None":
d[key] = None
if model_dict[date] != "None":
model_dict[date] = datetime.fromisoformat(model_dict[date])
for key in model_dict.keys():
if model_dict[key] == "None":
model_dict[key] = None
def _model_dict_from_hash(self, redis_key: str) -> Optional[dict]:
def _model_dict_from_hash(self, redis_key: str) -> Optional[dict[str, Any]]:
redis = self.connection
if redis.exists(redis_key):
token_dict = redis.hgetall(redis_key)
token_dict: dict[str, Any] = redis.hgetall(redis_key) # type: ignore
RedisTokensRepository._prepare_model_dict(token_dict)
return token_dict
return None

View File

@ -1,9 +1,9 @@
"""
Redis pool module for selfprivacy_api
"""
from os import environ
import redis
from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass
from os import environ
REDIS_SOCKET = "/run/redis-sp-api/redis.sock"
@ -14,7 +14,7 @@ class RedisPool(metaclass=SingletonMetaclass):
"""
def __init__(self):
if "USE_REDIS_PORT" in environ.keys():
if "USE_REDIS_PORT" in environ:
self._pool = redis.ConnectionPool(
host="127.0.0.1",
port=int(environ["USE_REDIS_PORT"]),