This commit is contained in:
Untriex Programming
2021-03-17 08:57:57 +01:00
parent 339be0ccd8
commit ed6afdb5c9
3074 changed files with 423348 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
__version__ = "0.63.0"
from starlette import status as status
from .applications import FastAPI as FastAPI
from .background import BackgroundTasks as BackgroundTasks
from .datastructures import UploadFile as UploadFile
from .exceptions import HTTPException as HTTPException
from .param_functions import Body as Body
from .param_functions import Cookie as Cookie
from .param_functions import Depends as Depends
from .param_functions import File as File
from .param_functions import Form as Form
from .param_functions import Header as Header
from .param_functions import Path as Path
from .param_functions import Query as Query
from .param_functions import Security as Security
from .requests import Request as Request
from .responses import Response as Response
from .routing import APIRouter as APIRouter
from .websockets import WebSocket as WebSocket
from .websockets import WebSocketDisconnect as WebSocketDisconnect

View File

@@ -0,0 +1,739 @@
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
from fastapi import routing
from fastapi.concurrency import AsyncExitStack
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.encoders import DictIntStrAny, SetIntStr
from fastapi.exception_handlers import (
http_exception_handler,
request_validation_exception_handler,
)
from fastapi.exceptions import RequestValidationError
from fastapi.logger import logger
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends
from fastapi.types import DecoratedCallable
from starlette.applications import Starlette
from starlette.datastructures import State
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute
from starlette.types import ASGIApp, Receive, Scope, Send
class FastAPI(Starlette):
def __init__(
self,
*,
debug: bool = False,
routes: Optional[List[BaseRoute]] = None,
title: str = "FastAPI",
description: str = "",
version: str = "0.1.0",
openapi_url: Optional[str] = "/openapi.json",
openapi_tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
swagger_ui_init_oauth: Optional[Dict[str, Any]] = None,
middleware: Optional[Sequence[Middleware]] = None,
exception_handlers: Optional[
Dict[
Union[int, Type[Exception]],
Callable[[Request, Any], Coroutine[Any, Any, Response]],
]
] = None,
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
openapi_prefix: str = "",
root_path: str = "",
root_path_in_servers: bool = True,
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[BaseRoute]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
**extra: Any,
) -> None:
self._debug: bool = debug
self.state: State = State()
self.router: routing.APIRouter = routing.APIRouter(
routes=routes,
dependency_overrides_provider=self,
on_startup=on_startup,
on_shutdown=on_shutdown,
default_response_class=default_response_class,
dependencies=dependencies,
callbacks=callbacks,
deprecated=deprecated,
include_in_schema=include_in_schema,
responses=responses,
)
self.exception_handlers: Dict[
Union[int, Type[Exception]],
Callable[[Request, Any], Coroutine[Any, Any, Response]],
] = (
{} if exception_handlers is None else dict(exception_handlers)
)
self.exception_handlers.setdefault(HTTPException, http_exception_handler)
self.exception_handlers.setdefault(
RequestValidationError, request_validation_exception_handler
)
self.user_middleware: List[Middleware] = (
[] if middleware is None else list(middleware)
)
self.middleware_stack: ASGIApp = self.build_middleware_stack()
self.title = title
self.description = description
self.version = version
self.servers = servers or []
self.openapi_url = openapi_url
self.openapi_tags = openapi_tags
# TODO: remove when discarding the openapi_prefix parameter
if openapi_prefix:
logger.warning(
'"openapi_prefix" has been deprecated in favor of "root_path", which '
"follows more closely the ASGI standard, is simpler, and more "
"automatic. Check the docs at "
"https://fastapi.tiangolo.com/advanced/sub-applications/"
)
self.root_path = root_path or openapi_prefix
self.root_path_in_servers = root_path_in_servers
self.docs_url = docs_url
self.redoc_url = redoc_url
self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
self.swagger_ui_init_oauth = swagger_ui_init_oauth
self.extra = extra
self.dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
self.openapi_version = "3.0.2"
if self.openapi_url:
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup()
def openapi(self) -> Dict[str, Any]:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
description=self.description,
routes=self.routes,
tags=self.openapi_tags,
servers=self.servers,
)
return self.openapi_schema
def setup(self) -> None:
if self.openapi_url:
urls = (server_data.get("url") for server_data in self.servers)
server_urls = {url for url in urls if url}
async def openapi(req: Request) -> JSONResponse:
root_path = req.scope.get("root_path", "").rstrip("/")
if root_path not in server_urls:
if root_path and self.root_path_in_servers:
self.servers.insert(0, {"url": root_path})
server_urls.add(root_path)
return JSONResponse(self.openapi())
self.add_route(self.openapi_url, openapi, include_in_schema=False)
if self.openapi_url and self.docs_url:
async def swagger_ui_html(req: Request) -> HTMLResponse:
root_path = req.scope.get("root_path", "").rstrip("/")
openapi_url = root_path + self.openapi_url
oauth2_redirect_url = self.swagger_ui_oauth2_redirect_url
if oauth2_redirect_url:
oauth2_redirect_url = root_path + oauth2_redirect_url
return get_swagger_ui_html(
openapi_url=openapi_url,
title=self.title + " - Swagger UI",
oauth2_redirect_url=oauth2_redirect_url,
init_oauth=self.swagger_ui_init_oauth,
)
self.add_route(self.docs_url, swagger_ui_html, include_in_schema=False)
if self.swagger_ui_oauth2_redirect_url:
async def swagger_ui_redirect(req: Request) -> HTMLResponse:
return get_swagger_ui_oauth2_redirect_html()
self.add_route(
self.swagger_ui_oauth2_redirect_url,
swagger_ui_redirect,
include_in_schema=False,
)
if self.openapi_url and self.redoc_url:
async def redoc_html(req: Request) -> HTMLResponse:
root_path = req.scope.get("root_path", "").rstrip("/")
openapi_url = root_path + self.openapi_url
return get_redoc_html(
openapi_url=openapi_url, title=self.title + " - ReDoc"
)
self.add_route(self.redoc_url, redoc_html, include_in_schema=False)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.root_path:
scope["root_path"] = self.root_path
if AsyncExitStack:
async with AsyncExitStack() as stack:
scope["fastapi_astack"] = stack
await super().__call__(scope, receive, send)
else:
await super().__call__(scope, receive, send) # pragma: no cover
def add_api_route(
self,
path: str,
endpoint: Callable[..., Coroutine[Any, Any, Response]],
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[List[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
name: Optional[str] = None,
) -> None:
self.router.add_api_route(
path,
endpoint=endpoint,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=methods,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
)
def api_route(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[List[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.router.add_api_route(
path,
func,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=methods,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
)
return func
return decorator
def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
) -> None:
self.router.add_api_websocket_route(path, endpoint, name=name)
def websocket(
self, path: str, name: Optional[str] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
return func
return decorator
def include_router(
self,
router: routing.APIRouter,
*,
prefix: str = "",
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
default_response_class: Type[Response] = Default(JSONResponse),
callbacks: Optional[List[BaseRoute]] = None,
) -> None:
self.router.include_router(
router,
prefix=prefix,
tags=tags,
dependencies=dependencies,
responses=responses,
deprecated=deprecated,
include_in_schema=include_in_schema,
default_response_class=default_response_class,
callbacks=callbacks,
)
def get(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.get(
path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
)
def put(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.put(
path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
)
def post(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.post(
path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
)
def delete(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.delete(
path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
operation_id=operation_id,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
)
def options(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.options(
path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
)
def head(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.head(
path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
)
def patch(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.patch(
path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
)
def trace(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.trace(
path,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
)

View File

@@ -0,0 +1 @@
from starlette.background import BackgroundTasks as BackgroundTasks # noqa

View File

@@ -0,0 +1,51 @@
from typing import Any, Callable
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette.concurrency import ( # noqa
run_until_first_complete as run_until_first_complete,
)
asynccontextmanager_error_message = """
FastAPI's contextmanager_in_threadpool require Python 3.7 or above,
or the backport for Python 3.6, installed with:
pip install async-generator
"""
def _fake_asynccontextmanager(func: Callable[..., Any]) -> Callable[..., Any]:
def raiser(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError(asynccontextmanager_error_message)
return raiser
try:
from contextlib import asynccontextmanager as asynccontextmanager # type: ignore
except ImportError:
try:
from async_generator import ( # type: ignore # isort: skip
asynccontextmanager as asynccontextmanager,
)
except ImportError: # pragma: no cover
asynccontextmanager = _fake_asynccontextmanager
try:
from contextlib import AsyncExitStack as AsyncExitStack # type: ignore
except ImportError:
try:
from async_exit_stack import AsyncExitStack as AsyncExitStack # type: ignore
except ImportError: # pragma: no cover
AsyncExitStack = None # type: ignore
@asynccontextmanager # type: ignore
async def contextmanager_in_threadpool(cm: Any) -> Any:
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
ok = await run_in_threadpool(cm.__exit__, type(e), e, None)
if not ok:
raise e
else:
await run_in_threadpool(cm.__exit__, None, None, None)

View File

@@ -0,0 +1,47 @@
from typing import Any, Callable, Iterable, Type, TypeVar
from starlette.datastructures import State as State # noqa: F401
from starlette.datastructures import UploadFile as StarletteUploadFile
class UploadFile(StarletteUploadFile):
@classmethod
def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable[..., Any]]:
yield cls.validate
@classmethod
def validate(cls: Type["UploadFile"], v: Any) -> Any:
if not isinstance(v, StarletteUploadFile):
raise ValueError(f"Expected UploadFile, received: {type(v)}")
return v
class DefaultPlaceholder:
"""
You shouldn't use this class directly.
It's used internally to recognize when a default value has been overwritten, even
if the overriden default value was truthy.
"""
def __init__(self, value: Any):
self.value = value
def __bool__(self) -> bool:
return bool(self.value)
def __eq__(self, o: object) -> bool:
return isinstance(o, DefaultPlaceholder) and o.value == self.value
DefaultType = TypeVar("DefaultType")
def Default(value: DefaultType) -> DefaultType:
"""
You shouldn't use this function directly.
It's used internally to recognize when a default value has been overwritten, even
if the overriden default value was truthy.
"""
return DefaultPlaceholder(value) # type: ignore

View File

@@ -0,0 +1,58 @@
from typing import Any, Callable, List, Optional, Sequence
from fastapi.security.base import SecurityBase
from pydantic.fields import ModelField
class SecurityRequirement:
def __init__(
self, security_scheme: SecurityBase, scopes: Optional[Sequence[str]] = None
):
self.security_scheme = security_scheme
self.scopes = scopes
class Dependant:
def __init__(
self,
*,
path_params: Optional[List[ModelField]] = None,
query_params: Optional[List[ModelField]] = None,
header_params: Optional[List[ModelField]] = None,
cookie_params: Optional[List[ModelField]] = None,
body_params: Optional[List[ModelField]] = None,
dependencies: Optional[List["Dependant"]] = None,
security_schemes: Optional[List[SecurityRequirement]] = None,
name: Optional[str] = None,
call: Optional[Callable[..., Any]] = None,
request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None,
http_connection_param_name: Optional[str] = None,
response_param_name: Optional[str] = None,
background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
path: Optional[str] = None,
) -> None:
self.path_params = path_params or []
self.query_params = query_params or []
self.header_params = header_params or []
self.cookie_params = cookie_params or []
self.body_params = body_params or []
self.dependencies = dependencies or []
self.security_requirements = security_schemes or []
self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
self.http_connection_param_name = http_connection_param_name
self.response_param_name = response_param_name
self.background_tasks_param_name = background_tasks_param_name
self.security_scopes = security_scopes
self.security_scopes_param_name = security_scopes_param_name
self.name = name
self.call = call
self.use_cache = use_cache
# Store the path to be able to re-generate a dependable from it in overrides
self.path = path
# Save the cache key at creation to optimize performance
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))

View File

@@ -0,0 +1,783 @@
import asyncio
import inspect
from contextlib import contextmanager
from copy import deepcopy
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from fastapi import params
from fastapi.concurrency import (
AsyncExitStack,
_fake_asynccontextmanager,
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.utils import create_response_field, get_path_param_names
from pydantic import BaseModel, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import (
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
FieldInfo,
ModelField,
Required,
)
from pydantic.schema import get_annotation_from_field_info
from pydantic.typing import ForwardRef, evaluate_forwardref
from pydantic.utils import lenient_issubclass
from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
SHAPE_TUPLE,
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}
sequence_types = (list, set, tuple)
sequence_shape_to_type = {
SHAPE_LIST: list,
SHAPE_SET: set,
SHAPE_TUPLE: tuple,
SHAPE_SEQUENCE: list,
SHAPE_TUPLE_ELLIPSIS: list,
}
multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n'
'You can install "python-multipart" with: \n\n'
"pip install python-multipart\n"
)
multipart_incorrect_install_error = (
'Form data requires "python-multipart" to be installed. '
'It seems you installed "multipart" instead. \n'
'You can remove "multipart" with: \n\n'
"pip uninstall multipart\n\n"
'And then install "python-multipart" with: \n\n'
"pip install python-multipart\n"
)
def check_file_field(field: ModelField) -> None:
field_info = field.field_info
if isinstance(field_info, params.Form):
try:
# __version__ is available in both multiparts, and can be mocked
from multipart import __version__ # type: ignore
assert __version__
try:
# parse_options_header is only available in the right multipart
from multipart.multipart import parse_options_header # type: ignore
assert parse_options_header
except ImportError:
logger.error(multipart_incorrect_install_error)
raise RuntimeError(multipart_incorrect_install_error)
except ImportError:
logger.error(multipart_not_installed_error)
raise RuntimeError(multipart_not_installed_error)
def get_param_sub_dependant(
*, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None
) -> Dependant:
depends: params.Depends = param.default
if depends.dependency:
dependency = depends.dependency
else:
dependency = param.annotation
return get_sub_dependant(
depends=depends,
dependency=dependency,
path=path,
name=param.name,
security_scopes=security_scopes,
)
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
def get_sub_dependant(
*,
depends: params.Depends,
dependency: Callable[..., Any],
path: str,
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
) -> Dependant:
security_requirement = None
security_scopes = security_scopes or []
if isinstance(depends, params.Security):
dependency_scopes = depends.scopes
security_scopes.extend(dependency_scopes)
if isinstance(dependency, SecurityBase):
use_scopes: List[str] = []
if isinstance(dependency, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes
security_requirement = SecurityRequirement(
security_scheme=dependency, scopes=use_scopes
)
sub_dependant = get_dependant(
path=path,
call=dependency,
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
sub_dependant.security_scopes = security_scopes
return sub_dependant
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
def get_flat_dependant(
dependant: Dependant,
*,
skip_repeats: bool = False,
visited: Optional[List[CacheKey]] = None,
) -> Dependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_schemes=dependant.security_requirements.copy(),
use_cache=dependant.use_cache,
path=dependant.path,
)
for sub_dependant in dependant.dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
continue
flat_sub = get_flat_dependant(
sub_dependant, skip_repeats=skip_repeats, visited=visited
)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
return flat_dependant
def get_flat_params(dependant: Dependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)
def is_scalar_field(field: ModelField) -> bool:
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, sequence_types + (dict,))
and not isinstance(field_info, params.Body)
):
return False
if field.sub_fields:
if not all(is_scalar_field(f) for f in field.sub_fields):
return False
return True
def is_scalar_sequence_field(field: ModelField) -> bool:
if (field.shape in sequence_shapes) and not lenient_issubclass(
field.type_, BaseModel
):
if field.sub_fields is not None:
for sub_field in field.sub_fields:
if not is_scalar_field(sub_field):
return False
return True
if lenient_issubclass(field.type_, sequence_types):
return True
return False
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param, globalns),
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
return annotation
async_contextmanager_dependencies_error = """
FastAPI dependencies with yield require Python 3.7 or above,
or the backports for Python 3.6, installed with:
pip install async-exit-stack async-generator
"""
def check_dependency_contextmanagers() -> None:
if AsyncExitStack is None or asynccontextmanager == _fake_asynccontextmanager:
raise RuntimeError(async_contextmanager_dependencies_error) # pragma: no cover
def get_dependant(
*,
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
if is_gen_callable(call) or is_async_gen_callable(call):
check_dependency_contextmanagers()
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
for param_name, param in signature_params.items():
if isinstance(param.default, params.Depends):
sub_dependant = get_param_sub_dependant(
param=param, path=path, security_scopes=security_scopes
)
dependant.dependencies.append(sub_dependant)
continue
if add_non_field_param_to_dependency(param=param, dependant=dependant):
continue
param_field = get_param_field(
param=param, default_field_info=params.Query, param_name=param_name
)
if param_name in path_param_names:
assert is_scalar_field(
field=param_field
), "Path params must be of one of the supported types"
if isinstance(param.default, params.Path):
ignore_default = False
else:
ignore_default = True
param_field = get_param_field(
param=param,
param_name=param_name,
default_field_info=params.Path,
force_type=params.ParamTypes.path,
ignore_default=ignore_default,
)
add_param_to_fields(field=param_field, dependant=dependant)
elif is_scalar_field(field=param_field):
add_param_to_fields(field=param_field, dependant=dependant)
elif isinstance(
param.default, (params.Query, params.Header)
) and is_scalar_sequence_field(param_field):
add_param_to_fields(field=param_field, dependant=dependant)
else:
field_info = param_field.field_info
assert isinstance(
field_info, params.Body
), f"Param: {param_field.name} can only be a request body, using Body(...)"
dependant.body_params.append(param_field)
return dependant
def add_non_field_param_to_dependency(
*, param: inspect.Parameter, dependant: Dependant
) -> Optional[bool]:
if lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param.name
return True
elif lenient_issubclass(param.annotation, WebSocket):
dependant.websocket_param_name = param.name
return True
elif lenient_issubclass(param.annotation, HTTPConnection):
dependant.http_connection_param_name = param.name
return True
elif lenient_issubclass(param.annotation, Response):
dependant.response_param_name = param.name
return True
elif lenient_issubclass(param.annotation, BackgroundTasks):
dependant.background_tasks_param_name = param.name
return True
elif lenient_issubclass(param.annotation, SecurityScopes):
dependant.security_scopes_param_name = param.name
return True
return None
def get_param_field(
*,
param: inspect.Parameter,
param_name: str,
default_field_info: Type[params.Param] = params.Param,
force_type: Optional[params.ParamTypes] = None,
ignore_default: bool = False,
) -> ModelField:
default_value = Required
had_schema = False
if not param.default == param.empty and ignore_default is False:
default_value = param.default
if isinstance(default_value, FieldInfo):
had_schema = True
field_info = default_value
default_value = field_info.default
if (
isinstance(field_info, params.Param)
and getattr(field_info, "in_", None) is None
):
field_info.in_ = default_field_info.in_
if force_type:
field_info.in_ = force_type # type: ignore
else:
field_info = default_field_info(default_value)
required = default_value == Required
annotation: Any = Any
if not param.annotation == param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info, param_name)
if not field_info.alias and getattr(field_info, "convert_underscores", None):
alias = param.name.replace("_", "-")
else:
alias = field_info.alias or param.name
field = create_response_field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=alias,
required=required,
field_info=field_info,
)
field.required = required
if not had_schema and not is_scalar_field(field=field):
field.field_info = params.Body(field_info.default)
return field
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
field_info = cast(params.Param, field.field_info)
if field_info.in_ == params.ParamTypes.path:
dependant.path_params.append(field)
elif field_info.in_ == params.ParamTypes.query:
dependant.query_params.append(field)
elif field_info.in_ == params.ParamTypes.header:
dependant.header_params.append(field)
else:
assert (
field_info.in_ == params.ParamTypes.cookie
), f"non-body parameters must be in path, query, header or cookie: {field.name}"
dependant.cookie_params.append(field)
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
call = getattr(call, "__call__", None)
return inspect.iscoroutinefunction(call)
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
call = getattr(call, "__call__", None)
return inspect.isasyncgenfunction(call)
def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
call = getattr(call, "__call__", None)
return inspect.isgeneratorfunction(call)
async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call):
if not inspect.isasyncgenfunction(call):
# asynccontextmanager from the async_generator backfill pre python3.7
# does not support callables that are not functions or methods.
# See https://github.com/python-trio/async_generator/issues/32
#
# Expand the callable class into its __call__ method before decorating it.
# This approach will work on newer python versions as well.
call = getattr(call, "__call__", None)
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
background_tasks: Optional[BackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
) -> Tuple[
Dict[str, Any],
List[ErrorWrapper],
Optional[BackgroundTasks],
Response,
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
]:
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
response = response or Response(
content=None,
status_code=None, # type: ignore
headers=None, # type: ignore # in Starlette
media_type=None, # type: ignore # in Starlette
background=None, # type: ignore # in Starlette
)
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call
use_sub_dependant = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant(
path=use_path,
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
)
use_sub_dependant.security_scopes = sub_dependant.security_scopes
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
)
(
sub_values,
sub_errors,
background_tasks,
_, # the subdependency returns the same response we have
sub_dependency_cache,
) = solved_result
dependency_cache.update(sub_dependency_cache)
if sub_errors:
errors.extend(sub_errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
stack = request.scope.get("fastapi_astack")
if stack is None:
raise RuntimeError(
async_contextmanager_dependencies_error
) # pragma: no cover
solved = await solve_generator(
call=call, stack=stack, sub_values=sub_values
)
elif is_coroutine_callable(call):
solved = await call(**sub_values)
else:
solved = await run_in_threadpool(call, **sub_values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
query_values, query_errors = request_params_to_args(
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
values.update(path_values)
values.update(query_values)
values.update(header_values)
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
(
body_values,
body_errors,
) = await request_body_to_args( # body_params checked above
required_params=dependant.body_params, received_body=body
)
values.update(body_values)
errors.extend(body_errors)
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name:
values[dependant.response_param_name] = response
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
return values, errors, background_tasks, response, dependency_cache
def request_params_to_args(
required_params: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values = {}
errors = []
for field in required_params:
if is_scalar_sequence_field(field) and isinstance(
received_params, (QueryParams, Headers)
):
value = received_params.getlist(field.alias) or field.default
else:
value = received_params.get(field.alias)
field_info = field.field_info
assert isinstance(
field_info, params.Param
), "Params must be subclasses of Param"
if value is None:
if field.required:
errors.append(
ErrorWrapper(
MissingError(), loc=(field_info.in_.value, field.alias)
)
)
else:
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(
value, values, loc=(field_info.in_.value, field.alias)
)
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
async def request_body_to_args(
required_params: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values = {}
errors = []
if required_params:
field = required_params[0]
field_info = field.field_info
embed = getattr(field_info, "embed", None)
field_alias_omitted = len(required_params) == 1 and not embed
if field_alias_omitted:
received_body = {field.alias: received_body}
for field in required_params:
loc: Tuple[str, ...]
if field_alias_omitted:
loc = ("body",)
else:
loc = ("body", field.alias)
value: Optional[Any] = None
if received_body is not None:
if (
field.shape in sequence_shapes or field.type_ in sequence_types
) and isinstance(received_body, FormData):
value = received_body.getlist(field.alias)
else:
try:
value = received_body.get(field.alias)
except AttributeError:
errors.append(get_missing_field_error(loc))
continue
if (
value is None
or (isinstance(field_info, params.Form) and value == "")
or (
isinstance(field_info, params.Form)
and field.shape in sequence_shapes
and len(value) == 0
)
):
if field.required:
errors.append(get_missing_field_error(loc))
else:
values[field.name] = deepcopy(field.default)
continue
if (
isinstance(field_info, params.File)
and lenient_issubclass(field.type_, bytes)
and isinstance(value, UploadFile)
):
value = await value.read()
elif (
field.shape in sequence_shapes
and isinstance(field_info, params.File)
and lenient_issubclass(field.type_, bytes)
and isinstance(value, sequence_types)
):
awaitables = [sub_value.read() for sub_value in value]
contents = await asyncio.gather(*awaitables)
value = sequence_shape_to_type[field.shape](contents)
v_, errors_ = field.validate(value, values, loc=loc)
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper:
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
return missing_field_error
def get_schema_compatible_field(*, field: ModelField) -> ModelField:
out_field = field
if lenient_issubclass(field.type_, UploadFile):
use_type: type = bytes
if field.shape in sequence_shapes:
use_type = List[bytes]
out_field = create_response_field(
name=field.name,
type_=use_type,
class_validators=field.class_validators,
model_config=field.model_config,
default=field.default,
required=field.required,
alias=field.alias,
field_info=field.field_info,
)
return out_field
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
flat_dependant = get_flat_dependant(dependant)
if not flat_dependant.body_params:
return None
first_param = flat_dependant.body_params[0]
field_info = first_param.field_info
embed = getattr(field_info, "embed", None)
body_param_names_set = {param.name for param in flat_dependant.body_params}
if len(body_param_names_set) == 1 and not embed:
final_field = get_schema_compatible_field(field=first_param)
check_file_field(final_field)
return final_field
# If one field requires to embed, all have to be embedded
# in case a sub-dependency is evaluated with a single unique body field
# That is combined (embedded) with other body fields
for param in flat_dependant.body_params:
setattr(param.field_info, "embed", True)
model_name = "Body_" + name
BodyModel = create_model(model_name)
for f in flat_dependant.body_params:
BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f)
required = any(True for f in flat_dependant.body_params if f.required)
BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None)
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
BodyFieldInfo: Type[params.Body] = params.File
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
BodyFieldInfo = params.Form
else:
BodyFieldInfo = params.Body
body_param_media_types = [
getattr(f.field_info, "media_type")
for f in flat_dependant.body_params
if isinstance(f.field_info, params.Body)
]
if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
final_field = create_response_field(
name="body",
type_=BodyModel,
required=required,
alias="body",
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
)
check_file_field(final_field)
return final_field

View File

@@ -0,0 +1,150 @@
from collections import defaultdict
from enum import Enum
from pathlib import PurePath
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel
from pydantic.json import ENCODERS_BY_TYPE
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
def generate_encoders_by_class_tuples(
type_encoder_map: Dict[Any, Callable[[Any], Any]]
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
tuple
)
for type_, encoder in type_encoder_map.items():
encoders_by_class_tuples[encoder] += (type_,)
return encoders_by_class_tuples
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
def jsonable_encoder(
obj: Any,
include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
sqlalchemy_safe: bool = True,
) -> Any:
if include is not None and not isinstance(include, set):
include = set(include)
if exclude is not None and not isinstance(exclude, set):
exclude = set(exclude)
if isinstance(obj, BaseModel):
encoder = getattr(obj.__config__, "json_encoders", {})
if custom_encoder:
encoder.update(custom_encoder)
obj_dict = obj.dict(
include=include, # type: ignore # in Pydantic
exclude=exclude, # type: ignore # in Pydantic
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
custom_encoder=encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, dict):
encoded_dict = {}
for key, value in obj.items():
if (
(
not sqlalchemy_safe
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and ((include and key in include) or not exclude or key not in exclude)
):
encoded_key = jsonable_encoder(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder(obj)
if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
errors: List[Exception] = []
try:
data = dict(obj)
except Exception as e:
errors.append(e)
try:
data = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors)
return jsonable_encoder(
data,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)

View File

@@ -0,0 +1,25 @@
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
headers = getattr(exc, "headers", None)
if headers:
return JSONResponse(
{"detail": exc.detail}, status_code=exc.status_code, headers=headers
)
else:
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
return JSONResponse(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content={"detail": jsonable_encoder(exc.errors())},
)

View File

@@ -0,0 +1,37 @@
from typing import Any, Dict, Optional, Sequence
from pydantic import ValidationError, create_model
from pydantic.error_wrappers import ErrorList
from starlette.exceptions import HTTPException as StarletteHTTPException
class HTTPException(StarletteHTTPException):
def __init__(
self,
status_code: int,
detail: Any = None,
headers: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(status_code=status_code, detail=detail)
self.headers = headers
RequestErrorModel = create_model("Request")
WebSocketErrorModel = create_model("WebSocket")
class FastAPIError(RuntimeError):
"""
A generic, FastAPI-specific error.
"""
class RequestValidationError(ValidationError):
def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
self.body = body
super().__init__(errors, RequestErrorModel)
class WebSocketRequestValidationError(ValidationError):
def __init__(self, errors: Sequence[ErrorList]) -> None:
super().__init__(errors, WebSocketErrorModel)

View File

@@ -0,0 +1,3 @@
import logging
logger = logging.getLogger("fastapi")

View File

@@ -0,0 +1 @@
from starlette.middleware import Middleware as Middleware

View File

@@ -0,0 +1 @@
from starlette.middleware.cors import CORSMiddleware as CORSMiddleware # noqa

View File

@@ -0,0 +1 @@
from starlette.middleware.gzip import GZipMiddleware as GZipMiddleware # noqa

View File

@@ -0,0 +1,3 @@
from starlette.middleware.httpsredirect import ( # noqa
HTTPSRedirectMiddleware as HTTPSRedirectMiddleware,
)

View File

@@ -0,0 +1,3 @@
from starlette.middleware.trustedhost import ( # noqa
TrustedHostMiddleware as TrustedHostMiddleware,
)

View File

@@ -0,0 +1 @@
from starlette.middleware.wsgi import WSGIMiddleware as WSGIMiddleware # noqa

View File

@@ -0,0 +1,3 @@
METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}
STATUS_CODES_WITH_NO_BODY = {100, 101, 102, 103, 204, 304}
REF_PREFIX = "#/components/schemas/"

View File

@@ -0,0 +1,177 @@
import json
from typing import Any, Dict, Optional
from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse
def get_swagger_ui_html(
*,
openapi_url: str,
title: str,
swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js",
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css",
swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
oauth2_redirect_url: Optional[str] = None,
init_oauth: Optional[Dict[str, Any]] = None,
) -> HTMLResponse:
html = f"""
<!DOCTYPE html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="{swagger_css_url}">
<link rel="shortcut icon" href="{swagger_favicon_url}">
<title>{title}</title>
</head>
<body>
<div id="swagger-ui">
</div>
<script src="{swagger_js_url}"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({{
url: '{openapi_url}',
"""
if oauth2_redirect_url:
html += f"oauth2RedirectUrl: window.location.origin + '{oauth2_redirect_url}',"
html += """
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true
})"""
if init_oauth:
html += f"""
ui.initOAuth({json.dumps(jsonable_encoder(init_oauth))})
"""
html += """
</script>
</body>
</html>
"""
return HTMLResponse(html)
def get_redoc_html(
*,
openapi_url: str,
title: str,
redoc_js_url: str = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js",
redoc_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
with_google_fonts: bool = True,
) -> HTMLResponse:
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>{title}</title>
<!-- needed for adaptive design -->
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
"""
if with_google_fonts:
html += """
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
"""
html += f"""
<link rel="shortcut icon" href="{redoc_favicon_url}">
<!--
ReDoc doesn't change outer page styles
-->
<style>
body {{
margin: 0;
padding: 0;
}}
</style>
</head>
<body>
<redoc spec-url="{openapi_url}"></redoc>
<script src="{redoc_js_url}"> </script>
</body>
</html>
"""
return HTMLResponse(html)
def get_swagger_ui_oauth2_redirect_html() -> HTMLResponse:
html = """
<!DOCTYPE html>
<html lang="en-US">
<body onload="run()">
</body>
</html>
<script>
'use strict';
function run () {
var oauth2 = window.opener.swaggerUIRedirectOauth2;
var sentState = oauth2.state;
var redirectUrl = oauth2.redirectUrl;
var isValid, qp, arr;
if (/code|token|error/.test(window.location.hash)) {
qp = window.location.hash.substring(1);
} else {
qp = location.search.substring(1);
}
arr = qp.split("&")
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';})
qp = qp ? JSON.parse('{' + arr.join() + '}',
function (key, value) {
return key === "" ? value : decodeURIComponent(value)
}
) : {}
isValid = qp.state === sentState
if ((
oauth2.auth.schema.get("flow") === "accessCode"||
oauth2.auth.schema.get("flow") === "authorizationCode"
) && !oauth2.auth.code) {
if (!isValid) {
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "warning",
message: "Authorization may be unsafe, passed state was changed in server Passed state wasn't returned from auth server"
});
}
if (qp.code) {
delete oauth2.state;
oauth2.auth.code = qp.code;
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
} else {
let oauthErrorMsg
if (qp.error) {
oauthErrorMsg = "["+qp.error+"]: " +
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
(qp.error_uri ? "More info: "+qp.error_uri : "");
}
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "error",
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server"
});
}
} else {
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
}
window.close();
}
</script>
"""
return HTMLResponse(content=html)

View File

@@ -0,0 +1,351 @@
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from fastapi.logger import logger
from pydantic import AnyUrl, BaseModel, Field
try:
import email_validator # type: ignore
assert email_validator # make autoflake ignore the unused import
from pydantic import EmailStr
except ImportError: # pragma: no cover
class EmailStr(str): # type: ignore
@classmethod
def __get_validators__(cls) -> Iterable[Callable[..., Any]]:
yield cls.validate
@classmethod
def validate(cls, v: Any) -> str:
logger.warning(
"email-validator not installed, email fields will be treated as str.\n"
"To install, run: pip install email-validator"
)
return str(v)
class Contact(BaseModel):
name: Optional[str] = None
url: Optional[AnyUrl] = None
email: Optional[EmailStr] = None
class License(BaseModel):
name: str
url: Optional[AnyUrl] = None
class Info(BaseModel):
title: str
description: Optional[str] = None
termsOfService: Optional[str] = None
contact: Optional[Contact] = None
license: Optional[License] = None
version: str
class ServerVariable(BaseModel):
enum: Optional[List[str]] = None
default: str
description: Optional[str] = None
class Server(BaseModel):
url: Union[AnyUrl, str]
description: Optional[str] = None
variables: Optional[Dict[str, ServerVariable]] = None
class Reference(BaseModel):
ref: str = Field(..., alias="$ref")
class Discriminator(BaseModel):
propertyName: str
mapping: Optional[Dict[str, str]] = None
class XML(BaseModel):
name: Optional[str] = None
namespace: Optional[str] = None
prefix: Optional[str] = None
attribute: Optional[bool] = None
wrapped: Optional[bool] = None
class ExternalDocumentation(BaseModel):
description: Optional[str] = None
url: AnyUrl
class SchemaBase(BaseModel):
ref: Optional[str] = Field(None, alias="$ref")
title: Optional[str] = None
multipleOf: Optional[float] = None
maximum: Optional[float] = None
exclusiveMaximum: Optional[float] = None
minimum: Optional[float] = None
exclusiveMinimum: Optional[float] = None
maxLength: Optional[int] = Field(None, gte=0)
minLength: Optional[int] = Field(None, gte=0)
pattern: Optional[str] = None
maxItems: Optional[int] = Field(None, gte=0)
minItems: Optional[int] = Field(None, gte=0)
uniqueItems: Optional[bool] = None
maxProperties: Optional[int] = Field(None, gte=0)
minProperties: Optional[int] = Field(None, gte=0)
required: Optional[List[str]] = None
enum: Optional[List[Any]] = None
type: Optional[str] = None
allOf: Optional[List[Any]] = None
oneOf: Optional[List[Any]] = None
anyOf: Optional[List[Any]] = None
not_: Optional[Any] = Field(None, alias="not")
items: Optional[Any] = None
properties: Optional[Dict[str, Any]] = None
additionalProperties: Optional[Union[Dict[str, Any], bool]] = None
description: Optional[str] = None
format: Optional[str] = None
default: Optional[Any] = None
nullable: Optional[bool] = None
discriminator: Optional[Discriminator] = None
readOnly: Optional[bool] = None
writeOnly: Optional[bool] = None
xml: Optional[XML] = None
externalDocs: Optional[ExternalDocumentation] = None
example: Optional[Any] = None
deprecated: Optional[bool] = None
class Schema(SchemaBase):
allOf: Optional[List[SchemaBase]] = None
oneOf: Optional[List[SchemaBase]] = None
anyOf: Optional[List[SchemaBase]] = None
not_: Optional[SchemaBase] = Field(None, alias="not")
items: Optional[SchemaBase] = None
properties: Optional[Dict[str, SchemaBase]] = None
additionalProperties: Optional[Union[Dict[str, Any], bool]] = None
class Example(BaseModel):
summary: Optional[str] = None
description: Optional[str] = None
value: Optional[Any] = None
externalValue: Optional[AnyUrl] = None
class ParameterInType(Enum):
query = "query"
header = "header"
path = "path"
cookie = "cookie"
class Encoding(BaseModel):
contentType: Optional[str] = None
# Workaround OpenAPI recursive reference, using Any
headers: Optional[Dict[str, Union[Any, Reference]]] = None
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
class MediaType(BaseModel):
schema_: Optional[Union[Schema, Reference]] = Field(None, alias="schema")
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
encoding: Optional[Dict[str, Encoding]] = None
class ParameterBase(BaseModel):
description: Optional[str] = None
required: Optional[bool] = None
deprecated: Optional[bool] = None
# Serialization rules for simple scenarios
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
schema_: Optional[Union[Schema, Reference]] = Field(None, alias="schema")
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
# Serialization rules for more complex scenarios
content: Optional[Dict[str, MediaType]] = None
class Parameter(ParameterBase):
name: str
in_: ParameterInType = Field(..., alias="in")
class Header(ParameterBase):
pass
# Workaround OpenAPI recursive reference
class EncodingWithHeaders(Encoding):
headers: Optional[Dict[str, Union[Header, Reference]]] = None
class RequestBody(BaseModel):
description: Optional[str] = None
content: Dict[str, MediaType]
required: Optional[bool] = None
class Link(BaseModel):
operationRef: Optional[str] = None
operationId: Optional[str] = None
parameters: Optional[Dict[str, Union[Any, str]]] = None
requestBody: Optional[Union[Any, str]] = None
description: Optional[str] = None
server: Optional[Server] = None
class Response(BaseModel):
description: str
headers: Optional[Dict[str, Union[Header, Reference]]] = None
content: Optional[Dict[str, MediaType]] = None
links: Optional[Dict[str, Union[Link, Reference]]] = None
class Operation(BaseModel):
tags: Optional[List[str]] = None
summary: Optional[str] = None
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
operationId: Optional[str] = None
parameters: Optional[List[Union[Parameter, Reference]]] = None
requestBody: Optional[Union[RequestBody, Reference]] = None
responses: Dict[str, Response]
# Workaround OpenAPI recursive reference
callbacks: Optional[Dict[str, Union[Dict[str, Any], Reference]]] = None
deprecated: Optional[bool] = None
security: Optional[List[Dict[str, List[str]]]] = None
servers: Optional[List[Server]] = None
class PathItem(BaseModel):
ref: Optional[str] = Field(None, alias="$ref")
summary: Optional[str] = None
description: Optional[str] = None
get: Optional[Operation] = None
put: Optional[Operation] = None
post: Optional[Operation] = None
delete: Optional[Operation] = None
options: Optional[Operation] = None
head: Optional[Operation] = None
patch: Optional[Operation] = None
trace: Optional[Operation] = None
servers: Optional[List[Server]] = None
parameters: Optional[List[Union[Parameter, Reference]]] = None
# Workaround OpenAPI recursive reference
class OperationWithCallbacks(BaseModel):
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
class SecuritySchemeType(Enum):
apiKey = "apiKey"
http = "http"
oauth2 = "oauth2"
openIdConnect = "openIdConnect"
class SecurityBase(BaseModel):
type_: SecuritySchemeType = Field(..., alias="type")
description: Optional[str] = None
class APIKeyIn(Enum):
query = "query"
header = "header"
cookie = "cookie"
class APIKey(SecurityBase):
type_ = Field(SecuritySchemeType.apiKey, alias="type")
in_: APIKeyIn = Field(..., alias="in")
name: str
class HTTPBase(SecurityBase):
type_ = Field(SecuritySchemeType.http, alias="type")
scheme: str
class HTTPBearer(HTTPBase):
scheme = "bearer"
bearerFormat: Optional[str] = None
class OAuthFlow(BaseModel):
refreshUrl: Optional[str] = None
scopes: Dict[str, str] = {}
class OAuthFlowImplicit(OAuthFlow):
authorizationUrl: str
class OAuthFlowPassword(OAuthFlow):
tokenUrl: str
class OAuthFlowClientCredentials(OAuthFlow):
tokenUrl: str
class OAuthFlowAuthorizationCode(OAuthFlow):
authorizationUrl: str
tokenUrl: str
class OAuthFlows(BaseModel):
implicit: Optional[OAuthFlowImplicit] = None
password: Optional[OAuthFlowPassword] = None
clientCredentials: Optional[OAuthFlowClientCredentials] = None
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
class OAuth2(SecurityBase):
type_ = Field(SecuritySchemeType.oauth2, alias="type")
flows: OAuthFlows
class OpenIdConnect(SecurityBase):
type_ = Field(SecuritySchemeType.openIdConnect, alias="type")
openIdConnectUrl: str
SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer]
class Components(BaseModel):
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
responses: Optional[Dict[str, Union[Response, Reference]]] = None
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None
headers: Optional[Dict[str, Union[Header, Reference]]] = None
securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None
links: Optional[Dict[str, Union[Link, Reference]]] = None
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
class Tag(BaseModel):
name: str
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
class OpenAPI(BaseModel):
openapi: str
info: Info
servers: Optional[List[Server]] = None
paths: Dict[str, PathItem]
components: Optional[Components] = None
security: Optional[List[Dict[str, List[str]]]] = None
tags: Optional[List[Tag]] = None
externalDocs: Optional[ExternalDocumentation] = None

View File

@@ -0,0 +1,377 @@
import http.client
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
from fastapi import routing
from fastapi.datastructures import DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import (
METHODS_WITH_BODY,
REF_PREFIX,
STATUS_CODES_WITH_NO_BODY,
)
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, Param
from fastapi.responses import Response
from fastapi.utils import (
deep_dict_update,
generate_operation_id_for_path,
get_model_definitions,
)
from pydantic import BaseModel
from pydantic.fields import ModelField
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
)
from pydantic.utils import lenient_issubclass
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
validation_error_definition = {
"title": "ValidationError",
"type": "object",
"properties": {
"loc": {"title": "Location", "type": "array", "items": {"type": "string"}},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
"required": ["loc", "msg", "type"],
}
validation_error_response_definition = {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": REF_PREFIX + "ValidationError"},
}
},
}
status_code_ranges: Dict[str, str] = {
"1XX": "Information",
"2XX": "Success",
"3XX": "Redirection",
"4XX": "Client Error",
"5XX": "Server Error",
"DEFAULT": "Default Response",
}
def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
security_definition = jsonable_encoder(
security_requirement.security_scheme.model,
by_alias=True,
exclude_none=True,
)
security_name = security_requirement.security_scheme.scheme_name
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
return security_definitions, operation_security
def get_openapi_operation_parameters(
*,
all_route_params: Sequence[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> List[Dict[str, Any]]:
parameters = []
for param in all_route_params:
field_info = param.field_info
field_info = cast(Param, field_info)
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": field_schema(
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0],
}
if field_info.description:
parameter["description"] = field_info.description
if field_info.deprecated:
parameter["deprecated"] = field_info.deprecated
parameters.append(parameter)
return parameters
def get_openapi_operation_request_body(
*,
body_field: Optional[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Optional[Dict[str, Any]]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
body_schema, _, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
required = body_field.required
request_body_oai: Dict[str, Any] = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
return request_body_oai
def generate_operation_id(*, route: routing.APIRoute, method: str) -> str:
if route.operation_id:
return route.operation_id
path: str = route.path_format
return generate_operation_id_for_path(name=route.name, path=path, method=method)
def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
if route.summary:
return route.summary
return route.name.replace("_", " ").title()
def get_openapi_operation_metadata(
*, route: routing.APIRoute, method: str
) -> Dict[str, Any]:
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
operation["summary"] = generate_operation_summary(route=route, method=method)
if route.description:
operation["description"] = route.description
operation["operationId"] = generate_operation_id(route=route, method=method)
if route.deprecated:
operation["deprecated"] = route.deprecated
return operation
def get_openapi_path(
*, route: routing.APIRoute, model_name_map: Dict[type, str]
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
path = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
if isinstance(route.response_class, DefaultPlaceholder):
current_response_class: Type[Response] = route.response_class.value
else:
current_response_class = route.response_class
assert current_response_class, "A response class is needed to generate OpenAPI"
route_response_media_type: Optional[str] = current_response_class.media_type
if route.include_in_schema:
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method)
parameters: List[Dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
)
if operation_security:
operation.setdefault("security", []).extend(operation_security)
if security_definitions:
security_schemes.update(security_definitions)
all_route_params = get_flat_params(route.dependant)
operation_parameters = get_openapi_operation_parameters(
all_route_params=all_route_params, model_name_map=model_name_map
)
parameters.extend(operation_parameters)
if parameters:
operation["parameters"] = list(
{param["name"]: param for param in parameters}.values()
)
if method in METHODS_WITH_BODY:
request_body_oai = get_openapi_operation_request_body(
body_field=route.body_field, model_name_map=model_name_map
)
if request_body_oai:
operation["requestBody"] = request_body_oai
if route.callbacks:
callbacks = {}
for callback in route.callbacks:
if isinstance(callback, routing.APIRoute):
(
cb_path,
cb_security_schemes,
cb_definitions,
) = get_openapi_path(
route=callback, model_name_map=model_name_map
)
callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks
status_code = str(route.status_code)
operation.setdefault("responses", {}).setdefault(status_code, {})[
"description"
] = route.response_description
if (
route_response_media_type
and route.status_code not in STATUS_CODES_WITH_NO_BODY
):
response_schema = {"type": "string"}
if lenient_issubclass(current_response_class, JSONResponse):
if route.response_field:
response_schema, _, _ = field_schema(
route.response_field,
model_name_map=model_name_map,
ref_prefix=REF_PREFIX,
)
else:
response_schema = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(route_response_media_type, {})[
"schema"
] = response_schema
if route.responses:
operation_responses = operation.setdefault("responses", {})
for (
additional_status_code,
additional_response,
) in route.responses.items():
process_response = additional_response.copy()
process_response.pop("model", None)
status_code_key = str(additional_status_code).upper()
if status_code_key == "DEFAULT":
status_code_key = "default"
openapi_response = operation_responses.setdefault(
status_code_key, {}
)
assert isinstance(
process_response, dict
), "An additional response must be a dict"
field = route.response_fields.get(additional_status_code)
additional_field_schema: Optional[Dict[str, Any]] = None
if field:
additional_field_schema, _, _ = field_schema(
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
media_type = route_response_media_type or "application/json"
additional_schema = (
process_response.setdefault("content", {})
.setdefault(media_type, {})
.setdefault("schema", {})
)
deep_dict_update(additional_schema, additional_field_schema)
status_text: Optional[str] = status_code_ranges.get(
str(additional_status_code).upper()
) or http.client.responses.get(int(additional_status_code))
description = (
process_response.get("description")
or openapi_response.get("description")
or status_text
or "Additional Response"
)
deep_dict_update(openapi_response, process_response)
openapi_response["description"] = description
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
if (all_route_params or route.body_field) and not any(
[
status in operation["responses"]
for status in [http422, "4XX", "default"]
]
):
operation["responses"][http422] = {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
}
},
}
if "ValidationError" not in definitions:
definitions.update(
{
"ValidationError": validation_error_definition,
"HTTPValidationError": validation_error_response_definition,
}
)
path[method.lower()] = operation
return path, security_schemes, definitions
def get_flat_models_from_routes(
routes: Sequence[BaseRoute],
) -> Set[Union[Type[BaseModel], Type[Enum]]]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
request_fields_from_routes: List[ModelField] = []
callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, ModelField
), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.callbacks:
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
params = get_flat_params(route.dependant)
request_fields_from_routes.extend(params)
flat_models = callback_flat_models | get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes + request_fields_from_routes,
known_models=set(),
)
return flat_models
def get_openapi(
*,
title: str,
version: str,
openapi_version: str = "3.0.2",
description: Optional[str] = None,
routes: Sequence[BaseRoute],
tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
) -> Dict[str, Any]:
info = {"title": title, "version": version}
if description:
info["description"] = description
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
if servers:
output["servers"] = servers
components: Dict[str, Dict[str, Any]] = {}
paths: Dict[str, Dict[str, Any]] = {}
flat_models = get_flat_models_from_routes(routes)
model_name_map = get_model_name_map(flat_models)
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map
)
for route in routes:
if isinstance(route, routing.APIRoute):
result = get_openapi_path(route=route, model_name_map=model_name_map)
if result:
path, security_schemes, path_definitions = result
if path:
paths.setdefault(route.path_format, {}).update(path)
if security_schemes:
components.setdefault("securitySchemes", {}).update(
security_schemes
)
if path_definitions:
definitions.update(path_definitions)
if definitions:
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
if components:
output["components"] = components
output["paths"] = paths
if tags:
output["tags"] = tags
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore

View File

@@ -0,0 +1,253 @@
from typing import Any, Callable, Optional, Sequence
from fastapi import params
def Path( # noqa: N802
default: Any,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
deprecated: Optional[bool] = None,
**extra: Any,
) -> Any:
return params.Path(
default=default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
def Query( # noqa: N802
default: Any,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
deprecated: Optional[bool] = None,
**extra: Any,
) -> Any:
return params.Query(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
def Header( # noqa: N802
default: Any,
*,
alias: Optional[str] = None,
convert_underscores: bool = True,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
deprecated: Optional[bool] = None,
**extra: Any,
) -> Any:
return params.Header(
default,
alias=alias,
convert_underscores=convert_underscores,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
def Cookie( # noqa: N802
default: Any,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
deprecated: Optional[bool] = None,
**extra: Any,
) -> Any:
return params.Cookie(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
def Body( # noqa: N802
default: Any,
*,
embed: bool = False,
media_type: str = "application/json",
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
**extra: Any,
) -> Any:
return params.Body(
default,
embed=embed,
media_type=media_type,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
def Form( # noqa: N802
default: Any,
*,
media_type: str = "application/x-www-form-urlencoded",
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
**extra: Any,
) -> Any:
return params.Form(
default,
media_type=media_type,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
def File( # noqa: N802
default: Any,
*,
media_type: str = "multipart/form-data",
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
**extra: Any,
) -> Any:
return params.File(
default,
media_type=media_type,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
def Depends( # noqa: N802
dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
) -> Any:
return params.Depends(dependency=dependency, use_cache=use_cache)
def Security( # noqa: N802
dependency: Optional[Callable[..., Any]] = None,
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
) -> Any:
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)

View File

@@ -0,0 +1,338 @@
from enum import Enum
from typing import Any, Callable, Optional, Sequence
from pydantic.fields import FieldInfo
class ParamTypes(Enum):
query = "query"
header = "header"
path = "path"
cookie = "cookie"
class Param(FieldInfo):
in_: ParamTypes
def __init__(
self,
default: Any,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
deprecated: Optional[bool] = None,
**extra: Any,
):
self.deprecated = deprecated
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"
class Path(Param):
in_ = ParamTypes.path
def __init__(
self,
default: Any,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
deprecated: Optional[bool] = None,
**extra: Any,
):
self.in_ = self.in_
super().__init__(
...,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
class Query(Param):
in_ = ParamTypes.query
def __init__(
self,
default: Any,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
deprecated: Optional[bool] = None,
**extra: Any,
):
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
class Header(Param):
in_ = ParamTypes.header
def __init__(
self,
default: Any,
*,
alias: Optional[str] = None,
convert_underscores: bool = True,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
deprecated: Optional[bool] = None,
**extra: Any,
):
self.convert_underscores = convert_underscores
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
class Cookie(Param):
in_ = ParamTypes.cookie
def __init__(
self,
default: Any,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
deprecated: Optional[bool] = None,
**extra: Any,
):
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
class Body(FieldInfo):
def __init__(
self,
default: Any,
*,
embed: bool = False,
media_type: str = "application/json",
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
**extra: Any,
):
self.embed = embed
self.media_type = media_type
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"
class Form(Body):
def __init__(
self,
default: Any,
*,
media_type: str = "application/x-www-form-urlencoded",
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
**extra: Any,
):
super().__init__(
default,
embed=True,
media_type=media_type,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class File(Form):
def __init__(
self,
default: Any,
*,
media_type: str = "multipart/form-data",
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
**extra: Any,
):
super().__init__(
default,
media_type=media_type,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Depends:
def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
):
self.dependency = dependency
self.use_cache = use_cache
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"
class Security(Depends):
def __init__(
self,
dependency: Optional[Callable[..., Any]] = None,
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache)
self.scopes = scopes or []

View File

@@ -0,0 +1,2 @@
from starlette.requests import HTTPConnection as HTTPConnection # noqa: F401
from starlette.requests import Request as Request # noqa: F401

View File

@@ -0,0 +1,23 @@
from typing import Any
from starlette.responses import FileResponse as FileResponse # noqa
from starlette.responses import HTMLResponse as HTMLResponse # noqa
from starlette.responses import JSONResponse as JSONResponse # noqa
from starlette.responses import PlainTextResponse as PlainTextResponse # noqa
from starlette.responses import RedirectResponse as RedirectResponse # noqa
from starlette.responses import Response as Response # noqa
from starlette.responses import StreamingResponse as StreamingResponse # noqa
from starlette.responses import UJSONResponse as UJSONResponse # noqa
try:
import orjson
except ImportError: # pragma: nocover
orjson = None # type: ignore
class ORJSONResponse(JSONResponse):
media_type = "application/json"
def render(self, content: Any) -> bytes:
assert orjson is not None, "orjson must be installed to use ORJSONResponse"
return orjson.dumps(content)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,15 @@
from .api_key import APIKeyCookie as APIKeyCookie
from .api_key import APIKeyHeader as APIKeyHeader
from .api_key import APIKeyQuery as APIKeyQuery
from .http import HTTPAuthorizationCredentials as HTTPAuthorizationCredentials
from .http import HTTPBasic as HTTPBasic
from .http import HTTPBasicCredentials as HTTPBasicCredentials
from .http import HTTPBearer as HTTPBearer
from .http import HTTPDigest as HTTPDigest
from .oauth2 import OAuth2 as OAuth2
from .oauth2 import OAuth2AuthorizationCodeBearer as OAuth2AuthorizationCodeBearer
from .oauth2 import OAuth2PasswordBearer as OAuth2PasswordBearer
from .oauth2 import OAuth2PasswordRequestForm as OAuth2PasswordRequestForm
from .oauth2 import OAuth2PasswordRequestFormStrict as OAuth2PasswordRequestFormStrict
from .oauth2 import SecurityScopes as SecurityScopes
from .open_id_connect_url import OpenIdConnect as OpenIdConnect

View File

@@ -0,0 +1,71 @@
from typing import Optional
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
class APIKeyBase(SecurityBase):
pass
class APIKeyQuery(APIKeyBase):
def __init__(
self, *, name: str, scheme_name: Optional[str] = None, auto_error: bool = True
):
self.model: APIKey = APIKey(**{"in": APIKeyIn.query}, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key: str = request.query_params.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key
class APIKeyHeader(APIKeyBase):
def __init__(
self, *, name: str, scheme_name: Optional[str] = None, auto_error: bool = True
):
self.model: APIKey = APIKey(**{"in": APIKeyIn.header}, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key: str = request.headers.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key
class APIKeyCookie(APIKeyBase):
def __init__(
self, *, name: str, scheme_name: Optional[str] = None, auto_error: bool = True
):
self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key

View File

@@ -0,0 +1,6 @@
from fastapi.openapi.models import SecurityBase as SecurityBaseModel
class SecurityBase:
model: SecurityBaseModel
scheme_name: str

View File

@@ -0,0 +1,152 @@
import binascii
from base64 import b64decode
from typing import Optional
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
class HTTPBasicCredentials(BaseModel):
username: str
password: str
class HTTPAuthorizationCredentials(BaseModel):
scheme: str
credentials: str
class HTTPBase(SecurityBase):
def __init__(
self, *, scheme: str, scheme_name: Optional[str] = None, auto_error: bool = True
):
self.model = HTTPBaseModel(scheme=scheme)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
class HTTPBasic(HTTPBase):
def __init__(
self,
*,
scheme_name: Optional[str] = None,
realm: Optional[str] = None,
auto_error: bool = True,
):
self.model = HTTPBaseModel(scheme="basic")
self.scheme_name = scheme_name or self.__class__.__name__
self.realm = realm
self.auto_error = auto_error
async def __call__( # type: ignore
self, request: Request
) -> Optional[HTTPBasicCredentials]:
authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if self.realm:
unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
else:
unauthorized_headers = {"WWW-Authenticate": "Basic"}
invalid_user_credentials_exc = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers=unauthorized_headers,
)
if not authorization or scheme.lower() != "basic":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers=unauthorized_headers,
)
else:
return None
try:
data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise invalid_user_credentials_exc
username, separator, password = data.partition(":")
if not separator:
raise invalid_user_credentials_exc
return HTTPBasicCredentials(username=username, password=password)
class HTTPBearer(HTTPBase):
def __init__(
self,
*,
bearerFormat: Optional[str] = None,
scheme_name: Optional[str] = None,
auto_error: bool = True,
):
self.model = HTTPBearerModel(bearerFormat=bearerFormat)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
if scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
class HTTPDigest(HTTPBase):
def __init__(self, *, scheme_name: Optional[str] = None, auto_error: bool = True):
self.model = HTTPBaseModel(scheme="digest")
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
if scheme.lower() != "digest":
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)

View File

@@ -0,0 +1,207 @@
from typing import Any, Dict, List, Optional, Union
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
class OAuth2PasswordRequestForm:
"""
This is a dependency class, use it like:
@app.post("/login")
def login(form_data: OAuth2PasswordRequestForm = Depends()):
data = form_data.parse()
print(data.username)
print(data.password)
for scope in data.scopes:
print(scope)
if data.client_id:
print(data.client_id)
if data.client_secret:
print(data.client_secret)
return data
It creates the following Form request parameters in your endpoint:
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
Nevertheless, this dependency class is permissive and allows not passing it. If you want to enforce it,
use instead the OAuth2PasswordRequestFormStrict dependency.
username: username string. The OAuth2 spec requires the exact field name "username".
password: password string. The OAuth2 spec requires the exact field name "password".
scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
"items:read items:write users:read profile openid"
client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
using HTTP Basic auth, as: client_id:client_secret
client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
using HTTP Basic auth, as: client_id:client_secret
"""
def __init__(
self,
grant_type: str = Form(None, regex="password"),
username: str = Form(...),
password: str = Form(...),
scope: str = Form(""),
client_id: Optional[str] = Form(None),
client_secret: Optional[str] = Form(None),
):
self.grant_type = grant_type
self.username = username
self.password = password
self.scopes = scope.split()
self.client_id = client_id
self.client_secret = client_secret
class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
"""
This is a dependency class, use it like:
@app.post("/login")
def login(form_data: OAuth2PasswordRequestFormStrict = Depends()):
data = form_data.parse()
print(data.username)
print(data.password)
for scope in data.scopes:
print(scope)
if data.client_id:
print(data.client_id)
if data.client_secret:
print(data.client_secret)
return data
It creates the following Form request parameters in your endpoint:
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
This dependency is strict about it. If you want to be permissive, use instead the
OAuth2PasswordRequestForm dependency class.
username: username string. The OAuth2 spec requires the exact field name "username".
password: password string. The OAuth2 spec requires the exact field name "password".
scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
"items:read items:write users:read profile openid"
client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
using HTTP Basic auth, as: client_id:client_secret
client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
using HTTP Basic auth, as: client_id:client_secret
"""
def __init__(
self,
grant_type: str = Form(..., regex="password"),
username: str = Form(...),
password: str = Form(...),
scope: str = Form(""),
client_id: Optional[str] = Form(None),
client_secret: Optional[str] = Form(None),
):
super().__init__(
grant_type=grant_type,
username=username,
password=password,
scope=scope,
client_id=client_id,
client_secret=client_secret,
)
class OAuth2(SecurityBase):
def __init__(
self,
*,
flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(),
scheme_name: Optional[str] = None,
auto_error: Optional[bool] = True
):
self.model = OAuth2Model(flows=flows)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return authorization
class OAuth2PasswordBearer(OAuth2):
def __init__(
self,
tokenUrl: str,
scheme_name: Optional[str] = None,
scopes: Optional[Dict[str, str]] = None,
auto_error: bool = True,
):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None
return param
class OAuth2AuthorizationCodeBearer(OAuth2):
def __init__(
self,
authorizationUrl: str,
tokenUrl: str,
refreshUrl: Optional[str] = None,
scheme_name: Optional[str] = None,
scopes: Optional[Dict[str, str]] = None,
auto_error: bool = True,
):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(
authorizationCode={
"authorizationUrl": authorizationUrl,
"tokenUrl": tokenUrl,
"refreshUrl": refreshUrl,
"scopes": scopes,
}
)
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None # pragma: nocover
return param
class SecurityScopes:
def __init__(self, scopes: Optional[List[str]] = None):
self.scopes = scopes or []
self.scope_str = " ".join(self.scopes)

View File

@@ -0,0 +1,31 @@
from typing import Optional
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
class OpenIdConnect(SecurityBase):
def __init__(
self,
*,
openIdConnectUrl: str,
scheme_name: Optional[str] = None,
auto_error: bool = True
):
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return authorization

View File

@@ -0,0 +1,8 @@
from typing import Tuple
def get_authorization_scheme_param(authorization_header_value: str) -> Tuple[str, str]:
if not authorization_header_value:
return "", ""
scheme, _, param = authorization_header_value.partition(" ")
return scheme, param

View File

@@ -0,0 +1 @@
from starlette.staticfiles import StaticFiles as StaticFiles # noqa

View File

@@ -0,0 +1 @@
from starlette.templating import Jinja2Templates as Jinja2Templates # noqa

View File

@@ -0,0 +1 @@
from starlette.testclient import TestClient as TestClient # noqa

View File

@@ -0,0 +1,3 @@
from typing import Any, Callable, TypeVar
DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])

View File

@@ -0,0 +1,156 @@
import functools
import re
from dataclasses import is_dataclass
from enum import Enum
from typing import Any, Dict, Optional, Set, Type, Union, cast
import fastapi
from fastapi.datastructures import DefaultPlaceholder, DefaultType
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model
from pydantic.class_validators import Validator
from pydantic.fields import FieldInfo, ModelField, UndefinedType
from pydantic.schema import model_process_schema
from pydantic.utils import lenient_issubclass
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict[str, Any]] = {}
for model in flat_models:
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
return definitions
def get_path_param_names(path: str) -> Set[str]:
return set(re.findall("{(.*?)}", path))
def create_response_field(
name: str,
type_: Type[Any],
class_validators: Optional[Dict[str, Validator]] = None,
default: Optional[Any] = None,
required: Union[bool, UndefinedType] = False,
model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
) -> ModelField:
"""
Create a new response field. Raises if type_ is invalid.
"""
class_validators = class_validators or {}
field_info = field_info or FieldInfo(None)
response_field = functools.partial(
ModelField,
name=name,
type_=type_,
class_validators=class_validators,
default=default,
required=required,
model_config=model_config,
alias=alias,
)
try:
return response_field(field_info=field_info)
except RuntimeError:
raise fastapi.exceptions.FastAPIError(
f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
)
def create_cloned_field(
field: ModelField,
*,
cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None,
) -> ModelField:
# _cloned_types has already cloned types, to support recursive models
if cloned_types is None:
cloned_types = dict()
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
use_type = cloned_types.get(original_type)
if use_type is None:
use_type = create_model(original_type.__name__, __base__=original_type)
cloned_types[original_type] = use_type
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(
f, cloned_types=cloned_types
)
new_field = create_response_field(name=field.name, type_=use_type)
new_field.has_alias = field.has_alias
new_field.alias = field.alias
new_field.class_validators = field.class_validators
new_field.default = field.default
new_field.required = field.required
new_field.model_config = field.model_config
new_field.field_info = field.field_info
new_field.allow_none = field.allow_none
new_field.validate_always = field.validate_always
if field.sub_fields:
new_field.sub_fields = [
create_cloned_field(sub_field, cloned_types=cloned_types)
for sub_field in field.sub_fields
]
if field.key_field:
new_field.key_field = create_cloned_field(
field.key_field, cloned_types=cloned_types
)
new_field.validators = field.validators
new_field.pre_validators = field.pre_validators
new_field.post_validators = field.post_validators
new_field.parse_json = field.parse_json
new_field.shape = field.shape
new_field.populate_validators()
return new_field
def generate_operation_id_for_path(*, name: str, path: str, method: str) -> str:
operation_id = name + path
operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id)
operation_id = operation_id + "_" + method.lower()
return operation_id
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
for key in update_dict:
if (
key in main_dict
and isinstance(main_dict[key], dict)
and isinstance(update_dict[key], dict)
):
deep_dict_update(main_dict[key], update_dict[key])
else:
main_dict[key] = update_dict[key]
def get_value_or_default(
first_item: Union[DefaultPlaceholder, DefaultType],
*extra_items: Union[DefaultPlaceholder, DefaultType],
) -> Union[DefaultPlaceholder, DefaultType]:
"""
Pass items or `DefaultPlaceholder`s by descending priority.
The first one to _not_ be a `DefaultPlaceholder` will be returned.
Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
"""
items = (first_item,) + extra_items
for item in items:
if not isinstance(item, DefaultPlaceholder):
return item
return first_item

View File

@@ -0,0 +1,2 @@
from starlette.websockets import WebSocket as WebSocket # noqa
from starlette.websockets import WebSocketDisconnect as WebSocketDisconnect # noqa