big refactoring

pull/12/head
def 2022-07-30 23:02:19 +02:00
parent 973af08523
commit 4de85acf60
8 changed files with 235 additions and 296 deletions

View File

@ -5,7 +5,7 @@ from selfprivacy_api.graphql.mutations.mutation_interface import (
) )
from enum import Enum from enum import Enum
from selfprivacy_api.utils import ReadUserData from selfprivacy_api.utils import ReadUserData, ensure_ssh_and_users_fields_exist
@strawberry.enum @strawberry.enum
@ -34,40 +34,29 @@ class UserMutationReturn(MutationReturnInterface):
def get_user_by_username(username: str) -> typing.Optional[User]: def get_user_by_username(username: str) -> typing.Optional[User]:
with ReadUserData() as data: with ReadUserData() as data:
ensure_ssh_and_users_fields_exist(data)
if username == "root": if username == "root":
if "ssh" not in data:
data["ssh"] = []
elif data["ssh"].get("rootKeys") is None:
data["ssh"]["rootKeys"] = []
return User( return User(
user_type=UserType.ROOT, user_type=UserType.ROOT,
username="root", username="root",
ssh_keys=data["ssh"]["rootKeys"], ssh_keys=data["ssh"]["rootKeys"],
) )
elif username == data["username"]:
if "sshKeys" not in data:
data["sshKeys"] = []
if username == data["username"]:
return User( return User(
user_type=UserType.PRIMARY, user_type=UserType.PRIMARY,
username=username, username=username,
ssh_keys=data["sshKeys"], ssh_keys=data["sshKeys"],
) )
else:
if "users" not in data:
data["users"] = []
for user in data["users"]: for user in data["users"]:
if user["username"] == username: if user["username"] == username:
if "sshKeys" not in user:
user["sshKeys"] = [] return User(
user_type=UserType.NORMAL,
username=username,
ssh_keys=user["sshKeys"],
)
return User(
user_type=UserType.NORMAL,
username=username,
ssh_keys=user["sshKeys"],
)
return None return None

View File

