This commit is contained in:
Untriex Programming
2021-03-17 08:57:57 +01:00
parent 339be0ccd8
commit ed6afdb5c9
3074 changed files with 423348 additions and 0 deletions

View File

@@ -0,0 +1 @@
__version__ = "0.11.2"

View File

@@ -0,0 +1,4 @@
from .__about__ import __version__
from .config import Config
__all__ = ("__version__", "Config")

View File

@@ -0,0 +1,271 @@
import argparse
import ssl
import sys
import warnings
from typing import List, Optional
from .config import Config
from .run import run
sentinel = object()
def _load_config(config_path: Optional[str]) -> Config:
if config_path is None:
return Config()
elif config_path.startswith("python:"):
return Config.from_object(config_path[len("python:") :])
elif config_path.startswith("file:"):
return Config.from_pyfile(config_path[len("file:") :])
else:
return Config.from_toml(config_path)
def main(sys_args: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"application", help="The application to dispatch to as path.to.module:instance.path"
)
parser.add_argument("--access-log", help="Deprecated, see access-logfile", default=sentinel)
parser.add_argument(
"--access-logfile",
help="The target location for the access log, use `-` for stdout",
default=sentinel,
)
parser.add_argument(
"--access-logformat",
help="The log format for the access log, see help docs",
default=sentinel,
)
parser.add_argument(
"--backlog", help="The maximum number of pending connections", type=int, default=sentinel
)
parser.add_argument(
"-b",
"--bind",
dest="binds",
help=""" The TCP host/address to bind to. Should be either host:port, host,
unix:path or fd://num, e.g. 127.0.0.1:5000, 127.0.0.1,
unix:/tmp/socket or fd://33 respectively. """,
default=[],
action="append",
)
parser.add_argument("--ca-certs", help="Path to the SSL CA certificate file", default=sentinel)
parser.add_argument("--certfile", help="Path to the SSL certificate file", default=sentinel)
parser.add_argument("--cert-reqs", help="See verify mode argument", type=int, default=sentinel)
parser.add_argument("--ciphers", help="Ciphers to use for the SSL setup", default=sentinel)
parser.add_argument(
"-c",
"--config",
help="Location of a TOML config file, or when prefixed with `file:` a Python file, or when prefixed with `python:` a Python module.", # noqa: E501
default=None,
)
parser.add_argument(
"--debug",
help="Enable debug mode, i.e. extra logging and checks",
action="store_true",
default=sentinel,
)
parser.add_argument("--error-log", help="Deprecated, see error-logfile", default=sentinel)
parser.add_argument(
"--error-logfile",
"--log-file",
dest="error_logfile",
help="The target location for the error log, use `-` for stderr",
default=sentinel,
)
parser.add_argument(
"--graceful-timeout",
help="""Time to wait after SIGTERM or Ctrl-C for any remaining requests (tasks)
to complete.""",
default=sentinel,
type=int,
)
parser.add_argument(
"-g", "--group", help="Group to own any unix sockets.", default=sentinel, type=int
)
parser.add_argument(
"-k",
"--worker-class",
dest="worker_class",
help="The type of worker to use. "
"Options include asyncio, uvloop (pip install hypercorn[uvloop]), "
"and trio (pip install hypercorn[trio]).",
default=sentinel,
)
parser.add_argument(
"--keep-alive",
help="Seconds to keep inactive connections alive for",
default=sentinel,
type=int,
)
parser.add_argument("--keyfile", help="Path to the SSL key file", default=sentinel)
parser.add_argument(
"--insecure-bind",
dest="insecure_binds",
help="""The TCP host/address to bind to. SSL options will not apply to these binds.
See *bind* for formatting options. Care must be taken! See HTTP -> HTTPS redirection docs.
""",
default=[],
action="append",
)
parser.add_argument(
"--log-config", help="A Python logging configuration file.", default=sentinel
)
parser.add_argument(
"--log-level", help="The (error) log level, defaults to info", default="INFO"
)
parser.add_argument(
"-p", "--pid", help="Location to write the PID (Program ID) to.", default=sentinel
)
parser.add_argument(
"--quic-bind",
dest="quic_binds",
help="""The UDP/QUIC host/address to bind to. See *bind* for formatting
options.
""",
default=[],
action="append",
)
parser.add_argument(
"--reload",
help="Enable automatic reloads on code changes",
action="store_true",
default=sentinel,
)
parser.add_argument(
"--root-path", help="The setting for the ASGI root_path variable", default=sentinel
)
parser.add_argument(
"--server-name",
dest="server_names",
help="""The hostnames that can be served, requests to different hosts
will be responded to with 404s.
""",
default=[],
action="append",
)
parser.add_argument(
"--statsd-host", help="The host:port of the statsd server", default=sentinel
)
parser.add_argument("--statsd-prefix", help="Prefix for all statsd messages", default="")
parser.add_argument(
"-m",
"--umask",
help="The permissions bit mask to use on any unix sockets.",
default=sentinel,
type=int,
)
parser.add_argument(
"-u", "--user", help="User to own any unix sockets.", default=sentinel, type=int
)
def _convert_verify_mode(value: str) -> ssl.VerifyMode:
try:
return ssl.VerifyMode[value]
except KeyError:
raise argparse.ArgumentTypeError(f"'{value}' is not a valid verify mode")
parser.add_argument(
"--verify-mode",
help="SSL verify mode for peer's certificate, see ssl.VerifyMode enum for possible values.",
type=_convert_verify_mode,
default=sentinel,
)
parser.add_argument(
"--websocket-ping-interval",
help="""If set this is the time in seconds between pings sent to the client.
This can be used to keep the websocket connection alive.""",
default=sentinel,
type=int,
)
parser.add_argument(
"-w",
"--workers",
dest="workers",
help="The number of workers to spawn and use",
default=sentinel,
type=int,
)
args = parser.parse_args(sys_args or sys.argv[1:])
config = _load_config(args.config)
config.application_path = args.application
config.loglevel = args.log_level
if args.access_logformat is not sentinel:
config.access_log_format = args.access_logformat
if args.access_log is not sentinel:
warnings.warn(
"The --access-log argument is deprecated, use `--access-logfile` instead",
DeprecationWarning,
)
config.accesslog = args.access_log
if args.access_logfile is not sentinel:
config.accesslog = args.access_logfile
if args.backlog is not sentinel:
config.backlog = args.backlog
if args.ca_certs is not sentinel:
config.ca_certs = args.ca_certs
if args.certfile is not sentinel:
config.certfile = args.certfile
if args.cert_reqs is not sentinel:
config.cert_reqs = args.cert_reqs
if args.ciphers is not sentinel:
config.ciphers = args.ciphers
if args.debug is not sentinel:
config.debug = args.debug
if args.error_log is not sentinel:
warnings.warn(
"The --error-log argument is deprecated, use `--error-logfile` instead",
DeprecationWarning,
)
config.errorlog = args.error_log
if args.error_logfile is not sentinel:
config.errorlog = args.error_logfile
if args.graceful_timeout is not sentinel:
config.graceful_timeout = args.graceful_timeout
if args.group is not sentinel:
config.group = args.group
if args.keep_alive is not sentinel:
config.keep_alive_timeout = args.keep_alive
if args.keyfile is not sentinel:
config.keyfile = args.keyfile
if args.log_config is not sentinel:
config.logconfig = args.log_config
if args.pid is not sentinel:
config.pid_path = args.pid
if args.root_path is not sentinel:
config.root_path = args.root_path
if args.reload is not sentinel:
config.use_reloader = args.reload
if args.statsd_host is not sentinel:
config.statsd_host = args.statsd_host
if args.statsd_prefix is not sentinel:
config.statsd_prefix = args.statsd_prefix
if args.umask is not sentinel:
config.umask = args.umask
if args.user is not sentinel:
config.user = args.user
if args.worker_class is not sentinel:
config.worker_class = args.worker_class
if args.verify_mode is not sentinel:
config.verify_mode = args.verify_mode
if args.websocket_ping_interval is not sentinel:
config.websocket_ping_interval = args.websocket_ping_interval
if args.workers is not sentinel:
config.workers = args.workers
if len(args.binds) > 0:
config.bind = args.binds
if len(args.insecure_binds) > 0:
config.insecure_bind = args.insecure_binds
if len(args.quic_binds) > 0:
config.quic_bind = args.quic_binds
if len(args.server_names) > 0:
config.server_names = args.server_names
run(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,39 @@
import warnings
from typing import Awaitable, Callable, Optional
from .run import worker_serve
from ..config import Config
from ..typing import ASGIFramework
async def serve(
app: ASGIFramework,
config: Config,
*,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
) -> None:
"""Serve an ASGI framework app given the config.
This allows for a programmatic way to serve an ASGI framework, it
can be used via,
.. code-block:: python
asyncio.run(serve(app, config))
It is assumed that the event-loop is configured before calling
this function, therefore configuration values that relate to loop
setup or process setup are ignored.
Arguments:
app: The ASGI application to serve.
config: A Hypercorn configuration object.
shutdown_trigger: This should return to trigger a graceful
shutdown.
"""
if config.debug:
warnings.warn("The config `debug` has no affect when using serve", Warning)
if config.workers != 1:
warnings.warn("The config `workers` has no affect when using serve", Warning)
await worker_serve(app, config, shutdown_trigger=shutdown_trigger)

View File

@@ -0,0 +1,74 @@
import asyncio
from typing import Any, Awaitable, Callable, Optional, Type, Union
from .task_group import TaskGroup
from ..config import Config
from ..typing import (
ASGIFramework,
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendEvent,
Event,
Scope,
)
from ..utils import invoke_asgi
class EventWrapper:
def __init__(self) -> None:
self._event = asyncio.Event()
async def clear(self) -> None:
self._event.clear()
async def wait(self) -> None:
await self._event.wait()
async def set(self) -> None:
self._event.set()
async def _handle(
app: ASGIFramework,
config: Config,
scope: Scope,
receive: ASGIReceiveCallable,
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> None:
try:
await invoke_asgi(app, scope, receive, send)
except asyncio.CancelledError:
raise
except Exception:
await config.log.exception("Error in ASGI Framework")
finally:
await send(None)
class Context:
event_class: Type[Event] = EventWrapper
def __init__(self, task_group: TaskGroup) -> None:
self.task_group = task_group
async def spawn_app(
self,
app: ASGIFramework,
config: Config,
scope: Scope,
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue(config.max_app_queue_size)
self.task_group.spawn(_handle(app, config, scope, app_queue.get, send))
return app_queue.put
def spawn(self, func: Callable, *args: Any) -> None:
self.task_group.spawn(func(*args))
@staticmethod
async def sleep(wait: Union[float, int]) -> None:
return await asyncio.sleep(wait)
@staticmethod
def time() -> float:
return asyncio.get_event_loop().time()

View File

@@ -0,0 +1,85 @@
import asyncio
from ..config import Config
from ..typing import ASGIFramework, ASGIReceiveEvent, ASGISendEvent, LifespanScope
from ..utils import invoke_asgi, LifespanFailure, LifespanTimeout
class UnexpectedMessage(Exception):
pass
class Lifespan:
def __init__(self, app: ASGIFramework, config: Config) -> None:
self.app = app
self.config = config
self.startup = asyncio.Event()
self.shutdown = asyncio.Event()
self.app_queue: asyncio.Queue = asyncio.Queue(config.max_app_queue_size)
self.supported = True
# This mimics the Trio nursery.start task_status and is
# required to ensure the support has been checked before
# waiting on timeouts.
self._started = asyncio.Event()
async def handle_lifespan(self) -> None:
self._started.set()
scope: LifespanScope = {"type": "lifespan", "asgi": {"spec_version": "2.0"}}
try:
await invoke_asgi(self.app, scope, self.asgi_receive, self.asgi_send)
except LifespanFailure:
# Lifespan failures should crash the server
raise
except Exception:
self.supported = False
if not self.startup.is_set():
message = "ASGI Framework Lifespan error, continuing without Lifespan support"
elif not self.shutdown.is_set():
message = "ASGI Framework Lifespan error, shutdown without Lifespan support"
else:
message = "ASGI Framework Lifespan errored after shutdown."
await self.config.log.exception(message)
finally:
self.startup.set()
self.shutdown.set()
async def wait_for_startup(self) -> None:
await self._started.wait()
if not self.supported:
return
await self.app_queue.put({"type": "lifespan.startup"})
try:
await asyncio.wait_for(self.startup.wait(), timeout=self.config.startup_timeout)
except asyncio.TimeoutError as error:
raise LifespanTimeout("startup") from error
async def wait_for_shutdown(self) -> None:
await self._started.wait()
if not self.supported:
return
await self.app_queue.put({"type": "lifespan.shutdown"})
try:
await asyncio.wait_for(self.shutdown.wait(), timeout=self.config.shutdown_timeout)
except asyncio.TimeoutError as error:
raise LifespanTimeout("shutdown") from error
async def asgi_receive(self) -> ASGIReceiveEvent:
return await self.app_queue.get()
async def asgi_send(self, message: ASGISendEvent) -> None:
if message["type"] == "lifespan.startup.complete":
self.startup.set()
elif message["type"] == "lifespan.shutdown.complete":
self.shutdown.set()
elif message["type"] == "lifespan.startup.failed":
self.startup.set()
raise LifespanFailure("startup", message["message"])
elif message["type"] == "lifespan.shutdown.failed":
self.shutdown.set()
raise LifespanFailure("shutdown", message["message"])
else:
raise UnexpectedMessage(message["type"])

View File

@@ -0,0 +1,266 @@
import asyncio
import platform
import signal
import ssl
from functools import partial
from multiprocessing.synchronize import Event as EventType
from os import getpid
from socket import socket
from typing import Any, Awaitable, Callable, Optional
from .lifespan import Lifespan
from .statsd import StatsdLogger
from .tcp_server import TCPServer
from .udp_server import UDPServer
from ..config import Config, Sockets
from ..typing import ASGIFramework
from ..utils import (
check_multiprocess_shutdown_event,
load_application,
MustReloadException,
observe_changes,
raise_shutdown,
repr_socket_addr,
restart,
Shutdown,
)
try:
from socket import AF_UNIX
except ImportError:
AF_UNIX = None
async def _windows_signal_support() -> None:
# See https://bugs.python.org/issue23057, to catch signals on
# Windows it is necessary for an IO event to happen periodically.
while True:
await asyncio.sleep(1)
def _share_socket(sock: socket) -> socket:
# Windows requires the socket be explicitly shared across
# multiple workers (processes).
from socket import fromshare # type: ignore
sock_data = sock.share(getpid()) # type: ignore
return fromshare(sock_data)
async def worker_serve(
app: ASGIFramework,
config: Config,
*,
sockets: Optional[Sockets] = None,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
) -> None:
config.set_statsd_logger_class(StatsdLogger)
lifespan = Lifespan(app, config)
lifespan_task = asyncio.ensure_future(lifespan.handle_lifespan())
await lifespan.wait_for_startup()
if lifespan_task.done():
exception = lifespan_task.exception()
if exception is not None:
raise exception
if sockets is None:
sockets = config.create_sockets()
loop = asyncio.get_event_loop()
tasks = []
if platform.system() == "Windows":
tasks.append(loop.create_task(_windows_signal_support()))
if shutdown_trigger is None:
signal_event = asyncio.Event()
def _signal_handler(*_: Any) -> None: # noqa: N803
signal_event.set()
for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}:
if hasattr(signal, signal_name):
try:
loop.add_signal_handler(getattr(signal, signal_name), _signal_handler)
except NotImplementedError:
# Add signal handler may not be implemented on Windows
signal.signal(getattr(signal, signal_name), _signal_handler)
shutdown_trigger = signal_event.wait # type: ignore
tasks.append(loop.create_task(raise_shutdown(shutdown_trigger)))
if config.use_reloader:
tasks.append(loop.create_task(observe_changes(asyncio.sleep)))
ssl_handshake_timeout = None
if config.ssl_enabled:
ssl_context = config.create_ssl_context()
ssl_handshake_timeout = config.ssl_handshake_timeout
async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
await TCPServer(app, loop, config, reader, writer)
servers = []
for sock in sockets.secure_sockets:
if config.workers > 1 and platform.system() == "Windows":
sock = _share_socket(sock)
servers.append(
await asyncio.start_server(
_server_callback,
backlog=config.backlog,
loop=loop,
ssl=ssl_context,
sock=sock,
ssl_handshake_timeout=ssl_handshake_timeout,
)
)
bind = repr_socket_addr(sock.family, sock.getsockname())
await config.log.info(f"Running on https://{bind} (CTRL + C to quit)")
for sock in sockets.insecure_sockets:
if config.workers > 1 and platform.system() == "Windows":
sock = _share_socket(sock)
servers.append(
await asyncio.start_server(
_server_callback, backlog=config.backlog, loop=loop, sock=sock
)
)
bind = repr_socket_addr(sock.family, sock.getsockname())
await config.log.info(f"Running on http://{bind} (CTRL + C to quit)")
tasks.extend(server.serve_forever() for server in servers) # type: ignore
for sock in sockets.quic_sockets:
if config.workers > 1 and platform.system() == "Windows":
sock = _share_socket(sock)
await loop.create_datagram_endpoint(lambda: UDPServer(app, loop, config), sock=sock)
bind = repr_socket_addr(sock.family, sock.getsockname())
await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)")
reload_ = False
try:
gathered_tasks = asyncio.gather(*tasks)
await gathered_tasks
except MustReloadException:
reload_ = True
except (Shutdown, KeyboardInterrupt):
pass
finally:
for server in servers:
server.close()
await server.wait_closed()
try:
await asyncio.sleep(config.graceful_timeout)
except (Shutdown, KeyboardInterrupt):
pass
# Retrieve the Gathered Tasks Cancelled Exception, to
# prevent a warning that this hasn't been done.
gathered_tasks.exception()
await lifespan.wait_for_shutdown()
lifespan_task.cancel()
await lifespan_task
if reload_:
restart()
def asyncio_worker(
config: Config, sockets: Optional[Sockets] = None, shutdown_event: Optional[EventType] = None
) -> None:
app = load_application(config.application_path)
shutdown_trigger = None
if shutdown_event is not None:
shutdown_trigger = partial(check_multiprocess_shutdown_event, shutdown_event, asyncio.sleep)
if config.workers > 1 and platform.system() == "Windows":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore
_run(
partial(worker_serve, app, config, sockets=sockets),
debug=config.debug,
shutdown_trigger=shutdown_trigger,
)
def uvloop_worker(
config: Config, sockets: Optional[Sockets] = None, shutdown_event: Optional[EventType] = None
) -> None:
try:
import uvloop
except ImportError as error:
raise Exception("uvloop is not installed") from error
else:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = load_application(config.application_path)
shutdown_trigger = None
if shutdown_event is not None:
shutdown_trigger = partial(check_multiprocess_shutdown_event, shutdown_event, asyncio.sleep)
_run(
partial(worker_serve, app, config, sockets=sockets),
debug=config.debug,
shutdown_trigger=shutdown_trigger,
)
def _run(
main: Callable,
*,
debug: bool = False,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.set_debug(debug)
loop.set_exception_handler(_exception_handler)
try:
loop.run_until_complete(main(shutdown_trigger=shutdown_trigger))
except KeyboardInterrupt:
pass
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = [task for task in asyncio.all_tasks(loop) if not task.done()]
if not tasks:
return
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, loop=loop, return_exceptions=True))
for task in tasks:
if not task.cancelled() and task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during shutdown",
"exception": task.exception(),
"task": task,
}
)
def _exception_handler(loop: asyncio.AbstractEventLoop, context: dict) -> None:
exception = context.get("exception")
if isinstance(exception, ssl.SSLError):
pass # Handshake failure
else:
loop.default_exception_handler(context)

View File

@@ -0,0 +1,24 @@
import asyncio
from typing import Optional
from ..config import Config
from ..statsd import StatsdLogger as Base
class _DummyProto(asyncio.DatagramProtocol):
pass
class StatsdLogger(Base):
def __init__(self, config: Config) -> None:
super().__init__(config)
self.address = config.statsd_host.rsplit(":", 1)
self.transport: Optional[asyncio.BaseTransport] = None
async def _socket_send(self, message: bytes) -> None:
if self.transport is None:
self.transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
_DummyProto, remote_addr=(self.address[0], int(self.address[1]))
)
self.transport.sendto(message) # type: ignore

View File

@@ -0,0 +1,34 @@
import asyncio
import weakref
from types import TracebackType
from typing import Coroutine
class TaskGroup:
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._tasks: weakref.WeakSet = weakref.WeakSet()
def spawn(self, coro: Coroutine) -> None:
self._tasks.add(self._loop.create_task(coro))
async def __aenter__(self) -> "TaskGroup":
return self
async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
if exc_type is not None:
self._cancel_tasks()
try:
task = asyncio.gather(*self._tasks)
await task
finally:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def _cancel_tasks(self) -> None:
for task in self._tasks:
task.cancel()

View File

@@ -0,0 +1,153 @@
import asyncio
from typing import Any, Callable, cast, Generator, Optional
from .context import Context
from .task_group import TaskGroup
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..protocol import ProtocolWrapper
from ..typing import ASGIFramework
from ..utils import parse_socket_addr
MAX_RECV = 2 ** 16
class EventWrapper:
def __init__(self) -> None:
self._event = asyncio.Event()
async def clear(self) -> None:
self._event.clear()
async def wait(self) -> None:
await self._event.wait()
async def set(self) -> None:
self._event.set()
class TCPServer:
def __init__(
self,
app: ASGIFramework,
loop: asyncio.AbstractEventLoop,
config: Config,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
self.app = app
self.config = config
self.loop = loop
self.protocol: ProtocolWrapper
self.reader = reader
self.writer = writer
self.send_lock = asyncio.Lock()
self.timeout_lock = asyncio.Lock()
self._keep_alive_timeout_handle: Optional[asyncio.Task] = None
def __await__(self) -> Generator[Any, None, None]:
return self.run().__await__()
async def run(self) -> None:
socket = self.writer.get_extra_info("socket")
try:
client = parse_socket_addr(socket.family, socket.getpeername())
server = parse_socket_addr(socket.family, socket.getsockname())
ssl_object = self.writer.get_extra_info("ssl_object")
if ssl_object is not None:
ssl = True
alpn_protocol = ssl_object.selected_alpn_protocol()
else:
ssl = False
alpn_protocol = "http/1.1"
async with TaskGroup(self.loop) as task_group:
context = Context(task_group)
self.protocol = ProtocolWrapper(
self.app,
self.config,
cast(Any, context),
ssl,
client,
server,
self.protocol_send,
alpn_protocol,
)
await self.protocol.initiate()
await self._update_keep_alive_timeout()
await self._read_data()
except OSError:
pass
finally:
await self._close()
async def protocol_send(self, event: Event) -> None:
if isinstance(event, RawData):
async with self.send_lock:
try:
self.writer.write(event.data)
await self.writer.drain()
except ConnectionError:
await self.protocol.handle(Closed())
elif isinstance(event, Closed):
await self._close()
await self.protocol.handle(Closed())
elif isinstance(event, Updated):
pass # Triggers the keep alive timeout update
await self._update_keep_alive_timeout()
async def _read_data(self) -> None:
while True:
try:
data = await self.reader.read(MAX_RECV)
except (
ConnectionError,
OSError,
asyncio.TimeoutError,
TimeoutError,
):
await self.protocol.handle(Closed())
break
else:
if data == b"":
await self._update_keep_alive_timeout()
break
await self.protocol.handle(RawData(data))
await self._update_keep_alive_timeout()
async def _close(self) -> None:
try:
self.writer.write_eof()
except (NotImplementedError, OSError, RuntimeError):
pass # Likely SSL connection
try:
self.writer.close()
await self.writer.wait_closed()
except (BrokenPipeError, ConnectionResetError):
pass # Already closed
async def _update_keep_alive_timeout(self) -> None:
async with self.timeout_lock:
if self._keep_alive_timeout_handle is not None:
self._keep_alive_timeout_handle.cancel()
try:
await self._keep_alive_timeout_handle
except asyncio.CancelledError:
pass
self._keep_alive_timeout_handle = None
if self.protocol.idle:
self._keep_alive_timeout_handle = self.loop.create_task(
_call_later(self.config.keep_alive_timeout, self._timeout)
)
async def _timeout(self) -> None:
await self.protocol.handle(Closed())
self.writer.close()
async def _call_later(timeout: float, callback: Callable) -> None:
await asyncio.sleep(timeout)
await asyncio.shield(callback())

View File

@@ -0,0 +1,55 @@
import asyncio
from typing import Any, cast, Optional, Tuple, TYPE_CHECKING
from .context import Context
from .task_group import TaskGroup
from ..config import Config
from ..events import Closed, Event, RawData
from ..typing import ASGIFramework
from ..utils import parse_socket_addr
if TYPE_CHECKING:
# h3/Quic is an optional part of Hypercorn
from ..protocol.quic import QuicProtocol # noqa: F401
class UDPServer(asyncio.DatagramProtocol):
def __init__(self, app: ASGIFramework, loop: asyncio.AbstractEventLoop, config: Config) -> None:
self.app = app
self.config = config
self.loop = loop
self.protocol: "QuicProtocol"
self.protocol_queue: asyncio.Queue = asyncio.Queue(10)
self.transport: Optional[asyncio.DatagramTransport] = None
self.loop.create_task(self._consume_events())
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
# h3/Quic is an optional part of Hypercorn
from ..protocol.quic import QuicProtocol # noqa: F811
self.transport = transport
socket = self.transport.get_extra_info("socket")
server = parse_socket_addr(socket.family, socket.getsockname())
task_group = TaskGroup(self.loop)
context = Context(task_group)
self.protocol = QuicProtocol(
self.app, self.config, cast(Any, context), server, self.protocol_send
)
def datagram_received(self, data: bytes, address: Tuple[bytes, str]) -> None: # type: ignore
try:
self.protocol_queue.put_nowait(RawData(data=data, address=address)) # type: ignore
except asyncio.QueueFull:
pass # Just throw the data away, is UDP
async def protocol_send(self, event: Event) -> None:
if isinstance(event, RawData):
self.transport.sendto(event.data, event.address)
async def _consume_events(self) -> None:
while True:
event = await self.protocol_queue.get()
await self.protocol.handle(event)
if isinstance(event, Closed):
break

View File

@@ -0,0 +1,373 @@
import importlib
import importlib.util
import logging
import os
import socket
import ssl
import stat
import types
import warnings
from dataclasses import dataclass
from ssl import SSLContext, VerifyFlags, VerifyMode
from time import time
from typing import Any, AnyStr, Dict, List, Mapping, Optional, Tuple, Type, Union
from wsgiref.handlers import format_date_time
import toml
from .logging import Logger
BYTES = 1
OCTETS = 1
SECONDS = 1.0
FilePath = Union[AnyStr, os.PathLike]
SocketKind = Union[int, socket.SocketKind]
@dataclass
class Sockets:
secure_sockets: List[socket.socket]
insecure_sockets: List[socket.socket]
quic_sockets: List[socket.socket]
class SocketTypeError(Exception):
def __init__(self, expected: SocketKind, actual: SocketKind) -> None:
super().__init__(
f'Unexpected socket type, wanted "{socket.SocketKind(expected)}" got '
f'"{socket.SocketKind(actual)}"'
)
class Config:
_bind = ["127.0.0.1:8000"]
_insecure_bind: List[str] = []
_quic_bind: List[str] = []
_quic_addresses: List[Tuple] = []
_log: Optional[Logger] = None
access_log_format = '%(h)s %(l)s %(l)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
accesslog: Union[logging.Logger, str, None] = None
alpn_protocols = ["h2", "http/1.1"]
alt_svc_headers: List[str] = []
application_path: str
backlog = 100
ca_certs: Optional[str] = None
certfile: Optional[str] = None
ciphers: str = "ECDHE+AESGCM"
debug = False
dogstatsd_tags = ""
errorlog: Union[logging.Logger, str, None] = "-"
graceful_timeout: float = 3 * SECONDS
group: Optional[int] = None
h11_max_incomplete_size = 16 * 1024 * BYTES
h2_max_concurrent_streams = 100
h2_max_header_list_size = 2 ** 16
h2_max_inbound_frame_size = 2 ** 14 * OCTETS
include_server_header = True
keep_alive_timeout = 5 * SECONDS
keyfile: Optional[str] = None
logconfig: Optional[str] = None
logconfig_dict: Optional[dict] = None
logger_class = Logger
loglevel: str = "INFO"
max_app_queue_size: int = 10
pid_path: Optional[str] = None
root_path = ""
server_names: List[str] = []
shutdown_timeout = 60 * SECONDS
ssl_handshake_timeout = 60 * SECONDS
startup_timeout = 60 * SECONDS
statsd_host: Optional[str] = None
statsd_prefix = ""
umask: Optional[int] = None
use_reloader = False
user: Optional[int] = None
verify_flags: Optional[VerifyFlags] = None
verify_mode: Optional[VerifyMode] = None
websocket_max_message_size = 16 * 1024 * 1024 * BYTES
websocket_ping_interval: Optional[int] = None
worker_class = "asyncio"
workers = 1
def set_cert_reqs(self, value: int) -> None:
warnings.warn("Please use verify_mode instead", Warning)
self.verify_mode = VerifyMode(value)
cert_reqs = property(None, set_cert_reqs)
@property
def log(self) -> Logger:
if self._log is None:
self._log = self.logger_class(self)
return self._log
@property
def bind(self) -> List[str]:
return self._bind
@bind.setter
def bind(self, value: Union[List[str], str]) -> None:
if isinstance(value, str):
self._bind = [value]
else:
self._bind = value
@property
def insecure_bind(self) -> List[str]:
return self._insecure_bind
@insecure_bind.setter
def insecure_bind(self, value: Union[List[str], str]) -> None:
if isinstance(value, str):
self._insecure_bind = [value]
else:
self._insecure_bind = value
@property
def quic_bind(self) -> List[str]:
return self._quic_bind
@quic_bind.setter
def quic_bind(self, value: Union[List[str], str]) -> None:
if isinstance(value, str):
self._quic_bind = [value]
else:
self._quic_bind = value
def create_ssl_context(self) -> Optional[SSLContext]:
if not self.ssl_enabled:
return None
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.set_ciphers(self.ciphers)
cipher_opts = 0
for attr in ["OP_NO_SSLv2", "OP_NO_SSLv3", "OP_NO_TLSv1", "OP_NO_TLSv1_1"]:
if hasattr(ssl, attr): # To be future proof
cipher_opts |= getattr(ssl, attr)
context.options |= cipher_opts # RFC 7540 Section 9.2: MUST be TLS >=1.2
context.options |= ssl.OP_NO_COMPRESSION # RFC 7540 Section 9.2.1: MUST disable compression
context.set_alpn_protocols(self.alpn_protocols)
if self.certfile is not None and self.keyfile is not None:
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
if self.ca_certs is not None:
context.load_verify_locations(self.ca_certs)
if self.verify_mode is not None:
context.verify_mode = self.verify_mode
if self.verify_flags is not None:
context.verify_flags = self.verify_flags
return context
@property
def ssl_enabled(self) -> bool:
return self.certfile is not None and self.keyfile is not None
def create_sockets(self) -> Sockets:
if self.ssl_enabled:
secure_sockets = self._create_sockets(self.bind)
insecure_sockets = self._create_sockets(self.insecure_bind)
quic_sockets = self._create_sockets(self.quic_bind, socket.SOCK_DGRAM)
self._set_quic_addresses(quic_sockets)
else:
secure_sockets = []
insecure_sockets = self._create_sockets(self.bind)
quic_sockets = []
return Sockets(secure_sockets, insecure_sockets, quic_sockets)
def _set_quic_addresses(self, sockets: List[socket.socket]) -> None:
self._quic_addresses = []
for sock in sockets:
name = sock.getsockname()
if type(name) is not str and len(name) >= 2:
self._quic_addresses.append(name)
else:
warnings.warn(
f'Cannot create a alt-svc header for the QUIC socket with address "{name}"',
Warning,
)
def _create_sockets(
self, binds: List[str], type_: int = socket.SOCK_STREAM
) -> List[socket.socket]:
sockets: List[socket.socket] = []
for bind in binds:
binding: Any = None
if bind.startswith("unix:"):
sock = socket.socket(socket.AF_UNIX, type_)
binding = bind[5:]
try:
if stat.S_ISSOCK(os.stat(binding).st_mode):
os.remove(binding)
except FileNotFoundError:
pass
elif bind.startswith("fd://"):
sock = socket.socket(fileno=int(bind[5:]))
actual_type = sock.getsockopt(socket.SOL_SOCKET, socket.SO_TYPE)
if actual_type != type_:
raise SocketTypeError(type_, actual_type)
else:
bind = bind.replace("[", "").replace("]", "")
try:
value = bind.rsplit(":", 1)
host, port = value[0], int(value[1])
except (ValueError, IndexError):
host, port = bind, 8000
sock = socket.socket(socket.AF_INET6 if ":" in host else socket.AF_INET, type_)
if self.workers > 1:
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except AttributeError:
pass
binding = (host, port)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if bind.startswith("unix:"):
if self.umask is not None:
current_umask = os.umask(self.umask)
sock.bind(binding)
if self.user is not None and self.group is not None:
os.chown(binding, self.user, self.group)
if self.umask is not None:
os.umask(current_umask)
elif bind.startswith("fd://"):
pass
else:
sock.bind(binding)
sock.setblocking(False)
try:
sock.set_inheritable(True)
except AttributeError:
pass
sockets.append(sock)
return sockets
def response_headers(self, protocol: str) -> List[Tuple[bytes, bytes]]:
headers = [(b"date", format_date_time(time()).encode("ascii"))]
if self.include_server_header:
headers.append((b"server", f"hypercorn-{protocol}".encode("ascii")))
for alt_svc_header in self.alt_svc_headers:
headers.append((b"alt-svc", alt_svc_header.encode()))
if len(self.alt_svc_headers) == 0 and self._quic_addresses:
from aioquic.h3.connection import H3_ALPN
for version in H3_ALPN:
for addr in self._quic_addresses:
port = addr[1]
headers.append((b"alt-svc", b'%s=":%d"; ma=3600' % (version.encode(), port)))
return headers
def set_statsd_logger_class(self, statsd_logger: Type[Logger]) -> None:
if self.logger_class == Logger and self.statsd_host is not None:
self.logger_class = statsd_logger
@classmethod
def from_mapping(
cls: Type["Config"], mapping: Optional[Mapping[str, Any]] = None, **kwargs: Any
) -> "Config":
"""Create a configuration from a mapping.
This allows either a mapping to be directly passed or as
keyword arguments, for example,
.. code-block:: python
config = {'keep_alive_timeout': 10}
Config.from_mapping(config)
Config.from_mapping(keep_alive_timeout=10)
Arguments:
mapping: Optionally a mapping object.
kwargs: Optionally a collection of keyword arguments to
form a mapping.
"""
mappings: Dict[str, Any] = {}
if mapping is not None:
mappings.update(mapping)
mappings.update(kwargs)
config = cls()
for key, value in mappings.items():
try:
setattr(config, key, value)
except AttributeError:
pass
return config
@classmethod
def from_pyfile(cls: Type["Config"], filename: FilePath) -> "Config":
"""Create a configuration from a Python file.
.. code-block:: python
Config.from_pyfile('hypercorn_config.py')
Arguments:
filename: The filename which gives the path to the file.
"""
file_path = os.fspath(filename)
spec = importlib.util.spec_from_file_location("module.name", file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return cls.from_object(module)
@classmethod
def from_toml(cls: Type["Config"], filename: FilePath) -> "Config":
"""Load the configuration values from a TOML formatted file.
This allows configuration to be loaded as so
.. code-block:: python
Config.from_toml('config.toml')
Arguments:
filename: The filename which gives the path to the file.
"""
file_path = os.fspath(filename)
with open(file_path) as file_:
data = toml.load(file_)
return cls.from_mapping(data)
@classmethod
def from_object(cls: Type["Config"], instance: Union[object, str]) -> "Config":
"""Create a configuration from a Python object.
This can be used to reference modules or objects within
modules for example,
.. code-block:: python
Config.from_object('module')
Config.from_object('module.instance')
from module import instance
Config.from_object(instance)
are valid.
Arguments:
instance: Either a str referencing a python object or the
object itself.
"""
if isinstance(instance, str):
try:
instance = importlib.import_module(instance)
except ImportError:
path, config = instance.rsplit(".", 1) # type: ignore
module = importlib.import_module(path)
instance = getattr(module, config)
mapping = {
key: getattr(instance, key)
for key in dir(instance)
if not isinstance(getattr(instance, key), types.ModuleType)
}
return cls.from_mapping(mapping)

View File

@@ -0,0 +1,25 @@
from abc import ABC
from dataclasses import dataclass
from typing import Optional, Tuple
class Event(ABC):
pass
@dataclass(frozen=True)
class RawData(Event):
data: bytes
address: Optional[Tuple[str, int]] = None
@dataclass(frozen=True)
class Closed(Event):
pass
@dataclass(frozen=True)
class Updated(Event):
# Indicate that the protocol has updated (although it has nothing
# for the server to do).
pass

View File

@@ -0,0 +1,182 @@
import logging
import os
import sys
import time
from http import HTTPStatus
from logging.config import dictConfig, fileConfig
from typing import Any, IO, Mapping, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
from .config import Config
from .typing import ResponseSummary, WWWScope
def _create_logger(
name: str,
target: Union[logging.Logger, str, None],
level: Optional[str],
sys_default: IO,
*,
propagate: bool = True,
) -> Optional[logging.Logger]:
if isinstance(target, logging.Logger):
return target
if target:
logger = logging.getLogger(name)
logger.handlers = [
logging.StreamHandler(sys_default) if target == "-" else logging.FileHandler(target)
]
logger.propagate = propagate
formatter = logging.Formatter(
"%(asctime)s [%(process)d] [%(levelname)s] %(message)s",
"[%Y-%m-%d %H:%M:%S %z]",
)
logger.handlers[0].setFormatter(formatter)
if level is not None:
logger.setLevel(logging.getLevelName(level.upper()))
return logger
else:
return None
class Logger:
def __init__(self, config: "Config") -> None:
self.access_log_format = config.access_log_format
self.access_logger = _create_logger(
"hypercorn.access",
config.accesslog,
config.loglevel,
sys.stdout,
propagate=False,
)
self.error_logger = _create_logger(
"hypercorn.error", config.errorlog, config.loglevel, sys.stderr
)
if config.logconfig is not None:
log_config = {
"__file__": config.logconfig,
"here": os.path.dirname(config.logconfig),
}
fileConfig(config.logconfig, defaults=log_config, disable_existing_loggers=False)
else:
if config.logconfig_dict is not None:
dictConfig(config.logconfig_dict)
async def access(
self, request: "WWWScope", response: "ResponseSummary", request_time: float
) -> None:
if self.access_logger is not None:
self.access_logger.info(
self.access_log_format, self.atoms(request, response, request_time)
)
async def critical(self, message: str, *args: Any, **kwargs: Any) -> None:
if self.error_logger is not None:
self.error_logger.critical(message, *args, **kwargs)
async def error(self, message: str, *args: Any, **kwargs: Any) -> None:
if self.error_logger is not None:
self.error_logger.error(message, *args, **kwargs)
async def warning(self, message: str, *args: Any, **kwargs: Any) -> None:
if self.error_logger is not None:
self.error_logger.warning(message, *args, **kwargs)
async def info(self, message: str, *args: Any, **kwargs: Any) -> None:
if self.error_logger is not None:
self.error_logger.info(message, *args, **kwargs)
async def debug(self, message: str, *args: Any, **kwargs: Any) -> None:
if self.error_logger is not None:
self.error_logger.debug(message, *args, **kwargs)
async def exception(self, message: str, *args: Any, **kwargs: Any) -> None:
if self.error_logger is not None:
self.error_logger.exception(message, *args, **kwargs)
async def log(self, level: int, message: str, *args: Any, **kwargs: Any) -> None:
if self.error_logger is not None:
self.error_logger.log(level, message, *args, **kwargs)
def atoms(
self, request: "WWWScope", response: "ResponseSummary", request_time: float
) -> Mapping[str, str]:
"""Create and return an access log atoms dictionary.
This can be overidden and customised if desired. It should
return a mapping between an access log format key and a value.
"""
return AccessLogAtoms(request, response, request_time)
def __getattr__(self, name: str) -> Any:
return getattr(self.error_logger, name)
class AccessLogAtoms(dict):
def __init__(
self, request: "WWWScope", response: "ResponseSummary", request_time: float
) -> None:
for name, value in request["headers"]:
self[f"{{{name.decode('latin1').lower()}}}i"] = value.decode("latin1")
for name, value in response.get("headers", []):
self[f"{{{name.decode('latin1').lower()}}}o"] = value.decode("latin1")
for name, value in os.environ.items():
self[f"{{{name.lower()}}}e"] = value
protocol = request.get("http_version", "ws")
client = request.get("client")
if client is None:
remote_addr = None
elif len(client) == 2:
remote_addr = f"{client[0]}:{client[1]}"
elif len(client) == 1:
remote_addr = client[0]
else: # make sure not to throw UnboundLocalError
remote_addr = f"<???{client}???>"
if request["type"] == "http":
method = request["method"]
else:
method = "GET"
query_string = request["query_string"].decode()
path_with_qs = request["path"] + ("?" + query_string if query_string else "")
status_code = response["status"]
try:
status_phrase = HTTPStatus(status_code).phrase
except ValueError:
status_phrase = f"<???{status_code}???>"
self.update(
{
"h": remote_addr,
"l": "-",
"t": time.strftime("[%d/%b/%Y:%H:%M:%S %z]"),
"r": f"{method} {request['path']} {protocol}",
"R": f"{method} {path_with_qs} {protocol}",
"s": response["status"],
"st": status_phrase,
"S": request["scheme"],
"m": method,
"U": request["path"],
"Uq": path_with_qs,
"q": query_string,
"H": protocol,
"b": self["{Content-Length}o"],
"B": self["{Content-Length}o"],
"f": self["{Referer}i"],
"a": self["{User-Agent}i"],
"T": int(request_time),
"D": int(request_time * 1_000_000),
"L": f"{request_time:.6f}",
"p": f"<{os.getpid()}>",
}
)
def __getitem__(self, key: str) -> str:
try:
if key.startswith("{"):
return super().__getitem__(key.lower())
else:
return super().__getitem__(key)
except KeyError:
return "-"

View File

@@ -0,0 +1,10 @@
from .dispatcher import DispatcherMiddleware
from .http_to_https import HTTPToHTTPSRedirectMiddleware
from .wsgi import AsyncioWSGIMiddleware, TrioWSGIMiddleware
__all__ = (
"AsyncioWSGIMiddleware",
"DispatcherMiddleware",
"HTTPToHTTPSRedirectMiddleware",
"TrioWSGIMiddleware",
)

View File

@@ -0,0 +1,107 @@
import asyncio
from functools import partial
from typing import Callable, Dict
from ..asyncio.task_group import TaskGroup
from ..typing import ASGIFramework, Scope
from ..utils import invoke_asgi
MAX_QUEUE_SIZE = 10
class _DispatcherMiddleware:
def __init__(self, mounts: Dict[str, ASGIFramework]) -> None:
self.mounts = mounts
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
if scope["type"] == "lifespan":
await self._handle_lifespan(scope, receive, send)
else:
for path, app in self.mounts.items():
if scope["path"].startswith(path):
scope["path"] = scope["path"][len(path) :] or "/" # type: ignore
return await invoke_asgi(app, scope, receive, send)
await send(
{
"type": "http.response.start",
"status": 404,
"headers": [(b"content-length", b"0")],
}
)
await send({"type": "http.response.body"})
async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
pass
class AsyncioDispatcherMiddleware(_DispatcherMiddleware):
async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
self.app_queues: Dict[str, asyncio.Queue] = {
path: asyncio.Queue(MAX_QUEUE_SIZE) for path in self.mounts
}
self.startup_complete = {path: False for path in self.mounts}
self.shutdown_complete = {path: False for path in self.mounts}
async with TaskGroup(asyncio.get_event_loop()) as task_group:
for path, app in self.mounts.items():
task_group.spawn(
invoke_asgi(
app, scope, self.app_queues[path].get, partial(self.send, path, send)
)
)
while True:
message = await receive()
for queue in self.app_queues.values():
await queue.put(message)
if message["type"] == "lifespan.shutdown":
break
async def send(self, path: str, send: Callable, message: dict) -> None:
if message["type"] == "lifespan.startup.complete":
self.startup_complete[path] = True
if all(self.startup_complete.values()):
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown.complete":
self.shutdown_complete[path] = True
if all(self.shutdown_complete.values()):
await send({"type": "lifespan.shutdown.complete"})
class TrioDispatcherMiddleware(_DispatcherMiddleware):
async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
import trio
self.app_queues = {path: trio.open_memory_channel(MAX_QUEUE_SIZE) for path in self.mounts}
self.startup_complete = {path: False for path in self.mounts}
self.shutdown_complete = {path: False for path in self.mounts}
async with trio.open_nursery() as nursery:
for path, app in self.mounts.items():
nursery.start_soon(
invoke_asgi,
app,
scope,
self.app_queues[path][1].receive,
partial(self.send, path, send),
)
while True:
message = await receive()
for channels in self.app_queues.values():
await channels[0].send(message)
if message["type"] == "lifespan.shutdown":
break
async def send(self, path: str, send: Callable, message: dict) -> None:
if message["type"] == "lifespan.startup.complete":
self.startup_complete[path] = True
if all(self.startup_complete.values()):
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown.complete":
self.shutdown_complete[path] = True
if all(self.shutdown_complete.values()):
await send({"type": "lifespan.shutdown.complete"})
DispatcherMiddleware = AsyncioDispatcherMiddleware # Remove with version 0.11

View File

@@ -0,0 +1,66 @@
from typing import Callable, Optional
from urllib.parse import urlunsplit
from ..typing import ASGIFramework, HTTPScope, Scope, WebsocketScope, WWWScope
from ..utils import invoke_asgi
class HTTPToHTTPSRedirectMiddleware:
def __init__(self, app: ASGIFramework, host: Optional[str]) -> None:
self.app = app
self.host = host
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
if scope["type"] == "http" and scope["scheme"] == "http":
await self._send_http_redirect(scope, send)
elif scope["type"] == "websocket" and scope["scheme"] == "ws":
# If the server supports the WebSocket Denial Response
# extension we can send a redirection response, if not we
# can only deny the WebSocket connection.
if "websocket.http.response" in scope.get("extensions", {}):
await self._send_websocket_redirect(scope, send)
else:
await send({"type": "websocket.close"})
else:
return await invoke_asgi(self.app, scope, receive, send)
async def _send_http_redirect(self, scope: HTTPScope, send: Callable) -> None:
new_url = self._new_url("https", scope)
await send(
{
"type": "http.response.start",
"status": 307,
"headers": [(b"location", new_url.encode())],
}
)
await send({"type": "http.response.body"})
async def _send_websocket_redirect(self, scope: WebsocketScope, send: Callable) -> None:
# If the HTTP version is 2 we should redirect with a https
# scheme not wss.
scheme = "wss"
if scope.get("http_version", "1.1") == "2":
scheme = "https"
new_url = self._new_url(scheme, scope)
await send(
{
"type": "websocket.http.response.start",
"status": 307,
"headers": [(b"location", new_url.encode())],
}
)
await send({"type": "websocket.http.response.body"})
def _new_url(self, scheme: str, scope: WWWScope) -> str:
host = self.host
if host is None:
for key, value in scope["headers"]:
if key == b"host":
host = value.decode("latin-1")
break
if host is None:
raise ValueError("Host to redirect to cannot be determined")
path = scope.get("root_path", "") + scope["raw_path"].decode()
return urlunsplit((scheme, host, path, scope["query_string"].decode(), ""))

View File

@@ -0,0 +1,132 @@
import asyncio
from functools import partial
from io import BytesIO
from typing import Callable, Iterable, List, Optional, Tuple
from ..typing import HTTPScope, Scope
MAX_BODY_SIZE = 2 ** 16
WSGICallable = Callable[[dict, Callable], Iterable[bytes]]
class _WSGIMiddleware:
def __init__(self, wsgi_app: WSGICallable, max_body_size: int = MAX_BODY_SIZE) -> None:
self.wsgi_app = wsgi_app
self.max_body_size = max_body_size
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
if scope["type"] == "http":
status_code, headers, body = await self._handle_http(scope, receive, send)
await send({"type": "http.response.start", "status": status_code, "headers": headers})
await send({"type": "http.response.body", "body": body})
elif scope["type"] == "websocket":
await send({"type": "websocket.close"})
elif scope["type"] == "lifespan":
return
else:
raise Exception(f"Unknown scope type, {scope['type']}")
async def _handle_http(
self, scope: HTTPScope, receive: Callable, send: Callable
) -> Tuple[int, list, bytes]:
pass
class AsyncioWSGIMiddleware(_WSGIMiddleware):
async def _handle_http(
self, scope: HTTPScope, receive: Callable, send: Callable
) -> Tuple[int, list, bytes]:
loop = asyncio.get_event_loop()
instance = _WSGIInstance(self.wsgi_app, self.max_body_size)
return await instance.handle_http(scope, receive, partial(loop.run_in_executor, None))
class TrioWSGIMiddleware(_WSGIMiddleware):
async def _handle_http(
self, scope: HTTPScope, receive: Callable, send: Callable
) -> Tuple[int, list, bytes]:
import trio
instance = _WSGIInstance(self.wsgi_app, self.max_body_size)
return await instance.handle_http(scope, receive, trio.to_thread.run_sync)
class _WSGIInstance:
def __init__(self, wsgi_app: WSGICallable, max_body_size: int = MAX_BODY_SIZE) -> None:
self.wsgi_app = wsgi_app
self.max_body_size = max_body_size
self.status_code = 500
self.headers: list = []
async def handle_http(
self, scope: HTTPScope, receive: Callable, spawn: Callable
) -> Tuple[int, list, bytes]:
self.scope = scope
body = bytearray()
while True:
message = await receive()
body.extend(message.get("body", b""))
if len(body) > self.max_body_size:
return 400, [], b""
if not message.get("more_body"):
break
return await spawn(self.run_wsgi_app, body)
def _start_response(
self,
status: str,
response_headers: List[Tuple[str, str]],
exc_info: Optional[Exception] = None,
) -> None:
raw, _ = status.split(" ", 1)
self.status_code = int(raw)
self.headers = [
(name.lower().encode("ascii"), value.encode("ascii"))
for name, value in response_headers
]
def run_wsgi_app(self, body: bytes) -> Tuple[int, list, bytes]:
environ = _build_environ(self.scope, body)
body = bytearray()
for output in self.wsgi_app(environ, self._start_response):
body.extend(output)
return self.status_code, self.headers, body
def _build_environ(scope: HTTPScope, body: bytes) -> dict:
server = scope.get("server") or ("localhost", 80)
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_NAME": server[0],
"SERVER_PORT": server[1],
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": BytesIO(body),
"wsgi.errors": BytesIO(),
"wsgi.multithread": True,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
if "client" in scope:
environ["REMOTE_ADDR"] = scope["client"][0]
for raw_name, raw_value in scope.get("headers", []):
name = raw_name.decode("latin1")
if name == "content-length":
corrected_name = "CONTENT_LENGTH"
elif name == "content-type":
corrected_name = "CONTENT_TYPE"
else:
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case
value = raw_value.decode("latin1")
if corrected_name in environ:
value = environ[corrected_name] + "," + value # type: ignore
environ[corrected_name] = value
return environ

View File

@@ -0,0 +1,86 @@
from typing import Awaitable, Callable, Optional, Tuple, Union
from .h2 import H2Protocol
from .h11 import H2CProtocolRequired, H2ProtocolAssumed, H11Protocol
from ..config import Config
from ..events import Event, RawData
from ..typing import ASGIFramework, Context
class ProtocolWrapper:
def __init__(
self,
app: ASGIFramework,
config: Config,
context: Context,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
alpn_protocol: Optional[str] = None,
) -> None:
self.app = app
self.config = config
self.context = context
self.ssl = ssl
self.client = client
self.server = server
self.send = send
self.protocol: Union[H11Protocol, H2Protocol]
if alpn_protocol == "h2":
self.protocol = H2Protocol(
self.app,
self.config,
self.context,
self.ssl,
self.client,
self.server,
self.send,
)
else:
self.protocol = H11Protocol(
self.app,
self.config,
self.context,
self.ssl,
self.client,
self.server,
self.send,
)
@property
def idle(self) -> bool:
return self.protocol.idle
async def initiate(self) -> None:
return await self.protocol.initiate()
async def handle(self, event: Event) -> None:
try:
return await self.protocol.handle(event)
except H2ProtocolAssumed as error:
self.protocol = H2Protocol(
self.app,
self.config,
self.context,
self.ssl,
self.client,
self.server,
self.send,
)
await self.protocol.initiate()
if error.data != b"":
return await self.protocol.handle(RawData(data=error.data))
except H2CProtocolRequired as error:
self.protocol = H2Protocol(
self.app,
self.config,
self.context,
self.ssl,
self.client,
self.server,
self.send,
)
await self.protocol.initiate(error.headers, error.settings)
if error.data != b"":
return await self.protocol.handle(RawData(data=error.data))

View File

@@ -0,0 +1,46 @@
from dataclasses import dataclass
from typing import List, Tuple
@dataclass(frozen=True)
class Event:
stream_id: int
@dataclass(frozen=True)
class Request(Event):
headers: List[Tuple[bytes, bytes]]
http_version: str
method: str
raw_path: bytes
@dataclass(frozen=True)
class Body(Event):
data: bytes
@dataclass(frozen=True)
class EndBody(Event):
pass
@dataclass(frozen=True)
class Data(Event):
data: bytes
@dataclass(frozen=True)
class EndData(Event):
pass
@dataclass(frozen=True)
class Response(Event):
headers: List[Tuple[bytes, bytes]]
status_code: int
@dataclass(frozen=True)
class StreamClosed(Event):
pass

View File

@@ -0,0 +1,289 @@
from itertools import chain
from typing import Awaitable, Callable, Optional, Tuple, Union
import h11
from .events import (
Body,
Data,
EndBody,
EndData,
Event as StreamEvent,
Request,
Response,
StreamClosed,
)
from .http_stream import HTTPStream
from .ws_stream import WSStream
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..typing import ASGIFramework, Context, H11SendableEvent
STREAM_ID = 1
class H2CProtocolRequired(Exception):
def __init__(self, data: bytes, request: h11.Request) -> None:
settings = ""
headers = [(b":method", request.method), (b":path", request.target)]
for name, value in request.headers:
if name.lower() == b"http2-settings":
settings = value.decode()
elif name.lower() == b"host":
headers.append((b":authority", value))
headers.append((name, value))
self.data = data
self.headers = headers
self.settings = settings
class H2ProtocolAssumed(Exception):
def __init__(self, data: bytes) -> None:
self.data = data
class H11WSConnection:
# This class matches the h11 interface, and either passes data
# through without altering it (for Data, EndData) or sends h11
# events (Response, Body, EndBody).
our_state = None # Prevents recycling the connection
they_are_waiting_for_100_continue = False
def __init__(self, h11_connection: h11.Connection) -> None:
self.buffer = bytearray(h11_connection.trailing_data[0])
self.h11_connection = h11_connection
def receive_data(self, data: bytes) -> None:
self.buffer.extend(data)
def next_event(self) -> Data:
if self.buffer:
event = Data(stream_id=STREAM_ID, data=bytes(self.buffer))
self.buffer = bytearray()
return event
else:
return h11.NEED_DATA
def send(self, event: H11SendableEvent) -> bytes:
return self.h11_connection.send(event)
class H11Protocol:
def __init__(
self,
app: ASGIFramework,
config: Config,
context: Context,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
) -> None:
self.app = app
self.can_read = context.event_class()
self.client = client
self.config = config
self.connection = h11.Connection(
h11.SERVER, max_incomplete_event_size=self.config.h11_max_incomplete_size
)
self.context = context
self.send = send
self.server = server
self.ssl = ssl
self.stream: Optional[Union[HTTPStream, WSStream]] = None
@property
def idle(self) -> bool:
return self.stream is None or self.stream.idle
async def initiate(self) -> None:
pass
async def handle(self, event: Event) -> None:
if isinstance(event, RawData):
self.connection.receive_data(event.data)
await self._handle_events()
elif isinstance(event, Closed):
if self.stream is not None:
await self._close_stream()
async def stream_send(self, event: StreamEvent) -> None:
if isinstance(event, Response):
if event.status_code >= 200:
await self._send_h11_event(
h11.Response(
headers=chain(event.headers, self.config.response_headers("h11")),
status_code=event.status_code,
)
)
else:
await self._send_h11_event(
h11.InformationalResponse(
headers=chain(event.headers, self.config.response_headers("h11")),
status_code=event.status_code,
)
)
elif isinstance(event, Body):
await self._send_h11_event(h11.Data(data=event.data))
elif isinstance(event, EndBody):
await self._send_h11_event(h11.EndOfMessage())
elif isinstance(event, Data):
await self.send(RawData(data=event.data))
elif isinstance(event, EndData):
pass
elif isinstance(event, StreamClosed):
await self._maybe_recycle()
async def _handle_events(self) -> None:
while True:
if self.connection.they_are_waiting_for_100_continue:
await self._send_h11_event(
h11.InformationalResponse(
status_code=100, headers=self.config.response_headers("h11")
)
)
try:
event = self.connection.next_event()
except h11.RemoteProtocolError:
if self.connection.our_state in {h11.IDLE, h11.SEND_RESPONSE}:
await self._send_error_response(400)
await self.send(Closed())
break
else:
if isinstance(event, h11.Request):
await self._check_protocol(event)
await self._create_stream(event)
elif event is h11.PAUSED:
await self.can_read.clear()
await self.send(Updated())
await self.can_read.wait()
elif isinstance(event, h11.ConnectionClosed) or event is h11.NEED_DATA:
break
elif self.stream is None:
break
elif isinstance(event, h11.Data):
await self.stream.handle(Body(stream_id=STREAM_ID, data=event.data))
elif isinstance(event, h11.EndOfMessage):
await self.stream.handle(EndBody(stream_id=STREAM_ID))
elif isinstance(event, Data):
# WebSocket pass through
await self.stream.handle(event)
async def _create_stream(self, request: h11.Request) -> None:
upgrade_value = ""
connection_value = ""
for name, value in request.headers:
sanitised_name = name.decode("latin1").strip().lower()
if sanitised_name == "upgrade":
upgrade_value = value.decode("latin1").strip()
elif sanitised_name == "connection":
connection_value = value.decode("latin1").strip()
connection_tokens = connection_value.lower().split(",")
if (
any(token.strip() == "upgrade" for token in connection_tokens)
and upgrade_value.lower() == "websocket"
and request.method.decode("ascii").upper() == "GET"
):
self.stream = WSStream(
self.app,
self.config,
self.context,
self.ssl,
self.client,
self.server,
self.stream_send,
STREAM_ID,
)
self.connection = H11WSConnection(self.connection)
else:
self.stream = HTTPStream(
self.app,
self.config,
self.context,
self.ssl,
self.client,
self.server,
self.stream_send,
STREAM_ID,
)
await self.stream.handle(
Request(
stream_id=STREAM_ID,
headers=request.headers,
http_version=request.http_version.decode(),
method=request.method.decode("ascii").upper(),
raw_path=request.target,
)
)
async def _send_h11_event(self, event: H11SendableEvent) -> None:
try:
data = self.connection.send(event)
except h11.LocalProtocolError:
if self.connection.their_state != h11.ERROR:
raise
else:
await self.send(RawData(data=data))
async def _send_error_response(self, status_code: int) -> None:
await self._send_h11_event(
h11.Response(
status_code=status_code,
headers=chain(
[(b"content-length", b"0"), (b"connection", b"close")],
self.config.response_headers("h11"),
),
)
)
await self._send_h11_event(h11.EndOfMessage())
async def _maybe_recycle(self) -> None:
await self._close_stream()
if self.connection.our_state is h11.DONE:
try:
self.connection.start_next_cycle()
except h11.LocalProtocolError:
await self.send(Closed())
else:
self.response = None
self.scope = None
await self.can_read.set()
await self.send(Updated())
else:
await self.can_read.set()
await self.send(Closed())
async def _close_stream(self) -> None:
if self.stream is not None:
await self.stream.handle(StreamClosed(stream_id=STREAM_ID))
self.stream = None
async def _check_protocol(self, event: h11.Request) -> None:
upgrade_value = ""
has_body = False
for name, value in event.headers:
sanitised_name = name.decode("latin1").strip().lower()
if sanitised_name == "upgrade":
upgrade_value = value.decode("latin1").strip()
elif sanitised_name in {"content-length", "transfer-encoding"}:
has_body = True
# h2c Upgrade requests with a body are a pain as the body must
# be fully recieved in HTTP/1.1 before the upgrade response
# and HTTP/2 takes over, so Hypercorn ignores the upgrade and
# responds in HTTP/1.1. Use a preflight OPTIONS request to
# initiate the upgrade if really required (or just use h2).
if upgrade_value.lower() == "h2c" and not has_body:
await self._send_h11_event(
h11.InformationalResponse(
status_code=101,
headers=self.config.response_headers("h11")
+ [(b"connection", b"upgrade"), (b"upgrade", b"h2c")],
)
)
raise H2CProtocolRequired(self.connection.trailing_data[0], event)
elif event.method == b"PRI" and event.target == b"*" and event.http_version == b"2.0":
raise H2ProtocolAssumed(b"PRI * HTTP/2.0\r\n\r\n" + self.connection.trailing_data[0])

View File

@@ -0,0 +1,362 @@
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
import h2
import h2.connection
import h2.events
import h2.exceptions
import priority
from .events import (
Body,
Data,
EndBody,
EndData,
Event as StreamEvent,
Request,
Response,
StreamClosed,
)
from .http_stream import HTTPStream
from .ws_stream import WSStream
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..typing import ASGIFramework, Context, Event as IOEvent
from ..utils import filter_pseudo_headers
BUFFER_HIGH_WATER = 2 * 2 ** 14 # Twice the default max frame size (two frames worth)
BUFFER_LOW_WATER = BUFFER_HIGH_WATER / 2
class BufferCompleteError(Exception):
pass
class StreamBuffer:
def __init__(self, event_class: Type[IOEvent]) -> None:
self.buffer = bytearray()
self._complete = False
self._is_empty = event_class()
self._paused = event_class()
async def drain(self) -> None:
await self._is_empty.wait()
def set_complete(self) -> None:
self._complete = True
async def close(self) -> None:
self._complete = True
self.buffer = bytearray()
await self._is_empty.set()
await self._paused.set()
@property
def complete(self) -> bool:
return self._complete and len(self.buffer) == 0
async def push(self, data: bytes) -> None:
if self._complete:
raise BufferCompleteError()
self.buffer.extend(data)
await self._is_empty.clear()
if len(self.buffer) >= BUFFER_HIGH_WATER:
await self._paused.wait()
await self._paused.clear()
async def pop(self, max_length: int) -> bytes:
length = min(len(self.buffer), max_length)
data = bytes(self.buffer[:length])
del self.buffer[:length]
if len(data) < BUFFER_LOW_WATER:
await self._paused.set()
if len(self.buffer) == 0:
await self._is_empty.set()
return data
class H2Protocol:
def __init__(
self,
app: ASGIFramework,
config: Config,
context: Context,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
) -> None:
self.app = app
self.client = client
self.closed = False
self.config = config
self.context = context
self.connection = h2.connection.H2Connection(
config=h2.config.H2Configuration(client_side=False, header_encoding=None)
)
self.connection.DEFAULT_MAX_INBOUND_FRAME_SIZE = config.h2_max_inbound_frame_size
self.connection.local_settings = h2.settings.Settings(
client=False,
initial_values={
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: config.h2_max_concurrent_streams,
h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: config.h2_max_header_list_size,
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL: 1,
},
)
self.send = send
self.server = server
self.ssl = ssl
self.streams: Dict[int, Union[HTTPStream, WSStream]] = {}
# The below are used by the sending task
self.has_data = self.context.event_class()
self.priority = priority.PriorityTree()
self.stream_buffers: Dict[int, StreamBuffer] = {}
@property
def idle(self) -> bool:
return len(self.streams) == 0 or all(stream.idle for stream in self.streams.values())
async def initiate(
self, headers: Optional[List[Tuple[bytes, bytes]]] = None, settings: Optional[str] = None
) -> None:
if settings is not None:
self.connection.initiate_upgrade_connection(settings)
else:
self.connection.initiate_connection()
await self._flush()
if headers is not None:
event = h2.events.RequestReceived()
event.stream_id = 1
event.headers = headers
await self._create_stream(event)
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
self.context.spawn(self.send_task)
async def send_task(self) -> None:
# This should be run in a seperate task to the rest of this
# class. This allows it seperately choose when to send,
# crucially in what order.
while not self.closed:
try:
stream_id = next(self.priority)
except priority.DeadlockError:
await self.has_data.wait()
await self.has_data.clear()
else:
await self._send_data(stream_id)
async def _send_data(self, stream_id: int) -> None:
try:
chunk_size = min(
self.connection.local_flow_control_window(stream_id),
self.connection.max_outbound_frame_size,
)
chunk_size = max(0, chunk_size)
data = await self.stream_buffers[stream_id].pop(chunk_size)
if data:
self.connection.send_data(stream_id, data)
await self._flush()
else:
self.priority.block(stream_id)
if self.stream_buffers[stream_id].complete:
self.connection.end_stream(stream_id)
await self._flush()
del self.stream_buffers[stream_id]
self.priority.remove_stream(stream_id)
except (h2.exceptions.StreamClosedError, KeyError, h2.exceptions.ProtocolError):
# Stream or connection has closed whilst waiting to send
# data, not a problem - just force close it.
await self.stream_buffers[stream_id].close()
del self.stream_buffers[stream_id]
self.priority.remove_stream(stream_id)
async def handle(self, event: Event) -> None:
if isinstance(event, RawData):
try:
events = self.connection.receive_data(event.data)
except h2.exceptions.ProtocolError:
await self._flush()
await self.send(Closed())
else:
await self._handle_events(events)
elif isinstance(event, Closed):
self.closed = True
stream_ids = list(self.streams.keys())
for stream_id in stream_ids:
await self._close_stream(stream_id)
await self.has_data.set()
async def stream_send(self, event: StreamEvent) -> None:
try:
if isinstance(event, Response):
self.connection.send_headers(
event.stream_id,
[(b":status", b"%d" % event.status_code)]
+ event.headers
+ self.config.response_headers("h2"),
)
await self._flush()
elif isinstance(event, (Body, Data)):
self.priority.unblock(event.stream_id)
await self.has_data.set()
await self.stream_buffers[event.stream_id].push(event.data)
elif isinstance(event, (EndBody, EndData)):
self.stream_buffers[event.stream_id].set_complete()
self.priority.unblock(event.stream_id)
await self.has_data.set()
await self.stream_buffers[event.stream_id].drain()
elif isinstance(event, StreamClosed):
await self._close_stream(event.stream_id)
await self.send(Updated())
elif isinstance(event, Request):
await self._create_server_push(event.stream_id, event.raw_path, event.headers)
except (
BufferCompleteError,
KeyError,
priority.MissingStreamError,
h2.exceptions.ProtocolError,
):
# Connection has closed whilst blocked on flow control or
# connection has advanced ahead of the last emitted event.
return
async def _handle_events(self, events: List[h2.events.Event]) -> None:
for event in events:
if isinstance(event, h2.events.RequestReceived):
await self._create_stream(event)
elif isinstance(event, h2.events.DataReceived):
await self.streams[event.stream_id].handle(
Body(stream_id=event.stream_id, data=event.data)
)
self.connection.acknowledge_received_data(
event.flow_controlled_length, event.stream_id
)
elif isinstance(event, h2.events.StreamEnded):
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
elif isinstance(event, h2.events.StreamReset):
await self._close_stream(event.stream_id)
await self._window_updated(event.stream_id)
elif isinstance(event, h2.events.WindowUpdated):
await self._window_updated(event.stream_id)
elif isinstance(event, h2.events.PriorityUpdated):
await self._priority_updated(event)
elif isinstance(event, h2.events.RemoteSettingsChanged):
if h2.settings.SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings:
await self._window_updated(None)
elif isinstance(event, h2.events.ConnectionTerminated):
await self.send(Closed())
await self._flush()
async def _flush(self) -> None:
data = self.connection.data_to_send()
if data != b"":
await self.send(RawData(data=data))
async def _window_updated(self, stream_id: Optional[int]) -> None:
if stream_id is None or stream_id == 0:
# Unblock all streams
for stream_id in list(self.stream_buffers.keys()):
self.priority.unblock(stream_id)
elif stream_id is not None and stream_id in self.stream_buffers:
self.priority.unblock(stream_id)
await self.has_data.set()
async def _priority_updated(self, event: h2.events.PriorityUpdated) -> None:
try:
self.priority.reprioritize(
stream_id=event.stream_id,
depends_on=event.depends_on or None,
weight=event.weight,
exclusive=event.exclusive,
)
except priority.MissingStreamError:
# Received PRIORITY frame before HEADERS frame
self.priority.insert_stream(
stream_id=event.stream_id,
depends_on=event.depends_on or None,
weight=event.weight,
exclusive=event.exclusive,
)
self.priority.block(event.stream_id)
await self.has_data.set()
async def _create_stream(self, request: h2.events.RequestReceived) -> None:
for name, value in request.headers:
if name == b":method":
method = value.decode("ascii").upper()
elif name == b":path":
raw_path = value
if method == "CONNECT":
self.streams[request.stream_id] = WSStream(
self.app,
self.config,
self.context,
self.ssl,
self.client,
self.server,
self.stream_send,
request.stream_id,
)
else:
self.streams[request.stream_id] = HTTPStream(
self.app,
self.config,
self.context,
self.ssl,
self.client,
self.server,
self.stream_send,
request.stream_id,
)
self.stream_buffers[request.stream_id] = StreamBuffer(self.context.event_class)
try:
self.priority.insert_stream(request.stream_id)
except priority.DuplicateStreamError:
# Recieved PRIORITY frame before HEADERS frame
pass
else:
self.priority.block(request.stream_id)
await self.streams[request.stream_id].handle(
Request(
stream_id=request.stream_id,
headers=filter_pseudo_headers(request.headers),
http_version="2",
method=method,
raw_path=raw_path,
)
)
async def _create_server_push(
self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]]
) -> None:
push_stream_id = self.connection.get_next_available_stream_id()
request_headers = [(b":method", b"GET"), (b":path", path)]
request_headers.extend(headers)
request_headers.extend(self.config.response_headers("h2"))
try:
self.connection.push_stream(
stream_id=stream_id,
promised_stream_id=push_stream_id,
request_headers=request_headers,
)
await self._flush()
except h2.exceptions.ProtocolError:
# Client does not accept push promises or we are trying to
# push on a push promises request.
pass
else:
event = h2.events.RequestReceived()
event.stream_id = push_stream_id
event.headers = request_headers
await self._create_stream(event)
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
async def _close_stream(self, stream_id: int) -> None:
if stream_id in self.streams:
stream = self.streams.pop(stream_id)
await stream.handle(StreamClosed(stream_id=stream_id))
await self.has_data.set()

