import asyncio import typing from starlette.requests import Request from starlette.responses import Response, StreamingResponse from starlette.types import ASGIApp, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] ] class BaseHTTPMiddleware: def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None: self.app = app self.dispatch_func = self.dispatch if dispatch is None else dispatch async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return request = Request(scope, receive=receive) response = await self.dispatch_func(request, self.call_next) await response(scope, receive, send) async def call_next(self, request: Request) -> Response: loop = asyncio.get_event_loop() queue = asyncio.Queue() # type: asyncio.Queue scope = request.scope receive = request.receive send = queue.put async def coro() -> None: try: await self.app(scope, receive, send) finally: await queue.put(None) task = loop.create_task(coro()) message = await queue.get() if message is None: task.result() raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" async def body_stream() -> typing.AsyncGenerator[bytes, None]: while True: message = await queue.get() if message is None: break assert message["type"] == "http.response.body" yield message.get("body", b"") task.result() response = StreamingResponse( status_code=message["status"], content=body_stream() ) response.raw_headers = message["headers"] return response async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: raise NotImplementedError() # pragma: no cover