diff --git a/selfprivacy_api/app.py b/selfprivacy_api/app.py index e7c8f924..ab412c97 100644 --- a/selfprivacy_api/app.py +++ b/selfprivacy_api/app.py @@ -22,6 +22,8 @@ def create_app(): api = Api(app) app.config["AUTH_TOKEN"] = os.environ.get("AUTH_TOKEN") + if app.config["AUTH_TOKEN"] is None: + raise ValueError("AUTH_TOKEN is not set") app.config["ENABLE_SWAGGER"] = os.environ.get("ENABLE_SWAGGER", "0") # Check bearer token @@ -49,7 +51,7 @@ def create_app(): def spec(): if app.config["ENABLE_SWAGGER"] == "1": swag = swagger(app) - swag["info"]["version"] = "1.0.0" + swag["info"]["version"] = "1.1.0" swag["info"]["title"] = "SelfPrivacy API" swag["info"]["description"] = "SelfPrivacy API" swag["securityDefinitions"] = { diff --git a/selfprivacy_api/resources/common.py b/selfprivacy_api/resources/common.py index fa96a394..a9663aa9 100644 --- a/selfprivacy_api/resources/common.py +++ b/selfprivacy_api/resources/common.py @@ -24,7 +24,7 @@ class ApiVersion(Resource): 401: description: Unauthorized """ - return {"version": "1.0.0"} + return {"version": "1.1.0"} class DecryptDisk(Resource): diff --git a/selfprivacy_api/resources/services/bitwarden.py b/selfprivacy_api/resources/services/bitwarden.py index 5c037c92..412ba8ab 100644 --- a/selfprivacy_api/resources/services/bitwarden.py +++ b/selfprivacy_api/resources/services/bitwarden.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """Bitwarden management module""" -import json -import portalocker from flask_restful import Resource from selfprivacy_api.resources.services import api +from selfprivacy_api.utils import WriteUserData class EnableBitwarden(Resource): @@ -24,20 +23,10 @@ class EnableBitwarden(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "bitwarden" not in data: - data["bitwarden"] = {} - data["bitwarden"]["enable"] = True - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "bitwarden" not in data: + data["bitwarden"] = {} + data["bitwarden"]["enable"] = True return { "status": 0, @@ -62,20 +51,10 @@ class DisableBitwarden(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "bitwarden" not in data: - data["bitwarden"] = {} - data["bitwarden"]["enable"] = False - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "bitwarden" not in data: + data["bitwarden"] = {} + data["bitwarden"]["enable"] = False return { "status": 0, diff --git a/selfprivacy_api/resources/services/gitea.py b/selfprivacy_api/resources/services/gitea.py index 4ae0b6a3..bd4b8dec 100644 --- a/selfprivacy_api/resources/services/gitea.py +++ b/selfprivacy_api/resources/services/gitea.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """Gitea management module""" -import json -import portalocker from flask_restful import Resource from selfprivacy_api.resources.services import api +from selfprivacy_api.utils import WriteUserData class EnableGitea(Resource): @@ -24,20 +23,10 @@ class EnableGitea(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "gitea" not in data: - data["gitea"] = {} - data["gitea"]["enable"] = True - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "gitea" not in data: + data["gitea"] = {} + data["gitea"]["enable"] = True return { "status": 0, @@ -62,20 +51,10 @@ class DisableGitea(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "gitea" not in data: - data["gitea"] = {} - data["gitea"]["enable"] = False - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "gitea" not in data: + data["gitea"] = {} + data["gitea"]["enable"] = False return { "status": 0, diff --git a/selfprivacy_api/resources/services/nextcloud.py b/selfprivacy_api/resources/services/nextcloud.py index fc0bdbe6..3aa9d06b 100644 --- a/selfprivacy_api/resources/services/nextcloud.py +++ b/selfprivacy_api/resources/services/nextcloud.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """Nextcloud management module""" -import json -import portalocker from flask_restful import Resource from selfprivacy_api.resources.services import api +from selfprivacy_api.utils import WriteUserData class EnableNextcloud(Resource): @@ -24,20 +23,10 @@ class EnableNextcloud(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "nextcloud" not in data: - data["nextcloud"] = {} - data["nextcloud"]["enable"] = True - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "nextcloud" not in data: + data["nextcloud"] = {} + data["nextcloud"]["enable"] = True return { "status": 0, @@ -62,20 +51,10 @@ class DisableNextcloud(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "nextcloud" not in data: - data["nextcloud"] = {} - data["nextcloud"]["enable"] = False - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "nextcloud" not in data: + data["nextcloud"] = {} + data["nextcloud"]["enable"] = False return { "status": 0, diff --git a/selfprivacy_api/resources/services/ocserv.py b/selfprivacy_api/resources/services/ocserv.py index 6ef56677..4dc83da4 100644 --- a/selfprivacy_api/resources/services/ocserv.py +++ b/selfprivacy_api/resources/services/ocserv.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """OpenConnect VPN server management module""" -import json -import portalocker from flask_restful import Resource from selfprivacy_api.resources.services import api +from selfprivacy_api.utils import WriteUserData class EnableOcserv(Resource): @@ -24,20 +23,10 @@ class EnableOcserv(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "ocserv" not in data: - data["ocserv"] = {} - data["ocserv"]["enable"] = True - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "ocserv" not in data: + data["ocserv"] = {} + data["ocserv"]["enable"] = True return { "status": 0, @@ -62,20 +51,10 @@ class DisableOcserv(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "ocserv" not in data: - data["ocserv"] = {} - data["ocserv"]["enable"] = False - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "ocserv" not in data: + data["ocserv"] = {} + data["ocserv"]["enable"] = False return { "status": 0, diff --git a/selfprivacy_api/resources/services/pleroma.py b/selfprivacy_api/resources/services/pleroma.py index 201a5a65..aaf08f03 100644 --- a/selfprivacy_api/resources/services/pleroma.py +++ b/selfprivacy_api/resources/services/pleroma.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """Pleroma management module""" -import json -import portalocker from flask_restful import Resource from selfprivacy_api.resources.services import api +from selfprivacy_api.utils import WriteUserData class EnablePleroma(Resource): @@ -24,20 +23,10 @@ class EnablePleroma(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "pleroma" not in data: - data["pleroma"] = {} - data["pleroma"]["enable"] = True - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "pleroma" not in data: + data["pleroma"] = {} + data["pleroma"]["enable"] = True return { "status": 0, @@ -62,20 +51,10 @@ class DisablePleroma(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "pleroma" not in data: - data["pleroma"] = {} - data["pleroma"]["enable"] = False - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "pleroma" not in data: + data["pleroma"] = {} + data["pleroma"]["enable"] = False return { "status": 0, diff --git a/selfprivacy_api/resources/services/ssh.py b/selfprivacy_api/resources/services/ssh.py index 86ecc90a..d924660e 100644 --- a/selfprivacy_api/resources/services/ssh.py +++ b/selfprivacy_api/resources/services/ssh.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """SSH management module""" -import json -import portalocker from flask_restful import Resource, reqparse from selfprivacy_api.resources.services import api +from selfprivacy_api.utils import WriteUserData, ReadUserData class EnableSSH(Resource): @@ -24,20 +23,10 @@ class EnableSSH(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "ssh" not in data: - data["ssh"] = {} - data["ssh"]["enable"] = True - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + if "ssh" not in data: + data["ssh"] = {} + data["ssh"]["enable"] = True return { "status": 0, @@ -45,6 +34,82 @@ class EnableSSH(Resource): } +class SSHSettings(Resource): + """Enable/disable SSH""" + + def get(self): + """ + Get current SSH settings + --- + tags: + - SSH + security: + - bearerAuth: [] + responses: + 200: + description: SSH settings + 400: + description: Bad request + """ + with ReadUserData() as data: + if "ssh" not in data: + return {"enable": True, "passwordAuthentication": True} + if "enable" not in data["ssh"]: + data["ssh"]["enable"] = True + if "passwordAuthentication" not in data["ssh"]: + data["ssh"]["passwordAuthentication"] = True + return { + "enable": data["ssh"]["enable"], + "passwordAuthentication": data["ssh"]["passwordAuthentication"], + } + + def put(self): + """ + Change SSH settings + --- + tags: + - SSH + security: + - bearerAuth: [] + parameters: + - name: sshSettings + in: body + required: true + description: SSH settings + schema: + type: object + required: + - enable + - passwordAuthentication + properties: + enable: + type: boolean + passwordAuthentication: + type: boolean + responses: + 200: + description: New settings saved + 400: + description: Bad request + """ + parser = reqparse.RequestParser() + parser.add_argument("enable", type=bool, required=False) + parser.add_argument("passwordAuthentication", type=bool, required=False) + args = parser.parse_args() + enable = args["enable"] + password_authentication = args["passwordAuthentication"] + + with WriteUserData() as data: + if "ssh" not in data: + data["ssh"] = {} + if enable is not None: + data["ssh"]["enable"] = enable + if password_authentication is not None: + data["ssh"]["passwordAuthentication"] = password_authentication + + return "SSH settings changed" + + class WriteSSHKey(Resource): """Write new SSH key""" @@ -89,28 +154,26 @@ class WriteSSHKey(Resource): public_key = args["public_key"] - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - if "ssh" not in data: - data["ssh"] = {} - if "rootKeys" not in data["ssh"]: - data["ssh"]["rootKeys"] = [] - # Return 409 if key already in array - for key in data["ssh"]["rootKeys"]: - if key == public_key: - return { - "error": "Key already exists", - }, 409 - data["ssh"]["rootKeys"].append(public_key) - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + # Validate SSH public key + # It may be ssh-ed25519 or ssh-rsa + if not public_key.startswith("ssh-ed25519"): + if not public_key.startswith("ssh-rsa"): + return { + "error": "Invalid key type. Only ssh-ed25519 and ssh-rsa are supported.", + }, 400 + + with WriteUserData() as data: + if "ssh" not in data: + data["ssh"] = {} + if "rootKeys" not in data["ssh"]: + data["ssh"]["rootKeys"] = [] + # Return 409 if key already in array + for key in data["ssh"]["rootKeys"]: + if key == public_key: + return { + "error": "Key already exists", + }, 409 + data["ssh"]["rootKeys"].append(public_key) return { "status": 0, @@ -118,5 +181,225 @@ class WriteSSHKey(Resource): }, 201 +class SSHKeys(Resource): + """List SSH keys""" + + def get(self, username): + """ + List SSH keys + --- + tags: + - SSH + security: + - bearerAuth: [] + parameters: + - in: path + name: username + type: string + required: true + description: User to list keys for + responses: + 200: + description: SSH keys + 401: + description: Unauthorized + """ + with ReadUserData() as data: + if username == "root": + if "ssh" not in data: + data["ssh"] = {} + if "rootKeys" not in data["ssh"]: + data["ssh"]["rootKeys"] = [] + return data["ssh"]["rootKeys"] + if username == data["username"]: + if "sshKeys" not in data: + data["sshKeys"] = [] + return data["sshKeys"] + else: + if "users" not in data: + data["users"] = [] + for user in data["users"]: + if user["name"] == username: + if "sshKeys" not in user: + user["sshKeys"] = [] + return user["ssh"]["sshKeys"] + return { + "error": "User not found", + }, 404 + + def post(self, username): + """ + Add SSH key to the user + --- + tags: + - SSH + security: + - bearerAuth: [] + parameters: + - in: body + required: true + name: public_key + schema: + type: object + required: + - public_key + properties: + public_key: + type: string + - in: path + name: username + type: string + required: true + description: User to add keys for + responses: + 201: + description: SSH key added + 401: + description: Unauthorized + 404: + description: User not found + 409: + description: Key already exists + """ + parser = reqparse.RequestParser() + parser.add_argument( + "public_key", type=str, required=True, help="Key cannot be blank!" + ) + args = parser.parse_args() + + if username == "root": + return { + "error": "Use /ssh/key/send to add root keys", + }, 400 + + # Validate SSH public key + # It may be ssh-ed25519 or ssh-rsa + if not args["public_key"].startswith("ssh-ed25519"): + if not args["public_key"].startswith("ssh-rsa"): + return { + "error": "Invalid key type. Only ssh-ed25519 and ssh-rsa are supported.", + }, 400 + + with WriteUserData() as data: + if username == data["username"]: + if "sshKeys" not in data: + data["sshKeys"] = [] + data["sshKeys"].append(args["public_key"]) + return { + "message": "New SSH key successfully written", + }, 201 + + if "users" not in data: + data["users"] = [] + for user in data["users"]: + if user["username"] == username: + if "sshKeys" not in user: + user["sshKeys"] = [] + # Return 409 if key already in array + for key in user["sshKeys"]: + if key == args["public_key"]: + return { + "error": "Key already exists", + }, 409 + user["sshKeys"].append(args["public_key"]) + return { + "message": "New SSH key successfully written", + }, 201 + return { + "error": "User not found", + }, 404 + + def delete(self, username): + """ + Delete SSH key + --- + tags: + - SSH + security: + - bearerAuth: [] + parameters: + - in: body + name: public_key + required: true + description: Key to delete + schema: + type: object + required: + - public_key + properties: + public_key: + type: string + - in: path + name: username + type: string + required: true + description: User to delete keys for + responses: + 200: + description: SSH key deleted + 401: + description: Unauthorized + 404: + description: Key not found + """ + parser = reqparse.RequestParser() + parser.add_argument( + "public_key", type=str, required=True, help="Key cannot be blank!" + ) + args = parser.parse_args() + + with WriteUserData() as data: + if username == "root": + if "ssh" not in data: + data["ssh"] = {} + if "rootKeys" not in data["ssh"]: + data["ssh"]["rootKeys"] = [] + # Return 404 if key not in array + for key in data["ssh"]["rootKeys"]: + if key == args["public_key"]: + data["ssh"]["rootKeys"].remove(key) + return { + "message": "SSH key deleted", + }, 200 + return { + "error": "Key not found", + }, 404 + if username == data["username"]: + if "sshKeys" not in data: + data["sshKeys"] = [] + # Return 404 if key not in array + for key in data["sshKeys"]: + if key == args["public_key"]: + data["sshKeys"].remove(key) + return { + "message": "SSH key deleted", + }, 200 + return { + "error": "Key not found", + }, 404 + if "users" not in data: + data["users"] = [] + for user in data["users"]: + if user["username"] == username: + if "sshKeys" not in user: + user["sshKeys"] = [] + # Return 404 if key not in array + for key in user["sshKeys"]: + if key == args["public_key"]: + user["sshKeys"].remove(key) + return { + "message": "SSH key successfully deleted", + }, 200 + return { + "error": "Key not found", + }, 404 + return { + "error": "User not found", + }, 404 + + api.add_resource(EnableSSH, "/ssh/enable") +api.add_resource(SSHSettings, "/ssh") + api.add_resource(WriteSSHKey, "/ssh/key/send") +api.add_resource(SSHKeys, "/ssh/keys/") diff --git a/selfprivacy_api/resources/system.py b/selfprivacy_api/resources/system.py index 7c3ad778..ddce3be9 100644 --- a/selfprivacy_api/resources/system.py +++ b/selfprivacy_api/resources/system.py @@ -1,13 +1,149 @@ #!/usr/bin/env python3 """System management module""" import subprocess +import pytz from flask import Blueprint -from flask_restful import Resource, Api +from flask_restful import Resource, Api, reqparse + +from selfprivacy_api.utils import WriteUserData, ReadUserData api_system = Blueprint("system", __name__, url_prefix="/system") api = Api(api_system) +class Timezone(Resource): + """Change timezone of NixOS""" + + def get(self): + """ + Get current system timezone + --- + tags: + - System + security: + - bearerAuth: [] + responses: + 200: + description: Timezone + 400: + description: Bad request + """ + with ReadUserData() as data: + if "timezone" not in data: + return "Europe/Uzhgorod" + return data["timezone"] + + def put(self): + """ + Change system timezone + --- + tags: + - System + security: + - bearerAuth: [] + parameters: + - name: timezone + in: body + required: true + description: Timezone to set + schema: + type: object + required: + - timezone + properties: + timezone: + type: string + responses: + 200: + description: Timezone changed + 400: + description: Bad request + """ + parser = reqparse.RequestParser() + parser.add_argument("timezone", type=str, required=True) + timezone = parser.parse_args()["timezone"] + + # Check if timezone is a valid tzdata string + if timezone not in pytz.all_timezones: + return {"error": "Invalid timezone"}, 400 + + with WriteUserData() as data: + data["timezone"] = timezone + return "Timezone changed" + + +class AutoUpgrade(Resource): + """Enable/disable automatic upgrades and reboots""" + + def get(self): + """ + Get current system autoupgrade settings + --- + tags: + - System + security: + - bearerAuth: [] + responses: + 200: + description: Auto-upgrade settings + 400: + description: Bad request + """ + with ReadUserData() as data: + if "autoUpgrade" not in data: + return {"enable": True, "allowReboot": False} + if "enable" not in data["autoUpgrade"]: + data["autoUpgrade"]["enable"] = True + if "allowReboot" not in data["autoUpgrade"]: + data["autoUpgrade"]["allowReboot"] = False + return data["autoUpgrade"] + + def put(self): + """ + Change system auto upgrade settings + --- + tags: + - System + security: + - bearerAuth: [] + parameters: + - name: autoUpgrade + in: body + required: true + description: Auto upgrade settings + schema: + type: object + required: + - enable + - allowReboot + properties: + enable: + type: boolean + allowReboot: + type: boolean + responses: + 200: + description: New settings saved + 400: + description: Bad request + """ + parser = reqparse.RequestParser() + parser.add_argument("enable", type=bool, required=False) + parser.add_argument("allowReboot", type=bool, required=False) + args = parser.parse_args() + enable = args["enable"] + allow_reboot = args["allowReboot"] + + with WriteUserData() as data: + if "autoUpgrade" not in data: + data["autoUpgrade"] = {} + if enable is not None: + data["autoUpgrade"]["enable"] = enable + if allow_reboot is not None: + data["autoUpgrade"]["allowReboot"] = allow_reboot + return "Auto-upgrade settings changed" + + class RebuildSystem(Resource): """Rebuild NixOS""" @@ -145,6 +281,8 @@ class PythonVersion(Resource): return subprocess.check_output(["python", "-V"]).decode("utf-8").strip() +api.add_resource(Timezone, "/configuration/timezone") +api.add_resource(AutoUpgrade, "/configuration/autoUpgrade") api.add_resource(RebuildSystem, "/configuration/apply") api.add_resource(RollbackSystem, "/configuration/rollback") api.add_resource(UpgradeSystem, "/configuration/upgrade") diff --git a/selfprivacy_api/resources/users.py b/selfprivacy_api/resources/users.py index 2f373bce..057a5e30 100644 --- a/selfprivacy_api/resources/users.py +++ b/selfprivacy_api/resources/users.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 """Users management module""" import subprocess -import json import re -import portalocker from flask_restful import Resource, reqparse +from selfprivacy_api.utils import WriteUserData, ReadUserData + class Users(Resource): """Users management""" @@ -24,17 +24,10 @@ class Users(Resource): 401: description: Unauthorized """ - with open( - "/etc/nixos/userdata/userdata.json", "r", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_SH) - try: - data = json.load(userdata_file) - users = [] - for user in data["users"]: - users.append(user["username"]) - finally: - portalocker.unlock(userdata_file) + with ReadUserData() as data: + users = [] + for user in data["users"]: + users.append(user["username"]) return users def post(self): @@ -97,32 +90,21 @@ class Users(Resource): if len(args["username"]) > 32: return {"error": "username must be less than 32 characters"}, 400 - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) + with WriteUserData() as data: + if "users" not in data: + data["users"] = [] - if "users" not in data: - data["users"] = [] + # Return 400 if user already exists + for user in data["users"]: + if user["username"] == args["username"]: + return {"error": "User already exists"}, 409 - # Return 400 if user already exists - for user in data["users"]: - if user["username"] == args["username"]: - return {"error": "User already exists"}, 409 - - data["users"].append( - { - "username": args["username"], - "hashedPassword": hashed_password, - } - ) - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + data["users"].append( + { + "username": args["username"], + "hashedPassword": hashed_password, + } + ) return {"result": 0, "username": args["username"]}, 201 @@ -154,29 +136,18 @@ class User(Resource): 404: description: User not found """ - with open( - "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" - ) as userdata_file: - portalocker.lock(userdata_file, portalocker.LOCK_EX) - try: - data = json.load(userdata_file) - # Return 400 if username is not provided - if username is None: - return {"error": "username is required"}, 400 - if username == data["username"]: - return {"error": "Cannot delete root user"}, 400 - # Return 400 if user does not exist - for user in data["users"]: - if user["username"] == username: - data["users"].remove(user) - break - else: - return {"error": "User does not exist"}, 404 - - userdata_file.seek(0) - json.dump(data, userdata_file, indent=4) - userdata_file.truncate() - finally: - portalocker.unlock(userdata_file) + with WriteUserData() as data: + # Return 400 if username is not provided + if username is None: + return {"error": "username is required"}, 400 + if username == data["username"]: + return {"error": "Cannot delete root user"}, 400 + # Return 400 if user does not exist + for user in data["users"]: + if user["username"] == username: + data["users"].remove(user) + break + else: + return {"error": "User does not exist"}, 404 return {"result": 0, "username": username} diff --git a/selfprivacy_api/utils.py b/selfprivacy_api/utils.py index b7ef2a82..8a8006c0 100644 --- a/selfprivacy_api/utils.py +++ b/selfprivacy_api/utils.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 """Various utility functions""" +import json +import portalocker def get_domain(): @@ -7,3 +9,43 @@ def get_domain(): with open("/var/domain", "r", encoding="utf-8") as domain_file: domain = domain_file.readline().rstrip() return domain + + +class WriteUserData(object): + """Write userdata.json with lock""" + + def __init__(self): + self.userdata_file = open( + "/etc/nixos/userdata/userdata.json", "r+", encoding="utf-8" + ) + portalocker.lock(self.userdata_file, portalocker.LOCK_EX) + self.data = json.load(self.userdata_file) + + def __enter__(self): + return self.data + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self.userdata_file.seek(0) + json.dump(self.data, self.userdata_file, indent=4) + self.userdata_file.truncate() + portalocker.unlock(self.userdata_file) + self.userdata_file.close() + + +class ReadUserData(object): + """Read userdata.json with lock""" + + def __init__(self): + self.userdata_file = open( + "/etc/nixos/userdata/userdata.json", "r", encoding="utf-8" + ) + portalocker.lock(self.userdata_file, portalocker.LOCK_SH) + self.data = json.load(self.userdata_file) + + def __enter__(self): + return self.data + + def __exit__(self, *args): + portalocker.unlock(self.userdata_file) + self.userdata_file.close()