View File

@@ -0,0 +1,138 @@
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union
from aioquic.h3.connection import H3Connection
from aioquic.h3.events import DataReceived, HeadersReceived
from aioquic.h3.exceptions import NoAvailablePushIDError
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import QuicEvent
from .events import (
Body,
Data,
EndBody,
EndData,
Event as StreamEvent,
Request,
Response,
StreamClosed,
)
from .http_stream import HTTPStream
from .ws_stream import WSStream
from ..config import Config
from ..typing import ASGIFramework, Context
from ..utils import filter_pseudo_headers
class H3Protocol:
def __init__(
self,
app: ASGIFramework,
config: Config,
context: Context,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
quic: QuicConnection,
send: Callable[[], Awaitable[None]],
) -> None:
self.app = app
self.client = client
self.config = config
self.context = context
self.connection = H3Connection(quic)
self.send = send
self.server = server
self.streams: Dict[int, Union[HTTPStream, WSStream]] = {}
async def handle(self, quic_event: QuicEvent) -> None:
for event in self.connection.handle_event(quic_event):
if isinstance(event, HeadersReceived):
await self._create_stream(event)
if event.stream_ended:
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
elif isinstance(event, DataReceived):
await self.streams[event.stream_id].handle(
Body(stream_id=event.stream_id, data=event.data)
)
if event.stream_ended:
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))
async def stream_send(self, event: StreamEvent) -> None:
if isinstance(event, Response):
self.connection.send_headers(
event.stream_id,
[(b":status", b"%d" % event.status_code)]
+ event.headers
+ self.config.response_headers("h3"),
)
await self.send()
elif isinstance(event, (Body, Data)):
self.connection.send_data(event.stream_id, event.data, False)
await self.send()
elif isinstance(event, (EndBody, EndData)):
self.connection.send_data(event.stream_id, b"", True)
await self.send()
elif isinstance(event, StreamClosed):
pass # ??
elif isinstance(event, Request):
await self._create_server_push(event.stream_id, event.raw_path, event.headers)
async def _create_stream(self, request: HeadersReceived) -> None:
for name, value in request.headers:
if name == b":method":
method = value.decode("ascii").upper()
elif name == b":path":
raw_path = value
if method == "CONNECT":
self.streams[request.stream_id] = WSStream(
self.app,
self.config,
self.context,
True,
self.client,
self.server,
self.stream_send,
request.stream_id,
)
else:
self.streams[request.stream_id] = HTTPStream(
self.app,
self.config,
self.context,
True,
self.client,
self.server,
self.stream_send,
request.stream_id,
)
await self.streams[request.stream_id].handle(
Request(
stream_id=request.stream_id,
headers=filter_pseudo_headers(request.headers),
http_version="3",
method=method,
raw_path=raw_path,
)
)
async def _create_server_push(
self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]]
) -> None:
request_headers = [(b":method", b"GET"), (b":path", path)]
request_headers.extend(headers)
request_headers.extend(self.config.response_headers("h3"))
try:
push_stream_id = self.connection.send_push_promise(
stream_id=stream_id, headers=request_headers
)
except NoAvailablePushIDError:
# Client does not accept push promises or we are trying to
# push on a push promises request.
pass
else:
event = HeadersReceived(
stream_id=push_stream_id, stream_ended=True, headers=request_headers
)
await self._create_stream(event)
await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id))

