refactor(backups): linting

pull/35/head
Inex Code 2023-07-20 18:24:26 +03:00
parent 2df448a4a9
commit 0245d629fd
12 changed files with 196 additions and 112 deletions

View File

@ -1,3 +1,6 @@
"""
This module contains the controller class for backups.
"""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from os import statvfs from os import statvfs
from typing import List, Optional from typing import List, Optional
@ -42,8 +45,12 @@ DEFAULT_JSON_PROVIDER = {
class NotDeadError(AssertionError): class NotDeadError(AssertionError):
"""
This error is raised when we try to back up a service that is not dead yet.
"""
def __init__(self, service: Service): def __init__(self, service: Service):
self.service_name = service.get_id() self.service_name = service.get_id()
super().__init__()
def __str__(self): def __str__(self):
return f""" return f"""
@ -61,6 +68,9 @@ class Backups:
@staticmethod @staticmethod
def provider() -> AbstractBackupProvider: def provider() -> AbstractBackupProvider:
"""
Returns the current backup storage provider.
"""
return Backups._lookup_provider() return Backups._lookup_provider()
@staticmethod @staticmethod
@ -71,6 +81,13 @@ class Backups:
location: str, location: str,
repo_id: str = "", repo_id: str = "",
) -> None: ) -> None:
"""
Sets the new configuration of the backup storage provider.
In case of `BackupProviderEnum.BACKBLAZE`, the `login` is the key ID,
the `key` is the key itself, and the `location` is the bucket name and
the `repo_id` is the bucket ID.
"""
provider: AbstractBackupProvider = Backups._construct_provider( provider: AbstractBackupProvider = Backups._construct_provider(
kind, kind,
login, login,
@ -82,6 +99,9 @@ class Backups:
@staticmethod @staticmethod
def reset(reset_json=True) -> None: def reset(reset_json=True) -> None:
"""
Deletes all the data about the backup storage provider.
"""
Storage.reset() Storage.reset()
if reset_json: if reset_json:
try: try:
@ -183,11 +203,19 @@ class Backups:
@staticmethod @staticmethod
def init_repo() -> None: def init_repo() -> None:
"""
Initializes the backup repository. This is required once per repo.
"""
Backups.provider().backupper.init() Backups.provider().backupper.init()
Storage.mark_as_init() Storage.mark_as_init()
@staticmethod @staticmethod
def is_initted() -> bool: def is_initted() -> bool:
"""
Returns whether the backup repository is initialized or not.
If it is not initialized, we cannot back up and probably should
call `init_repo` first.
"""
if Storage.has_init_mark(): if Storage.has_init_mark():
return True return True
@ -219,9 +247,9 @@ class Backups:
) )
Backups._store_last_snapshot(tag, snapshot) Backups._store_last_snapshot(tag, snapshot)
service.post_restore() service.post_restore()
except Exception as e: except Exception as error:
Jobs.update(job, status=JobStatus.ERROR) Jobs.update(job, status=JobStatus.ERROR)
raise e raise error
Jobs.update(job, status=JobStatus.FINISHED) Jobs.update(job, status=JobStatus.FINISHED)
return snapshot return snapshot
@ -252,16 +280,17 @@ class Backups:
snapshot.id, snapshot.id,
verify=False, verify=False,
) )
except Exception as e: except Exception as error:
Backups._restore_service_from_snapshot( Backups._restore_service_from_snapshot(
service, failsafe_snapshot.id, verify=False service, failsafe_snapshot.id, verify=False
) )
raise e raise error
@staticmethod @staticmethod
def restore_snapshot( def restore_snapshot(
snapshot: Snapshot, strategy=RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE snapshot: Snapshot, strategy=RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE
) -> None: ) -> None:
"""Restores a snapshot to its original service using the given strategy"""
service = get_service_by_id(snapshot.service_name) service = get_service_by_id(snapshot.service_name)
if service is None: if service is None:
raise ValueError( raise ValueError(
@ -283,9 +312,9 @@ class Backups:
service.post_restore() service.post_restore()
except Exception as e: except Exception as error:
Jobs.update(job, status=JobStatus.ERROR) Jobs.update(job, status=JobStatus.ERROR)
raise e raise error
Jobs.update(job, status=JobStatus.FINISHED) Jobs.update(job, status=JobStatus.FINISHED)
@ -338,6 +367,7 @@ class Backups:
@staticmethod @staticmethod
def get_snapshots(service: Service) -> List[Snapshot]: def get_snapshots(service: Service) -> List[Snapshot]:
"""Returns all snapshots for a given service"""
snapshots = Backups.get_all_snapshots() snapshots = Backups.get_all_snapshots()
service_id = service.get_id() service_id = service.get_id()
return list( return list(
@ -349,8 +379,9 @@ class Backups:
@staticmethod @staticmethod
def get_all_snapshots() -> List[Snapshot]: def get_all_snapshots() -> List[Snapshot]:
"""Returns all snapshots"""
cached_snapshots = Storage.get_cached_snapshots() cached_snapshots = Storage.get_cached_snapshots()
if cached_snapshots != []: if cached_snapshots:
return cached_snapshots return cached_snapshots
# TODO: the oldest snapshots will get expired faster than the new ones. # TODO: the oldest snapshots will get expired faster than the new ones.
# How to detect that the end is missing? # How to detect that the end is missing?
@ -359,24 +390,32 @@ class Backups:
return Storage.get_cached_snapshots() return Storage.get_cached_snapshots()
@staticmethod @staticmethod
def get_snapshot_by_id(id: str) -> Optional[Snapshot]: def get_snapshot_by_id(snapshot_id: str) -> Optional[Snapshot]:
snap = Storage.get_cached_snapshot_by_id(id) """Returns a backup snapshot by its id"""
snap = Storage.get_cached_snapshot_by_id(snapshot_id)
if snap is not None: if snap is not None:
return snap return snap
# Possibly our cache entry got invalidated, let's try one more time # Possibly our cache entry got invalidated, let's try one more time
Backups.force_snapshot_cache_reload() Backups.force_snapshot_cache_reload()
snap = Storage.get_cached_snapshot_by_id(id) snap = Storage.get_cached_snapshot_by_id(snapshot_id)
return snap return snap
@staticmethod @staticmethod
def forget_snapshot(snapshot: Snapshot) -> None: def forget_snapshot(snapshot: Snapshot) -> None:
"""Deletes a snapshot from the storage"""
Backups.provider().backupper.forget_snapshot(snapshot.id) Backups.provider().backupper.forget_snapshot(snapshot.id)
Storage.delete_cached_snapshot(snapshot) Storage.delete_cached_snapshot(snapshot)
@staticmethod @staticmethod
def force_snapshot_cache_reload() -> None: def force_snapshot_cache_reload() -> None:
"""
Forces a reload of the snapshot cache.
This may be an expensive operation, so use it wisely.
User pays for the API calls.
"""
upstream_snapshots = Backups.provider().backupper.get_snapshots() upstream_snapshots = Backups.provider().backupper.get_snapshots()
Storage.invalidate_snapshot_storage() Storage.invalidate_snapshot_storage()
for snapshot in upstream_snapshots: for snapshot in upstream_snapshots:
@ -384,6 +423,7 @@ class Backups:
@staticmethod @staticmethod
def snapshot_restored_size(snapshot_id: str) -> int: def snapshot_restored_size(snapshot_id: str) -> int:
"""Returns the size of the snapshot"""
return Backups.provider().backupper.restored_size( return Backups.provider().backupper.restored_size(
snapshot_id, snapshot_id,
) )
@ -434,6 +474,7 @@ class Backups:
@staticmethod @staticmethod
def services_to_back_up(time: datetime) -> List[Service]: def services_to_back_up(time: datetime) -> List[Service]:
"""Returns a list of services that should be backed up at a given time"""
return [ return [
service service
for service in get_all_services() for service in get_all_services()
@ -447,6 +488,7 @@ class Backups:
@staticmethod @staticmethod
def is_time_to_backup_service(service: Service, time: datetime): def is_time_to_backup_service(service: Service, time: datetime):
"""Returns True if it is time to back up a service"""
period = Backups.autobackup_period_minutes() period = Backups.autobackup_period_minutes()
service_id = service.get_id() service_id = service.get_id()
if not service.can_be_backed_up(): if not service.can_be_backed_up():
@ -467,6 +509,10 @@ class Backups:
@staticmethod @staticmethod
def space_usable_for_service(service: Service) -> int: def space_usable_for_service(service: Service) -> int:
"""
Returns the amount of space available on the volume the given
service is located on.
"""
folders = service.get_folders() folders = service.get_folders()
if folders == []: if folders == []:
raise ValueError("unallocated service", service.get_id()) raise ValueError("unallocated service", service.get_id())
@ -478,6 +524,8 @@ class Backups:
@staticmethod @staticmethod
def set_localfile_repo(file_path: str): def set_localfile_repo(file_path: str):
"""Used by tests to set a local folder as a backup repo"""
# pylint: disable-next=invalid-name
ProviderClass = get_provider(BackupProviderEnum.FILE) ProviderClass = get_provider(BackupProviderEnum.FILE)
provider = ProviderClass( provider = ProviderClass(
login="", login="",
@ -490,10 +538,7 @@ class Backups:
@staticmethod @staticmethod
def assert_dead(service: Service): def assert_dead(service: Service):
""" """
Checks if a service is dead and can be safely restored from a snapshot.
If we backup the service that is failing to restore it to the previous snapshot,
its status can be FAILED.
And obviously restoring a failed service is the main route
""" """
if service.get_status() not in [ if service.get_status() not in [
ServiceStatus.INACTIVE, ServiceStatus.INACTIVE,

View File

@ -5,19 +5,25 @@ from selfprivacy_api.models.backup.snapshot import Snapshot
class AbstractBackupper(ABC): class AbstractBackupper(ABC):
"""Abstract class for backuppers"""
# flake8: noqa: B027
def __init__(self) -> None: def __init__(self) -> None:
pass pass
@abstractmethod @abstractmethod
def is_initted(self) -> bool: def is_initted(self) -> bool:
"""Returns true if the repository is initted"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def set_creds(self, account: str, key: str, repo: str) -> None: def set_creds(self, account: str, key: str, repo: str) -> None:
"""Set the credentials for the backupper"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def start_backup(self, folders: List[str], repo_name: str) -> Snapshot: def start_backup(self, folders: List[str], tag: str) -> Snapshot:
"""Start a backup of the given folders"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@ -27,6 +33,7 @@ class AbstractBackupper(ABC):
@abstractmethod @abstractmethod
def init(self) -> None: def init(self) -> None:
"""Initialize the repository"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@ -41,8 +48,10 @@ class AbstractBackupper(ABC):
@abstractmethod @abstractmethod
def restored_size(self, snapshot_id: str) -> int: def restored_size(self, snapshot_id: str) -> int:
"""Get the size of the restored snapshot"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def forget_snapshot(self, snapshot_id) -> None: def forget_snapshot(self, snapshot_id) -> None:
"""Forget a snapshot"""
raise NotImplementedError raise NotImplementedError

View File

@ -5,13 +5,15 @@ from selfprivacy_api.backup.backuppers import AbstractBackupper
class NoneBackupper(AbstractBackupper): class NoneBackupper(AbstractBackupper):
"""A backupper that does nothing"""
def is_initted(self, repo_name: str = "") -> bool: def is_initted(self, repo_name: str = "") -> bool:
return False return False
def set_creds(self, account: str, key: str, repo: str): def set_creds(self, account: str, key: str, repo: str):
pass pass
def start_backup(self, folders: List[str], repo_name: str): def start_backup(self, folders: List[str], tag: str):
raise NotImplementedError raise NotImplementedError
def get_snapshots(self) -> List[Snapshot]: def get_snapshots(self) -> List[Snapshot]:
@ -21,7 +23,7 @@ class NoneBackupper(AbstractBackupper):
def init(self): def init(self):
raise NotImplementedError raise NotImplementedError
def restore_from_backup(self, snapshot_id: str, folders: List[str]): def restore_from_backup(self, snapshot_id: str, folders: List[str], verify=True):
"""Restore a target folder using a snapshot""" """Restore a target folder using a snapshot"""
raise NotImplementedError raise NotImplementedError

View File

@ -21,13 +21,14 @@ from selfprivacy_api.backup.local_secret import LocalBackupSecret
class ResticBackupper(AbstractBackupper): class ResticBackupper(AbstractBackupper):
def __init__(self, login_flag: str, key_flag: str, type: str) -> None: def __init__(self, login_flag: str, key_flag: str, storage_type: str) -> None:
self.login_flag = login_flag self.login_flag = login_flag
self.key_flag = key_flag self.key_flag = key_flag
self.type = type self.storage_type = storage_type
self.account = "" self.account = ""
self.key = "" self.key = ""
self.repo = "" self.repo = ""
super().__init__()
def set_creds(self, account: str, key: str, repo: str) -> None: def set_creds(self, account: str, key: str, repo: str) -> None:
self.account = account self.account = account
@ -37,7 +38,7 @@ class ResticBackupper(AbstractBackupper):
def restic_repo(self) -> str: def restic_repo(self) -> str:
# https://restic.readthedocs.io/en/latest/030_preparing_a_new_repo.html#other-services-via-rclone # https://restic.readthedocs.io/en/latest/030_preparing_a_new_repo.html#other-services-via-rclone
# https://forum.rclone.org/t/can-rclone-be-run-solely-with-command-line-options-no-config-no-env-vars/6314/5 # https://forum.rclone.org/t/can-rclone-be-run-solely-with-command-line-options-no-config-no-env-vars/6314/5
return f"rclone:{self.type}{self.repo}" return f"rclone:{self.storage_type}{self.repo}"
def rclone_args(self): def rclone_args(self):
return "rclone.args=serve restic --stdio " + self.backend_rclone_args() return "rclone.args=serve restic --stdio " + self.backend_rclone_args()
@ -72,12 +73,12 @@ class ResticBackupper(AbstractBackupper):
tag, tag,
] ]
) )
if args != []: if args:
command.extend(ResticBackupper.__flatten_list(args)) command.extend(ResticBackupper.__flatten_list(args))
return command return command
def mount_repo(self, dir): def mount_repo(self, mount_directory):
mount_command = self.restic_command("mount", dir) mount_command = self.restic_command("mount", mount_directory)
mount_command.insert(0, "nohup") mount_command.insert(0, "nohup")
handle = subprocess.Popen( handle = subprocess.Popen(
mount_command, mount_command,
@ -85,28 +86,28 @@ class ResticBackupper(AbstractBackupper):
shell=False, shell=False,
) )
sleep(2) sleep(2)
if "ids" not in listdir(dir): if "ids" not in listdir(mount_directory):
raise IOError("failed to mount dir ", dir) raise IOError("failed to mount dir ", mount_directory)
return handle return handle
def unmount_repo(self, dir): def unmount_repo(self, mount_directory):
mount_command = ["umount", "-l", dir] mount_command = ["umount", "-l", mount_directory]
with subprocess.Popen( with subprocess.Popen(
mount_command, stdout=subprocess.PIPE, shell=False mount_command, stdout=subprocess.PIPE, shell=False
) as handle: ) as handle:
output = handle.communicate()[0].decode("utf-8") output = handle.communicate()[0].decode("utf-8")
# TODO: check for exit code? # TODO: check for exit code?
if "error" in output.lower(): if "error" in output.lower():
return IOError("failed to unmount dir ", dir, ": ", output) return IOError("failed to unmount dir ", mount_directory, ": ", output)
if not listdir(dir) == []: if not listdir(mount_directory) == []:
return IOError("failed to unmount dir ", dir) return IOError("failed to unmount dir ", mount_directory)
@staticmethod @staticmethod
def __flatten_list(list): def __flatten_list(list_to_flatten):
"""string-aware list flattener""" """string-aware list flattener"""
result = [] result = []
for item in list: for item in list_to_flatten:
if isinstance(item, Iterable) and not isinstance(item, str): if isinstance(item, Iterable) and not isinstance(item, str):
result.extend(ResticBackupper.__flatten_list(item)) result.extend(ResticBackupper.__flatten_list(item))
continue continue
@ -147,8 +148,8 @@ class ResticBackupper(AbstractBackupper):
messages, messages,
tag, tag,
) )
except ValueError as e: except ValueError as error:
raise ValueError("Could not create a snapshot: ", messages) from e raise ValueError("Could not create a snapshot: ", messages) from error
@staticmethod @staticmethod
def _snapshot_from_backup_messages(messages, repo_name) -> Snapshot: def _snapshot_from_backup_messages(messages, repo_name) -> Snapshot:
@ -231,8 +232,8 @@ class ResticBackupper(AbstractBackupper):
try: try:
parsed_output = ResticBackupper.parse_json_output(output) parsed_output = ResticBackupper.parse_json_output(output)
return parsed_output["total_size"] return parsed_output["total_size"]
except ValueError as e: except ValueError as error:
raise ValueError("cannot restore a snapshot: " + output) from e raise ValueError("cannot restore a snapshot: " + output) from error
def restore_from_backup( def restore_from_backup(
self, self,
@ -246,13 +247,13 @@ class ResticBackupper(AbstractBackupper):
if folders is None or folders == []: if folders is None or folders == []:
raise ValueError("cannot restore without knowing where to!") raise ValueError("cannot restore without knowing where to!")
with tempfile.TemporaryDirectory() as dir: with tempfile.TemporaryDirectory() as temp_dir:
if verify: if verify:
self._raw_verified_restore(snapshot_id, target=dir) self._raw_verified_restore(snapshot_id, target=temp_dir)
snapshot_root = dir snapshot_root = temp_dir
else: # attempting inplace restore via mount + sync else: # attempting inplace restore via mount + sync
self.mount_repo(dir) self.mount_repo(temp_dir)
snapshot_root = join(dir, "ids", snapshot_id) snapshot_root = join(temp_dir, "ids", snapshot_id)
assert snapshot_root is not None assert snapshot_root is not None
for folder in folders: for folder in folders:
@ -263,7 +264,7 @@ class ResticBackupper(AbstractBackupper):
sync(src, dst) sync(src, dst)
if not verify: if not verify:
self.unmount_repo(dir) self.unmount_repo(temp_dir)
def _raw_verified_restore(self, snapshot_id, target="/"): def _raw_verified_restore(self, snapshot_id, target="/"):
"""barebones restic restore""" """barebones restic restore"""
@ -355,8 +356,8 @@ class ResticBackupper(AbstractBackupper):
raise ValueError("No repository! : " + output) raise ValueError("No repository! : " + output)
try: try:
return ResticBackupper.parse_json_output(output) return ResticBackupper.parse_json_output(output)
except ValueError as e: except ValueError as error:
raise ValueError("Cannot load snapshots: ") from e raise ValueError("Cannot load snapshots: ") from error
def get_snapshots(self) -> List[Snapshot]: def get_snapshots(self) -> List[Snapshot]:
"""Get all snapshots from the repo""" """Get all snapshots from the repo"""
@ -383,10 +384,10 @@ class ResticBackupper(AbstractBackupper):
if len(json_messages) == 1: if len(json_messages) == 1:
try: try:
return json.loads(truncated_output) return json.loads(truncated_output)
except JSONDecodeError as e: except JSONDecodeError as error:
raise ValueError( raise ValueError(
"There is no json in the restic output : " + output "There is no json in the restic output : " + output
) from e ) from error
result_array = [] result_array = []
for message in json_messages: for message in json_messages:

View File

@ -1,4 +1,4 @@
from .provider import AbstractBackupProvider from selfprivacy_api.backup.providers.provider import AbstractBackupProvider
from selfprivacy_api.backup.backuppers.none_backupper import NoneBackupper from selfprivacy_api.backup.backuppers.none_backupper import NoneBackupper
from selfprivacy_api.graphql.queries.providers import ( from selfprivacy_api.graphql.queries.providers import (
BackupProvider as BackupProviderEnum, BackupProvider as BackupProviderEnum,

View File

@ -1,3 +1,6 @@
"""
Module for storing backup related data in redis.
"""
from typing import List, Optional from typing import List, Optional
from datetime import datetime from datetime import datetime
@ -10,10 +13,6 @@ from selfprivacy_api.utils.redis_model_storage import (
hash_as_model, hash_as_model,
) )
from selfprivacy_api.services.service import Service
from selfprivacy_api.services import get_service_by_id
from selfprivacy_api.backup.providers.provider import AbstractBackupProvider from selfprivacy_api.backup.providers.provider import AbstractBackupProvider
from selfprivacy_api.backup.providers import get_kind from selfprivacy_api.backup.providers import get_kind
@ -32,8 +31,10 @@ redis = RedisPool().get_connection()
class Storage: class Storage:
"""Static class for storing backup related data in redis"""
@staticmethod @staticmethod
def reset(): def reset() -> None:
"""Deletes all backup related data from redis"""
redis.delete(REDIS_PROVIDER_KEY) redis.delete(REDIS_PROVIDER_KEY)
redis.delete(REDIS_AUTOBACKUP_PERIOD_KEY) redis.delete(REDIS_AUTOBACKUP_PERIOD_KEY)
@ -48,20 +49,22 @@ class Storage:
redis.delete(key) redis.delete(key)
@staticmethod @staticmethod
def invalidate_snapshot_storage(): def invalidate_snapshot_storage() -> None:
"""Deletes all cached snapshots from redis"""
for key in redis.keys(REDIS_SNAPSHOTS_PREFIX + "*"): for key in redis.keys(REDIS_SNAPSHOTS_PREFIX + "*"):
redis.delete(key) redis.delete(key)
@staticmethod @staticmethod
def __last_backup_key(service_id): def __last_backup_key(service_id: str) -> str:
return REDIS_LAST_BACKUP_PREFIX + service_id return REDIS_LAST_BACKUP_PREFIX + service_id
@staticmethod @staticmethod
def __snapshot_key(snapshot: Snapshot): def __snapshot_key(snapshot: Snapshot) -> str:
return REDIS_SNAPSHOTS_PREFIX + snapshot.id return REDIS_SNAPSHOTS_PREFIX + snapshot.id
@staticmethod @staticmethod
def get_last_backup_time(service_id: str) -> Optional[datetime]: def get_last_backup_time(service_id: str) -> Optional[datetime]:
"""Returns last backup time for a service or None if it was never backed up"""
key = Storage.__last_backup_key(service_id) key = Storage.__last_backup_key(service_id)
if not redis.exists(key): if not redis.exists(key):
return None return None
@ -72,7 +75,8 @@ class Storage:
return snapshot.created_at return snapshot.created_at
@staticmethod @staticmethod
def store_last_timestamp(service_id: str, snapshot: Snapshot): def store_last_timestamp(service_id: str, snapshot: Snapshot) -> None:
"""Stores last backup time for a service"""
store_model_as_hash( store_model_as_hash(
redis, redis,
Storage.__last_backup_key(service_id), Storage.__last_backup_key(service_id),
@ -80,18 +84,21 @@ class Storage:
) )
@staticmethod @staticmethod
def cache_snapshot(snapshot: Snapshot): def cache_snapshot(snapshot: Snapshot) -> None:
"""Stores snapshot metadata in redis for caching purposes"""
snapshot_key = Storage.__snapshot_key(snapshot) snapshot_key = Storage.__snapshot_key(snapshot)
store_model_as_hash(redis, snapshot_key, snapshot) store_model_as_hash(redis, snapshot_key, snapshot)
redis.expire(snapshot_key, REDIS_SNAPSHOT_CACHE_EXPIRE_SECONDS) redis.expire(snapshot_key, REDIS_SNAPSHOT_CACHE_EXPIRE_SECONDS)
@staticmethod @staticmethod
def delete_cached_snapshot(snapshot: Snapshot): def delete_cached_snapshot(snapshot: Snapshot) -> None:
"""Deletes snapshot metadata from redis"""
snapshot_key = Storage.__snapshot_key(snapshot) snapshot_key = Storage.__snapshot_key(snapshot)
redis.delete(snapshot_key) redis.delete(snapshot_key)
@staticmethod @staticmethod
def get_cached_snapshot_by_id(snapshot_id: str) -> Optional[Snapshot]: def get_cached_snapshot_by_id(snapshot_id: str) -> Optional[Snapshot]:
"""Returns cached snapshot by id or None if it doesn't exist"""
key = REDIS_SNAPSHOTS_PREFIX + snapshot_id key = REDIS_SNAPSHOTS_PREFIX + snapshot_id
if not redis.exists(key): if not redis.exists(key):
return None return None
@ -99,12 +106,14 @@ class Storage:
@staticmethod @staticmethod
def get_cached_snapshots() -> List[Snapshot]: def get_cached_snapshots() -> List[Snapshot]:
keys = redis.keys(REDIS_SNAPSHOTS_PREFIX + "*") """Returns all cached snapshots stored in redis"""
result = [] keys: list[str] = redis.keys(REDIS_SNAPSHOTS_PREFIX + "*") # type: ignore
result: list[Snapshot] = []
for key in keys: for key in keys:
snapshot = hash_as_model(redis, key, Snapshot) snapshot = hash_as_model(redis, key, Snapshot)
result.append(snapshot) if snapshot:
result.append(snapshot)
return result return result
@staticmethod @staticmethod
@ -112,18 +121,21 @@ class Storage:
"""None means autobackup is disabled""" """None means autobackup is disabled"""
if not redis.exists(REDIS_AUTOBACKUP_PERIOD_KEY): if not redis.exists(REDIS_AUTOBACKUP_PERIOD_KEY):
return None return None
return int(redis.get(REDIS_AUTOBACKUP_PERIOD_KEY)) return int(redis.get(REDIS_AUTOBACKUP_PERIOD_KEY)) # type: ignore
@staticmethod @staticmethod
def store_autobackup_period_minutes(minutes: int): def store_autobackup_period_minutes(minutes: int) -> None:
"""Set the new autobackup period in minutes"""
redis.set(REDIS_AUTOBACKUP_PERIOD_KEY, minutes) redis.set(REDIS_AUTOBACKUP_PERIOD_KEY, minutes)
@staticmethod @staticmethod
def delete_backup_period(): def delete_backup_period() -> None:
"""Set the autobackup period to none, effectively disabling autobackup"""
redis.delete(REDIS_AUTOBACKUP_PERIOD_KEY) redis.delete(REDIS_AUTOBACKUP_PERIOD_KEY)
@staticmethod @staticmethod
def store_provider(provider: AbstractBackupProvider): def store_provider(provider: AbstractBackupProvider) -> None:
"""Stores backup stroage provider auth data in redis"""
store_model_as_hash( store_model_as_hash(
redis, redis,
REDIS_PROVIDER_KEY, REDIS_PROVIDER_KEY,
@ -138,6 +150,7 @@ class Storage:
@staticmethod @staticmethod
def load_provider() -> Optional[BackupProviderModel]: def load_provider() -> Optional[BackupProviderModel]:
"""Loads backup storage provider auth data from redis"""
provider_model = hash_as_model( provider_model = hash_as_model(
redis, redis,
REDIS_PROVIDER_KEY, REDIS_PROVIDER_KEY,
@ -147,10 +160,12 @@ class Storage:
@staticmethod @staticmethod
def has_init_mark() -> bool: def has_init_mark() -> bool:
"""Returns True if the repository was initialized"""
if redis.exists(REDIS_INITTED_CACHE_PREFIX): if redis.exists(REDIS_INITTED_CACHE_PREFIX):
return True return True
return False return False
@staticmethod @staticmethod
def mark_as_init(): def mark_as_init():
"""Marks the repository as initialized"""
redis.set(REDIS_INITTED_CACHE_PREFIX, 1) redis.set(REDIS_INITTED_CACHE_PREFIX, 1)

View File

@ -1,21 +1,24 @@
"""
The tasks module contains the worker tasks that are used to back up and restore
"""
from datetime import datetime from datetime import datetime
from selfprivacy_api.graphql.common_types.backup import RestoreStrategy from selfprivacy_api.graphql.common_types.backup import RestoreStrategy
from selfprivacy_api.models.backup.snapshot import Snapshot from selfprivacy_api.models.backup.snapshot import Snapshot
from selfprivacy_api.utils.huey import huey from selfprivacy_api.utils.huey import huey
from selfprivacy_api.services import get_service_by_id
from selfprivacy_api.services.service import Service from selfprivacy_api.services.service import Service
from selfprivacy_api.backup import Backups from selfprivacy_api.backup import Backups
from selfprivacy_api.backup.jobs import add_backup_job, add_restore_job
def validate_datetime(dt: datetime): def validate_datetime(dt: datetime) -> bool:
# dt = datetime.now(timezone.utc) """
Validates that the datetime passed in is timezone-aware.
"""
if dt.timetz is None: if dt.timetz is None:
raise ValueError( raise ValueError(
""" """
huey passed in the timezone-unaware time! huey passed in the timezone-unaware time!
Post it in support chat or maybe try uncommenting a line above Post it in support chat or maybe try uncommenting a line above
""" """
) )
@ -25,6 +28,9 @@ def validate_datetime(dt: datetime):
# huey tasks need to return something # huey tasks need to return something
@huey.task() @huey.task()
def start_backup(service: Service) -> bool: def start_backup(service: Service) -> bool:
"""
The worker task that starts the backup process.
"""
Backups.back_up(service) Backups.back_up(service)
return True return True
@ -34,12 +40,18 @@ def restore_snapshot(
snapshot: Snapshot, snapshot: Snapshot,
strategy: RestoreStrategy = RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE, strategy: RestoreStrategy = RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE,
) -> bool: ) -> bool:
"""
The worker task that starts the restore process.
"""
Backups.restore_snapshot(snapshot, strategy) Backups.restore_snapshot(snapshot, strategy)
return True return True
@huey.periodic_task(validate_datetime=validate_datetime) @huey.periodic_task(validate_datetime=validate_datetime)
def automatic_backup(): def automatic_backup():
"""
The worker periodic task that starts the automatic backup process.
"""
time = datetime.now() time = datetime.now()
for service in Backups.services_to_back_up(time): for service in Backups.services_to_back_up(time):
start_backup(service) start_backup(service)

View File

@ -1,7 +1,5 @@
import datetime
import typing import typing
import strawberry import strawberry
from strawberry.types import Info
from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.mutations.mutation_interface import ( from selfprivacy_api.graphql.mutations.mutation_interface import (
@ -16,7 +14,7 @@ from selfprivacy_api.graphql.common_types.jobs import job_to_api_job
from selfprivacy_api.graphql.common_types.backup import RestoreStrategy from selfprivacy_api.graphql.common_types.backup import RestoreStrategy
from selfprivacy_api.backup import Backups from selfprivacy_api.backup import Backups
from selfprivacy_api.services import get_all_services, get_service_by_id from selfprivacy_api.services import get_service_by_id
from selfprivacy_api.backup.tasks import start_backup, restore_snapshot from selfprivacy_api.backup.tasks import start_backup, restore_snapshot
from selfprivacy_api.backup.jobs import add_backup_job, add_restore_job from selfprivacy_api.backup.jobs import add_backup_job, add_restore_job
@ -142,11 +140,11 @@ class BackupMutations:
try: try:
job = add_restore_job(snap) job = add_restore_job(snap)
except ValueError as e: except ValueError as error:
return GenericJobMutationReturn( return GenericJobMutationReturn(
success=False, success=False,
code=400, code=400,
message=str(e), message=str(error),
job=None, job=None,
) )

View File

@ -64,6 +64,8 @@ class Backup:
status=ServiceStatusEnum.OFF, status=ServiceStatusEnum.OFF,
url=None, url=None,
dns_records=None, dns_records=None,
can_be_backed_up=False,
backup_description="",
) )
else: else:
service = service_to_graphql_service(service) service = service_to_graphql_service(service)

View File

@ -125,57 +125,57 @@ class Jobs:
return False return False
@staticmethod @staticmethod
def reset_logs(): def reset_logs() -> None:
redis = RedisPool().get_connection() redis = RedisPool().get_connection()
for key in redis.keys(STATUS_LOGS_PREFIX + "*"): for key in redis.keys(STATUS_LOGS_PREFIX + "*"):
redis.delete(key) redis.delete(key)
@staticmethod @staticmethod
def log_status_update(job: Job, status: JobStatus): def log_status_update(job: Job, status: JobStatus) -> None:
redis = RedisPool().get_connection() redis = RedisPool().get_connection()
key = _status_log_key_from_uuid(job.uid) key = _status_log_key_from_uuid(job.uid)
redis.lpush(key, status.value) redis.lpush(key, status.value)
redis.expire(key, 10) redis.expire(key, 10)
@staticmethod @staticmethod
def log_progress_update(job: Job, progress: int): def log_progress_update(job: Job, progress: int) -> None:
redis = RedisPool().get_connection() redis = RedisPool().get_connection()
key = _progress_log_key_from_uuid(job.uid) key = _progress_log_key_from_uuid(job.uid)
redis.lpush(key, progress) redis.lpush(key, progress)
redis.expire(key, 10) redis.expire(key, 10)
@staticmethod @staticmethod
def status_updates(job: Job) -> typing.List[JobStatus]: def status_updates(job: Job) -> list[JobStatus]:
result = [] result: list[JobStatus] = []
redis = RedisPool().get_connection() redis = RedisPool().get_connection()
key = _status_log_key_from_uuid(job.uid) key = _status_log_key_from_uuid(job.uid)
if not redis.exists(key): if not redis.exists(key):
return [] return []
status_strings = redis.lrange(key, 0, -1) status_strings: list[str] = redis.lrange(key, 0, -1) # type: ignore
for status in status_strings: for status in status_strings:
try: try:
result.append(JobStatus[status]) result.append(JobStatus[status])
except KeyError as e: except KeyError as error:
raise ValueError("impossible job status: " + status) from e raise ValueError("impossible job status: " + status) from error
return result return result
@staticmethod @staticmethod
def progress_updates(job: Job) -> typing.List[int]: def progress_updates(job: Job) -> list[int]:
result = [] result: list[int] = []
redis = RedisPool().get_connection() redis = RedisPool().get_connection()
key = _progress_log_key_from_uuid(job.uid) key = _progress_log_key_from_uuid(job.uid)
if not redis.exists(key): if not redis.exists(key):
return [] return []
progress_strings = redis.lrange(key, 0, -1) progress_strings: list[str] = redis.lrange(key, 0, -1) # type: ignore
for progress in progress_strings: for progress in progress_strings:
try: try:
result.append(int(progress)) result.append(int(progress))
except KeyError as e: except KeyError as error:
raise ValueError("impossible job progress: " + progress) from e raise ValueError("impossible job progress: " + progress) from error
return result return result
@staticmethod @staticmethod
@ -257,19 +257,19 @@ class Jobs:
return False return False
def _redis_key_from_uuid(uuid_string): def _redis_key_from_uuid(uuid_string) -> str:
return "jobs:" + str(uuid_string) return "jobs:" + str(uuid_string)
def _status_log_key_from_uuid(uuid_string): def _status_log_key_from_uuid(uuid_string) -> str:
return STATUS_LOGS_PREFIX + str(uuid_string) return STATUS_LOGS_PREFIX + str(uuid_string)
def _progress_log_key_from_uuid(uuid_string): def _progress_log_key_from_uuid(uuid_string) -> str:
return PROGRESS_LOGS_PREFIX + str(uuid_string) return PROGRESS_LOGS_PREFIX + str(uuid_string)
def _store_job_as_hash(redis, redis_key, model): def _store_job_as_hash(redis, redis_key, model) -> None:
for key, value in model.dict().items(): for key, value in model.dict().items():
if isinstance(value, uuid.UUID): if isinstance(value, uuid.UUID):
value = str(value) value = str(value)
@ -280,7 +280,7 @@ def _store_job_as_hash(redis, redis_key, model):
redis.hset(redis_key, key, str(value)) redis.hset(redis_key, key, str(value))
def _job_from_hash(redis, redis_key): def _job_from_hash(redis, redis_key) -> typing.Optional[Job]:
if redis.exists(redis_key): if redis.exists(redis_key):
job_dict = redis.hgetall(redis_key) job_dict = redis.hgetall(redis_key)
for date in [ for date in [

View File

@ -1,7 +1,7 @@
""" """
Token repository using Redis as backend. Token repository using Redis as backend.
""" """
from typing import Optional from typing import Any, Optional
from datetime import datetime from datetime import datetime
from hashlib import md5 from hashlib import md5
@ -29,15 +29,15 @@ class RedisTokensRepository(AbstractTokensRepository):
@staticmethod @staticmethod
def token_key_for_device(device_name: str): def token_key_for_device(device_name: str):
hash = md5() md5_hash = md5()
hash.update(bytes(device_name, "utf-8")) md5_hash.update(bytes(device_name, "utf-8"))
digest = hash.hexdigest() digest = md5_hash.hexdigest()
return TOKENS_PREFIX + digest return TOKENS_PREFIX + digest
def get_tokens(self) -> list[Token]: def get_tokens(self) -> list[Token]:
"""Get the tokens""" """Get the tokens"""
redis = self.connection redis = self.connection
token_keys = redis.keys(TOKENS_PREFIX + "*") token_keys: list[str] = redis.keys(TOKENS_PREFIX + "*") # type: ignore
tokens = [] tokens = []
for key in token_keys: for key in token_keys:
token = self._token_from_hash(key) token = self._token_from_hash(key)
@ -45,10 +45,10 @@ class RedisTokensRepository(AbstractTokensRepository):
tokens.append(token) tokens.append(token)
return tokens return tokens
def _discover_token_key(self, input_token: Token) -> str: def _discover_token_key(self, input_token: Token) -> Optional[str]:
"""brute-force searching for tokens, for robust deletion""" """brute-force searching for tokens, for robust deletion"""
redis = self.connection redis = self.connection
token_keys = redis.keys(TOKENS_PREFIX + "*") token_keys: list[str] = redis.keys(TOKENS_PREFIX + "*") # type: ignore
for key in token_keys: for key in token_keys:
token = self._token_from_hash(key) token = self._token_from_hash(key)
if token == input_token: if token == input_token:
@ -120,26 +120,26 @@ class RedisTokensRepository(AbstractTokensRepository):
return self._new_device_key_from_hash(NEW_DEVICE_KEY_REDIS_KEY) return self._new_device_key_from_hash(NEW_DEVICE_KEY_REDIS_KEY)
@staticmethod @staticmethod
def _is_date_key(key: str): def _is_date_key(key: str) -> bool:
return key in [ return key in [
"created_at", "created_at",
"expires_at", "expires_at",
] ]
@staticmethod @staticmethod
def _prepare_model_dict(d: dict): def _prepare_model_dict(model_dict: dict[str, Any]) -> None:
date_keys = [key for key in d.keys() if RedisTokensRepository._is_date_key(key)] date_keys = [key for key in model_dict.keys() if RedisTokensRepository._is_date_key(key)]
for date in date_keys: for date in date_keys:
if d[date] != "None": if model_dict[date] != "None":
d[date] = datetime.fromisoformat(d[date]) model_dict[date] = datetime.fromisoformat(model_dict[date])
for key in d.keys(): for key in model_dict.keys():
if d[key] == "None": if model_dict[key] == "None":
d[key] = None model_dict[key] = None
def _model_dict_from_hash(self, redis_key: str) -> Optional[dict]: def _model_dict_from_hash(self, redis_key: str) -> Optional[dict[str, Any]]:
redis = self.connection redis = self.connection
if redis.exists(redis_key): if redis.exists(redis_key):
token_dict = redis.hgetall(redis_key) token_dict: dict[str, Any] = redis.hgetall(redis_key) # type: ignore
RedisTokensRepository._prepare_model_dict(token_dict) RedisTokensRepository._prepare_model_dict(token_dict)
return token_dict return token_dict
return None return None

View File

@ -1,9 +1,9 @@
""" """
Redis pool module for selfprivacy_api Redis pool module for selfprivacy_api
""" """
from os import environ
import redis import redis
from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass
from os import environ
REDIS_SOCKET = "/run/redis-sp-api/redis.sock" REDIS_SOCKET = "/run/redis-sp-api/redis.sock"
@ -14,7 +14,7 @@ class RedisPool(metaclass=SingletonMetaclass):
""" """
def __init__(self): def __init__(self):
if "USE_REDIS_PORT" in environ.keys(): if "USE_REDIS_PORT" in environ:
self._pool = redis.ConnectionPool( self._pool = redis.ConnectionPool(
host="127.0.0.1", host="127.0.0.1",
port=int(environ["USE_REDIS_PORT"]), port=int(environ["USE_REDIS_PORT"]),