Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pydantic v2 support with backwards compatibility #86

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion itemadapter/_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
attr = None # type: ignore [assignment]

try:
import pydantic.v1 as pydantic_v1 # pylint: disable=W0611 (unused-import)
import pydantic # pylint: disable=W0611 (unused-import)
except ImportError:
pydantic = None # type: ignore [assignment]
try:
import pydantic as pydantic_v1 # pylint: disable=W0611 (unused-import)
pydantic = None # type: ignore [assignment]
except ImportError:
# Handle the case where neither pydantic.v1 nor pydantic is available
pydantic_v1 = None
64 changes: 60 additions & 4 deletions itemadapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from itemadapter._imports import _scrapy_item_classes, attr
from itemadapter.utils import (
_get_pydantic_model_metadata,
_get_pydantic_v1_model_metadata,
_is_attrs_class,
_is_pydantic_model,
_is_pydantic_v1_model
)

__all__ = [
Expand All @@ -18,6 +20,7 @@
"DataclassAdapter",
"DictAdapter",
"ItemAdapter",
"PydanticV1Adapter",
"PydanticAdapter",
"ScrapyItemAdapter",
]
Expand Down Expand Up @@ -162,17 +165,17 @@ def get_field_names_from_class(cls, item_class: type) -> Optional[List[str]]:
return [a.name for a in dataclasses.fields(item_class)]


class PydanticAdapter(AdapterInterface):
class PydanticV1Adapter(AdapterInterface):
item: Any

@classmethod
def is_item_class(cls, item_class: type) -> bool:
return _is_pydantic_model(item_class)
return _is_pydantic_v1_model(item_class)

@classmethod
def get_field_meta_from_class(cls, item_class: type, field_name: str) -> MappingProxyType:
try:
return _get_pydantic_model_metadata(item_class, field_name)
return _get_pydantic_v1_model_metadata(item_class, field_name)
except KeyError:
raise KeyError(f"{item_class.__name__} does not support field: {field_name}")

Expand Down Expand Up @@ -213,6 +216,58 @@ def __len__(self) -> int:
return len(list(iter(self)))


class PydanticAdapter(AdapterInterface):
item: Any
import pydantic

@classmethod
def is_item_class(cls, item_class: type) -> bool:
return _is_pydantic_model(item_class)

@classmethod
def get_field_meta_from_class(cls, item_class: type, field_name: str) -> MappingProxyType:
try:
return _get_pydantic_model_metadata(item_class, field_name)
except KeyError:
raise KeyError(f"{item_class.__name__} does not support field: {field_name}")

@classmethod
def get_field_names_from_class(cls, item_class: pydantic.BaseModel) -> Optional[List[str]]:
return list(item_class.model_fields.keys()) # type: ignore[attr-defined]

def field_names(self) -> KeysView:
return KeysView(self.item.model_fields)

def __getitem__(self, field_name: str) -> Any:
if field_name in self.item.model_fields:
return getattr(self.item, field_name)
raise KeyError(field_name)

def __setitem__(self, field_name: str, value: Any) -> None:
if field_name in self.item.model_fields:
setattr(self.item, field_name, value)
else:
raise KeyError(f"{self.item.__class__.__name__} does not support field: {field_name}")

def __delitem__(self, field_name: str) -> None:
if field_name in self.item.model_fields:
try:
if hasattr(self.item, field_name):
delattr(self.item, field_name)
else:
raise AttributeError
except AttributeError:
raise KeyError(field_name)
else:
raise KeyError(f"{self.item.__class__.__name__} does not support field: {field_name}")

def __iter__(self) -> Iterator:
return iter(attr for attr in self.item.model_fields if hasattr(self.item, attr))

def __len__(self) -> int:
return len(list(iter(self)))


class _MixinDictScrapyItemAdapter:
_fields_dict: dict
item: Any
Expand Down Expand Up @@ -278,7 +333,8 @@ class ItemAdapter(MutableMapping):
DictAdapter,
DataclassAdapter,
AttrsAdapter,
PydanticAdapter,
PydanticV1Adapter,
PydanticAdapter
]
)

Expand Down
57 changes: 56 additions & 1 deletion itemadapter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from types import MappingProxyType
from typing import Any

from itemadapter._imports import attr, pydantic
from itemadapter._imports import attr, pydantic, pydantic_v1

__all__ = ["is_item", "get_field_meta_from_class"]

Expand All @@ -19,7 +19,62 @@ def _is_pydantic_model(obj: Any) -> bool:
return issubclass(obj, pydantic.BaseModel)


def _is_pydantic_v1_model(obj: Any) -> bool:
if pydantic_v1 is None:
return False
return issubclass(obj, pydantic_v1.BaseModel)


def _get_pydantic_model_metadata(item_model: Any, field_name: str) -> MappingProxyType:
metadata = {}
field = item_model.model_fields[field_name]

for attribute in [
"default",
"default_factory",
"alias",
"alias_priority",
"validation_alias",
"serialization_alias",
"title",
"field_title_generator",
"description",
"examples",
"exclude",
"discriminator",
"deprecated",
"json_schema_extra",
"frozen",
"validate_default",
"repr",
"init",
"init_var",
"kw_only",
"pattern",
"strict",
"coerce_numbers_to_str",
"gt",
"ge",
"lt",
"le",
"multiple_of",
"allow_inf_nan",
"max_digits",
"decimal_places",
"min_length",
"max_length",
"union_mode",
"fail_fast",
]:
if hasattr(field, attribute) and (value := getattr(field, attribute)) is not None:
metadata[attribute] = value
# if field.json_schema_extra is not None:
# metadata.update(field.json_schema_extra)

return MappingProxyType(metadata)


def _get_pydantic_v1_model_metadata(item_model: Any, field_name: str) -> MappingProxyType:
metadata = {}
field = item_model.__fields__[field_name].field_info

Expand Down