157 lines
5.4 KiB
Python
157 lines
5.4 KiB
Python
|
import functools
|
||
|
import re
|
||
|
from dataclasses import is_dataclass
|
||
|
from enum import Enum
|
||
|
from typing import Any, Dict, Optional, Set, Type, Union, cast
|
||
|
|
||
|
import fastapi
|
||
|
from fastapi.datastructures import DefaultPlaceholder, DefaultType
|
||
|
from fastapi.openapi.constants import REF_PREFIX
|
||
|
from pydantic import BaseConfig, BaseModel, create_model
|
||
|
from pydantic.class_validators import Validator
|
||
|
from pydantic.fields import FieldInfo, ModelField, UndefinedType
|
||
|
from pydantic.schema import model_process_schema
|
||
|
from pydantic.utils import lenient_issubclass
|
||
|
|
||
|
|
||
|
def get_model_definitions(
|
||
|
*,
|
||
|
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
||
|
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
||
|
) -> Dict[str, Any]:
|
||
|
definitions: Dict[str, Dict[str, Any]] = {}
|
||
|
for model in flat_models:
|
||
|
m_schema, m_definitions, m_nested_models = model_process_schema(
|
||
|
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||
|
)
|
||
|
definitions.update(m_definitions)
|
||
|
model_name = model_name_map[model]
|
||
|
definitions[model_name] = m_schema
|
||
|
return definitions
|
||
|
|
||
|
|
||
|
def get_path_param_names(path: str) -> Set[str]:
|
||
|
return set(re.findall("{(.*?)}", path))
|
||
|
|
||
|
|
||
|
def create_response_field(
|
||
|
name: str,
|
||
|
type_: Type[Any],
|
||
|
class_validators: Optional[Dict[str, Validator]] = None,
|
||
|
default: Optional[Any] = None,
|
||
|
required: Union[bool, UndefinedType] = False,
|
||
|
model_config: Type[BaseConfig] = BaseConfig,
|
||
|
field_info: Optional[FieldInfo] = None,
|
||
|
alias: Optional[str] = None,
|
||
|
) -> ModelField:
|
||
|
"""
|
||
|
Create a new response field. Raises if type_ is invalid.
|
||
|
"""
|
||
|
class_validators = class_validators or {}
|
||
|
field_info = field_info or FieldInfo(None)
|
||
|
|
||
|
response_field = functools.partial(
|
||
|
ModelField,
|
||
|
name=name,
|
||
|
type_=type_,
|
||
|
class_validators=class_validators,
|
||
|
default=default,
|
||
|
required=required,
|
||
|
model_config=model_config,
|
||
|
alias=alias,
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
return response_field(field_info=field_info)
|
||
|
except RuntimeError:
|
||
|
raise fastapi.exceptions.FastAPIError(
|
||
|
f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
|
||
|
)
|
||
|
|
||
|
|
||
|
def create_cloned_field(
|
||
|
field: ModelField,
|
||
|
*,
|
||
|
cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None,
|
||
|
) -> ModelField:
|
||
|
# _cloned_types has already cloned types, to support recursive models
|
||
|
if cloned_types is None:
|
||
|
cloned_types = dict()
|
||
|
original_type = field.type_
|
||
|
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
|
||
|
original_type = original_type.__pydantic_model__
|
||
|
use_type = original_type
|
||
|
if lenient_issubclass(original_type, BaseModel):
|
||
|
original_type = cast(Type[BaseModel], original_type)
|
||
|
use_type = cloned_types.get(original_type)
|
||
|
if use_type is None:
|
||
|
use_type = create_model(original_type.__name__, __base__=original_type)
|
||
|
cloned_types[original_type] = use_type
|
||
|
for f in original_type.__fields__.values():
|
||
|
use_type.__fields__[f.name] = create_cloned_field(
|
||
|
f, cloned_types=cloned_types
|
||
|
)
|
||
|
new_field = create_response_field(name=field.name, type_=use_type)
|
||
|
new_field.has_alias = field.has_alias
|
||
|
new_field.alias = field.alias
|
||
|
new_field.class_validators = field.class_validators
|
||
|
new_field.default = field.default
|
||
|
new_field.required = field.required
|
||
|
new_field.model_config = field.model_config
|
||
|
new_field.field_info = field.field_info
|
||
|
new_field.allow_none = field.allow_none
|
||
|
new_field.validate_always = field.validate_always
|
||
|
if field.sub_fields:
|
||
|
new_field.sub_fields = [
|
||
|
create_cloned_field(sub_field, cloned_types=cloned_types)
|
||
|
for sub_field in field.sub_fields
|
||
|
]
|
||
|
if field.key_field:
|
||
|
new_field.key_field = create_cloned_field(
|
||
|
field.key_field, cloned_types=cloned_types
|
||
|
)
|
||
|
new_field.validators = field.validators
|
||
|
new_field.pre_validators = field.pre_validators
|
||
|
new_field.post_validators = field.post_validators
|
||
|
new_field.parse_json = field.parse_json
|
||
|
new_field.shape = field.shape
|
||
|
new_field.populate_validators()
|
||
|
return new_field
|
||
|
|
||
|
|
||
|
def generate_operation_id_for_path(*, name: str, path: str, method: str) -> str:
|
||
|
operation_id = name + path
|
||
|
operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id)
|
||
|
operation_id = operation_id + "_" + method.lower()
|
||
|
return operation_id
|
||
|
|
||
|
|
||
|
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
|
||
|
for key in update_dict:
|
||
|
if (
|
||
|
key in main_dict
|
||
|
and isinstance(main_dict[key], dict)
|
||
|
and isinstance(update_dict[key], dict)
|
||
|
):
|
||
|
deep_dict_update(main_dict[key], update_dict[key])
|
||
|
else:
|
||
|
main_dict[key] = update_dict[key]
|
||
|
|
||
|
|
||
|
def get_value_or_default(
|
||
|
first_item: Union[DefaultPlaceholder, DefaultType],
|
||
|
*extra_items: Union[DefaultPlaceholder, DefaultType],
|
||
|
) -> Union[DefaultPlaceholder, DefaultType]:
|
||
|
"""
|
||
|
Pass items or `DefaultPlaceholder`s by descending priority.
|
||
|
|
||
|
The first one to _not_ be a `DefaultPlaceholder` will be returned.
|
||
|
|
||
|
Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
|
||
|
"""
|
||
|
items = (first_item,) + extra_items
|
||
|
for item in items:
|
||
|
if not isinstance(item, DefaultPlaceholder):
|
||
|
return item
|
||
|
return first_item
|