From f53cc22cbc793b8f9075cd74010a83af4faaadcd Mon Sep 17 00:00:00 2001 From: def Date: Sun, 24 Jul 2022 16:26:13 +0200 Subject: [PATCH] fix import --- .../graphql/common_types/__init__.py | 0 selfprivacy_api/graphql/common_types/user.py | 35 +++++++++++++++++++ selfprivacy_api/utils/__init__.py | 34 ------------------ tests/common.py | 4 +++ tests/test_graphql/test_users.py | 6 ++-- 5 files changed, 42 insertions(+), 37 deletions(-) create mode 100644 selfprivacy_api/graphql/common_types/__init__.py diff --git a/selfprivacy_api/graphql/common_types/__init__.py b/selfprivacy_api/graphql/common_types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/selfprivacy_api/graphql/common_types/user.py b/selfprivacy_api/graphql/common_types/user.py index 7e1852e..4804c75 100644 --- a/selfprivacy_api/graphql/common_types/user.py +++ b/selfprivacy_api/graphql/common_types/user.py @@ -5,6 +5,8 @@ from selfprivacy_api.graphql.mutations.mutation_interface import ( ) from enum import Enum +from selfprivacy_api.utils import ReadUserData + @strawberry.enum class UserType(Enum): @@ -28,3 +30,36 @@ class UserMutationReturn(MutationReturnInterface): """Return type for user mutation""" user: typing.Optional[User] + +def get_user_by_username(username): + with ReadUserData() as data: + + if username == "root": + if data["ssh"]["rootKeys"] not in data: + data["ssh"]["rootKeys"] = [] + + return User( + user_type=UserType.ROOT, + username="root", + ssh_keys=data["ssh"]["rootKeys"], + ) + elif username == data["username"]: + if "sshKeys" not in data: + data["sshKeys"] = [] + + return User( + user_type=UserType.PRIMARY, + username=username, + ssh_keys=data["sshKeys"], + ) + else: + for user in data["users"]: + if user["username"] == username: + if "sshKeys" not in user: + user["sshKeys"] = [] + + return User( + user_type=UserType.NORMAL, + username=username, + ssh_keys=user["sshKeys"], + ) diff --git a/selfprivacy_api/utils/__init__.py b/selfprivacy_api/utils/__init__.py index 893dd52..054e52c 100644 --- a/selfprivacy_api/utils/__init__.py +++ b/selfprivacy_api/utils/__init__.py @@ -7,8 +7,6 @@ import os import subprocess import portalocker -from selfprivacy_api.graphql.common_types.user import User, UserType - USERDATA_FILE = "/etc/nixos/userdata/userdata.json" TOKENS_FILE = "/etc/nixos/userdata/tokens.json" @@ -177,35 +175,3 @@ def hash_password(password): return hashed_password -def get_user_by_username(username): - with ReadUserData() as data: - - if username == "root": - if data["ssh"]["rootKeys"] not in data: - data["ssh"]["rootKeys"] = [] - - return User( - user_type=UserType.ROOT, - username="root", - ssh_keys=data["ssh"]["rootKeys"], - ) - elif username == data["username"]: - if "sshKeys" not in data: - data["sshKeys"] = [] - - return User( - user_type=UserType.PRIMARY, - username=username, - ssh_keys=data["sshKeys"], - ) - else: - for user in data["users"]: - if user["username"] == username: - if "sshKeys" not in user: - user["sshKeys"] = [] - - return User( - user_type=UserType.NORMAL, - username=username, - ssh_keys=user["sshKeys"], - ) diff --git a/tests/common.py b/tests/common.py index 01975e8..18e065c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -20,5 +20,9 @@ def generate_system_query(query_array): return "query TestSystem {\n system {" + "\n".join(query_array) + "}\n}" +def generate_users_query(query_array): + return "query TestUsers {\n users {" + "\n".join(query_array) + "}\n}" + + def mnemonic_to_hex(mnemonic): return Mnemonic(language="english").to_entropy(mnemonic).hex() diff --git a/tests/test_graphql/test_users.py b/tests/test_graphql/test_users.py index 081f994..76dd961 100644 --- a/tests/test_graphql/test_users.py +++ b/tests/test_graphql/test_users.py @@ -3,7 +3,7 @@ import json import pytest -from tests.common import generate_system_query, read_json, write_json +from tests.common import generate_system_query, generate_users_query, read_json, write_json def read_json(file_path): @@ -129,7 +129,7 @@ def test_graphql_get_users_unauthorized(client, some_users, mock_subprocess_pope response = client.get( "/graphql", json={ - "query": generate_system_query([API_USERS_INFO]), + "query": generate_users_query([API_USERS_INFO]), }, ) assert response.status_code == 200 @@ -140,7 +140,7 @@ def test_graphql_get_some_users(authorized_client, some_users, mock_subprocess_p response = authorized_client.get( "/graphql", json={ - "query": generate_system_query([API_USERS_INFO]), + "query": generate_users_query([API_USERS_INFO]), }, ) assert response.status_code == 200