diff --git a/selfprivacy_api/services/service.py b/selfprivacy_api/services/service.py index 286fab7..e2c7c01 100644 --- a/selfprivacy_api/services/service.py +++ b/selfprivacy_api/services/service.py @@ -10,6 +10,7 @@ from selfprivacy_api.utils.block_devices import BlockDevice from selfprivacy_api.services.generic_size_counter import get_storage_usage from selfprivacy_api.services.owned_path import OwnedPath +from selfprivacy_api.utils.waitloop import wait_until_true class ServiceStatus(Enum): @@ -245,3 +246,32 @@ class Service(ABC): def post_restore(self): pass + + +class StoppedService: + """ + A context manager that stops the service if needed and reactivates it + after you are done if it was active + + Example: + ``` + assert service.get_status() == ServiceStatus.ACTIVE + with StoppedService(service) [as stopped_service]: + assert service.get_status() == ServiceStatus.INACTIVE + ``` + """ + def __init__(self, service: Service): + self.service = service + self.original_status = service.get_status() + + def __enter__(self) -> Service: + self.original_status = self.service.get_status() + if self.original_status != ServiceStatus.INACTIVE: + self.service.stop() + wait_until_true(lambda: self.service.get_status() == ServiceStatus.INACTIVE) + return self.service + + def __exit__(self, type, value, traceback): + if self.original_status in [ServiceStatus.ACTIVATING, ServiceStatus.ACTIVE]: + self.service.start() + wait_until_true(lambda: self.service.get_status() == ServiceStatus.ACTIVE) diff --git a/tests/test_services.py b/tests/test_services.py index 12889c3..b83a7f2 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -9,7 +9,7 @@ from selfprivacy_api.services.owned_path import OwnedPath from selfprivacy_api.services.generic_service_mover import FolderMoveNames from selfprivacy_api.services.test_service import DummyService -from selfprivacy_api.services.service import Service, ServiceStatus +from selfprivacy_api.services.service import Service, ServiceStatus, StoppedService from selfprivacy_api.utils.waitloop import wait_until_true from tests.test_graphql.test_backup import raw_dummy_service @@ -28,6 +28,19 @@ def test_unimplemented_folders_raises(): assert owned_folders is not None +def test_service_stopper(raw_dummy_service): + dummy: Service = raw_dummy_service + dummy.set_delay(0.3) + + assert dummy.get_status() == ServiceStatus.ACTIVE + + with StoppedService(dummy) as stopped_dummy: + assert stopped_dummy.get_status() == ServiceStatus.INACTIVE + assert dummy.get_status() == ServiceStatus.INACTIVE + + assert dummy.get_status() == ServiceStatus.ACTIVE + + def test_delayed_start_stop(raw_dummy_service): dummy = raw_dummy_service dummy.set_delay(0.3)