View File

@@ -0,0 +1,175 @@
from enum import auto, Enum
from time import time
from typing import Awaitable, Callable, Optional, Tuple
from urllib.parse import unquote
from .events import Body, EndBody, Event, Request, Response, StreamClosed
from ..config import Config
from ..typing import ASGIFramework, ASGISendEvent, Context, HTTPResponseStartEvent, HTTPScope
from ..utils import build_and_validate_headers, suppress_body, UnexpectedMessage, valid_server_name
PUSH_VERSIONS = {"2", "3"}
class ASGIHTTPState(Enum):
# The ASGI Spec is clear that a response should not start till the
# framework has sent at least one body message hence why this
# state tracking is required.
REQUEST = auto()
RESPONSE = auto()
CLOSED = auto()
class HTTPStream:
def __init__(
self,
app: ASGIFramework,
config: Config,
context: Context,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
stream_id: int,
) -> None:
self.app = app
self.client = client
self.closed = False
self.config = config
self.context = context
self.response: HTTPResponseStartEvent
self.scope: HTTPScope
self.send = send
self.scheme = "https" if ssl else "http"
self.server = server
self.start_time: float
self.state = ASGIHTTPState.REQUEST
self.stream_id = stream_id
@property
def idle(self) -> bool:
return False
async def handle(self, event: Event) -> None:
if self.closed:
return
elif isinstance(event, Request):
self.start_time = time()
path, _, query_string = event.raw_path.partition(b"?")
self.scope = {
"type": "http",
"http_version": event.http_version,
"asgi": {"spec_version": "2.1"},
"method": event.method,
"scheme": self.scheme,
"path": unquote(path.decode("ascii")),
"raw_path": path,
"query_string": query_string,
"root_path": self.config.root_path,
"headers": event.headers,
"client": self.client,
"server": self.server,
"extensions": {},
}
if event.http_version in PUSH_VERSIONS:
self.scope["extensions"]["http.response.push"] = {}
if valid_server_name(self.config, event):
self.app_put = await self.context.spawn_app(
self.app, self.config, self.scope, self.app_send
)
else:
await self._send_error_response(404)
self.closed = True
elif isinstance(event, Body):
await self.app_put(
{"type": "http.request", "body": bytes(event.data), "more_body": True}
)
elif isinstance(event, EndBody):
await self.app_put({"type": "http.request", "body": b"", "more_body": False})
elif isinstance(event, StreamClosed):
self.closed = True
if self.app_put is not None:
await self.app_put({"type": "http.disconnect"}) # type: ignore
async def app_send(self, message: Optional[ASGISendEvent]) -> None:
if self.closed:
# Allow app to finish after close
return
if message is None: # ASGI App has finished sending messages
# Cleanup if required
if self.state == ASGIHTTPState.REQUEST:
await self._send_error_response(500)
await self.send(StreamClosed(stream_id=self.stream_id))
else:
if message["type"] == "http.response.start" and self.state == ASGIHTTPState.REQUEST:
self.response = message
elif (
message["type"] == "http.response.push"
and self.scope["http_version"] in PUSH_VERSIONS
):
if not isinstance(message["path"], str):
raise TypeError(f"{message['path']} should be a str")
headers = [(b":scheme", self.scope["scheme"].encode())]
for name, value in self.scope["headers"]:
if name == b"host":
headers.append((b":authority", value))
headers.extend(build_and_validate_headers(message["headers"]))
await self.send(
Request(
stream_id=self.stream_id,
headers=headers,
http_version=self.scope["http_version"],
method="GET",
raw_path=message["path"].encode(),
)
)
elif message["type"] == "http.response.body" and self.state in {
ASGIHTTPState.REQUEST,
ASGIHTTPState.RESPONSE,
}:
if self.state == ASGIHTTPState.REQUEST:
headers = build_and_validate_headers(self.response.get("headers", []))
await self.send(
Response(
stream_id=self.stream_id,
headers=headers,
status_code=int(self.response["status"]),
)
)
self.state = ASGIHTTPState.RESPONSE
if (
not suppress_body(self.scope["method"], int(self.response["status"]))
and message.get("body", b"") != b""
):
await self.send(
Body(stream_id=self.stream_id, data=bytes(message.get("body", b"")))
)
if not message.get("more_body", False):
if self.state != ASGIHTTPState.CLOSED:
self.state = ASGIHTTPState.CLOSED
await self.config.log.access(
self.scope, self.response, time() - self.start_time
)
await self.send(EndBody(stream_id=self.stream_id))
await self.send(StreamClosed(stream_id=self.stream_id))
else:
raise UnexpectedMessage(self.state, message["type"])
async def _send_error_response(self, status_code: int) -> None:
await self.send(
Response(
stream_id=self.stream_id,
headers=[(b"content-length", b"0"), (b"connection", b"close")],
status_code=status_code,
)
)
await self.send(EndBody(stream_id=self.stream_id))
self.state = ASGIHTTPState.CLOSED
await self.config.log.access(
self.scope, {"status": status_code, "headers": []}, time() - self.start_time
)

