Fix ws auth

pull/13/head
Inex Code 2022-08-22 23:49:14 +04:00
parent ab9e8d81e5
commit 28c6d983b9
6 changed files with 40 additions and 37 deletions

View File

@ -1,13 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""SelfPrivacy server management API""" """SelfPrivacy server management API"""
import os from fastapi import FastAPI
from fastapi import FastAPI, Depends, Request, WebSocket, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from strawberry.fastapi import BaseContext, GraphQLRouter from strawberry.fastapi import GraphQLRouter
import uvicorn import uvicorn
from selfprivacy_api.dependencies import get_api_version, get_graphql_context from selfprivacy_api.dependencies import get_api_version
from selfprivacy_api.graphql.schema import schema from selfprivacy_api.graphql.schema import schema
from selfprivacy_api.migrations import run_migrations from selfprivacy_api.migrations import run_migrations
from selfprivacy_api.restic_controller.tasks import init_restic from selfprivacy_api.restic_controller.tasks import init_restic
@ -20,9 +19,9 @@ from selfprivacy_api.rest import (
) )
app = FastAPI() app = FastAPI()
graphql_app = GraphQLRouter( graphql_app = GraphQLRouter(
schema, schema,
context_getter=get_graphql_context,
) )
app.add_middleware( app.add_middleware(

View File

@ -1,6 +1,4 @@
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, HTTPException, status
from typing import Optional
from strawberry.fastapi import BaseContext
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from pydantic import BaseModel from pydantic import BaseModel
@ -27,29 +25,6 @@ async def get_token_header(
return TokenHeader(token=token) return TokenHeader(token=token)
class GraphQlContext(BaseContext):
def __init__(self, auth_token: Optional[str] = None):
self.auth_token = auth_token
self.is_authenticated = auth_token is not None
async def get_graphql_context(
token: str = Depends(
APIKeyHeader(
name="Authorization",
auto_error=False,
)
)
) -> GraphQlContext:
if token is None:
return GraphQlContext()
else:
token = token.replace("Bearer ", "")
if not is_token_valid(token):
return GraphQlContext()
return GraphQlContext(auth_token=token)
def get_api_version() -> str: def get_api_version() -> str:
"""Get API version""" """Get API version"""
return "2.0.0" return "2.0.0"

View File

@ -13,4 +13,8 @@ class IsAuthenticated(BasePermission):
message = "You must be authenticated to access this resource." message = "You must be authenticated to access this resource."
def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool: def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool:
return info.context.is_authenticated return is_token_valid(
info.context["request"]
.headers.get("Authorization", "")
.replace("Bearer ", "")
)

View File

@ -116,7 +116,11 @@ class ApiMutations:
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn: def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn:
"""Refresh device api token""" """Refresh device api token"""
token = info.context.auth_token token = (
info.context["request"]
.headers.get("Authorization", "")
.replace("Bearer ", "")
)
if token is None: if token is None:
return DeviceApiTokenMutationReturn( return DeviceApiTokenMutationReturn(
success=False, success=False,
@ -142,7 +146,11 @@ class ApiMutations:
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn: def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn:
"""Delete device api token""" """Delete device api token"""
self_token = info.context.auth_token self_token = (
info.context["request"]
.headers.get("Authorization", "")
.replace("Bearer ", "")
)
try: try:
delete_api_token(self_token, device) delete_api_token(self_token, device)
except NotFoundException: except NotFoundException:

View File

@ -85,7 +85,11 @@ class Api:
creation_date=device.date, creation_date=device.date,
is_caller=device.is_caller, is_caller=device.is_caller,
) )
for device in get_api_tokens_with_caller_flag(info.context.auth_token) for device in get_api_tokens_with_caller_flag(
info.context["request"]
.headers.get("Authorization", "")
.replace("Bearer ", "")
)
] ]
recovery_key: ApiRecoveryKeyStatus = strawberry.field( recovery_key: ApiRecoveryKeyStatus = strawberry.field(

View File

@ -1,6 +1,8 @@
"""GraphQL API for SelfPrivacy.""" """GraphQL API for SelfPrivacy."""
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
import asyncio
from typing import AsyncGenerator
import strawberry import strawberry
from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.mutations.api_mutations import ApiMutations from selfprivacy_api.graphql.mutations.api_mutations import ApiMutations
@ -69,7 +71,7 @@ class Mutation(
): ):
"""Root schema for mutations""" """Root schema for mutations"""
@strawberry.mutation @strawberry.mutation(permission_classes=[IsAuthenticated])
def test_mutation(self) -> GenericMutationReturn: def test_mutation(self) -> GenericMutationReturn:
"""Test mutation""" """Test mutation"""
test_job() test_job()
@ -82,4 +84,15 @@ class Mutation(
pass pass
schema = strawberry.Schema(query=Query, mutation=Mutation) @strawberry.type
class Subscription:
"""Root schema for subscriptions"""
@strawberry.subscription(permission_classes=[IsAuthenticated])
async def count(self, target: int = 100) -> AsyncGenerator[int, None]:
for i in range(target):
yield i
await asyncio.sleep(0.5)
schema = strawberry.Schema(query=Query, mutation=Mutation, subscription=Subscription)