new
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||||
DispatchFunction = typing.Callable[
|
||||
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
|
||||
]
|
||||
|
||||
|
||||
class BaseHTTPMiddleware:
|
||||
def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None:
|
||||
self.app = app
|
||||
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope, receive=receive)
|
||||
response = await self.dispatch_func(request, self.call_next)
|
||||
await response(scope, receive, send)
|
||||
|
||||
async def call_next(self, request: Request) -> Response:
|
||||
loop = asyncio.get_event_loop()
|
||||
queue = asyncio.Queue() # type: asyncio.Queue
|
||||
|
||||
scope = request.scope
|
||||
receive = request.receive
|
||||
send = queue.put
|
||||
|
||||
async def coro() -> None:
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
finally:
|
||||
await queue.put(None)
|
||||
|
||||
task = loop.create_task(coro())
|
||||
message = await queue.get()
|
||||
if message is None:
|
||||
task.result()
|
||||
raise RuntimeError("No response returned.")
|
||||
assert message["type"] == "http.response.start"
|
||||
|
||||
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||||
while True:
|
||||
message = await queue.get()
|
||||
if message is None:
|
||||
break
|
||||
assert message["type"] == "http.response.body"
|
||||
yield message.get("body", b"")
|
||||
task.result()
|
||||
|
||||
response = StreamingResponse(
|
||||
status_code=message["status"], content=body_stream()
|
||||
)
|
||||
response.raw_headers = message["headers"]
|
||||
return response
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
Reference in New Issue
Block a user