import functools import re from dataclasses import is_dataclass from enum import Enum from typing import Any, Dict, Optional, Set, Type, Union, cast import fastapi from fastapi.datastructures import DefaultPlaceholder, DefaultType from fastapi.openapi.constants import REF_PREFIX from pydantic import BaseConfig, BaseModel, create_model from pydantic.class_validators import Validator from pydantic.fields import FieldInfo, ModelField, UndefinedType from pydantic.schema import model_process_schema from pydantic.utils import lenient_issubclass def get_model_definitions( *, flat_models: Set[Union[Type[BaseModel], Type[Enum]]], model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], ) -> Dict[str, Any]: definitions: Dict[str, Dict[str, Any]] = {} for model in flat_models: m_schema, m_definitions, m_nested_models = model_process_schema( model, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) definitions.update(m_definitions) model_name = model_name_map[model] definitions[model_name] = m_schema return definitions def get_path_param_names(path: str) -> Set[str]: return set(re.findall("{(.*?)}", path)) def create_response_field( name: str, type_: Type[Any], class_validators: Optional[Dict[str, Validator]] = None, default: Optional[Any] = None, required: Union[bool, UndefinedType] = False, model_config: Type[BaseConfig] = BaseConfig, field_info: Optional[FieldInfo] = None, alias: Optional[str] = None, ) -> ModelField: """ Create a new response field. Raises if type_ is invalid. """ class_validators = class_validators or {} field_info = field_info or FieldInfo(None) response_field = functools.partial( ModelField, name=name, type_=type_, class_validators=class_validators, default=default, required=required, model_config=model_config, alias=alias, ) try: return response_field(field_info=field_info) except RuntimeError: raise fastapi.exceptions.FastAPIError( f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type" ) def create_cloned_field( field: ModelField, *, cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None, ) -> ModelField: # _cloned_types has already cloned types, to support recursive models if cloned_types is None: cloned_types = dict() original_type = field.type_ if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): original_type = original_type.__pydantic_model__ use_type = original_type if lenient_issubclass(original_type, BaseModel): original_type = cast(Type[BaseModel], original_type) use_type = cloned_types.get(original_type) if use_type is None: use_type = create_model(original_type.__name__, __base__=original_type) cloned_types[original_type] = use_type for f in original_type.__fields__.values(): use_type.__fields__[f.name] = create_cloned_field( f, cloned_types=cloned_types ) new_field = create_response_field(name=field.name, type_=use_type) new_field.has_alias = field.has_alias new_field.alias = field.alias new_field.class_validators = field.class_validators new_field.default = field.default new_field.required = field.required new_field.model_config = field.model_config new_field.field_info = field.field_info new_field.allow_none = field.allow_none new_field.validate_always = field.validate_always if field.sub_fields: new_field.sub_fields = [ create_cloned_field(sub_field, cloned_types=cloned_types) for sub_field in field.sub_fields ] if field.key_field: new_field.key_field = create_cloned_field( field.key_field, cloned_types=cloned_types ) new_field.validators = field.validators new_field.pre_validators = field.pre_validators new_field.post_validators = field.post_validators new_field.parse_json = field.parse_json new_field.shape = field.shape new_field.populate_validators() return new_field def generate_operation_id_for_path(*, name: str, path: str, method: str) -> str: operation_id = name + path operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id) operation_id = operation_id + "_" + method.lower() return operation_id def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: for key in update_dict: if ( key in main_dict and isinstance(main_dict[key], dict) and isinstance(update_dict[key], dict) ): deep_dict_update(main_dict[key], update_dict[key]) else: main_dict[key] = update_dict[key] def get_value_or_default( first_item: Union[DefaultPlaceholder, DefaultType], *extra_items: Union[DefaultPlaceholder, DefaultType], ) -> Union[DefaultPlaceholder, DefaultType]: """ Pass items or `DefaultPlaceholder`s by descending priority. The first one to _not_ be a `DefaultPlaceholder` will be returned. Otherwise, the first item (a `DefaultPlaceholder`) will be returned. """ items = (first_item,) + extra_items for item in items: if not isinstance(item, DefaultPlaceholder): return item return first_item