View File

@@ -0,0 +1,125 @@
from functools import partial
from typing import Awaitable, Callable, Dict, Optional, Tuple
from aioquic.buffer import Buffer
from aioquic.h3.connection import H3_ALPN
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import (
ConnectionIdIssued,
ConnectionIdRetired,
ConnectionTerminated,
ProtocolNegotiated,
)
from aioquic.quic.packet import (
encode_quic_version_negotiation,
PACKET_TYPE_INITIAL,
pull_quic_header,
)
from .h3 import H3Protocol
from ..config import Config
from ..events import Closed, Event, RawData
from ..typing import ASGIFramework, Context
class QuicProtocol:
def __init__(
self,
app: ASGIFramework,
config: Config,
context: Context,
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
) -> None:
self.app = app
self.config = config
self.context = context
self.connections: Dict[bytes, QuicConnection] = {}
self.http_connections: Dict[QuicConnection, H3Protocol] = {}
self.send = send
self.server = server
self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False)
self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile)
async def handle(self, event: Event) -> None:
if isinstance(event, RawData):
try:
header = pull_quic_header(Buffer(data=event.data), host_cid_length=8)
except ValueError:
return
if (
header.version is not None
and header.version not in self.quic_config.supported_versions
):
data = encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=self.quic_config.supported_versions,
)
await self.send(RawData(data=data, address=event.address))
return
connection = self.connections.get(header.destination_cid)
if (
connection is None
and len(event.data) >= 1200
and header.packet_type == PACKET_TYPE_INITIAL
):
connection = QuicConnection(
configuration=self.quic_config,
original_destination_connection_id=header.destination_cid,
)
self.connections[header.destination_cid] = connection
self.connections[connection.host_cid] = connection
if connection is not None:
connection.receive_datagram(event.data, event.address, now=self.context.time())
await self._handle_events(connection, event.address)
elif isinstance(event, Closed):
pass
async def send_all(self, connection: QuicConnection) -> None:
for data, address in connection.datagrams_to_send(now=self.context.time()):
await self.send(RawData(data=data, address=address))
async def _handle_events(
self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None
) -> None:
event = connection.next_event()
while event is not None:
if isinstance(event, ConnectionTerminated):
pass
elif isinstance(event, ProtocolNegotiated):
self.http_connections[connection] = H3Protocol(
self.app,
self.config,
self.context,
client,
self.server,
connection,
partial(self.send_all, connection),
)
elif isinstance(event, ConnectionIdIssued):
self.connections[event.connection_id] = connection
elif isinstance(event, ConnectionIdRetired):
del self.connections[event.connection_id]
if connection in self.http_connections:
await self.http_connections[connection].handle(event)
event = connection.next_event()
await self.send_all(connection)
timer = connection.get_timer()
if timer is not None:
self.context.spawn(self._handle_timer, timer, connection)
async def _handle_timer(self, timer: float, connection: QuicConnection) -> None:
wait = max(0, timer - self.context.time())
await self.context.sleep(wait)
if connection._close_at is not None:
connection.handle_timer(now=self.context.time())
await self._handle_events(connection, None)

