new
This commit is contained in:
1
.venv/lib/python3.9/site-packages/starlette/__init__.py
Normal file
1
.venv/lib/python3.9/site-packages/starlette/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.13.6"
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
200
.venv/lib/python3.9/site-packages/starlette/applications.py
Normal file
200
.venv/lib/python3.9/site-packages/starlette/applications.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import State, URLPath
|
||||
from starlette.exceptions import ExceptionMiddleware
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.errors import ServerErrorMiddleware
|
||||
from starlette.routing import BaseRoute, Router
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class Starlette:
|
||||
"""
|
||||
Creates an application instance.
|
||||
|
||||
**Parameters:**
|
||||
|
||||
* **debug** - Boolean indicating if debug tracebacks should be returned on errors.
|
||||
* **routes** - A list of routes to serve incoming HTTP and WebSocket requests.
|
||||
* **middleware** - A list of middleware to run for every request. A starlette
|
||||
application will always automatically include two middleware classes.
|
||||
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
|
||||
any uncaught errors occuring anywhere in the entire stack.
|
||||
`ExceptionMiddleware` is added as the very innermost middleware, to deal
|
||||
with handled exception cases occuring in the routing or endpoints.
|
||||
* **exception_handlers** - A dictionary mapping either integer status codes,
|
||||
or exception class types onto callables which handle the exceptions.
|
||||
Exception handler callables should be of the form `handler(request, exc) -> response`
|
||||
and may be be either standard functions, or async functions.
|
||||
* **on_startup** - A list of callables to run on application startup.
|
||||
Startup handler callables do not take any arguments, and may be be either
|
||||
standard functions, or async functions.
|
||||
* **on_shutdown** - A list of callables to run on application shutdown.
|
||||
Shutdown handler callables do not take any arguments, and may be be either
|
||||
standard functions, or async functions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
debug: bool = False,
|
||||
routes: typing.Sequence[BaseRoute] = None,
|
||||
middleware: typing.Sequence[Middleware] = None,
|
||||
exception_handlers: typing.Dict[
|
||||
typing.Union[int, typing.Type[Exception]], typing.Callable
|
||||
] = None,
|
||||
on_startup: typing.Sequence[typing.Callable] = None,
|
||||
on_shutdown: typing.Sequence[typing.Callable] = None,
|
||||
lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None,
|
||||
) -> None:
|
||||
# The lifespan context function is a newer style that replaces
|
||||
# on_startup / on_shutdown handlers. Use one or the other, not both.
|
||||
assert lifespan is None or (
|
||||
on_startup is None and on_shutdown is None
|
||||
), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
|
||||
|
||||
self._debug = debug
|
||||
self.state = State()
|
||||
self.router = Router(
|
||||
routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
|
||||
)
|
||||
self.exception_handlers = (
|
||||
{} if exception_handlers is None else dict(exception_handlers)
|
||||
)
|
||||
self.user_middleware = [] if middleware is None else list(middleware)
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
|
||||
def build_middleware_stack(self) -> ASGIApp:
|
||||
debug = self.debug
|
||||
error_handler = None
|
||||
exception_handlers = {}
|
||||
|
||||
for key, value in self.exception_handlers.items():
|
||||
if key in (500, Exception):
|
||||
error_handler = value
|
||||
else:
|
||||
exception_handlers[key] = value
|
||||
|
||||
middleware = (
|
||||
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug,)]
|
||||
+ self.user_middleware
|
||||
+ [
|
||||
Middleware(
|
||||
ExceptionMiddleware, handlers=exception_handlers, debug=debug,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
app = self.router
|
||||
for cls, options in reversed(middleware):
|
||||
app = cls(app=app, **options)
|
||||
return app
|
||||
|
||||
@property
|
||||
def routes(self) -> typing.List[BaseRoute]:
|
||||
return self.router.routes
|
||||
|
||||
@property
|
||||
def debug(self) -> bool:
|
||||
return self._debug
|
||||
|
||||
@debug.setter
|
||||
def debug(self, value: bool) -> None:
|
||||
self._debug = value
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
return self.router.url_path_for(name, **path_params)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
scope["app"] = self
|
||||
await self.middleware_stack(scope, receive, send)
|
||||
|
||||
# The following usages are now discouraged in favour of configuration
|
||||
# during Starlette.__init__(...)
|
||||
def on_event(self, event_type: str) -> typing.Callable:
|
||||
return self.router.on_event(event_type)
|
||||
|
||||
def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
|
||||
self.router.mount(path, app=app, name=name)
|
||||
|
||||
def host(self, host: str, app: ASGIApp, name: str = None) -> None:
|
||||
self.router.host(host, app=app, name=name)
|
||||
|
||||
def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
|
||||
self.user_middleware.insert(0, Middleware(middleware_class, **options))
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
|
||||
handler: typing.Callable,
|
||||
) -> None:
|
||||
self.exception_handlers[exc_class_or_status_code] = handler
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
|
||||
def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
|
||||
self.router.add_event_handler(event_type, func)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
path: str,
|
||||
route: typing.Callable,
|
||||
methods: typing.List[str] = None,
|
||||
name: str = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None:
|
||||
self.router.add_route(
|
||||
path, route, methods=methods, name=name, include_in_schema=include_in_schema
|
||||
)
|
||||
|
||||
def add_websocket_route(
|
||||
self, path: str, route: typing.Callable, name: str = None
|
||||
) -> None:
|
||||
self.router.add_websocket_route(path, route, name=name)
|
||||
|
||||
def exception_handler(
|
||||
self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]]
|
||||
) -> typing.Callable:
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
self.add_exception_handler(exc_class_or_status_code, func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def route(
|
||||
self,
|
||||
path: str,
|
||||
methods: typing.List[str] = None,
|
||||
name: str = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> typing.Callable:
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
self.router.add_route(
|
||||
path,
|
||||
func,
|
||||
methods=methods,
|
||||
name=name,
|
||||
include_in_schema=include_in_schema,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def websocket_route(self, path: str, name: str = None) -> typing.Callable:
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
self.router.add_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def middleware(self, middleware_type: str) -> typing.Callable:
|
||||
assert (
|
||||
middleware_type == "http"
|
||||
), 'Currently only middleware("http") is supported.'
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
|
||||
return func
|
||||
|
||||
return decorator
|
143
.venv/lib/python3.9/site-packages/starlette/authentication.py
Normal file
143
.venv/lib/python3.9/site-packages/starlette/authentication.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
|
||||
for scope in scopes:
|
||||
if scope not in conn.auth.scopes:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def requires(
|
||||
scopes: typing.Union[str, typing.Sequence[str]],
|
||||
status_code: int = 403,
|
||||
redirect: str = None,
|
||||
) -> typing.Callable:
|
||||
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
type = None
|
||||
sig = inspect.signature(func)
|
||||
for idx, parameter in enumerate(sig.parameters.values()):
|
||||
if parameter.name == "request" or parameter.name == "websocket":
|
||||
type = parameter.name
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
f'No "request" or "websocket" argument on function "{func}"'
|
||||
)
|
||||
|
||||
if type == "websocket":
|
||||
# Handle websocket functions. (Always async)
|
||||
@functools.wraps(func)
|
||||
async def websocket_wrapper(
|
||||
*args: typing.Any, **kwargs: typing.Any
|
||||
) -> None:
|
||||
websocket = kwargs.get("websocket", args[idx])
|
||||
assert isinstance(websocket, WebSocket)
|
||||
|
||||
if not has_required_scope(websocket, scopes_list):
|
||||
await websocket.close()
|
||||
else:
|
||||
await func(*args, **kwargs)
|
||||
|
||||
return websocket_wrapper
|
||||
|
||||
elif asyncio.iscoroutinefunction(func):
|
||||
# Handle async request/response functions.
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(
|
||||
*args: typing.Any, **kwargs: typing.Any
|
||||
) -> Response:
|
||||
request = kwargs.get("request", args[idx])
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
return RedirectResponse(
|
||||
url=request.url_for(redirect), status_code=303
|
||||
)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
# Handle sync request/response functions.
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||
request = kwargs.get("request", args[idx])
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
return RedirectResponse(
|
||||
url=request.url_for(redirect), status_code=303
|
||||
)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationBackend:
|
||||
async def authenticate(
|
||||
self, conn: HTTPConnection
|
||||
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class AuthCredentials:
|
||||
def __init__(self, scopes: typing.Sequence[str] = None):
|
||||
self.scopes = [] if scopes is None else list(scopes)
|
||||
|
||||
|
||||
class BaseUser:
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
@property
|
||||
def identity(self) -> str:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class SimpleUser(BaseUser):
|
||||
def __init__(self, username: str) -> None:
|
||||
self.username = username
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self.username
|
||||
|
||||
|
||||
class UnauthenticatedUser(BaseUser):
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return ""
|
35
.venv/lib/python3.9/site-packages/starlette/background.py
Normal file
35
.venv/lib/python3.9/site-packages/starlette/background.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
|
||||
class BackgroundTask:
|
||||
def __init__(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> None:
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.is_async = asyncio.iscoroutinefunction(func)
|
||||
|
||||
async def __call__(self) -> None:
|
||||
if self.is_async:
|
||||
await self.func(*self.args, **self.kwargs)
|
||||
else:
|
||||
await run_in_threadpool(self.func, *self.args, **self.kwargs)
|
||||
|
||||
|
||||
class BackgroundTasks(BackgroundTask):
|
||||
def __init__(self, tasks: typing.Sequence[BackgroundTask] = []):
|
||||
self.tasks = list(tasks)
|
||||
|
||||
def add_task(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> None:
|
||||
task = BackgroundTask(func, *args, **kwargs)
|
||||
self.tasks.append(task)
|
||||
|
||||
async def __call__(self) -> None:
|
||||
for task in self.tasks:
|
||||
await task()
|
56
.venv/lib/python3.9/site-packages/starlette/concurrency.py
Normal file
56
.venv/lib/python3.9/site-packages/starlette/concurrency.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import typing
|
||||
from typing import Any, AsyncGenerator, Iterator
|
||||
|
||||
try:
|
||||
import contextvars # Python 3.7+ only.
|
||||
except ImportError: # pragma: no cover
|
||||
contextvars = None # type: ignore
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
|
||||
tasks = [handler(**kwargs) for handler, kwargs in args]
|
||||
(done, pending) = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
[task.cancel() for task in pending]
|
||||
[task.result() for task in done]
|
||||
|
||||
|
||||
async def run_in_threadpool(
|
||||
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
|
||||
) -> T:
|
||||
loop = asyncio.get_event_loop()
|
||||
if contextvars is not None: # pragma: no cover
|
||||
# Ensure we run in the same context
|
||||
child = functools.partial(func, *args, **kwargs)
|
||||
context = contextvars.copy_context()
|
||||
func = context.run
|
||||
args = (child,)
|
||||
elif kwargs: # pragma: no cover
|
||||
# loop.run_in_executor doesn't accept 'kwargs', so bind them in here
|
||||
func = functools.partial(func, **kwargs)
|
||||
return await loop.run_in_executor(None, func, *args)
|
||||
|
||||
|
||||
class _StopIteration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _next(iterator: Iterator) -> Any:
|
||||
# We can't raise `StopIteration` from within the threadpool iterator
|
||||
# and catch it outside that context, so we coerce them into a different
|
||||
# exception type.
|
||||
try:
|
||||
return next(iterator)
|
||||
except StopIteration:
|
||||
raise _StopIteration
|
||||
|
||||
|
||||
async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator:
|
||||
while True:
|
||||
try:
|
||||
yield await run_in_threadpool(_next, iterator)
|
||||
except _StopIteration:
|
||||
break
|
107
.venv/lib/python3.9/site-packages/starlette/config.py
Normal file
107
.venv/lib/python3.9/site-packages/starlette/config.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import os
|
||||
import typing
|
||||
from collections.abc import MutableMapping
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class undefined:
|
||||
pass
|
||||
|
||||
|
||||
class EnvironError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Environ(MutableMapping):
|
||||
def __init__(self, environ: typing.MutableMapping = os.environ):
|
||||
self._environ = environ
|
||||
self._has_been_read = set() # type: typing.Set[typing.Any]
|
||||
|
||||
def __getitem__(self, key: typing.Any) -> typing.Any:
|
||||
self._has_been_read.add(key)
|
||||
return self._environ.__getitem__(key)
|
||||
|
||||
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(
|
||||
f"Attempting to set environ['{key}'], but the value has already been read."
|
||||
)
|
||||
self._environ.__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key: typing.Any) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(
|
||||
f"Attempting to delete environ['{key}'], but the value has already been read."
|
||||
)
|
||||
self._environ.__delitem__(key)
|
||||
|
||||
def __iter__(self) -> typing.Iterator:
|
||||
return iter(self._environ)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._environ)
|
||||
|
||||
|
||||
environ = Environ()
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(
|
||||
self,
|
||||
env_file: typing.Union[str, Path] = None,
|
||||
environ: typing.Mapping[str, str] = environ,
|
||||
) -> None:
|
||||
self.environ = environ
|
||||
self.file_values = {} # type: typing.Dict[str, str]
|
||||
if env_file is not None and os.path.isfile(env_file):
|
||||
self.file_values = self._read_file(env_file)
|
||||
|
||||
def __call__(
|
||||
self, key: str, cast: typing.Callable = None, default: typing.Any = undefined,
|
||||
) -> typing.Any:
|
||||
return self.get(key, cast, default)
|
||||
|
||||
def get(
|
||||
self, key: str, cast: typing.Callable = None, default: typing.Any = undefined,
|
||||
) -> typing.Any:
|
||||
if key in self.environ:
|
||||
value = self.environ[key]
|
||||
return self._perform_cast(key, value, cast)
|
||||
if key in self.file_values:
|
||||
value = self.file_values[key]
|
||||
return self._perform_cast(key, value, cast)
|
||||
if default is not undefined:
|
||||
return self._perform_cast(key, default, cast)
|
||||
raise KeyError(f"Config '{key}' is missing, and has no default.")
|
||||
|
||||
def _read_file(self, file_name: typing.Union[str, Path]) -> typing.Dict[str, str]:
|
||||
file_values = {} # type: typing.Dict[str, str]
|
||||
with open(file_name) as input_file:
|
||||
for line in input_file.readlines():
|
||||
line = line.strip()
|
||||
if "=" in line and not line.startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip().strip("\"'")
|
||||
file_values[key] = value
|
||||
return file_values
|
||||
|
||||
def _perform_cast(
|
||||
self, key: str, value: typing.Any, cast: typing.Callable = None,
|
||||
) -> typing.Any:
|
||||
if cast is None or value is None:
|
||||
return value
|
||||
elif cast is bool and isinstance(value, str):
|
||||
mapping = {"true": True, "1": True, "false": False, "0": False}
|
||||
value = value.lower()
|
||||
if value not in mapping:
|
||||
raise ValueError(
|
||||
f"Config '{key}' has value '{value}'. Not a valid bool."
|
||||
)
|
||||
return mapping[value]
|
||||
try:
|
||||
return cast(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}."
|
||||
)
|
81
.venv/lib/python3.9/site-packages/starlette/convertors.py
Normal file
81
.venv/lib/python3.9/site-packages/starlette/convertors.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import math
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
|
||||
class Convertor:
|
||||
regex = ""
|
||||
|
||||
def convert(self, value: str) -> typing.Any:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def to_string(self, value: typing.Any) -> str:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class StringConvertor(Convertor):
|
||||
regex = "[^/]+"
|
||||
|
||||
def convert(self, value: str) -> typing.Any:
|
||||
return value
|
||||
|
||||
def to_string(self, value: typing.Any) -> str:
|
||||
value = str(value)
|
||||
assert "/" not in value, "May not contain path separators"
|
||||
assert value, "Must not be empty"
|
||||
return value
|
||||
|
||||
|
||||
class PathConvertor(Convertor):
|
||||
regex = ".*"
|
||||
|
||||
def convert(self, value: str) -> typing.Any:
|
||||
return str(value)
|
||||
|
||||
def to_string(self, value: typing.Any) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
class IntegerConvertor(Convertor):
|
||||
regex = "[0-9]+"
|
||||
|
||||
def convert(self, value: str) -> typing.Any:
|
||||
return int(value)
|
||||
|
||||
def to_string(self, value: typing.Any) -> str:
|
||||
value = int(value)
|
||||
assert value >= 0, "Negative integers are not supported"
|
||||
return str(value)
|
||||
|
||||
|
||||
class FloatConvertor(Convertor):
|
||||
regex = "[0-9]+(.[0-9]+)?"
|
||||
|
||||
def convert(self, value: str) -> typing.Any:
|
||||
return float(value)
|
||||
|
||||
def to_string(self, value: typing.Any) -> str:
|
||||
value = float(value)
|
||||
assert value >= 0.0, "Negative floats are not supported"
|
||||
assert not math.isnan(value), "NaN values are not supported"
|
||||
assert not math.isinf(value), "Infinite values are not supported"
|
||||
return ("%0.20f" % value).rstrip("0").rstrip(".")
|
||||
|
||||
|
||||
class UUIDConvertor(Convertor):
|
||||
regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
|
||||
def convert(self, value: str) -> typing.Any:
|
||||
return uuid.UUID(value)
|
||||
|
||||
def to_string(self, value: typing.Any) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
CONVERTOR_TYPES = {
|
||||
"str": StringConvertor(),
|
||||
"path": PathConvertor(),
|
||||
"int": IntegerConvertor(),
|
||||
"float": FloatConvertor(),
|
||||
"uuid": UUIDConvertor(),
|
||||
}
|
675
.venv/lib/python3.9/site-packages/starlette/datastructures.py
Normal file
675
.venv/lib/python3.9/site-packages/starlette/datastructures.py
Normal file
@@ -0,0 +1,675 @@
|
||||
import tempfile
|
||||
import typing
|
||||
from collections import namedtuple
|
||||
from collections.abc import Sequence
|
||||
from shlex import shlex
|
||||
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.types import Scope
|
||||
|
||||
Address = namedtuple("Address", ["host", "port"])
|
||||
|
||||
|
||||
class URL:
|
||||
def __init__(
|
||||
self, url: str = "", scope: Scope = None, **components: typing.Any
|
||||
) -> None:
|
||||
if scope is not None:
|
||||
assert not url, 'Cannot set both "url" and "scope".'
|
||||
assert not components, 'Cannot set both "scope" and "**components".'
|
||||
scheme = scope.get("scheme", "http")
|
||||
server = scope.get("server", None)
|
||||
path = scope.get("root_path", "") + scope["path"]
|
||||
query_string = scope.get("query_string", b"")
|
||||
|
||||
host_header = None
|
||||
for key, value in scope["headers"]:
|
||||
if key == b"host":
|
||||
host_header = value.decode("latin-1")
|
||||
break
|
||||
|
||||
if host_header is not None:
|
||||
url = f"{scheme}://{host_header}{path}"
|
||||
elif server is None:
|
||||
url = path
|
||||
else:
|
||||
host, port = server
|
||||
default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
|
||||
if port == default_port:
|
||||
url = f"{scheme}://{host}{path}"
|
||||
else:
|
||||
url = f"{scheme}://{host}:{port}{path}"
|
||||
|
||||
if query_string:
|
||||
url += "?" + query_string.decode()
|
||||
elif components:
|
||||
assert not url, 'Cannot set both "url" and "**components".'
|
||||
url = URL("").replace(**components).components.geturl()
|
||||
|
||||
self._url = url
|
||||
|
||||
@property
|
||||
def components(self) -> SplitResult:
|
||||
if not hasattr(self, "_components"):
|
||||
self._components = urlsplit(self._url)
|
||||
return self._components
|
||||
|
||||
@property
|
||||
def scheme(self) -> str:
|
||||
return self.components.scheme
|
||||
|
||||
@property
|
||||
def netloc(self) -> str:
|
||||
return self.components.netloc
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self.components.path
|
||||
|
||||
@property
|
||||
def query(self) -> str:
|
||||
return self.components.query
|
||||
|
||||
@property
|
||||
def fragment(self) -> str:
|
||||
return self.components.fragment
|
||||
|
||||
@property
|
||||
def username(self) -> typing.Union[None, str]:
|
||||
return self.components.username
|
||||
|
||||
@property
|
||||
def password(self) -> typing.Union[None, str]:
|
||||
return self.components.password
|
||||
|
||||
@property
|
||||
def hostname(self) -> typing.Union[None, str]:
|
||||
return self.components.hostname
|
||||
|
||||
@property
|
||||
def port(self) -> typing.Optional[int]:
|
||||
return self.components.port
|
||||
|
||||
@property
|
||||
def is_secure(self) -> bool:
|
||||
return self.scheme in ("https", "wss")
|
||||
|
||||
def replace(self, **kwargs: typing.Any) -> "URL":
|
||||
if (
|
||||
"username" in kwargs
|
||||
or "password" in kwargs
|
||||
or "hostname" in kwargs
|
||||
or "port" in kwargs
|
||||
):
|
||||
hostname = kwargs.pop("hostname", self.hostname)
|
||||
port = kwargs.pop("port", self.port)
|
||||
username = kwargs.pop("username", self.username)
|
||||
password = kwargs.pop("password", self.password)
|
||||
|
||||
netloc = hostname
|
||||
if port is not None:
|
||||
netloc += f":{port}"
|
||||
if username is not None:
|
||||
userpass = username
|
||||
if password is not None:
|
||||
userpass += f":{password}"
|
||||
netloc = f"{userpass}@{netloc}"
|
||||
|
||||
kwargs["netloc"] = netloc
|
||||
|
||||
components = self.components._replace(**kwargs)
|
||||
return self.__class__(components.geturl())
|
||||
|
||||
def include_query_params(self, **kwargs: typing.Any) -> "URL":
|
||||
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
|
||||
params.update({str(key): str(value) for key, value in kwargs.items()})
|
||||
query = urlencode(params.multi_items())
|
||||
return self.replace(query=query)
|
||||
|
||||
def replace_query_params(self, **kwargs: typing.Any) -> "URL":
|
||||
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
|
||||
return self.replace(query=query)
|
||||
|
||||
def remove_query_params(
|
||||
self, keys: typing.Union[str, typing.Sequence[str]]
|
||||
) -> "URL":
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
|
||||
for key in keys:
|
||||
params.pop(key, None)
|
||||
query = urlencode(params.multi_items())
|
||||
return self.replace(query=query)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return str(self) == str(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._url
|
||||
|
||||
def __repr__(self) -> str:
|
||||
url = str(self)
|
||||
if self.password:
|
||||
url = str(self.replace(password="********"))
|
||||
return f"{self.__class__.__name__}({repr(url)})"
|
||||
|
||||
|
||||
class URLPath(str):
|
||||
"""
|
||||
A URL path string that may also hold an associated protocol and/or host.
|
||||
Used by the routing to return `url_path_for` matches.
|
||||
"""
|
||||
|
||||
def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
|
||||
assert protocol in ("http", "websocket", "")
|
||||
return str.__new__(cls, path) # type: ignore
|
||||
|
||||
def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
|
||||
self.protocol = protocol
|
||||
self.host = host
|
||||
|
||||
def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str:
|
||||
if isinstance(base_url, str):
|
||||
base_url = URL(base_url)
|
||||
if self.protocol:
|
||||
scheme = {
|
||||
"http": {True: "https", False: "http"},
|
||||
"websocket": {True: "wss", False: "ws"},
|
||||
}[self.protocol][base_url.is_secure]
|
||||
else:
|
||||
scheme = base_url.scheme
|
||||
|
||||
if self.host:
|
||||
netloc = self.host
|
||||
else:
|
||||
netloc = base_url.netloc
|
||||
|
||||
path = base_url.path.rstrip("/") + str(self)
|
||||
return str(URL(scheme=scheme, netloc=netloc, path=path))
|
||||
|
||||
|
||||
class Secret:
|
||||
"""
|
||||
Holds a string value that should not be revealed in tracebacks etc.
|
||||
You should cast the value to `str` at the point it is required.
|
||||
"""
|
||||
|
||||
def __init__(self, value: str):
|
||||
self._value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}('**********')"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._value
|
||||
|
||||
|
||||
class CommaSeparatedStrings(Sequence):
|
||||
def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
|
||||
if isinstance(value, str):
|
||||
splitter = shlex(value, posix=True)
|
||||
splitter.whitespace = ","
|
||||
splitter.whitespace_split = True
|
||||
self._items = [item.strip() for item in splitter]
|
||||
else:
|
||||
self._items = list(value)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._items)
|
||||
|
||||
def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
|
||||
return self._items[index]
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
return iter(self._items)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
items = [item for item in self]
|
||||
return f"{class_name}({items!r})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ", ".join([repr(item) for item in self])
|
||||
|
||||
|
||||
class ImmutableMultiDict(typing.Mapping):
|
||||
def __init__(
|
||||
self,
|
||||
*args: typing.Union[
|
||||
"ImmutableMultiDict",
|
||||
typing.Mapping,
|
||||
typing.List[typing.Tuple[typing.Any, typing.Any]],
|
||||
],
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
assert len(args) < 2, "Too many arguments."
|
||||
|
||||
if args:
|
||||
value = args[0]
|
||||
else:
|
||||
value = []
|
||||
|
||||
if kwargs:
|
||||
value = (
|
||||
ImmutableMultiDict(value).multi_items()
|
||||
+ ImmutableMultiDict(kwargs).multi_items()
|
||||
)
|
||||
|
||||
if not value:
|
||||
_items = [] # type: typing.List[typing.Tuple[typing.Any, typing.Any]]
|
||||
elif hasattr(value, "multi_items"):
|
||||
value = typing.cast(ImmutableMultiDict, value)
|
||||
_items = list(value.multi_items())
|
||||
elif hasattr(value, "items"):
|
||||
value = typing.cast(typing.Mapping, value)
|
||||
_items = list(value.items())
|
||||
else:
|
||||
value = typing.cast(
|
||||
typing.List[typing.Tuple[typing.Any, typing.Any]], value
|
||||
)
|
||||
_items = list(value)
|
||||
|
||||
self._dict = {k: v for k, v in _items}
|
||||
self._list = _items
|
||||
|
||||
def getlist(self, key: typing.Any) -> typing.List[str]:
|
||||
return [item_value for item_key, item_value in self._list if item_key == key]
|
||||
|
||||
def keys(self) -> typing.KeysView:
|
||||
return self._dict.keys()
|
||||
|
||||
def values(self) -> typing.ValuesView:
|
||||
return self._dict.values()
|
||||
|
||||
def items(self) -> typing.ItemsView:
|
||||
return self._dict.items()
|
||||
|
||||
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
|
||||
return list(self._list)
|
||||
|
||||
def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
|
||||
if key in self._dict:
|
||||
return self._dict[key]
|
||||
return default
|
||||
|
||||
def __getitem__(self, key: typing.Any) -> str:
|
||||
return self._dict[key]
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
return key in self._dict
|
||||
|
||||
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||||
return iter(self.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dict)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
items = self.multi_items()
|
||||
return f"{class_name}({items!r})"
|
||||
|
||||
|
||||
class MultiDict(ImmutableMultiDict):
|
||||
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
|
||||
self.setlist(key, [value])
|
||||
|
||||
def __delitem__(self, key: typing.Any) -> None:
|
||||
self._list = [(k, v) for k, v in self._list if k != key]
|
||||
del self._dict[key]
|
||||
|
||||
def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
|
||||
self._list = [(k, v) for k, v in self._list if k != key]
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
def popitem(self) -> typing.Tuple:
|
||||
key, value = self._dict.popitem()
|
||||
self._list = [(k, v) for k, v in self._list if k != key]
|
||||
return key, value
|
||||
|
||||
def poplist(self, key: typing.Any) -> typing.List:
|
||||
values = [v for k, v in self._list if k == key]
|
||||
self.pop(key)
|
||||
return values
|
||||
|
||||
def clear(self) -> None:
|
||||
self._dict.clear()
|
||||
self._list.clear()
|
||||
|
||||
def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
|
||||
if key not in self:
|
||||
self._dict[key] = default
|
||||
self._list.append((key, default))
|
||||
|
||||
return self[key]
|
||||
|
||||
def setlist(self, key: typing.Any, values: typing.List) -> None:
|
||||
if not values:
|
||||
self.pop(key, None)
|
||||
else:
|
||||
existing_items = [(k, v) for (k, v) in self._list if k != key]
|
||||
self._list = existing_items + [(key, value) for value in values]
|
||||
self._dict[key] = values[-1]
|
||||
|
||||
def append(self, key: typing.Any, value: typing.Any) -> None:
|
||||
self._list.append((key, value))
|
||||
self._dict[key] = value
|
||||
|
||||
def update(
|
||||
self,
|
||||
*args: typing.Union[
|
||||
"MultiDict",
|
||||
typing.Mapping,
|
||||
typing.List[typing.Tuple[typing.Any, typing.Any]],
|
||||
],
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
value = MultiDict(*args, **kwargs)
|
||||
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
|
||||
self._list = existing_items + value.multi_items()
|
||||
self._dict.update(value)
|
||||
|
||||
|
||||
class QueryParams(ImmutableMultiDict):
|
||||
"""
|
||||
An immutable multidict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: typing.Union[
|
||||
"ImmutableMultiDict",
|
||||
typing.Mapping,
|
||||
typing.List[typing.Tuple[typing.Any, typing.Any]],
|
||||
str,
|
||||
bytes,
|
||||
],
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
assert len(args) < 2, "Too many arguments."
|
||||
|
||||
value = args[0] if args else []
|
||||
|
||||
if isinstance(value, str):
|
||||
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
|
||||
elif isinstance(value, bytes):
|
||||
super().__init__(
|
||||
parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
|
||||
)
|
||||
else:
|
||||
super().__init__(*args, **kwargs) # type: ignore
|
||||
self._list = [(str(k), str(v)) for k, v in self._list]
|
||||
self._dict = {str(k): str(v) for k, v in self._dict.items()}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return urlencode(self._list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
query_string = str(self)
|
||||
return f"{class_name}({query_string!r})"
|
||||
|
||||
|
||||
class UploadFile:
|
||||
"""
|
||||
An uploaded file included as part of the request data.
|
||||
"""
|
||||
|
||||
spool_max_size = 1024 * 1024
|
||||
|
||||
def __init__(
|
||||
self, filename: str, file: typing.IO = None, content_type: str = ""
|
||||
) -> None:
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
if file is None:
|
||||
file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)
|
||||
self.file = file
|
||||
|
||||
@property
|
||||
def _in_memory(self) -> bool:
|
||||
rolled_to_disk = getattr(self.file, "_rolled", True)
|
||||
return not rolled_to_disk
|
||||
|
||||
async def write(self, data: typing.Union[bytes, str]) -> None:
|
||||
if self._in_memory:
|
||||
self.file.write(data) # type: ignore
|
||||
else:
|
||||
await run_in_threadpool(self.file.write, data)
|
||||
|
||||
async def read(self, size: int = -1) -> typing.Union[bytes, str]:
|
||||
if self._in_memory:
|
||||
return self.file.read(size)
|
||||
return await run_in_threadpool(self.file.read, size)
|
||||
|
||||
async def seek(self, offset: int) -> None:
|
||||
if self._in_memory:
|
||||
self.file.seek(offset)
|
||||
else:
|
||||
await run_in_threadpool(self.file.seek, offset)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._in_memory:
|
||||
self.file.close()
|
||||
else:
|
||||
await run_in_threadpool(self.file.close)
|
||||
|
||||
|
||||
class FormData(ImmutableMultiDict):
|
||||
"""
|
||||
An immutable multidict, containing both file uploads and text input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: typing.Union[
|
||||
"FormData",
|
||||
typing.Mapping[str, typing.Union[str, UploadFile]],
|
||||
typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
|
||||
],
|
||||
**kwargs: typing.Union[str, UploadFile],
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def close(self) -> None:
|
||||
for key, value in self.multi_items():
|
||||
if isinstance(value, UploadFile):
|
||||
await value.close()
|
||||
|
||||
|
||||
class Headers(typing.Mapping[str, str]):
|
||||
"""
|
||||
An immutable, case-insensitive multidict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: typing.Mapping[str, str] = None,
|
||||
raw: typing.List[typing.Tuple[bytes, bytes]] = None,
|
||||
scope: Scope = None,
|
||||
) -> None:
|
||||
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
if headers is not None:
|
||||
assert raw is None, 'Cannot set both "headers" and "raw".'
|
||||
assert scope is None, 'Cannot set both "headers" and "scope".'
|
||||
self._list = [
|
||||
(key.lower().encode("latin-1"), value.encode("latin-1"))
|
||||
for key, value in headers.items()
|
||||
]
|
||||
elif raw is not None:
|
||||
assert scope is None, 'Cannot set both "raw" and "scope".'
|
||||
self._list = raw
|
||||
elif scope is not None:
|
||||
self._list = scope["headers"]
|
||||
|
||||
@property
|
||||
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
return list(self._list)
|
||||
|
||||
def keys(self) -> typing.List[str]: # type: ignore
|
||||
return [key.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def values(self) -> typing.List[str]: # type: ignore
|
||||
return [value.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
|
||||
return [
|
||||
(key.decode("latin-1"), value.decode("latin-1"))
|
||||
for key, value in self._list
|
||||
]
|
||||
|
||||
def get(self, key: str, default: typing.Any = None) -> typing.Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def getlist(self, key: str) -> typing.List[str]:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
return [
|
||||
item_value.decode("latin-1")
|
||||
for item_key, item_value in self._list
|
||||
if item_key == get_header_key
|
||||
]
|
||||
|
||||
def mutablecopy(self) -> "MutableHeaders":
|
||||
return MutableHeaders(raw=self._list[:])
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return header_value.decode("latin-1")
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||||
return iter(self.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._list)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
if not isinstance(other, Headers):
|
||||
return False
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
as_dict = dict(self.items())
|
||||
if len(as_dict) == len(self):
|
||||
return f"{class_name}({as_dict!r})"
|
||||
return f"{class_name}(raw={self.raw!r})"
|
||||
|
||||
|
||||
class MutableHeaders(Headers):
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set the header `key` to `value`, removing any duplicate entries.
|
||||
Retains insertion order.
|
||||
"""
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
found_indexes = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
found_indexes.append(idx)
|
||||
|
||||
for idx in reversed(found_indexes[1:]):
|
||||
del self._list[idx]
|
||||
|
||||
if found_indexes:
|
||||
idx = found_indexes[0]
|
||||
self._list[idx] = (set_key, set_value)
|
||||
else:
|
||||
self._list.append((set_key, set_value))
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""
|
||||
Remove the header `key`.
|
||||
"""
|
||||
del_key = key.lower().encode("latin-1")
|
||||
|
||||
pop_indexes = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == del_key:
|
||||
pop_indexes.append(idx)
|
||||
|
||||
for idx in reversed(pop_indexes):
|
||||
del self._list[idx]
|
||||
|
||||
@property
|
||||
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
return self._list
|
||||
|
||||
def setdefault(self, key: str, value: str) -> str:
|
||||
"""
|
||||
If the header `key` does not exist, then set it to `value`.
|
||||
Returns the header value.
|
||||
"""
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
return item_value.decode("latin-1")
|
||||
self._list.append((set_key, set_value))
|
||||
return value
|
||||
|
||||
def update(self, other: dict) -> None:
|
||||
for key, val in other.items():
|
||||
self[key] = val
|
||||
|
||||
def append(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Append a header, preserving any duplicate entries.
|
||||
"""
|
||||
append_key = key.lower().encode("latin-1")
|
||||
append_value = value.encode("latin-1")
|
||||
self._list.append((append_key, append_value))
|
||||
|
||||
def add_vary_header(self, vary: str) -> None:
|
||||
existing = self.get("vary")
|
||||
if existing is not None:
|
||||
vary = ", ".join([existing, vary])
|
||||
self["vary"] = vary
|
||||
|
||||
|
||||
class State(object):
|
||||
"""
|
||||
An object that can be used to store arbitrary state.
|
||||
|
||||
Used for `request.state` and `app.state`.
|
||||
"""
|
||||
|
||||
def __init__(self, state: typing.Dict = None):
|
||||
if state is None:
|
||||
state = {}
|
||||
super(State, self).__setattr__("_state", state)
|
||||
|
||||
def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
|
||||
self._state[key] = value
|
||||
|
||||
def __getattr__(self, key: typing.Any) -> typing.Any:
|
||||
try:
|
||||
return self._state[key]
|
||||
except KeyError:
|
||||
message = "'{}' object has no attribute '{}'"
|
||||
raise AttributeError(message.format(self.__class__.__name__, key))
|
||||
|
||||
def __delattr__(self, key: typing.Any) -> None:
|
||||
del self._state[key]
|
117
.venv/lib/python3.9/site-packages/starlette/endpoints.py
Normal file
117
.venv/lib/python3.9/site-packages/starlette/endpoints.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import asyncio
|
||||
import json
|
||||
import typing
|
||||
|
||||
from starlette import status
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
class HTTPEndpoint:
|
||||
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] == "http"
|
||||
self.scope = scope
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
|
||||
def __await__(self) -> typing.Generator:
|
||||
return self.dispatch().__await__()
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
request = Request(self.scope, receive=self.receive)
|
||||
handler_name = "get" if request.method == "HEAD" else request.method.lower()
|
||||
handler = getattr(self, handler_name, self.method_not_allowed)
|
||||
is_async = asyncio.iscoroutinefunction(handler)
|
||||
if is_async:
|
||||
response = await handler(request)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, request)
|
||||
await response(self.scope, self.receive, self.send)
|
||||
|
||||
async def method_not_allowed(self, request: Request) -> Response:
|
||||
# If we're running inside a starlette application then raise an
|
||||
# exception, so that the configurable exception handler can deal with
|
||||
# returning the response. For plain ASGI apps, just return the response.
|
||||
if "app" in self.scope:
|
||||
raise HTTPException(status_code=405)
|
||||
return PlainTextResponse("Method Not Allowed", status_code=405)
|
||||
|
||||
|
||||
class WebSocketEndpoint:
|
||||
|
||||
encoding = None # May be "text", "bytes", or "json".
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] == "websocket"
|
||||
self.scope = scope
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
|
||||
def __await__(self) -> typing.Generator:
|
||||
return self.dispatch().__await__()
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
|
||||
await self.on_connect(websocket)
|
||||
|
||||
close_code = status.WS_1000_NORMAL_CLOSURE
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
if message["type"] == "websocket.receive":
|
||||
data = await self.decode(websocket, message)
|
||||
await self.on_receive(websocket, data)
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
close_code = int(message.get("code", status.WS_1000_NORMAL_CLOSURE))
|
||||
break
|
||||
except Exception as exc:
|
||||
close_code = status.WS_1011_INTERNAL_ERROR
|
||||
raise exc from None
|
||||
finally:
|
||||
await self.on_disconnect(websocket, close_code)
|
||||
|
||||
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
|
||||
|
||||
if self.encoding == "text":
|
||||
if "text" not in message:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
raise RuntimeError("Expected text websocket messages, but got bytes")
|
||||
return message["text"]
|
||||
|
||||
elif self.encoding == "bytes":
|
||||
if "bytes" not in message:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
raise RuntimeError("Expected bytes websocket messages, but got text")
|
||||
return message["bytes"]
|
||||
|
||||
elif self.encoding == "json":
|
||||
if message.get("text") is not None:
|
||||
text = message["text"]
|
||||
else:
|
||||
text = message["bytes"].decode("utf-8")
|
||||
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.decoder.JSONDecodeError:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
raise RuntimeError("Malformed JSON data received.")
|
||||
|
||||
assert (
|
||||
self.encoding is None
|
||||
), f"Unsupported 'encoding' attribute {self.encoding}"
|
||||
return message["text"] if message.get("text") else message["bytes"]
|
||||
|
||||
async def on_connect(self, websocket: WebSocket) -> None:
|
||||
"""Override to handle an incoming websocket connection"""
|
||||
await websocket.accept()
|
||||
|
||||
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
|
||||
"""Override to handle an incoming websocket message"""
|
||||
|
||||
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
|
||||
"""Override to handle a disconnecting websocket"""
|
98
.venv/lib/python3.9/site-packages/starlette/exceptions.py
Normal file
98
.venv/lib/python3.9/site-packages/starlette/exceptions.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import http
|
||||
import typing
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class HTTPException(Exception):
|
||||
def __init__(self, status_code: int, detail: str = None) -> None:
|
||||
if detail is None:
|
||||
detail = http.HTTPStatus(status_code).phrase
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
|
||||
|
||||
|
||||
class ExceptionMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, handlers: dict = None, debug: bool = False
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
|
||||
self._status_handlers = {} # type: typing.Dict[int, typing.Callable]
|
||||
self._exception_handlers = {
|
||||
HTTPException: self.http_exception
|
||||
} # type: typing.Dict[typing.Type[Exception], typing.Callable]
|
||||
if handlers is not None:
|
||||
for key, value in handlers.items():
|
||||
self.add_exception_handler(key, value)
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
|
||||
handler: typing.Callable,
|
||||
) -> None:
|
||||
if isinstance(exc_class_or_status_code, int):
|
||||
self._status_handlers[exc_class_or_status_code] = handler
|
||||
else:
|
||||
assert issubclass(exc_class_or_status_code, Exception)
|
||||
self._exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
def _lookup_exception_handler(
|
||||
self, exc: Exception
|
||||
) -> typing.Optional[typing.Callable]:
|
||||
for cls in type(exc).__mro__:
|
||||
if cls in self._exception_handlers:
|
||||
return self._exception_handlers[cls]
|
||||
return None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response_started = False
|
||||
|
||||
async def sender(message: Message) -> None:
|
||||
nonlocal response_started
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, sender)
|
||||
except Exception as exc:
|
||||
handler = None
|
||||
|
||||
if isinstance(exc, HTTPException):
|
||||
handler = self._status_handlers.get(exc.status_code)
|
||||
|
||||
if handler is None:
|
||||
handler = self._lookup_exception_handler(exc)
|
||||
|
||||
if handler is None:
|
||||
raise exc from None
|
||||
|
||||
if response_started:
|
||||
msg = "Caught handled exception, but response already started."
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
request = Request(scope, receive=receive)
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
response = await handler(request, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, request, exc)
|
||||
await response(scope, receive, sender)
|
||||
|
||||
def http_exception(self, request: Request, exc: HTTPException) -> Response:
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(b"", status_code=exc.status_code)
|
||||
return PlainTextResponse(exc.detail, status_code=exc.status_code)
|
242
.venv/lib/python3.9/site-packages/starlette/formparsers.py
Normal file
242
.venv/lib/python3.9/site-packages/starlette/formparsers.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import typing
|
||||
from enum import Enum
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from starlette.datastructures import FormData, Headers, UploadFile
|
||||
|
||||
try:
|
||||
import multipart
|
||||
from multipart.multipart import parse_options_header
|
||||
except ImportError: # pragma: nocover
|
||||
parse_options_header = None
|
||||
multipart = None
|
||||
|
||||
|
||||
class FormMessage(Enum):
|
||||
FIELD_START = 1
|
||||
FIELD_NAME = 2
|
||||
FIELD_DATA = 3
|
||||
FIELD_END = 4
|
||||
END = 5
|
||||
|
||||
|
||||
class MultiPartMessage(Enum):
|
||||
PART_BEGIN = 1
|
||||
PART_DATA = 2
|
||||
PART_END = 3
|
||||
HEADER_FIELD = 4
|
||||
HEADER_VALUE = 5
|
||||
HEADER_END = 6
|
||||
HEADERS_FINISHED = 7
|
||||
END = 8
|
||||
|
||||
|
||||
def _user_safe_decode(src: bytes, codec: str) -> str:
|
||||
try:
|
||||
return src.decode(codec)
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
return src.decode("latin-1")
|
||||
|
||||
|
||||
class FormParser:
|
||||
def __init__(
|
||||
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
|
||||
) -> None:
|
||||
assert (
|
||||
multipart is not None
|
||||
), "The `python-multipart` library must be installed to use form parsing."
|
||||
self.headers = headers
|
||||
self.stream = stream
|
||||
self.messages = [] # type: typing.List[typing.Tuple[FormMessage, bytes]]
|
||||
|
||||
def on_field_start(self) -> None:
|
||||
message = (FormMessage.FIELD_START, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
def on_field_name(self, data: bytes, start: int, end: int) -> None:
|
||||
message = (FormMessage.FIELD_NAME, data[start:end])
|
||||
self.messages.append(message)
|
||||
|
||||
def on_field_data(self, data: bytes, start: int, end: int) -> None:
|
||||
message = (FormMessage.FIELD_DATA, data[start:end])
|
||||
self.messages.append(message)
|
||||
|
||||
def on_field_end(self) -> None:
|
||||
message = (FormMessage.FIELD_END, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
def on_end(self) -> None:
|
||||
message = (FormMessage.END, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
async def parse(self) -> FormData:
|
||||
# Callbacks dictionary.
|
||||
callbacks = {
|
||||
"on_field_start": self.on_field_start,
|
||||
"on_field_name": self.on_field_name,
|
||||
"on_field_data": self.on_field_data,
|
||||
"on_field_end": self.on_field_end,
|
||||
"on_end": self.on_end,
|
||||
}
|
||||
|
||||
# Create the parser.
|
||||
parser = multipart.QuerystringParser(callbacks)
|
||||
field_name = b""
|
||||
field_value = b""
|
||||
|
||||
items = (
|
||||
[]
|
||||
) # type: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]]
|
||||
|
||||
# Feed the parser with data from the request.
|
||||
async for chunk in self.stream:
|
||||
if chunk:
|
||||
parser.write(chunk)
|
||||
else:
|
||||
parser.finalize()
|
||||
messages = list(self.messages)
|
||||
self.messages.clear()
|
||||
for message_type, message_bytes in messages:
|
||||
if message_type == FormMessage.FIELD_START:
|
||||
field_name = b""
|
||||
field_value = b""
|
||||
elif message_type == FormMessage.FIELD_NAME:
|
||||
field_name += message_bytes
|
||||
elif message_type == FormMessage.FIELD_DATA:
|
||||
field_value += message_bytes
|
||||
elif message_type == FormMessage.FIELD_END:
|
||||
name = unquote_plus(field_name.decode("latin-1"))
|
||||
value = unquote_plus(field_value.decode("latin-1"))
|
||||
items.append((name, value))
|
||||
elif message_type == FormMessage.END:
|
||||
pass
|
||||
|
||||
return FormData(items)
|
||||
|
||||
|
||||
class MultiPartParser:
|
||||
def __init__(
|
||||
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
|
||||
) -> None:
|
||||
assert (
|
||||
multipart is not None
|
||||
), "The `python-multipart` library must be installed to use form parsing."
|
||||
self.headers = headers
|
||||
self.stream = stream
|
||||
self.messages = [] # type: typing.List[typing.Tuple[MultiPartMessage, bytes]]
|
||||
|
||||
def on_part_begin(self) -> None:
|
||||
message = (MultiPartMessage.PART_BEGIN, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
def on_part_data(self, data: bytes, start: int, end: int) -> None:
|
||||
message = (MultiPartMessage.PART_DATA, data[start:end])
|
||||
self.messages.append(message)
|
||||
|
||||
def on_part_end(self) -> None:
|
||||
message = (MultiPartMessage.PART_END, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
def on_header_field(self, data: bytes, start: int, end: int) -> None:
|
||||
message = (MultiPartMessage.HEADER_FIELD, data[start:end])
|
||||
self.messages.append(message)
|
||||
|
||||
def on_header_value(self, data: bytes, start: int, end: int) -> None:
|
||||
message = (MultiPartMessage.HEADER_VALUE, data[start:end])
|
||||
self.messages.append(message)
|
||||
|
||||
def on_header_end(self) -> None:
|
||||
message = (MultiPartMessage.HEADER_END, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
def on_headers_finished(self) -> None:
|
||||
message = (MultiPartMessage.HEADERS_FINISHED, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
def on_end(self) -> None:
|
||||
message = (MultiPartMessage.END, b"")
|
||||
self.messages.append(message)
|
||||
|
||||
async def parse(self) -> FormData:
|
||||
# Parse the Content-Type header to get the multipart boundary.
|
||||
content_type, params = parse_options_header(self.headers["Content-Type"])
|
||||
charset = params.get(b"charset", "utf-8")
|
||||
if type(charset) == bytes:
|
||||
charset = charset.decode("latin-1")
|
||||
boundary = params.get(b"boundary")
|
||||
|
||||
# Callbacks dictionary.
|
||||
callbacks = {
|
||||
"on_part_begin": self.on_part_begin,
|
||||
"on_part_data": self.on_part_data,
|
||||
"on_part_end": self.on_part_end,
|
||||
"on_header_field": self.on_header_field,
|
||||
"on_header_value": self.on_header_value,
|
||||
"on_header_end": self.on_header_end,
|
||||
"on_headers_finished": self.on_headers_finished,
|
||||
"on_end": self.on_end,
|
||||
}
|
||||
|
||||
# Create the parser.
|
||||
parser = multipart.MultipartParser(boundary, callbacks)
|
||||
header_field = b""
|
||||
header_value = b""
|
||||
content_disposition = None
|
||||
content_type = b""
|
||||
field_name = ""
|
||||
data = b""
|
||||
file = None # type: typing.Optional[UploadFile]
|
||||
|
||||
items = (
|
||||
[]
|
||||
) # type: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]]
|
||||
|
||||
# Feed the parser with data from the request.
|
||||
async for chunk in self.stream:
|
||||
parser.write(chunk)
|
||||
messages = list(self.messages)
|
||||
self.messages.clear()
|
||||
for message_type, message_bytes in messages:
|
||||
if message_type == MultiPartMessage.PART_BEGIN:
|
||||
content_disposition = None
|
||||
content_type = b""
|
||||
data = b""
|
||||
elif message_type == MultiPartMessage.HEADER_FIELD:
|
||||
header_field += message_bytes
|
||||
elif message_type == MultiPartMessage.HEADER_VALUE:
|
||||
header_value += message_bytes
|
||||
elif message_type == MultiPartMessage.HEADER_END:
|
||||
field = header_field.lower()
|
||||
if field == b"content-disposition":
|
||||
content_disposition = header_value
|
||||
elif field == b"content-type":
|
||||
content_type = header_value
|
||||
header_field = b""
|
||||
header_value = b""
|
||||
elif message_type == MultiPartMessage.HEADERS_FINISHED:
|
||||
disposition, options = parse_options_header(content_disposition)
|
||||
field_name = _user_safe_decode(options[b"name"], charset)
|
||||
if b"filename" in options:
|
||||
filename = _user_safe_decode(options[b"filename"], charset)
|
||||
file = UploadFile(
|
||||
filename=filename,
|
||||
content_type=content_type.decode("latin-1"),
|
||||
)
|
||||
else:
|
||||
file = None
|
||||
elif message_type == MultiPartMessage.PART_DATA:
|
||||
if file is None:
|
||||
data += message_bytes
|
||||
else:
|
||||
await file.write(message_bytes)
|
||||
elif message_type == MultiPartMessage.PART_END:
|
||||
if file is None:
|
||||
items.append((field_name, _user_safe_decode(data, charset)))
|
||||
else:
|
||||
await file.seek(0)
|
||||
items.append((field_name, file))
|
||||
elif message_type == MultiPartMessage.END:
|
||||
pass
|
||||
|
||||
parser.finalize()
|
||||
return FormData(items)
|
278
.venv/lib/python3.9/site-packages/starlette/graphql.py
Normal file
278
.venv/lib/python3.9/site-packages/starlette/graphql.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import json
|
||||
import typing
|
||||
|
||||
from starlette import status
|
||||
from starlette.background import BackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
try:
|
||||
import graphene
|
||||
from graphql.error import GraphQLError, format_error as format_graphql_error
|
||||
from graphql.execution.executors.asyncio import AsyncioExecutor
|
||||
except ImportError: # pragma: nocover
|
||||
graphene = None
|
||||
AsyncioExecutor = None # type: ignore
|
||||
format_graphql_error = None # type: ignore
|
||||
GraphQLError = None # type: ignore
|
||||
|
||||
|
||||
class GraphQLApp:
|
||||
def __init__(
|
||||
self,
|
||||
schema: "graphene.Schema",
|
||||
executor: typing.Any = None,
|
||||
executor_class: type = None,
|
||||
graphiql: bool = True,
|
||||
) -> None:
|
||||
self.schema = schema
|
||||
self.graphiql = graphiql
|
||||
if executor is None:
|
||||
# New style in 0.10.0. Use 'executor_class'.
|
||||
# See issue https://github.com/encode/starlette/issues/242
|
||||
self.executor = executor
|
||||
self.executor_class = executor_class
|
||||
self.is_async = executor_class is not None and issubclass(
|
||||
executor_class, AsyncioExecutor
|
||||
)
|
||||
else:
|
||||
# Old style. Use 'executor'.
|
||||
# We should remove this in the next median/major version bump.
|
||||
self.executor = executor
|
||||
self.executor_class = None
|
||||
self.is_async = isinstance(executor, AsyncioExecutor)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.executor is None and self.executor_class is not None:
|
||||
self.executor = self.executor_class()
|
||||
|
||||
request = Request(scope, receive=receive)
|
||||
response = await self.handle_graphql(request)
|
||||
await response(scope, receive, send)
|
||||
|
||||
async def handle_graphql(self, request: Request) -> Response:
|
||||
if request.method in ("GET", "HEAD"):
|
||||
if "text/html" in request.headers.get("Accept", ""):
|
||||
if not self.graphiql:
|
||||
return PlainTextResponse(
|
||||
"Not Found", status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
return await self.handle_graphiql(request)
|
||||
|
||||
data = request.query_params # type: typing.Mapping[str, typing.Any]
|
||||
|
||||
elif request.method == "POST":
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
data = await request.json()
|
||||
elif "application/graphql" in content_type:
|
||||
body = await request.body()
|
||||
text = body.decode()
|
||||
data = {"query": text}
|
||||
elif "query" in request.query_params:
|
||||
data = request.query_params
|
||||
else:
|
||||
return PlainTextResponse(
|
||||
"Unsupported Media Type",
|
||||
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
||||
)
|
||||
|
||||
else:
|
||||
return PlainTextResponse(
|
||||
"Method Not Allowed", status_code=status.HTTP_405_METHOD_NOT_ALLOWED
|
||||
)
|
||||
|
||||
try:
|
||||
query = data["query"]
|
||||
variables = data.get("variables")
|
||||
operation_name = data.get("operationName")
|
||||
except KeyError:
|
||||
return PlainTextResponse(
|
||||
"No GraphQL query found in the request",
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
background = BackgroundTasks()
|
||||
context = {"request": request, "background": background}
|
||||
|
||||
result = await self.execute(
|
||||
query, variables=variables, context=context, operation_name=operation_name
|
||||
)
|
||||
error_data = (
|
||||
[format_graphql_error(err) for err in result.errors]
|
||||
if result.errors
|
||||
else None
|
||||
)
|
||||
response_data = {"data": result.data}
|
||||
if error_data:
|
||||
response_data["errors"] = error_data
|
||||
status_code = (
|
||||
status.HTTP_400_BAD_REQUEST if result.errors else status.HTTP_200_OK
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
response_data, status_code=status_code, background=background
|
||||
)
|
||||
|
||||
async def execute( # type: ignore
|
||||
self, query, variables=None, context=None, operation_name=None
|
||||
):
|
||||
if self.is_async:
|
||||
return await self.schema.execute(
|
||||
query,
|
||||
variables=variables,
|
||||
operation_name=operation_name,
|
||||
executor=self.executor,
|
||||
return_promise=True,
|
||||
context=context,
|
||||
)
|
||||
else:
|
||||
return await run_in_threadpool(
|
||||
self.schema.execute,
|
||||
query,
|
||||
variables=variables,
|
||||
operation_name=operation_name,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def handle_graphiql(self, request: Request) -> Response:
|
||||
text = GRAPHIQL.replace("{{REQUEST_PATH}}", json.dumps(request.url.path))
|
||||
return HTMLResponse(text)
|
||||
|
||||
|
||||
GRAPHIQL = """
|
||||
<!--
|
||||
* Copyright (c) Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
-->
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body {
|
||||
height: 100%;
|
||||
margin: 0;
|
||||
width: 100%;
|
||||
overflow: hidden;
|
||||
}
|
||||
#graphiql {
|
||||
height: 100vh;
|
||||
}
|
||||
</style>
|
||||
<!--
|
||||
This GraphiQL example depends on Promise and fetch, which are available in
|
||||
modern browsers, but can be "polyfilled" for older browsers.
|
||||
GraphiQL itself depends on React DOM.
|
||||
If you do not want to rely on a CDN, you can host these files locally or
|
||||
include them directly in your favored resource bunder.
|
||||
-->
|
||||
<link href="//cdn.jsdelivr.net/npm/graphiql@0.12.0/graphiql.css" rel="stylesheet"/>
|
||||
<script src="//cdn.jsdelivr.net/npm/whatwg-fetch@2.0.3/fetch.min.js"></script>
|
||||
<script src="//cdn.jsdelivr.net/npm/react@16.2.0/umd/react.production.min.js"></script>
|
||||
<script src="//cdn.jsdelivr.net/npm/react-dom@16.2.0/umd/react-dom.production.min.js"></script>
|
||||
<script src="//cdn.jsdelivr.net/npm/graphiql@0.12.0/graphiql.min.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<div id="graphiql">Loading...</div>
|
||||
<script>
|
||||
/**
|
||||
* This GraphiQL example illustrates how to use some of GraphiQL's props
|
||||
* in order to enable reading and updating the URL parameters, making
|
||||
* link sharing of queries a little bit easier.
|
||||
*
|
||||
* This is only one example of this kind of feature, GraphiQL exposes
|
||||
* various React params to enable interesting integrations.
|
||||
*/
|
||||
// Parse the search string to get url parameters.
|
||||
var search = window.location.search;
|
||||
var parameters = {};
|
||||
search.substr(1).split('&').forEach(function (entry) {
|
||||
var eq = entry.indexOf('=');
|
||||
if (eq >= 0) {
|
||||
parameters[decodeURIComponent(entry.slice(0, eq))] =
|
||||
decodeURIComponent(entry.slice(eq + 1));
|
||||
}
|
||||
});
|
||||
// if variables was provided, try to format it.
|
||||
if (parameters.variables) {
|
||||
try {
|
||||
parameters.variables =
|
||||
JSON.stringify(JSON.parse(parameters.variables), null, 2);
|
||||
} catch (e) {
|
||||
// Do nothing, we want to display the invalid JSON as a string, rather
|
||||
// than present an error.
|
||||
}
|
||||
}
|
||||
// When the query and variables string is edited, update the URL bar so
|
||||
// that it can be easily shared
|
||||
function onEditQuery(newQuery) {
|
||||
parameters.query = newQuery;
|
||||
updateURL();
|
||||
}
|
||||
function onEditVariables(newVariables) {
|
||||
parameters.variables = newVariables;
|
||||
updateURL();
|
||||
}
|
||||
function onEditOperationName(newOperationName) {
|
||||
parameters.operationName = newOperationName;
|
||||
updateURL();
|
||||
}
|
||||
function updateURL() {
|
||||
var newSearch = '?' + Object.keys(parameters).filter(function (key) {
|
||||
return Boolean(parameters[key]);
|
||||
}).map(function (key) {
|
||||
return encodeURIComponent(key) + '=' +
|
||||
encodeURIComponent(parameters[key]);
|
||||
}).join('&');
|
||||
history.replaceState(null, null, newSearch);
|
||||
}
|
||||
// Defines a GraphQL fetcher using the fetch API. You're not required to
|
||||
// use fetch, and could instead implement graphQLFetcher however you like,
|
||||
// as long as it returns a Promise or Observable.
|
||||
function graphQLFetcher(graphQLParams) {
|
||||
// This example expects a GraphQL server at the path /graphql.
|
||||
// Change this to point wherever you host your GraphQL server.
|
||||
return fetch({{REQUEST_PATH}}, {
|
||||
method: 'post',
|
||||
headers: {
|
||||
'Accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(graphQLParams),
|
||||
credentials: 'include',
|
||||
}).then(function (response) {
|
||||
return response.text();
|
||||
}).then(function (responseBody) {
|
||||
try {
|
||||
return JSON.parse(responseBody);
|
||||
} catch (error) {
|
||||
return responseBody;
|
||||
}
|
||||
});
|
||||
}
|
||||
// Render <GraphiQL /> into the body.
|
||||
// See the README in the top level of this module to learn more about
|
||||
// how you can customize GraphiQL by providing different values or
|
||||
// additional child elements.
|
||||
ReactDOM.render(
|
||||
React.createElement(GraphiQL, {
|
||||
fetcher: graphQLFetcher,
|
||||
query: parameters.query,
|
||||
variables: parameters.variables,
|
||||
operationName: parameters.operationName,
|
||||
onEditQuery: onEditQuery,
|
||||
onEditVariables: onEditVariables,
|
||||
onEditOperationName: onEditOperationName
|
||||
}),
|
||||
document.getElementById('graphiql')
|
||||
);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
@@ -0,0 +1,17 @@
|
||||
import typing
|
||||
|
||||
|
||||
class Middleware:
|
||||
def __init__(self, cls: type, **options: typing.Any) -> None:
|
||||
self.cls = cls
|
||||
self.options = options
|
||||
|
||||
def __iter__(self) -> typing.Iterator:
|
||||
as_tuple = (self.cls, self.options)
|
||||
return iter(as_tuple)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
|
||||
args_repr = ", ".join([self.cls.__name__] + option_strings)
|
||||
return f"{class_name}({args_repr})"
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,52 @@
|
||||
import typing
|
||||
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
AuthenticationBackend,
|
||||
AuthenticationError,
|
||||
UnauthenticatedUser,
|
||||
)
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Callable[
|
||||
[HTTPConnection, AuthenticationError], Response
|
||||
] = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
) # type: typing.Callable[[HTTPConnection, AuthenticationError], Response]
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ["http", "websocket"]:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
conn = HTTPConnection(scope)
|
||||
try:
|
||||
auth_result = await self.backend.authenticate(conn)
|
||||
except AuthenticationError as exc:
|
||||
response = self.on_error(conn, exc)
|
||||
if scope["type"] == "websocket":
|
||||
await send({"type": "websocket.close", "code": 1000})
|
||||
else:
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
if auth_result is None:
|
||||
auth_result = AuthCredentials(), UnauthenticatedUser()
|
||||
scope["auth"], scope["user"] = auth_result
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
@staticmethod
|
||||
def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
|
||||
return PlainTextResponse(str(exc), status_code=400)
|
@@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||||
DispatchFunction = typing.Callable[
|
||||
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
|
||||
]
|
||||
|
||||
|
||||
class BaseHTTPMiddleware:
|
||||
def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None:
|
||||
self.app = app
|
||||
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope, receive=receive)
|
||||
response = await self.dispatch_func(request, self.call_next)
|
||||
await response(scope, receive, send)
|
||||
|
||||
async def call_next(self, request: Request) -> Response:
|
||||
loop = asyncio.get_event_loop()
|
||||
queue = asyncio.Queue() # type: asyncio.Queue
|
||||
|
||||
scope = request.scope
|
||||
receive = request.receive
|
||||
send = queue.put
|
||||
|
||||
async def coro() -> None:
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
finally:
|
||||
await queue.put(None)
|
||||
|
||||
task = loop.create_task(coro())
|
||||
message = await queue.get()
|
||||
if message is None:
|
||||
task.result()
|
||||
raise RuntimeError("No response returned.")
|
||||
assert message["type"] == "http.response.start"
|
||||
|
||||
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||||
while True:
|
||||
message = await queue.get()
|
||||
if message is None:
|
||||
break
|
||||
assert message["type"] == "http.response.body"
|
||||
yield message.get("body", b"")
|
||||
task.result()
|
||||
|
||||
response = StreamingResponse(
|
||||
status_code=message["status"], content=body_stream()
|
||||
)
|
||||
response.raw_headers = message["headers"]
|
||||
return response
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
raise NotImplementedError() # pragma: no cover
|
167
.venv/lib/python3.9/site-packages/starlette/middleware/cors.py
Normal file
167
.venv/lib/python3.9/site-packages/starlette/middleware/cors.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
ALL_METHODS = ("DELETE", "GET", "OPTIONS", "PATCH", "POST", "PUT")
|
||||
SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
|
||||
|
||||
|
||||
class CORSMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allow_origins: typing.Sequence[str] = (),
|
||||
allow_methods: typing.Sequence[str] = ("GET",),
|
||||
allow_headers: typing.Sequence[str] = (),
|
||||
allow_credentials: bool = False,
|
||||
allow_origin_regex: str = None,
|
||||
expose_headers: typing.Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
) -> None:
|
||||
|
||||
if "*" in allow_methods:
|
||||
allow_methods = ALL_METHODS
|
||||
|
||||
compiled_allow_origin_regex = None
|
||||
if allow_origin_regex is not None:
|
||||
compiled_allow_origin_regex = re.compile(allow_origin_regex)
|
||||
|
||||
simple_headers = {}
|
||||
if "*" in allow_origins:
|
||||
simple_headers["Access-Control-Allow-Origin"] = "*"
|
||||
if allow_credentials:
|
||||
simple_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
if expose_headers:
|
||||
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
|
||||
|
||||
preflight_headers = {}
|
||||
if "*" in allow_origins:
|
||||
preflight_headers["Access-Control-Allow-Origin"] = "*"
|
||||
else:
|
||||
preflight_headers["Vary"] = "Origin"
|
||||
preflight_headers.update(
|
||||
{
|
||||
"Access-Control-Allow-Methods": ", ".join(allow_methods),
|
||||
"Access-Control-Max-Age": str(max_age),
|
||||
}
|
||||
)
|
||||
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
|
||||
if allow_headers and "*" not in allow_headers:
|
||||
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
|
||||
if allow_credentials:
|
||||
preflight_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
self.app = app
|
||||
self.allow_origins = allow_origins
|
||||
self.allow_methods = allow_methods
|
||||
self.allow_headers = [h.lower() for h in allow_headers]
|
||||
self.allow_all_origins = "*" in allow_origins
|
||||
self.allow_all_headers = "*" in allow_headers
|
||||
self.allow_origin_regex = compiled_allow_origin_regex
|
||||
self.simple_headers = simple_headers
|
||||
self.preflight_headers = preflight_headers
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http": # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
method = scope["method"]
|
||||
headers = Headers(scope=scope)
|
||||
origin = headers.get("origin")
|
||||
|
||||
if origin is None:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
if method == "OPTIONS" and "access-control-request-method" in headers:
|
||||
response = self.preflight_response(request_headers=headers)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.simple_response(scope, receive, send, request_headers=headers)
|
||||
|
||||
def is_allowed_origin(self, origin: str) -> bool:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
|
||||
origin
|
||||
):
|
||||
return True
|
||||
|
||||
return origin in self.allow_origins
|
||||
|
||||
def preflight_response(self, request_headers: Headers) -> Response:
|
||||
requested_origin = request_headers["origin"]
|
||||
requested_method = request_headers["access-control-request-method"]
|
||||
requested_headers = request_headers.get("access-control-request-headers")
|
||||
|
||||
headers = dict(self.preflight_headers)
|
||||
failures = []
|
||||
|
||||
if self.is_allowed_origin(origin=requested_origin):
|
||||
if not self.allow_all_origins:
|
||||
# If self.allow_all_origins is True, then the "Access-Control-Allow-Origin"
|
||||
# header is already set to "*".
|
||||
# If we only allow specific origins, then we have to mirror back
|
||||
# the Origin header in the response.
|
||||
headers["Access-Control-Allow-Origin"] = requested_origin
|
||||
else:
|
||||
failures.append("origin")
|
||||
|
||||
if requested_method not in self.allow_methods:
|
||||
failures.append("method")
|
||||
|
||||
# If we allow all headers, then we have to mirror back any requested
|
||||
# headers in the response.
|
||||
if self.allow_all_headers and requested_headers is not None:
|
||||
headers["Access-Control-Allow-Headers"] = requested_headers
|
||||
elif requested_headers is not None:
|
||||
for header in [h.lower() for h in requested_headers.split(",")]:
|
||||
if header.strip() not in self.allow_headers:
|
||||
failures.append("headers")
|
||||
|
||||
# We don't strictly need to use 400 responses here, since its up to
|
||||
# the browser to enforce the CORS policy, but its more informative
|
||||
# if we do.
|
||||
if failures:
|
||||
failure_text = "Disallowed CORS " + ", ".join(failures)
|
||||
return PlainTextResponse(failure_text, status_code=400, headers=headers)
|
||||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(
|
||||
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def send(
|
||||
self, message: Message, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(scope=message)
|
||||
headers.update(self.simple_headers)
|
||||
origin = request_headers["Origin"]
|
||||
has_cookie = "cookie" in request_headers
|
||||
|
||||
# If request includes any cookie headers, then we must respond
|
||||
# with the specific origin instead of '*'.
|
||||
if self.allow_all_origins and has_cookie:
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
|
||||
# If we only allow specific origins, then we have to mirror back
|
||||
# the Origin header in the response.
|
||||
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
headers.add_vary_header("Origin")
|
||||
await send(message)
|
246
.venv/lib/python3.9/site-packages/starlette/middleware/errors.py
Normal file
246
.venv/lib/python3.9/site-packages/starlette/middleware/errors.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import asyncio
|
||||
import html
|
||||
import inspect
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
STYLES = """
|
||||
p {
|
||||
color: #211c1c;
|
||||
}
|
||||
.traceback-container {
|
||||
border: 1px solid #038BB8;
|
||||
}
|
||||
.traceback-title {
|
||||
background-color: #038BB8;
|
||||
color: lemonchiffon;
|
||||
padding: 12px;
|
||||
font-size: 20px;
|
||||
margin-top: 0px;
|
||||
}
|
||||
.frame-line {
|
||||
padding-left: 10px;
|
||||
font-family: monospace;
|
||||
}
|
||||
.frame-filename {
|
||||
font-family: monospace;
|
||||
}
|
||||
.center-line {
|
||||
background-color: #038BB8;
|
||||
color: #f9f6e1;
|
||||
padding: 5px 0px 5px 5px;
|
||||
}
|
||||
.lineno {
|
||||
margin-right: 5px;
|
||||
}
|
||||
.frame-title {
|
||||
font-weight: unset;
|
||||
padding: 10px 10px 10px 10px;
|
||||
background-color: #E4F4FD;
|
||||
margin-right: 10px;
|
||||
color: #191f21;
|
||||
font-size: 17px;
|
||||
border: 1px solid #c7dce8;
|
||||
}
|
||||
.collapse-btn {
|
||||
float: right;
|
||||
padding: 0px 5px 1px 5px;
|
||||
border: solid 1px #96aebb;
|
||||
cursor: pointer;
|
||||
}
|
||||
.collapsed {
|
||||
display: none;
|
||||
}
|
||||
.source-code {
|
||||
font-family: courier;
|
||||
font-size: small;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
"""
|
||||
|
||||
JS = """
|
||||
<script type="text/javascript">
|
||||
function collapse(element){
|
||||
const frameId = element.getAttribute("data-frame-id");
|
||||
const frame = document.getElementById(frameId);
|
||||
|
||||
if (frame.classList.contains("collapsed")){
|
||||
element.innerHTML = "‒";
|
||||
frame.classList.remove("collapsed");
|
||||
} else {
|
||||
element.innerHTML = "+";
|
||||
frame.classList.add("collapsed");
|
||||
}
|
||||
}
|
||||
</script>
|
||||
"""
|
||||
|
||||
TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<style type='text/css'>
|
||||
{styles}
|
||||
</style>
|
||||
<title>Starlette Debugger</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>500 Server Error</h1>
|
||||
<h2>{error}</h2>
|
||||
<div class="traceback-container">
|
||||
<p class="traceback-title">Traceback</p>
|
||||
<div>{exc_html}</div>
|
||||
</div>
|
||||
{js}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
FRAME_TEMPLATE = """
|
||||
<div>
|
||||
<p class="frame-title">File <span class="frame-filename">{frame_filename}</span>,
|
||||
line <i>{frame_lineno}</i>,
|
||||
in <b>{frame_name}</b>
|
||||
<span class="collapse-btn" data-frame-id="{frame_filename}-{frame_lineno}" onclick="collapse(this)">{collapse_button}</span>
|
||||
</p>
|
||||
<div id="{frame_filename}-{frame_lineno}" class="source-code {collapsed}">{code_context}</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
LINE = """
|
||||
<p><span class="frame-line">
|
||||
<span class="lineno">{lineno}.</span> {line}</span></p>
|
||||
"""
|
||||
|
||||
CENTER_LINE = """
|
||||
<p class="center-line"><span class="frame-line center-line">
|
||||
<span class="lineno">{lineno}.</span> {line}</span></p>
|
||||
"""
|
||||
|
||||
|
||||
class ServerErrorMiddleware:
|
||||
"""
|
||||
Handles returning 500 responses when a server error occurs.
|
||||
|
||||
If 'debug' is set, then traceback responses will be returned,
|
||||
otherwise the designated 'handler' will be called.
|
||||
|
||||
This middleware class should generally be used to wrap *everything*
|
||||
else up, so that unhandled exceptions anywhere in the stack
|
||||
always result in an appropriate 500 response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, app: ASGIApp, handler: typing.Callable = None, debug: bool = False
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.handler = handler
|
||||
self.debug = debug
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response_started = False
|
||||
|
||||
async def _send(message: Message) -> None:
|
||||
nonlocal response_started, send
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, _send)
|
||||
except Exception as exc:
|
||||
if not response_started:
|
||||
request = Request(scope)
|
||||
if self.debug:
|
||||
# In debug mode, return traceback responses.
|
||||
response = self.debug_response(request, exc)
|
||||
elif self.handler is None:
|
||||
# Use our default 500 error handler.
|
||||
response = self.error_response(request, exc)
|
||||
else:
|
||||
# Use an installed 500 error handler.
|
||||
if asyncio.iscoroutinefunction(self.handler):
|
||||
response = await self.handler(request, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(self.handler, request, exc)
|
||||
|
||||
await response(scope, receive, send)
|
||||
|
||||
# We always continue to raise the exception.
|
||||
# This allows servers to log the error, or allows test clients
|
||||
# to optionally raise the error within the test case.
|
||||
raise exc from None
|
||||
|
||||
def format_line(
|
||||
self, index: int, line: str, frame_lineno: int, frame_index: int
|
||||
) -> str:
|
||||
values = {
|
||||
# HTML escape - line could contain < or >
|
||||
"line": html.escape(line).replace(" ", " "),
|
||||
"lineno": (frame_lineno - frame_index) + index,
|
||||
}
|
||||
|
||||
if index != frame_index:
|
||||
return LINE.format(**values)
|
||||
return CENTER_LINE.format(**values)
|
||||
|
||||
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
|
||||
code_context = "".join(
|
||||
self.format_line(index, line, frame.lineno, frame.index) # type: ignore
|
||||
for index, line in enumerate(frame.code_context or [])
|
||||
)
|
||||
|
||||
values = {
|
||||
# HTML escape - filename could contain < or >, especially if it's a virtual file e.g. <stdin> in the REPL
|
||||
"frame_filename": html.escape(frame.filename),
|
||||
"frame_lineno": frame.lineno,
|
||||
# HTML escape - if you try very hard it's possible to name a function with < or >
|
||||
"frame_name": html.escape(frame.function),
|
||||
"code_context": code_context,
|
||||
"collapsed": "collapsed" if is_collapsed else "",
|
||||
"collapse_button": "+" if is_collapsed else "‒",
|
||||
}
|
||||
return FRAME_TEMPLATE.format(**values)
|
||||
|
||||
def generate_html(self, exc: Exception, limit: int = 7) -> str:
|
||||
traceback_obj = traceback.TracebackException.from_exception(
|
||||
exc, capture_locals=True
|
||||
)
|
||||
frames = inspect.getinnerframes(
|
||||
traceback_obj.exc_traceback, limit # type: ignore
|
||||
)
|
||||
|
||||
exc_html = ""
|
||||
is_collapsed = False
|
||||
for frame in reversed(frames):
|
||||
exc_html += self.generate_frame_html(frame, is_collapsed)
|
||||
is_collapsed = True
|
||||
|
||||
# escape error class and text
|
||||
error = f"{html.escape(traceback_obj.exc_type.__name__)}: {html.escape(str(traceback_obj))}"
|
||||
|
||||
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
|
||||
|
||||
def generate_plain_text(self, exc: Exception) -> str:
|
||||
return "".join(traceback.format_tb(exc.__traceback__))
|
||||
|
||||
def debug_response(self, request: Request, exc: Exception) -> Response:
|
||||
accept = request.headers.get("accept", "")
|
||||
|
||||
if "text/html" in accept:
|
||||
content = self.generate_html(exc)
|
||||
return HTMLResponse(content, status_code=500)
|
||||
content = self.generate_plain_text(exc)
|
||||
return PlainTextResponse(content, status_code=500)
|
||||
|
||||
def error_response(self, request: Request, exc: Exception) -> Response:
|
||||
return PlainTextResponse("Internal Server Error", status_code=500)
|
@@ -0,0 +1,97 @@
|
||||
import gzip
|
||||
import io
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class GZipMiddleware:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int = 500) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "http":
|
||||
headers = Headers(scope=scope)
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(self.app, self.minimum_size)
|
||||
await responder(scope, receive, send)
|
||||
return
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class GZipResponder:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.send = unattached_send # type: Send
|
||||
self.initial_message = {} # type: Message
|
||||
self.started = False
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.send = send
|
||||
await self.app(scope, receive, self.send_with_gzip)
|
||||
|
||||
async def send_with_gzip(self, message: Message) -> None:
|
||||
message_type = message["type"]
|
||||
if message_type == "http.response.start":
|
||||
# Don't send the initial message until we've determined how to
|
||||
# modify the ougoging headers correctly.
|
||||
self.initial_message = message
|
||||
elif message_type == "http.response.body" and not self.started:
|
||||
self.started = True
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if len(body) < self.minimum_size and not more_body:
|
||||
# Don't apply GZip to small outgoing responses.
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif not more_body:
|
||||
# Standard GZip response.
|
||||
self.gzip_file.write(body)
|
||||
self.gzip_file.close()
|
||||
body = self.gzip_buffer.getvalue()
|
||||
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers["Content-Length"] = str(len(body))
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
else:
|
||||
# Initial body in streaming GZip response.
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
del headers["Content-Length"]
|
||||
|
||||
self.gzip_file.write(body)
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
elif message_type == "http.response.body":
|
||||
# Remaining body in streaming GZip response.
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
await self.send(message)
|
||||
|
||||
|
||||
async def unattached_send(message: Message) -> None:
|
||||
raise RuntimeError("send awaitable not set") # pragma: no cover
|
@@ -0,0 +1,19 @@
|
||||
from starlette.datastructures import URL
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class HTTPSRedirectMiddleware:
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"):
|
||||
url = URL(scope=scope)
|
||||
redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme]
|
||||
netloc = url.hostname if url.port in (80, 443) else url.netloc
|
||||
url = url.replace(scheme=redirect_scheme, netloc=netloc)
|
||||
response = RedirectResponse(url, status_code=307)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
import typing
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
import itsdangerous
|
||||
from itsdangerous.exc import BadTimeSignature, SignatureExpired
|
||||
|
||||
from starlette.datastructures import MutableHeaders, Secret
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class SessionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
secret_key: typing.Union[str, Secret],
|
||||
session_cookie: str = "session",
|
||||
max_age: int = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
same_site: str = "lax",
|
||||
https_only: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.signer = itsdangerous.TimestampSigner(str(secret_key))
|
||||
self.session_cookie = session_cookie
|
||||
self.max_age = max_age
|
||||
self.security_flags = "httponly; samesite=" + same_site
|
||||
if https_only: # Secure flag can be used with HTTPS only
|
||||
self.security_flags += "; secure"
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"): # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
connection = HTTPConnection(scope)
|
||||
initial_session_was_empty = True
|
||||
|
||||
if self.session_cookie in connection.cookies:
|
||||
data = connection.cookies[self.session_cookie].encode("utf-8")
|
||||
try:
|
||||
data = self.signer.unsign(data, max_age=self.max_age)
|
||||
scope["session"] = json.loads(b64decode(data))
|
||||
initial_session_was_empty = False
|
||||
except (BadTimeSignature, SignatureExpired):
|
||||
scope["session"] = {}
|
||||
else:
|
||||
scope["session"] = {}
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
if scope["session"]:
|
||||
# We have session data to persist.
|
||||
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
|
||||
data = self.signer.sign(data)
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "%s=%s; path=/; Max-Age=%d; %s" % (
|
||||
self.session_cookie,
|
||||
data.decode("utf-8"),
|
||||
self.max_age,
|
||||
self.security_flags,
|
||||
)
|
||||
headers.append("Set-Cookie", header_value)
|
||||
elif not initial_session_was_empty:
|
||||
# The session has been cleared.
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "%s=%s; %s" % (
|
||||
self.session_cookie,
|
||||
"null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT;",
|
||||
self.security_flags,
|
||||
)
|
||||
headers.append("Set-Cookie", header_value)
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
@@ -0,0 +1,59 @@
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
|
||||
|
||||
|
||||
class TrustedHostMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_hosts: typing.Sequence[str] = None,
|
||||
www_redirect: bool = True,
|
||||
) -> None:
|
||||
if allowed_hosts is None:
|
||||
allowed_hosts = ["*"]
|
||||
|
||||
for pattern in allowed_hosts:
|
||||
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
|
||||
if pattern.startswith("*") and pattern != "*":
|
||||
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
|
||||
self.app = app
|
||||
self.allowed_hosts = list(allowed_hosts)
|
||||
self.allow_any = "*" in allowed_hosts
|
||||
self.www_redirect = www_redirect
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.allow_any or scope["type"] not in (
|
||||
"http",
|
||||
"websocket",
|
||||
): # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
is_valid_host = False
|
||||
found_www_redirect = False
|
||||
for pattern in self.allowed_hosts:
|
||||
if host == pattern or (
|
||||
pattern.startswith("*") and host.endswith(pattern[1:])
|
||||
):
|
||||
is_valid_host = True
|
||||
break
|
||||
elif "www." + host == pattern:
|
||||
found_www_redirect = True
|
||||
|
||||
if is_valid_host:
|
||||
await self.app(scope, receive, send)
|
||||
else:
|
||||
if found_www_redirect and self.www_redirect:
|
||||
url = URL(scope=scope)
|
||||
redirect_url = url.replace(netloc="www." + url.netloc)
|
||||
response = RedirectResponse(url=str(redirect_url)) # type: Response
|
||||
else:
|
||||
response = PlainTextResponse("Invalid host header", status_code=400)
|
||||
await response(scope, receive, send)
|
143
.venv/lib/python3.9/site-packages/starlette/middleware/wsgi.py
Normal file
143
.venv/lib/python3.9/site-packages/starlette/middleware/wsgi.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
import io
|
||||
import sys
|
||||
import typing
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
|
||||
def build_environ(scope: Scope, body: bytes) -> dict:
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
"""
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
|
||||
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": scope.get("scheme", "http"),
|
||||
"wsgi.input": io.BytesIO(body),
|
||||
"wsgi.errors": sys.stdout,
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
}
|
||||
|
||||
# Get server name and port - required in WSGI, not in ASGI
|
||||
server = scope.get("server") or ("localhost", 80)
|
||||
environ["SERVER_NAME"] = server[0]
|
||||
environ["SERVER_PORT"] = server[1]
|
||||
|
||||
# Get client IP address
|
||||
if scope.get("client"):
|
||||
environ["REMOTE_ADDR"] = scope["client"][0]
|
||||
|
||||
# Go through headers and make them into environ entries
|
||||
for name, value in scope.get("headers", []):
|
||||
name = name.decode("latin1")
|
||||
if name == "content-length":
|
||||
corrected_name = "CONTENT_LENGTH"
|
||||
elif name == "content-type":
|
||||
corrected_name = "CONTENT_TYPE"
|
||||
else:
|
||||
corrected_name = f"HTTP_{name}".upper().replace("-", "_")
|
||||
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case
|
||||
value = value.decode("latin1")
|
||||
if corrected_name in environ:
|
||||
value = environ[corrected_name] + "," + value
|
||||
environ[corrected_name] = value
|
||||
return environ
|
||||
|
||||
|
||||
class WSGIMiddleware:
|
||||
def __init__(self, app: typing.Callable, workers: int = 10) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] == "http"
|
||||
responder = WSGIResponder(self.app, scope)
|
||||
await responder(receive, send)
|
||||
|
||||
|
||||
class WSGIResponder:
|
||||
def __init__(self, app: typing.Callable, scope: Scope) -> None:
|
||||
self.app = app
|
||||
self.scope = scope
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.send_event = asyncio.Event()
|
||||
self.send_queue = [] # type: typing.List[typing.Optional[Message]]
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.response_started = False
|
||||
self.exc_info = None # type: typing.Any
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
body = b""
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await receive()
|
||||
body += message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
environ = build_environ(self.scope, body)
|
||||
sender = None
|
||||
try:
|
||||
sender = self.loop.create_task(self.sender(send))
|
||||
await run_in_threadpool(self.wsgi, environ, self.start_response)
|
||||
self.send_queue.append(None)
|
||||
self.send_event.set()
|
||||
await asyncio.wait_for(sender, None)
|
||||
if self.exc_info is not None:
|
||||
raise self.exc_info[0].with_traceback(
|
||||
self.exc_info[1], self.exc_info[2]
|
||||
)
|
||||
finally:
|
||||
if sender and not sender.done():
|
||||
sender.cancel() # pragma: no cover
|
||||
|
||||
async def sender(self, send: Send) -> None:
|
||||
while True:
|
||||
if self.send_queue:
|
||||
message = self.send_queue.pop(0)
|
||||
if message is None:
|
||||
return
|
||||
await send(message)
|
||||
else:
|
||||
await self.send_event.wait()
|
||||
self.send_event.clear()
|
||||
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: typing.List[typing.Tuple[str, str]],
|
||||
exc_info: typing.Any = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
self.response_started = True
|
||||
status_code_string, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_string)
|
||||
headers = [
|
||||
(name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
|
||||
for name, value in response_headers
|
||||
]
|
||||
self.send_queue.append(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
"headers": headers,
|
||||
}
|
||||
)
|
||||
self.loop.call_soon_threadsafe(self.send_event.set)
|
||||
|
||||
def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
|
||||
for chunk in self.app(environ, start_response):
|
||||
self.send_queue.append(
|
||||
{"type": "http.response.body", "body": chunk, "more_body": True}
|
||||
)
|
||||
self.loop.call_soon_threadsafe(self.send_event.set)
|
||||
|
||||
self.send_queue.append({"type": "http.response.body", "body": b""})
|
||||
self.loop.call_soon_threadsafe(self.send_event.set)
|
273
.venv/lib/python3.9/site-packages/starlette/requests.py
Normal file
273
.venv/lib/python3.9/site-packages/starlette/requests.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import asyncio
|
||||
import json
|
||||
import typing
|
||||
from collections.abc import Mapping
|
||||
from http import cookies as http_cookies
|
||||
|
||||
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
|
||||
from starlette.formparsers import FormParser, MultiPartParser
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
try:
|
||||
from multipart.multipart import parse_options_header
|
||||
except ImportError: # pragma: nocover
|
||||
parse_options_header = None
|
||||
|
||||
|
||||
SERVER_PUSH_HEADERS_TO_COPY = {
|
||||
"accept",
|
||||
"accept-encoding",
|
||||
"accept-language",
|
||||
"cache-control",
|
||||
"user-agent",
|
||||
}
|
||||
|
||||
|
||||
def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
|
||||
"""
|
||||
This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
|
||||
|
||||
It attempts to mimic browser cookie parsing behavior: browsers and web servers
|
||||
frequently disregard the spec (RFC 6265) when setting and reading cookies,
|
||||
so we attempt to suit the common scenarios here.
|
||||
|
||||
This function has been adapted from Django 3.1.0.
|
||||
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
|
||||
on an outdated spec and will fail on lots of input we want to support
|
||||
"""
|
||||
cookie_dict: typing.Dict[str, str] = {}
|
||||
for chunk in cookie_string.split(";"):
|
||||
if "=" in chunk:
|
||||
key, val = chunk.split("=", 1)
|
||||
else:
|
||||
# Assume an empty name per
|
||||
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
|
||||
key, val = "", chunk
|
||||
key, val = key.strip(), val.strip()
|
||||
if key or val:
|
||||
# unquote using Python's algorithm.
|
||||
cookie_dict[key] = http_cookies._unquote(val) # type: ignore
|
||||
return cookie_dict
|
||||
|
||||
|
||||
class ClientDisconnect(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HTTPConnection(Mapping):
|
||||
"""
|
||||
A base class for incoming HTTP connections, that is used to provide
|
||||
any functionality that is common to both `Request` and `WebSocket`.
|
||||
"""
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive = None) -> None:
|
||||
assert scope["type"] in ("http", "websocket")
|
||||
self.scope = scope
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
return self.scope[key]
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
return iter(self.scope)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.scope)
|
||||
|
||||
@property
|
||||
def app(self) -> typing.Any:
|
||||
return self.scope["app"]
|
||||
|
||||
@property
|
||||
def url(self) -> URL:
|
||||
if not hasattr(self, "_url"):
|
||||
self._url = URL(scope=self.scope)
|
||||
return self._url
|
||||
|
||||
@property
|
||||
def base_url(self) -> URL:
|
||||
if not hasattr(self, "_base_url"):
|
||||
base_url_scope = dict(self.scope)
|
||||
base_url_scope["path"] = "/"
|
||||
base_url_scope["query_string"] = b""
|
||||
base_url_scope["root_path"] = base_url_scope.get(
|
||||
"app_root_path", base_url_scope.get("root_path", "")
|
||||
)
|
||||
self._base_url = URL(scope=base_url_scope)
|
||||
return self._base_url
|
||||
|
||||
@property
|
||||
def headers(self) -> Headers:
|
||||
if not hasattr(self, "_headers"):
|
||||
self._headers = Headers(scope=self.scope)
|
||||
return self._headers
|
||||
|
||||
@property
|
||||
def query_params(self) -> QueryParams:
|
||||
if not hasattr(self, "_query_params"):
|
||||
self._query_params = QueryParams(self.scope["query_string"])
|
||||
return self._query_params
|
||||
|
||||
@property
|
||||
def path_params(self) -> dict:
|
||||
return self.scope.get("path_params", {})
|
||||
|
||||
@property
|
||||
def cookies(self) -> typing.Dict[str, str]:
|
||||
if not hasattr(self, "_cookies"):
|
||||
cookies: typing.Dict[str, str] = {}
|
||||
cookie_header = self.headers.get("cookie")
|
||||
|
||||
if cookie_header:
|
||||
cookies = cookie_parser(cookie_header)
|
||||
self._cookies = cookies
|
||||
return self._cookies
|
||||
|
||||
@property
|
||||
def client(self) -> Address:
|
||||
host, port = self.scope.get("client") or (None, None)
|
||||
return Address(host=host, port=port)
|
||||
|
||||
@property
|
||||
def session(self) -> dict:
|
||||
assert (
|
||||
"session" in self.scope
|
||||
), "SessionMiddleware must be installed to access request.session"
|
||||
return self.scope["session"]
|
||||
|
||||
@property
|
||||
def auth(self) -> typing.Any:
|
||||
assert (
|
||||
"auth" in self.scope
|
||||
), "AuthenticationMiddleware must be installed to access request.auth"
|
||||
return self.scope["auth"]
|
||||
|
||||
@property
|
||||
def user(self) -> typing.Any:
|
||||
assert (
|
||||
"user" in self.scope
|
||||
), "AuthenticationMiddleware must be installed to access request.user"
|
||||
return self.scope["user"]
|
||||
|
||||
@property
|
||||
def state(self) -> State:
|
||||
if not hasattr(self, "_state"):
|
||||
# Ensure 'state' has an empty dict if it's not already populated.
|
||||
self.scope.setdefault("state", {})
|
||||
# Create a state instance with a reference to the dict in which it should store info
|
||||
self._state = State(self.scope["state"])
|
||||
return self._state
|
||||
|
||||
def url_for(self, name: str, **path_params: typing.Any) -> str:
|
||||
router = self.scope["router"]
|
||||
url_path = router.url_path_for(name, **path_params)
|
||||
return url_path.make_absolute_url(base_url=self.base_url)
|
||||
|
||||
|
||||
async def empty_receive() -> Message:
|
||||
raise RuntimeError("Receive channel has not been made available")
|
||||
|
||||
|
||||
async def empty_send(message: Message) -> None:
|
||||
raise RuntimeError("Send channel has not been made available")
|
||||
|
||||
|
||||
class Request(HTTPConnection):
|
||||
def __init__(
|
||||
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
|
||||
):
|
||||
super().__init__(scope)
|
||||
assert scope["type"] == "http"
|
||||
self._receive = receive
|
||||
self._send = send
|
||||
self._stream_consumed = False
|
||||
self._is_disconnected = False
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self.scope["method"]
|
||||
|
||||
@property
|
||||
def receive(self) -> Receive:
|
||||
return self._receive
|
||||
|
||||
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
if hasattr(self, "_body"):
|
||||
yield self._body
|
||||
yield b""
|
||||
return
|
||||
|
||||
if self._stream_consumed:
|
||||
raise RuntimeError("Stream consumed")
|
||||
|
||||
self._stream_consumed = True
|
||||
while True:
|
||||
message = await self._receive()
|
||||
if message["type"] == "http.request":
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
elif message["type"] == "http.disconnect":
|
||||
self._is_disconnected = True
|
||||
raise ClientDisconnect()
|
||||
yield b""
|
||||
|
||||
async def body(self) -> bytes:
|
||||
if not hasattr(self, "_body"):
|
||||
chunks = []
|
||||
async for chunk in self.stream():
|
||||
chunks.append(chunk)
|
||||
self._body = b"".join(chunks)
|
||||
return self._body
|
||||
|
||||
async def json(self) -> typing.Any:
|
||||
if not hasattr(self, "_json"):
|
||||
body = await self.body()
|
||||
self._json = json.loads(body)
|
||||
return self._json
|
||||
|
||||
async def form(self) -> FormData:
|
||||
if not hasattr(self, "_form"):
|
||||
assert (
|
||||
parse_options_header is not None
|
||||
), "The `python-multipart` library must be installed to use form parsing."
|
||||
content_type_header = self.headers.get("Content-Type")
|
||||
content_type, options = parse_options_header(content_type_header)
|
||||
if content_type == b"multipart/form-data":
|
||||
multipart_parser = MultiPartParser(self.headers, self.stream())
|
||||
self._form = await multipart_parser.parse()
|
||||
elif content_type == b"application/x-www-form-urlencoded":
|
||||
form_parser = FormParser(self.headers, self.stream())
|
||||
self._form = await form_parser.parse()
|
||||
else:
|
||||
self._form = FormData()
|
||||
return self._form
|
||||
|
||||
async def close(self) -> None:
|
||||
if hasattr(self, "_form"):
|
||||
await self._form.close()
|
||||
|
||||
async def is_disconnected(self) -> bool:
|
||||
if not self._is_disconnected:
|
||||
try:
|
||||
message = await asyncio.wait_for(self._receive(), timeout=0.0000001)
|
||||
except asyncio.TimeoutError:
|
||||
message = {}
|
||||
|
||||
if message.get("type") == "http.disconnect":
|
||||
self._is_disconnected = True
|
||||
|
||||
return self._is_disconnected
|
||||
|
||||
async def send_push_promise(self, path: str) -> None:
|
||||
if "http.response.push" in self.scope.get("extensions", {}):
|
||||
raw_headers = []
|
||||
for name in SERVER_PUSH_HEADERS_TO_COPY:
|
||||
for value in self.headers.getlist(name):
|
||||
raw_headers.append(
|
||||
(name.encode("latin-1"), value.encode("latin-1"))
|
||||
)
|
||||
await self._send(
|
||||
{"type": "http.response.push", "path": path, "headers": raw_headers}
|
||||
)
|
318
.venv/lib/python3.9/site-packages/starlette/responses.py
Normal file
318
.venv/lib/python3.9/site-packages/starlette/responses.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import hashlib
|
||||
import http.cookies
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
import typing
|
||||
from email.utils import formatdate
|
||||
from mimetypes import guess_type
|
||||
from urllib.parse import quote, quote_plus
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.concurrency import iterate_in_threadpool, run_until_first_complete
|
||||
from starlette.datastructures import URL, MutableHeaders
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
# Workaround for adding samesite support to pre 3.8 python
|
||||
http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore
|
||||
|
||||
try:
|
||||
import aiofiles
|
||||
from aiofiles.os import stat as aio_stat
|
||||
except ImportError: # pragma: nocover
|
||||
aiofiles = None
|
||||
aio_stat = None
|
||||
|
||||
try:
|
||||
import ujson
|
||||
except ImportError: # pragma: nocover
|
||||
ujson = None # type: ignore
|
||||
|
||||
|
||||
class Response:
|
||||
media_type = None
|
||||
charset = "utf-8"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: typing.Any = None,
|
||||
status_code: int = 200,
|
||||
headers: dict = None,
|
||||
media_type: str = None,
|
||||
background: BackgroundTask = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
self.media_type = media_type
|
||||
self.background = background
|
||||
self.body = self.render(content)
|
||||
self.init_headers(headers)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
if content is None:
|
||||
return b""
|
||||
if isinstance(content, bytes):
|
||||
return content
|
||||
return content.encode(self.charset)
|
||||
|
||||
def init_headers(self, headers: typing.Mapping[str, str] = None) -> None:
|
||||
if headers is None:
|
||||
raw_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
populate_content_length = True
|
||||
populate_content_type = True
|
||||
else:
|
||||
raw_headers = [
|
||||
(k.lower().encode("latin-1"), v.encode("latin-1"))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
keys = [h[0] for h in raw_headers]
|
||||
populate_content_length = b"content-length" not in keys
|
||||
populate_content_type = b"content-type" not in keys
|
||||
|
||||
body = getattr(self, "body", b"")
|
||||
if body and populate_content_length:
|
||||
content_length = str(len(body))
|
||||
raw_headers.append((b"content-length", content_length.encode("latin-1")))
|
||||
|
||||
content_type = self.media_type
|
||||
if content_type is not None and populate_content_type:
|
||||
if content_type.startswith("text/"):
|
||||
content_type += "; charset=" + self.charset
|
||||
raw_headers.append((b"content-type", content_type.encode("latin-1")))
|
||||
|
||||
self.raw_headers = raw_headers
|
||||
|
||||
@property
|
||||
def headers(self) -> MutableHeaders:
|
||||
if not hasattr(self, "_headers"):
|
||||
self._headers = MutableHeaders(raw=self.raw_headers)
|
||||
return self._headers
|
||||
|
||||
def set_cookie(
|
||||
self,
|
||||
key: str,
|
||||
value: str = "",
|
||||
max_age: int = None,
|
||||
expires: int = None,
|
||||
path: str = "/",
|
||||
domain: str = None,
|
||||
secure: bool = False,
|
||||
httponly: bool = False,
|
||||
samesite: str = "lax",
|
||||
) -> None:
|
||||
cookie = http.cookies.SimpleCookie() # type: http.cookies.BaseCookie
|
||||
cookie[key] = value
|
||||
if max_age is not None:
|
||||
cookie[key]["max-age"] = max_age
|
||||
if expires is not None:
|
||||
cookie[key]["expires"] = expires
|
||||
if path is not None:
|
||||
cookie[key]["path"] = path
|
||||
if domain is not None:
|
||||
cookie[key]["domain"] = domain
|
||||
if secure:
|
||||
cookie[key]["secure"] = True
|
||||
if httponly:
|
||||
cookie[key]["httponly"] = True
|
||||
if samesite is not None:
|
||||
assert samesite.lower() in [
|
||||
"strict",
|
||||
"lax",
|
||||
"none",
|
||||
], "samesite must be either 'strict', 'lax' or 'none'"
|
||||
cookie[key]["samesite"] = samesite
|
||||
cookie_val = cookie.output(header="").strip()
|
||||
self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
|
||||
|
||||
def delete_cookie(self, key: str, path: str = "/", domain: str = None) -> None:
|
||||
self.set_cookie(key, expires=0, max_age=0, path=path, domain=domain)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": self.body})
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
|
||||
|
||||
class HTMLResponse(Response):
|
||||
media_type = "text/html"
|
||||
|
||||
|
||||
class PlainTextResponse(Response):
|
||||
media_type = "text/plain"
|
||||
|
||||
|
||||
class JSONResponse(Response):
|
||||
media_type = "application/json"
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
return json.dumps(
|
||||
content,
|
||||
ensure_ascii=False,
|
||||
allow_nan=False,
|
||||
indent=None,
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
|
||||
|
||||
class UJSONResponse(JSONResponse):
|
||||
media_type = "application/json"
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
assert ujson is not None, "ujson must be installed to use UJSONResponse"
|
||||
return ujson.dumps(content, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
|
||||
class RedirectResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
url: typing.Union[str, URL],
|
||||
status_code: int = 307,
|
||||
headers: dict = None,
|
||||
background: BackgroundTask = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
content=b"", status_code=status_code, headers=headers, background=background
|
||||
)
|
||||
self.headers["location"] = quote_plus(str(url), safe=":/%#?&=@[]!$&'()*+,;")
|
||||
|
||||
|
||||
class StreamingResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
content: typing.Any,
|
||||
status_code: int = 200,
|
||||
headers: dict = None,
|
||||
media_type: str = None,
|
||||
background: BackgroundTask = None,
|
||||
) -> None:
|
||||
if inspect.isasyncgen(content):
|
||||
self.body_iterator = content
|
||||
else:
|
||||
self.body_iterator = iterate_in_threadpool(content)
|
||||
self.status_code = status_code
|
||||
self.media_type = self.media_type if media_type is None else media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
|
||||
async def listen_for_disconnect(self, receive: Receive) -> None:
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
break
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
async for chunk in self.body_iterator:
|
||||
if not isinstance(chunk, bytes):
|
||||
chunk = chunk.encode(self.charset)
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await run_until_first_complete(
|
||||
(self.stream_response, {"send": send}),
|
||||
(self.listen_for_disconnect, {"receive": receive}),
|
||||
)
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
|
||||
|
||||
class FileResponse(Response):
|
||||
chunk_size = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
status_code: int = 200,
|
||||
headers: dict = None,
|
||||
media_type: str = None,
|
||||
background: BackgroundTask = None,
|
||||
filename: str = None,
|
||||
stat_result: os.stat_result = None,
|
||||
method: str = None,
|
||||
) -> None:
|
||||
assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse"
|
||||
self.path = path
|
||||
self.status_code = status_code
|
||||
self.filename = filename
|
||||
self.send_header_only = method is not None and method.upper() == "HEAD"
|
||||
if media_type is None:
|
||||
media_type = guess_type(filename or path)[0] or "text/plain"
|
||||
self.media_type = media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
if self.filename is not None:
|
||||
content_disposition_filename = quote(self.filename)
|
||||
if content_disposition_filename != self.filename:
|
||||
content_disposition = "attachment; filename*=utf-8''{}".format(
|
||||
content_disposition_filename
|
||||
)
|
||||
else:
|
||||
content_disposition = 'attachment; filename="{}"'.format(self.filename)
|
||||
self.headers.setdefault("content-disposition", content_disposition)
|
||||
self.stat_result = stat_result
|
||||
if stat_result is not None:
|
||||
self.set_stat_headers(stat_result)
|
||||
|
||||
def set_stat_headers(self, stat_result: os.stat_result) -> None:
|
||||
content_length = str(stat_result.st_size)
|
||||
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
|
||||
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
|
||||
etag = hashlib.md5(etag_base.encode()).hexdigest()
|
||||
|
||||
self.headers.setdefault("content-length", content_length)
|
||||
self.headers.setdefault("last-modified", last_modified)
|
||||
self.headers.setdefault("etag", etag)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.stat_result is None:
|
||||
try:
|
||||
stat_result = await aio_stat(self.path)
|
||||
self.set_stat_headers(stat_result)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"File at path {self.path} does not exist.")
|
||||
else:
|
||||
mode = stat_result.st_mode
|
||||
if not stat.S_ISREG(mode):
|
||||
raise RuntimeError(f"File at path {self.path} is not a file.")
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
if self.send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
else:
|
||||
async with aiofiles.open(self.path, mode="rb") as file:
|
||||
more_body = True
|
||||
while more_body:
|
||||
chunk = await file.read(self.chunk_size)
|
||||
more_body = len(chunk) == self.chunk_size
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": chunk,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
if self.background is not None:
|
||||
await self.background()
|
672
.venv/lib/python3.9/site-packages/starlette/routing.py
Normal file
672
.venv/lib/python3.9/site-packages/starlette/routing.py
Normal file
@@ -0,0 +1,672 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import re
|
||||
import traceback
|
||||
import typing
|
||||
from enum import Enum
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.convertors import CONVERTOR_TYPES, Convertor
|
||||
from starlette.datastructures import URL, Headers, URLPath
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket, WebSocketClose
|
||||
|
||||
|
||||
class NoMatchFound(Exception):
|
||||
"""
|
||||
Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
|
||||
if no matching route exists.
|
||||
"""
|
||||
|
||||
|
||||
class Match(Enum):
|
||||
NONE = 0
|
||||
PARTIAL = 1
|
||||
FULL = 2
|
||||
|
||||
|
||||
def request_response(func: typing.Callable) -> ASGIApp:
|
||||
"""
|
||||
Takes a function or coroutine `func(request) -> response`,
|
||||
and returns an ASGI application.
|
||||
"""
|
||||
is_coroutine = asyncio.iscoroutinefunction(func)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = Request(scope, receive=receive, send=send)
|
||||
if is_coroutine:
|
||||
response = await func(request)
|
||||
else:
|
||||
response = await run_in_threadpool(func, request)
|
||||
await response(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def websocket_session(func: typing.Callable) -> ASGIApp:
|
||||
"""
|
||||
Takes a coroutine `func(session)`, and returns an ASGI application.
|
||||
"""
|
||||
# assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
session = WebSocket(scope, receive=receive, send=send)
|
||||
await func(session)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_name(endpoint: typing.Callable) -> str:
|
||||
if inspect.isfunction(endpoint) or inspect.isclass(endpoint):
|
||||
return endpoint.__name__
|
||||
return endpoint.__class__.__name__
|
||||
|
||||
|
||||
def replace_params(
|
||||
path: str,
|
||||
param_convertors: typing.Dict[str, Convertor],
|
||||
path_params: typing.Dict[str, str],
|
||||
) -> typing.Tuple[str, dict]:
|
||||
for key, value in list(path_params.items()):
|
||||
if "{" + key + "}" in path:
|
||||
convertor = param_convertors[key]
|
||||
value = convertor.to_string(value)
|
||||
path = path.replace("{" + key + "}", value)
|
||||
path_params.pop(key)
|
||||
return path, path_params
|
||||
|
||||
|
||||
# Match parameters in URL paths, eg. '{param}', and '{param:int}'
|
||||
PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
|
||||
|
||||
|
||||
def compile_path(
|
||||
path: str,
|
||||
) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
|
||||
"""
|
||||
Given a path string, like: "/{username:str}", return a three-tuple
|
||||
of (regex, format, {param_name:convertor}).
|
||||
|
||||
regex: "/(?P<username>[^/]+)"
|
||||
format: "/{username}"
|
||||
convertors: {"username": StringConvertor()}
|
||||
"""
|
||||
path_regex = "^"
|
||||
path_format = ""
|
||||
|
||||
idx = 0
|
||||
param_convertors = {}
|
||||
for match in PARAM_REGEX.finditer(path):
|
||||
param_name, convertor_type = match.groups("str")
|
||||
convertor_type = convertor_type.lstrip(":")
|
||||
assert (
|
||||
convertor_type in CONVERTOR_TYPES
|
||||
), f"Unknown path convertor '{convertor_type}'"
|
||||
convertor = CONVERTOR_TYPES[convertor_type]
|
||||
|
||||
path_regex += re.escape(path[idx : match.start()])
|
||||
path_regex += f"(?P<{param_name}>{convertor.regex})"
|
||||
|
||||
path_format += path[idx : match.start()]
|
||||
path_format += "{%s}" % param_name
|
||||
|
||||
param_convertors[param_name] = convertor
|
||||
|
||||
idx = match.end()
|
||||
|
||||
path_regex += re.escape(path[idx:]) + "$"
|
||||
path_format += path[idx:]
|
||||
|
||||
return re.compile(path_regex), path_format, param_convertors
|
||||
|
||||
|
||||
class BaseRoute:
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""
|
||||
A route may be used in isolation as a stand-alone ASGI app.
|
||||
This is a somewhat contrived case, as they'll almost always be used
|
||||
within a Router, but could be useful for some tooling and minimal apps.
|
||||
"""
|
||||
match, child_scope = self.matches(scope)
|
||||
if match == Match.NONE:
|
||||
if scope["type"] == "http":
|
||||
response = PlainTextResponse("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
elif scope["type"] == "websocket":
|
||||
websocket_close = WebSocketClose()
|
||||
await websocket_close(scope, receive, send)
|
||||
return
|
||||
|
||||
scope.update(child_scope)
|
||||
await self.handle(scope, receive, send)
|
||||
|
||||
|
||||
class Route(BaseRoute):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable,
|
||||
*,
|
||||
methods: typing.List[str] = None,
|
||||
name: str = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None:
|
||||
assert path.startswith("/"), "Routed paths must start with '/'"
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
self.include_in_schema = include_in_schema
|
||||
|
||||
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
|
||||
# Endpoint is function or method. Treat it as `func(request) -> response`.
|
||||
self.app = request_response(endpoint)
|
||||
if methods is None:
|
||||
methods = ["GET"]
|
||||
else:
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
if methods is None:
|
||||
self.methods = None
|
||||
else:
|
||||
self.methods = set(method.upper() for method in methods)
|
||||
if "GET" in self.methods:
|
||||
self.methods.add("HEAD")
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
if scope["type"] == "http":
|
||||
match = self.path_regex.match(scope["path"])
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {"endpoint": self.endpoint, "path_params": path_params}
|
||||
if self.methods and scope["method"] not in self.methods:
|
||||
return Match.PARTIAL, child_scope
|
||||
else:
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
seen_params = set(path_params.keys())
|
||||
expected_params = set(self.param_convertors.keys())
|
||||
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound()
|
||||
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="http")
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.methods and scope["method"] not in self.methods:
|
||||
if "app" in scope:
|
||||
raise HTTPException(status_code=405)
|
||||
else:
|
||||
response = PlainTextResponse("Method Not Allowed", status_code=405)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Route)
|
||||
and self.path == other.path
|
||||
and self.endpoint == other.endpoint
|
||||
and self.methods == other.methods
|
||||
)
|
||||
|
||||
|
||||
class WebSocketRoute(BaseRoute):
|
||||
def __init__(
|
||||
self, path: str, endpoint: typing.Callable, *, name: str = None
|
||||
) -> None:
|
||||
assert path.startswith("/"), "Routed paths must start with '/'"
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
|
||||
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
|
||||
# Endpoint is function or method. Treat it as `func(websocket)`.
|
||||
self.app = websocket_session(endpoint)
|
||||
else:
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
if scope["type"] == "websocket":
|
||||
match = self.path_regex.match(scope["path"])
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {"endpoint": self.endpoint, "path_params": path_params}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
seen_params = set(path_params.keys())
|
||||
expected_params = set(self.param_convertors.keys())
|
||||
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound()
|
||||
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="websocket")
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, WebSocketRoute)
|
||||
and self.path == other.path
|
||||
and self.endpoint == other.endpoint
|
||||
)
|
||||
|
||||
|
||||
class Mount(BaseRoute):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
app: ASGIApp = None,
|
||||
routes: typing.Sequence[BaseRoute] = None,
|
||||
name: str = None,
|
||||
) -> None:
|
||||
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
|
||||
assert (
|
||||
app is not None or routes is not None
|
||||
), "Either 'app=...', or 'routes=' must be specified"
|
||||
self.path = path.rstrip("/")
|
||||
if app is not None:
|
||||
self.app = app # type: ASGIApp
|
||||
else:
|
||||
self.app = Router(routes=routes)
|
||||
self.name = name
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(
|
||||
self.path + "/{path:path}"
|
||||
)
|
||||
|
||||
@property
|
||||
def routes(self) -> typing.List[BaseRoute]:
|
||||
return getattr(self.app, "routes", None)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
path = scope["path"]
|
||||
match = self.path_regex.match(path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
remaining_path = "/" + matched_params.pop("path")
|
||||
matched_path = path[: -len(remaining_path)]
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
root_path = scope.get("root_path", "")
|
||||
child_scope = {
|
||||
"path_params": path_params,
|
||||
"app_root_path": scope.get("app_root_path", root_path),
|
||||
"root_path": root_path + matched_path,
|
||||
"path": remaining_path,
|
||||
"endpoint": self.app,
|
||||
}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path_params["path"] = path_params["path"].lstrip("/")
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path)
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
if self.name is None:
|
||||
# No mount name.
|
||||
remaining_name = name
|
||||
else:
|
||||
# 'name' matches "<mount_name>:<child_name>".
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
path_kwarg = path_params.get("path")
|
||||
path_params["path"] = ""
|
||||
path_prefix, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
if path_kwarg is not None:
|
||||
remaining_params["path"] = path_kwarg
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
return URLPath(
|
||||
path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
|
||||
)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound()
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Mount)
|
||||
and self.path == other.path
|
||||
and self.app == other.app
|
||||
)
|
||||
|
||||
|
||||
class Host(BaseRoute):
|
||||
def __init__(self, host: str, app: ASGIApp, name: str = None) -> None:
|
||||
self.host = host
|
||||
self.app = app
|
||||
self.name = name
|
||||
self.host_regex, self.host_format, self.param_convertors = compile_path(host)
|
||||
|
||||
@property
|
||||
def routes(self) -> typing.List[BaseRoute]:
|
||||
return getattr(self.app, "routes", None)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
match = self.host_regex.match(host)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {"path_params": path_params, "endpoint": self.app}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path = path_params.pop("path")
|
||||
host, remaining_params = replace_params(
|
||||
self.host_format, self.param_convertors, path_params
|
||||
)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path, host=host)
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
if self.name is None:
|
||||
# No mount name.
|
||||
remaining_name = name
|
||||
else:
|
||||
# 'name' matches "<mount_name>:<child_name>".
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
host, remaining_params = replace_params(
|
||||
self.host_format, self.param_convertors, path_params
|
||||
)
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
return URLPath(path=str(url), protocol=url.protocol, host=host)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound()
|
||||
|
||||
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Host)
|
||||
and self.host == other.host
|
||||
and self.app == other.app
|
||||
)
|
||||
|
||||
|
||||
class Router:
|
||||
def __init__(
|
||||
self,
|
||||
routes: typing.Sequence[BaseRoute] = None,
|
||||
redirect_slashes: bool = True,
|
||||
default: ASGIApp = None,
|
||||
on_startup: typing.Sequence[typing.Callable] = None,
|
||||
on_shutdown: typing.Sequence[typing.Callable] = None,
|
||||
lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
|
||||
) -> None:
|
||||
self.routes = [] if routes is None else list(routes)
|
||||
self.redirect_slashes = redirect_slashes
|
||||
self.default = self.not_found if default is None else default
|
||||
self.on_startup = [] if on_startup is None else list(on_startup)
|
||||
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
|
||||
|
||||
async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
|
||||
await self.startup()
|
||||
yield
|
||||
await self.shutdown()
|
||||
|
||||
self.lifespan_context = default_lifespan if lifespan is None else lifespan
|
||||
|
||||
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "websocket":
|
||||
websocket_close = WebSocketClose()
|
||||
await websocket_close(scope, receive, send)
|
||||
return
|
||||
|
||||
# If we're running inside a starlette application then raise an
|
||||
# exception, so that the configurable exception handler can deal with
|
||||
# returning the response. For plain ASGI apps, just return the response.
|
||||
if "app" in scope:
|
||||
raise HTTPException(status_code=404)
|
||||
else:
|
||||
response = PlainTextResponse("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
for route in self.routes:
|
||||
try:
|
||||
return route.url_path_for(name, **path_params)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound()
|
||||
|
||||
async def startup(self) -> None:
|
||||
"""
|
||||
Run any `.on_startup` event handlers.
|
||||
"""
|
||||
for handler in self.on_startup:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler()
|
||||
else:
|
||||
handler()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Run any `.on_shutdown` event handlers.
|
||||
"""
|
||||
for handler in self.on_shutdown:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler()
|
||||
else:
|
||||
handler()
|
||||
|
||||
async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""
|
||||
Handle ASGI lifespan messages, which allows us to manage application
|
||||
startup and shutdown events.
|
||||
"""
|
||||
first = True
|
||||
app = scope.get("app")
|
||||
message = await receive()
|
||||
try:
|
||||
if inspect.isasyncgenfunction(self.lifespan_context):
|
||||
async for item in self.lifespan_context(app):
|
||||
assert first, "Lifespan context yielded multiple times."
|
||||
first = False
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
else:
|
||||
for item in self.lifespan_context(app): # type: ignore
|
||||
assert first, "Lifespan context yielded multiple times."
|
||||
first = False
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
except BaseException:
|
||||
if first:
|
||||
exc_text = traceback.format_exc()
|
||||
await send({"type": "lifespan.startup.failed", "message": exc_text})
|
||||
raise
|
||||
else:
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""
|
||||
The main entry point to the Router class.
|
||||
"""
|
||||
assert scope["type"] in ("http", "websocket", "lifespan")
|
||||
|
||||
if "router" not in scope:
|
||||
scope["router"] = self
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
await self.lifespan(scope, receive, send)
|
||||
return
|
||||
|
||||
partial = None
|
||||
|
||||
for route in self.routes:
|
||||
# Determine if any route matches the incoming scope,
|
||||
# and hand over to the matching route if found.
|
||||
match, child_scope = route.matches(scope)
|
||||
if match == Match.FULL:
|
||||
scope.update(child_scope)
|
||||
await route.handle(scope, receive, send)
|
||||
return
|
||||
elif match == Match.PARTIAL and partial is None:
|
||||
partial = route
|
||||
partial_scope = child_scope
|
||||
|
||||
if partial is not None:
|
||||
# Handle partial matches. These are cases where an endpoint is
|
||||
# able to handle the request, but is not a preferred option.
|
||||
# We use this in particular to deal with "405 Method Not Allowed".
|
||||
scope.update(partial_scope)
|
||||
await partial.handle(scope, receive, send)
|
||||
return
|
||||
|
||||
if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/":
|
||||
redirect_scope = dict(scope)
|
||||
if scope["path"].endswith("/"):
|
||||
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
|
||||
else:
|
||||
redirect_scope["path"] = redirect_scope["path"] + "/"
|
||||
|
||||
for route in self.routes:
|
||||
match, child_scope = route.matches(redirect_scope)
|
||||
if match != Match.NONE:
|
||||
redirect_url = URL(scope=redirect_scope)
|
||||
response = RedirectResponse(url=str(redirect_url))
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.default(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, Router) and self.routes == other.routes
|
||||
|
||||
# The following usages are now discouraged in favour of configuration
|
||||
# during Router.__init__(...)
|
||||
def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
|
||||
route = Mount(path, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def host(self, host: str, app: ASGIApp, name: str = None) -> None:
|
||||
route = Host(host, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: typing.Callable,
|
||||
methods: typing.List[str] = None,
|
||||
name: str = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None:
|
||||
route = Route(
|
||||
path,
|
||||
endpoint=endpoint,
|
||||
methods=methods,
|
||||
name=name,
|
||||
include_in_schema=include_in_schema,
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
def add_websocket_route(
|
||||
self, path: str, endpoint: typing.Callable, name: str = None
|
||||
) -> None:
|
||||
route = WebSocketRoute(path, endpoint=endpoint, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def route(
|
||||
self,
|
||||
path: str,
|
||||
methods: typing.List[str] = None,
|
||||
name: str = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> typing.Callable:
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
self.add_route(
|
||||
path,
|
||||
func,
|
||||
methods=methods,
|
||||
name=name,
|
||||
include_in_schema=include_in_schema,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def websocket_route(self, path: str, name: str = None) -> typing.Callable:
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
self.add_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
|
||||
assert event_type in ("startup", "shutdown")
|
||||
|
||||
if event_type == "startup":
|
||||
self.on_startup.append(func)
|
||||
else:
|
||||
self.on_shutdown.append(func)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable:
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
self.add_event_handler(event_type, func)
|
||||
return func
|
||||
|
||||
return decorator
|
135
.venv/lib/python3.9/site-packages/starlette/schemas.py
Normal file
135
.venv/lib/python3.9/site-packages/starlette/schemas.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import BaseRoute, Mount, Route
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError: # pragma: nocover
|
||||
yaml = None # type: ignore
|
||||
|
||||
|
||||
class OpenAPIResponse(Response):
|
||||
media_type = "application/vnd.oai.openapi"
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
|
||||
assert isinstance(
|
||||
content, dict
|
||||
), "The schema passed to OpenAPIResponse should be a dictionary."
|
||||
return yaml.dump(content, default_flow_style=False).encode("utf-8")
|
||||
|
||||
|
||||
class EndpointInfo(typing.NamedTuple):
|
||||
path: str
|
||||
http_method: str
|
||||
func: typing.Callable
|
||||
|
||||
|
||||
class BaseSchemaGenerator:
|
||||
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def get_endpoints(
|
||||
self, routes: typing.List[BaseRoute]
|
||||
) -> typing.List[EndpointInfo]:
|
||||
"""
|
||||
Given the routes, yields the following information:
|
||||
|
||||
- path
|
||||
eg: /users/
|
||||
- http_method
|
||||
one of 'get', 'post', 'put', 'patch', 'delete', 'options'
|
||||
- func
|
||||
method ready to extract the docstring
|
||||
"""
|
||||
endpoints_info: list = []
|
||||
|
||||
for route in routes:
|
||||
if isinstance(route, Mount):
|
||||
routes = route.routes or []
|
||||
sub_endpoints = [
|
||||
EndpointInfo(
|
||||
path="".join((route.path, sub_endpoint.path)),
|
||||
http_method=sub_endpoint.http_method,
|
||||
func=sub_endpoint.func,
|
||||
)
|
||||
for sub_endpoint in self.get_endpoints(routes)
|
||||
]
|
||||
endpoints_info.extend(sub_endpoints)
|
||||
|
||||
elif not isinstance(route, Route) or not route.include_in_schema:
|
||||
continue
|
||||
|
||||
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
|
||||
for method in route.methods or ["GET"]:
|
||||
if method == "HEAD":
|
||||
continue
|
||||
endpoints_info.append(
|
||||
EndpointInfo(route.path, method.lower(), route.endpoint)
|
||||
)
|
||||
else:
|
||||
for method in ["get", "post", "put", "patch", "delete", "options"]:
|
||||
if not hasattr(route.endpoint, method):
|
||||
continue
|
||||
func = getattr(route.endpoint, method)
|
||||
endpoints_info.append(
|
||||
EndpointInfo(route.path, method.lower(), func)
|
||||
)
|
||||
|
||||
return endpoints_info
|
||||
|
||||
def parse_docstring(self, func_or_method: typing.Callable) -> dict:
|
||||
"""
|
||||
Given a function, parse the docstring as YAML and return a dictionary of info.
|
||||
"""
|
||||
docstring = func_or_method.__doc__
|
||||
if not docstring:
|
||||
return {}
|
||||
|
||||
assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
|
||||
|
||||
# We support having regular docstrings before the schema
|
||||
# definition. Here we return just the schema part from
|
||||
# the docstring.
|
||||
docstring = docstring.split("---")[-1]
|
||||
|
||||
parsed = yaml.safe_load(docstring)
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
# A regular docstring (not yaml formatted) can return
|
||||
# a simple string here, which wouldn't follow the schema.
|
||||
return {}
|
||||
|
||||
return parsed
|
||||
|
||||
def OpenAPIResponse(self, request: Request) -> Response:
|
||||
routes = request.app.routes
|
||||
schema = self.get_schema(routes=routes)
|
||||
return OpenAPIResponse(schema)
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
def __init__(self, base_schema: dict) -> None:
|
||||
self.base_schema = base_schema
|
||||
|
||||
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault("paths", {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
|
||||
for endpoint in endpoints_info:
|
||||
|
||||
parsed = self.parse_docstring(endpoint.func)
|
||||
|
||||
if not parsed:
|
||||
continue
|
||||
|
||||
if endpoint.path not in schema["paths"]:
|
||||
schema["paths"][endpoint.path] = {}
|
||||
|
||||
schema["paths"][endpoint.path][endpoint.http_method] = parsed
|
||||
|
||||
return schema
|
219
.venv/lib/python3.9/site-packages/starlette/staticfiles.py
Normal file
219
.venv/lib/python3.9/site-packages/starlette/staticfiles.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import importlib.util
|
||||
import os
|
||||
import stat
|
||||
import typing
|
||||
from email.utils import parsedate
|
||||
|
||||
from aiofiles.os import stat as aio_stat
|
||||
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.responses import (
|
||||
FileResponse,
|
||||
PlainTextResponse,
|
||||
RedirectResponse,
|
||||
Response,
|
||||
)
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
|
||||
class NotModifiedResponse(Response):
|
||||
NOT_MODIFIED_HEADERS = (
|
||||
"cache-control",
|
||||
"content-location",
|
||||
"date",
|
||||
"etag",
|
||||
"expires",
|
||||
"vary",
|
||||
)
|
||||
|
||||
def __init__(self, headers: Headers):
|
||||
super().__init__(
|
||||
status_code=304,
|
||||
headers={
|
||||
name: value
|
||||
for name, value in headers.items()
|
||||
if name in self.NOT_MODIFIED_HEADERS
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class StaticFiles:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
directory: str = None,
|
||||
packages: typing.List[str] = None,
|
||||
html: bool = False,
|
||||
check_dir: bool = True,
|
||||
) -> None:
|
||||
self.directory = directory
|
||||
self.packages = packages
|
||||
self.all_directories = self.get_directories(directory, packages)
|
||||
self.html = html
|
||||
self.config_checked = False
|
||||
if check_dir and directory is not None and not os.path.isdir(directory):
|
||||
raise RuntimeError(f"Directory '{directory}' does not exist")
|
||||
|
||||
def get_directories(
|
||||
self, directory: str = None, packages: typing.List[str] = None
|
||||
) -> typing.List[str]:
|
||||
"""
|
||||
Given `directory` and `packages` arguments, return a list of all the
|
||||
directories that should be used for serving static files from.
|
||||
"""
|
||||
directories = []
|
||||
if directory is not None:
|
||||
directories.append(directory)
|
||||
|
||||
for package in packages or []:
|
||||
spec = importlib.util.find_spec(package)
|
||||
assert spec is not None, f"Package {package!r} could not be found."
|
||||
assert (
|
||||
spec.origin is not None
|
||||
), f"Directory 'statics' in package {package!r} could not be found."
|
||||
directory = os.path.normpath(os.path.join(spec.origin, "..", "statics"))
|
||||
assert os.path.isdir(
|
||||
directory
|
||||
), f"Directory 'statics' in package {package!r} could not be found."
|
||||
directories.append(directory)
|
||||
|
||||
return directories
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""
|
||||
The ASGI entry point.
|
||||
"""
|
||||
assert scope["type"] == "http"
|
||||
|
||||
if not self.config_checked:
|
||||
await self.check_config()
|
||||
self.config_checked = True
|
||||
|
||||
path = self.get_path(scope)
|
||||
response = await self.get_response(path, scope)
|
||||
await response(scope, receive, send)
|
||||
|
||||
def get_path(self, scope: Scope) -> str:
|
||||
"""
|
||||
Given the ASGI scope, return the `path` string to serve up,
|
||||
with OS specific path seperators, and any '..', '.' components removed.
|
||||
"""
|
||||
return os.path.normpath(os.path.join(*scope["path"].split("/")))
|
||||
|
||||
async def get_response(self, path: str, scope: Scope) -> Response:
|
||||
"""
|
||||
Returns an HTTP response, given the incoming path, method and request headers.
|
||||
"""
|
||||
if scope["method"] not in ("GET", "HEAD"):
|
||||
return PlainTextResponse("Method Not Allowed", status_code=405)
|
||||
|
||||
full_path, stat_result = await self.lookup_path(path)
|
||||
|
||||
if stat_result and stat.S_ISREG(stat_result.st_mode):
|
||||
# We have a static file to serve.
|
||||
return self.file_response(full_path, stat_result, scope)
|
||||
|
||||
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
|
||||
# We're in HTML mode, and have got a directory URL.
|
||||
# Check if we have 'index.html' file to serve.
|
||||
index_path = os.path.join(path, "index.html")
|
||||
full_path, stat_result = await self.lookup_path(index_path)
|
||||
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
|
||||
if not scope["path"].endswith("/"):
|
||||
# Directory URLs should redirect to always end in "/".
|
||||
url = URL(scope=scope)
|
||||
url = url.replace(path=url.path + "/")
|
||||
return RedirectResponse(url=url)
|
||||
return self.file_response(full_path, stat_result, scope)
|
||||
|
||||
if self.html:
|
||||
# Check for '404.html' if we're in HTML mode.
|
||||
full_path, stat_result = await self.lookup_path("404.html")
|
||||
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
|
||||
return self.file_response(
|
||||
full_path, stat_result, scope, status_code=404
|
||||
)
|
||||
|
||||
return PlainTextResponse("Not Found", status_code=404)
|
||||
|
||||
async def lookup_path(
|
||||
self, path: str
|
||||
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
|
||||
for directory in self.all_directories:
|
||||
full_path = os.path.realpath(os.path.join(directory, path))
|
||||
directory = os.path.realpath(directory)
|
||||
if os.path.commonprefix([full_path, directory]) != directory:
|
||||
# Don't allow misbehaving clients to break out of the static files directory.
|
||||
continue
|
||||
try:
|
||||
stat_result = await aio_stat(full_path)
|
||||
return (full_path, stat_result)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return ("", None)
|
||||
|
||||
def file_response(
|
||||
self,
|
||||
full_path: str,
|
||||
stat_result: os.stat_result,
|
||||
scope: Scope,
|
||||
status_code: int = 200,
|
||||
) -> Response:
|
||||
method = scope["method"]
|
||||
request_headers = Headers(scope=scope)
|
||||
|
||||
response = FileResponse(
|
||||
full_path, status_code=status_code, stat_result=stat_result, method=method
|
||||
)
|
||||
if self.is_not_modified(response.headers, request_headers):
|
||||
return NotModifiedResponse(response.headers)
|
||||
return response
|
||||
|
||||
async def check_config(self) -> None:
|
||||
"""
|
||||
Perform a one-off configuration check that StaticFiles is actually
|
||||
pointed at a directory, so that we can raise loud errors rather than
|
||||
just returning 404 responses.
|
||||
"""
|
||||
if self.directory is None:
|
||||
return
|
||||
|
||||
try:
|
||||
stat_result = await aio_stat(self.directory)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(
|
||||
f"StaticFiles directory '{self.directory}' does not exist."
|
||||
)
|
||||
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
|
||||
raise RuntimeError(
|
||||
f"StaticFiles path '{self.directory}' is not a directory."
|
||||
)
|
||||
|
||||
def is_not_modified(
|
||||
self, response_headers: Headers, request_headers: Headers
|
||||
) -> bool:
|
||||
"""
|
||||
Given the request and response headers, return `True` if an HTTP
|
||||
"Not Modified" response could be returned instead.
|
||||
"""
|
||||
try:
|
||||
if_none_match = request_headers["if-none-match"]
|
||||
etag = response_headers["etag"]
|
||||
if if_none_match == etag:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
if_modified_since = parsedate(request_headers["if-modified-since"])
|
||||
last_modified = parsedate(response_headers["last-modified"])
|
||||
if (
|
||||
if_modified_since is not None
|
||||
and last_modified is not None
|
||||
and if_modified_since >= last_modified
|
||||
):
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return False
|
81
.venv/lib/python3.9/site-packages/starlette/status.py
Normal file
81
.venv/lib/python3.9/site-packages/starlette/status.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
HTTP codes
|
||||
See RFC 2616 - https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html
|
||||
And RFC 6585 - https://tools.ietf.org/html/rfc6585
|
||||
And RFC 4918 - https://tools.ietf.org/html/rfc4918
|
||||
And RFC 8470 - https://tools.ietf.org/html/rfc8470
|
||||
"""
|
||||
HTTP_100_CONTINUE = 100
|
||||
HTTP_101_SWITCHING_PROTOCOLS = 101
|
||||
HTTP_200_OK = 200
|
||||
HTTP_201_CREATED = 201
|
||||
HTTP_202_ACCEPTED = 202
|
||||
HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203
|
||||
HTTP_204_NO_CONTENT = 204
|
||||
HTTP_205_RESET_CONTENT = 205
|
||||
HTTP_206_PARTIAL_CONTENT = 206
|
||||
HTTP_207_MULTI_STATUS = 207
|
||||
HTTP_300_MULTIPLE_CHOICES = 300
|
||||
HTTP_301_MOVED_PERMANENTLY = 301
|
||||
HTTP_302_FOUND = 302
|
||||
HTTP_303_SEE_OTHER = 303
|
||||
HTTP_304_NOT_MODIFIED = 304
|
||||
HTTP_305_USE_PROXY = 305
|
||||
HTTP_306_RESERVED = 306
|
||||
HTTP_307_TEMPORARY_REDIRECT = 307
|
||||
HTTP_400_BAD_REQUEST = 400
|
||||
HTTP_401_UNAUTHORIZED = 401
|
||||
HTTP_402_PAYMENT_REQUIRED = 402
|
||||
HTTP_403_FORBIDDEN = 403
|
||||
HTTP_404_NOT_FOUND = 404
|
||||
HTTP_405_METHOD_NOT_ALLOWED = 405
|
||||
HTTP_406_NOT_ACCEPTABLE = 406
|
||||
HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407
|
||||
HTTP_408_REQUEST_TIMEOUT = 408
|
||||
HTTP_409_CONFLICT = 409
|
||||
HTTP_410_GONE = 410
|
||||
HTTP_411_LENGTH_REQUIRED = 411
|
||||
HTTP_412_PRECONDITION_FAILED = 412
|
||||
HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413
|
||||
HTTP_414_REQUEST_URI_TOO_LONG = 414
|
||||
HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415
|
||||
HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416
|
||||
HTTP_417_EXPECTATION_FAILED = 417
|
||||
HTTP_422_UNPROCESSABLE_ENTITY = 422
|
||||
HTTP_423_LOCKED = 423
|
||||
HTTP_424_FAILED_DEPENDENCY = 424
|
||||
HTTP_425_TOO_EARLY = 425
|
||||
HTTP_428_PRECONDITION_REQUIRED = 428
|
||||
HTTP_429_TOO_MANY_REQUESTS = 429
|
||||
HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431
|
||||
HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = 500
|
||||
HTTP_501_NOT_IMPLEMENTED = 501
|
||||
HTTP_502_BAD_GATEWAY = 502
|
||||
HTTP_503_SERVICE_UNAVAILABLE = 503
|
||||
HTTP_504_GATEWAY_TIMEOUT = 504
|
||||
HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505
|
||||
HTTP_507_INSUFFICIENT_STORAGE = 507
|
||||
HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511
|
||||
|
||||
|
||||
"""
|
||||
WebSocket codes
|
||||
https://www.iana.org/assignments/websocket/websocket.xml#close-code-number
|
||||
https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent
|
||||
"""
|
||||
WS_1000_NORMAL_CLOSURE = 1000
|
||||
WS_1001_GOING_AWAY = 1001
|
||||
WS_1002_PROTOCOL_ERROR = 1002
|
||||
WS_1003_UNSUPPORTED_DATA = 1003
|
||||
WS_1004_NO_STATUS_RCVD = 1004
|
||||
WS_1005_ABNORMAL_CLOSURE = 1005
|
||||
WS_1007_INVALID_FRAME_PAYLOAD_DATA = 1007
|
||||
WS_1008_POLICY_VIOLATION = 1008
|
||||
WS_1009_MESSAGE_TOO_BIG = 1009
|
||||
WS_1010_MANDATORY_EXT = 1010
|
||||
WS_1011_INTERNAL_ERROR = 1011
|
||||
WS_1012_SERVICE_RESTART = 1012
|
||||
WS_1013_TRY_AGAIN_LATER = 1013
|
||||
WS_1014_BAD_GATEWAY = 1014
|
||||
WS_1015_TLS_HANDSHAKE = 1015
|
88
.venv/lib/python3.9/site-packages/starlette/templating.py
Normal file
88
.venv/lib/python3.9/site-packages/starlette/templating.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import typing
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
except ImportError: # pragma: nocover
|
||||
jinja2 = None # type: ignore
|
||||
|
||||
|
||||
class _TemplateResponse(Response):
|
||||
media_type = "text/html"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template: typing.Any,
|
||||
context: dict,
|
||||
status_code: int = 200,
|
||||
headers: dict = None,
|
||||
media_type: str = None,
|
||||
background: BackgroundTask = None,
|
||||
):
|
||||
self.template = template
|
||||
self.context = context
|
||||
content = template.render(context)
|
||||
super().__init__(content, status_code, headers, media_type, background)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = self.context.get("request", {})
|
||||
extensions = request.get("extensions", {})
|
||||
if "http.response.template" in extensions:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.template",
|
||||
"template": self.template,
|
||||
"context": self.context,
|
||||
}
|
||||
)
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
|
||||
class Jinja2Templates:
|
||||
"""
|
||||
templates = Jinja2Templates("templates")
|
||||
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
"""
|
||||
|
||||
def __init__(self, directory: str) -> None:
|
||||
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
|
||||
self.env = self.get_env(directory)
|
||||
|
||||
def get_env(self, directory: str) -> "jinja2.Environment":
|
||||
@jinja2.contextfunction
|
||||
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
|
||||
request = context["request"]
|
||||
return request.url_for(name, **path_params)
|
||||
|
||||
loader = jinja2.FileSystemLoader(directory)
|
||||
env = jinja2.Environment(loader=loader, autoescape=True)
|
||||
env.globals["url_for"] = url_for
|
||||
return env
|
||||
|
||||
def get_template(self, name: str) -> "jinja2.Template":
|
||||
return self.env.get_template(name)
|
||||
|
||||
def TemplateResponse(
|
||||
self,
|
||||
name: str,
|
||||
context: dict,
|
||||
status_code: int = 200,
|
||||
headers: dict = None,
|
||||
media_type: str = None,
|
||||
background: BackgroundTask = None,
|
||||
) -> _TemplateResponse:
|
||||
if "request" not in context:
|
||||
raise ValueError('context must include a "request" key')
|
||||
template = self.get_template(name)
|
||||
return _TemplateResponse(
|
||||
template,
|
||||
context,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background,
|
||||
)
|
491
.venv/lib/python3.9/site-packages/starlette/testclient.py
Normal file
491
.venv/lib/python3.9/site-packages/starlette/testclient.py
Normal file
@@ -0,0 +1,491 @@
|
||||
import asyncio
|
||||
import http
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
import typing
|
||||
from urllib.parse import unquote, urljoin, urlsplit
|
||||
|
||||
import requests
|
||||
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
# Annotations for `Session.request()`
|
||||
Cookies = typing.Union[
|
||||
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
|
||||
]
|
||||
Params = typing.Union[bytes, typing.MutableMapping[str, str]]
|
||||
DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO]
|
||||
TimeOut = typing.Union[float, typing.Tuple[float, float]]
|
||||
FileType = typing.MutableMapping[str, typing.IO]
|
||||
AuthType = typing.Union[
|
||||
typing.Tuple[str, str],
|
||||
requests.auth.AuthBase,
|
||||
typing.Callable[[requests.Request], requests.Request],
|
||||
]
|
||||
|
||||
|
||||
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
|
||||
ASGI2App = typing.Callable[[Scope], ASGIInstance]
|
||||
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
||||
|
||||
|
||||
class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
|
||||
def get_all(self, key: str, default: str) -> str:
|
||||
return self.getheaders(key)
|
||||
|
||||
|
||||
class _MockOriginalResponse:
|
||||
"""
|
||||
We have to jump through some hoops to present the response as if
|
||||
it was made using urllib3.
|
||||
"""
|
||||
|
||||
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
|
||||
self.msg = _HeaderDict(headers)
|
||||
self.closed = False
|
||||
|
||||
def isclosed(self) -> bool:
|
||||
return self.closed
|
||||
|
||||
|
||||
class _Upgrade(Exception):
|
||||
def __init__(self, session: "WebSocketTestSession") -> None:
|
||||
self.session = session
|
||||
|
||||
|
||||
def _get_reason_phrase(status_code: int) -> str:
|
||||
try:
|
||||
return http.HTTPStatus(status_code).phrase
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
|
||||
if inspect.isclass(app):
|
||||
return hasattr(app, "__await__")
|
||||
elif inspect.isfunction(app):
|
||||
return asyncio.iscoroutinefunction(app)
|
||||
call = getattr(app, "__call__", None)
|
||||
return asyncio.iscoroutinefunction(call)
|
||||
|
||||
|
||||
class _WrapASGI2:
|
||||
"""
|
||||
Provide an ASGI3 interface onto an ASGI2 app.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGI2App) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
instance = self.app(scope)
|
||||
await instance(receive, send)
|
||||
|
||||
|
||||
class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
||||
def __init__(
|
||||
self, app: ASGI3App, raise_server_exceptions: bool = True, root_path: str = ""
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_server_exceptions = raise_server_exceptions
|
||||
self.root_path = root_path
|
||||
|
||||
def send(
|
||||
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> requests.Response:
|
||||
scheme, netloc, path, query, fragment = (
|
||||
str(item) for item in urlsplit(request.url)
|
||||
)
|
||||
|
||||
default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
|
||||
|
||||
if ":" in netloc:
|
||||
host, port_string = netloc.split(":", 1)
|
||||
port = int(port_string)
|
||||
else:
|
||||
host = netloc
|
||||
port = default_port
|
||||
|
||||
# Include the 'host' header.
|
||||
if "host" in request.headers:
|
||||
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
elif port == default_port:
|
||||
headers = [(b"host", host.encode())]
|
||||
else:
|
||||
headers = [(b"host", (f"{host}:{port}").encode())]
|
||||
|
||||
# Include other request headers.
|
||||
headers += [
|
||||
(key.lower().encode(), value.encode())
|
||||
for key, value in request.headers.items()
|
||||
]
|
||||
|
||||
if scheme in {"ws", "wss"}:
|
||||
subprotocol = request.headers.get("sec-websocket-protocol", None)
|
||||
if subprotocol is None:
|
||||
subprotocols = [] # type: typing.Sequence[str]
|
||||
else:
|
||||
subprotocols = [value.strip() for value in subprotocol.split(",")]
|
||||
scope = {
|
||||
"type": "websocket",
|
||||
"path": unquote(path),
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": ["testclient", 50000],
|
||||
"server": [host, port],
|
||||
"subprotocols": subprotocols,
|
||||
}
|
||||
session = WebSocketTestSession(self.app, scope)
|
||||
raise _Upgrade(session)
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"path": unquote(path),
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": ["testclient", 50000],
|
||||
"server": [host, port],
|
||||
"extensions": {"http.response.template": {}},
|
||||
}
|
||||
|
||||
request_complete = False
|
||||
response_started = False
|
||||
response_complete = False
|
||||
raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any]
|
||||
template = None
|
||||
context = None
|
||||
|
||||
async def receive() -> Message:
|
||||
nonlocal request_complete, response_complete
|
||||
|
||||
if request_complete:
|
||||
while not response_complete:
|
||||
await asyncio.sleep(0.0001)
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
body = request.body
|
||||
if isinstance(body, str):
|
||||
body_bytes = body.encode("utf-8") # type: bytes
|
||||
elif body is None:
|
||||
body_bytes = b""
|
||||
elif isinstance(body, types.GeneratorType):
|
||||
try:
|
||||
chunk = body.send(None)
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.encode("utf-8")
|
||||
return {"type": "http.request", "body": chunk, "more_body": True}
|
||||
except StopIteration:
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": b""}
|
||||
else:
|
||||
body_bytes = body
|
||||
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body_bytes}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
nonlocal raw_kwargs, response_started, response_complete, template, context
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
assert (
|
||||
not response_started
|
||||
), 'Received multiple "http.response.start" messages.'
|
||||
raw_kwargs["version"] = 11
|
||||
raw_kwargs["status"] = message["status"]
|
||||
raw_kwargs["reason"] = _get_reason_phrase(message["status"])
|
||||
raw_kwargs["headers"] = [
|
||||
(key.decode(), value.decode()) for key, value in message["headers"]
|
||||
]
|
||||
raw_kwargs["preload_content"] = False
|
||||
raw_kwargs["original_response"] = _MockOriginalResponse(
|
||||
raw_kwargs["headers"]
|
||||
)
|
||||
response_started = True
|
||||
elif message["type"] == "http.response.body":
|
||||
assert (
|
||||
response_started
|
||||
), 'Received "http.response.body" without "http.response.start".'
|
||||
assert (
|
||||
not response_complete
|
||||
), 'Received "http.response.body" after response completed.'
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if request.method != "HEAD":
|
||||
raw_kwargs["body"].write(body)
|
||||
if not more_body:
|
||||
raw_kwargs["body"].seek(0)
|
||||
response_complete = True
|
||||
elif message["type"] == "http.response.template":
|
||||
template = message["template"]
|
||||
context = message["context"]
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(self.app(scope, receive, send))
|
||||
except BaseException as exc:
|
||||
if self.raise_server_exceptions:
|
||||
raise exc from None
|
||||
|
||||
if self.raise_server_exceptions:
|
||||
assert response_started, "TestClient did not receive any response."
|
||||
elif not response_started:
|
||||
raw_kwargs = {
|
||||
"version": 11,
|
||||
"status": 500,
|
||||
"reason": "Internal Server Error",
|
||||
"headers": [],
|
||||
"preload_content": False,
|
||||
"original_response": _MockOriginalResponse([]),
|
||||
"body": io.BytesIO(),
|
||||
}
|
||||
|
||||
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
|
||||
response = self.build_response(request, raw)
|
||||
if template is not None:
|
||||
response.template = template
|
||||
response.context = context
|
||||
return response
|
||||
|
||||
|
||||
class WebSocketTestSession:
|
||||
def __init__(self, app: ASGI3App, scope: Scope) -> None:
|
||||
self.app = app
|
||||
self.scope = scope
|
||||
self.accepted_subprotocol = None
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._receive_queue = queue.Queue() # type: queue.Queue
|
||||
self._send_queue = queue.Queue() # type: queue.Queue
|
||||
self._thread = threading.Thread(target=self._run)
|
||||
self.send({"type": "websocket.connect"})
|
||||
self._thread.start()
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
|
||||
def __enter__(self) -> "WebSocketTestSession":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
self.close(1000)
|
||||
self._thread.join()
|
||||
while not self._send_queue.empty():
|
||||
message = self._send_queue.get()
|
||||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
|
||||
def _run(self) -> None:
|
||||
"""
|
||||
The sub-thread in which the websocket session runs.
|
||||
"""
|
||||
scope = self.scope
|
||||
receive = self._asgi_receive
|
||||
send = self._asgi_send
|
||||
try:
|
||||
self._loop.run_until_complete(self.app(scope, receive, send))
|
||||
except BaseException as exc:
|
||||
self._send_queue.put(exc)
|
||||
|
||||
async def _asgi_receive(self) -> Message:
|
||||
while self._receive_queue.empty():
|
||||
await asyncio.sleep(0)
|
||||
return self._receive_queue.get()
|
||||
|
||||
async def _asgi_send(self, message: Message) -> None:
|
||||
self._send_queue.put(message)
|
||||
|
||||
def _raise_on_close(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.close":
|
||||
raise WebSocketDisconnect(message.get("code", 1000))
|
||||
|
||||
def send(self, message: Message) -> None:
|
||||
self._receive_queue.put(message)
|
||||
|
||||
def send_text(self, data: str) -> None:
|
||||
self.send({"type": "websocket.receive", "text": data})
|
||||
|
||||
def send_bytes(self, data: bytes) -> None:
|
||||
self.send({"type": "websocket.receive", "bytes": data})
|
||||
|
||||
def send_json(self, data: typing.Any, mode: str = "text") -> None:
|
||||
assert mode in ["text", "binary"]
|
||||
text = json.dumps(data)
|
||||
if mode == "text":
|
||||
self.send({"type": "websocket.receive", "text": text})
|
||||
else:
|
||||
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
|
||||
|
||||
def close(self, code: int = 1000) -> None:
|
||||
self.send({"type": "websocket.disconnect", "code": code})
|
||||
|
||||
def receive(self) -> Message:
|
||||
message = self._send_queue.get()
|
||||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
return message
|
||||
|
||||
def receive_text(self) -> str:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return message["text"]
|
||||
|
||||
def receive_bytes(self) -> bytes:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return message["bytes"]
|
||||
|
||||
def receive_json(self, mode: str = "text") -> typing.Any:
|
||||
assert mode in ["text", "binary"]
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
if mode == "text":
|
||||
text = message["text"]
|
||||
else:
|
||||
text = message["bytes"].decode("utf-8")
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
class TestClient(requests.Session):
|
||||
__test__ = False # For pytest to not discover this up.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: typing.Union[ASGI2App, ASGI3App],
|
||||
base_url: str = "http://testserver",
|
||||
raise_server_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
) -> None:
|
||||
super(TestClient, self).__init__()
|
||||
if _is_asgi3(app):
|
||||
app = typing.cast(ASGI3App, app)
|
||||
asgi_app = app
|
||||
else:
|
||||
app = typing.cast(ASGI2App, app)
|
||||
asgi_app = _WrapASGI2(app) # type: ignore
|
||||
adapter = _ASGIAdapter(
|
||||
asgi_app,
|
||||
raise_server_exceptions=raise_server_exceptions,
|
||||
root_path=root_path,
|
||||
)
|
||||
self.mount("http://", adapter)
|
||||
self.mount("https://", adapter)
|
||||
self.mount("ws://", adapter)
|
||||
self.mount("wss://", adapter)
|
||||
self.headers.update({"user-agent": "testclient"})
|
||||
self.app = asgi_app
|
||||
self.base_url = base_url
|
||||
|
||||
def request( # type: ignore
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
params: Params = None,
|
||||
data: DataType = None,
|
||||
headers: typing.MutableMapping[str, str] = None,
|
||||
cookies: Cookies = None,
|
||||
files: FileType = None,
|
||||
auth: AuthType = None,
|
||||
timeout: TimeOut = None,
|
||||
allow_redirects: bool = None,
|
||||
proxies: typing.MutableMapping[str, str] = None,
|
||||
hooks: typing.Any = None,
|
||||
stream: bool = None,
|
||||
verify: typing.Union[bool, str] = None,
|
||||
cert: typing.Union[str, typing.Tuple[str, str]] = None,
|
||||
json: typing.Any = None,
|
||||
) -> requests.Response:
|
||||
url = urljoin(self.base_url, url)
|
||||
return super().request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
files=files,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
hooks=hooks,
|
||||
stream=stream,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
json=json,
|
||||
)
|
||||
|
||||
def websocket_connect(
|
||||
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
url = urljoin("ws://testserver", url)
|
||||
headers = kwargs.get("headers", {})
|
||||
headers.setdefault("connection", "upgrade")
|
||||
headers.setdefault("sec-websocket-key", "testserver==")
|
||||
headers.setdefault("sec-websocket-version", "13")
|
||||
if subprotocols is not None:
|
||||
headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
|
||||
kwargs["headers"] = headers
|
||||
try:
|
||||
super().request("GET", url, **kwargs)
|
||||
except _Upgrade as exc:
|
||||
session = exc.session
|
||||
else:
|
||||
raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
|
||||
|
||||
return session
|
||||
|
||||
def __enter__(self) -> requests.Session:
|
||||
loop = asyncio.get_event_loop()
|
||||
self.send_queue = asyncio.Queue() # type: asyncio.Queue
|
||||
self.receive_queue = asyncio.Queue() # type: asyncio.Queue
|
||||
self.task = loop.create_task(self.lifespan())
|
||||
loop.run_until_complete(self.wait_startup())
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.wait_shutdown())
|
||||
|
||||
async def lifespan(self) -> None:
|
||||
scope = {"type": "lifespan"}
|
||||
try:
|
||||
await self.app(scope, self.receive_queue.get, self.send_queue.put)
|
||||
finally:
|
||||
await self.send_queue.put(None)
|
||||
|
||||
async def wait_startup(self) -> None:
|
||||
await self.receive_queue.put({"type": "lifespan.startup"})
|
||||
message = await self.send_queue.get()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
assert message["type"] in (
|
||||
"lifespan.startup.complete",
|
||||
"lifespan.startup.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.startup.failed":
|
||||
message = await self.send_queue.get()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
|
||||
async def wait_shutdown(self) -> None:
|
||||
await self.receive_queue.put({"type": "lifespan.shutdown"})
|
||||
message = await self.send_queue.get()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
assert message["type"] == "lifespan.shutdown.complete"
|
||||
await self.task
|
9
.venv/lib/python3.9/site-packages/starlette/types.py
Normal file
9
.venv/lib/python3.9/site-packages/starlette/types.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import typing
|
||||
|
||||
Scope = typing.MutableMapping[str, typing.Any]
|
||||
Message = typing.MutableMapping[str, typing.Any]
|
||||
|
||||
Receive = typing.Callable[[], typing.Awaitable[Message]]
|
||||
Send = typing.Callable[[Message], typing.Awaitable[None]]
|
||||
|
||||
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
150
.venv/lib/python3.9/site-packages/starlette/websockets.py
Normal file
150
.venv/lib/python3.9/site-packages/starlette/websockets.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import enum
|
||||
import json
|
||||
import typing
|
||||
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class WebSocketState(enum.Enum):
|
||||
CONNECTING = 0
|
||||
CONNECTED = 1
|
||||
DISCONNECTED = 2
|
||||
|
||||
|
||||
class WebSocketDisconnect(Exception):
|
||||
def __init__(self, code: int = 1000) -> None:
|
||||
self.code = code
|
||||
|
||||
|
||||
class WebSocket(HTTPConnection):
|
||||
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
super().__init__(scope)
|
||||
assert scope["type"] == "websocket"
|
||||
self._receive = receive
|
||||
self._send = send
|
||||
self.client_state = WebSocketState.CONNECTING
|
||||
self.application_state = WebSocketState.CONNECTING
|
||||
|
||||
async def receive(self) -> Message:
|
||||
"""
|
||||
Receive ASGI websocket messages, ensuring valid state transitions.
|
||||
"""
|
||||
if self.client_state == WebSocketState.CONNECTING:
|
||||
message = await self._receive()
|
||||
message_type = message["type"]
|
||||
assert message_type == "websocket.connect"
|
||||
self.client_state = WebSocketState.CONNECTED
|
||||
return message
|
||||
elif self.client_state == WebSocketState.CONNECTED:
|
||||
message = await self._receive()
|
||||
message_type = message["type"]
|
||||
assert message_type in {"websocket.receive", "websocket.disconnect"}
|
||||
if message_type == "websocket.disconnect":
|
||||
self.client_state = WebSocketState.DISCONNECTED
|
||||
return message
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Cannot call "receive" once a disconnect message has been received.'
|
||||
)
|
||||
|
||||
async def send(self, message: Message) -> None:
|
||||
"""
|
||||
Send ASGI websocket messages, ensuring valid state transitions.
|
||||
"""
|
||||
if self.application_state == WebSocketState.CONNECTING:
|
||||
message_type = message["type"]
|
||||
assert message_type in {"websocket.accept", "websocket.close"}
|
||||
if message_type == "websocket.close":
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
else:
|
||||
self.application_state = WebSocketState.CONNECTED
|
||||
await self._send(message)
|
||||
elif self.application_state == WebSocketState.CONNECTED:
|
||||
message_type = message["type"]
|
||||
assert message_type in {"websocket.send", "websocket.close"}
|
||||
if message_type == "websocket.close":
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
await self._send(message)
|
||||
else:
|
||||
raise RuntimeError('Cannot call "send" once a close message has been sent.')
|
||||
|
||||
async def accept(self, subprotocol: str = None) -> None:
|
||||
if self.client_state == WebSocketState.CONNECTING:
|
||||
# If we haven't yet seen the 'connect' message, then wait for it first.
|
||||
await self.receive()
|
||||
await self.send({"type": "websocket.accept", "subprotocol": subprotocol})
|
||||
|
||||
def _raise_on_disconnect(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.disconnect":
|
||||
raise WebSocketDisconnect(message["code"])
|
||||
|
||||
async def receive_text(self) -> str:
|
||||
assert self.application_state == WebSocketState.CONNECTED
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
return message["text"]
|
||||
|
||||
async def receive_bytes(self) -> bytes:
|
||||
assert self.application_state == WebSocketState.CONNECTED
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
return message["bytes"]
|
||||
|
||||
async def receive_json(self, mode: str = "text") -> typing.Any:
|
||||
assert mode in ["text", "binary"]
|
||||
assert self.application_state == WebSocketState.CONNECTED
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
|
||||
if mode == "text":
|
||||
text = message["text"]
|
||||
else:
|
||||
text = message["bytes"].decode("utf-8")
|
||||
return json.loads(text)
|
||||
|
||||
async def iter_text(self) -> typing.AsyncIterator[str]:
|
||||
try:
|
||||
while True:
|
||||
yield await self.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
|
||||
try:
|
||||
while True:
|
||||
yield await self.receive_bytes()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
|
||||
try:
|
||||
while True:
|
||||
yield await self.receive_json()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
async def send_text(self, data: str) -> None:
|
||||
await self.send({"type": "websocket.send", "text": data})
|
||||
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.send({"type": "websocket.send", "bytes": data})
|
||||
|
||||
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
|
||||
assert mode in ["text", "binary"]
|
||||
text = json.dumps(data)
|
||||
if mode == "text":
|
||||
await self.send({"type": "websocket.send", "text": text})
|
||||
else:
|
||||
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
|
||||
|
||||
async def close(self, code: int = 1000) -> None:
|
||||
await self.send({"type": "websocket.close", "code": code})
|
||||
|
||||
|
||||
class WebSocketClose:
|
||||
def __init__(self, code: int = 1000) -> None:
|
||||
self.code = code
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await send({"type": "websocket.close", "code": self.code})
|
Reference in New Issue
Block a user