import asyncio import functools import inspect import typing from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request from starlette.responses import RedirectResponse, Response from starlette.websockets import WebSocket def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: for scope in scopes: if scope not in conn.auth.scopes: return False return True def requires( scopes: typing.Union[str, typing.Sequence[str]], status_code: int = 403, redirect: str = None, ) -> typing.Callable: scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator(func: typing.Callable) -> typing.Callable: type = None sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": type = parameter.name break else: raise Exception( f'No "request" or "websocket" argument on function "{func}"' ) if type == "websocket": # Handle websocket functions. (Always async) @functools.wraps(func) async def websocket_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> None: websocket = kwargs.get("websocket", args[idx]) assert isinstance(websocket, WebSocket) if not has_required_scope(websocket, scopes_list): await websocket.close() else: await func(*args, **kwargs) return websocket_wrapper elif asyncio.iscoroutinefunction(func): # Handle async request/response functions. @functools.wraps(func) async def async_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> Response: request = kwargs.get("request", args[idx]) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: return RedirectResponse( url=request.url_for(redirect), status_code=303 ) raise HTTPException(status_code=status_code) return await func(*args, **kwargs) return async_wrapper else: # Handle sync request/response functions. @functools.wraps(func) def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: request = kwargs.get("request", args[idx]) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: return RedirectResponse( url=request.url_for(redirect), status_code=303 ) raise HTTPException(status_code=status_code) return func(*args, **kwargs) return sync_wrapper return decorator class AuthenticationError(Exception): pass class AuthenticationBackend: async def authenticate( self, conn: HTTPConnection ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: raise NotImplementedError() # pragma: no cover class AuthCredentials: def __init__(self, scopes: typing.Sequence[str] = None): self.scopes = [] if scopes is None else list(scopes) class BaseUser: @property def is_authenticated(self) -> bool: raise NotImplementedError() # pragma: no cover @property def display_name(self) -> str: raise NotImplementedError() # pragma: no cover @property def identity(self) -> str: raise NotImplementedError() # pragma: no cover class SimpleUser(BaseUser): def __init__(self, username: str) -> None: self.username = username @property def is_authenticated(self) -> bool: return True @property def display_name(self) -> str: return self.username class UnauthenticatedUser(BaseUser): @property def is_authenticated(self) -> bool: return False @property def display_name(self) -> str: return ""