@ -4,17 +4,16 @@
import strawberry import strawberry
from selfprivacy_api.graphql.mutations.ssh_utils import (
create_ssh_key,
delete_ssh_key,
)
from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.common_types.user import ( from selfprivacy_api.graphql.common_types.user import (
UserMutationReturn, UserMutationReturn,
get_user_by_username, get_user_by_username,
) )
from selfprivacy_api.utils import (
WriteUserData,
validate_ssh_public_key,
)
@strawberry.input @strawberry.input
class SshMutationInput: class SshMutationInput:
@ -32,145 +31,24 @@ class SshMutations:
def create_ssh(self, ssh_input: SshMutationInput) -> UserMutationReturn: def create_ssh(self, ssh_input: SshMutationInput) -> UserMutationReturn:
"""Create a new ssh""" """Create a new ssh"""
if not validate_ssh_public_key(ssh_input.ssh_key): success, message, code = create_ssh_key(ssh_input.username, ssh_input.ssh_key)
return UserMutationReturn(
success=False,
message="Invalid key type. Only ssh-ed25519 and ssh-rsa are supported",
code=400,
user=get_user_by_username(ssh_input.username),
)
with WriteUserData() as data: return UserMutationReturn(
success=success,
if ssh_input.username == data["username"]: message=message,
if "sshKeys" not in data: code=code,
data["sshKeys"] = [] user=get_user_by_username(ssh_input.username),
# Return 409 if key already in array )
for key in data["sshKeys"]:
if key == ssh_input.ssh_key:
return UserMutationReturn(
success=False,
message="Key already exists",
code=409,
user=get_user_by_username(ssh_input.username),
)
data["sshKeys"].append(ssh_input.ssh_key)
return UserMutationReturn(
success=True,
message="New SSH key successfully written",
code=201,
user=get_user_by_username(ssh_input.username),
)
if "users" not in data:
data["users"] = []
for user in data["users"]:
if user["username"] == ssh_input.username:
if "sshKeys" not in user:
user["sshKeys"] = []
# Return 409 if key already in array
for key in user["sshKeys"]:
if key == ssh_input.ssh_key:
return UserMutationReturn(
success=False,
message="Key already exists",
code=409,
user=get_user_by_username(ssh_input.username),
)
user["sshKeys"].append(ssh_input.ssh_key)
return UserMutationReturn(
success=True,
message="New SSH key successfully written",
code=201,
user=get_user_by_username(ssh_input.username),
)
return UserMutationReturn(
success=False,
message="User not found",
code=404,
user=None,
)
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def delete_ssh(self, ssh_input: SshMutationInput) -> UserMutationReturn: def delete_ssh(self, ssh_input: SshMutationInput) -> UserMutationReturn:
"""Delete ssh key from user""" """Delete ssh key from user"""
with WriteUserData() as data: success, message, code = delete_ssh_key(ssh_input.username, ssh_input.ssh_key)
if ssh_input.username == "root":
if "ssh" not in data:
data["ssh"] = {}
if "rootKeys" not in data["ssh"]:
data["ssh"]["rootKeys"] = []
# Return 404 if key not in array
for key in data["ssh"]["rootKeys"]:
if key == ssh_input.ssh_key:
data["ssh"]["rootKeys"].remove(key)
return UserMutationReturn(
success=True,
message="SSH key deleted",
code=200,
user=get_user_by_username(ssh_input.username),
)
return UserMutationReturn(
success=False,
message="Key not found",
code=404,
user=get_user_by_username(ssh_input.username),
)
if ssh_input.username == data["username"]:
if "sshKeys" not in data:
data["sshKeys"] = []
# Return 404 if key not in array
for key in data["sshKeys"]:
if key == ssh_input.ssh_key:
data["sshKeys"].remove(key)
return UserMutationReturn(
success=True,
message="SSH key deleted",
code=200,
user=get_user_by_username(ssh_input.username),
)
return UserMutationReturn(
success=False,
message="Key not found",
code=404,
user=get_user_by_username(ssh_input.username),
)
if "users" not in data:
data["users"] = []
for user in data["users"]:
if user["username"] == ssh_input.username:
if "sshKeys" not in user:
user["sshKeys"] = []
# Return 404 if key not in array
for key in user["sshKeys"]:
if key == ssh_input.ssh_key:
user["sshKeys"].remove(key)
return UserMutationReturn(
success=True,
message="SSH key deleted",
code=200,
user=get_user_by_username(ssh_input.username),
)
return UserMutationReturn(
success=False,
message="Key not found",
code=404,
user=get_user_by_username(ssh_input.username),
)
return UserMutationReturn( return UserMutationReturn(
success=False, success=success,
message="User not found", message=message,
code=404, code=code,
user=None, user=get_user_by_username(ssh_input.username),
) )

View File

@ -0,0 +1,74 @@
from selfprivacy_api.utils import (
WriteUserData,
ensure_ssh_and_users_fields_exist,
validate_ssh_public_key,
)
def create_ssh_key(username, ssh_key):
"""Create a new ssh key"""
if not validate_ssh_public_key(ssh_key):
return (
False,
"Invalid key type. Only ssh-ed25519 and ssh-rsa are supported",
400,
)
with WriteUserData() as data:
ensure_ssh_and_users_fields_exist(data)
if username == data["username"]:
if ssh_key in data["sshKeys"]:
return False, "Key already exists", 409
data["sshKeys"].append(ssh_key)
return True, "New SSH key successfully written", 201
if username == "root":
if ssh_key in data["ssh"]["rootKeys"]:
return False, "Key already exists", 409
data["ssh"]["rootKeys"].append(ssh_key)
return True, "New SSH key successfully written", 201
for user in data["users"]:
if user["username"] == username:
if ssh_key in user["sshKeys"]:
return False, "Key already exists", 409
user["sshKeys"].append(ssh_key)
return True, "New SSH key successfully written", 201
return False, "User not found", 404
def delete_ssh_key(username, ssh_key):
"""Delete a ssh key"""
with WriteUserData() as data:
ensure_ssh_and_users_fields_exist(data)
if username == "root":
if ssh_key in data["ssh"]["rootKeys"]:
data["ssh"]["rootKeys"].remove(ssh_key)
return True, "SSH key deleted", 200
return False, "Key not found", 404
if username == data["username"]:
if ssh_key in data["sshKeys"]:
data["sshKeys"].remove(ssh_key)
return True, "SSH key deleted", 200
return False, "Key not found", 404
for user in data["users"]:
if user["username"] == username:
if ssh_key in user["sshKeys"]:
user["sshKeys"].remove(ssh_key)
return True, "SSH key deleted", 200
return False, "Key not found", 404
return False, "User not found", 404

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Users management module""" """Users management module"""
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
import re
import strawberry import strawberry
from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.common_types.user import ( from selfprivacy_api.graphql.common_types.user import (
@ -11,12 +10,11 @@ from selfprivacy_api.graphql.common_types.user import (
from selfprivacy_api.graphql.mutations.mutation_interface import ( from selfprivacy_api.graphql.mutations.mutation_interface import (
GenericMutationReturn, GenericMutationReturn,
) )
from selfprivacy_api.utils import ( from selfprivacy_api.graphql.mutations.users_utils import (
WriteUserData, create_user_util,
ReadUserData, delete_user_util,
is_username_forbidden, update_user_util,
) )
from selfprivacy_api.utils import hash_password
@strawberry.input @strawberry.input
@ -33,140 +31,35 @@ class UserMutations:
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def create_user(self, user: UserMutationInput) -> UserMutationReturn: def create_user(self, user: UserMutationInput) -> UserMutationReturn:
"""Create a new user"""
# Check if password is null or none success, message, code = create_user_util(user.username, user.password)
if user.password is None or user.password == "":
return UserMutationReturn(
success=False,
message="Password is none or null",
code=400,
user=None,
)
hashed_password = hash_password(user.password)
# Check if username is forbidden
if is_username_forbidden(user.username):
return UserMutationReturn(
success=False,
message="Username is forbidden",
code=409,
user=None,
)
# Check is username passes regex
if not re.match(r"^[a-z_][a-z0-9_]+$", user.username):
return UserMutationReturn(
success=False,
message="Username must be alphanumeric",
code=400,
user=None,
)
# Check if username less than 32 characters
if len(user.username) >= 32:
return UserMutationReturn(
success=False,
message="Username must be less than 32 characters",
code=400,
user=None,
)
with ReadUserData() as data:
if "users" not in data:
data["users"] = []
# Return 409 if user already exists
if data["username"] == user.username:
return UserMutationReturn(
success=False,
message="User already exists",
code=409,
user=None,
)
for data_user in data["users"]:
if data_user["username"] == user.username:
return UserMutationReturn(
success=False,
message="User already exists",
code=409,
user=None,
)
with WriteUserData() as data:
if "users" not in data:
data["users"] = []
data["users"].append(
{
"username": user.username,
"hashedPassword": hashed_password,
}
)
return UserMutationReturn( return UserMutationReturn(
success=True, success=success,
message="User was successfully created!", message=message,
code=201, code=code,
user=get_user_by_username(user.username), user=get_user_by_username(user.username),
) )
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def delete_user(self, username: str) -> GenericMutationReturn: def delete_user(self, username: str) -> GenericMutationReturn:
with WriteUserData() as data: success, message, code = delete_user_util(username)
if username == data["username"] or username == "root":
return GenericMutationReturn(
success=False,
message="Cannot delete main or root user",
code=400,
)
# Return 404 if user does not exist
for data_user in data["users"]:
if data_user["username"] == username:
data["users"].remove(data_user)
break
else:
return GenericMutationReturn(
success=False,
message="User does not exist",
code=404,
)
return GenericMutationReturn( return GenericMutationReturn(
success=True, success=success,
message="User was deleted", message=message,
code=200, code=code,
) )
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def update_user(self, user: UserMutationInput) -> UserMutationReturn: def update_user(self, user: UserMutationInput) -> UserMutationReturn:
"""Update user mutation""" """Update user mutation"""
hashed_password = hash_password(user.password)
with WriteUserData() as data: success, message, code = update_user_util(user.username, user.password)
if user.username == data["username"]:
data["hashedMasterPassword"] = hashed_password
# Return 404 if user does not exist
else:
for data_user in data["users"]:
if data_user["username"] == user.username:
data_user["hashedPassword"] = hashed_password
break
else:
return UserMutationReturn(
success=False,
message="User does not exist",
code=404,
user=None,
)
return UserMutationReturn( return UserMutationReturn(
success=True, success=success,
message="User was successfully updated", message=message,
code=200, code=code,
user=get_user_by_username(user.username), user=get_user_by_username(user.username),
) )

View File

@ -0,0 +1,93 @@
import re
from selfprivacy_api.utils import (
WriteUserData,
ReadUserData,
ensure_ssh_and_users_fields_exist,
is_username_forbidden,
)
from selfprivacy_api.utils import hash_password
def create_user_util(username, password):
"""Create a new user"""
# Check if password is null or none
if password == "":
return False, "Password is null", 400
# Check if username is forbidden
if is_username_forbidden(username):
return False, "Username is forbidden", 409
# Check is username passes regex
if not re.match(r"^[a-z_][a-z0-9_]+$", username):
return False, "Username must be alphanumeric", 400
# Check if username less than 32 characters
if len(username) >= 32:
return False, "Username must be less than 32 characters", 400
with ReadUserData() as data:
ensure_ssh_and_users_fields_exist(data)
# Return 409 if user already exists
if data["username"] == username:
return False, "User already exists", 409
for data_user in data["users"]:
if data_user["username"] == username:
return False, "User already exists", 409
hashed_password = hash_password(password)
with WriteUserData() as data:
ensure_ssh_and_users_fields_exist(data)
data["users"].append(
{
"username": username,
"hashedPassword": hashed_password,
}
)
return True, "User was successfully created!", 201
def delete_user_util(username):
with WriteUserData() as data:
if username == data["username"] or username == "root":
return False, "Cannot delete main or root user", 400
# Return 404 if user does not exist
for data_user in data["users"]:
if data_user["username"] == username:
data["users"].remove(data_user)
break
else:
return False, "User does not exist", 404
return True, "User was deleted", 200
def update_user_util(username, password):
# Check if password is null or none
if password == "":
return False, "Password is null", 400
hashed_password = hash_password(password)
with WriteUserData() as data:
if username == data["username"]:
data["hashedMasterPassword"] = hashed_password
# Return 404 if user does not exist
else:
for data_user in data["users"]:
if data_user["username"] == username:
data_user["hashedPassword"] = hashed_password
break
else:
return False, "User does not exist", 404
return True, "User was successfully updated", 200

View File

@ -173,3 +173,17 @@ def hash_password(password):
hashed_password = hashed_password.decode("ascii") hashed_password = hashed_password.decode("ascii")
hashed_password = hashed_password.rstrip() hashed_password = hashed_password.rstrip()
return hashed_password return hashed_password
def ensure_ssh_and_users_fields_exist(data):
if "ssh" not in data:
data["ssh"] = []
elif data["ssh"].get("rootKeys") is None:
data["ssh"]["rootKeys"] = []
if "sshKeys" not in data:
data["sshKeys"] = []
if "users" not in data:
data["users"] = []

View File

@ -92,13 +92,13 @@ def test_graphql_add_ssh(authorized_client, some_users, mock_subprocess_popen):
assert response.status_code == 200 assert response.status_code == 200
assert response.json.get("data") is not None assert response.json.get("data") is not None
assert response.json["data"]["createSsh"]["code"] == 200 assert response.json["data"]["createSsh"]["code"] == 201
assert response.json["data"]["createSsh"]["message"] is not None assert response.json["data"]["createSsh"]["message"] is not None
assert response.json["data"]["createSsh"]["success"] is True assert response.json["data"]["createSsh"]["success"] is True
assert response.json["data"]["createSsh"]["user"]["username"] == "user1" assert response.json["data"]["createSsh"]["user"]["username"] == "user1"
assert response.json["data"]["createSsh"]["user"]["sshKeys"] == [ assert response.json["data"]["createSsh"]["user"]["sshKeys"] == [
"ssh-rsa KEY tester@pc", "ssh-rsa KEY user1@pc",
"ssh-rsa KEY test_key@pc", "ssh-rsa KEY test_key@pc",
] ]
@ -119,7 +119,7 @@ def test_graphql_add_root_ssh(authorized_client, some_users, mock_subprocess_pop
assert response.status_code == 200 assert response.status_code == 200
assert response.json.get("data") is not None assert response.json.get("data") is not None
assert response.json["data"]["createSsh"]["code"] == 200 assert response.json["data"]["createSsh"]["code"] == 201
assert response.json["data"]["createSsh"]["message"] is not None assert response.json["data"]["createSsh"]["message"] is not None
assert response.json["data"]["createSsh"]["success"] is True assert response.json["data"]["createSsh"]["success"] is True
@ -146,13 +146,13 @@ def test_graphql_add_main_ssh(authorized_client, some_users, mock_subprocess_pop
assert response.status_code == 200 assert response.status_code == 200
assert response.json.get("data") is not None assert response.json.get("data") is not None
assert response.json["data"]["createSsh"]["code"] == 200 assert response.json["data"]["createSsh"]["code"] == 201
assert response.json["data"]["createSsh"]["message"] is not None assert response.json["data"]["createSsh"]["message"] is not None
assert response.json["data"]["createSsh"]["success"] is True assert response.json["data"]["createSsh"]["success"] is True
assert response.json["data"]["createSsh"]["user"]["username"] == "tester" assert response.json["data"]["createSsh"]["user"]["username"] == "tester"
assert response.json["data"]["createSsh"]["user"]["sshKeys"] == [ assert response.json["data"]["createSsh"]["user"]["sshKeys"] == [
"ssh-rsa KEY tester@pc", "ssh-rsa KEY test@pc",
"ssh-rsa KEY test_key@pc", "ssh-rsa KEY test_key@pc",
] ]
@ -252,9 +252,7 @@ def test_graphql_dell_ssh(authorized_client, some_users, mock_subprocess_popen):
assert response.json["data"]["deleteSsh"]["success"] is True assert response.json["data"]["deleteSsh"]["success"] is True
assert response.json["data"]["deleteSsh"]["user"]["username"] == "user1" assert response.json["data"]["deleteSsh"]["user"]["username"] == "user1"
assert response.json["data"]["deleteSsh"]["user"]["sshKeys"] == [ assert response.json["data"]["deleteSsh"]["user"]["sshKeys"] == []
"ssh-rsa KEY user1@pc"
]
def test_graphql_dell_root_ssh(authorized_client, some_users, mock_subprocess_popen): def test_graphql_dell_root_ssh(authorized_client, some_users, mock_subprocess_popen):
@ -289,7 +287,7 @@ def test_graphql_dell_main_ssh(authorized_client, some_users, mock_subprocess_po
"variables": { "variables": {
"sshInput": { "sshInput": {
"username": "tester", "username": "tester",
"sshKey": "ssh-rsa KEY tester@pc", "sshKey": "ssh-rsa KEY test@pc",
}, },
}, },
}, },

View File

@ -8,7 +8,6 @@ from tests.common import (
) )
invalid_usernames = [ invalid_usernames = [
"root",
"messagebus", "messagebus",
"postfix", "postfix",
"polkituser", "polkituser",
@ -468,7 +467,8 @@ def test_graphql_add_existing_user(authorized_client, one_user, mock_subprocess_
assert response.json["data"]["createUser"]["code"] == 409 assert response.json["data"]["createUser"]["code"] == 409
assert response.json["data"]["createUser"]["success"] is False assert response.json["data"]["createUser"]["success"] is False
assert response.json["data"]["createUser"]["user"] is None assert response.json["data"]["createUser"]["user"]["username"] == "user1"
assert response.json["data"]["createUser"]["user"]["sshKeys"][0] == "ssh-rsa KEY user1@pc"
def test_graphql_add_main_user(authorized_client, one_user, mock_subprocess_popen): def test_graphql_add_main_user(authorized_client, one_user, mock_subprocess_popen):
@ -491,8 +491,8 @@ def test_graphql_add_main_user(authorized_client, one_user, mock_subprocess_pope
assert response.json["data"]["createUser"]["code"] == 409 assert response.json["data"]["createUser"]["code"] == 409
assert response.json["data"]["createUser"]["success"] is False assert response.json["data"]["createUser"]["success"] is False
assert response.json["data"]["createUser"]["user"] is None assert response.json["data"]["createUser"]["user"]["username"] == "tester"
assert response.json["data"]["createUser"]["user"]["sshKeys"][0] == "ssh-rsa KEY test@pc"
def test_graphql_add_long_username(authorized_client, one_user, mock_subprocess_popen): def test_graphql_add_long_username(authorized_client, one_user, mock_subprocess_popen):
response = authorized_client.post( response = authorized_client.post(