Fix ws auth

pull/13/head
Inex Code 2022-08-22 23:49:14 +04:00
parent ab9e8d81e5
commit 28c6d983b9
6 changed files with 40 additions and 37 deletions

View File

@ -1,13 +1,12 @@
#!/usr/bin/env python3
"""SelfPrivacy server management API"""
import os
from fastapi import FastAPI, Depends, Request, WebSocket, BackgroundTasks
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from strawberry.fastapi import BaseContext, GraphQLRouter
from strawberry.fastapi import GraphQLRouter
import uvicorn
from selfprivacy_api.dependencies import get_api_version, get_graphql_context
from selfprivacy_api.dependencies import get_api_version
from selfprivacy_api.graphql.schema import schema
from selfprivacy_api.migrations import run_migrations
from selfprivacy_api.restic_controller.tasks import init_restic
@ -20,9 +19,9 @@ from selfprivacy_api.rest import (
)
app = FastAPI()
graphql_app = GraphQLRouter(
schema,
context_getter=get_graphql_context,
)
app.add_middleware(

View File

@ -1,6 +1,4 @@
from fastapi import Depends, FastAPI, HTTPException, status
from typing import Optional
from strawberry.fastapi import BaseContext
from fastapi import Depends, HTTPException, status
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
@ -27,29 +25,6 @@ async def get_token_header(
return TokenHeader(token=token)
class GraphQlContext(BaseContext):
def __init__(self, auth_token: Optional[str] = None):
self.auth_token = auth_token
self.is_authenticated = auth_token is not None
async def get_graphql_context(
token: str = Depends(
APIKeyHeader(
name="Authorization",
auto_error=False,
)
)
) -> GraphQlContext:
if token is None:
return GraphQlContext()
else:
token = token.replace("Bearer ", "")
if not is_token_valid(token):
return GraphQlContext()
return GraphQlContext(auth_token=token)
def get_api_version() -> str:
"""Get API version"""
return "2.0.0"

View File

@ -13,4 +13,8 @@ class IsAuthenticated(BasePermission):
message = "You must be authenticated to access this resource."
def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool:
return info.context.is_authenticated
return is_token_valid(
info.context["request"]
.headers.get("Authorization", "")
.replace("Bearer ", "")
)

View File

@ -116,7 +116,11 @@ class ApiMutations:
@strawberry.mutation(permission_classes=[IsAuthenticated])
def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn:
"""Refresh device api token"""
token = info.context.auth_token
token = (
info.context["request"]
.headers.get("Authorization", "")
.replace("Bearer ", "")
)
if token is None:
return DeviceApiTokenMutationReturn(
success=False,
@ -142,7 +146,11 @@ class ApiMutations:
@strawberry.mutation(permission_classes=[IsAuthenticated])
def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn:
"""Delete device api token"""
self_token = info.context.auth_token
self_token = (
info.context["request"]
.headers.get("Authorization", "")
.replace("Bearer ", "")
)
try:
delete_api_token(self_token, device)
except NotFoundException:

View File

@ -85,7 +85,11 @@ class Api:
creation_date=device.date,
is_caller=device.is_caller,
)
for device in get_api_tokens_with_caller_flag(info.context.auth_token)
for device in get_api_tokens_with_caller_flag(
info.context["request"]
.headers.get("Authorization", "")
.replace("Bearer ", "")
)
]
recovery_key: ApiRecoveryKeyStatus = strawberry.field(

View File

@ -1,6 +1,8 @@
"""GraphQL API for SelfPrivacy."""
# pylint: disable=too-few-public-methods
import asyncio
from typing import AsyncGenerator
import strawberry
from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.mutations.api_mutations import ApiMutations
@ -69,7 +71,7 @@ class Mutation(
):
"""Root schema for mutations"""
@strawberry.mutation
@strawberry.mutation(permission_classes=[IsAuthenticated])
def test_mutation(self) -> GenericMutationReturn:
"""Test mutation"""
test_job()
@ -82,4 +84,15 @@ class Mutation(
pass
schema = strawberry.Schema(query=Query, mutation=Mutation)
@strawberry.type
class Subscription:
"""Root schema for subscriptions"""
@strawberry.subscription(permission_classes=[IsAuthenticated])
async def count(self, target: int = 100) -> AsyncGenerator[int, None]:
for i in range(target):
yield i
await asyncio.sleep(0.5)
schema = strawberry.Schema(query=Query, mutation=Mutation, subscription=Subscription)