View File

@@ -0,0 +1,349 @@
from enum import auto, Enum
from time import time
from typing import Awaitable, Callable, List, Optional, Tuple, Union
from urllib.parse import unquote
from wsproto.connection import Connection, ConnectionState, ConnectionType
from wsproto.events import (
BytesMessage,
CloseConnection,
Event as WSProtoEvent,
Message,
Ping,
TextMessage,
)
from wsproto.extensions import Extension, PerMessageDeflate
from wsproto.frame_protocol import CloseReason
from wsproto.handshake import server_extensions_handshake, WEBSOCKET_VERSION
from wsproto.utilities import generate_accept_token, split_comma_header
from .events import Body, Data, EndBody, EndData, Event, Request, Response, StreamClosed
from ..config import Config
from ..typing import (
ASGIFramework,
ASGISendEvent,
Context,
WebsocketAcceptEvent,
WebsocketResponseBodyEvent,
WebsocketResponseStartEvent,
WebsocketScope,
)
from ..utils import build_and_validate_headers, suppress_body, UnexpectedMessage, valid_server_name
class ASGIWebsocketState(Enum):
# Hypercorn supports the ASGI websocket HTTP response extension,
# which allows HTTP responses rather than acceptance.
HANDSHAKE = auto()
CONNECTED = auto()
RESPONSE = auto()
CLOSED = auto()
HTTPCLOSED = auto()
class FrameTooLarge(Exception):
pass
class Handshake:
def __init__(self, headers: List[Tuple[bytes, bytes]], http_version: str) -> None:
self.http_version = http_version
self.connection_tokens: Optional[List[str]] = None
self.extensions: Optional[List[str]] = None
self.key: Optional[bytes] = None
self.subprotocols: Optional[List[str]] = None
self.upgrade: Optional[bytes] = None
self.version: Optional[bytes] = None
for name, value in headers:
name = name.lower()
if name == b"connection":
self.connection_tokens = split_comma_header(value)
elif name == b"sec-websocket-extensions":
self.extensions = split_comma_header(value)
elif name == b"sec-websocket-key":
self.key = value
elif name == b"sec-websocket-protocol":
self.subprotocols = split_comma_header(value)
elif name == b"sec-websocket-version":
self.version = value
elif name == b"upgrade":
self.upgrade = value
def is_valid(self) -> bool:
if self.http_version < "1.1":
return False
elif self.http_version == "1.1":
if self.key is None:
return False
if self.connection_tokens is None or not any(
token.lower() == "upgrade" for token in self.connection_tokens
):
return False
if self.upgrade.lower() != b"websocket":
return False
if self.version != WEBSOCKET_VERSION:
return False
return True
def accept(
self, subprotocol: Optional[str]
) -> Tuple[int, List[Tuple[bytes, bytes]], Connection]:
headers = []
if subprotocol is not None:
if subprotocol not in self.subprotocols:
raise Exception("Invalid Subprotocol")
else:
headers.append((b"sec-websocket-protocol", subprotocol.encode()))
extensions: List[Extension] = [PerMessageDeflate()]
accepts = None
if False and self.extensions is not None:
accepts = server_extensions_handshake(self.extensions, extensions)
if accepts:
headers.append((b"sec-websocket-extensions", accepts))
if self.key is not None:
headers.append((b"sec-websocket-accept", generate_accept_token(self.key)))
status_code = 200
if self.http_version == "1.1":
headers.extend([(b"upgrade", b"WebSocket"), (b"connection", b"Upgrade")])
status_code = 101
return status_code, headers, Connection(ConnectionType.SERVER, extensions)
class WebsocketBuffer:
def __init__(self, max_length: int) -> None:
self.value: Optional[Union[bytes, str]] = None
self.max_length = max_length
def extend(self, event: Message) -> None:
if self.value is None:
if isinstance(event, TextMessage):
self.value = ""
else:
self.value = b""
self.value += event.data
if len(self.value) > self.max_length:
raise FrameTooLarge()
def clear(self) -> None:
self.value = None
def to_message(self) -> dict:
return {
"type": "websocket.receive",
"bytes": self.value if isinstance(self.value, bytes) else None,
"text": self.value if isinstance(self.value, str) else None,
}
class WSStream:
def __init__(
self,
app: ASGIFramework,
config: Config,
context: Context,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
stream_id: int,
) -> None:
self.app = app
self.app_put: Optional[Callable] = None
self.buffer = WebsocketBuffer(config.websocket_max_message_size)
self.client = client
self.closed = False
self.config = config
self.context = context
self.response: WebsocketResponseStartEvent
self.scope: WebsocketScope
self.send = send
# RFC 8441 for HTTP/2 says use http or https, ASGI says ws or wss
self.scheme = "wss" if ssl else "ws"
self.server = server
self.start_time: float
self.state = ASGIWebsocketState.HANDSHAKE
self.stream_id = stream_id
self.connection: Connection
self.handshake: Handshake
@property
def idle(self) -> bool:
return self.state in {ASGIWebsocketState.CLOSED, ASGIWebsocketState.HTTPCLOSED}
async def handle(self, event: Event) -> None:
if self.closed:
return
elif isinstance(event, Request):
self.start_time = time()
self.handshake = Handshake(event.headers, event.http_version)
path, _, query_string = event.raw_path.partition(b"?")
self.scope = {
"type": "websocket",
"asgi": {"spec_version": "2.1"},
"scheme": self.scheme,
"http_version": event.http_version,
"path": unquote(path.decode("ascii")),
"raw_path": path,
"query_string": query_string,
"root_path": self.config.root_path,
"headers": event.headers,
"client": self.client,
"server": self.server,
"subprotocols": self.handshake.subprotocols or [],
"extensions": {"websocket.http.response": {}},
}
if not valid_server_name(self.config, event):
await self._send_error_response(404)
self.closed = True
elif not self.handshake.is_valid():
await self._send_error_response(400)
self.closed = True
else:
self.app_put = await self.context.spawn_app(
self.app, self.config, self.scope, self.app_send
)
await self.app_put({"type": "websocket.connect"}) # type: ignore
elif isinstance(event, (Body, Data)):
self.connection.receive_data(event.data)
await self._handle_events()
elif isinstance(event, StreamClosed):
self.closed = True
if self.app_put is not None:
if self.state in {ASGIWebsocketState.HTTPCLOSED, ASGIWebsocketState.CLOSED}:
code = CloseReason.NORMAL_CLOSURE.value
else:
code = CloseReason.ABNORMAL_CLOSURE.value
await self.app_put({"type": "websocket.disconnect", "code": code})
async def app_send(self, message: Optional[ASGISendEvent]) -> None:
if self.closed:
# Allow app to finish after close
return
if message is None: # ASGI App has finished sending messages
# Cleanup if required
if self.state == ASGIWebsocketState.HANDSHAKE:
await self._send_error_response(500)
await self.config.log.access(
self.scope, {"status": 500, "headers": []}, time() - self.start_time
)
elif self.state == ASGIWebsocketState.CONNECTED:
await self._send_wsproto_event(CloseConnection(code=CloseReason.ABNORMAL_CLOSURE))
await self.send(StreamClosed(stream_id=self.stream_id))
else:
if message["type"] == "websocket.accept" and self.state == ASGIWebsocketState.HANDSHAKE:
await self._accept(message)
elif (
message["type"] == "websocket.http.response.start"
and self.state == ASGIWebsocketState.HANDSHAKE
):
self.response = message
elif message["type"] == "websocket.http.response.body" and self.state in {
ASGIWebsocketState.HANDSHAKE,
ASGIWebsocketState.RESPONSE,
}:
await self._send_rejection(message)
elif message["type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED:
event: WSProtoEvent
if message.get("bytes") is not None:
event = BytesMessage(data=bytes(message["bytes"]))
elif not isinstance(message["text"], str):
raise TypeError(f"{message['text']} should be a str")
else:
event = TextMessage(data=message["text"])
await self._send_wsproto_event(event)
elif (
message["type"] == "websocket.close" and self.state == ASGIWebsocketState.HANDSHAKE
):
self.state = ASGIWebsocketState.HTTPCLOSED
await self._send_error_response(403)
elif message["type"] == "websocket.close":
self.state = ASGIWebsocketState.CLOSED
await self._send_wsproto_event(
CloseConnection(code=int(message.get("code", CloseReason.NORMAL_CLOSURE)))
)
await self.send(EndData(stream_id=self.stream_id))
else:
raise UnexpectedMessage(self.state, message["type"])
async def _handle_events(self) -> None:
for event in self.connection.events():
if isinstance(event, Message):
try:
self.buffer.extend(event)
except FrameTooLarge:
await self._send_wsproto_event(
CloseConnection(code=CloseReason.MESSAGE_TOO_BIG)
)
break
if event.message_finished:
await self.app_put(self.buffer.to_message())
self.buffer.clear()
elif isinstance(event, Ping):
await self._send_wsproto_event(event.response())
elif isinstance(event, CloseConnection):
if self.connection.state == ConnectionState.REMOTE_CLOSING:
await self._send_wsproto_event(event.response())
await self.send(StreamClosed(stream_id=self.stream_id))
async def _send_error_response(self, status_code: int) -> None:
await self.send(
Response(
stream_id=self.stream_id,
status_code=status_code,
headers=[(b"content-length", b"0"), (b"connection", b"close")],
)
)
await self.send(EndBody(stream_id=self.stream_id))
await self.config.log.access(
self.scope, {"status": status_code, "headers": []}, time() - self.start_time
)
async def _send_wsproto_event(self, event: WSProtoEvent) -> None:
data = self.connection.send(event)
await self.send(Data(stream_id=self.stream_id, data=data))
async def _accept(self, message: WebsocketAcceptEvent) -> None:
self.state = ASGIWebsocketState.CONNECTED
status_code, headers, self.connection = self.handshake.accept(message.get("subprotocol"))
await self.send(
Response(stream_id=self.stream_id, status_code=status_code, headers=headers)
)
await self.config.log.access(
self.scope, {"status": status_code, "headers": []}, time() - self.start_time
)
if self.config.websocket_ping_interval is not None:
self.context.spawn(self._send_pings)
async def _send_rejection(self, message: WebsocketResponseBodyEvent) -> None:
body_suppressed = suppress_body("GET", self.response["status"])
if self.state == ASGIWebsocketState.HANDSHAKE:
headers = build_and_validate_headers(self.response["headers"])
await self.send(
Response(
stream_id=self.stream_id,
status_code=int(self.response["status"]),
headers=headers,
)
)
self.state = ASGIWebsocketState.RESPONSE
if not body_suppressed:
await self.send(Body(stream_id=self.stream_id, data=bytes(message.get("body", b""))))
if not message.get("more_body", False):
self.state = ASGIWebsocketState.HTTPCLOSED
await self.send(EndBody(stream_id=self.stream_id))
await self.config.log.access(self.scope, self.response, time() - self.start_time)
async def _send_pings(self) -> None:
while not self.closed:
await self._send_wsproto_event(Ping())
await self.context.sleep(self.config.websocket_ping_interval)

View File

@@ -0,0 +1 @@
Marker

View File

@@ -0,0 +1,80 @@
import platform
import random
import signal
import time
from multiprocessing import Event, Process
from typing import Any
from .config import Config
from .typing import WorkerFunc
from .utils import write_pid_file
def run(config: Config) -> None:
if config.pid_path is not None:
write_pid_file(config.pid_path)
worker_func: WorkerFunc
if config.worker_class == "asyncio":
from .asyncio.run import asyncio_worker
worker_func = asyncio_worker
elif config.worker_class == "uvloop":
from .asyncio.run import uvloop_worker
worker_func = uvloop_worker
elif config.worker_class == "trio":
from .trio.run import trio_worker
worker_func = trio_worker
else:
raise ValueError(f"No worker of class {config.worker_class} exists")
if config.workers == 1:
worker_func(config)
else:
run_multiple(config, worker_func)
def run_multiple(config: Config, worker_func: WorkerFunc) -> None:
if config.use_reloader:
raise RuntimeError("Reloader can only be used with a single worker")
sockets = config.create_sockets()
processes = []
# Ignore SIGINT before creating the processes, so that they
# inherit the signal handling. This means that the shutdown
# function controls the shutdown.
signal.signal(signal.SIGINT, signal.SIG_IGN)
shutdown_event = Event()
for _ in range(config.workers):
process = Process(
target=worker_func,
kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets},
)
process.daemon = True
process.start()
processes.append(process)
if platform.system() == "Windows":
time.sleep(0.1 * random.random())
def shutdown(*args: Any) -> None:
shutdown_event.set()
for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}:
if hasattr(signal, signal_name):
signal.signal(getattr(signal, signal_name), shutdown)
for process in processes:
process.join()
for process in processes:
process.terminate()
for sock in sockets.secure_sockets:
sock.close()
for sock in sockets.insecure_sockets:
sock.close()

