new
This commit is contained in:
1
.venv/lib/python3.9/site-packages/hypercorn/__about__.py
Normal file
1
.venv/lib/python3.9/site-packages/hypercorn/__about__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.11.2"
|
4
.venv/lib/python3.9/site-packages/hypercorn/__init__.py
Normal file
4
.venv/lib/python3.9/site-packages/hypercorn/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .__about__ import __version__
|
||||
from .config import Config
|
||||
|
||||
__all__ = ("__version__", "Config")
|
271
.venv/lib/python3.9/site-packages/hypercorn/__main__.py
Normal file
271
.venv/lib/python3.9/site-packages/hypercorn/__main__.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import argparse
|
||||
import ssl
|
||||
import sys
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from .config import Config
|
||||
from .run import run
|
||||
|
||||
sentinel = object()
|
||||
|
||||
|
||||
def _load_config(config_path: Optional[str]) -> Config:
|
||||
if config_path is None:
|
||||
return Config()
|
||||
elif config_path.startswith("python:"):
|
||||
return Config.from_object(config_path[len("python:") :])
|
||||
elif config_path.startswith("file:"):
|
||||
return Config.from_pyfile(config_path[len("file:") :])
|
||||
else:
|
||||
return Config.from_toml(config_path)
|
||||
|
||||
|
||||
def main(sys_args: Optional[List[str]] = None) -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"application", help="The application to dispatch to as path.to.module:instance.path"
|
||||
)
|
||||
parser.add_argument("--access-log", help="Deprecated, see access-logfile", default=sentinel)
|
||||
parser.add_argument(
|
||||
"--access-logfile",
|
||||
help="The target location for the access log, use `-` for stdout",
|
||||
default=sentinel,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--access-logformat",
|
||||
help="The log format for the access log, see help docs",
|
||||
default=sentinel,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backlog", help="The maximum number of pending connections", type=int, default=sentinel
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--bind",
|
||||
dest="binds",
|
||||
help=""" The TCP host/address to bind to. Should be either host:port, host,
|
||||
unix:path or fd://num, e.g. 127.0.0.1:5000, 127.0.0.1,
|
||||
unix:/tmp/socket or fd://33 respectively. """,
|
||||
default=[],
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument("--ca-certs", help="Path to the SSL CA certificate file", default=sentinel)
|
||||
parser.add_argument("--certfile", help="Path to the SSL certificate file", default=sentinel)
|
||||
parser.add_argument("--cert-reqs", help="See verify mode argument", type=int, default=sentinel)
|
||||
parser.add_argument("--ciphers", help="Ciphers to use for the SSL setup", default=sentinel)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
help="Location of a TOML config file, or when prefixed with `file:` a Python file, or when prefixed with `python:` a Python module.", # noqa: E501
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
help="Enable debug mode, i.e. extra logging and checks",
|
||||
action="store_true",
|
||||
default=sentinel,
|
||||
)
|
||||
parser.add_argument("--error-log", help="Deprecated, see error-logfile", default=sentinel)
|
||||
parser.add_argument(
|
||||
"--error-logfile",
|
||||
"--log-file",
|
||||
dest="error_logfile",
|
||||
help="The target location for the error log, use `-` for stderr",
|
||||
default=sentinel,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--graceful-timeout",
|
||||
help="""Time to wait after SIGTERM or Ctrl-C for any remaining requests (tasks)
|
||||
to complete.""",
|
||||
default=sentinel,
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--group", help="Group to own any unix sockets.", default=sentinel, type=int
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--worker-class",
|
||||
dest="worker_class",
|
||||
help="The type of worker to use. "
|
||||
"Options include asyncio, uvloop (pip install hypercorn[uvloop]), "
|
||||
"and trio (pip install hypercorn[trio]).",
|
||||
default=sentinel,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-alive",
|
||||
help="Seconds to keep inactive connections alive for",
|
||||
default=sentinel,
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument("--keyfile", help="Path to the SSL key file", default=sentinel)
|
||||
parser.add_argument(
|
||||
"--insecure-bind",
|
||||
dest="insecure_binds",
|
||||
help="""The TCP host/address to bind to. SSL options will not apply to these binds.
|
||||
See *bind* for formatting options. Care must be taken! See HTTP -> HTTPS redirection docs.
|
||||
""",
|
||||
default=[],
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-config", help="A Python logging configuration file.", default=sentinel
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level", help="The (error) log level, defaults to info", default="INFO"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--pid", help="Location to write the PID (Program ID) to.", default=sentinel
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quic-bind",
|
||||
dest="quic_binds",
|
||||
help="""The UDP/QUIC host/address to bind to. See *bind* for formatting
|
||||
options.
|
||||
""",
|
||||
default=[],
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
help="Enable automatic reloads on code changes",
|
||||
action="store_true",
|
||||
default=sentinel,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path", help="The setting for the ASGI root_path variable", default=sentinel
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-name",
|
||||
dest="server_names",
|
||||
help="""The hostnames that can be served, requests to different hosts
|
||||
will be responded to with 404s.
|
||||
""",
|
||||
default=[],
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--statsd-host", help="The host:port of the statsd server", default=sentinel
|
||||
)
|
||||
parser.add_argument("--statsd-prefix", help="Prefix for all statsd messages", default="")
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--umask",
|
||||
help="The permissions bit mask to use on any unix sockets.",
|
||||
default=sentinel,
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u", "--user", help="User to own any unix sockets.", default=sentinel, type=int
|
||||
)
|
||||
|
||||
def _convert_verify_mode(value: str) -> ssl.VerifyMode:
|
||||
try:
|
||||
return ssl.VerifyMode[value]
|
||||
except KeyError:
|
||||
raise argparse.ArgumentTypeError(f"'{value}' is not a valid verify mode")
|
||||
|
||||
parser.add_argument(
|
||||
"--verify-mode",
|
||||
help="SSL verify mode for peer's certificate, see ssl.VerifyMode enum for possible values.",
|
||||
type=_convert_verify_mode,
|
||||
default=sentinel,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--websocket-ping-interval",
|
||||
help="""If set this is the time in seconds between pings sent to the client.
|
||||
This can be used to keep the websocket connection alive.""",
|
||||
default=sentinel,
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--workers",
|
||||
dest="workers",
|
||||
help="The number of workers to spawn and use",
|
||||
default=sentinel,
|
||||
type=int,
|
||||
)
|
||||
args = parser.parse_args(sys_args or sys.argv[1:])
|
||||
config = _load_config(args.config)
|
||||
config.application_path = args.application
|
||||
config.loglevel = args.log_level
|
||||
|
||||
if args.access_logformat is not sentinel:
|
||||
config.access_log_format = args.access_logformat
|
||||
if args.access_log is not sentinel:
|
||||
warnings.warn(
|
||||
"The --access-log argument is deprecated, use `--access-logfile` instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
config.accesslog = args.access_log
|
||||
if args.access_logfile is not sentinel:
|
||||
config.accesslog = args.access_logfile
|
||||
if args.backlog is not sentinel:
|
||||
config.backlog = args.backlog
|
||||
if args.ca_certs is not sentinel:
|
||||
config.ca_certs = args.ca_certs
|
||||
if args.certfile is not sentinel:
|
||||
config.certfile = args.certfile
|
||||
if args.cert_reqs is not sentinel:
|
||||
config.cert_reqs = args.cert_reqs
|
||||
if args.ciphers is not sentinel:
|
||||
config.ciphers = args.ciphers
|
||||
if args.debug is not sentinel:
|
||||
config.debug = args.debug
|
||||
if args.error_log is not sentinel:
|
||||
warnings.warn(
|
||||
"The --error-log argument is deprecated, use `--error-logfile` instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
config.errorlog = args.error_log
|
||||
if args.error_logfile is not sentinel:
|
||||
config.errorlog = args.error_logfile
|
||||
if args.graceful_timeout is not sentinel:
|
||||
config.graceful_timeout = args.graceful_timeout
|
||||
if args.group is not sentinel:
|
||||
config.group = args.group
|
||||
if args.keep_alive is not sentinel:
|
||||
config.keep_alive_timeout = args.keep_alive
|
||||
if args.keyfile is not sentinel:
|
||||
config.keyfile = args.keyfile
|
||||
if args.log_config is not sentinel:
|
||||
config.logconfig = args.log_config
|
||||
if args.pid is not sentinel:
|
||||
config.pid_path = args.pid
|
||||
if args.root_path is not sentinel:
|
||||
config.root_path = args.root_path
|
||||
if args.reload is not sentinel:
|
||||
config.use_reloader = args.reload
|
||||
if args.statsd_host is not sentinel:
|
||||
config.statsd_host = args.statsd_host
|
||||
if args.statsd_prefix is not sentinel:
|
||||
config.statsd_prefix = args.statsd_prefix
|
||||
if args.umask is not sentinel:
|
||||
config.umask = args.umask
|
||||
if args.user is not sentinel:
|
||||
config.user = args.user
|
||||
if args.worker_class is not sentinel:
|
||||
config.worker_class = args.worker_class
|
||||
if args.verify_mode is not sentinel:
|
||||
config.verify_mode = args.verify_mode
|
||||
if args.websocket_ping_interval is not sentinel:
|
||||
config.websocket_ping_interval = args.websocket_ping_interval
|
||||
if args.workers is not sentinel:
|
||||
config.workers = args.workers
|
||||
|
||||
if len(args.binds) > 0:
|
||||
config.bind = args.binds
|
||||
if len(args.insecure_binds) > 0:
|
||||
config.insecure_bind = args.insecure_binds
|
||||
if len(args.quic_binds) > 0:
|
||||
config.quic_bind = args.quic_binds
|
||||
if len(args.server_names) > 0:
|
||||
config.server_names = args.server_names
|
||||
|
||||
run(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,39 @@
|
||||
import warnings
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from .run import worker_serve
|
||||
from ..config import Config
|
||||
from ..typing import ASGIFramework
|
||||
|
||||
|
||||
async def serve(
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
*,
|
||||
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
"""Serve an ASGI framework app given the config.
|
||||
|
||||
This allows for a programmatic way to serve an ASGI framework, it
|
||||
can be used via,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
asyncio.run(serve(app, config))
|
||||
|
||||
It is assumed that the event-loop is configured before calling
|
||||
this function, therefore configuration values that relate to loop
|
||||
setup or process setup are ignored.
|
||||
|
||||
Arguments:
|
||||
app: The ASGI application to serve.
|
||||
config: A Hypercorn configuration object.
|
||||
shutdown_trigger: This should return to trigger a graceful
|
||||
shutdown.
|
||||
"""
|
||||
if config.debug:
|
||||
warnings.warn("The config `debug` has no affect when using serve", Warning)
|
||||
if config.workers != 1:
|
||||
warnings.warn("The config `workers` has no affect when using serve", Warning)
|
||||
|
||||
await worker_serve(app, config, shutdown_trigger=shutdown_trigger)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,74 @@
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Callable, Optional, Type, Union
|
||||
|
||||
from .task_group import TaskGroup
|
||||
from ..config import Config
|
||||
from ..typing import (
|
||||
ASGIFramework,
|
||||
ASGIReceiveCallable,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
Event,
|
||||
Scope,
|
||||
)
|
||||
from ..utils import invoke_asgi
|
||||
|
||||
|
||||
class EventWrapper:
|
||||
def __init__(self) -> None:
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def clear(self) -> None:
|
||||
self._event.clear()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
async def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
|
||||
async def _handle(
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
scope: Scope,
|
||||
receive: ASGIReceiveCallable,
|
||||
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
|
||||
) -> None:
|
||||
try:
|
||||
await invoke_asgi(app, scope, receive, send)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
await config.log.exception("Error in ASGI Framework")
|
||||
finally:
|
||||
await send(None)
|
||||
|
||||
|
||||
class Context:
|
||||
event_class: Type[Event] = EventWrapper
|
||||
|
||||
def __init__(self, task_group: TaskGroup) -> None:
|
||||
self.task_group = task_group
|
||||
|
||||
async def spawn_app(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
scope: Scope,
|
||||
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
|
||||
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
|
||||
app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue(config.max_app_queue_size)
|
||||
self.task_group.spawn(_handle(app, config, scope, app_queue.get, send))
|
||||
return app_queue.put
|
||||
|
||||
def spawn(self, func: Callable, *args: Any) -> None:
|
||||
self.task_group.spawn(func(*args))
|
||||
|
||||
@staticmethod
|
||||
async def sleep(wait: Union[float, int]) -> None:
|
||||
return await asyncio.sleep(wait)
|
||||
|
||||
@staticmethod
|
||||
def time() -> float:
|
||||
return asyncio.get_event_loop().time()
|
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
|
||||
from ..config import Config
|
||||
from ..typing import ASGIFramework, ASGIReceiveEvent, ASGISendEvent, LifespanScope
|
||||
from ..utils import invoke_asgi, LifespanFailure, LifespanTimeout
|
||||
|
||||
|
||||
class UnexpectedMessage(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Lifespan:
|
||||
def __init__(self, app: ASGIFramework, config: Config) -> None:
|
||||
self.app = app
|
||||
self.config = config
|
||||
self.startup = asyncio.Event()
|
||||
self.shutdown = asyncio.Event()
|
||||
self.app_queue: asyncio.Queue = asyncio.Queue(config.max_app_queue_size)
|
||||
self.supported = True
|
||||
|
||||
# This mimics the Trio nursery.start task_status and is
|
||||
# required to ensure the support has been checked before
|
||||
# waiting on timeouts.
|
||||
self._started = asyncio.Event()
|
||||
|
||||
async def handle_lifespan(self) -> None:
|
||||
self._started.set()
|
||||
scope: LifespanScope = {"type": "lifespan", "asgi": {"spec_version": "2.0"}}
|
||||
try:
|
||||
await invoke_asgi(self.app, scope, self.asgi_receive, self.asgi_send)
|
||||
except LifespanFailure:
|
||||
# Lifespan failures should crash the server
|
||||
raise
|
||||
except Exception:
|
||||
self.supported = False
|
||||
if not self.startup.is_set():
|
||||
message = "ASGI Framework Lifespan error, continuing without Lifespan support"
|
||||
elif not self.shutdown.is_set():
|
||||
message = "ASGI Framework Lifespan error, shutdown without Lifespan support"
|
||||
else:
|
||||
message = "ASGI Framework Lifespan errored after shutdown."
|
||||
|
||||
await self.config.log.exception(message)
|
||||
finally:
|
||||
self.startup.set()
|
||||
self.shutdown.set()
|
||||
|
||||
async def wait_for_startup(self) -> None:
|
||||
await self._started.wait()
|
||||
if not self.supported:
|
||||
return
|
||||
|
||||
await self.app_queue.put({"type": "lifespan.startup"})
|
||||
try:
|
||||
await asyncio.wait_for(self.startup.wait(), timeout=self.config.startup_timeout)
|
||||
except asyncio.TimeoutError as error:
|
||||
raise LifespanTimeout("startup") from error
|
||||
|
||||
async def wait_for_shutdown(self) -> None:
|
||||
await self._started.wait()
|
||||
if not self.supported:
|
||||
return
|
||||
|
||||
await self.app_queue.put({"type": "lifespan.shutdown"})
|
||||
try:
|
||||
await asyncio.wait_for(self.shutdown.wait(), timeout=self.config.shutdown_timeout)
|
||||
except asyncio.TimeoutError as error:
|
||||
raise LifespanTimeout("shutdown") from error
|
||||
|
||||
async def asgi_receive(self) -> ASGIReceiveEvent:
|
||||
return await self.app_queue.get()
|
||||
|
||||
async def asgi_send(self, message: ASGISendEvent) -> None:
|
||||
if message["type"] == "lifespan.startup.complete":
|
||||
self.startup.set()
|
||||
elif message["type"] == "lifespan.shutdown.complete":
|
||||
self.shutdown.set()
|
||||
elif message["type"] == "lifespan.startup.failed":
|
||||
self.startup.set()
|
||||
raise LifespanFailure("startup", message["message"])
|
||||
elif message["type"] == "lifespan.shutdown.failed":
|
||||
self.shutdown.set()
|
||||
raise LifespanFailure("shutdown", message["message"])
|
||||
else:
|
||||
raise UnexpectedMessage(message["type"])
|
266
.venv/lib/python3.9/site-packages/hypercorn/asyncio/run.py
Normal file
266
.venv/lib/python3.9/site-packages/hypercorn/asyncio/run.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import asyncio
|
||||
import platform
|
||||
import signal
|
||||
import ssl
|
||||
from functools import partial
|
||||
from multiprocessing.synchronize import Event as EventType
|
||||
from os import getpid
|
||||
from socket import socket
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
|
||||
from .lifespan import Lifespan
|
||||
from .statsd import StatsdLogger
|
||||
from .tcp_server import TCPServer
|
||||
from .udp_server import UDPServer
|
||||
from ..config import Config, Sockets
|
||||
from ..typing import ASGIFramework
|
||||
from ..utils import (
|
||||
check_multiprocess_shutdown_event,
|
||||
load_application,
|
||||
MustReloadException,
|
||||
observe_changes,
|
||||
raise_shutdown,
|
||||
repr_socket_addr,
|
||||
restart,
|
||||
Shutdown,
|
||||
)
|
||||
|
||||
try:
|
||||
from socket import AF_UNIX
|
||||
except ImportError:
|
||||
AF_UNIX = None
|
||||
|
||||
|
||||
async def _windows_signal_support() -> None:
|
||||
# See https://bugs.python.org/issue23057, to catch signals on
|
||||
# Windows it is necessary for an IO event to happen periodically.
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
def _share_socket(sock: socket) -> socket:
|
||||
# Windows requires the socket be explicitly shared across
|
||||
# multiple workers (processes).
|
||||
from socket import fromshare # type: ignore
|
||||
|
||||
sock_data = sock.share(getpid()) # type: ignore
|
||||
return fromshare(sock_data)
|
||||
|
||||
|
||||
async def worker_serve(
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
*,
|
||||
sockets: Optional[Sockets] = None,
|
||||
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
config.set_statsd_logger_class(StatsdLogger)
|
||||
|
||||
lifespan = Lifespan(app, config)
|
||||
lifespan_task = asyncio.ensure_future(lifespan.handle_lifespan())
|
||||
|
||||
await lifespan.wait_for_startup()
|
||||
if lifespan_task.done():
|
||||
exception = lifespan_task.exception()
|
||||
if exception is not None:
|
||||
raise exception
|
||||
|
||||
if sockets is None:
|
||||
sockets = config.create_sockets()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = []
|
||||
if platform.system() == "Windows":
|
||||
tasks.append(loop.create_task(_windows_signal_support()))
|
||||
|
||||
if shutdown_trigger is None:
|
||||
signal_event = asyncio.Event()
|
||||
|
||||
def _signal_handler(*_: Any) -> None: # noqa: N803
|
||||
signal_event.set()
|
||||
|
||||
for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}:
|
||||
if hasattr(signal, signal_name):
|
||||
try:
|
||||
loop.add_signal_handler(getattr(signal, signal_name), _signal_handler)
|
||||
except NotImplementedError:
|
||||
# Add signal handler may not be implemented on Windows
|
||||
signal.signal(getattr(signal, signal_name), _signal_handler)
|
||||
|
||||
shutdown_trigger = signal_event.wait # type: ignore
|
||||
|
||||
tasks.append(loop.create_task(raise_shutdown(shutdown_trigger)))
|
||||
|
||||
if config.use_reloader:
|
||||
tasks.append(loop.create_task(observe_changes(asyncio.sleep)))
|
||||
|
||||
ssl_handshake_timeout = None
|
||||
if config.ssl_enabled:
|
||||
ssl_context = config.create_ssl_context()
|
||||
ssl_handshake_timeout = config.ssl_handshake_timeout
|
||||
|
||||
async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
await TCPServer(app, loop, config, reader, writer)
|
||||
|
||||
servers = []
|
||||
for sock in sockets.secure_sockets:
|
||||
if config.workers > 1 and platform.system() == "Windows":
|
||||
sock = _share_socket(sock)
|
||||
|
||||
servers.append(
|
||||
await asyncio.start_server(
|
||||
_server_callback,
|
||||
backlog=config.backlog,
|
||||
loop=loop,
|
||||
ssl=ssl_context,
|
||||
sock=sock,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
)
|
||||
bind = repr_socket_addr(sock.family, sock.getsockname())
|
||||
await config.log.info(f"Running on https://{bind} (CTRL + C to quit)")
|
||||
|
||||
for sock in sockets.insecure_sockets:
|
||||
if config.workers > 1 and platform.system() == "Windows":
|
||||
sock = _share_socket(sock)
|
||||
|
||||
servers.append(
|
||||
await asyncio.start_server(
|
||||
_server_callback, backlog=config.backlog, loop=loop, sock=sock
|
||||
)
|
||||
)
|
||||
bind = repr_socket_addr(sock.family, sock.getsockname())
|
||||
await config.log.info(f"Running on http://{bind} (CTRL + C to quit)")
|
||||
|
||||
tasks.extend(server.serve_forever() for server in servers) # type: ignore
|
||||
|
||||
for sock in sockets.quic_sockets:
|
||||
if config.workers > 1 and platform.system() == "Windows":
|
||||
sock = _share_socket(sock)
|
||||
|
||||
await loop.create_datagram_endpoint(lambda: UDPServer(app, loop, config), sock=sock)
|
||||
bind = repr_socket_addr(sock.family, sock.getsockname())
|
||||
await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)")
|
||||
|
||||
reload_ = False
|
||||
try:
|
||||
gathered_tasks = asyncio.gather(*tasks)
|
||||
await gathered_tasks
|
||||
except MustReloadException:
|
||||
reload_ = True
|
||||
except (Shutdown, KeyboardInterrupt):
|
||||
pass
|
||||
finally:
|
||||
for server in servers:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
try:
|
||||
await asyncio.sleep(config.graceful_timeout)
|
||||
except (Shutdown, KeyboardInterrupt):
|
||||
pass
|
||||
|
||||
# Retrieve the Gathered Tasks Cancelled Exception, to
|
||||
# prevent a warning that this hasn't been done.
|
||||
gathered_tasks.exception()
|
||||
|
||||
await lifespan.wait_for_shutdown()
|
||||
lifespan_task.cancel()
|
||||
await lifespan_task
|
||||
|
||||
if reload_:
|
||||
restart()
|
||||
|
||||
|
||||
def asyncio_worker(
|
||||
config: Config, sockets: Optional[Sockets] = None, shutdown_event: Optional[EventType] = None
|
||||
) -> None:
|
||||
app = load_application(config.application_path)
|
||||
|
||||
shutdown_trigger = None
|
||||
if shutdown_event is not None:
|
||||
shutdown_trigger = partial(check_multiprocess_shutdown_event, shutdown_event, asyncio.sleep)
|
||||
|
||||
if config.workers > 1 and platform.system() == "Windows":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore
|
||||
|
||||
_run(
|
||||
partial(worker_serve, app, config, sockets=sockets),
|
||||
debug=config.debug,
|
||||
shutdown_trigger=shutdown_trigger,
|
||||
)
|
||||
|
||||
|
||||
def uvloop_worker(
|
||||
config: Config, sockets: Optional[Sockets] = None, shutdown_event: Optional[EventType] = None
|
||||
) -> None:
|
||||
try:
|
||||
import uvloop
|
||||
except ImportError as error:
|
||||
raise Exception("uvloop is not installed") from error
|
||||
else:
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
app = load_application(config.application_path)
|
||||
|
||||
shutdown_trigger = None
|
||||
if shutdown_event is not None:
|
||||
shutdown_trigger = partial(check_multiprocess_shutdown_event, shutdown_event, asyncio.sleep)
|
||||
|
||||
_run(
|
||||
partial(worker_serve, app, config, sockets=sockets),
|
||||
debug=config.debug,
|
||||
shutdown_trigger=shutdown_trigger,
|
||||
)
|
||||
|
||||
|
||||
def _run(
|
||||
main: Callable,
|
||||
*,
|
||||
debug: bool = False,
|
||||
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.set_debug(debug)
|
||||
loop.set_exception_handler(_exception_handler)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(main(shutdown_trigger=shutdown_trigger))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
_cancel_all_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
tasks = [task for task in asyncio.all_tasks(loop) if not task.done()]
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
loop.run_until_complete(asyncio.gather(*tasks, loop=loop, return_exceptions=True))
|
||||
|
||||
for task in tasks:
|
||||
if not task.cancelled() and task.exception() is not None:
|
||||
loop.call_exception_handler(
|
||||
{
|
||||
"message": "unhandled exception during shutdown",
|
||||
"exception": task.exception(),
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _exception_handler(loop: asyncio.AbstractEventLoop, context: dict) -> None:
|
||||
exception = context.get("exception")
|
||||
if isinstance(exception, ssl.SSLError):
|
||||
pass # Handshake failure
|
||||
else:
|
||||
loop.default_exception_handler(context)
|
@@ -0,0 +1,24 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from ..config import Config
|
||||
from ..statsd import StatsdLogger as Base
|
||||
|
||||
|
||||
class _DummyProto(asyncio.DatagramProtocol):
|
||||
pass
|
||||
|
||||
|
||||
class StatsdLogger(Base):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__(config)
|
||||
self.address = config.statsd_host.rsplit(":", 1)
|
||||
self.transport: Optional[asyncio.BaseTransport] = None
|
||||
|
||||
async def _socket_send(self, message: bytes) -> None:
|
||||
if self.transport is None:
|
||||
self.transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
|
||||
_DummyProto, remote_addr=(self.address[0], int(self.address[1]))
|
||||
)
|
||||
|
||||
self.transport.sendto(message) # type: ignore
|
@@ -0,0 +1,34 @@
|
||||
import asyncio
|
||||
import weakref
|
||||
from types import TracebackType
|
||||
from typing import Coroutine
|
||||
|
||||
|
||||
class TaskGroup:
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._loop = loop
|
||||
self._tasks: weakref.WeakSet = weakref.WeakSet()
|
||||
|
||||
def spawn(self, coro: Coroutine) -> None:
|
||||
self._tasks.add(self._loop.create_task(coro))
|
||||
|
||||
async def __aenter__(self) -> "TaskGroup":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
|
||||
if exc_type is not None:
|
||||
self._cancel_tasks()
|
||||
|
||||
try:
|
||||
task = asyncio.gather(*self._tasks)
|
||||
await task
|
||||
finally:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _cancel_tasks(self) -> None:
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
@@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, cast, Generator, Optional
|
||||
|
||||
from .context import Context
|
||||
from .task_group import TaskGroup
|
||||
from ..config import Config
|
||||
from ..events import Closed, Event, RawData, Updated
|
||||
from ..protocol import ProtocolWrapper
|
||||
from ..typing import ASGIFramework
|
||||
from ..utils import parse_socket_addr
|
||||
|
||||
MAX_RECV = 2 ** 16
|
||||
|
||||
|
||||
class EventWrapper:
|
||||
def __init__(self) -> None:
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def clear(self) -> None:
|
||||
self._event.clear()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
async def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
|
||||
class TCPServer:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
config: Config,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.config = config
|
||||
self.loop = loop
|
||||
self.protocol: ProtocolWrapper
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.send_lock = asyncio.Lock()
|
||||
self.timeout_lock = asyncio.Lock()
|
||||
|
||||
self._keep_alive_timeout_handle: Optional[asyncio.Task] = None
|
||||
|
||||
def __await__(self) -> Generator[Any, None, None]:
|
||||
return self.run().__await__()
|
||||
|
||||
async def run(self) -> None:
|
||||
socket = self.writer.get_extra_info("socket")
|
||||
try:
|
||||
client = parse_socket_addr(socket.family, socket.getpeername())
|
||||
server = parse_socket_addr(socket.family, socket.getsockname())
|
||||
ssl_object = self.writer.get_extra_info("ssl_object")
|
||||
if ssl_object is not None:
|
||||
ssl = True
|
||||
alpn_protocol = ssl_object.selected_alpn_protocol()
|
||||
else:
|
||||
ssl = False
|
||||
alpn_protocol = "http/1.1"
|
||||
|
||||
async with TaskGroup(self.loop) as task_group:
|
||||
context = Context(task_group)
|
||||
self.protocol = ProtocolWrapper(
|
||||
self.app,
|
||||
self.config,
|
||||
cast(Any, context),
|
||||
ssl,
|
||||
client,
|
||||
server,
|
||||
self.protocol_send,
|
||||
alpn_protocol,
|
||||
)
|
||||
await self.protocol.initiate()
|
||||
await self._update_keep_alive_timeout()
|
||||
await self._read_data()
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
await self._close()
|
||||
|
||||
async def protocol_send(self, event: Event) -> None:
|
||||
if isinstance(event, RawData):
|
||||
async with self.send_lock:
|
||||
try:
|
||||
self.writer.write(event.data)
|
||||
await self.writer.drain()
|
||||
except ConnectionError:
|
||||
await self.protocol.handle(Closed())
|
||||
elif isinstance(event, Closed):
|
||||
await self._close()
|
||||
await self.protocol.handle(Closed())
|
||||
elif isinstance(event, Updated):
|
||||
pass # Triggers the keep alive timeout update
|
||||
await self._update_keep_alive_timeout()
|
||||
|
||||
async def _read_data(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
data = await self.reader.read(MAX_RECV)
|
||||
except (
|
||||
ConnectionError,
|
||||
OSError,
|
||||
asyncio.TimeoutError,
|
||||
TimeoutError,
|
||||
):
|
||||
await self.protocol.handle(Closed())
|
||||
break
|
||||
else:
|
||||
if data == b"":
|
||||
await self._update_keep_alive_timeout()
|
||||
break
|
||||
await self.protocol.handle(RawData(data))
|
||||
await self._update_keep_alive_timeout()
|
||||
|
||||
async def _close(self) -> None:
|
||||
try:
|
||||
self.writer.write_eof()
|
||||
except (NotImplementedError, OSError, RuntimeError):
|
||||
pass # Likely SSL connection
|
||||
|
||||
try:
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
except (BrokenPipeError, ConnectionResetError):
|
||||
pass # Already closed
|
||||
|
||||
async def _update_keep_alive_timeout(self) -> None:
|
||||
async with self.timeout_lock:
|
||||
if self._keep_alive_timeout_handle is not None:
|
||||
self._keep_alive_timeout_handle.cancel()
|
||||
try:
|
||||
await self._keep_alive_timeout_handle
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._keep_alive_timeout_handle = None
|
||||
if self.protocol.idle:
|
||||
self._keep_alive_timeout_handle = self.loop.create_task(
|
||||
_call_later(self.config.keep_alive_timeout, self._timeout)
|
||||
)
|
||||
|
||||
async def _timeout(self) -> None:
|
||||
await self.protocol.handle(Closed())
|
||||
self.writer.close()
|
||||
|
||||
|
||||
async def _call_later(timeout: float, callback: Callable) -> None:
|
||||
await asyncio.sleep(timeout)
|
||||
await asyncio.shield(callback())
|
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
from typing import Any, cast, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from .context import Context
|
||||
from .task_group import TaskGroup
|
||||
from ..config import Config
|
||||
from ..events import Closed, Event, RawData
|
||||
from ..typing import ASGIFramework
|
||||
from ..utils import parse_socket_addr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# h3/Quic is an optional part of Hypercorn
|
||||
from ..protocol.quic import QuicProtocol # noqa: F401
|
||||
|
||||
|
||||
class UDPServer(asyncio.DatagramProtocol):
|
||||
def __init__(self, app: ASGIFramework, loop: asyncio.AbstractEventLoop, config: Config) -> None:
|
||||
self.app = app
|
||||
self.config = config
|
||||
self.loop = loop
|
||||
self.protocol: "QuicProtocol"
|
||||
self.protocol_queue: asyncio.Queue = asyncio.Queue(10)
|
||||
self.transport: Optional[asyncio.DatagramTransport] = None
|
||||
|
||||
self.loop.create_task(self._consume_events())
|
||||
|
||||
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
|
||||
# h3/Quic is an optional part of Hypercorn
|
||||
from ..protocol.quic import QuicProtocol # noqa: F811
|
||||
|
||||
self.transport = transport
|
||||
socket = self.transport.get_extra_info("socket")
|
||||
server = parse_socket_addr(socket.family, socket.getsockname())
|
||||
task_group = TaskGroup(self.loop)
|
||||
context = Context(task_group)
|
||||
self.protocol = QuicProtocol(
|
||||
self.app, self.config, cast(Any, context), server, self.protocol_send
|
||||
)
|
||||
|
||||
def datagram_received(self, data: bytes, address: Tuple[bytes, str]) -> None: # type: ignore
|
||||
try:
|
||||
self.protocol_queue.put_nowait(RawData(data=data, address=address)) # type: ignore
|
||||
except asyncio.QueueFull:
|
||||
pass # Just throw the data away, is UDP
|
||||
|
||||
async def protocol_send(self, event: Event) -> None:
|
||||
if isinstance(event, RawData):
|
||||
self.transport.sendto(event.data, event.address)
|
||||
|
||||
async def _consume_events(self) -> None:
|
||||
while True:
|
||||
event = await self.protocol_queue.get()
|
||||
await self.protocol.handle(event)
|
||||
if isinstance(event, Closed):
|
||||
break
|
373
.venv/lib/python3.9/site-packages/hypercorn/config.py
Normal file
373
.venv/lib/python3.9/site-packages/hypercorn/config.py
Normal file
@@ -0,0 +1,373 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
import stat
|
||||
import types
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from ssl import SSLContext, VerifyFlags, VerifyMode
|
||||
from time import time
|
||||
from typing import Any, AnyStr, Dict, List, Mapping, Optional, Tuple, Type, Union
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import toml
|
||||
|
||||
from .logging import Logger
|
||||
|
||||
BYTES = 1
|
||||
OCTETS = 1
|
||||
SECONDS = 1.0
|
||||
|
||||
FilePath = Union[AnyStr, os.PathLike]
|
||||
SocketKind = Union[int, socket.SocketKind]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sockets:
|
||||
secure_sockets: List[socket.socket]
|
||||
insecure_sockets: List[socket.socket]
|
||||
quic_sockets: List[socket.socket]
|
||||
|
||||
|
||||
class SocketTypeError(Exception):
|
||||
def __init__(self, expected: SocketKind, actual: SocketKind) -> None:
|
||||
super().__init__(
|
||||
f'Unexpected socket type, wanted "{socket.SocketKind(expected)}" got '
|
||||
f'"{socket.SocketKind(actual)}"'
|
||||
)
|
||||
|
||||
|
||||
class Config:
|
||||
_bind = ["127.0.0.1:8000"]
|
||||
_insecure_bind: List[str] = []
|
||||
_quic_bind: List[str] = []
|
||||
_quic_addresses: List[Tuple] = []
|
||||
_log: Optional[Logger] = None
|
||||
|
||||
access_log_format = '%(h)s %(l)s %(l)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
|
||||
accesslog: Union[logging.Logger, str, None] = None
|
||||
alpn_protocols = ["h2", "http/1.1"]
|
||||
alt_svc_headers: List[str] = []
|
||||
application_path: str
|
||||
backlog = 100
|
||||
ca_certs: Optional[str] = None
|
||||
certfile: Optional[str] = None
|
||||
ciphers: str = "ECDHE+AESGCM"
|
||||
debug = False
|
||||
dogstatsd_tags = ""
|
||||
errorlog: Union[logging.Logger, str, None] = "-"
|
||||
graceful_timeout: float = 3 * SECONDS
|
||||
group: Optional[int] = None
|
||||
h11_max_incomplete_size = 16 * 1024 * BYTES
|
||||
h2_max_concurrent_streams = 100
|
||||
h2_max_header_list_size = 2 ** 16
|
||||
h2_max_inbound_frame_size = 2 ** 14 * OCTETS
|
||||
include_server_header = True
|
||||
keep_alive_timeout = 5 * SECONDS
|
||||
keyfile: Optional[str] = None
|
||||
logconfig: Optional[str] = None
|
||||
logconfig_dict: Optional[dict] = None
|
||||
logger_class = Logger
|
||||
loglevel: str = "INFO"
|
||||
max_app_queue_size: int = 10
|
||||
pid_path: Optional[str] = None
|
||||
root_path = ""
|
||||
server_names: List[str] = []
|
||||
shutdown_timeout = 60 * SECONDS
|
||||
ssl_handshake_timeout = 60 * SECONDS
|
||||
startup_timeout = 60 * SECONDS
|
||||
statsd_host: Optional[str] = None
|
||||
statsd_prefix = ""
|
||||
umask: Optional[int] = None
|
||||
use_reloader = False
|
||||
user: Optional[int] = None
|
||||
verify_flags: Optional[VerifyFlags] = None
|
||||
verify_mode: Optional[VerifyMode] = None
|
||||
websocket_max_message_size = 16 * 1024 * 1024 * BYTES
|
||||
websocket_ping_interval: Optional[int] = None
|
||||
worker_class = "asyncio"
|
||||
workers = 1
|
||||
|
||||
def set_cert_reqs(self, value: int) -> None:
|
||||
warnings.warn("Please use verify_mode instead", Warning)
|
||||
self.verify_mode = VerifyMode(value)
|
||||
|
||||
cert_reqs = property(None, set_cert_reqs)
|
||||
|
||||
@property
|
||||
def log(self) -> Logger:
|
||||
if self._log is None:
|
||||
self._log = self.logger_class(self)
|
||||
return self._log
|
||||
|
||||
@property
|
||||
def bind(self) -> List[str]:
|
||||
return self._bind
|
||||
|
||||
@bind.setter
|
||||
def bind(self, value: Union[List[str], str]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._bind = [value]
|
||||
else:
|
||||
self._bind = value
|
||||
|
||||
@property
|
||||
def insecure_bind(self) -> List[str]:
|
||||
return self._insecure_bind
|
||||
|
||||
@insecure_bind.setter
|
||||
def insecure_bind(self, value: Union[List[str], str]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._insecure_bind = [value]
|
||||
else:
|
||||
self._insecure_bind = value
|
||||
|
||||
@property
|
||||
def quic_bind(self) -> List[str]:
|
||||
return self._quic_bind
|
||||
|
||||
@quic_bind.setter
|
||||
def quic_bind(self, value: Union[List[str], str]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._quic_bind = [value]
|
||||
else:
|
||||
self._quic_bind = value
|
||||
|
||||
def create_ssl_context(self) -> Optional[SSLContext]:
|
||||
if not self.ssl_enabled:
|
||||
return None
|
||||
|
||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
context.set_ciphers(self.ciphers)
|
||||
cipher_opts = 0
|
||||
for attr in ["OP_NO_SSLv2", "OP_NO_SSLv3", "OP_NO_TLSv1", "OP_NO_TLSv1_1"]:
|
||||
if hasattr(ssl, attr): # To be future proof
|
||||
cipher_opts |= getattr(ssl, attr)
|
||||
context.options |= cipher_opts # RFC 7540 Section 9.2: MUST be TLS >=1.2
|
||||
context.options |= ssl.OP_NO_COMPRESSION # RFC 7540 Section 9.2.1: MUST disable compression
|
||||
context.set_alpn_protocols(self.alpn_protocols)
|
||||
|
||||
if self.certfile is not None and self.keyfile is not None:
|
||||
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
|
||||
|
||||
if self.ca_certs is not None:
|
||||
context.load_verify_locations(self.ca_certs)
|
||||
if self.verify_mode is not None:
|
||||
context.verify_mode = self.verify_mode
|
||||
if self.verify_flags is not None:
|
||||
context.verify_flags = self.verify_flags
|
||||
|
||||
return context
|
||||
|
||||
@property
|
||||
def ssl_enabled(self) -> bool:
|
||||
return self.certfile is not None and self.keyfile is not None
|
||||
|
||||
def create_sockets(self) -> Sockets:
|
||||
if self.ssl_enabled:
|
||||
secure_sockets = self._create_sockets(self.bind)
|
||||
insecure_sockets = self._create_sockets(self.insecure_bind)
|
||||
quic_sockets = self._create_sockets(self.quic_bind, socket.SOCK_DGRAM)
|
||||
self._set_quic_addresses(quic_sockets)
|
||||
else:
|
||||
secure_sockets = []
|
||||
insecure_sockets = self._create_sockets(self.bind)
|
||||
quic_sockets = []
|
||||
return Sockets(secure_sockets, insecure_sockets, quic_sockets)
|
||||
|
||||
def _set_quic_addresses(self, sockets: List[socket.socket]) -> None:
|
||||
self._quic_addresses = []
|
||||
for sock in sockets:
|
||||
name = sock.getsockname()
|
||||
if type(name) is not str and len(name) >= 2:
|
||||
self._quic_addresses.append(name)
|
||||
else:
|
||||
warnings.warn(
|
||||
f'Cannot create a alt-svc header for the QUIC socket with address "{name}"',
|
||||
Warning,
|
||||
)
|
||||
|
||||
def _create_sockets(
|
||||
self, binds: List[str], type_: int = socket.SOCK_STREAM
|
||||
) -> List[socket.socket]:
|
||||
sockets: List[socket.socket] = []
|
||||
for bind in binds:
|
||||
binding: Any = None
|
||||
if bind.startswith("unix:"):
|
||||
sock = socket.socket(socket.AF_UNIX, type_)
|
||||
binding = bind[5:]
|
||||
try:
|
||||
if stat.S_ISSOCK(os.stat(binding).st_mode):
|
||||
os.remove(binding)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
elif bind.startswith("fd://"):
|
||||
sock = socket.socket(fileno=int(bind[5:]))
|
||||
actual_type = sock.getsockopt(socket.SOL_SOCKET, socket.SO_TYPE)
|
||||
if actual_type != type_:
|
||||
raise SocketTypeError(type_, actual_type)
|
||||
else:
|
||||
bind = bind.replace("[", "").replace("]", "")
|
||||
try:
|
||||
value = bind.rsplit(":", 1)
|
||||
host, port = value[0], int(value[1])
|
||||
except (ValueError, IndexError):
|
||||
host, port = bind, 8000
|
||||
sock = socket.socket(socket.AF_INET6 if ":" in host else socket.AF_INET, type_)
|
||||
if self.workers > 1:
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
except AttributeError:
|
||||
pass
|
||||
binding = (host, port)
|
||||
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
if bind.startswith("unix:"):
|
||||
if self.umask is not None:
|
||||
current_umask = os.umask(self.umask)
|
||||
sock.bind(binding)
|
||||
if self.user is not None and self.group is not None:
|
||||
os.chown(binding, self.user, self.group)
|
||||
if self.umask is not None:
|
||||
os.umask(current_umask)
|
||||
elif bind.startswith("fd://"):
|
||||
pass
|
||||
else:
|
||||
sock.bind(binding)
|
||||
|
||||
sock.setblocking(False)
|
||||
try:
|
||||
sock.set_inheritable(True)
|
||||
except AttributeError:
|
||||
pass
|
||||
sockets.append(sock)
|
||||
return sockets
|
||||
|
||||
def response_headers(self, protocol: str) -> List[Tuple[bytes, bytes]]:
|
||||
headers = [(b"date", format_date_time(time()).encode("ascii"))]
|
||||
if self.include_server_header:
|
||||
headers.append((b"server", f"hypercorn-{protocol}".encode("ascii")))
|
||||
|
||||
for alt_svc_header in self.alt_svc_headers:
|
||||
headers.append((b"alt-svc", alt_svc_header.encode()))
|
||||
if len(self.alt_svc_headers) == 0 and self._quic_addresses:
|
||||
from aioquic.h3.connection import H3_ALPN
|
||||
|
||||
for version in H3_ALPN:
|
||||
for addr in self._quic_addresses:
|
||||
port = addr[1]
|
||||
headers.append((b"alt-svc", b'%s=":%d"; ma=3600' % (version.encode(), port)))
|
||||
|
||||
return headers
|
||||
|
||||
def set_statsd_logger_class(self, statsd_logger: Type[Logger]) -> None:
|
||||
if self.logger_class == Logger and self.statsd_host is not None:
|
||||
self.logger_class = statsd_logger
|
||||
|
||||
@classmethod
|
||||
def from_mapping(
|
||||
cls: Type["Config"], mapping: Optional[Mapping[str, Any]] = None, **kwargs: Any
|
||||
) -> "Config":
|
||||
"""Create a configuration from a mapping.
|
||||
|
||||
This allows either a mapping to be directly passed or as
|
||||
keyword arguments, for example,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
config = {'keep_alive_timeout': 10}
|
||||
Config.from_mapping(config)
|
||||
Config.from_mapping(keep_alive_timeout=10)
|
||||
|
||||
Arguments:
|
||||
mapping: Optionally a mapping object.
|
||||
kwargs: Optionally a collection of keyword arguments to
|
||||
form a mapping.
|
||||
"""
|
||||
mappings: Dict[str, Any] = {}
|
||||
if mapping is not None:
|
||||
mappings.update(mapping)
|
||||
mappings.update(kwargs)
|
||||
config = cls()
|
||||
for key, value in mappings.items():
|
||||
try:
|
||||
setattr(config, key, value)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_pyfile(cls: Type["Config"], filename: FilePath) -> "Config":
|
||||
"""Create a configuration from a Python file.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Config.from_pyfile('hypercorn_config.py')
|
||||
|
||||
Arguments:
|
||||
filename: The filename which gives the path to the file.
|
||||
"""
|
||||
file_path = os.fspath(filename)
|
||||
spec = importlib.util.spec_from_file_location("module.name", file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
return cls.from_object(module)
|
||||
|
||||
@classmethod
|
||||
def from_toml(cls: Type["Config"], filename: FilePath) -> "Config":
|
||||
"""Load the configuration values from a TOML formatted file.
|
||||
|
||||
This allows configuration to be loaded as so
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Config.from_toml('config.toml')
|
||||
|
||||
Arguments:
|
||||
filename: The filename which gives the path to the file.
|
||||
"""
|
||||
file_path = os.fspath(filename)
|
||||
with open(file_path) as file_:
|
||||
data = toml.load(file_)
|
||||
return cls.from_mapping(data)
|
||||
|
||||
@classmethod
|
||||
def from_object(cls: Type["Config"], instance: Union[object, str]) -> "Config":
|
||||
"""Create a configuration from a Python object.
|
||||
|
||||
This can be used to reference modules or objects within
|
||||
modules for example,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Config.from_object('module')
|
||||
Config.from_object('module.instance')
|
||||
from module import instance
|
||||
Config.from_object(instance)
|
||||
|
||||
are valid.
|
||||
|
||||
Arguments:
|
||||
instance: Either a str referencing a python object or the
|
||||
object itself.
|
||||
|
||||
"""
|
||||
if isinstance(instance, str):
|
||||
try:
|
||||
instance = importlib.import_module(instance)
|
||||
except ImportError:
|
||||
path, config = instance.rsplit(".", 1) # type: ignore
|
||||
module = importlib.import_module(path)
|
||||
instance = getattr(module, config)
|
||||
|
||||
mapping = {
|
||||
key: getattr(instance, key)
|
||||
for key in dir(instance)
|
||||
if not isinstance(getattr(instance, key), types.ModuleType)
|
||||
}
|
||||
return cls.from_mapping(mapping)
|
25
.venv/lib/python3.9/site-packages/hypercorn/events.py
Normal file
25
.venv/lib/python3.9/site-packages/hypercorn/events.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class Event(ABC):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RawData(Event):
|
||||
data: bytes
|
||||
address: Optional[Tuple[str, int]] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Closed(Event):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Updated(Event):
|
||||
# Indicate that the protocol has updated (although it has nothing
|
||||
# for the server to do).
|
||||
pass
|
182
.venv/lib/python3.9/site-packages/hypercorn/logging.py
Normal file
182
.venv/lib/python3.9/site-packages/hypercorn/logging.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from logging.config import dictConfig, fileConfig
|
||||
from typing import Any, IO, Mapping, Optional, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
from .typing import ResponseSummary, WWWScope
|
||||
|
||||
|
||||
def _create_logger(
|
||||
name: str,
|
||||
target: Union[logging.Logger, str, None],
|
||||
level: Optional[str],
|
||||
sys_default: IO,
|
||||
*,
|
||||
propagate: bool = True,
|
||||
) -> Optional[logging.Logger]:
|
||||
if isinstance(target, logging.Logger):
|
||||
return target
|
||||
|
||||
if target:
|
||||
logger = logging.getLogger(name)
|
||||
logger.handlers = [
|
||||
logging.StreamHandler(sys_default) if target == "-" else logging.FileHandler(target)
|
||||
]
|
||||
logger.propagate = propagate
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s [%(process)d] [%(levelname)s] %(message)s",
|
||||
"[%Y-%m-%d %H:%M:%S %z]",
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
if level is not None:
|
||||
logger.setLevel(logging.getLevelName(level.upper()))
|
||||
return logger
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, config: "Config") -> None:
|
||||
self.access_log_format = config.access_log_format
|
||||
|
||||
self.access_logger = _create_logger(
|
||||
"hypercorn.access",
|
||||
config.accesslog,
|
||||
config.loglevel,
|
||||
sys.stdout,
|
||||
propagate=False,
|
||||
)
|
||||
self.error_logger = _create_logger(
|
||||
"hypercorn.error", config.errorlog, config.loglevel, sys.stderr
|
||||
)
|
||||
|
||||
if config.logconfig is not None:
|
||||
log_config = {
|
||||
"__file__": config.logconfig,
|
||||
"here": os.path.dirname(config.logconfig),
|
||||
}
|
||||
fileConfig(config.logconfig, defaults=log_config, disable_existing_loggers=False)
|
||||
else:
|
||||
if config.logconfig_dict is not None:
|
||||
dictConfig(config.logconfig_dict)
|
||||
|
||||
async def access(
|
||||
self, request: "WWWScope", response: "ResponseSummary", request_time: float
|
||||
) -> None:
|
||||
if self.access_logger is not None:
|
||||
self.access_logger.info(
|
||||
self.access_log_format, self.atoms(request, response, request_time)
|
||||
)
|
||||
|
||||
async def critical(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
if self.error_logger is not None:
|
||||
self.error_logger.critical(message, *args, **kwargs)
|
||||
|
||||
async def error(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
if self.error_logger is not None:
|
||||
self.error_logger.error(message, *args, **kwargs)
|
||||
|
||||
async def warning(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
if self.error_logger is not None:
|
||||
self.error_logger.warning(message, *args, **kwargs)
|
||||
|
||||
async def info(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
if self.error_logger is not None:
|
||||
self.error_logger.info(message, *args, **kwargs)
|
||||
|
||||
async def debug(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
if self.error_logger is not None:
|
||||
self.error_logger.debug(message, *args, **kwargs)
|
||||
|
||||
async def exception(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
if self.error_logger is not None:
|
||||
self.error_logger.exception(message, *args, **kwargs)
|
||||
|
||||
async def log(self, level: int, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
if self.error_logger is not None:
|
||||
self.error_logger.log(level, message, *args, **kwargs)
|
||||
|
||||
def atoms(
|
||||
self, request: "WWWScope", response: "ResponseSummary", request_time: float
|
||||
) -> Mapping[str, str]:
|
||||
"""Create and return an access log atoms dictionary.
|
||||
|
||||
This can be overidden and customised if desired. It should
|
||||
return a mapping between an access log format key and a value.
|
||||
"""
|
||||
return AccessLogAtoms(request, response, request_time)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.error_logger, name)
|
||||
|
||||
|
||||
class AccessLogAtoms(dict):
|
||||
def __init__(
|
||||
self, request: "WWWScope", response: "ResponseSummary", request_time: float
|
||||
) -> None:
|
||||
for name, value in request["headers"]:
|
||||
self[f"{{{name.decode('latin1').lower()}}}i"] = value.decode("latin1")
|
||||
for name, value in response.get("headers", []):
|
||||
self[f"{{{name.decode('latin1').lower()}}}o"] = value.decode("latin1")
|
||||
for name, value in os.environ.items():
|
||||
self[f"{{{name.lower()}}}e"] = value
|
||||
protocol = request.get("http_version", "ws")
|
||||
client = request.get("client")
|
||||
if client is None:
|
||||
remote_addr = None
|
||||
elif len(client) == 2:
|
||||
remote_addr = f"{client[0]}:{client[1]}"
|
||||
elif len(client) == 1:
|
||||
remote_addr = client[0]
|
||||
else: # make sure not to throw UnboundLocalError
|
||||
remote_addr = f"<???{client}???>"
|
||||
if request["type"] == "http":
|
||||
method = request["method"]
|
||||
else:
|
||||
method = "GET"
|
||||
query_string = request["query_string"].decode()
|
||||
path_with_qs = request["path"] + ("?" + query_string if query_string else "")
|
||||
status_code = response["status"]
|
||||
try:
|
||||
status_phrase = HTTPStatus(status_code).phrase
|
||||
except ValueError:
|
||||
status_phrase = f"<???{status_code}???>"
|
||||
self.update(
|
||||
{
|
||||
"h": remote_addr,
|
||||
"l": "-",
|
||||
"t": time.strftime("[%d/%b/%Y:%H:%M:%S %z]"),
|
||||
"r": f"{method} {request['path']} {protocol}",
|
||||
"R": f"{method} {path_with_qs} {protocol}",
|
||||
"s": response["status"],
|
||||
"st": status_phrase,
|
||||
"S": request["scheme"],
|
||||
"m": method,
|
||||
"U": request["path"],
|
||||
"Uq": path_with_qs,
|
||||
"q": query_string,
|
||||
"H": protocol,
|
||||
"b": self["{Content-Length}o"],
|
||||
"B": self["{Content-Length}o"],
|
||||
"f": self["{Referer}i"],
|
||||
"a": self["{User-Agent}i"],
|
||||
"T": int(request_time),
|
||||
"D": int(request_time * 1_000_000),
|
||||
"L": f"{request_time:.6f}",
|
||||
"p": f"<{os.getpid()}>",
|
||||
}
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
try:
|
||||
if key.startswith("{"):
|
||||
return super().__getitem__(key.lower())
|
||||
else:
|
||||
return super().__getitem__(key)
|
||||
except KeyError:
|
||||
return "-"
|
@@ -0,0 +1,10 @@
|
||||
from .dispatcher import DispatcherMiddleware
|
||||
from .http_to_https import HTTPToHTTPSRedirectMiddleware
|
||||
from .wsgi import AsyncioWSGIMiddleware, TrioWSGIMiddleware
|
||||
|
||||
__all__ = (
|
||||
"AsyncioWSGIMiddleware",
|
||||
"DispatcherMiddleware",
|
||||
"HTTPToHTTPSRedirectMiddleware",
|
||||
"TrioWSGIMiddleware",
|
||||
)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,107 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Callable, Dict
|
||||
|
||||
from ..asyncio.task_group import TaskGroup
|
||||
from ..typing import ASGIFramework, Scope
|
||||
from ..utils import invoke_asgi
|
||||
|
||||
MAX_QUEUE_SIZE = 10
|
||||
|
||||
|
||||
class _DispatcherMiddleware:
|
||||
def __init__(self, mounts: Dict[str, ASGIFramework]) -> None:
|
||||
self.mounts = mounts
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
|
||||
if scope["type"] == "lifespan":
|
||||
await self._handle_lifespan(scope, receive, send)
|
||||
else:
|
||||
for path, app in self.mounts.items():
|
||||
if scope["path"].startswith(path):
|
||||
scope["path"] = scope["path"][len(path) :] or "/" # type: ignore
|
||||
return await invoke_asgi(app, scope, receive, send)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 404,
|
||||
"headers": [(b"content-length", b"0")],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body"})
|
||||
|
||||
async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncioDispatcherMiddleware(_DispatcherMiddleware):
|
||||
async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
|
||||
self.app_queues: Dict[str, asyncio.Queue] = {
|
||||
path: asyncio.Queue(MAX_QUEUE_SIZE) for path in self.mounts
|
||||
}
|
||||
self.startup_complete = {path: False for path in self.mounts}
|
||||
self.shutdown_complete = {path: False for path in self.mounts}
|
||||
|
||||
async with TaskGroup(asyncio.get_event_loop()) as task_group:
|
||||
for path, app in self.mounts.items():
|
||||
task_group.spawn(
|
||||
invoke_asgi(
|
||||
app, scope, self.app_queues[path].get, partial(self.send, path, send)
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
message = await receive()
|
||||
for queue in self.app_queues.values():
|
||||
await queue.put(message)
|
||||
if message["type"] == "lifespan.shutdown":
|
||||
break
|
||||
|
||||
async def send(self, path: str, send: Callable, message: dict) -> None:
|
||||
if message["type"] == "lifespan.startup.complete":
|
||||
self.startup_complete[path] = True
|
||||
if all(self.startup_complete.values()):
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown.complete":
|
||||
self.shutdown_complete[path] = True
|
||||
if all(self.shutdown_complete.values()):
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
|
||||
class TrioDispatcherMiddleware(_DispatcherMiddleware):
|
||||
async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
|
||||
import trio
|
||||
|
||||
self.app_queues = {path: trio.open_memory_channel(MAX_QUEUE_SIZE) for path in self.mounts}
|
||||
self.startup_complete = {path: False for path in self.mounts}
|
||||
self.shutdown_complete = {path: False for path in self.mounts}
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for path, app in self.mounts.items():
|
||||
nursery.start_soon(
|
||||
invoke_asgi,
|
||||
app,
|
||||
scope,
|
||||
self.app_queues[path][1].receive,
|
||||
partial(self.send, path, send),
|
||||
)
|
||||
|
||||
while True:
|
||||
message = await receive()
|
||||
for channels in self.app_queues.values():
|
||||
await channels[0].send(message)
|
||||
if message["type"] == "lifespan.shutdown":
|
||||
break
|
||||
|
||||
async def send(self, path: str, send: Callable, message: dict) -> None:
|
||||
if message["type"] == "lifespan.startup.complete":
|
||||
self.startup_complete[path] = True
|
||||
if all(self.startup_complete.values()):
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown.complete":
|
||||
self.shutdown_complete[path] = True
|
||||
if all(self.shutdown_complete.values()):
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
|
||||
DispatcherMiddleware = AsyncioDispatcherMiddleware # Remove with version 0.11
|
@@ -0,0 +1,66 @@
|
||||
from typing import Callable, Optional
|
||||
from urllib.parse import urlunsplit
|
||||
|
||||
from ..typing import ASGIFramework, HTTPScope, Scope, WebsocketScope, WWWScope
|
||||
from ..utils import invoke_asgi
|
||||
|
||||
|
||||
class HTTPToHTTPSRedirectMiddleware:
|
||||
def __init__(self, app: ASGIFramework, host: Optional[str]) -> None:
|
||||
self.app = app
|
||||
self.host = host
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
|
||||
if scope["type"] == "http" and scope["scheme"] == "http":
|
||||
await self._send_http_redirect(scope, send)
|
||||
elif scope["type"] == "websocket" and scope["scheme"] == "ws":
|
||||
# If the server supports the WebSocket Denial Response
|
||||
# extension we can send a redirection response, if not we
|
||||
# can only deny the WebSocket connection.
|
||||
if "websocket.http.response" in scope.get("extensions", {}):
|
||||
await self._send_websocket_redirect(scope, send)
|
||||
else:
|
||||
await send({"type": "websocket.close"})
|
||||
else:
|
||||
return await invoke_asgi(self.app, scope, receive, send)
|
||||
|
||||
async def _send_http_redirect(self, scope: HTTPScope, send: Callable) -> None:
|
||||
new_url = self._new_url("https", scope)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 307,
|
||||
"headers": [(b"location", new_url.encode())],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body"})
|
||||
|
||||
async def _send_websocket_redirect(self, scope: WebsocketScope, send: Callable) -> None:
|
||||
# If the HTTP version is 2 we should redirect with a https
|
||||
# scheme not wss.
|
||||
scheme = "wss"
|
||||
if scope.get("http_version", "1.1") == "2":
|
||||
scheme = "https"
|
||||
|
||||
new_url = self._new_url(scheme, scope)
|
||||
await send(
|
||||
{
|
||||
"type": "websocket.http.response.start",
|
||||
"status": 307,
|
||||
"headers": [(b"location", new_url.encode())],
|
||||
}
|
||||
)
|
||||
await send({"type": "websocket.http.response.body"})
|
||||
|
||||
def _new_url(self, scheme: str, scope: WWWScope) -> str:
|
||||
host = self.host
|
||||
if host is None:
|
||||
for key, value in scope["headers"]:
|
||||
if key == b"host":
|
||||
host = value.decode("latin-1")
|
||||
break
|
||||
if host is None:
|
||||
raise ValueError("Host to redirect to cannot be determined")
|
||||
|
||||
path = scope.get("root_path", "") + scope["raw_path"].decode()
|
||||
return urlunsplit((scheme, host, path, scope["query_string"].decode(), ""))
|
132
.venv/lib/python3.9/site-packages/hypercorn/middleware/wsgi.py
Normal file
132
.venv/lib/python3.9/site-packages/hypercorn/middleware/wsgi.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
|
||||
from ..typing import HTTPScope, Scope
|
||||
|
||||
MAX_BODY_SIZE = 2 ** 16
|
||||
|
||||
WSGICallable = Callable[[dict, Callable], Iterable[bytes]]
|
||||
|
||||
|
||||
class _WSGIMiddleware:
|
||||
def __init__(self, wsgi_app: WSGICallable, max_body_size: int = MAX_BODY_SIZE) -> None:
|
||||
self.wsgi_app = wsgi_app
|
||||
self.max_body_size = max_body_size
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
|
||||
if scope["type"] == "http":
|
||||
status_code, headers, body = await self._handle_http(scope, receive, send)
|
||||
await send({"type": "http.response.start", "status": status_code, "headers": headers})
|
||||
await send({"type": "http.response.body", "body": body})
|
||||
elif scope["type"] == "websocket":
|
||||
await send({"type": "websocket.close"})
|
||||
elif scope["type"] == "lifespan":
|
||||
return
|
||||
else:
|
||||
raise Exception(f"Unknown scope type, {scope['type']}")
|
||||
|
||||
async def _handle_http(
|
||||
self, scope: HTTPScope, receive: Callable, send: Callable
|
||||
) -> Tuple[int, list, bytes]:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncioWSGIMiddleware(_WSGIMiddleware):
|
||||
async def _handle_http(
|
||||
self, scope: HTTPScope, receive: Callable, send: Callable
|
||||
) -> Tuple[int, list, bytes]:
|
||||
loop = asyncio.get_event_loop()
|
||||
instance = _WSGIInstance(self.wsgi_app, self.max_body_size)
|
||||
return await instance.handle_http(scope, receive, partial(loop.run_in_executor, None))
|
||||
|
||||
|
||||
class TrioWSGIMiddleware(_WSGIMiddleware):
|
||||
async def _handle_http(
|
||||
self, scope: HTTPScope, receive: Callable, send: Callable
|
||||
) -> Tuple[int, list, bytes]:
|
||||
import trio
|
||||
|
||||
instance = _WSGIInstance(self.wsgi_app, self.max_body_size)
|
||||
return await instance.handle_http(scope, receive, trio.to_thread.run_sync)
|
||||
|
||||
|
||||
class _WSGIInstance:
|
||||
def __init__(self, wsgi_app: WSGICallable, max_body_size: int = MAX_BODY_SIZE) -> None:
|
||||
self.wsgi_app = wsgi_app
|
||||
self.max_body_size = max_body_size
|
||||
self.status_code = 500
|
||||
self.headers: list = []
|
||||
|
||||
async def handle_http(
|
||||
self, scope: HTTPScope, receive: Callable, spawn: Callable
|
||||
) -> Tuple[int, list, bytes]:
|
||||
self.scope = scope
|
||||
body = bytearray()
|
||||
while True:
|
||||
message = await receive()
|
||||
body.extend(message.get("body", b""))
|
||||
if len(body) > self.max_body_size:
|
||||
return 400, [], b""
|
||||
if not message.get("more_body"):
|
||||
break
|
||||
return await spawn(self.run_wsgi_app, body)
|
||||
|
||||
def _start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: List[Tuple[str, str]],
|
||||
exc_info: Optional[Exception] = None,
|
||||
) -> None:
|
||||
raw, _ = status.split(" ", 1)
|
||||
self.status_code = int(raw)
|
||||
self.headers = [
|
||||
(name.lower().encode("ascii"), value.encode("ascii"))
|
||||
for name, value in response_headers
|
||||
]
|
||||
|
||||
def run_wsgi_app(self, body: bytes) -> Tuple[int, list, bytes]:
|
||||
environ = _build_environ(self.scope, body)
|
||||
body = bytearray()
|
||||
for output in self.wsgi_app(environ, self._start_response):
|
||||
body.extend(output)
|
||||
return self.status_code, self.headers, body
|
||||
|
||||
|
||||
def _build_environ(scope: HTTPScope, body: bytes) -> dict:
|
||||
server = scope.get("server") or ("localhost", 80)
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
|
||||
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_NAME": server[0],
|
||||
"SERVER_PORT": server[1],
|
||||
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": scope.get("scheme", "http"),
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.errors": BytesIO(),
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
}
|
||||
|
||||
if "client" in scope:
|
||||
environ["REMOTE_ADDR"] = scope["client"][0]
|
||||
|
||||
for raw_name, raw_value in scope.get("headers", []):
|
||||
name = raw_name.decode("latin1")
|
||||
if name == "content-length":
|
||||
corrected_name = "CONTENT_LENGTH"
|
||||
elif name == "content-type":
|
||||
corrected_name = "CONTENT_TYPE"
|
||||
else:
|
||||
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
|
||||
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case
|
||||
value = raw_value.decode("latin1")
|
||||
if corrected_name in environ:
|
||||
value = environ[corrected_name] + "," + value # type: ignore
|
||||
environ[corrected_name] = value
|
||||
return environ
|
@@ -0,0 +1,86 @@
|
||||
from typing import Awaitable, Callable, Optional, Tuple, Union
|
||||
|
||||
from .h2 import H2Protocol
|
||||
from .h11 import H2CProtocolRequired, H2ProtocolAssumed, H11Protocol
|
||||
from ..config import Config
|
||||
from ..events import Event, RawData
|
||||
from ..typing import ASGIFramework, Context
|
||||
|
||||
|
||||
class ProtocolWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
context: Context,
|
||||
ssl: bool,
|
||||
client: Optional[Tuple[str, int]],
|
||||
server: Optional[Tuple[str, int]],
|
||||
send: Callable[[Event], Awaitable[None]],
|
||||
alpn_protocol: Optional[str] = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.config = config
|
||||
self.context = context
|
||||
self.ssl = ssl
|
||||
self.client = client
|
||||
self.server = server
|
||||
self.send = send
|
||||
self.protocol: Union[H11Protocol, H2Protocol]
|
||||
if alpn_protocol == "h2":
|
||||
self.protocol = H2Protocol(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
self.ssl,
|
||||
self.client,
|
||||
self.server,
|
||||
self.send,
|
||||
)
|
||||
else:
|
||||
self.protocol = H11Protocol(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
self.ssl,
|
||||
self.client,
|
||||
self.server,
|
||||
self.send,
|
||||
)
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return self.protocol.idle
|
||||
|
||||
async def initiate(self) -> None:
|
||||
return await self.protocol.initiate()
|
||||
|
||||
async def handle(self, event: Event) -> None:
|
||||
try:
|
||||
return await self.protocol.handle(event)
|
||||
except H2ProtocolAssumed as error:
|
||||
self.protocol = H2Protocol(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
self.ssl,
|
||||
self.client,
|
||||
self.server,
|
||||
self.send,
|
||||
)
|
||||
await self.protocol.initiate()
|
||||
if error.data != b"":
|
||||
return await self.protocol.handle(RawData(data=error.data))
|
||||
except H2CProtocolRequired as error:
|
||||
self.protocol = H2Protocol(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
self.ssl,
|
||||
self.client,
|
||||
self.server,
|
||||
self.send,
|
||||
)
|
||||
await self.protocol.initiate(error.headers, error.settings)
|
||||
if error.data != b"":
|
||||
return await self.protocol.handle(RawData(data=error.data))
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,46 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Event:
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Request(Event):
|
||||
headers: List[Tuple[bytes, bytes]]
|
||||
http_version: str
|
||||
method: str
|
||||
raw_path: bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Body(Event):
|
||||
data: bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EndBody(Event):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Data(Event):
|
||||
data: bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EndData(Event):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Response(Event):
|
||||
headers: List[Tuple[bytes, bytes]]
|
||||
status_code: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamClosed(Event):
|
||||
pass
|
289
.venv/lib/python3.9/site-packages/hypercorn/protocol/h11.py
Normal file
289
.venv/lib/python3.9/site-packages/hypercorn/protocol/h11.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from itertools import chain
|
||||
from typing import Awaitable, Callable, Optional, Tuple, Union
|
||||
|
||||
import h11
|
||||
|
||||
from .events import (
|
||||
Body,
|
||||
Data,
|
||||
EndBody,
|
||||
EndData,
|
||||
Event as StreamEvent,
|
||||
Request,
|
||||
Response,
|
||||
StreamClosed,
|
||||
)
|
||||
from .http_stream import HTTPStream
|
||||
from .ws_stream import WSStream
|
||||
from ..config import Config
|
||||
from ..events import Closed, Event, RawData, Updated
|
||||
from ..typing import ASGIFramework, Context, H11SendableEvent
|
||||
|
||||
STREAM_ID = 1
|
||||
|
||||
|
||||
class H2CProtocolRequired(Exception):
|
||||
def __init__(self, data: bytes, request: h11.Request) -> None:
|
||||
settings = ""
|
||||
headers = [(b":method", request.method), (b":path", request.target)]
|
||||
for name, value in request.headers:
|
||||
if name.lower() == b"http2-settings":
|
||||
settings = value.decode()
|
||||
elif name.lower() == b"host":
|
||||
headers.append((b":authority", value))
|
||||
headers.append((name, value))
|
||||
|
||||
self.data = data
|
||||
self.headers = headers
|
||||
self.settings = settings
|
||||
|
||||
|
||||
class H2ProtocolAssumed(Exception):
|
||||
def __init__(self, data: bytes) -> None:
|
||||
self.data = data
|
||||
|
||||
|
||||
class H11WSConnection:
|
||||
# This class matches the h11 interface, and either passes data
|
||||
# through without altering it (for Data, EndData) or sends h11
|
||||
# events (Response, Body, EndBody).
|
||||
our_state = None # Prevents recycling the connection
|
||||
they_are_waiting_for_100_continue = False
|
||||
|
||||
def __init__(self, h11_connection: h11.Connection) -> None:
|
||||
self.buffer = bytearray(h11_connection.trailing_data[0])
|
||||
self.h11_connection = h11_connection
|
||||
|
||||
def receive_data(self, data: bytes) -> None:
|
||||
self.buffer.extend(data)
|
||||
|
||||
def next_event(self) -> Data:
|
||||
if self.buffer:
|
||||
event = Data(stream_id=STREAM_ID, data=bytes(self.buffer))
|
||||
self.buffer = bytearray()
|
||||
return event
|
||||
else:
|
||||
return h11.NEED_DATA
|
||||
|
||||
def send(self, event: H11SendableEvent) -> bytes:
|
||||
return self.h11_connection.send(event)
|
||||
|
||||
|
||||
class H11Protocol:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
context: Context,
|
||||
ssl: bool,
|
||||
client: Optional[Tuple[str, int]],
|
||||
server: Optional[Tuple[str, int]],
|
||||
send: Callable[[Event], Awaitable[None]],
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.can_read = context.event_class()
|
||||
self.client = client
|
||||
self.config = config
|
||||
self.connection = h11.Connection(
|
||||
h11.SERVER, max_incomplete_event_size=self.config.h11_max_incomplete_size
|
||||
)
|
||||
self.context = context
|
||||
self.send = send
|
||||
self.server = server
|
||||
self.ssl = ssl
|
||||
self.stream: Optional[Union[HTTPStream, WSStream]] = None
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return self.stream is None or self.stream.idle
|
||||
|
||||
async def initiate(self) -> None:
|
||||
pass
|
||||
|
||||
async def handle(self, event: Event) -> None:
|
||||
if isinstance(event, RawData):
|
||||
self.connection.receive_data(event.data)
|
||||
await self._handle_events()
|
||||
elif isinstance(event, Closed):
|
||||
if self.stream is not None:
|
||||
await self._close_stream()
|
||||
|
||||
async def stream_send(self, event: StreamEvent) -> None:
|
||||
if isinstance(event, Response):
|
||||
if event.status_code >= 200:
|
||||
await self._send_h11_event(
|
||||
h11.Response(
|
||||
headers=chain(event.headers, self.config.response_headers("h11")),
|
||||
status_code=event.status_code,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self._send_h11_event(
|
||||
h11.InformationalResponse(
|
||||
headers=chain(event.headers, self.config.response_headers("h11")),
|
||||
status_code=event.status_code,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, Body):
|
||||
await self._send_h11_event(h11.Data(data=event.data))
|
||||
elif isinstance(event, EndBody):
|
||||
await self._send_h11_event(h11.EndOfMessage())
|
||||
elif isinstance(event, Data):
|
||||
await self.send(RawData(data=event.data))
|
||||
elif isinstance(event, EndData):
|
||||
pass
|
||||
elif isinstance(event, StreamClosed):
|
||||
await self._maybe_recycle()
|
||||
|
||||
async def _handle_events(self) -> None:
|
||||
while True:
|
||||
if self.connection.they_are_waiting_for_100_continue:
|
||||
await self._send_h11_event(
|
||||
h11.InformationalResponse(
|
||||
status_code=100, headers=self.config.response_headers("h11")
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
event = self.connection.next_event()
|
||||
except h11.RemoteProtocolError:
|
||||
if self.connection.our_state in {h11.IDLE, h11.SEND_RESPONSE}:
|
||||
await self._send_error_response(400)
|
||||
await self.send(Closed())
|
||||
break
|
||||
else:
|
||||
if isinstance(event, h11.Request):
|
||||
await self._check_protocol(event)
|
||||
await self._create_stream(event)
|
||||
elif event is h11.PAUSED:
|
||||
await self.can_read.clear()
|
||||
await self.send(Updated())
|
||||
await self.can_read.wait()
|
||||
elif isinstance(event, h11.ConnectionClosed) or event is h11.NEED_DATA:
|
||||
break
|
||||
elif self.stream is None:
|
||||
break
|
||||
elif isinstance(event, h11.Data):
|
||||
await self.stream.handle(Body(stream_id=STREAM_ID, data=event.data))
|
||||
elif isinstance(event, h11.EndOfMessage):
|
||||
await self.stream.handle(EndBody(stream_id=STREAM_ID))
|
||||
elif isinstance(event, Data):
|
||||
# WebSocket pass through
|
||||
await self.stream.handle(event)
|
||||
|
||||
async def _create_stream(self, request: h11.Request) -> None:
|
||||
upgrade_value = ""
|
||||
connection_value = ""
|
||||
for name, value in request.headers:
|
||||
sanitised_name = name.decode("latin1").strip().lower()
|
||||
if sanitised_name == "upgrade":
|
||||
upgrade_value = value.decode("latin1").strip()
|
||||
elif sanitised_name == "connection":
|
||||
connection_value = value.decode("latin1").strip()
|
||||
|
||||
connection_tokens = connection_value.lower().split(",")
|
||||
if (
|
||||
any(token.strip() == "upgrade" for token in connection_tokens)
|
||||
and upgrade_value.lower() == "websocket"
|
||||
and request.method.decode("ascii").upper() == "GET"
|
||||
):
|
||||
self.stream = WSStream(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
self.ssl,
|
||||
self.client,
|
||||
self.server,
|
||||
self.stream_send,
|
||||
STREAM_ID,
|
||||
)
|
||||
self.connection = H11WSConnection(self.connection)
|
||||
else:
|
||||
self.stream = HTTPStream(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
self.ssl,
|
||||
self.client,
|
||||
self.server,
|
||||
self.stream_send,
|
||||
STREAM_ID,
|
||||
)
|
||||
await self.stream.handle(
|
||||
Request(
|
||||
stream_id=STREAM_ID,
|
||||
headers=request.headers,
|
||||
http_version=request.http_version.decode(),
|
||||
method=request.method.decode("ascii").upper(),
|
||||
raw_path=request.target,
|
||||
)
|
||||
)
|
||||
|
||||
async def _send_h11_event(self, event: H11SendableEvent) -> None:
|
||||
try:
|
||||
data = self.connection.send(event)
|
||||
except h11.LocalProtocolError:
|
||||
if self.connection.their_state != h11.ERROR:
|
||||
raise
|
||||
else:
|
||||
await self.send(RawData(data=data))
|
||||
|
||||
async def _send_error_response(self, status_code: int) -> None:
|
||||
await self._send_h11_event(
|
||||
h11.Response(
|
||||
status_code=status_code,
|
||||
headers=chain(
|
||||
[(b"content-length", b"0"), (b"connection", b"close")],
|
||||
self.config.response_headers("h11"),
|
||||
),
|
||||
)
|
||||
)
|
||||
await self._send_h11_event(h11.EndOfMessage())
|
||||
|
||||
async def _maybe_recycle(self) -> None:
|
||||
await self._close_stream()
|
||||
if self.connection.our_state is h11.DONE:
|
||||
try:
|
||||
self.connection.start_next_cycle()
|
||||
except h11.LocalProtocolError:
|
||||
await self.send(Closed())
|
||||
else:
|
||||
self.response = None
|
||||
self.scope = None
|
||||
await self.can_read.set()
|
||||
await self.send(Updated())
|
||||
else:
|
||||
await self.can_read.set()
|
||||
await self.send(Closed())
|
||||
|
||||
async def _close_stream(self) -> None:
|
||||
if self.stream is not None:
|
||||
await self.stream.handle(StreamClosed(stream_id=STREAM_ID))
|
||||
self.stream = None
|
||||
|
||||
async def _check_protocol(self, event: h11.Request) -> None:
|
||||
upgrade_value = ""
|
||||
has_body = False
|
||||
for name, value in event.headers:
|
||||
sanitised_name = name.decode("latin1").strip().lower()
|
||||
if sanitised_name == "upgrade":
|
||||
upgrade_value = value.decode("latin1").strip()
|
||||
elif sanitised_name in {"content-length", "transfer-encoding"}:
|
||||
has_body = True
|
||||
|
||||
# h2c Upgrade requests with a body are a pain as the body must
|
||||
# be fully recieved in HTTP/1.1 before the upgrade response
|
||||
# and HTTP/2 takes over, so Hypercorn ignores the upgrade and
|
||||
# responds in HTTP/1.1. Use a preflight OPTIONS request to
|
||||
# initiate the upgrade if really required (or just use h2).
|
||||
if upgrade_value.lower() == "h2c" and not has_body:
|
||||
await self._send_h11_event(
|
||||
h11.InformationalResponse(
|
||||
status_code=101,
|
||||
headers=self.config.response_headers("h11")
|
||||
+ [(b"connection", b"upgrade"), (b"upgrade", b"h2c")],
|
||||
)
|
||||
)
|
||||
raise H2CProtocolRequired(self.connection.trailing_data[0], event)
|
||||
elif event.method == b"PRI" and event.target == b"*" and event.http_version == b"2.0":
|
||||
raise H2ProtocolAssumed(b"PRI * HTTP/2.0\r\n\r\n" + self.connection.trailing_data[0])
|
362
.venv/lib/python3.9/site-packages/hypercorn/protocol/h2.py
Normal file
362
.venv/lib/python3.9/site-packages/hypercorn/protocol/h2.py
Normal file
@@ -0,0 +1,362 @@
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import h2
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import h2.exceptions
|
||||
import priority
|
||||
|
||||
from .events import (
|
||||
Body,
|
||||
Data,
|
||||
EndBody,
|
||||
EndData,
|
||||
Event as StreamEvent,
|
||||
Request,
|
||||
Response,
|
||||
StreamClosed,
|
||||
)
|
||||
from .http_stream import HTTPStream
|
||||
from .ws_stream import WSStream
|
||||
from ..config import Config
|
||||
from ..events import Closed, Event, RawData, Updated
|
||||
from ..typing import ASGIFramework, Context, Event as IOEvent
|
||||
from ..utils import filter_pseudo_headers
|
||||
|
||||
BUFFER_HIGH_WATER = 2 * 2 ** 14 # Twice the default max frame size (two frames worth)
|
||||
BUFFER_LOW_WATER = BUFFER_HIGH_WATER / 2
|
||||
|
||||
|
||||
class BufferCompleteError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class StreamBuffer:
|
||||
def __init__(self, event_class: Type[IOEvent]) -> None:
|
||||
self.buffer = bytearray()
|
||||
self._complete = False
|
||||
self._is_empty = event_class()
|
||||
self._paused = event_class()
|
||||
|
||||
async def drain(self) -> None:
|
||||
await self._is_empty.wait()
|
||||
|
||||
def set_complete(self) -> None:
|
||||
self._complete = True
|
||||
|
||||
async def close(self) -> None:
|
||||
self._complete = True
|
||||
self.buffer = bytearray()
|
||||
await self._is_empty.set()
|
||||
await self._paused.set()
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
return self._complete and len(self.buffer) == 0
|
||||
|
||||
async def push(self, data: bytes) -> None:
|
||||
if self._complete:
|
||||
raise BufferCompleteError()
|
||||
self.buffer.extend(data)
|
||||
await self._is_empty.clear()
|
||||
if len(self.buffer) >= BUFFER_HIGH_WATER:
|
||||
await self._paused.wait()
|
||||
await self._paused.clear()
|
||||
|
||||
async def pop(self, max_length: int) -> bytes:
|
||||
length = min(len(self.buffer), max_length)
|
||||
data = bytes(self.buffer[:length])
|
||||
del self.buffer[:length]
|
||||
if len(data) < BUFFER_LOW_WATER:
|
||||
await self._paused.set()
|
||||
if len(self.buffer) == 0:
|
||||
await self._is_empty.set()
|
||||
return data
|
||||
|
||||
|
||||
class H2Protocol:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
context: Context,
|
||||
ssl: bool,
|
||||
client: Optional[Tuple[str, int]],
|
||||
server: Optional[Tuple[str, int]],
|
||||
send: Callable[[Event], Awaitable[None]],
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.client = client
|
||||
self.closed = False
|
||||
self.config = config
|
||||
self.context = context
|
||||
|
||||
self.connection = h2.connection.H2Connection(
|
||||
config=h2.config.H2Configuration(client_side=False, header_encoding=None)
|
||||
)
|
||||
self.connection.DEFAULT_MAX_INBOUND_FRAME_SIZE = config.h2_max_inbound_frame_size
|
||||
self.connection.local_settings = h2.settings.Settings(
|
||||
client=False,
|
||||
initial_values={
|
||||
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: config.h2_max_concurrent_streams,
|
||||
h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: config.h2_max_header_list_size,
|
||||
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL: 1,
|
||||
},
|
||||
)
|
||||
|
||||
self.send = send
|
||||
self.server = server
|
||||
self.ssl = ssl
|
||||
self.streams: Dict[int, Union[HTTPStream, WSStream]] = {}
|
||||
# The below are used by the sending task
|
||||
self.has_data = self.context.event_class()
|
||||
self.priority = priority.PriorityTree()
|
||||
self.stream_buffers: Dict[int, StreamBuffer] = {}
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return len(self.streams) == 0 or all(stream.idle for stream in self.streams.values())
|
||||
|
||||
async def initiate(
|
||||
self, headers: Optional[List[Tuple[bytes, bytes]]] = None, settings: Optional[str] = None
|
||||
) -> None:
|
||||
if settings is not None:
|
||||
self.connection.initiate_upgrade_connection(settings)
|
||||
else:
|
||||
self.connection.initiate_connection()
|
||||
await self._flush()
|
||||
if headers is not None:
|
||||
event = h2.events.RequestReceived()
|
||||
event.stream_id = 1
|
||||
event.headers = headers
|
||||
await self._create_stream(event)
|
||||
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
|
||||
self.context.spawn(self.send_task)
|
||||
|
||||
async def send_task(self) -> None:
|
||||
# This should be run in a seperate task to the rest of this
|
||||
# class. This allows it seperately choose when to send,
|
||||
# crucially in what order.
|
||||
while not self.closed:
|
||||
try:
|
||||
stream_id = next(self.priority)
|
||||
except priority.DeadlockError:
|
||||
await self.has_data.wait()
|
||||
await self.has_data.clear()
|
||||
else:
|
||||
await self._send_data(stream_id)
|
||||
|
||||
async def _send_data(self, stream_id: int) -> None:
|
||||
try:
|
||||
chunk_size = min(
|
||||
self.connection.local_flow_control_window(stream_id),
|
||||
self.connection.max_outbound_frame_size,
|
||||
)
|
||||
chunk_size = max(0, chunk_size)
|
||||
data = await self.stream_buffers[stream_id].pop(chunk_size)
|
||||
if data:
|
||||
self.connection.send_data(stream_id, data)
|
||||
await self._flush()
|
||||
else:
|
||||
self.priority.block(stream_id)
|
||||
|
||||
if self.stream_buffers[stream_id].complete:
|
||||
self.connection.end_stream(stream_id)
|
||||
await self._flush()
|
||||
del self.stream_buffers[stream_id]
|
||||
self.priority.remove_stream(stream_id)
|
||||
except (h2.exceptions.StreamClosedError, KeyError, h2.exceptions.ProtocolError):
|
||||
# Stream or connection has closed whilst waiting to send
|
||||
# data, not a problem - just force close it.
|
||||
await self.stream_buffers[stream_id].close()
|
||||
del self.stream_buffers[stream_id]
|
||||
self.priority.remove_stream(stream_id)
|
||||
|
||||
async def handle(self, event: Event) -> None:
|
||||
if isinstance(event, RawData):
|
||||
try:
|
||||
events = self.connection.receive_data(event.data)
|
||||
except h2.exceptions.ProtocolError:
|
||||
await self._flush()
|
||||
await self.send(Closed())
|
||||
else:
|
||||
await self._handle_events(events)
|
||||
elif isinstance(event, Closed):
|
||||
self.closed = True
|
||||
stream_ids = list(self.streams.keys())
|
||||
for stream_id in stream_ids:
|
||||
await self._close_stream(stream_id)
|
||||
await self.has_data.set()
|
||||
|
||||
async def stream_send(self, event: StreamEvent) -> None:
|
||||
try:
|
||||
if isinstance(event, Response):
|
||||
self.connection.send_headers(
|
||||
event.stream_id,
|
||||
[(b":status", b"%d" % event.status_code)]
|
||||
+ event.headers
|
||||
+ self.config.response_headers("h2"),
|
||||
)
|
||||
await self._flush()
|
||||
elif isinstance(event, (Body, Data)):
|
||||
self.priority.unblock(event.stream_id)
|
||||
await self.has_data.set()
|
||||
await self.stream_buffers[event.stream_id].push(event.data)
|
||||
elif isinstance(event, (EndBody, EndData)):
|
||||
self.stream_buffers[event.stream_id].set_complete()
|
||||
self.priority.unblock(event.stream_id)
|
||||
await self.has_data.set()
|
||||
await self.stream_buffers[event.stream_id].drain()
|
||||
elif isinstance(event, StreamClosed):
|
||||
await self._close_stream(event.stream_id)
|
||||
await self.send(Updated())
|
||||
elif isinstance(event, Request):
|
||||
await self._create_server_push(event.stream_id, event.raw_path, event.headers)
|
||||
except (
|
||||
BufferCompleteError,
|
||||
KeyError,
|
||||
priority.MissingStreamError,
|
||||
h2.exceptions.ProtocolError,
|
||||
):
|
||||
# Connection has closed whilst blocked on flow control or
|
||||
# connection has advanced ahead of the last emitted event.
|
||||
return
|
||||
|
||||
async def _handle_events(self, events: List[h2.events.Event]) -> None:
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.RequestReceived):
|
||||
await self._create_stream(event)
|
||||
elif isinstance(event, h2.events.DataReceived):
|
||||
await self.streams[event.stream_id].handle(
|
||||
Body(stream_id=event.stream_id, data=event.data)
|
||||
)
|
||||
self.connection.acknowledge_received_data(
|
||||
event.flow_controlled_length, event.stream_id
|
||||
)
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
|
||||
elif isinstance(event, h2.events.StreamReset):
|
||||
await self._close_stream(event.stream_id)
|
||||
await self._window_updated(event.stream_id)
|
||||
elif isinstance(event, h2.events.WindowUpdated):
|
||||
await self._window_updated(event.stream_id)
|
||||
elif isinstance(event, h2.events.PriorityUpdated):
|
||||
await self._priority_updated(event)
|
||||
elif isinstance(event, h2.events.RemoteSettingsChanged):
|
||||
if h2.settings.SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings:
|
||||
await self._window_updated(None)
|
||||
elif isinstance(event, h2.events.ConnectionTerminated):
|
||||
await self.send(Closed())
|
||||
await self._flush()
|
||||
|
||||
async def _flush(self) -> None:
|
||||
data = self.connection.data_to_send()
|
||||
if data != b"":
|
||||
await self.send(RawData(data=data))
|
||||
|
||||
async def _window_updated(self, stream_id: Optional[int]) -> None:
|
||||
if stream_id is None or stream_id == 0:
|
||||
# Unblock all streams
|
||||
for stream_id in list(self.stream_buffers.keys()):
|
||||
self.priority.unblock(stream_id)
|
||||
elif stream_id is not None and stream_id in self.stream_buffers:
|
||||
self.priority.unblock(stream_id)
|
||||
await self.has_data.set()
|
||||
|
||||
async def _priority_updated(self, event: h2.events.PriorityUpdated) -> None:
|
||||
try:
|
||||
self.priority.reprioritize(
|
||||
stream_id=event.stream_id,
|
||||
depends_on=event.depends_on or None,
|
||||
weight=event.weight,
|
||||
exclusive=event.exclusive,
|
||||
)
|
||||
except priority.MissingStreamError:
|
||||
# Received PRIORITY frame before HEADERS frame
|
||||
self.priority.insert_stream(
|
||||
stream_id=event.stream_id,
|
||||
depends_on=event.depends_on or None,
|
||||
weight=event.weight,
|
||||
exclusive=event.exclusive,
|
||||
)
|
||||
self.priority.block(event.stream_id)
|
||||
await self.has_data.set()
|
||||
|
||||
async def _create_stream(self, request: h2.events.RequestReceived) -> None:
|
||||
for name, value in request.headers:
|
||||
if name == b":method":
|
||||
method = value.decode("ascii").upper()
|
||||
elif name == b":path":
|
||||
raw_path = value
|
||||
|
||||
if method == "CONNECT":
|
||||
self.streams[request.stream_id] = WSStream(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
self.ssl,
|
||||
self.client,
|
||||
self.server,
|
||||
self.stream_send,
|
||||
request.stream_id,
|
||||
)
|
||||
else:
|
||||
self.streams[request.stream_id] = HTTPStream(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
self.ssl,
|
||||
self.client,
|
||||
self.server,
|
||||
self.stream_send,
|
||||
request.stream_id,
|
||||
)
|
||||
self.stream_buffers[request.stream_id] = StreamBuffer(self.context.event_class)
|
||||
try:
|
||||
self.priority.insert_stream(request.stream_id)
|
||||
except priority.DuplicateStreamError:
|
||||
# Recieved PRIORITY frame before HEADERS frame
|
||||
pass
|
||||
else:
|
||||
self.priority.block(request.stream_id)
|
||||
|
||||
await self.streams[request.stream_id].handle(
|
||||
Request(
|
||||
stream_id=request.stream_id,
|
||||
headers=filter_pseudo_headers(request.headers),
|
||||
http_version="2",
|
||||
method=method,
|
||||
raw_path=raw_path,
|
||||
)
|
||||
)
|
||||
|
||||
async def _create_server_push(
|
||||
self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]]
|
||||
) -> None:
|
||||
push_stream_id = self.connection.get_next_available_stream_id()
|
||||
request_headers = [(b":method", b"GET"), (b":path", path)]
|
||||
request_headers.extend(headers)
|
||||
request_headers.extend(self.config.response_headers("h2"))
|
||||
try:
|
||||
self.connection.push_stream(
|
||||
stream_id=stream_id,
|
||||
promised_stream_id=push_stream_id,
|
||||
request_headers=request_headers,
|
||||
)
|
||||
await self._flush()
|
||||
except h2.exceptions.ProtocolError:
|
||||
# Client does not accept push promises or we are trying to
|
||||
# push on a push promises request.
|
||||
pass
|
||||
else:
|
||||
event = h2.events.RequestReceived()
|
||||
event.stream_id = push_stream_id
|
||||
event.headers = request_headers
|
||||
await self._create_stream(event)
|
||||
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
|
||||
|
||||
async def _close_stream(self, stream_id: int) -> None:
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams.pop(stream_id)
|
||||
await stream.handle(StreamClosed(stream_id=stream_id))
|
||||
await self.has_data.set()
|
138
.venv/lib/python3.9/site-packages/hypercorn/protocol/h3.py
Normal file
138
.venv/lib/python3.9/site-packages/hypercorn/protocol/h3.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from aioquic.h3.connection import H3Connection
|
||||
from aioquic.h3.events import DataReceived, HeadersReceived
|
||||
from aioquic.h3.exceptions import NoAvailablePushIDError
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.events import QuicEvent
|
||||
|
||||
from .events import (
|
||||
Body,
|
||||
Data,
|
||||
EndBody,
|
||||
EndData,
|
||||
Event as StreamEvent,
|
||||
Request,
|
||||
Response,
|
||||
StreamClosed,
|
||||
)
|
||||
from .http_stream import HTTPStream
|
||||
from .ws_stream import WSStream
|
||||
from ..config import Config
|
||||
from ..typing import ASGIFramework, Context
|
||||
from ..utils import filter_pseudo_headers
|
||||
|
||||
|
||||
class H3Protocol:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
context: Context,
|
||||
client: Optional[Tuple[str, int]],
|
||||
server: Optional[Tuple[str, int]],
|
||||
quic: QuicConnection,
|
||||
send: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.client = client
|
||||
self.config = config
|
||||
self.context = context
|
||||
self.connection = H3Connection(quic)
|
||||
self.send = send
|
||||
self.server = server
|
||||
self.streams: Dict[int, Union[HTTPStream, WSStream]] = {}
|
||||
|
||||
async def handle(self, quic_event: QuicEvent) -> None:
|
||||
for event in self.connection.handle_event(quic_event):
|
||||
if isinstance(event, HeadersReceived):
|
||||
await self._create_stream(event)
|
||||
if event.stream_ended:
|
||||
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
|
||||
elif isinstance(event, DataReceived):
|
||||
await self.streams[event.stream_id].handle(
|
||||
Body(stream_id=event.stream_id, data=event.data)
|
||||
)
|
||||
if event.stream_ended:
|
||||
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
|
||||
|
||||
async def stream_send(self, event: StreamEvent) -> None:
|
||||
if isinstance(event, Response):
|
||||
self.connection.send_headers(
|
||||
event.stream_id,
|
||||
[(b":status", b"%d" % event.status_code)]
|
||||
+ event.headers
|
||||
+ self.config.response_headers("h3"),
|
||||
)
|
||||
await self.send()
|
||||
elif isinstance(event, (Body, Data)):
|
||||
self.connection.send_data(event.stream_id, event.data, False)
|
||||
await self.send()
|
||||
elif isinstance(event, (EndBody, EndData)):
|
||||
self.connection.send_data(event.stream_id, b"", True)
|
||||
await self.send()
|
||||
elif isinstance(event, StreamClosed):
|
||||
pass # ??
|
||||
elif isinstance(event, Request):
|
||||
await self._create_server_push(event.stream_id, event.raw_path, event.headers)
|
||||
|
||||
async def _create_stream(self, request: HeadersReceived) -> None:
|
||||
for name, value in request.headers:
|
||||
if name == b":method":
|
||||
method = value.decode("ascii").upper()
|
||||
elif name == b":path":
|
||||
raw_path = value
|
||||
|
||||
if method == "CONNECT":
|
||||
self.streams[request.stream_id] = WSStream(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
True,
|
||||
self.client,
|
||||
self.server,
|
||||
self.stream_send,
|
||||
request.stream_id,
|
||||
)
|
||||
else:
|
||||
self.streams[request.stream_id] = HTTPStream(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
True,
|
||||
self.client,
|
||||
self.server,
|
||||
self.stream_send,
|
||||
request.stream_id,
|
||||
)
|
||||
|
||||
await self.streams[request.stream_id].handle(
|
||||
Request(
|
||||
stream_id=request.stream_id,
|
||||
headers=filter_pseudo_headers(request.headers),
|
||||
http_version="3",
|
||||
method=method,
|
||||
raw_path=raw_path,
|
||||
)
|
||||
)
|
||||
|
||||
async def _create_server_push(
|
||||
self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]]
|
||||
) -> None:
|
||||
request_headers = [(b":method", b"GET"), (b":path", path)]
|
||||
request_headers.extend(headers)
|
||||
request_headers.extend(self.config.response_headers("h3"))
|
||||
try:
|
||||
push_stream_id = self.connection.send_push_promise(
|
||||
stream_id=stream_id, headers=request_headers
|
||||
)
|
||||
except NoAvailablePushIDError:
|
||||
# Client does not accept push promises or we are trying to
|
||||
# push on a push promises request.
|
||||
pass
|
||||
else:
|
||||
event = HeadersReceived(
|
||||
stream_id=push_stream_id, stream_ended=True, headers=request_headers
|
||||
)
|
||||
await self._create_stream(event)
|
||||
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
|
@@ -0,0 +1,175 @@
|
||||
from enum import auto, Enum
|
||||
from time import time
|
||||
from typing import Awaitable, Callable, Optional, Tuple
|
||||
from urllib.parse import unquote
|
||||
|
||||
from .events import Body, EndBody, Event, Request, Response, StreamClosed
|
||||
from ..config import Config
|
||||
from ..typing import ASGIFramework, ASGISendEvent, Context, HTTPResponseStartEvent, HTTPScope
|
||||
from ..utils import build_and_validate_headers, suppress_body, UnexpectedMessage, valid_server_name
|
||||
|
||||
PUSH_VERSIONS = {"2", "3"}
|
||||
|
||||
|
||||
class ASGIHTTPState(Enum):
|
||||
# The ASGI Spec is clear that a response should not start till the
|
||||
# framework has sent at least one body message hence why this
|
||||
# state tracking is required.
|
||||
REQUEST = auto()
|
||||
RESPONSE = auto()
|
||||
CLOSED = auto()
|
||||
|
||||
|
||||
class HTTPStream:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
context: Context,
|
||||
ssl: bool,
|
||||
client: Optional[Tuple[str, int]],
|
||||
server: Optional[Tuple[str, int]],
|
||||
send: Callable[[Event], Awaitable[None]],
|
||||
stream_id: int,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.client = client
|
||||
self.closed = False
|
||||
self.config = config
|
||||
self.context = context
|
||||
self.response: HTTPResponseStartEvent
|
||||
self.scope: HTTPScope
|
||||
self.send = send
|
||||
self.scheme = "https" if ssl else "http"
|
||||
self.server = server
|
||||
self.start_time: float
|
||||
self.state = ASGIHTTPState.REQUEST
|
||||
self.stream_id = stream_id
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return False
|
||||
|
||||
async def handle(self, event: Event) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
elif isinstance(event, Request):
|
||||
self.start_time = time()
|
||||
path, _, query_string = event.raw_path.partition(b"?")
|
||||
self.scope = {
|
||||
"type": "http",
|
||||
"http_version": event.http_version,
|
||||
"asgi": {"spec_version": "2.1"},
|
||||
"method": event.method,
|
||||
"scheme": self.scheme,
|
||||
"path": unquote(path.decode("ascii")),
|
||||
"raw_path": path,
|
||||
"query_string": query_string,
|
||||
"root_path": self.config.root_path,
|
||||
"headers": event.headers,
|
||||
"client": self.client,
|
||||
"server": self.server,
|
||||
"extensions": {},
|
||||
}
|
||||
if event.http_version in PUSH_VERSIONS:
|
||||
self.scope["extensions"]["http.response.push"] = {}
|
||||
|
||||
if valid_server_name(self.config, event):
|
||||
self.app_put = await self.context.spawn_app(
|
||||
self.app, self.config, self.scope, self.app_send
|
||||
)
|
||||
else:
|
||||
await self._send_error_response(404)
|
||||
self.closed = True
|
||||
|
||||
elif isinstance(event, Body):
|
||||
await self.app_put(
|
||||
{"type": "http.request", "body": bytes(event.data), "more_body": True}
|
||||
)
|
||||
elif isinstance(event, EndBody):
|
||||
await self.app_put({"type": "http.request", "body": b"", "more_body": False})
|
||||
elif isinstance(event, StreamClosed):
|
||||
self.closed = True
|
||||
if self.app_put is not None:
|
||||
await self.app_put({"type": "http.disconnect"}) # type: ignore
|
||||
|
||||
async def app_send(self, message: Optional[ASGISendEvent]) -> None:
|
||||
if self.closed:
|
||||
# Allow app to finish after close
|
||||
return
|
||||
|
||||
if message is None: # ASGI App has finished sending messages
|
||||
# Cleanup if required
|
||||
if self.state == ASGIHTTPState.REQUEST:
|
||||
await self._send_error_response(500)
|
||||
await self.send(StreamClosed(stream_id=self.stream_id))
|
||||
else:
|
||||
if message["type"] == "http.response.start" and self.state == ASGIHTTPState.REQUEST:
|
||||
self.response = message
|
||||
elif (
|
||||
message["type"] == "http.response.push"
|
||||
and self.scope["http_version"] in PUSH_VERSIONS
|
||||
):
|
||||
if not isinstance(message["path"], str):
|
||||
raise TypeError(f"{message['path']} should be a str")
|
||||
headers = [(b":scheme", self.scope["scheme"].encode())]
|
||||
for name, value in self.scope["headers"]:
|
||||
if name == b"host":
|
||||
headers.append((b":authority", value))
|
||||
headers.extend(build_and_validate_headers(message["headers"]))
|
||||
await self.send(
|
||||
Request(
|
||||
stream_id=self.stream_id,
|
||||
headers=headers,
|
||||
http_version=self.scope["http_version"],
|
||||
method="GET",
|
||||
raw_path=message["path"].encode(),
|
||||
)
|
||||
)
|
||||
elif message["type"] == "http.response.body" and self.state in {
|
||||
ASGIHTTPState.REQUEST,
|
||||
ASGIHTTPState.RESPONSE,
|
||||
}:
|
||||
if self.state == ASGIHTTPState.REQUEST:
|
||||
headers = build_and_validate_headers(self.response.get("headers", []))
|
||||
await self.send(
|
||||
Response(
|
||||
stream_id=self.stream_id,
|
||||
headers=headers,
|
||||
status_code=int(self.response["status"]),
|
||||
)
|
||||
)
|
||||
self.state = ASGIHTTPState.RESPONSE
|
||||
|
||||
if (
|
||||
not suppress_body(self.scope["method"], int(self.response["status"]))
|
||||
and message.get("body", b"") != b""
|
||||
):
|
||||
await self.send(
|
||||
Body(stream_id=self.stream_id, data=bytes(message.get("body", b"")))
|
||||
)
|
||||
|
||||
if not message.get("more_body", False):
|
||||
if self.state != ASGIHTTPState.CLOSED:
|
||||
self.state = ASGIHTTPState.CLOSED
|
||||
await self.config.log.access(
|
||||
self.scope, self.response, time() - self.start_time
|
||||
)
|
||||
await self.send(EndBody(stream_id=self.stream_id))
|
||||
await self.send(StreamClosed(stream_id=self.stream_id))
|
||||
else:
|
||||
raise UnexpectedMessage(self.state, message["type"])
|
||||
|
||||
async def _send_error_response(self, status_code: int) -> None:
|
||||
await self.send(
|
||||
Response(
|
||||
stream_id=self.stream_id,
|
||||
headers=[(b"content-length", b"0"), (b"connection", b"close")],
|
||||
status_code=status_code,
|
||||
)
|
||||
)
|
||||
await self.send(EndBody(stream_id=self.stream_id))
|
||||
self.state = ASGIHTTPState.CLOSED
|
||||
await self.config.log.access(
|
||||
self.scope, {"status": status_code, "headers": []}, time() - self.start_time
|
||||
)
|
125
.venv/lib/python3.9/site-packages/hypercorn/protocol/quic.py
Normal file
125
.venv/lib/python3.9/site-packages/hypercorn/protocol/quic.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from functools import partial
|
||||
from typing import Awaitable, Callable, Dict, Optional, Tuple
|
||||
|
||||
from aioquic.buffer import Buffer
|
||||
from aioquic.h3.connection import H3_ALPN
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.events import (
|
||||
ConnectionIdIssued,
|
||||
ConnectionIdRetired,
|
||||
ConnectionTerminated,
|
||||
ProtocolNegotiated,
|
||||
)
|
||||
from aioquic.quic.packet import (
|
||||
encode_quic_version_negotiation,
|
||||
PACKET_TYPE_INITIAL,
|
||||
pull_quic_header,
|
||||
)
|
||||
|
||||
from .h3 import H3Protocol
|
||||
from ..config import Config
|
||||
from ..events import Closed, Event, RawData
|
||||
from ..typing import ASGIFramework, Context
|
||||
|
||||
|
||||
class QuicProtocol:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
context: Context,
|
||||
server: Optional[Tuple[str, int]],
|
||||
send: Callable[[Event], Awaitable[None]],
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.config = config
|
||||
self.context = context
|
||||
self.connections: Dict[bytes, QuicConnection] = {}
|
||||
self.http_connections: Dict[QuicConnection, H3Protocol] = {}
|
||||
self.send = send
|
||||
self.server = server
|
||||
|
||||
self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False)
|
||||
self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile)
|
||||
|
||||
async def handle(self, event: Event) -> None:
|
||||
if isinstance(event, RawData):
|
||||
try:
|
||||
header = pull_quic_header(Buffer(data=event.data), host_cid_length=8)
|
||||
except ValueError:
|
||||
return
|
||||
if (
|
||||
header.version is not None
|
||||
and header.version not in self.quic_config.supported_versions
|
||||
):
|
||||
data = encode_quic_version_negotiation(
|
||||
source_cid=header.destination_cid,
|
||||
destination_cid=header.source_cid,
|
||||
supported_versions=self.quic_config.supported_versions,
|
||||
)
|
||||
await self.send(RawData(data=data, address=event.address))
|
||||
return
|
||||
|
||||
connection = self.connections.get(header.destination_cid)
|
||||
if (
|
||||
connection is None
|
||||
and len(event.data) >= 1200
|
||||
and header.packet_type == PACKET_TYPE_INITIAL
|
||||
):
|
||||
connection = QuicConnection(
|
||||
configuration=self.quic_config,
|
||||
original_destination_connection_id=header.destination_cid,
|
||||
)
|
||||
self.connections[header.destination_cid] = connection
|
||||
self.connections[connection.host_cid] = connection
|
||||
|
||||
if connection is not None:
|
||||
connection.receive_datagram(event.data, event.address, now=self.context.time())
|
||||
await self._handle_events(connection, event.address)
|
||||
elif isinstance(event, Closed):
|
||||
pass
|
||||
|
||||
async def send_all(self, connection: QuicConnection) -> None:
|
||||
for data, address in connection.datagrams_to_send(now=self.context.time()):
|
||||
await self.send(RawData(data=data, address=address))
|
||||
|
||||
async def _handle_events(
|
||||
self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None
|
||||
) -> None:
|
||||
event = connection.next_event()
|
||||
while event is not None:
|
||||
if isinstance(event, ConnectionTerminated):
|
||||
pass
|
||||
elif isinstance(event, ProtocolNegotiated):
|
||||
self.http_connections[connection] = H3Protocol(
|
||||
self.app,
|
||||
self.config,
|
||||
self.context,
|
||||
client,
|
||||
self.server,
|
||||
connection,
|
||||
partial(self.send_all, connection),
|
||||
)
|
||||
elif isinstance(event, ConnectionIdIssued):
|
||||
self.connections[event.connection_id] = connection
|
||||
elif isinstance(event, ConnectionIdRetired):
|
||||
del self.connections[event.connection_id]
|
||||
|
||||
if connection in self.http_connections:
|
||||
await self.http_connections[connection].handle(event)
|
||||
|
||||
event = connection.next_event()
|
||||
|
||||
await self.send_all(connection)
|
||||
|
||||
timer = connection.get_timer()
|
||||
if timer is not None:
|
||||
self.context.spawn(self._handle_timer, timer, connection)
|
||||
|
||||
async def _handle_timer(self, timer: float, connection: QuicConnection) -> None:
|
||||
wait = max(0, timer - self.context.time())
|
||||
await self.context.sleep(wait)
|
||||
if connection._close_at is not None:
|
||||
connection.handle_timer(now=self.context.time())
|
||||
await self._handle_events(connection, None)
|
@@ -0,0 +1,349 @@
|
||||
from enum import auto, Enum
|
||||
from time import time
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple, Union
|
||||
from urllib.parse import unquote
|
||||
|
||||
from wsproto.connection import Connection, ConnectionState, ConnectionType
|
||||
from wsproto.events import (
|
||||
BytesMessage,
|
||||
CloseConnection,
|
||||
Event as WSProtoEvent,
|
||||
Message,
|
||||
Ping,
|
||||
TextMessage,
|
||||
)
|
||||
from wsproto.extensions import Extension, PerMessageDeflate
|
||||
from wsproto.frame_protocol import CloseReason
|
||||
from wsproto.handshake import server_extensions_handshake, WEBSOCKET_VERSION
|
||||
from wsproto.utilities import generate_accept_token, split_comma_header
|
||||
|
||||
from .events import Body, Data, EndBody, EndData, Event, Request, Response, StreamClosed
|
||||
from ..config import Config
|
||||
from ..typing import (
|
||||
ASGIFramework,
|
||||
ASGISendEvent,
|
||||
Context,
|
||||
WebsocketAcceptEvent,
|
||||
WebsocketResponseBodyEvent,
|
||||
WebsocketResponseStartEvent,
|
||||
WebsocketScope,
|
||||
)
|
||||
from ..utils import build_and_validate_headers, suppress_body, UnexpectedMessage, valid_server_name
|
||||
|
||||
|
||||
class ASGIWebsocketState(Enum):
|
||||
# Hypercorn supports the ASGI websocket HTTP response extension,
|
||||
# which allows HTTP responses rather than acceptance.
|
||||
HANDSHAKE = auto()
|
||||
CONNECTED = auto()
|
||||
RESPONSE = auto()
|
||||
CLOSED = auto()
|
||||
HTTPCLOSED = auto()
|
||||
|
||||
|
||||
class FrameTooLarge(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Handshake:
|
||||
def __init__(self, headers: List[Tuple[bytes, bytes]], http_version: str) -> None:
|
||||
self.http_version = http_version
|
||||
self.connection_tokens: Optional[List[str]] = None
|
||||
self.extensions: Optional[List[str]] = None
|
||||
self.key: Optional[bytes] = None
|
||||
self.subprotocols: Optional[List[str]] = None
|
||||
self.upgrade: Optional[bytes] = None
|
||||
self.version: Optional[bytes] = None
|
||||
for name, value in headers:
|
||||
name = name.lower()
|
||||
if name == b"connection":
|
||||
self.connection_tokens = split_comma_header(value)
|
||||
elif name == b"sec-websocket-extensions":
|
||||
self.extensions = split_comma_header(value)
|
||||
elif name == b"sec-websocket-key":
|
||||
self.key = value
|
||||
elif name == b"sec-websocket-protocol":
|
||||
self.subprotocols = split_comma_header(value)
|
||||
elif name == b"sec-websocket-version":
|
||||
self.version = value
|
||||
elif name == b"upgrade":
|
||||
self.upgrade = value
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
if self.http_version < "1.1":
|
||||
return False
|
||||
elif self.http_version == "1.1":
|
||||
if self.key is None:
|
||||
return False
|
||||
if self.connection_tokens is None or not any(
|
||||
token.lower() == "upgrade" for token in self.connection_tokens
|
||||
):
|
||||
return False
|
||||
if self.upgrade.lower() != b"websocket":
|
||||
return False
|
||||
|
||||
if self.version != WEBSOCKET_VERSION:
|
||||
return False
|
||||
return True
|
||||
|
||||
def accept(
|
||||
self, subprotocol: Optional[str]
|
||||
) -> Tuple[int, List[Tuple[bytes, bytes]], Connection]:
|
||||
headers = []
|
||||
if subprotocol is not None:
|
||||
if subprotocol not in self.subprotocols:
|
||||
raise Exception("Invalid Subprotocol")
|
||||
else:
|
||||
headers.append((b"sec-websocket-protocol", subprotocol.encode()))
|
||||
|
||||
extensions: List[Extension] = [PerMessageDeflate()]
|
||||
accepts = None
|
||||
if False and self.extensions is not None:
|
||||
accepts = server_extensions_handshake(self.extensions, extensions)
|
||||
|
||||
if accepts:
|
||||
headers.append((b"sec-websocket-extensions", accepts))
|
||||
|
||||
if self.key is not None:
|
||||
headers.append((b"sec-websocket-accept", generate_accept_token(self.key)))
|
||||
|
||||
status_code = 200
|
||||
if self.http_version == "1.1":
|
||||
headers.extend([(b"upgrade", b"WebSocket"), (b"connection", b"Upgrade")])
|
||||
status_code = 101
|
||||
|
||||
return status_code, headers, Connection(ConnectionType.SERVER, extensions)
|
||||
|
||||
|
||||
class WebsocketBuffer:
|
||||
def __init__(self, max_length: int) -> None:
|
||||
self.value: Optional[Union[bytes, str]] = None
|
||||
self.max_length = max_length
|
||||
|
||||
def extend(self, event: Message) -> None:
|
||||
if self.value is None:
|
||||
if isinstance(event, TextMessage):
|
||||
self.value = ""
|
||||
else:
|
||||
self.value = b""
|
||||
self.value += event.data
|
||||
if len(self.value) > self.max_length:
|
||||
raise FrameTooLarge()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.value = None
|
||||
|
||||
def to_message(self) -> dict:
|
||||
return {
|
||||
"type": "websocket.receive",
|
||||
"bytes": self.value if isinstance(self.value, bytes) else None,
|
||||
"text": self.value if isinstance(self.value, str) else None,
|
||||
}
|
||||
|
||||
|
||||
class WSStream:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
context: Context,
|
||||
ssl: bool,
|
||||
client: Optional[Tuple[str, int]],
|
||||
server: Optional[Tuple[str, int]],
|
||||
send: Callable[[Event], Awaitable[None]],
|
||||
stream_id: int,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.app_put: Optional[Callable] = None
|
||||
self.buffer = WebsocketBuffer(config.websocket_max_message_size)
|
||||
self.client = client
|
||||
self.closed = False
|
||||
self.config = config
|
||||
self.context = context
|
||||
self.response: WebsocketResponseStartEvent
|
||||
self.scope: WebsocketScope
|
||||
self.send = send
|
||||
# RFC 8441 for HTTP/2 says use http or https, ASGI says ws or wss
|
||||
self.scheme = "wss" if ssl else "ws"
|
||||
self.server = server
|
||||
self.start_time: float
|
||||
self.state = ASGIWebsocketState.HANDSHAKE
|
||||
self.stream_id = stream_id
|
||||
|
||||
self.connection: Connection
|
||||
self.handshake: Handshake
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return self.state in {ASGIWebsocketState.CLOSED, ASGIWebsocketState.HTTPCLOSED}
|
||||
|
||||
async def handle(self, event: Event) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
elif isinstance(event, Request):
|
||||
self.start_time = time()
|
||||
self.handshake = Handshake(event.headers, event.http_version)
|
||||
path, _, query_string = event.raw_path.partition(b"?")
|
||||
self.scope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"spec_version": "2.1"},
|
||||
"scheme": self.scheme,
|
||||
"http_version": event.http_version,
|
||||
"path": unquote(path.decode("ascii")),
|
||||
"raw_path": path,
|
||||
"query_string": query_string,
|
||||
"root_path": self.config.root_path,
|
||||
"headers": event.headers,
|
||||
"client": self.client,
|
||||
"server": self.server,
|
||||
"subprotocols": self.handshake.subprotocols or [],
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
|
||||
if not valid_server_name(self.config, event):
|
||||
await self._send_error_response(404)
|
||||
self.closed = True
|
||||
elif not self.handshake.is_valid():
|
||||
await self._send_error_response(400)
|
||||
self.closed = True
|
||||
else:
|
||||
self.app_put = await self.context.spawn_app(
|
||||
self.app, self.config, self.scope, self.app_send
|
||||
)
|
||||
await self.app_put({"type": "websocket.connect"}) # type: ignore
|
||||
elif isinstance(event, (Body, Data)):
|
||||
self.connection.receive_data(event.data)
|
||||
await self._handle_events()
|
||||
elif isinstance(event, StreamClosed):
|
||||
self.closed = True
|
||||
if self.app_put is not None:
|
||||
if self.state in {ASGIWebsocketState.HTTPCLOSED, ASGIWebsocketState.CLOSED}:
|
||||
code = CloseReason.NORMAL_CLOSURE.value
|
||||
else:
|
||||
code = CloseReason.ABNORMAL_CLOSURE.value
|
||||
await self.app_put({"type": "websocket.disconnect", "code": code})
|
||||
|
||||
async def app_send(self, message: Optional[ASGISendEvent]) -> None:
|
||||
if self.closed:
|
||||
# Allow app to finish after close
|
||||
return
|
||||
|
||||
if message is None: # ASGI App has finished sending messages
|
||||
# Cleanup if required
|
||||
if self.state == ASGIWebsocketState.HANDSHAKE:
|
||||
await self._send_error_response(500)
|
||||
await self.config.log.access(
|
||||
self.scope, {"status": 500, "headers": []}, time() - self.start_time
|
||||
)
|
||||
elif self.state == ASGIWebsocketState.CONNECTED:
|
||||
await self._send_wsproto_event(CloseConnection(code=CloseReason.ABNORMAL_CLOSURE))
|
||||
await self.send(StreamClosed(stream_id=self.stream_id))
|
||||
else:
|
||||
if message["type"] == "websocket.accept" and self.state == ASGIWebsocketState.HANDSHAKE:
|
||||
await self._accept(message)
|
||||
elif (
|
||||
message["type"] == "websocket.http.response.start"
|
||||
and self.state == ASGIWebsocketState.HANDSHAKE
|
||||
):
|
||||
self.response = message
|
||||
elif message["type"] == "websocket.http.response.body" and self.state in {
|
||||
ASGIWebsocketState.HANDSHAKE,
|
||||
ASGIWebsocketState.RESPONSE,
|
||||
}:
|
||||
await self._send_rejection(message)
|
||||
elif message["type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED:
|
||||
event: WSProtoEvent
|
||||
if message.get("bytes") is not None:
|
||||
event = BytesMessage(data=bytes(message["bytes"]))
|
||||
elif not isinstance(message["text"], str):
|
||||
raise TypeError(f"{message['text']} should be a str")
|
||||
else:
|
||||
event = TextMessage(data=message["text"])
|
||||
await self._send_wsproto_event(event)
|
||||
elif (
|
||||
message["type"] == "websocket.close" and self.state == ASGIWebsocketState.HANDSHAKE
|
||||
):
|
||||
self.state = ASGIWebsocketState.HTTPCLOSED
|
||||
await self._send_error_response(403)
|
||||
elif message["type"] == "websocket.close":
|
||||
self.state = ASGIWebsocketState.CLOSED
|
||||
await self._send_wsproto_event(
|
||||
CloseConnection(code=int(message.get("code", CloseReason.NORMAL_CLOSURE)))
|
||||
)
|
||||
await self.send(EndData(stream_id=self.stream_id))
|
||||
else:
|
||||
raise UnexpectedMessage(self.state, message["type"])
|
||||
|
||||
async def _handle_events(self) -> None:
|
||||
for event in self.connection.events():
|
||||
if isinstance(event, Message):
|
||||
try:
|
||||
self.buffer.extend(event)
|
||||
except FrameTooLarge:
|
||||
await self._send_wsproto_event(
|
||||
CloseConnection(code=CloseReason.MESSAGE_TOO_BIG)
|
||||
)
|
||||
break
|
||||
|
||||
if event.message_finished:
|
||||
await self.app_put(self.buffer.to_message())
|
||||
self.buffer.clear()
|
||||
elif isinstance(event, Ping):
|
||||
await self._send_wsproto_event(event.response())
|
||||
elif isinstance(event, CloseConnection):
|
||||
if self.connection.state == ConnectionState.REMOTE_CLOSING:
|
||||
await self._send_wsproto_event(event.response())
|
||||
await self.send(StreamClosed(stream_id=self.stream_id))
|
||||
|
||||
async def _send_error_response(self, status_code: int) -> None:
|
||||
await self.send(
|
||||
Response(
|
||||
stream_id=self.stream_id,
|
||||
status_code=status_code,
|
||||
headers=[(b"content-length", b"0"), (b"connection", b"close")],
|
||||
)
|
||||
)
|
||||
await self.send(EndBody(stream_id=self.stream_id))
|
||||
await self.config.log.access(
|
||||
self.scope, {"status": status_code, "headers": []}, time() - self.start_time
|
||||
)
|
||||
|
||||
async def _send_wsproto_event(self, event: WSProtoEvent) -> None:
|
||||
data = self.connection.send(event)
|
||||
await self.send(Data(stream_id=self.stream_id, data=data))
|
||||
|
||||
async def _accept(self, message: WebsocketAcceptEvent) -> None:
|
||||
self.state = ASGIWebsocketState.CONNECTED
|
||||
status_code, headers, self.connection = self.handshake.accept(message.get("subprotocol"))
|
||||
await self.send(
|
||||
Response(stream_id=self.stream_id, status_code=status_code, headers=headers)
|
||||
)
|
||||
await self.config.log.access(
|
||||
self.scope, {"status": status_code, "headers": []}, time() - self.start_time
|
||||
)
|
||||
if self.config.websocket_ping_interval is not None:
|
||||
self.context.spawn(self._send_pings)
|
||||
|
||||
async def _send_rejection(self, message: WebsocketResponseBodyEvent) -> None:
|
||||
body_suppressed = suppress_body("GET", self.response["status"])
|
||||
if self.state == ASGIWebsocketState.HANDSHAKE:
|
||||
headers = build_and_validate_headers(self.response["headers"])
|
||||
await self.send(
|
||||
Response(
|
||||
stream_id=self.stream_id,
|
||||
status_code=int(self.response["status"]),
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
self.state = ASGIWebsocketState.RESPONSE
|
||||
if not body_suppressed:
|
||||
await self.send(Body(stream_id=self.stream_id, data=bytes(message.get("body", b""))))
|
||||
if not message.get("more_body", False):
|
||||
self.state = ASGIWebsocketState.HTTPCLOSED
|
||||
await self.send(EndBody(stream_id=self.stream_id))
|
||||
await self.config.log.access(self.scope, self.response, time() - self.start_time)
|
||||
|
||||
async def _send_pings(self) -> None:
|
||||
while not self.closed:
|
||||
await self._send_wsproto_event(Ping())
|
||||
await self.context.sleep(self.config.websocket_ping_interval)
|
1
.venv/lib/python3.9/site-packages/hypercorn/py.typed
Normal file
1
.venv/lib/python3.9/site-packages/hypercorn/py.typed
Normal file
@@ -0,0 +1 @@
|
||||
Marker
|
80
.venv/lib/python3.9/site-packages/hypercorn/run.py
Normal file
80
.venv/lib/python3.9/site-packages/hypercorn/run.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import platform
|
||||
import random
|
||||
import signal
|
||||
import time
|
||||
from multiprocessing import Event, Process
|
||||
from typing import Any
|
||||
|
||||
from .config import Config
|
||||
from .typing import WorkerFunc
|
||||
from .utils import write_pid_file
|
||||
|
||||
|
||||
def run(config: Config) -> None:
|
||||
if config.pid_path is not None:
|
||||
write_pid_file(config.pid_path)
|
||||
|
||||
worker_func: WorkerFunc
|
||||
if config.worker_class == "asyncio":
|
||||
from .asyncio.run import asyncio_worker
|
||||
|
||||
worker_func = asyncio_worker
|
||||
elif config.worker_class == "uvloop":
|
||||
from .asyncio.run import uvloop_worker
|
||||
|
||||
worker_func = uvloop_worker
|
||||
elif config.worker_class == "trio":
|
||||
from .trio.run import trio_worker
|
||||
|
||||
worker_func = trio_worker
|
||||
else:
|
||||
raise ValueError(f"No worker of class {config.worker_class} exists")
|
||||
|
||||
if config.workers == 1:
|
||||
worker_func(config)
|
||||
else:
|
||||
run_multiple(config, worker_func)
|
||||
|
||||
|
||||
def run_multiple(config: Config, worker_func: WorkerFunc) -> None:
|
||||
if config.use_reloader:
|
||||
raise RuntimeError("Reloader can only be used with a single worker")
|
||||
|
||||
sockets = config.create_sockets()
|
||||
|
||||
processes = []
|
||||
|
||||
# Ignore SIGINT before creating the processes, so that they
|
||||
# inherit the signal handling. This means that the shutdown
|
||||
# function controls the shutdown.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
for _ in range(config.workers):
|
||||
process = Process(
|
||||
target=worker_func,
|
||||
kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets},
|
||||
)
|
||||
process.daemon = True
|
||||
process.start()
|
||||
processes.append(process)
|
||||
if platform.system() == "Windows":
|
||||
time.sleep(0.1 * random.random())
|
||||
|
||||
def shutdown(*args: Any) -> None:
|
||||
shutdown_event.set()
|
||||
|
||||
for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}:
|
||||
if hasattr(signal, signal_name):
|
||||
signal.signal(getattr(signal, signal_name), shutdown)
|
||||
|
||||
for process in processes:
|
||||
process.join()
|
||||
for process in processes:
|
||||
process.terminate()
|
||||
|
||||
for sock in sockets.secure_sockets:
|
||||
sock.close()
|
||||
for sock in sockets.insecure_sockets:
|
||||
sock.close()
|
93
.venv/lib/python3.9/site-packages/hypercorn/statsd.py
Normal file
93
.venv/lib/python3.9/site-packages/hypercorn/statsd.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from .logging import Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
from .typing import ResponseSummary, WWWScope
|
||||
|
||||
METRIC_VAR = "metric"
|
||||
VALUE_VAR = "value"
|
||||
MTYPE_VAR = "mtype"
|
||||
GAUGE_TYPE = "gauge"
|
||||
COUNTER_TYPE = "counter"
|
||||
HISTOGRAM_TYPE = "histogram"
|
||||
|
||||
|
||||
class StatsdLogger(Logger):
|
||||
def __init__(self, config: "Config") -> None:
|
||||
super().__init__(config)
|
||||
self.dogstatsd_tags = config.dogstatsd_tags
|
||||
self.prefix = config.statsd_prefix
|
||||
if len(self.prefix) and self.prefix[-1] != ".":
|
||||
self.prefix += "."
|
||||
|
||||
async def critical(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
await super().critical(message, *args, **kwargs)
|
||||
await self.increment("hypercorn.log.critical", 1)
|
||||
|
||||
async def error(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
await super().error(message, *args, **kwargs)
|
||||
self.increment("hypercorn.log.error", 1)
|
||||
|
||||
async def warning(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
await super().warning(message, *args, **kwargs)
|
||||
self.increment("hypercorn.log.warning", 1)
|
||||
|
||||
async def info(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
await super().info(message, *args, **kwargs)
|
||||
|
||||
async def debug(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
await super().debug(message, *args, **kwargs)
|
||||
|
||||
async def exception(self, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
await super().exception(message, *args, **kwargs)
|
||||
await self.increment("hypercorn.log.exception", 1)
|
||||
|
||||
async def log(self, level: int, message: str, *args: Any, **kwargs: Any) -> None:
|
||||
try:
|
||||
extra = kwargs.get("extra", None)
|
||||
if extra is not None:
|
||||
metric = extra.get(METRIC_VAR, None)
|
||||
value = extra.get(VALUE_VAR, None)
|
||||
type_ = extra.get(MTYPE_VAR, None)
|
||||
if metric and value and type_:
|
||||
if type_ == GAUGE_TYPE:
|
||||
await self.gauge(metric, value)
|
||||
elif type_ == COUNTER_TYPE:
|
||||
await self.increment(metric, value)
|
||||
elif type_ == HISTOGRAM_TYPE:
|
||||
await self.histogram(metric, value)
|
||||
|
||||
if message:
|
||||
await super().log(level, message, *args, **kwargs)
|
||||
except Exception:
|
||||
await super().warning("Failed to log to statsd", exc_info=True)
|
||||
|
||||
async def access(
|
||||
self, request: "WWWScope", response: "ResponseSummary", request_time: float
|
||||
) -> None:
|
||||
await super().access(request, response, request_time)
|
||||
await self.histogram("hypercorn.request.duration", request_time * 1_000)
|
||||
await self.increment("hypercorn.requests", 1)
|
||||
await self.increment(f"hypercorn.request.status.{response['status']}", 1)
|
||||
|
||||
async def gauge(self, name: str, value: int) -> None:
|
||||
await self._send(f"{self.prefix}{name}:{value}|g")
|
||||
|
||||
async def increment(self, name: str, value: int, sampling_rate: float = 1.0) -> None:
|
||||
await self._send(f"{self.prefix}{name}:{value}|c|@{sampling_rate}")
|
||||
|
||||
async def decrement(self, name: str, value: int, sampling_rate: float = 1.0) -> None:
|
||||
await self._send(f"{self.prefix}{name}:-{value}|c|@{sampling_rate}")
|
||||
|
||||
async def histogram(self, name: str, value: float) -> None:
|
||||
await self._send(f"{self.prefix}{name}:{value}|ms")
|
||||
|
||||
async def _send(self, message: str) -> None:
|
||||
if self.dogstatsd_tags:
|
||||
message = f"{message}|#{self.dogstatsd_tags}"
|
||||
await self._socket_send(message.encode("ascii"))
|
||||
|
||||
async def _socket_send(self, message: bytes) -> None:
|
||||
raise NotImplementedError()
|
42
.venv/lib/python3.9/site-packages/hypercorn/trio/__init__.py
Normal file
42
.venv/lib/python3.9/site-packages/hypercorn/trio/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import warnings
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
import trio
|
||||
|
||||
from .run import worker_serve
|
||||
from ..config import Config
|
||||
from ..typing import ASGIFramework
|
||||
|
||||
|
||||
async def serve(
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
*,
|
||||
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
|
||||
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
|
||||
) -> None:
|
||||
"""Serve an ASGI framework app given the config.
|
||||
|
||||
This allows for a programmatic way to serve an ASGI framework, it
|
||||
can be used via,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trio.run(serve, app, config)
|
||||
|
||||
It is assumed that the event-loop is configured before calling
|
||||
this function, therefore configuration values that relate to loop
|
||||
setup or process setup are ignored.
|
||||
|
||||
Arguments:
|
||||
app: The ASGI application to serve.
|
||||
config: A Hypercorn configuration object.
|
||||
shutdown_trigger: This should return to trigger a graceful
|
||||
shutdown.
|
||||
"""
|
||||
if config.debug:
|
||||
warnings.warn("The config `debug` has no affect when using serve", Warning)
|
||||
if config.workers != 1:
|
||||
warnings.warn("The config `workers` has no affect when using serve", Warning)
|
||||
|
||||
await worker_serve(app, config, shutdown_trigger=shutdown_trigger, task_status=task_status)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
83
.venv/lib/python3.9/site-packages/hypercorn/trio/context.py
Normal file
83
.venv/lib/python3.9/site-packages/hypercorn/trio/context.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Any, Awaitable, Callable, Optional, Type, Union
|
||||
|
||||
import trio
|
||||
|
||||
from ..config import Config
|
||||
from ..typing import (
|
||||
ASGIFramework,
|
||||
ASGIReceiveCallable,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
Event,
|
||||
Scope,
|
||||
)
|
||||
from ..utils import invoke_asgi
|
||||
|
||||
|
||||
class EventWrapper:
|
||||
def __init__(self) -> None:
|
||||
self._event = trio.Event()
|
||||
|
||||
async def clear(self) -> None:
|
||||
self._event = trio.Event()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
async def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
|
||||
async def _handle(
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
scope: Scope,
|
||||
receive: ASGIReceiveCallable,
|
||||
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
|
||||
) -> None:
|
||||
try:
|
||||
await invoke_asgi(app, scope, receive, send)
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
except trio.MultiError as error:
|
||||
errors = trio.MultiError.filter(
|
||||
lambda exc: None if isinstance(exc, trio.Cancelled) else exc, root_exc=error
|
||||
)
|
||||
if errors is not None:
|
||||
await config.log.exception("Error in ASGI Framework")
|
||||
await send(None)
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
await config.log.exception("Error in ASGI Framework")
|
||||
finally:
|
||||
await send(None)
|
||||
|
||||
|
||||
class Context:
|
||||
event_class: Type[Event] = EventWrapper
|
||||
|
||||
def __init__(self, nursery: trio._core._run.Nursery) -> None:
|
||||
self.nursery = nursery
|
||||
|
||||
async def spawn_app(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
scope: Scope,
|
||||
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
|
||||
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
|
||||
app_send_channel, app_receive_channel = trio.open_memory_channel(config.max_app_queue_size)
|
||||
self.nursery.start_soon(_handle, app, config, scope, app_receive_channel.receive, send)
|
||||
return app_send_channel.send
|
||||
|
||||
def spawn(self, func: Callable, *args: Any) -> None:
|
||||
self.nursery.start_soon(func, *args)
|
||||
|
||||
@staticmethod
|
||||
async def sleep(wait: Union[float, int]) -> None:
|
||||
return await trio.sleep(wait)
|
||||
|
||||
@staticmethod
|
||||
def time() -> float:
|
||||
return trio.current_time()
|
79
.venv/lib/python3.9/site-packages/hypercorn/trio/lifespan.py
Normal file
79
.venv/lib/python3.9/site-packages/hypercorn/trio/lifespan.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import trio
|
||||
|
||||
from ..config import Config
|
||||
from ..typing import ASGIFramework, ASGIReceiveEvent, ASGISendEvent, LifespanScope
|
||||
from ..utils import invoke_asgi, LifespanFailure, LifespanTimeout
|
||||
|
||||
|
||||
class UnexpectedMessage(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Lifespan:
|
||||
def __init__(self, app: ASGIFramework, config: Config) -> None:
|
||||
self.app = app
|
||||
self.config = config
|
||||
self.startup = trio.Event()
|
||||
self.shutdown = trio.Event()
|
||||
self.app_send_channel, self.app_receive_channel = trio.open_memory_channel(
|
||||
config.max_app_queue_size
|
||||
)
|
||||
self.supported = True
|
||||
|
||||
async def handle_lifespan(
|
||||
self, *, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED
|
||||
) -> None:
|
||||
task_status.started()
|
||||
scope: LifespanScope = {"type": "lifespan", "asgi": {"spec_version": "2.0"}}
|
||||
try:
|
||||
await invoke_asgi(self.app, scope, self.asgi_receive, self.asgi_send)
|
||||
except LifespanFailure:
|
||||
# Lifespan failures should crash the server
|
||||
raise
|
||||
except Exception:
|
||||
self.supported = False
|
||||
await self.config.log.exception(
|
||||
"ASGI Framework Lifespan error, continuing without Lifespan support"
|
||||
)
|
||||
finally:
|
||||
self.startup.set()
|
||||
self.shutdown.set()
|
||||
await self.app_send_channel.aclose()
|
||||
await self.app_receive_channel.aclose()
|
||||
|
||||
async def wait_for_startup(self) -> None:
|
||||
if not self.supported:
|
||||
return
|
||||
|
||||
await self.app_send_channel.send({"type": "lifespan.startup"})
|
||||
try:
|
||||
with trio.fail_after(self.config.startup_timeout):
|
||||
await self.startup.wait()
|
||||
except trio.TooSlowError as error:
|
||||
raise LifespanTimeout("startup") from error
|
||||
|
||||
async def wait_for_shutdown(self) -> None:
|
||||
if not self.supported:
|
||||
return
|
||||
|
||||
await self.app_send_channel.send({"type": "lifespan.shutdown"})
|
||||
try:
|
||||
with trio.fail_after(self.config.shutdown_timeout):
|
||||
await self.shutdown.wait()
|
||||
except trio.TooSlowError as error:
|
||||
raise LifespanTimeout("startup") from error
|
||||
|
||||
async def asgi_receive(self) -> ASGIReceiveEvent:
|
||||
return await self.app_receive_channel.receive()
|
||||
|
||||
async def asgi_send(self, message: ASGISendEvent) -> None:
|
||||
if message["type"] == "lifespan.startup.complete":
|
||||
self.startup.set()
|
||||
elif message["type"] == "lifespan.shutdown.complete":
|
||||
self.shutdown.set()
|
||||
elif message["type"] == "lifespan.startup.failed":
|
||||
raise LifespanFailure("startup", message["message"])
|
||||
elif message["type"] == "lifespan.shutdown.failed":
|
||||
raise LifespanFailure("shutdown", message["message"])
|
||||
else:
|
||||
raise UnexpectedMessage(message["type"])
|
119
.venv/lib/python3.9/site-packages/hypercorn/trio/run.py
Normal file
119
.venv/lib/python3.9/site-packages/hypercorn/trio/run.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from functools import partial
|
||||
from multiprocessing.synchronize import Event as EventType
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
import trio
|
||||
|
||||
from .lifespan import Lifespan
|
||||
from .statsd import StatsdLogger
|
||||
from .tcp_server import TCPServer
|
||||
from .udp_server import UDPServer
|
||||
from ..config import Config, Sockets
|
||||
from ..typing import ASGIFramework
|
||||
from ..utils import (
|
||||
check_multiprocess_shutdown_event,
|
||||
load_application,
|
||||
MustReloadException,
|
||||
observe_changes,
|
||||
raise_shutdown,
|
||||
repr_socket_addr,
|
||||
restart,
|
||||
Shutdown,
|
||||
)
|
||||
|
||||
|
||||
async def worker_serve(
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
*,
|
||||
sockets: Optional[Sockets] = None,
|
||||
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
|
||||
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
|
||||
) -> None:
|
||||
config.set_statsd_logger_class(StatsdLogger)
|
||||
|
||||
lifespan = Lifespan(app, config)
|
||||
reload_ = False
|
||||
|
||||
async with trio.open_nursery() as lifespan_nursery:
|
||||
await lifespan_nursery.start(lifespan.handle_lifespan)
|
||||
await lifespan.wait_for_startup()
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
if config.use_reloader:
|
||||
nursery.start_soon(observe_changes, trio.sleep)
|
||||
|
||||
if shutdown_trigger is not None:
|
||||
nursery.start_soon(raise_shutdown, shutdown_trigger)
|
||||
|
||||
if sockets is None:
|
||||
sockets = config.create_sockets()
|
||||
for sock in sockets.secure_sockets:
|
||||
sock.listen(config.backlog)
|
||||
for sock in sockets.insecure_sockets:
|
||||
sock.listen(config.backlog)
|
||||
|
||||
ssl_context = config.create_ssl_context()
|
||||
listeners = []
|
||||
binds = []
|
||||
for sock in sockets.secure_sockets:
|
||||
listeners.append(
|
||||
trio.SSLListener(
|
||||
trio.SocketListener(trio.socket.from_stdlib_socket(sock)),
|
||||
ssl_context,
|
||||
https_compatible=True,
|
||||
)
|
||||
)
|
||||
bind = repr_socket_addr(sock.family, sock.getsockname())
|
||||
binds.append(f"https://{bind}")
|
||||
await config.log.info(f"Running on https://{bind} (CTRL + C to quit)")
|
||||
|
||||
for sock in sockets.insecure_sockets:
|
||||
listeners.append(trio.SocketListener(trio.socket.from_stdlib_socket(sock)))
|
||||
bind = repr_socket_addr(sock.family, sock.getsockname())
|
||||
binds.append(f"http://{bind}")
|
||||
await config.log.info(f"Running on http://{bind} (CTRL + C to quit)")
|
||||
|
||||
for sock in sockets.quic_sockets:
|
||||
await nursery.start(UDPServer(app, config, sock, nursery).run)
|
||||
bind = repr_socket_addr(sock.family, sock.getsockname())
|
||||
await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)")
|
||||
|
||||
task_status.started(binds)
|
||||
await trio.serve_listeners(
|
||||
partial(TCPServer, app, config), listeners, handler_nursery=lifespan_nursery
|
||||
)
|
||||
|
||||
except MustReloadException:
|
||||
reload_ = True
|
||||
except (Shutdown, KeyboardInterrupt):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
await trio.sleep(config.graceful_timeout)
|
||||
except (Shutdown, KeyboardInterrupt):
|
||||
pass
|
||||
|
||||
await lifespan.wait_for_shutdown()
|
||||
lifespan_nursery.cancel_scope.cancel()
|
||||
|
||||
if reload_:
|
||||
restart()
|
||||
|
||||
|
||||
def trio_worker(
|
||||
config: Config, sockets: Optional[Sockets] = None, shutdown_event: Optional[EventType] = None
|
||||
) -> None:
|
||||
if sockets is not None:
|
||||
for sock in sockets.secure_sockets:
|
||||
sock.listen(config.backlog)
|
||||
for sock in sockets.insecure_sockets:
|
||||
sock.listen(config.backlog)
|
||||
app = load_application(config.application_path)
|
||||
|
||||
shutdown_trigger = None
|
||||
if shutdown_event is not None:
|
||||
shutdown_trigger = partial(check_multiprocess_shutdown_event, shutdown_event, trio.sleep)
|
||||
|
||||
trio.run(partial(worker_serve, app, config, sockets=sockets, shutdown_trigger=shutdown_trigger))
|
14
.venv/lib/python3.9/site-packages/hypercorn/trio/statsd.py
Normal file
14
.venv/lib/python3.9/site-packages/hypercorn/trio/statsd.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import trio
|
||||
|
||||
from ..config import Config
|
||||
from ..statsd import StatsdLogger as Base
|
||||
|
||||
|
||||
class StatsdLogger(Base):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__(config)
|
||||
self.address = config.statsd_host.rsplit(":", 1)
|
||||
self.socket = trio.socket.socket(trio.socket.AF_INET, trio.socket.SOCK_DGRAM)
|
||||
|
||||
async def _socket_send(self, message: bytes) -> None:
|
||||
await self.socket.sendto(message, self.address)
|
152
.venv/lib/python3.9/site-packages/hypercorn/trio/tcp_server.py
Normal file
152
.venv/lib/python3.9/site-packages/hypercorn/trio/tcp_server.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Any, Callable, Generator, Optional
|
||||
|
||||
import trio
|
||||
|
||||
from .context import Context
|
||||
from ..config import Config
|
||||
from ..events import Closed, Event, RawData, Updated
|
||||
from ..protocol import ProtocolWrapper
|
||||
from ..typing import ASGIFramework
|
||||
from ..utils import parse_socket_addr
|
||||
|
||||
MAX_RECV = 2 ** 16
|
||||
|
||||
|
||||
class EventWrapper:
|
||||
def __init__(self) -> None:
|
||||
self._event = trio.Event()
|
||||
|
||||
async def clear(self) -> None:
|
||||
self._event = trio.Event()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
async def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
|
||||
class TCPServer:
|
||||
def __init__(self, app: ASGIFramework, config: Config, stream: trio.abc.Stream) -> None:
|
||||
self.app = app
|
||||
self.config = config
|
||||
self.protocol: ProtocolWrapper
|
||||
self.send_lock = trio.Lock()
|
||||
self.timeout_lock = trio.Lock()
|
||||
self.stream = stream
|
||||
|
||||
self._keep_alive_timeout_handle: Optional[trio.CancelScope] = None
|
||||
|
||||
def __await__(self) -> Generator[Any, None, None]:
|
||||
return self.run().__await__()
|
||||
|
||||
async def run(self) -> None:
|
||||
try:
|
||||
try:
|
||||
with trio.fail_after(self.config.ssl_handshake_timeout):
|
||||
await self.stream.do_handshake()
|
||||
except (trio.BrokenResourceError, trio.TooSlowError):
|
||||
return # Handshake failed
|
||||
alpn_protocol = self.stream.selected_alpn_protocol()
|
||||
socket = self.stream.transport_stream.socket
|
||||
ssl = True
|
||||
except AttributeError: # Not SSL
|
||||
alpn_protocol = "http/1.1"
|
||||
socket = self.stream.socket
|
||||
ssl = False
|
||||
|
||||
try:
|
||||
client = parse_socket_addr(socket.family, socket.getpeername())
|
||||
server = parse_socket_addr(socket.family, socket.getsockname())
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
self.nursery = nursery
|
||||
context = Context(nursery)
|
||||
self.protocol = ProtocolWrapper(
|
||||
self.app,
|
||||
self.config,
|
||||
context,
|
||||
ssl,
|
||||
client,
|
||||
server,
|
||||
self.protocol_send,
|
||||
alpn_protocol,
|
||||
)
|
||||
await self.protocol.initiate()
|
||||
await self._update_keep_alive_timeout()
|
||||
await self._read_data()
|
||||
except (trio.MultiError, OSError):
|
||||
pass
|
||||
finally:
|
||||
await self._close()
|
||||
|
||||
async def protocol_send(self, event: Event) -> None:
|
||||
if isinstance(event, RawData):
|
||||
async with self.send_lock:
|
||||
try:
|
||||
with trio.CancelScope() as cancel_scope:
|
||||
cancel_scope.shield = True
|
||||
await self.stream.send_all(event.data)
|
||||
except (trio.BrokenResourceError, trio.ClosedResourceError):
|
||||
await self.protocol.handle(Closed())
|
||||
elif isinstance(event, Closed):
|
||||
await self._close()
|
||||
await self.protocol.handle(Closed())
|
||||
elif isinstance(event, Updated):
|
||||
pass # Triggers the keep alive timeout update
|
||||
await self._update_keep_alive_timeout()
|
||||
|
||||
async def _read_data(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
data = await self.stream.receive_some(MAX_RECV)
|
||||
except (trio.ClosedResourceError, trio.BrokenResourceError):
|
||||
await self.protocol.handle(Closed())
|
||||
break
|
||||
else:
|
||||
if data == b"":
|
||||
await self._update_keep_alive_timeout()
|
||||
break
|
||||
await self.protocol.handle(RawData(data))
|
||||
await self._update_keep_alive_timeout()
|
||||
|
||||
async def _close(self) -> None:
|
||||
try:
|
||||
await self.stream.send_eof()
|
||||
except (
|
||||
trio.BrokenResourceError,
|
||||
AttributeError,
|
||||
trio.BusyResourceError,
|
||||
trio.ClosedResourceError,
|
||||
):
|
||||
# They're already gone, nothing to do
|
||||
# Or it is a SSL stream
|
||||
pass
|
||||
await self.stream.aclose()
|
||||
|
||||
async def _update_keep_alive_timeout(self) -> None:
|
||||
async with self.timeout_lock:
|
||||
if self._keep_alive_timeout_handle is not None:
|
||||
self._keep_alive_timeout_handle.cancel()
|
||||
self._keep_alive_timeout_handle = None
|
||||
if self.protocol.idle:
|
||||
self._keep_alive_timeout_handle = await self.nursery.start(
|
||||
_call_later, self.config.keep_alive_timeout, self._timeout
|
||||
)
|
||||
|
||||
async def _timeout(self) -> None:
|
||||
await self.protocol.handle(Closed())
|
||||
await self.stream.aclose()
|
||||
|
||||
|
||||
async def _call_later(
|
||||
timeout: float,
|
||||
callback: Callable,
|
||||
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
|
||||
) -> None:
|
||||
cancel_scope = trio.CancelScope()
|
||||
task_status.started(cancel_scope)
|
||||
with cancel_scope:
|
||||
await trio.sleep(timeout)
|
||||
cancel_scope.shield = True
|
||||
await callback()
|
@@ -0,0 +1,40 @@
|
||||
import trio
|
||||
|
||||
from .context import Context
|
||||
from ..config import Config
|
||||
from ..events import Event, RawData
|
||||
from ..typing import ASGIFramework
|
||||
from ..utils import parse_socket_addr
|
||||
|
||||
MAX_RECV = 2 ** 16
|
||||
|
||||
|
||||
class UDPServer:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
socket: trio.socket.socket,
|
||||
nursery: trio._core._run.Nursery,
|
||||
) -> None:
|
||||
from ..protocol.quic import QuicProtocol # h3/Quic is an optional part of Hypercorn
|
||||
|
||||
self.app = app
|
||||
self.config = config
|
||||
self.nursery = nursery
|
||||
self.socket = trio.socket.from_stdlib_socket(socket)
|
||||
context = Context(nursery)
|
||||
server = parse_socket_addr(socket.family, socket.getsockname())
|
||||
self.protocol = QuicProtocol(self.app, self.config, context, server, self.protocol_send)
|
||||
|
||||
async def run(
|
||||
self, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED
|
||||
) -> None:
|
||||
task_status.started()
|
||||
while True:
|
||||
data, address = await self.socket.recvfrom(MAX_RECV)
|
||||
await self.protocol.handle(RawData(data=data, address=address))
|
||||
|
||||
async def protocol_send(self, event: Event) -> None:
|
||||
if isinstance(event, RawData):
|
||||
await self.socket.sendto(event.data, event.address)
|
308
.venv/lib/python3.9/site-packages/hypercorn/typing.py
Normal file
308
.venv/lib/python3.9/site-packages/hypercorn/typing.py
Normal file
@@ -0,0 +1,308 @@
|
||||
from multiprocessing.synchronize import Event as EventType
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional, Tuple, Type, Union
|
||||
|
||||
import h2.events
|
||||
import h11
|
||||
|
||||
# Till PEP 544 is accepted
|
||||
try:
|
||||
from typing import Literal, Protocol, TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import Literal, Protocol, TypedDict # type: ignore
|
||||
|
||||
from .config import Config, Sockets
|
||||
|
||||
H11SendableEvent = Union[h11.Data, h11.EndOfMessage, h11.InformationalResponse, h11.Response]
|
||||
|
||||
WorkerFunc = Callable[[Config, Optional[Sockets], Optional[EventType]], None]
|
||||
|
||||
|
||||
class ASGIVersions(TypedDict, total=False):
|
||||
spec_version: str
|
||||
version: Union[Literal["2.0"], Literal["3.0"]]
|
||||
|
||||
|
||||
class HTTPScope(TypedDict):
|
||||
type: Literal["http"]
|
||||
asgi: ASGIVersions
|
||||
http_version: str
|
||||
method: str
|
||||
scheme: str
|
||||
path: str
|
||||
raw_path: bytes
|
||||
query_string: bytes
|
||||
root_path: str
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
client: Optional[Tuple[str, int]]
|
||||
server: Optional[Tuple[str, Optional[int]]]
|
||||
extensions: Dict[str, dict]
|
||||
|
||||
|
||||
class WebsocketScope(TypedDict):
|
||||
type: Literal["websocket"]
|
||||
asgi: ASGIVersions
|
||||
http_version: str
|
||||
scheme: str
|
||||
path: str
|
||||
raw_path: bytes
|
||||
query_string: bytes
|
||||
root_path: str
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
client: Optional[Tuple[str, int]]
|
||||
server: Optional[Tuple[str, Optional[int]]]
|
||||
subprotocols: Iterable[str]
|
||||
extensions: Dict[str, dict]
|
||||
|
||||
|
||||
class LifespanScope(TypedDict):
|
||||
type: Literal["lifespan"]
|
||||
asgi: ASGIVersions
|
||||
|
||||
|
||||
WWWScope = Union[HTTPScope, WebsocketScope]
|
||||
Scope = Union[HTTPScope, WebsocketScope, LifespanScope]
|
||||
|
||||
|
||||
class HTTPRequestEvent(TypedDict):
|
||||
type: Literal["http.request"]
|
||||
body: bytes
|
||||
more_body: bool
|
||||
|
||||
|
||||
class HTTPResponseStartEvent(TypedDict):
|
||||
type: Literal["http.response.start"]
|
||||
status: int
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class HTTPResponseBodyEvent(TypedDict):
|
||||
type: Literal["http.response.body"]
|
||||
body: bytes
|
||||
more_body: bool
|
||||
|
||||
|
||||
class HTTPServerPushEvent(TypedDict):
|
||||
type: Literal["http.response.push"]
|
||||
path: str
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class HTTPDisconnectEvent(TypedDict):
|
||||
type: Literal["http.disconnect"]
|
||||
|
||||
|
||||
class WebsocketConnectEvent(TypedDict):
|
||||
type: Literal["websocket.connect"]
|
||||
|
||||
|
||||
class WebsocketAcceptEvent(TypedDict):
|
||||
type: Literal["websocket.accept"]
|
||||
subprotocol: Optional[str]
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class WebsocketReceiveEvent(TypedDict):
|
||||
type: Literal["websocket.receive"]
|
||||
bytes: Optional[bytes]
|
||||
text: Optional[str]
|
||||
|
||||
|
||||
class WebsocketSendEvent(TypedDict):
|
||||
type: Literal["websocket.send"]
|
||||
bytes: Optional[bytes]
|
||||
text: Optional[str]
|
||||
|
||||
|
||||
class WebsocketResponseStartEvent(TypedDict):
|
||||
type: Literal["websocket.http.response.start"]
|
||||
status: int
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class WebsocketResponseBodyEvent(TypedDict):
|
||||
type: Literal["websocket.http.response.body"]
|
||||
body: bytes
|
||||
more_body: bool
|
||||
|
||||
|
||||
class WebsocketDisconnectEvent(TypedDict):
|
||||
type: Literal["websocket.disconnect"]
|
||||
code: int
|
||||
|
||||
|
||||
class WebsocketCloseEvent(TypedDict):
|
||||
type: Literal["websocket.close"]
|
||||
code: int
|
||||
|
||||
|
||||
class LifespanStartupEvent(TypedDict):
|
||||
type: Literal["lifespan.startup"]
|
||||
|
||||
|
||||
class LifespanShutdownEvent(TypedDict):
|
||||
type: Literal["lifespan.shutdown"]
|
||||
|
||||
|
||||
class LifespanStartupCompleteEvent(TypedDict):
|
||||
type: Literal["lifespan.startup.complete"]
|
||||
|
||||
|
||||
class LifespanStartupFailedEvent(TypedDict):
|
||||
type: Literal["lifespan.startup.failed"]
|
||||
message: str
|
||||
|
||||
|
||||
class LifespanShutdownCompleteEvent(TypedDict):
|
||||
type: Literal["lifespan.shutdown.complete"]
|
||||
|
||||
|
||||
class LifespanShutdownFailedEvent(TypedDict):
|
||||
type: Literal["lifespan.shutdown.failed"]
|
||||
message: str
|
||||
|
||||
|
||||
ASGIReceiveEvent = Union[
|
||||
HTTPRequestEvent,
|
||||
HTTPDisconnectEvent,
|
||||
WebsocketConnectEvent,
|
||||
WebsocketReceiveEvent,
|
||||
WebsocketDisconnectEvent,
|
||||
LifespanStartupEvent,
|
||||
LifespanShutdownEvent,
|
||||
]
|
||||
|
||||
|
||||
ASGISendEvent = Union[
|
||||
HTTPResponseStartEvent,
|
||||
HTTPResponseBodyEvent,
|
||||
HTTPServerPushEvent,
|
||||
HTTPDisconnectEvent,
|
||||
WebsocketAcceptEvent,
|
||||
WebsocketSendEvent,
|
||||
WebsocketResponseStartEvent,
|
||||
WebsocketResponseBodyEvent,
|
||||
WebsocketCloseEvent,
|
||||
LifespanStartupCompleteEvent,
|
||||
LifespanStartupFailedEvent,
|
||||
LifespanShutdownCompleteEvent,
|
||||
LifespanShutdownFailedEvent,
|
||||
]
|
||||
|
||||
|
||||
ASGIReceiveCallable = Callable[[], Awaitable[ASGIReceiveEvent]]
|
||||
ASGISendCallable = Callable[[ASGISendEvent], Awaitable[None]]
|
||||
|
||||
|
||||
class ASGI2Protocol(Protocol):
|
||||
# Should replace with a Protocol when PEP 544 is accepted.
|
||||
|
||||
def __init__(self, scope: Scope) -> None:
|
||||
...
|
||||
|
||||
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
...
|
||||
|
||||
|
||||
ASGI2Framework = Type[ASGI2Protocol]
|
||||
ASGI3Framework = Callable[
|
||||
[
|
||||
Scope,
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
],
|
||||
Awaitable[None],
|
||||
]
|
||||
ASGIFramework = Union[ASGI2Framework, ASGI3Framework]
|
||||
|
||||
|
||||
class H2SyncStream(Protocol):
|
||||
scope: dict
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
...
|
||||
|
||||
def ended(self) -> None:
|
||||
...
|
||||
|
||||
def reset(self) -> None:
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
...
|
||||
|
||||
async def handle_request(
|
||||
self,
|
||||
event: h2.events.RequestReceived,
|
||||
scheme: str,
|
||||
client: Tuple[str, int],
|
||||
server: Tuple[str, int],
|
||||
) -> None:
|
||||
...
|
||||
|
||||
|
||||
class H2AsyncStream(Protocol):
|
||||
scope: dict
|
||||
|
||||
async def data_received(self, data: bytes) -> None:
|
||||
...
|
||||
|
||||
async def ended(self) -> None:
|
||||
...
|
||||
|
||||
async def reset(self) -> None:
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
...
|
||||
|
||||
async def handle_request(
|
||||
self,
|
||||
event: h2.events.RequestReceived,
|
||||
scheme: str,
|
||||
client: Tuple[str, int],
|
||||
server: Tuple[str, int],
|
||||
) -> None:
|
||||
...
|
||||
|
||||
|
||||
class Event(Protocol):
|
||||
def __init__(self) -> None:
|
||||
...
|
||||
|
||||
async def clear(self) -> None:
|
||||
...
|
||||
|
||||
async def set(self) -> None:
|
||||
...
|
||||
|
||||
async def wait(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class Context(Protocol):
|
||||
event_class: Type[Event]
|
||||
|
||||
async def spawn_app(
|
||||
self,
|
||||
app: ASGIFramework,
|
||||
config: Config,
|
||||
scope: Scope,
|
||||
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
|
||||
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
|
||||
...
|
||||
|
||||
def spawn(self, func: Callable, *args: Any) -> None:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
async def sleep(wait: Union[float, int]) -> None:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def time() -> float:
|
||||
...
|
||||
|
||||
|
||||
class ResponseSummary(TypedDict):
|
||||
status: int
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
261
.venv/lib/python3.9/site-packages/hypercorn/utils.py
Normal file
261
.venv/lib/python3.9/site-packages/hypercorn/utils.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import inspect
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
from multiprocessing.synchronize import Event as EventType
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .config import Config
|
||||
from .typing import (
|
||||
ASGI2Framework,
|
||||
ASGI3Framework,
|
||||
ASGIFramework,
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
Scope,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocol.events import Request
|
||||
|
||||
|
||||
class Shutdown(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MustReloadException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoAppException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LifespanTimeout(Exception):
|
||||
def __init__(self, stage: str) -> None:
|
||||
super().__init__(
|
||||
f"Timeout whilst awaiting {stage}. Your application may not support the ASGI Lifespan "
|
||||
f"protocol correctly, alternatively the {stage}_timeout configuration is incorrect."
|
||||
)
|
||||
|
||||
|
||||
class LifespanFailure(Exception):
|
||||
def __init__(self, stage: str, message: str) -> None:
|
||||
super().__init__(f"Lifespan failure in {stage}. '{message}'")
|
||||
|
||||
|
||||
class UnexpectedMessage(Exception):
|
||||
def __init__(self, state: Enum, message_type: str) -> None:
|
||||
super().__init__(f"Unexpected message type, {message_type} given the state {state}")
|
||||
|
||||
|
||||
class FrameTooLarge(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def suppress_body(method: str, status_code: int) -> bool:
|
||||
return method == "HEAD" or 100 <= status_code < 200 or status_code in {204, 304, 412}
|
||||
|
||||
|
||||
def build_and_validate_headers(headers: Iterable[Tuple[bytes, bytes]]) -> List[Tuple[bytes, bytes]]:
|
||||
# Validates that the header name and value are bytes
|
||||
validated_headers: List[Tuple[bytes, bytes]] = []
|
||||
for name, value in headers:
|
||||
if name[0] == b":"[0]:
|
||||
raise ValueError("Pseudo headers are not valid")
|
||||
validated_headers.append((bytes(name).lower().strip(), bytes(value).strip()))
|
||||
return validated_headers
|
||||
|
||||
|
||||
def filter_pseudo_headers(headers: List[Tuple[bytes, bytes]]) -> List[Tuple[bytes, bytes]]:
|
||||
filtered_headers: List[Tuple[bytes, bytes]] = [(b"host", b"")] # Placeholder
|
||||
for name, value in headers:
|
||||
if name == b":authority": # h2 & h3 libraries validate this is present
|
||||
filtered_headers[0] = (b"host", value)
|
||||
elif name[0] != b":"[0]:
|
||||
filtered_headers.append((name, value))
|
||||
return filtered_headers
|
||||
|
||||
|
||||
def load_application(path: str) -> ASGIFramework:
|
||||
try:
|
||||
module_name, app_name = path.split(":", 1)
|
||||
except ValueError:
|
||||
module_name, app_name = path, "app"
|
||||
except AttributeError:
|
||||
raise NoAppException()
|
||||
|
||||
module_path = Path(module_name).resolve()
|
||||
sys.path.insert(0, str(module_path.parent))
|
||||
if module_path.is_file():
|
||||
import_name = module_path.with_suffix("").name
|
||||
else:
|
||||
import_name = module_path.name
|
||||
try:
|
||||
module = import_module(import_name)
|
||||
except ModuleNotFoundError as error:
|
||||
if error.name == import_name:
|
||||
raise NoAppException()
|
||||
else:
|
||||
raise
|
||||
|
||||
try:
|
||||
return eval(app_name, vars(module))
|
||||
except NameError:
|
||||
raise NoAppException()
|
||||
|
||||
|
||||
async def observe_changes(sleep: Callable[[float], Awaitable[Any]]) -> None:
|
||||
last_updates: Dict[Path, float] = {}
|
||||
for module in list(sys.modules.values()):
|
||||
filename = getattr(module, "__file__", None)
|
||||
if filename is None:
|
||||
continue
|
||||
path = Path(filename)
|
||||
try:
|
||||
last_updates[Path(filename)] = path.stat().st_mtime
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
pass
|
||||
|
||||
while True:
|
||||
await sleep(1)
|
||||
|
||||
for index, (path, last_mtime) in enumerate(last_updates.items()):
|
||||
if index % 10 == 0:
|
||||
# Yield to the event loop
|
||||
await sleep(0)
|
||||
|
||||
try:
|
||||
mtime = path.stat().st_mtime
|
||||
except FileNotFoundError:
|
||||
# File deleted
|
||||
raise MustReloadException()
|
||||
else:
|
||||
if mtime > last_mtime:
|
||||
raise MustReloadException()
|
||||
else:
|
||||
last_updates[path] = mtime
|
||||
|
||||
|
||||
def restart() -> None:
|
||||
# Restart this process (only safe for dev/debug)
|
||||
executable = sys.executable
|
||||
script_path = Path(sys.argv[0]).resolve()
|
||||
args = sys.argv[1:]
|
||||
main_package = sys.modules["__main__"].__package__
|
||||
|
||||
if main_package is None:
|
||||
# Executed by filename
|
||||
if platform.system() == "Windows":
|
||||
if not script_path.exists() and script_path.with_suffix(".exe").exists():
|
||||
# quart run
|
||||
executable = str(script_path.with_suffix(".exe"))
|
||||
else:
|
||||
# python run.py
|
||||
args.append(str(script_path))
|
||||
else:
|
||||
if script_path.is_file() and os.access(script_path, os.X_OK):
|
||||
# hypercorn run:app --reload
|
||||
executable = str(script_path)
|
||||
else:
|
||||
# python run.py
|
||||
args.append(str(script_path))
|
||||
else:
|
||||
# Executed as a module e.g. python -m run
|
||||
module = script_path.stem
|
||||
import_name = main_package
|
||||
if module != "__main__":
|
||||
import_name = f"{main_package}.{module}"
|
||||
args[:0] = ["-m", import_name.lstrip(".")]
|
||||
|
||||
os.execv(executable, [executable] + args)
|
||||
|
||||
|
||||
async def raise_shutdown(shutdown_event: Callable[..., Awaitable[None]]) -> None:
|
||||
await shutdown_event()
|
||||
raise Shutdown()
|
||||
|
||||
|
||||
async def check_multiprocess_shutdown_event(
|
||||
shutdown_event: EventType, sleep: Callable[[float], Awaitable[Any]]
|
||||
) -> None:
|
||||
while True:
|
||||
if shutdown_event.is_set():
|
||||
return
|
||||
await sleep(0.1)
|
||||
|
||||
|
||||
def write_pid_file(pid_path: str) -> None:
|
||||
with open(pid_path, "w") as file_:
|
||||
file_.write(f"{os.getpid()}")
|
||||
|
||||
|
||||
def parse_socket_addr(family: int, address: tuple) -> Optional[Tuple[str, int]]:
|
||||
if family == socket.AF_INET:
|
||||
return address # type: ignore
|
||||
elif family == socket.AF_INET6:
|
||||
return (address[0], address[1])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def repr_socket_addr(family: int, address: tuple) -> str:
|
||||
if family == socket.AF_INET:
|
||||
return f"{address[0]}:{address[1]}"
|
||||
elif family == socket.AF_INET6:
|
||||
return f"[{address[0]}]:{address[1]}"
|
||||
elif family == socket.AF_UNIX:
|
||||
return f"unix:{address}"
|
||||
else:
|
||||
return f"{address}"
|
||||
|
||||
|
||||
async def invoke_asgi(
|
||||
app: ASGIFramework, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
|
||||
) -> None:
|
||||
if _is_asgi_2(app):
|
||||
scope["asgi"]["version"] = "2.0"
|
||||
app = cast(ASGI2Framework, app)
|
||||
asgi_instance = app(scope)
|
||||
await asgi_instance(receive, send)
|
||||
else:
|
||||
scope["asgi"]["version"] = "3.0"
|
||||
app = cast(ASGI3Framework, app)
|
||||
await app(scope, receive, send)
|
||||
|
||||
|
||||
def _is_asgi_2(app: ASGIFramework) -> bool:
|
||||
if inspect.isclass(app):
|
||||
return True
|
||||
|
||||
if hasattr(app, "__call__") and inspect.iscoroutinefunction(app.__call__): # type: ignore
|
||||
return False
|
||||
|
||||
return not inspect.iscoroutinefunction(app)
|
||||
|
||||
|
||||
def valid_server_name(config: Config, request: "Request") -> bool:
|
||||
if len(config.server_names) == 0:
|
||||
return True
|
||||
|
||||
host = ""
|
||||
for name, value in request.headers:
|
||||
if name.lower() == b"host":
|
||||
host = value.decode()
|
||||
break
|
||||
return host in config.server_names
|
Reference in New Issue
Block a user