268 lines
9.3 KiB
Python
268 lines
9.3 KiB
Python
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
|
|
|
|
from .class_validators import gather_all_validators
|
|
from .error_wrappers import ValidationError
|
|
from .errors import DataclassTypeError
|
|
from .fields import Field, FieldInfo, Required, Undefined
|
|
from .main import create_model, validate_model
|
|
from .typing import resolve_annotations
|
|
from .utils import ClassAttribute
|
|
|
|
if TYPE_CHECKING:
|
|
from .main import BaseConfig, BaseModel # noqa: F401
|
|
from .typing import CallableGenerator, NoArgAnyCallable
|
|
|
|
DataclassT = TypeVar('DataclassT', bound='Dataclass')
|
|
|
|
class Dataclass:
|
|
__pydantic_model__: Type[BaseModel]
|
|
__initialised__: bool
|
|
__post_init_original__: Optional[Callable[..., None]]
|
|
__processed__: Optional[ClassAttribute]
|
|
__has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
|
|
pass
|
|
|
|
@classmethod
|
|
def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
|
|
pass
|
|
|
|
def __call__(self: 'DataclassT', *args: Any, **kwargs: Any) -> 'DataclassT':
|
|
pass
|
|
|
|
|
|
def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
|
|
if isinstance(v, cls):
|
|
return v
|
|
elif isinstance(v, (list, tuple)):
|
|
return cls(*v)
|
|
elif isinstance(v, dict):
|
|
return cls(**v)
|
|
# In nested dataclasses, v can be of type `dataclasses.dataclass`.
|
|
# But to validate fields `cls` will be in fact a `pydantic.dataclasses.dataclass`,
|
|
# which inherits directly from the class of `v`.
|
|
elif is_builtin_dataclass(v) and cls.__bases__[0] is type(v):
|
|
import dataclasses
|
|
|
|
return cls(**dataclasses.asdict(v))
|
|
else:
|
|
raise DataclassTypeError(class_name=cls.__name__)
|
|
|
|
|
|
def _get_validators(cls: Type['Dataclass']) -> 'CallableGenerator':
|
|
yield cls.__validate__
|
|
|
|
|
|
def setattr_validate_assignment(self: 'Dataclass', name: str, value: Any) -> None:
|
|
if self.__initialised__:
|
|
d = dict(self.__dict__)
|
|
d.pop(name, None)
|
|
known_field = self.__pydantic_model__.__fields__.get(name, None)
|
|
if known_field:
|
|
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
|
|
if error_:
|
|
raise ValidationError([error_], self.__class__)
|
|
|
|
object.__setattr__(self, name, value)
|
|
|
|
|
|
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
|
|
"""
|
|
`dataclasses.is_dataclass` is True if one of the class parents is a `dataclass`.
|
|
This is why we also add a class attribute `__processed__` to only consider 'direct' built-in dataclasses
|
|
"""
|
|
import dataclasses
|
|
|
|
return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls)
|
|
|
|
|
|
def _generate_pydantic_post_init(
|
|
post_init_original: Optional[Callable[..., None]], post_init_post_parse: Optional[Callable[..., None]]
|
|
) -> Callable[..., None]:
|
|
def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
|
|
if post_init_original is not None:
|
|
post_init_original(self, *initvars)
|
|
|
|
if getattr(self, '__has_field_info_default__', False):
|
|
# We need to remove `FieldInfo` values since they are not valid as input
|
|
# It's ok to do that because they are obviously the default values!
|
|
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
|
|
else:
|
|
input_data = self.__dict__
|
|
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
|
|
if validation_error:
|
|
raise validation_error
|
|
object.__setattr__(self, '__dict__', d)
|
|
object.__setattr__(self, '__initialised__', True)
|
|
if post_init_post_parse is not None:
|
|
post_init_post_parse(self, *initvars)
|
|
|
|
return _pydantic_post_init
|
|
|
|
|
|
def _process_class(
|
|
_cls: Type[Any],
|
|
init: bool,
|
|
repr: bool,
|
|
eq: bool,
|
|
order: bool,
|
|
unsafe_hash: bool,
|
|
frozen: bool,
|
|
config: Optional[Type[Any]],
|
|
) -> Type['Dataclass']:
|
|
import dataclasses
|
|
|
|
post_init_original = getattr(_cls, '__post_init__', None)
|
|
if post_init_original and post_init_original.__name__ == '_pydantic_post_init':
|
|
post_init_original = None
|
|
if not post_init_original:
|
|
post_init_original = getattr(_cls, '__post_init_original__', None)
|
|
|
|
post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None)
|
|
|
|
_pydantic_post_init = _generate_pydantic_post_init(post_init_original, post_init_post_parse)
|
|
|
|
# If the class is already a dataclass, __post_init__ will not be called automatically
|
|
# so no validation will be added.
|
|
# We hence create dynamically a new dataclass:
|
|
# ```
|
|
# @dataclasses.dataclass
|
|
# class NewClass(_cls):
|
|
# __post_init__ = _pydantic_post_init
|
|
# ```
|
|
# with the exact same fields as the base dataclass
|
|
# and register it on module level to address pickle problem:
|
|
# https://github.com/samuelcolvin/pydantic/issues/2111
|
|
if is_builtin_dataclass(_cls):
|
|
uniq_class_name = f'_Pydantic_{_cls.__name__}_{id(_cls)}'
|
|
_cls = type(
|
|
# for pretty output new class will have the name as original
|
|
_cls.__name__,
|
|
(_cls,),
|
|
{
|
|
'__annotations__': resolve_annotations(_cls.__annotations__, _cls.__module__),
|
|
'__post_init__': _pydantic_post_init,
|
|
# attrs for pickle to find this class
|
|
'__module__': __name__,
|
|
'__qualname__': uniq_class_name,
|
|
},
|
|
)
|
|
globals()[uniq_class_name] = _cls
|
|
else:
|
|
_cls.__post_init__ = _pydantic_post_init
|
|
cls: Type['Dataclass'] = dataclasses.dataclass( # type: ignore
|
|
_cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
|
|
)
|
|
cls.__processed__ = ClassAttribute('__processed__', True)
|
|
|
|
field_definitions: Dict[str, Any] = {}
|
|
for field in dataclasses.fields(cls):
|
|
default: Any = Undefined
|
|
default_factory: Optional['NoArgAnyCallable'] = None
|
|
field_info: FieldInfo
|
|
|
|
if field.default is not dataclasses.MISSING:
|
|
default = field.default
|
|
# mypy issue 7020 and 708
|
|
elif field.default_factory is not dataclasses.MISSING: # type: ignore
|
|
default_factory = field.default_factory # type: ignore
|
|
else:
|
|
default = Required
|
|
|
|
if isinstance(default, FieldInfo):
|
|
field_info = default
|
|
cls.__has_field_info_default__ = True
|
|
else:
|
|
field_info = Field(default=default, default_factory=default_factory, **field.metadata)
|
|
|
|
field_definitions[field.name] = (field.type, field_info)
|
|
|
|
validators = gather_all_validators(cls)
|
|
cls.__pydantic_model__ = create_model(
|
|
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
|
|
)
|
|
|
|
cls.__initialised__ = False
|
|
cls.__validate__ = classmethod(_validate_dataclass) # type: ignore[assignment]
|
|
cls.__get_validators__ = classmethod(_get_validators) # type: ignore[assignment]
|
|
if post_init_original:
|
|
cls.__post_init_original__ = post_init_original
|
|
|
|
if cls.__pydantic_model__.__config__.validate_assignment and not frozen:
|
|
cls.__setattr__ = setattr_validate_assignment # type: ignore[assignment]
|
|
|
|
return cls
|
|
|
|
|
|
@overload
|
|
def dataclass(
|
|
*,
|
|
init: bool = True,
|
|
repr: bool = True,
|
|
eq: bool = True,
|
|
order: bool = False,
|
|
unsafe_hash: bool = False,
|
|
frozen: bool = False,
|
|
config: Type[Any] = None,
|
|
) -> Callable[[Type[Any]], Type['Dataclass']]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def dataclass(
|
|
_cls: Type[Any],
|
|
*,
|
|
init: bool = True,
|
|
repr: bool = True,
|
|
eq: bool = True,
|
|
order: bool = False,
|
|
unsafe_hash: bool = False,
|
|
frozen: bool = False,
|
|
config: Type[Any] = None,
|
|
) -> Type['Dataclass']:
|
|
...
|
|
|
|
|
|
def dataclass(
|
|
_cls: Optional[Type[Any]] = None,
|
|
*,
|
|
init: bool = True,
|
|
repr: bool = True,
|
|
eq: bool = True,
|
|
order: bool = False,
|
|
unsafe_hash: bool = False,
|
|
frozen: bool = False,
|
|
config: Type[Any] = None,
|
|
) -> Union[Callable[[Type[Any]], Type['Dataclass']], Type['Dataclass']]:
|
|
"""
|
|
Like the python standard lib dataclasses but with type validation.
|
|
|
|
Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning
|
|
as Config.validate_assignment.
|
|
"""
|
|
|
|
def wrap(cls: Type[Any]) -> Type['Dataclass']:
|
|
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config)
|
|
|
|
if _cls is None:
|
|
return wrap
|
|
|
|
return wrap(_cls)
|
|
|
|
|
|
def make_dataclass_validator(_cls: Type[Any], config: Type['BaseConfig']) -> 'CallableGenerator':
|
|
"""
|
|
Create a pydantic.dataclass from a builtin dataclass to add type validation
|
|
and yield the validators
|
|
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
|
|
"""
|
|
dataclass_params = _cls.__dataclass_params__
|
|
stdlib_dataclass_parameters = {param: getattr(dataclass_params, param) for param in dataclass_params.__slots__}
|
|
cls = dataclass(_cls, config=config, **stdlib_dataclass_parameters)
|
|
yield from _get_validators(cls)
|