This commit is contained in:
Untriex Programming
2021-03-17 08:57:57 +01:00
parent 339be0ccd8
commit ed6afdb5c9
3074 changed files with 423348 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}
STATUS_CODES_WITH_NO_BODY = {100, 101, 102, 103, 204, 304}
REF_PREFIX = "#/components/schemas/"

View File

@@ -0,0 +1,177 @@
import json
from typing import Any, Dict, Optional
from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse
def get_swagger_ui_html(
*,
openapi_url: str,
title: str,
swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js",
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css",
swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
oauth2_redirect_url: Optional[str] = None,
init_oauth: Optional[Dict[str, Any]] = None,
) -> HTMLResponse:
html = f"""
<!DOCTYPE html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="{swagger_css_url}">
<link rel="shortcut icon" href="{swagger_favicon_url}">
<title>{title}</title>
</head>
<body>
<div id="swagger-ui">
</div>
<script src="{swagger_js_url}"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({{
url: '{openapi_url}',
"""
if oauth2_redirect_url:
html += f"oauth2RedirectUrl: window.location.origin + '{oauth2_redirect_url}',"
html += """
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true
})"""
if init_oauth:
html += f"""
ui.initOAuth({json.dumps(jsonable_encoder(init_oauth))})
"""
html += """
</script>
</body>
</html>
"""
return HTMLResponse(html)
def get_redoc_html(
*,
openapi_url: str,
title: str,
redoc_js_url: str = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js",
redoc_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
with_google_fonts: bool = True,
) -> HTMLResponse:
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>{title}</title>
<!-- needed for adaptive design -->
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
"""
if with_google_fonts:
html += """
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
"""
html += f"""
<link rel="shortcut icon" href="{redoc_favicon_url}">
<!--
ReDoc doesn't change outer page styles
-->
<style>
body {{
margin: 0;
padding: 0;
}}
</style>
</head>
<body>
<redoc spec-url="{openapi_url}"></redoc>
<script src="{redoc_js_url}"> </script>
</body>
</html>
"""
return HTMLResponse(html)
def get_swagger_ui_oauth2_redirect_html() -> HTMLResponse:
html = """
<!DOCTYPE html>
<html lang="en-US">
<body onload="run()">
</body>
</html>
<script>
'use strict';
function run () {
var oauth2 = window.opener.swaggerUIRedirectOauth2;
var sentState = oauth2.state;
var redirectUrl = oauth2.redirectUrl;
var isValid, qp, arr;
if (/code|token|error/.test(window.location.hash)) {
qp = window.location.hash.substring(1);
} else {
qp = location.search.substring(1);
}
arr = qp.split("&")
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';})
qp = qp ? JSON.parse('{' + arr.join() + '}',
function (key, value) {
return key === "" ? value : decodeURIComponent(value)
}
) : {}
isValid = qp.state === sentState
if ((
oauth2.auth.schema.get("flow") === "accessCode"||
oauth2.auth.schema.get("flow") === "authorizationCode"
) && !oauth2.auth.code) {
if (!isValid) {
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "warning",
message: "Authorization may be unsafe, passed state was changed in server Passed state wasn't returned from auth server"
});
}
if (qp.code) {
delete oauth2.state;
oauth2.auth.code = qp.code;
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
} else {
let oauthErrorMsg
if (qp.error) {
oauthErrorMsg = "["+qp.error+"]: " +
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
(qp.error_uri ? "More info: "+qp.error_uri : "");
}
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "error",
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server"
});
}
} else {
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
}
window.close();
}
</script>
"""
return HTMLResponse(content=html)

View File

@@ -0,0 +1,351 @@
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from fastapi.logger import logger
from pydantic import AnyUrl, BaseModel, Field
try:
import email_validator # type: ignore
assert email_validator # make autoflake ignore the unused import
from pydantic import EmailStr
except ImportError: # pragma: no cover
class EmailStr(str): # type: ignore
@classmethod
def __get_validators__(cls) -> Iterable[Callable[..., Any]]:
yield cls.validate
@classmethod
def validate(cls, v: Any) -> str:
logger.warning(
"email-validator not installed, email fields will be treated as str.\n"
"To install, run: pip install email-validator"
)
return str(v)
class Contact(BaseModel):
name: Optional[str] = None
url: Optional[AnyUrl] = None
email: Optional[EmailStr] = None
class License(BaseModel):
name: str
url: Optional[AnyUrl] = None
class Info(BaseModel):
title: str
description: Optional[str] = None
termsOfService: Optional[str] = None
contact: Optional[Contact] = None
license: Optional[License] = None
version: str
class ServerVariable(BaseModel):
enum: Optional[List[str]] = None
default: str
description: Optional[str] = None
class Server(BaseModel):
url: Union[AnyUrl, str]
description: Optional[str] = None
variables: Optional[Dict[str, ServerVariable]] = None
class Reference(BaseModel):
ref: str = Field(..., alias="$ref")
class Discriminator(BaseModel):
propertyName: str
mapping: Optional[Dict[str, str]] = None
class XML(BaseModel):
name: Optional[str] = None
namespace: Optional[str] = None
prefix: Optional[str] = None
attribute: Optional[bool] = None
wrapped: Optional[bool] = None
class ExternalDocumentation(BaseModel):
description: Optional[str] = None
url: AnyUrl
class SchemaBase(BaseModel):
ref: Optional[str] = Field(None, alias="$ref")
title: Optional[str] = None
multipleOf: Optional[float] = None
maximum: Optional[float] = None
exclusiveMaximum: Optional[float] = None
minimum: Optional[float] = None
exclusiveMinimum: Optional[float] = None
maxLength: Optional[int] = Field(None, gte=0)
minLength: Optional[int] = Field(None, gte=0)
pattern: Optional[str] = None
maxItems: Optional[int] = Field(None, gte=0)
minItems: Optional[int] = Field(None, gte=0)
uniqueItems: Optional[bool] = None
maxProperties: Optional[int] = Field(None, gte=0)
minProperties: Optional[int] = Field(None, gte=0)
required: Optional[List[str]] = None
enum: Optional[List[Any]] = None
type: Optional[str] = None
allOf: Optional[List[Any]] = None
oneOf: Optional[List[Any]] = None
anyOf: Optional[List[Any]] = None
not_: Optional[Any] = Field(None, alias="not")
items: Optional[Any] = None
properties: Optional[Dict[str, Any]] = None
additionalProperties: Optional[Union[Dict[str, Any], bool]] = None
description: Optional[str] = None
format: Optional[str] = None
default: Optional[Any] = None
nullable: Optional[bool] = None
discriminator: Optional[Discriminator] = None
readOnly: Optional[bool] = None
writeOnly: Optional[bool] = None
xml: Optional[XML] = None
externalDocs: Optional[ExternalDocumentation] = None
example: Optional[Any] = None
deprecated: Optional[bool] = None
class Schema(SchemaBase):
allOf: Optional[List[SchemaBase]] = None
oneOf: Optional[List[SchemaBase]] = None
anyOf: Optional[List[SchemaBase]] = None
not_: Optional[SchemaBase] = Field(None, alias="not")
items: Optional[SchemaBase] = None
properties: Optional[Dict[str, SchemaBase]] = None
additionalProperties: Optional[Union[Dict[str, Any], bool]] = None
class Example(BaseModel):
summary: Optional[str] = None
description: Optional[str] = None
value: Optional[Any] = None
externalValue: Optional[AnyUrl] = None
class ParameterInType(Enum):
query = "query"
header = "header"
path = "path"
cookie = "cookie"
class Encoding(BaseModel):
contentType: Optional[str] = None
# Workaround OpenAPI recursive reference, using Any
headers: Optional[Dict[str, Union[Any, Reference]]] = None
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
class MediaType(BaseModel):
schema_: Optional[Union[Schema, Reference]] = Field(None, alias="schema")
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
encoding: Optional[Dict[str, Encoding]] = None
class ParameterBase(BaseModel):
description: Optional[str] = None
required: Optional[bool] = None
deprecated: Optional[bool] = None
# Serialization rules for simple scenarios
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
schema_: Optional[Union[Schema, Reference]] = Field(None, alias="schema")
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
# Serialization rules for more complex scenarios
content: Optional[Dict[str, MediaType]] = None
class Parameter(ParameterBase):
name: str
in_: ParameterInType = Field(..., alias="in")
class Header(ParameterBase):
pass
# Workaround OpenAPI recursive reference
class EncodingWithHeaders(Encoding):
headers: Optional[Dict[str, Union[Header, Reference]]] = None
class RequestBody(BaseModel):
description: Optional[str] = None
content: Dict[str, MediaType]
required: Optional[bool] = None
class Link(BaseModel):
operationRef: Optional[str] = None
operationId: Optional[str] = None
parameters: Optional[Dict[str, Union[Any, str]]] = None
requestBody: Optional[Union[Any, str]] = None
description: Optional[str] = None
server: Optional[Server] = None
class Response(BaseModel):
description: str
headers: Optional[Dict[str, Union[Header, Reference]]] = None
content: Optional[Dict[str, MediaType]] = None
links: Optional[Dict[str, Union[Link, Reference]]] = None
class Operation(BaseModel):
tags: Optional[List[str]] = None
summary: Optional[str] = None
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
operationId: Optional[str] = None
parameters: Optional[List[Union[Parameter, Reference]]] = None
requestBody: Optional[Union[RequestBody, Reference]] = None
responses: Dict[str, Response]
# Workaround OpenAPI recursive reference
callbacks: Optional[Dict[str, Union[Dict[str, Any], Reference]]] = None
deprecated: Optional[bool] = None
security: Optional[List[Dict[str, List[str]]]] = None
servers: Optional[List[Server]] = None
class PathItem(BaseModel):
ref: Optional[str] = Field(None, alias="$ref")
summary: Optional[str] = None
description: Optional[str] = None
get: Optional[Operation] = None
put: Optional[Operation] = None
post: Optional[Operation] = None
delete: Optional[Operation] = None
options: Optional[Operation] = None
head: Optional[Operation] = None
patch: Optional[Operation] = None
trace: Optional[Operation] = None
servers: Optional[List[Server]] = None
parameters: Optional[List[Union[Parameter, Reference]]] = None
# Workaround OpenAPI recursive reference
class OperationWithCallbacks(BaseModel):
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
class SecuritySchemeType(Enum):
apiKey = "apiKey"
http = "http"
oauth2 = "oauth2"
openIdConnect = "openIdConnect"
class SecurityBase(BaseModel):
type_: SecuritySchemeType = Field(..., alias="type")
description: Optional[str] = None
class APIKeyIn(Enum):
query = "query"
header = "header"
cookie = "cookie"
class APIKey(SecurityBase):
type_ = Field(SecuritySchemeType.apiKey, alias="type")
in_: APIKeyIn = Field(..., alias="in")
name: str
class HTTPBase(SecurityBase):
type_ = Field(SecuritySchemeType.http, alias="type")
scheme: str
class HTTPBearer(HTTPBase):
scheme = "bearer"
bearerFormat: Optional[str] = None
class OAuthFlow(BaseModel):
refreshUrl: Optional[str] = None
scopes: Dict[str, str] = {}
class OAuthFlowImplicit(OAuthFlow):
authorizationUrl: str
class OAuthFlowPassword(OAuthFlow):
tokenUrl: str
class OAuthFlowClientCredentials(OAuthFlow):
tokenUrl: str
class OAuthFlowAuthorizationCode(OAuthFlow):
authorizationUrl: str
tokenUrl: str
class OAuthFlows(BaseModel):
implicit: Optional[OAuthFlowImplicit] = None
password: Optional[OAuthFlowPassword] = None
clientCredentials: Optional[OAuthFlowClientCredentials] = None
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
class OAuth2(SecurityBase):
type_ = Field(SecuritySchemeType.oauth2, alias="type")
flows: OAuthFlows
class OpenIdConnect(SecurityBase):
type_ = Field(SecuritySchemeType.openIdConnect, alias="type")
openIdConnectUrl: str
SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer]
class Components(BaseModel):
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
responses: Optional[Dict[str, Union[Response, Reference]]] = None
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None
headers: Optional[Dict[str, Union[Header, Reference]]] = None
securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None
links: Optional[Dict[str, Union[Link, Reference]]] = None
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
class Tag(BaseModel):
name: str
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
class OpenAPI(BaseModel):
openapi: str
info: Info
servers: Optional[List[Server]] = None
paths: Dict[str, PathItem]
components: Optional[Components] = None
security: Optional[List[Dict[str, List[str]]]] = None
tags: Optional[List[Tag]] = None
externalDocs: Optional[ExternalDocumentation] = None

View File

@@ -0,0 +1,377 @@
import http.client
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
from fastapi import routing
from fastapi.datastructures import DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import (
METHODS_WITH_BODY,
REF_PREFIX,
STATUS_CODES_WITH_NO_BODY,
)
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, Param
from fastapi.responses import Response
from fastapi.utils import (
deep_dict_update,
generate_operation_id_for_path,
get_model_definitions,
)
from pydantic import BaseModel
from pydantic.fields import ModelField
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
)
from pydantic.utils import lenient_issubclass
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
validation_error_definition = {
"title": "ValidationError",
"type": "object",
"properties": {
"loc": {"title": "Location", "type": "array", "items": {"type": "string"}},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
"required": ["loc", "msg", "type"],
}
validation_error_response_definition = {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": REF_PREFIX + "ValidationError"},
}
},
}
status_code_ranges: Dict[str, str] = {
"1XX": "Information",
"2XX": "Success",
"3XX": "Redirection",
"4XX": "Client Error",
"5XX": "Server Error",
"DEFAULT": "Default Response",
}
def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
security_definition = jsonable_encoder(
security_requirement.security_scheme.model,
by_alias=True,
exclude_none=True,
)
security_name = security_requirement.security_scheme.scheme_name
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
return security_definitions, operation_security
def get_openapi_operation_parameters(
*,
all_route_params: Sequence[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> List[Dict[str, Any]]:
parameters = []
for param in all_route_params:
field_info = param.field_info
field_info = cast(Param, field_info)
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": field_schema(
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0],
}
if field_info.description:
parameter["description"] = field_info.description
if field_info.deprecated:
parameter["deprecated"] = field_info.deprecated
parameters.append(parameter)
return parameters
def get_openapi_operation_request_body(
*,
body_field: Optional[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Optional[Dict[str, Any]]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
body_schema, _, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
required = body_field.required
request_body_oai: Dict[str, Any] = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
return request_body_oai
def generate_operation_id(*, route: routing.APIRoute, method: str) -> str:
if route.operation_id:
return route.operation_id
path: str = route.path_format
return generate_operation_id_for_path(name=route.name, path=path, method=method)
def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
if route.summary:
return route.summary
return route.name.replace("_", " ").title()
def get_openapi_operation_metadata(
*, route: routing.APIRoute, method: str
) -> Dict[str, Any]:
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
operation["summary"] = generate_operation_summary(route=route, method=method)
if route.description:
operation["description"] = route.description
operation["operationId"] = generate_operation_id(route=route, method=method)
if route.deprecated:
operation["deprecated"] = route.deprecated
return operation
def get_openapi_path(
*, route: routing.APIRoute, model_name_map: Dict[type, str]
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
path = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
if isinstance(route.response_class, DefaultPlaceholder):
current_response_class: Type[Response] = route.response_class.value
else:
current_response_class = route.response_class
assert current_response_class, "A response class is needed to generate OpenAPI"
route_response_media_type: Optional[str] = current_response_class.media_type
if route.include_in_schema:
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method)
parameters: List[Dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
)
if operation_security:
operation.setdefault("security", []).extend(operation_security)
if security_definitions:
security_schemes.update(security_definitions)
all_route_params = get_flat_params(route.dependant)
operation_parameters = get_openapi_operation_parameters(
all_route_params=all_route_params, model_name_map=model_name_map
)
parameters.extend(operation_parameters)
if parameters:
operation["parameters"] = list(
{param["name"]: param for param in parameters}.values()
)
if method in METHODS_WITH_BODY:
request_body_oai = get_openapi_operation_request_body(
body_field=route.body_field, model_name_map=model_name_map
)
if request_body_oai:
operation["requestBody"] = request_body_oai
if route.callbacks:
callbacks = {}
for callback in route.callbacks:
if isinstance(callback, routing.APIRoute):
(
cb_path,
cb_security_schemes,
cb_definitions,
) = get_openapi_path(
route=callback, model_name_map=model_name_map
)
callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks
status_code = str(route.status_code)
operation.setdefault("responses", {}).setdefault(status_code, {})[
"description"
] = route.response_description
if (
route_response_media_type
and route.status_code not in STATUS_CODES_WITH_NO_BODY
):
response_schema = {"type": "string"}
if lenient_issubclass(current_response_class, JSONResponse):
if route.response_field:
response_schema, _, _ = field_schema(
route.response_field,
model_name_map=model_name_map,
ref_prefix=REF_PREFIX,
)
else:
response_schema = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(route_response_media_type, {})[
"schema"
] = response_schema
if route.responses:
operation_responses = operation.setdefault("responses", {})
for (
additional_status_code,
additional_response,
) in route.responses.items():
process_response = additional_response.copy()
process_response.pop("model", None)
status_code_key = str(additional_status_code).upper()
if status_code_key == "DEFAULT":
status_code_key = "default"
openapi_response = operation_responses.setdefault(
status_code_key, {}
)
assert isinstance(
process_response, dict
), "An additional response must be a dict"
field = route.response_fields.get(additional_status_code)
additional_field_schema: Optional[Dict[str, Any]] = None
if field:
additional_field_schema, _, _ = field_schema(
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
media_type = route_response_media_type or "application/json"
additional_schema = (
process_response.setdefault("content", {})
.setdefault(media_type, {})
.setdefault("schema", {})
)
deep_dict_update(additional_schema, additional_field_schema)
status_text: Optional[str] = status_code_ranges.get(
str(additional_status_code).upper()
) or http.client.responses.get(int(additional_status_code))
description = (
process_response.get("description")
or openapi_response.get("description")
or status_text
or "Additional Response"
)
deep_dict_update(openapi_response, process_response)
openapi_response["description"] = description
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
if (all_route_params or route.body_field) and not any(
[
status in operation["responses"]
for status in [http422, "4XX", "default"]
]
):
operation["responses"][http422] = {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
}
},
}
if "ValidationError" not in definitions:
definitions.update(
{
"ValidationError": validation_error_definition,
"HTTPValidationError": validation_error_response_definition,
}
)
path[method.lower()] = operation
return path, security_schemes, definitions
def get_flat_models_from_routes(
routes: Sequence[BaseRoute],
) -> Set[Union[Type[BaseModel], Type[Enum]]]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
request_fields_from_routes: List[ModelField] = []
callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, ModelField
), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.callbacks:
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
params = get_flat_params(route.dependant)
request_fields_from_routes.extend(params)
flat_models = callback_flat_models | get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes + request_fields_from_routes,
known_models=set(),
)
return flat_models
def get_openapi(
*,
title: str,
version: str,
openapi_version: str = "3.0.2",
description: Optional[str] = None,
routes: Sequence[BaseRoute],
tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
) -> Dict[str, Any]:
info = {"title": title, "version": version}
if description:
info["description"] = description
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
if servers:
output["servers"] = servers
components: Dict[str, Dict[str, Any]] = {}
paths: Dict[str, Dict[str, Any]] = {}
flat_models = get_flat_models_from_routes(routes)
model_name_map = get_model_name_map(flat_models)
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map
)
for route in routes:
if isinstance(route, routing.APIRoute):
result = get_openapi_path(route=route, model_name_map=model_name_map)
if result:
path, security_schemes, path_definitions = result
if path:
paths.setdefault(route.path_format, {}).update(path)
if security_schemes:
components.setdefault("securitySchemes", {}).update(
security_schemes
)
if path_definitions:
definitions.update(path_definitions)
if definitions:
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
if components:
output["components"] = components
output["paths"] = paths
if tags:
output["tags"] = tags
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore