Skip to content

Commit

Permalink
change pydntic blog from P1 to Alpha
Browse files Browse the repository at this point in the history
  • Loading branch information
BuxianChen committed Feb 20, 2024
1 parent d928f24 commit b1f8728
Showing 1 changed file with 173 additions and 4 deletions.
177 changes: 173 additions & 4 deletions _drafts/2024-02-18-pydantic.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
layout: post
title: "(P1) pydantic tutorial"
title: "(Alpha) pydantic tutorial"
date: 2024-02-18 11:10:04 +0800
labels: [python]
---
Expand All @@ -19,16 +19,185 @@ Questions:

### V1

#### type hint
#### 例子

#### validator
本节展示一个例子, 基本上能覆盖大多数使用

```python
from pydantic.v1 import BaseModel, Field, Extra, validator, root_validator
from typing import Annotated, Optional, List
class Request(BaseModel):
# query, temperature, other_notes 展示了几种 type hint 的写法
query: str
temperature: float = Field(description="the temperature", ge=0.0, lt=2.0) # pydantic 会检查 Field 定义的约束
other_notes: Annotated[str, Field(description="tools", examples=["calculator", "python"])]
stop_words: Optional[List[str]] = None

# pydantic 的一些内置检查选项
class Config:
max_anystr_length = 10 # 任何字符串形式的字段长度不超过 10
extra = Extra.forbid # 禁止传入多余字段

# 通过指定 pre=True 先于后面的 validate_stop_word_length 检查
@validator("stop_words", pre=True)
def split_stop_words(cls, v):
if isinstance(v, str):
return v.split("|")
return v

@validator("stop_words")
def validate_stop_word_length(cls, v):
# 至多只能设置 4 个 stop word
if len(v) > 4:
raise ValueError(f'stop words more than 4')
return v # 注意需要返回数据

# 可以对多个字段采用相同的检查
@validator("query", "other_notes")
def validate_min_length(cls, v):
if len(v) == 0:
raise ValueError(f"empty string")
return v

# 对整个数据结构进行整体检查
@root_validator
def validate_context_length(cls, values):
query = values.get("query")
other_notes = values.get("other_notes")
if len(query) + len(other_notes) > 15:
raise ValueError("context length more than 15")
return values


req = Request(temperature=1.0, other_notes="note note", query="2+3", stop_words=["2", "3", "4"])
req = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4")
# err = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4", xx = 2) # Error!
print(req.dict()) # 转换为字典, v2 应该使用 model_dump
print(Request.schema()) # 输出 json schema, v2 应该使用 model_json_schema
```

输出:

```
# req.dict()
{'query': '1+1',
'temperature': 1.0,
'other_notes': 'calculate',
'stop_words': ['2', '3', '4']}
# Request.schema()
{'title': 'Request',
'type': 'object',
'properties': {'query': {'title': 'Query', 'type': 'string'},
'temperature': {'title': 'Temperature',
'description': 'the temperature',
'exclusiveMaximum': 2.0,
'minimum': 0.0,
'type': 'number'},
'other_notes': {'title': 'Other Notes',
'description': 'tools',
'examples': ['calculator', 'python'],
'type': 'string'},
'stop_words': {'title': 'Stop Words',
'type': 'array',
'items': {'type': 'string'}}},
'required': ['query', 'temperature', 'other_notes'],
'additionalProperties': False}
```

这个例子用 pydantic V2 写如下: 总的来说差异还是比较多的, 主要是各种方法名, 字段名的修改

```python
from pydantic import BaseModel, Field, model_validator, field_validator, ConfigDict
from typing import Annotated, Optional, List


class Request(BaseModel):
query: str
temperature: float = Field(description="the temperature", ge=0.0, lt=2.0) # pydantic 会检查 Field 定义的约束
other_notes: Annotated[str, Field(description="tools", examples=["calculator", "python"])]
stop_words: Optional[List[str]] = None
# Config 类变成了一个字段: model_config
# Extra.forbit 变成了字符串 "forbid"
model_config = ConfigDict(str_max_length=10, extra="forbid")

# 注意 pre=True/False 改为了 mode="after"/"before"
# validator (V1) -> field_validator (V2)
@field_validator("stop_words", mode="before")
@classmethod # 注意需要增加 classmethod 装饰器, 且需要位于 field_validator 之后
def split_stop_words(cls, v):
if isinstance(v, str):
return v.split("|")
return v

@field_validator("stop_words")
@classmethod
def validate_stop_word_length(cls, v):
if len(v) > 4:
raise ValueError(f'stop words more than 4')
return v

@field_validator("query", "other_notes")
@classmethod
def validate_min_length(cls, v):
if len(v) == 0:
raise ValueError(f"empty string")
return v

# root_validator -> model_validator
@model_validator(mode="after")
@classmethod # 注意需要增加 classmethod 装饰器, 且需要位于 model_validator 之后
def validate_context_length(cls, values):
query = values.query # 注意 V2 用的点运算符, values: Request
other_notes = values.other_notes
if len(query) + len(other_notes) > 15:
raise ValueError("context length more than 15")
return values


req = Request(temperature=1.0, other_notes="note note", query="2+3", stop_words=["2", "3", "4"])
req = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4")
# err = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4", xx = 2) # Error!
print(req.model_dump()) # dict -> model_dump
print(Request.model_json_schema()) # schema -> model_json_schema
```

#### Type Hint & Field & Annotated

继承自 `BaseModel` 的类的属性必须有 type hint, 有以下三种方式:

- 只使用普通的 type hint: 这种情况下, pydantic 会去校验数据项是否满足类型约束
- 使用普通的 type hint, 再补充一个 `Field`: 这种情况下, pydantic 会去校验数据项是否满足类型约束, 并且会检查 `Field` 中描述的约束
- 使用 `typing.Annotated`, 本质上与第二种方法一样.

备注:

`data: typing.Annotated[T, x]` 是对普通的 type hint 的增强, 其中 `T` 时类型名, `x` 是任意数据, 代表 metadata. 在 python 运行时, 无论是 type hint 以及 metadata, 都不会对 `data` 本身做校验. 但 pydantic 会利用这些信息进行数据校验.

```python
x: Annotated[str, "desc"] = "123"
Annotated[str, "desc"].__metadata__ # ("desc",)
```

`Field` 实际上是一个函数, 其返回类型是 `FieldInfo`

```python
Field(description="the temperature", ge=0.0, lt=2.0)
# FieldInfo(default=PydanticUndefined, description='the temperature', ge=0.0, lt=2.0, extra={})
```

#### Validator

字段校验次序参考 [https://docs.pydantic.dev/1.10/usage/models/#field-ordering](https://docs.pydantic.dev/1.10/usage/models/#field-ordering), 简单来说与字段定义的书写顺序相关, 也与 `validator(pre=True)` 里的 `pre` 参数相关.

#### Config

Config 是 pydantic 内置的一些校验方法, 而 Validator 是自定义的校验手段


### V1 to V2

感觉 API 变化很大, 不理解为什么要从 V1 升到 V2
感觉 API 变化很大, 不理解为什么要从 V1 升到 V2 (TODO)

- `llama_index` ([v0.9.31](https://github.com/run-llama/llama_index/blob/d7839442ab080347291bff0946c1e1ea2a7486ab/llama_index/bridge/pydantic.py), 发布时间 2024/1/16) 使用的是 V1
- `langchain` (v0.1.0, 发布时间 2024/01/06): 似乎在试图兼容 V1 与 V2, 但是否实际都是使用 pydantic.v1?
Expand Down

0 comments on commit b1f8728

Please sign in to comment.