View File

@@ -0,0 +1,93 @@
from typing import Any, TYPE_CHECKING
from .logging import Logger
if TYPE_CHECKING:
from .config import Config
from .typing import ResponseSummary, WWWScope
METRIC_VAR = "metric"
VALUE_VAR = "value"
MTYPE_VAR = "mtype"
GAUGE_TYPE = "gauge"
COUNTER_TYPE = "counter"
HISTOGRAM_TYPE = "histogram"
class StatsdLogger(Logger):
def __init__(self, config: "Config") -> None:
super().__init__(config)
self.dogstatsd_tags = config.dogstatsd_tags
self.prefix = config.statsd_prefix
if len(self.prefix) and self.prefix[-1] != ".":
self.prefix += "."
async def critical(self, message: str, *args: Any, **kwargs: Any) -> None:
await super().critical(message, *args, **kwargs)
await self.increment("hypercorn.log.critical", 1)
async def error(self, message: str, *args: Any, **kwargs: Any) -> None:
await super().error(message, *args, **kwargs)
self.increment("hypercorn.log.error", 1)
async def warning(self, message: str, *args: Any, **kwargs: Any) -> None:
await super().warning(message, *args, **kwargs)
self.increment("hypercorn.log.warning", 1)
async def info(self, message: str, *args: Any, **kwargs: Any) -> None:
await super().info(message, *args, **kwargs)
async def debug(self, message: str, *args: Any, **kwargs: Any) -> None:
await super().debug(message, *args, **kwargs)
async def exception(self, message: str, *args: Any, **kwargs: Any) -> None:
await super().exception(message, *args, **kwargs)
await self.increment("hypercorn.log.exception", 1)
async def log(self, level: int, message: str, *args: Any, **kwargs: Any) -> None:
try:
extra = kwargs.get("extra", None)
if extra is not None:
metric = extra.get(METRIC_VAR, None)
value = extra.get(VALUE_VAR, None)
type_ = extra.get(MTYPE_VAR, None)
if metric and value and type_:
if type_ == GAUGE_TYPE:
await self.gauge(metric, value)
elif type_ == COUNTER_TYPE:
await self.increment(metric, value)
elif type_ == HISTOGRAM_TYPE:
await self.histogram(metric, value)
if message:
await super().log(level, message, *args, **kwargs)
except Exception:
await super().warning("Failed to log to statsd", exc_info=True)
async def access(
self, request: "WWWScope", response: "ResponseSummary", request_time: float
) -> None:
await super().access(request, response, request_time)
await self.histogram("hypercorn.request.duration", request_time * 1_000)
await self.increment("hypercorn.requests", 1)
await self.increment(f"hypercorn.request.status.{response['status']}", 1)
async def gauge(self, name: str, value: int) -> None:
await self._send(f"{self.prefix}{name}:{value}|g")
async def increment(self, name: str, value: int, sampling_rate: float = 1.0) -> None:
await self._send(f"{self.prefix}{name}:{value}|c|@{sampling_rate}")
async def decrement(self, name: str, value: int, sampling_rate: float = 1.0) -> None:
await self._send(f"{self.prefix}{name}:-{value}|c|@{sampling_rate}")
async def histogram(self, name: str, value: float) -> None:
await self._send(f"{self.prefix}{name}:{value}|ms")
async def _send(self, message: str) -> None:
if self.dogstatsd_tags:
message = f"{message}|#{self.dogstatsd_tags}"
await self._socket_send(message.encode("ascii"))
async def _socket_send(self, message: bytes) -> None:
raise NotImplementedError()

View File

