144 lines
4.3 KiB
Python
144 lines
4.3 KiB
Python
|
import asyncio
|
||
|
import functools
|
||
|
import inspect
|
||
|
import typing
|
||
|
|
||
|
from starlette.exceptions import HTTPException
|
||
|
from starlette.requests import HTTPConnection, Request
|
||
|
from starlette.responses import RedirectResponse, Response
|
||
|
from starlette.websockets import WebSocket
|
||
|
|
||
|
|
||
|
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
|
||
|
for scope in scopes:
|
||
|
if scope not in conn.auth.scopes:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
def requires(
|
||
|
scopes: typing.Union[str, typing.Sequence[str]],
|
||
|
status_code: int = 403,
|
||
|
redirect: str = None,
|
||
|
) -> typing.Callable:
|
||
|
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
||
|
|
||
|
def decorator(func: typing.Callable) -> typing.Callable:
|
||
|
type = None
|
||
|
sig = inspect.signature(func)
|
||
|
for idx, parameter in enumerate(sig.parameters.values()):
|
||
|
if parameter.name == "request" or parameter.name == "websocket":
|
||
|
type = parameter.name
|
||
|
break
|
||
|
else:
|
||
|
raise Exception(
|
||
|
f'No "request" or "websocket" argument on function "{func}"'
|
||
|
)
|
||
|
|
||
|
if type == "websocket":
|
||
|
# Handle websocket functions. (Always async)
|
||
|
@functools.wraps(func)
|
||
|
async def websocket_wrapper(
|
||
|
*args: typing.Any, **kwargs: typing.Any
|
||
|
) -> None:
|
||
|
websocket = kwargs.get("websocket", args[idx])
|
||
|
assert isinstance(websocket, WebSocket)
|
||
|
|
||
|
if not has_required_scope(websocket, scopes_list):
|
||
|
await websocket.close()
|
||
|
else:
|
||
|
await func(*args, **kwargs)
|
||
|
|
||
|
return websocket_wrapper
|
||
|
|
||
|
elif asyncio.iscoroutinefunction(func):
|
||
|
# Handle async request/response functions.
|
||
|
@functools.wraps(func)
|
||
|
async def async_wrapper(
|
||
|
*args: typing.Any, **kwargs: typing.Any
|
||
|
) -> Response:
|
||
|
request = kwargs.get("request", args[idx])
|
||
|
assert isinstance(request, Request)
|
||
|
|
||
|
if not has_required_scope(request, scopes_list):
|
||
|
if redirect is not None:
|
||
|
return RedirectResponse(
|
||
|
url=request.url_for(redirect), status_code=303
|
||
|
)
|
||
|
raise HTTPException(status_code=status_code)
|
||
|
return await func(*args, **kwargs)
|
||
|
|
||
|
return async_wrapper
|
||
|
|
||
|
else:
|
||
|
# Handle sync request/response functions.
|
||
|
@functools.wraps(func)
|
||
|
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||
|
request = kwargs.get("request", args[idx])
|
||
|
assert isinstance(request, Request)
|
||
|
|
||
|
if not has_required_scope(request, scopes_list):
|
||
|
if redirect is not None:
|
||
|
return RedirectResponse(
|
||
|
url=request.url_for(redirect), status_code=303
|
||
|
)
|
||
|
raise HTTPException(status_code=status_code)
|
||
|
return func(*args, **kwargs)
|
||
|
|
||
|
return sync_wrapper
|
||
|
|
||
|
return decorator
|
||
|
|
||
|
|
||
|
class AuthenticationError(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class AuthenticationBackend:
|
||
|
async def authenticate(
|
||
|
self, conn: HTTPConnection
|
||
|
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
|
||
|
raise NotImplementedError() # pragma: no cover
|
||
|
|
||
|
|
||
|
class AuthCredentials:
|
||
|
def __init__(self, scopes: typing.Sequence[str] = None):
|
||
|
self.scopes = [] if scopes is None else list(scopes)
|
||
|
|
||
|
|
||
|
class BaseUser:
|
||
|
@property
|
||
|
def is_authenticated(self) -> bool:
|
||
|
raise NotImplementedError() # pragma: no cover
|
||
|
|
||
|
@property
|
||
|
def display_name(self) -> str:
|
||
|
raise NotImplementedError() # pragma: no cover
|
||
|
|
||
|
@property
|
||
|
def identity(self) -> str:
|
||
|
raise NotImplementedError() # pragma: no cover
|
||
|
|
||
|
|
||
|
class SimpleUser(BaseUser):
|
||
|
def __init__(self, username: str) -> None:
|
||
|
self.username = username
|
||
|
|
||
|
@property
|
||
|
def is_authenticated(self) -> bool:
|
||
|
return True
|
||
|
|
||
|
@property
|
||
|
def display_name(self) -> str:
|
||
|
return self.username
|
||
|
|
||
|
|
||
|
class UnauthenticatedUser(BaseUser):
|
||
|
@property
|
||
|
def is_authenticated(self) -> bool:
|
||
|
return False
|
||
|
|
||
|
@property
|
||
|
def display_name(self) -> str:
|
||
|
return ""
|