Fix ws auth
parent
ab9e8d81e5
commit
28c6d983b9
|
@ -1,13 +1,12 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""SelfPrivacy server management API"""
|
"""SelfPrivacy server management API"""
|
||||||
import os
|
from fastapi import FastAPI
|
||||||
from fastapi import FastAPI, Depends, Request, WebSocket, BackgroundTasks
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from strawberry.fastapi import BaseContext, GraphQLRouter
|
from strawberry.fastapi import GraphQLRouter
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from selfprivacy_api.dependencies import get_api_version, get_graphql_context
|
from selfprivacy_api.dependencies import get_api_version
|
||||||
from selfprivacy_api.graphql.schema import schema
|
from selfprivacy_api.graphql.schema import schema
|
||||||
from selfprivacy_api.migrations import run_migrations
|
from selfprivacy_api.migrations import run_migrations
|
||||||
from selfprivacy_api.restic_controller.tasks import init_restic
|
from selfprivacy_api.restic_controller.tasks import init_restic
|
||||||
|
@ -20,9 +19,9 @@ from selfprivacy_api.rest import (
|
||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
graphql_app = GraphQLRouter(
|
graphql_app = GraphQLRouter(
|
||||||
schema,
|
schema,
|
||||||
context_getter=get_graphql_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from typing import Optional
|
|
||||||
from strawberry.fastapi import BaseContext
|
|
||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -27,29 +25,6 @@ async def get_token_header(
|
||||||
return TokenHeader(token=token)
|
return TokenHeader(token=token)
|
||||||
|
|
||||||
|
|
||||||
class GraphQlContext(BaseContext):
|
|
||||||
def __init__(self, auth_token: Optional[str] = None):
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.is_authenticated = auth_token is not None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_graphql_context(
|
|
||||||
token: str = Depends(
|
|
||||||
APIKeyHeader(
|
|
||||||
name="Authorization",
|
|
||||||
auto_error=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
) -> GraphQlContext:
|
|
||||||
if token is None:
|
|
||||||
return GraphQlContext()
|
|
||||||
else:
|
|
||||||
token = token.replace("Bearer ", "")
|
|
||||||
if not is_token_valid(token):
|
|
||||||
return GraphQlContext()
|
|
||||||
return GraphQlContext(auth_token=token)
|
|
||||||
|
|
||||||
|
|
||||||
def get_api_version() -> str:
|
def get_api_version() -> str:
|
||||||
"""Get API version"""
|
"""Get API version"""
|
||||||
return "2.0.0"
|
return "2.0.0"
|
||||||
|
|
|
@ -13,4 +13,8 @@ class IsAuthenticated(BasePermission):
|
||||||
message = "You must be authenticated to access this resource."
|
message = "You must be authenticated to access this resource."
|
||||||
|
|
||||||
def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool:
|
def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool:
|
||||||
return info.context.is_authenticated
|
return is_token_valid(
|
||||||
|
info.context["request"]
|
||||||
|
.headers.get("Authorization", "")
|
||||||
|
.replace("Bearer ", "")
|
||||||
|
)
|
||||||
|
|
|
@ -116,7 +116,11 @@ class ApiMutations:
|
||||||
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
||||||
def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn:
|
def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn:
|
||||||
"""Refresh device api token"""
|
"""Refresh device api token"""
|
||||||
token = info.context.auth_token
|
token = (
|
||||||
|
info.context["request"]
|
||||||
|
.headers.get("Authorization", "")
|
||||||
|
.replace("Bearer ", "")
|
||||||
|
)
|
||||||
if token is None:
|
if token is None:
|
||||||
return DeviceApiTokenMutationReturn(
|
return DeviceApiTokenMutationReturn(
|
||||||
success=False,
|
success=False,
|
||||||
|
@ -142,7 +146,11 @@ class ApiMutations:
|
||||||
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
||||||
def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn:
|
def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn:
|
||||||
"""Delete device api token"""
|
"""Delete device api token"""
|
||||||
self_token = info.context.auth_token
|
self_token = (
|
||||||
|
info.context["request"]
|
||||||
|
.headers.get("Authorization", "")
|
||||||
|
.replace("Bearer ", "")
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
delete_api_token(self_token, device)
|
delete_api_token(self_token, device)
|
||||||
except NotFoundException:
|
except NotFoundException:
|
||||||
|
|
|
@ -85,7 +85,11 @@ class Api:
|
||||||
creation_date=device.date,
|
creation_date=device.date,
|
||||||
is_caller=device.is_caller,
|
is_caller=device.is_caller,
|
||||||
)
|
)
|
||||||
for device in get_api_tokens_with_caller_flag(info.context.auth_token)
|
for device in get_api_tokens_with_caller_flag(
|
||||||
|
info.context["request"]
|
||||||
|
.headers.get("Authorization", "")
|
||||||
|
.replace("Bearer ", "")
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
recovery_key: ApiRecoveryKeyStatus = strawberry.field(
|
recovery_key: ApiRecoveryKeyStatus = strawberry.field(
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""GraphQL API for SelfPrivacy."""
|
"""GraphQL API for SelfPrivacy."""
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import AsyncGenerator
|
||||||
import strawberry
|
import strawberry
|
||||||
from selfprivacy_api.graphql import IsAuthenticated
|
from selfprivacy_api.graphql import IsAuthenticated
|
||||||
from selfprivacy_api.graphql.mutations.api_mutations import ApiMutations
|
from selfprivacy_api.graphql.mutations.api_mutations import ApiMutations
|
||||||
|
@ -69,7 +71,7 @@ class Mutation(
|
||||||
):
|
):
|
||||||
"""Root schema for mutations"""
|
"""Root schema for mutations"""
|
||||||
|
|
||||||
@strawberry.mutation
|
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
||||||
def test_mutation(self) -> GenericMutationReturn:
|
def test_mutation(self) -> GenericMutationReturn:
|
||||||
"""Test mutation"""
|
"""Test mutation"""
|
||||||
test_job()
|
test_job()
|
||||||
|
@ -82,4 +84,15 @@ class Mutation(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
schema = strawberry.Schema(query=Query, mutation=Mutation)
|
@strawberry.type
|
||||||
|
class Subscription:
|
||||||
|
"""Root schema for subscriptions"""
|
||||||
|
|
||||||
|
@strawberry.subscription(permission_classes=[IsAuthenticated])
|
||||||
|
async def count(self, target: int = 100) -> AsyncGenerator[int, None]:
|
||||||
|
for i in range(target):
|
||||||
|
yield i
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
|
schema = strawberry.Schema(query=Query, mutation=Mutation, subscription=Subscription)
|
||||||
|
|
Loading…
Reference in New Issue