@@ -0,0 +1,42 @@
import warnings
from typing import Awaitable, Callable, Optional
import trio
from .run import worker_serve
from ..config import Config
from ..typing import ASGIFramework
async def serve(
app: ASGIFramework,
config: Config,
*,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None:
"""Serve an ASGI framework app given the config.
This allows for a programmatic way to serve an ASGI framework, it
can be used via,
.. code-block:: python
trio.run(serve, app, config)
It is assumed that the event-loop is configured before calling
this function, therefore configuration values that relate to loop
setup or process setup are ignored.
Arguments:
app: The ASGI application to serve.
config: A Hypercorn configuration object.
shutdown_trigger: This should return to trigger a graceful
shutdown.
"""
if config.debug:
warnings.warn("The config `debug` has no affect when using serve", Warning)
if config.workers != 1:
warnings.warn("The config `workers` has no affect when using serve", Warning)
await worker_serve(app, config, shutdown_trigger=shutdown_trigger, task_status=task_status)

View File

@@ -0,0 +1,83 @@
from typing import Any, Awaitable, Callable, Optional, Type, Union
import trio
from ..config import Config
from ..typing import (
ASGIFramework,
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendEvent,
Event,
Scope,
)
from ..utils import invoke_asgi
class EventWrapper:
def __init__(self) -> None:
self._event = trio.Event()
async def clear(self) -> None:
self._event = trio.Event()
async def wait(self) -> None:
await self._event.wait()
async def set(self) -> None:
self._event.set()
async def _handle(
app: ASGIFramework,
config: Config,
scope: Scope,
receive: ASGIReceiveCallable,
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> None:
try:
await invoke_asgi(app, scope, receive, send)
except trio.Cancelled:
raise
except trio.MultiError as error:
errors = trio.MultiError.filter(
lambda exc: None if isinstance(exc, trio.Cancelled) else exc, root_exc=error
)
if errors is not None:
await config.log.exception("Error in ASGI Framework")
await send(None)
else:
raise
except Exception:
await config.log.exception("Error in ASGI Framework")
finally:
await send(None)
class Context:
event_class: Type[Event] = EventWrapper
def __init__(self, nursery: trio._core._run.Nursery) -> None:
self.nursery = nursery
async def spawn_app(
self,
app: ASGIFramework,
config: Config,
scope: Scope,
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
app_send_channel, app_receive_channel = trio.open_memory_channel(config.max_app_queue_size)
self.nursery.start_soon(_handle, app, config, scope, app_receive_channel.receive, send)
return app_send_channel.send
def spawn(self, func: Callable, *args: Any) -> None:
self.nursery.start_soon(func, *args)
@staticmethod
async def sleep(wait: Union[float, int]) -> None:
return await trio.sleep(wait)
@staticmethod
def time() -> float:
return trio.current_time()

View File

@@ -0,0 +1,79 @@
import trio
from ..config import Config
from ..typing import ASGIFramework, ASGIReceiveEvent, ASGISendEvent, LifespanScope
from ..utils import invoke_asgi, LifespanFailure, LifespanTimeout
class UnexpectedMessage(Exception):
pass
class Lifespan:
def __init__(self, app: ASGIFramework, config: Config) -> None:
self.app = app
self.config = config
self.startup = trio.Event()
self.shutdown = trio.Event()
self.app_send_channel, self.app_receive_channel = trio.open_memory_channel(
config.max_app_queue_size
)
self.supported = True
async def handle_lifespan(
self, *, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED
) -> None:
task_status.started()
scope: LifespanScope = {"type": "lifespan", "asgi": {"spec_version": "2.0"}}
try:
await invoke_asgi(self.app, scope, self.asgi_receive, self.asgi_send)
except LifespanFailure:
# Lifespan failures should crash the server
raise
except Exception:
self.supported = False
await self.config.log.exception(
"ASGI Framework Lifespan error, continuing without Lifespan support"
)
finally:
self.startup.set()
self.shutdown.set()
await self.app_send_channel.aclose()
await self.app_receive_channel.aclose()
async def wait_for_startup(self) -> None:
if not self.supported:
return
await self.app_send_channel.send({"type": "lifespan.startup"})
try:
with trio.fail_after(self.config.startup_timeout):
await self.startup.wait()
except trio.TooSlowError as error:
raise LifespanTimeout("startup") from error
async def wait_for_shutdown(self) -> None:
if not self.supported:
return
await self.app_send_channel.send({"type": "lifespan.shutdown"})
try:
with trio.fail_after(self.config.shutdown_timeout):
await self.shutdown.wait()
except trio.TooSlowError as error:
raise LifespanTimeout("startup") from error
async def asgi_receive(self) -> ASGIReceiveEvent:
return await self.app_receive_channel.receive()
async def asgi_send(self, message: ASGISendEvent) -> None:
if message["type"] == "lifespan.startup.complete":
self.startup.set()
elif message["type"] == "lifespan.shutdown.complete":
self.shutdown.set()
elif message["type"] == "lifespan.startup.failed":
raise LifespanFailure("startup", message["message"])
elif message["type"] == "lifespan.shutdown.failed":
raise LifespanFailure("shutdown", message["message"])
else:
raise UnexpectedMessage(message["type"])

View File

@@ -0,0 +1,119 @@
from functools import partial
from multiprocessing.synchronize import Event as EventType
from typing import Awaitable, Callable, Optional
import trio
from .lifespan import Lifespan
from .statsd import StatsdLogger
from .tcp_server import TCPServer
from .udp_server import UDPServer
from ..config import Config, Sockets
from ..typing import ASGIFramework
from ..utils import (
check_multiprocess_shutdown_event,
load_application,
MustReloadException,
observe_changes,
raise_shutdown,
repr_socket_addr,
restart,
Shutdown,
)
async def worker_serve(
app: ASGIFramework,
config: Config,
*,
sockets: Optional[Sockets] = None,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None:
config.set_statsd_logger_class(StatsdLogger)
lifespan = Lifespan(app, config)
reload_ = False
async with trio.open_nursery() as lifespan_nursery:
await lifespan_nursery.start(lifespan.handle_lifespan)
await lifespan.wait_for_startup()
try:
async with trio.open_nursery() as nursery:
if config.use_reloader:
nursery.start_soon(observe_changes, trio.sleep)
if shutdown_trigger is not None:
nursery.start_soon(raise_shutdown, shutdown_trigger)
if sockets is None:
sockets = config.create_sockets()
for sock in sockets.secure_sockets:
sock.listen(config.backlog)
for sock in sockets.insecure_sockets:
sock.listen(config.backlog)
ssl_context = config.create_ssl_context()
listeners = []
binds = []
for sock in sockets.secure_sockets:
listeners.append(
trio.SSLListener(
trio.SocketListener(trio.socket.from_stdlib_socket(sock)),
ssl_context,
https_compatible=True,
)
)
bind = repr_socket_addr(sock.family, sock.getsockname())
binds.append(f"https://{bind}")
await config.log.info(f"Running on https://{bind} (CTRL + C to quit)")
for sock in sockets.insecure_sockets:
listeners.append(trio.SocketListener(trio.socket.from_stdlib_socket(sock)))
bind = repr_socket_addr(sock.family, sock.getsockname())
binds.append(f"http://{bind}")
await config.log.info(f"Running on http://{bind} (CTRL + C to quit)")
for sock in sockets.quic_sockets:
await nursery.start(UDPServer(app, config, sock, nursery).run)
bind = repr_socket_addr(sock.family, sock.getsockname())
await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)")
task_status.started(binds)
await trio.serve_listeners(
partial(TCPServer, app, config), listeners, handler_nursery=lifespan_nursery
)
except MustReloadException:
reload_ = True
except (Shutdown, KeyboardInterrupt):
pass
finally:
try:
await trio.sleep(config.graceful_timeout)
except (Shutdown, KeyboardInterrupt):
pass
await lifespan.wait_for_shutdown()
lifespan_nursery.cancel_scope.cancel()
if reload_:
restart()
def trio_worker(
config: Config, sockets: Optional[Sockets] = None, shutdown_event: Optional[EventType] = None
) -> None:
if sockets is not None:
for sock in sockets.secure_sockets:
sock.listen(config.backlog)
for sock in sockets.insecure_sockets:
sock.listen(config.backlog)
app = load_application(config.application_path)
shutdown_trigger = None
if shutdown_event is not None:
shutdown_trigger = partial(check_multiprocess_shutdown_event, shutdown_event, trio.sleep)
trio.run(partial(worker_serve, app, config, sockets=sockets, shutdown_trigger=shutdown_trigger))

View File

@@ -0,0 +1,14 @@
import trio
from ..config import Config
from ..statsd import StatsdLogger as Base
class StatsdLogger(Base):
def __init__(self, config: Config) -> None:
super().__init__(config)
self.address = config.statsd_host.rsplit(":", 1)
self.socket = trio.socket.socket(trio.socket.AF_INET, trio.socket.SOCK_DGRAM)
async def _socket_send(self, message: bytes) -> None:
await self.socket.sendto(message, self.address)

View File

@@ -0,0 +1,152 @@
from typing import Any, Callable, Generator, Optional
import trio
from .context import Context
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..protocol import ProtocolWrapper
from ..typing import ASGIFramework
from ..utils import parse_socket_addr
MAX_RECV = 2 ** 16
class EventWrapper:
def __init__(self) -> None:
self._event = trio.Event()
async def clear(self) -> None:
self._event = trio.Event()
async def wait(self) -> None:
await self._event.wait()
async def set(self) -> None:
self._event.set()
class TCPServer:
def __init__(self, app: ASGIFramework, config: Config, stream: trio.abc.Stream) -> None:
self.app = app
self.config = config
self.protocol: ProtocolWrapper
self.send_lock = trio.Lock()
self.timeout_lock = trio.Lock()
self.stream = stream
self._keep_alive_timeout_handle: Optional[trio.CancelScope] = None
def __await__(self) -> Generator[Any, None, None]:
return self.run().__await__()
async def run(self) -> None:
try:
try:
with trio.fail_after(self.config.ssl_handshake_timeout):
await self.stream.do_handshake()
except (trio.BrokenResourceError, trio.TooSlowError):
return # Handshake failed
alpn_protocol = self.stream.selected_alpn_protocol()
socket = self.stream.transport_stream.socket
ssl = True
except AttributeError: # Not SSL
alpn_protocol = "http/1.1"
socket = self.stream.socket
ssl = False
try:
client = parse_socket_addr(socket.family, socket.getpeername())
server = parse_socket_addr(socket.family, socket.getsockname())
async with trio.open_nursery() as nursery:
self.nursery = nursery
context = Context(nursery)
self.protocol = ProtocolWrapper(
self.app,
self.config,
context,
ssl,
client,
server,
self.protocol_send,
alpn_protocol,
)
await self.protocol.initiate()
await self._update_keep_alive_timeout()
await self._read_data()
except (trio.MultiError, OSError):
pass
finally:
await self._close()
async def protocol_send(self, event: Event) -> None:
if isinstance(event, RawData):
async with self.send_lock:
try:
with trio.CancelScope() as cancel_scope:
cancel_scope.shield = True
await self.stream.send_all(event.data)
except (trio.BrokenResourceError, trio.ClosedResourceError):
await self.protocol.handle(Closed())
elif isinstance(event, Closed):
await self._close()
await self.protocol.handle(Closed())
elif isinstance(event, Updated):
pass # Triggers the keep alive timeout update
await self._update_keep_alive_timeout()
async def _read_data(self) -> None:
while True:
try:
data = await self.stream.receive_some(MAX_RECV)
except (trio.ClosedResourceError, trio.BrokenResourceError):
await self.protocol.handle(Closed())
break
else:
if data == b"":
await self._update_keep_alive_timeout()
break
await self.protocol.handle(RawData(data))
await self._update_keep_alive_timeout()
async def _close(self) -> None:
try:
await self.stream.send_eof()
except (
trio.BrokenResourceError,
AttributeError,
trio.BusyResourceError,
trio.ClosedResourceError,
):
# They're already gone, nothing to do
# Or it is a SSL stream
pass
await self.stream.aclose()
async def _update_keep_alive_timeout(self) -> None:
async with self.timeout_lock:
if self._keep_alive_timeout_handle is not None:
self._keep_alive_timeout_handle.cancel()
self._keep_alive_timeout_handle = None
if self.protocol.idle:
self._keep_alive_timeout_handle = await self.nursery.start(
_call_later, self.config.keep_alive_timeout, self._timeout
)
async def _timeout(self) -> None:
await self.protocol.handle(Closed())
await self.stream.aclose()
async def _call_later(
timeout: float,
callback: Callable,
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None:
cancel_scope = trio.CancelScope()
task_status.started(cancel_scope)
with cancel_scope:
await trio.sleep(timeout)
cancel_scope.shield = True
await callback()

View File

@@ -0,0 +1,40 @@
import trio
from .context import Context
from ..config import Config
from ..events import Event, RawData
from ..typing import ASGIFramework
from ..utils import parse_socket_addr
MAX_RECV = 2 ** 16
class UDPServer:
def __init__(
self,
app: ASGIFramework,
config: Config,
socket: trio.socket.socket,
nursery: trio._core._run.Nursery,
) -> None:
from ..protocol.quic import QuicProtocol # h3/Quic is an optional part of Hypercorn
self.app = app
self.config = config
self.nursery = nursery
self.socket = trio.socket.from_stdlib_socket(socket)
context = Context(nursery)
server = parse_socket_addr(socket.family, socket.getsockname())
self.protocol = QuicProtocol(self.app, self.config, context, server, self.protocol_send)
async def run(
self, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED
) -> None:
task_status.started()
while True:
data, address = await self.socket.recvfrom(MAX_RECV)
await self.protocol.handle(RawData(data=data, address=address))
async def protocol_send(self, event: Event) -> None:
if isinstance(event, RawData):
await self.socket.sendto(event.data, event.address)

View File

@@ -0,0 +1,308 @@
from multiprocessing.synchronize import Event as EventType
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional, Tuple, Type, Union
import h2.events
import h11
# Till PEP 544 is accepted
try:
from typing import Literal, Protocol, TypedDict
except ImportError:
from typing_extensions import Literal, Protocol, TypedDict # type: ignore
from .config import Config, Sockets
H11SendableEvent = Union[h11.Data, h11.EndOfMessage, h11.InformationalResponse, h11.Response]
WorkerFunc = Callable[[Config, Optional[Sockets], Optional[EventType]], None]
class ASGIVersions(TypedDict, total=False):
spec_version: str
version: Union[Literal["2.0"], Literal["3.0"]]
class HTTPScope(TypedDict):
type: Literal["http"]
asgi: ASGIVersions
http_version: str
method: str
scheme: str
path: str
raw_path: bytes
query_string: bytes
root_path: str
headers: Iterable[Tuple[bytes, bytes]]
client: Optional[Tuple[str, int]]
server: Optional[Tuple[str, Optional[int]]]
extensions: Dict[str, dict]
class WebsocketScope(TypedDict):
type: Literal["websocket"]
asgi: ASGIVersions
http_version: str
scheme: str
path: str
raw_path: bytes
query_string: bytes
root_path: str
headers: Iterable[Tuple[bytes, bytes]]
client: Optional[Tuple[str, int]]
server: Optional[Tuple[str, Optional[int]]]
subprotocols: Iterable[str]
extensions: Dict[str, dict]
class LifespanScope(TypedDict):
type: Literal["lifespan"]
asgi: ASGIVersions
WWWScope = Union[HTTPScope, WebsocketScope]
Scope = Union[HTTPScope, WebsocketScope, LifespanScope]
class HTTPRequestEvent(TypedDict):
type: Literal["http.request"]
body: bytes
more_body: bool
class HTTPResponseStartEvent(TypedDict):
type: Literal["http.response.start"]
status: int
headers: Iterable[Tuple[bytes, bytes]]
class HTTPResponseBodyEvent(TypedDict):
type: Literal["http.response.body"]
body: bytes
more_body: bool
class HTTPServerPushEvent(TypedDict):
type: Literal["http.response.push"]
path: str
headers: Iterable[Tuple[bytes, bytes]]
class HTTPDisconnectEvent(TypedDict):
type: Literal["http.disconnect"]
class WebsocketConnectEvent(TypedDict):
type: Literal["websocket.connect"]
class WebsocketAcceptEvent(TypedDict):
type: Literal["websocket.accept"]
subprotocol: Optional[str]
headers: Iterable[Tuple[bytes, bytes]]
class WebsocketReceiveEvent(TypedDict):
type: Literal["websocket.receive"]
bytes: Optional[bytes]
text: Optional[str]
class WebsocketSendEvent(TypedDict):
type: Literal["websocket.send"]
bytes: Optional[bytes]
text: Optional[str]
class WebsocketResponseStartEvent(TypedDict):
type: Literal["websocket.http.response.start"]
status: int
headers: Iterable[Tuple[bytes, bytes]]
class WebsocketResponseBodyEvent(TypedDict):
type: Literal["websocket.http.response.body"]
body: bytes
more_body: bool
class WebsocketDisconnectEvent(TypedDict):
type: Literal["websocket.disconnect"]
code: int
class WebsocketCloseEvent(TypedDict):
type: Literal["websocket.close"]
code: int
class LifespanStartupEvent(TypedDict):
type: Literal["lifespan.startup"]
class LifespanShutdownEvent(TypedDict):
type: Literal["lifespan.shutdown"]
class LifespanStartupCompleteEvent(TypedDict):
type: Literal["lifespan.startup.complete"]
class LifespanStartupFailedEvent(TypedDict):
type: Literal["lifespan.startup.failed"]
message: str
class LifespanShutdownCompleteEvent(TypedDict):
type: Literal["lifespan.shutdown.complete"]
class LifespanShutdownFailedEvent(TypedDict):
type: Literal["lifespan.shutdown.failed"]
message: str
ASGIReceiveEvent = Union[
HTTPRequestEvent,
HTTPDisconnectEvent,
WebsocketConnectEvent,
WebsocketReceiveEvent,
WebsocketDisconnectEvent,
LifespanStartupEvent,
LifespanShutdownEvent,
]
ASGISendEvent = Union[
HTTPResponseStartEvent,
HTTPResponseBodyEvent,
HTTPServerPushEvent,
HTTPDisconnectEvent,
WebsocketAcceptEvent,
WebsocketSendEvent,
WebsocketResponseStartEvent,
WebsocketResponseBodyEvent,
WebsocketCloseEvent,
LifespanStartupCompleteEvent,
LifespanStartupFailedEvent,
LifespanShutdownCompleteEvent,
LifespanShutdownFailedEvent,
]
ASGIReceiveCallable = Callable[[], Awaitable[ASGIReceiveEvent]]
ASGISendCallable = Callable[[ASGISendEvent], Awaitable[None]]
class ASGI2Protocol(Protocol):
# Should replace with a Protocol when PEP 544 is accepted.
def __init__(self, scope: Scope) -> None:
...
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
...
ASGI2Framework = Type[ASGI2Protocol]
ASGI3Framework = Callable[
[
Scope,
ASGIReceiveCallable,
ASGISendCallable,
],
Awaitable[None],
]
ASGIFramework = Union[ASGI2Framework, ASGI3Framework]
class H2SyncStream(Protocol):
scope: dict
def data_received(self, data: bytes) -> None:
...
def ended(self) -> None:
...
def reset(self) -> None:
...
def close(self) -> None:
...
async def handle_request(
self,
event: h2.events.RequestReceived,
scheme: str,
client: Tuple[str, int],
server: Tuple[str, int],
) -> None:
...
class H2AsyncStream(Protocol):
scope: dict
async def data_received(self, data: bytes) -> None:
...
async def ended(self) -> None:
...
async def reset(self) -> None:
...
async def close(self) -> None:
...
async def handle_request(
self,
event: h2.events.RequestReceived,
scheme: str,
client: Tuple[str, int],
server: Tuple[str, int],
) -> None:
...
class Event(Protocol):
def __init__(self) -> None:
...
async def clear(self) -> None:
...
async def set(self) -> None:
...
async def wait(self) -> None:
...
class Context(Protocol):
event_class: Type[Event]
async def spawn_app(
self,
app: ASGIFramework,
config: Config,
scope: Scope,
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
...
def spawn(self, func: Callable, *args: Any) -> None:
...
@staticmethod
async def sleep(wait: Union[float, int]) -> None:
...
@staticmethod
def time() -> float:
...
class ResponseSummary(TypedDict):
status: int
headers: Iterable[Tuple[bytes, bytes]]

View File

@@ -0,0 +1,261 @@
import inspect
import os
import platform
import socket
import sys
from enum import Enum
from importlib import import_module
from multiprocessing.synchronize import Event as EventType
from pathlib import Path
from typing import (
Any,
Awaitable,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Tuple,
TYPE_CHECKING,
)
from .config import Config
from .typing import (
ASGI2Framework,
ASGI3Framework,
ASGIFramework,
ASGIReceiveCallable,
ASGISendCallable,
Scope,
)
if TYPE_CHECKING:
from .protocol.events import Request
class Shutdown(Exception):
pass
class MustReloadException(Exception):
pass
class NoAppException(Exception):
pass
class LifespanTimeout(Exception):
def __init__(self, stage: str) -> None:
super().__init__(
f"Timeout whilst awaiting {stage}. Your application may not support the ASGI Lifespan "
f"protocol correctly, alternatively the {stage}_timeout configuration is incorrect."
)
class LifespanFailure(Exception):
def __init__(self, stage: str, message: str) -> None:
super().__init__(f"Lifespan failure in {stage}. '{message}'")
class UnexpectedMessage(Exception):
def __init__(self, state: Enum, message_type: str) -> None:
super().__init__(f"Unexpected message type, {message_type} given the state {state}")
class FrameTooLarge(Exception):
pass
def suppress_body(method: str, status_code: int) -> bool:
return method == "HEAD" or 100 <= status_code < 200 or status_code in {204, 304, 412}
def build_and_validate_headers(headers: Iterable[Tuple[bytes, bytes]]) -> List[Tuple[bytes, bytes]]:
# Validates that the header name and value are bytes
validated_headers: List[Tuple[bytes, bytes]] = []
for name, value in headers:
if name[0] == b":"[0]:
raise ValueError("Pseudo headers are not valid")
validated_headers.append((bytes(name).lower().strip(), bytes(value).strip()))
return validated_headers
def filter_pseudo_headers(headers: List[Tuple[bytes, bytes]]) -> List[Tuple[bytes, bytes]]:
filtered_headers: List[Tuple[bytes, bytes]] = [(b"host", b"")] # Placeholder
for name, value in headers:
if name == b":authority": # h2 & h3 libraries validate this is present
filtered_headers[0] = (b"host", value)
elif name[0] != b":"[0]:
filtered_headers.append((name, value))
return filtered_headers
def load_application(path: str) -> ASGIFramework:
try:
module_name, app_name = path.split(":", 1)
except ValueError:
module_name, app_name = path, "app"
except AttributeError:
raise NoAppException()
module_path = Path(module_name).resolve()
sys.path.insert(0, str(module_path.parent))
if module_path.is_file():
import_name = module_path.with_suffix("").name
else:
import_name = module_path.name
try:
module = import_module(import_name)
except ModuleNotFoundError as error:
if error.name == import_name:
raise NoAppException()
else:
raise
try:
return eval(app_name, vars(module))
except NameError:
raise NoAppException()
async def observe_changes(sleep: Callable[[float], Awaitable[Any]]) -> None:
last_updates: Dict[Path, float] = {}
for module in list(sys.modules.values()):
filename = getattr(module, "__file__", None)
if filename is None:
continue
path = Path(filename)
try:
last_updates[Path(filename)] = path.stat().st_mtime
except (FileNotFoundError, NotADirectoryError):
pass
while True:
await sleep(1)
for index, (path, last_mtime) in enumerate(last_updates.items()):
if index % 10 == 0:
# Yield to the event loop
await sleep(0)
try:
mtime = path.stat().st_mtime
except FileNotFoundError:
# File deleted
raise MustReloadException()
else:
if mtime > last_mtime:
raise MustReloadException()
else:
last_updates[path] = mtime
def restart() -> None:
# Restart this process (only safe for dev/debug)
executable = sys.executable
script_path = Path(sys.argv[0]).resolve()
args = sys.argv[1:]
main_package = sys.modules["__main__"].__package__
if main_package is None:
# Executed by filename
if platform.system() == "Windows":
if not script_path.exists() and script_path.with_suffix(".exe").exists():
# quart run
executable = str(script_path.with_suffix(".exe"))
else:
# python run.py
args.append(str(script_path))
else:
if script_path.is_file() and os.access(script_path, os.X_OK):
# hypercorn run:app --reload
executable = str(script_path)
else:
# python run.py
args.append(str(script_path))
else:
# Executed as a module e.g. python -m run
module = script_path.stem
import_name = main_package
if module != "__main__":
import_name = f"{main_package}.{module}"
args[:0] = ["-m", import_name.lstrip(".")]
os.execv(executable, [executable] + args)
async def raise_shutdown(shutdown_event: Callable[..., Awaitable[None]]) -> None:
await shutdown_event()
raise Shutdown()
async def check_multiprocess_shutdown_event(
shutdown_event: EventType, sleep: Callable[[float], Awaitable[Any]]
) -> None:
while True:
if shutdown_event.is_set():
return
await sleep(0.1)
def write_pid_file(pid_path: str) -> None:
with open(pid_path, "w") as file_:
file_.write(f"{os.getpid()}")
def parse_socket_addr(family: int, address: tuple) -> Optional[Tuple[str, int]]:
if family == socket.AF_INET:
return address # type: ignore
elif family == socket.AF_INET6:
return (address[0], address[1])
else:
return None
def repr_socket_addr(family: int, address: tuple) -> str:
if family == socket.AF_INET:
return f"{address[0]}:{address[1]}"
elif family == socket.AF_INET6:
return f"[{address[0]}]:{address[1]}"
elif family == socket.AF_UNIX:
return f"unix:{address}"
else:
return f"{address}"
async def invoke_asgi(
app: ASGIFramework, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
if _is_asgi_2(app):
scope["asgi"]["version"] = "2.0"
app = cast(ASGI2Framework, app)
asgi_instance = app(scope)
await asgi_instance(receive, send)
else:
scope["asgi"]["version"] = "3.0"
app = cast(ASGI3Framework, app)
await app(scope, receive, send)
def _is_asgi_2(app: ASGIFramework) -> bool:
if inspect.isclass(app):
return True
if hasattr(app, "__call__") and inspect.iscoroutinefunction(app.__call__): # type: ignore
return False
return not inspect.iscoroutinefunction(app)
def valid_server_name(config: Config, request: "Request") -> bool:
if len(config.server_names) == 0:
return True
host = ""
for name, value in request.headers:
if name.lower() == b"host":
host = value.decode()
break
return host in config.server_names