Mabasej_Team/.venv/lib/python3.9/site-packages/pydantic/dataclasses.py
Untriex Programming ed6afdb5c9 new
2021-03-17 08:57:57 +01:00

268 lines
9.3 KiB
Python

from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
from .class_validators import gather_all_validators
from .error_wrappers import ValidationError
from .errors import DataclassTypeError
from .fields import Field, FieldInfo, Required, Undefined
from .main import create_model, validate_model
from .typing import resolve_annotations
from .utils import ClassAttribute
if TYPE_CHECKING:
from .main import BaseConfig, BaseModel # noqa: F401
from .typing import CallableGenerator, NoArgAnyCallable
DataclassT = TypeVar('DataclassT', bound='Dataclass')
class Dataclass:
__pydantic_model__: Type[BaseModel]
__initialised__: bool
__post_init_original__: Optional[Callable[..., None]]
__processed__: Optional[ClassAttribute]
__has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
@classmethod
def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
pass
@classmethod
def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
pass
def __call__(self: 'DataclassT', *args: Any, **kwargs: Any) -> 'DataclassT':
pass
def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
if isinstance(v, cls):
return v
elif isinstance(v, (list, tuple)):
return cls(*v)
elif isinstance(v, dict):
return cls(**v)
# In nested dataclasses, v can be of type `dataclasses.dataclass`.
# But to validate fields `cls` will be in fact a `pydantic.dataclasses.dataclass`,
# which inherits directly from the class of `v`.
elif is_builtin_dataclass(v) and cls.__bases__[0] is type(v):
import dataclasses
return cls(**dataclasses.asdict(v))
else:
raise DataclassTypeError(class_name=cls.__name__)
def _get_validators(cls: Type['Dataclass']) -> 'CallableGenerator':
yield cls.__validate__
def setattr_validate_assignment(self: 'Dataclass', name: str, value: Any) -> None:
if self.__initialised__:
d = dict(self.__dict__)
d.pop(name, None)
known_field = self.__pydantic_model__.__fields__.get(name, None)
if known_field:
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
if error_:
raise ValidationError([error_], self.__class__)
object.__setattr__(self, name, value)
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
"""
`dataclasses.is_dataclass` is True if one of the class parents is a `dataclass`.
This is why we also add a class attribute `__processed__` to only consider 'direct' built-in dataclasses
"""
import dataclasses
return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls)
def _generate_pydantic_post_init(
post_init_original: Optional[Callable[..., None]], post_init_post_parse: Optional[Callable[..., None]]
) -> Callable[..., None]:
def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
if post_init_original is not None:
post_init_original(self, *initvars)
if getattr(self, '__has_field_info_default__', False):
# We need to remove `FieldInfo` values since they are not valid as input
# It's ok to do that because they are obviously the default values!
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
else:
input_data = self.__dict__
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
if validation_error:
raise validation_error
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if post_init_post_parse is not None:
post_init_post_parse(self, *initvars)
return _pydantic_post_init
def _process_class(
_cls: Type[Any],
init: bool,
repr: bool,
eq: bool,
order: bool,
unsafe_hash: bool,
frozen: bool,
config: Optional[Type[Any]],
) -> Type['Dataclass']:
import dataclasses
post_init_original = getattr(_cls, '__post_init__', None)
if post_init_original and post_init_original.__name__ == '_pydantic_post_init':
post_init_original = None
if not post_init_original:
post_init_original = getattr(_cls, '__post_init_original__', None)
post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None)
_pydantic_post_init = _generate_pydantic_post_init(post_init_original, post_init_post_parse)
# If the class is already a dataclass, __post_init__ will not be called automatically
# so no validation will be added.
# We hence create dynamically a new dataclass:
# ```
# @dataclasses.dataclass
# class NewClass(_cls):
# __post_init__ = _pydantic_post_init
# ```
# with the exact same fields as the base dataclass
# and register it on module level to address pickle problem:
# https://github.com/samuelcolvin/pydantic/issues/2111
if is_builtin_dataclass(_cls):
uniq_class_name = f'_Pydantic_{_cls.__name__}_{id(_cls)}'
_cls = type(
# for pretty output new class will have the name as original
_cls.__name__,
(_cls,),
{
'__annotations__': resolve_annotations(_cls.__annotations__, _cls.__module__),
'__post_init__': _pydantic_post_init,
# attrs for pickle to find this class
'__module__': __name__,
'__qualname__': uniq_class_name,
},
)
globals()[uniq_class_name] = _cls
else:
_cls.__post_init__ = _pydantic_post_init
cls: Type['Dataclass'] = dataclasses.dataclass( # type: ignore
_cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
cls.__processed__ = ClassAttribute('__processed__', True)
field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(cls):
default: Any = Undefined
default_factory: Optional['NoArgAnyCallable'] = None
field_info: FieldInfo
if field.default is not dataclasses.MISSING:
default = field.default
# mypy issue 7020 and 708
elif field.default_factory is not dataclasses.MISSING: # type: ignore
default_factory = field.default_factory # type: ignore
else:
default = Required
if isinstance(default, FieldInfo):
field_info = default
cls.__has_field_info_default__ = True
else:
field_info = Field(default=default, default_factory=default_factory, **field.metadata)
field_definitions[field.name] = (field.type, field_info)
validators = gather_all_validators(cls)
cls.__pydantic_model__ = create_model(
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
)
cls.__initialised__ = False
cls.__validate__ = classmethod(_validate_dataclass) # type: ignore[assignment]
cls.__get_validators__ = classmethod(_get_validators) # type: ignore[assignment]
if post_init_original:
cls.__post_init_original__ = post_init_original
if cls.__pydantic_model__.__config__.validate_assignment and not frozen:
cls.__setattr__ = setattr_validate_assignment # type: ignore[assignment]
return cls
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type[Any] = None,
) -> Callable[[Type[Any]], Type['Dataclass']]:
...
@overload
def dataclass(
_cls: Type[Any],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type[Any] = None,
) -> Type['Dataclass']:
...
def dataclass(
_cls: Optional[Type[Any]] = None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type[Any] = None,
) -> Union[Callable[[Type[Any]], Type['Dataclass']], Type['Dataclass']]:
"""
Like the python standard lib dataclasses but with type validation.
Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning
as Config.validate_assignment.
"""
def wrap(cls: Type[Any]) -> Type['Dataclass']:
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config)
if _cls is None:
return wrap
return wrap(_cls)
def make_dataclass_validator(_cls: Type[Any], config: Type['BaseConfig']) -> 'CallableGenerator':
"""
Create a pydantic.dataclass from a builtin dataclass to add type validation
and yield the validators
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
"""
dataclass_params = _cls.__dataclass_params__
stdlib_dataclass_parameters = {param: getattr(dataclass_params, param) for param in dataclass_params.__slots__}
cls = dataclass(_cls, config=config, **stdlib_dataclass_parameters)
yield from _get_validators(cls)