图
除非你需要钉枪,否则别用钉枪
如果说 Pydantic AI 代理是锤子,多代理工作流是长柄大锤,那么图就是钉枪。
- 当然,钉枪看起来比锤子酷。
- 但是钉枪比锤子需要更多的设置工作。
- 而且钉枪不会让你成为一个更好的建造者,只会让你成为一个拥有钉枪的建造者。
- 最后(冒着过度使用这个比喻的风险),如果你是像木槌和无类型 Python 这样的中世纪工具的粉丝,你可能不会喜欢钉枪或我们的图方法。(但话说回来,如果你不喜欢 Python 中的类型提示,你可能已经从 Pydantic AI 转而去使用那些玩具代理框架了——祝你好运,当你意识到你需要长柄大锤时,随时可以来借我的。)
简而言之,图是一个强大的工具,但它并非适用于所有工作。在继续之前,请考虑其他多代理方法。
如果你不确定基于图的方法是个好主意,那么它可能是不必要的。
图和有限状态机(FSM)是用于建模、执行、控制和可视化复杂工作流的强大抽象。
与 Pydantic AI 一起,我们开发了 pydantic-graph
—— 一个用于 Python 的异步图和状态机库,其中的节点和边是使用类型提示定义的。
虽然这个库是作为 Pydantic AI 的一部分开发的,但它对 pydantic-ai
没有任何依赖,可以被视为一个纯粹的基于图的状态机库。无论你是否在使用 Pydantic AI,甚至是否在用生成式 AI 构建应用,你都可能会发现它很有用。
pydantic-graph
是为高级用户设计的,大量使用了 Python 的泛型和类型提示。它不像 Pydantic AI 那样对初学者友好。
安装
pydantic-graph
是 pydantic-ai
的必需依赖项,是 pydantic-ai-slim
的可选依赖项,更多信息请参见安装说明。你也可以直接安装它。
pip install pydantic-graph
uv add pydantic-graph
图类型
pydantic-graph
由几个关键组件组成
GraphRunContext
GraphRunContext
— 图运行的上下文,类似于 Pydantic AI 的 RunContext
。它持有图的状态和依赖项,并在节点运行时传递给节点。
GraphRunContext
在其使用的图的状态类型 StateT
上是泛型的。
End
End
— 返回值,用于指示图的运行应该结束。
End
在其使用的图的图返回类型 RunEndT
上是泛型的。
节点
BaseNode
的子类定义了在图中执行的节点。
节点通常是dataclass
,通常由以下部分组成:
节点在以下方面是泛型的:
- state(状态),它必须与包含它们的图的状态具有相同的类型,
StateT
的默认值为None
,所以如果你不使用状态,可以省略这个泛型参数,更多信息请参见有状态图 - deps(依赖),它必须与包含它们的图的依赖具有相同的类型,
DepsT
的默认值为None
,所以如果你不使用依赖,可以省略这个泛型参数,更多信息请参见依赖注入 - graph return type(图返回类型)— 这仅在节点返回
End
时适用。RunEndT
的默认值为 Never,因此如果节点不返回End
,可以省略这个泛型参数,但如果返回,则必须包含。
这是一个图中开始节点或中间节点的示例——它不能结束运行,因为它不返回 End
from dataclasses import dataclass
from pydantic_graph import BaseNode, GraphRunContext
@dataclass
class MyNode(BaseNode[MyState]): # (1)!
foo: int # (2)!
async def run(
self,
ctx: GraphRunContext[MyState], # (3)!
) -> AnotherNode: # (4)!
...
return AnotherNode()
- 本例中的状态是
MyState
(未显示),因此BaseNode
使用MyState
进行参数化。此节点无法结束运行,因此省略了RunEndT
泛型参数,其默认值为Never
。 MyNode
是一个 dataclass,它有一个字段foo
,类型为int
。run
方法接受一个GraphRunContext
参数,同样使用状态MyState
进行参数化。run
方法的返回类型是AnotherNode
(未显示),这用于确定节点的出边。
我们可以扩展 MyNode
,使其在 foo
能被 5 整除时可以选择性地结束运行。
from dataclasses import dataclass
from pydantic_graph import BaseNode, End, GraphRunContext
@dataclass
class MyNode(BaseNode[MyState, None, int]): # (1)!
foo: int
async def run(
self,
ctx: GraphRunContext[MyState],
) -> AnotherNode | End[int]: # (2)!
if self.foo % 5 == 0:
return End(self.foo)
else:
return AnotherNode()
- 我们使用返回类型(本例中为
int
)以及状态来参数化节点。因为泛型参数是仅位置参数,我们必须包含None
作为代表 deps 的第二个参数。 run
方法的返回类型现在是AnotherNode
和End[int]
的联合类型,这允许节点在foo
能被 5 整除时结束运行。
图
Graph
— 这是执行图本身,由一组节点类(即 BaseNode
的子类)组成。
Graph
在以下方面是泛型的:
这是一个简单图的示例
from __future__ import annotations
from dataclasses import dataclass
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
@dataclass
class DivisibleBy5(BaseNode[None, None, int]): # (1)!
foo: int
async def run(
self,
ctx: GraphRunContext,
) -> Increment | End[int]:
if self.foo % 5 == 0:
return End(self.foo)
else:
return Increment(self.foo)
@dataclass
class Increment(BaseNode): # (2)!
foo: int
async def run(self, ctx: GraphRunContext) -> DivisibleBy5:
return DivisibleBy5(self.foo + 1)
fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)!
result = fives_graph.run_sync(DivisibleBy5(4)) # (4)!
print(result.output)
#> 5
DivisibleBy5
节点的状态参数和依赖参数都用None
进行了参数化,因为此图不使用状态或依赖;而因为它可能结束运行,所以返回类型为int
。Increment
节点不返回End
,因此省略了RunEndT
泛型参数;由于图不使用状态,状态参数也可以省略。- 图是通过一个节点序列创建的。
- 图使用
run_sync
同步运行。初始节点是DivisibleBy5(4)
。因为图不使用外部状态或依赖,我们不传递state
或deps
。
(这个例子是完整的,可以“按原样”运行)
此图的 Mermaid 图表可以使用以下代码生成
from graph_example import DivisibleBy5, fives_graph
fives_graph.mermaid_code(start_node=DivisibleBy5)
---
title: fives_graph
---
stateDiagram-v2
[*] --> DivisibleBy5
DivisibleBy5 --> Increment
DivisibleBy5 --> [*]
Increment --> DivisibleBy5
为了在 jupyter-notebook
中可视化图,需要使用 IPython.display
。
from graph_example import DivisibleBy5, fives_graph
from IPython.display import Image, display
display(Image(fives_graph.mermaid_image(start_node=DivisibleBy5)))
有状态图
pydantic-graph
中的“状态”概念提供了一种可选的方式,在图中的节点运行时访问和修改一个对象(通常是 dataclass
或 Pydantic 模型)。如果你将图想象成一条生产线,那么你的状态就是沿着生产线传递并由每个节点在图运行时逐步构建的引擎。
pydantic-graph
提供状态持久化,在每个节点运行后记录状态。(请参见状态持久化。)
这是一个表示自动售货机的图的示例,用户可以投入硬币并选择要购买的商品。
from __future__ import annotations
from dataclasses import dataclass
from rich.prompt import Prompt
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
@dataclass
class MachineState: # (1)!
user_balance: float = 0.0
product: str | None = None
@dataclass
class InsertCoin(BaseNode[MachineState]): # (3)!
async def run(self, ctx: GraphRunContext[MachineState]) -> CoinsInserted: # (16)!
return CoinsInserted(float(Prompt.ask('Insert coins'))) # (4)!
@dataclass
class CoinsInserted(BaseNode[MachineState]):
amount: float # (5)!
async def run(
self, ctx: GraphRunContext[MachineState]
) -> SelectProduct | Purchase: # (17)!
ctx.state.user_balance += self.amount # (6)!
if ctx.state.product is not None: # (7)!
return Purchase(ctx.state.product)
else:
return SelectProduct()
@dataclass
class SelectProduct(BaseNode[MachineState]):
async def run(self, ctx: GraphRunContext[MachineState]) -> Purchase:
return Purchase(Prompt.ask('Select product'))
PRODUCT_PRICES = { # (2)!
'water': 1.25,
'soda': 1.50,
'crisps': 1.75,
'chocolate': 2.00,
}
@dataclass
class Purchase(BaseNode[MachineState, None, None]): # (18)!
product: str
async def run(
self, ctx: GraphRunContext[MachineState]
) -> End | InsertCoin | SelectProduct:
if price := PRODUCT_PRICES.get(self.product): # (8)!
ctx.state.product = self.product # (9)!
if ctx.state.user_balance >= price: # (10)!
ctx.state.user_balance -= price
return End(None)
else:
diff = price - ctx.state.user_balance
print(f'Not enough money for {self.product}, need {diff:0.2f} more')
#> Not enough money for crisps, need 0.75 more
return InsertCoin() # (11)!
else:
print(f'No such product: {self.product}, try again')
return SelectProduct() # (12)!
vending_machine_graph = Graph( # (13)!
nodes=[InsertCoin, CoinsInserted, SelectProduct, Purchase]
)
async def main():
state = MachineState() # (14)!
await vending_machine_graph.run(InsertCoin(), state=state) # (15)!
print(f'purchase successful item={state.product} change={state.user_balance:0.2f}')
#> purchase successful item=crisps change=0.25
- 自动售货机的状态被定义为一个 dataclass,包含用户的余额和他们选择的商品(如果有的话)。
- 一个将产品映射到价格的字典。
InsertCoin
节点,BaseNode
使用MachineState
进行参数化,因为这是此图中使用的状态。InsertCoin
节点提示用户投入硬币。我们通过仅输入一个浮点数金额来简化操作。在你开始认为这也是一个玩具,因为它在节点内使用了 rich 的Prompt.ask
之前,请看下面关于当节点需要外部输入时如何管理控制流的内容。CoinsInserted
节点;同样,这是一个带有一个字段amount
的dataclass
。- 用投入的金额更新用户的余额。
- 如果用户已经选择了商品,则转到
Purchase
,否则转到SelectProduct
。 - 在
Purchase
节点中,如果用户输入了有效的产品,则查找该产品的价格。 - 如果用户确实输入了有效的产品,则在状态中设置该产品,这样我们就不会重新访问
SelectProduct
。 - 如果余额足以购买该产品,则调整余额以反映购买情况,并返回
End
来结束图的运行。我们没有使用运行的返回类型,所以我们用None
调用End
。 - 如果余额不足,则转到
InsertCoin
提示用户投入更多硬币。 - 如果产品无效,则转到
SelectProduct
提示用户重新选择产品。 - 图是通过向
Graph
传递一个节点列表来创建的。节点的顺序不重要,但它可能会影响图表的显示方式。 - 初始化状态。这将被传递给图的运行,并在图运行时被修改。
- 使用初始状态运行图。由于图可以从任何节点开始运行,我们必须传递起始节点——在本例中是
InsertCoin
。Graph.run
返回一个GraphRunResult
,它提供了最终数据和运行的历史记录。 - 节点的
run
方法的返回类型很重要,因为它用于确定节点的出边。这些信息反过来又用于渲染 Mermaid 图表,并在运行时强制执行,以尽快检测不当行为。 CoinsInserted
的run
方法的返回类型是一个联合类型,意味着可能有多条出边。- 与其他节点不同,
Purchase
可以结束运行,因此必须设置RunEndT
泛型参数。在这种情况下,它是None
,因为图运行的返回类型是None
。
(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main())
来运行 main
)
此图的 Mermaid 图表可以使用以下代码生成
from vending_machine import InsertCoin, vending_machine_graph
vending_machine_graph.mermaid_code(start_node=InsertCoin)
由上述代码生成的图表是
---
title: vending_machine_graph
---
stateDiagram-v2
[*] --> InsertCoin
InsertCoin --> CoinsInserted
CoinsInserted --> SelectProduct
CoinsInserted --> Purchase
SelectProduct --> Purchase
Purchase --> InsertCoin
Purchase --> SelectProduct
Purchase --> [*]
有关生成图表的更多信息,请参见下文。
生成式 AI 示例
到目前为止,我们还没有展示一个实际使用 Pydantic AI 或生成式 AI 的图的示例。
在这个例子中,一个代理生成一封给用户的欢迎邮件,另一个代理对该邮件提供反馈。
这个图的结构非常简单
---
title: feedback_graph
---
stateDiagram-v2
[*] --> WriteEmail
WriteEmail --> Feedback
Feedback --> WriteEmail
Feedback --> [*]
from __future__ import annotations as _annotations
from dataclasses import dataclass, field
from pydantic import BaseModel, EmailStr
from pydantic_ai import Agent, format_as_xml
from pydantic_ai.messages import ModelMessage
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
@dataclass
class User:
name: str
email: EmailStr
interests: list[str]
@dataclass
class Email:
subject: str
body: str
@dataclass
class State:
user: User
write_agent_messages: list[ModelMessage] = field(default_factory=list)
email_writer_agent = Agent(
'google-gla:gemini-1.5-pro',
output_type=Email,
system_prompt='Write a welcome email to our tech blog.',
)
@dataclass
class WriteEmail(BaseNode[State]):
email_feedback: str | None = None
async def run(self, ctx: GraphRunContext[State]) -> Feedback:
if self.email_feedback:
prompt = (
f'Rewrite the email for the user:\n'
f'{format_as_xml(ctx.state.user)}\n'
f'Feedback: {self.email_feedback}'
)
else:
prompt = (
f'Write a welcome email for the user:\n'
f'{format_as_xml(ctx.state.user)}'
)
result = await email_writer_agent.run(
prompt,
message_history=ctx.state.write_agent_messages,
)
ctx.state.write_agent_messages += result.new_messages()
return Feedback(result.output)
class EmailRequiresWrite(BaseModel):
feedback: str
class EmailOk(BaseModel):
pass
feedback_agent = Agent[None, EmailRequiresWrite | EmailOk](
'openai:gpt-4o',
output_type=EmailRequiresWrite | EmailOk, # type: ignore
system_prompt=(
'Review the email and provide feedback, email must reference the users specific interests.'
),
)
@dataclass
class Feedback(BaseNode[State, None, Email]):
email: Email
async def run(
self,
ctx: GraphRunContext[State],
) -> WriteEmail | End[Email]:
prompt = format_as_xml({'user': ctx.state.user, 'email': self.email})
result = await feedback_agent.run(prompt)
if isinstance(result.output, EmailRequiresWrite):
return WriteEmail(email_feedback=result.output.feedback)
else:
return End(self.email)
async def main():
user = User(
name='John Doe',
email='john.joe@example.com',
interests=['Haskel', 'Lisp', 'Fortran'],
)
state = State(user)
feedback_graph = Graph(nodes=(WriteEmail, Feedback))
result = await feedback_graph.run(WriteEmail(), state=state)
print(result.output)
"""
Email(
subject='Welcome to our tech blog!',
body='Hello John, Welcome to our tech blog! ...',
)
"""
(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main())
来运行 main
)
迭代图
使用 Graph.iter
进行 async for
迭代
有时你希望在图执行时直接控制或洞察每个节点。最简单的方法是使用 Graph.iter
方法,它返回一个**上下文管理器**,该管理器会产生一个 GraphRun
对象。GraphRun
是一个针对图中节点的可异步迭代对象,允许你在它们执行时记录或修改它们。
这是一个例子
from __future__ import annotations as _annotations
from dataclasses import dataclass
from pydantic_graph import Graph, BaseNode, End, GraphRunContext
@dataclass
class CountDownState:
counter: int
@dataclass
class CountDown(BaseNode[CountDownState, None, int]):
async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
if ctx.state.counter <= 0:
return End(ctx.state.counter)
ctx.state.counter -= 1
return CountDown()
count_down_graph = Graph(nodes=[CountDown])
async def main():
state = CountDownState(counter=3)
async with count_down_graph.iter(CountDown(), state=state) as run: # (1)!
async for node in run: # (2)!
print('Node:', node)
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: End(data=0)
print('Final output:', run.result.output) # (3)!
#> Final output: 0
Graph.iter(...)
返回一个GraphRun
。- 在这里,我们逐步执行每个节点。
- 一旦图返回一个
End
,循环就会结束,run.result
会变成一个GraphRunResult
,其中包含最终结果(这里是0
)。
手动使用 GraphRun.next(node)
或者,你可以使用 GraphRun.next
方法手动驱动迭代,这允许你传入任何你想要接下来运行的节点。通过这种方式,你可以修改或选择性地跳过节点。
下面是一个刻意设计的例子,它在计数器为 2 时停止,忽略了之后的所有节点运行
from pydantic_graph import End, FullStatePersistence
from count_down import CountDown, CountDownState, count_down_graph
async def main():
state = CountDownState(counter=5)
persistence = FullStatePersistence() # (7)!
async with count_down_graph.iter(
CountDown(), state=state, persistence=persistence
) as run:
node = run.next_node # (1)!
while not isinstance(node, End): # (2)!
print('Node:', node)
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
if state.counter == 2:
break # (3)!
node = await run.next(node) # (4)!
print(run.result) # (5)!
#> None
for step in persistence.history: # (6)!
print('History Step:', step.state, step.state)
#> History Step: CountDownState(counter=5) CountDownState(counter=5)
#> History Step: CountDownState(counter=4) CountDownState(counter=4)
#> History Step: CountDownState(counter=3) CountDownState(counter=3)
#> History Step: CountDownState(counter=2) CountDownState(counter=2)
- 我们首先获取代理图中将要运行的第一个节点。
- 一旦产生了
End
节点,代理运行就完成了;End
的实例不能传递给next
。 - 如果用户决定提前停止,我们就跳出循环。在这种情况下,图的运行将没有真正的最终结果(
run.result
保持为None
)。 - 在每一步,我们调用
await run.next(node)
来运行它并获取下一个节点(或一个End
)。 - 因为我们没有继续运行直到它完成,所以
result
没有被设置。 - 运行的历史记录仍然包含了我们已经执行的步骤。
- 使用
FullStatePersistence
,这样我们就可以显示运行的历史记录,更多信息请参见下面的状态持久化。
状态持久化
有限状态机(FSM)图最大的好处之一是它们如何简化对中断执行的处理。这可能由于多种原因发生:
- 状态机逻辑可能根本上需要暂停——例如,电子商务订单的退货工作流需要等待物品被寄到退货中心,或者因为下一个节点的执行需要用户的输入,所以需要等待一个新的 http 请求。
- 执行时间太长,以至于整个图无法可靠地在一次连续运行中执行完毕——例如,一个可能需要数小时运行的深度研究代理。
- 你希望在不同的进程/硬件实例中并行运行多个图节点(注意:
pydantic-graph
目前尚不支持并行节点执行,请参见 #704)。
试图使传统的控制流(即布尔逻辑和嵌套函数调用)实现与这些使用场景兼容,通常会导致代码脆弱且过于复杂的“意大利面条式代码”,其中用于中断和恢复执行的逻辑主导了整个实现。
为了允许图的运行可以被中断和恢复,pydantic-graph
提供了状态持久化——一个在每个节点运行前后对图运行状态进行快照的系统,从而允许图的运行可以从图中的任何一点恢复。
pydantic-graph
包含了三种状态持久化实现
SimpleStatePersistence
— 简单的内存状态持久化,只保存最新的快照。如果在运行图时没有提供状态持久化实现,则默认使用此实现。FullStatePersistence
— 内存状态持久化,保存一个快照列表。FileStatePersistence
— 基于文件的状态持久化,将快照保存到 JSON 文件中。
在生产应用中,开发者应该通过子类化 BaseStatePersistence
抽象基类来实现自己的状态持久化,例如将运行状态持久化到像 PostgresQL 这样的关系型数据库中。
从高层次来看,StatePersistence
实现的角色是存储和检索 NodeSnapshot
和 EndSnapshot
对象。
graph.iter_from_persistence()
可以用来基于持久化中存储的状态来运行图。
我们可以使用 graph.iter_from_persistence()
和 FileStatePersistence
来运行上面的 count_down_graph
。
正如你在这段代码中看到的,run_node
的运行不需要任何外部应用状态(除了状态持久化),这意味着图可以很容易地由分布式执行和排队系统来执行。
from pathlib import Path
from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence
from count_down import CountDown, CountDownState, count_down_graph
async def main():
run_id = 'run_abc123'
persistence = FileStatePersistence(Path(f'count_down_{run_id}.json')) # (1)!
state = CountDownState(counter=5)
await count_down_graph.initialize( # (2)!
CountDown(), state=state, persistence=persistence
)
done = False
while not done:
done = await run_node(run_id)
async def run_node(run_id: str) -> bool: # (3)!
persistence = FileStatePersistence(Path(f'count_down_{run_id}.json'))
async with count_down_graph.iter_from_persistence(persistence) as run: # (4)!
node_or_end = await run.next() # (5)!
print('Node:', node_or_end)
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: End(data=0)
return isinstance(node_or_end, End) # (6)!
- 创建一个
FileStatePersistence
用于启动图。 - 调用
graph.initialize()
在持久化对象中设置初始图状态。 run_node
是一个纯函数,除了运行的 ID 外,不需要访问任何其他进程状态来运行图的下一个节点。- 调用
graph.iter_from_persistence()
创建一个GraphRun
对象,它将根据持久化中存储的状态运行图的下一个节点。这将返回一个节点或一个End
对象。 graph.run()
将返回一个 节点 或一个End
对象。- 检查节点是否为
End
对象,如果是,则图的运行完成。
(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main())
来运行 main
)
示例:人在回路。
如上所述,状态持久化允许图被中断和恢复。其中一个用例是允许用户输入以继续进行。
在这个例子中,一个 AI 向用户提问,用户提供答案,AI 评估答案,如果用户答对了就结束,如果答错了就再问一个问题。
我们不是在单个进程调用中运行整个图,而是通过重复运行进程来运行图,并可选择地通过命令行参数提供问题的答案。
ai_q_and_a_graph.py
— question_graph
定义
from __future__ import annotations as _annotations
from typing import Annotated
from pydantic_graph import Edge
from dataclasses import dataclass, field
from pydantic import BaseModel
from pydantic_graph import (
BaseNode,
End,
Graph,
GraphRunContext,
)
from pydantic_ai import Agent, format_as_xml
from pydantic_ai.messages import ModelMessage
ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True)
@dataclass
class QuestionState:
question: str | None = None
ask_agent_messages: list[ModelMessage] = field(default_factory=list)
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)
@dataclass
class Ask(BaseNode[QuestionState]):
"""Generate question using GPT-4o."""
docstring_notes = True
async def run(
self, ctx: GraphRunContext[QuestionState]
) -> Annotated[Answer, Edge(label='Ask the question')]:
result = await ask_agent.run(
'Ask a simple question with a single correct answer.',
message_history=ctx.state.ask_agent_messages,
)
ctx.state.ask_agent_messages += result.new_messages()
ctx.state.question = result.output
return Answer(result.output)
@dataclass
class Answer(BaseNode[QuestionState]):
question: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
answer = input(f'{self.question}: ')
return Evaluate(answer)
class EvaluationResult(BaseModel, use_attribute_docstrings=True):
correct: bool
"""Whether the answer is correct."""
comment: str
"""Comment on the answer, reprimand the user if the answer is wrong."""
evaluate_agent = Agent(
'openai:gpt-4o',
output_type=EvaluationResult,
system_prompt='Given a question and answer, evaluate if the answer is correct.',
)
@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
answer: str
async def run(
self,
ctx: GraphRunContext[QuestionState],
) -> Annotated[End[str], Edge(label='success')] | Reprimand:
assert ctx.state.question is not None
result = await evaluate_agent.run(
format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
message_history=ctx.state.evaluate_agent_messages,
)
ctx.state.evaluate_agent_messages += result.new_messages()
if result.output.correct:
return End(result.output.comment)
else:
return Reprimand(result.output.comment)
@dataclass
class Reprimand(BaseNode[QuestionState]):
comment: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
print(f'Comment: {self.comment}')
ctx.state.question = None
return Ask()
question_graph = Graph(
nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)
(这个例子是完整的,可以“按原样”运行)
import sys
from pathlib import Path
from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence
from pydantic_ai.messages import ModelMessage # noqa: F401
from ai_q_and_a_graph import Ask, question_graph, Evaluate, QuestionState, Answer
async def main():
answer: str | None = sys.argv[1] if len(sys.argv) > 1 else None # (1)!
persistence = FileStatePersistence(Path('question_graph.json')) # (2)!
persistence.set_graph_types(question_graph) # (3)!
if snapshot := await persistence.load_next(): # (4)!
state = snapshot.state
assert answer is not None
node = Evaluate(answer)
else:
state = QuestionState()
node = Ask() # (5)!
async with question_graph.iter(node, state=state, persistence=persistence) as run:
while True:
node = await run.next() # (6)!
if isinstance(node, End): # (7)!
print('END:', node.data)
history = await persistence.load_all() # (8)!
print([e.node for e in history])
break
elif isinstance(node, Answer): # (9)!
print(node.question)
#> What is the capital of France?
break
# otherwise just continue
- 从命令行获取用户的答案(如果提供)。完整的示例请参见问题图示例。
- 创建一个状态持久化实例,
'question_graph.json'
文件可能存在也可能不存在。 - 由于我们在图外部使用持久化接口,我们需要调用
set_graph_types
来为持久化实例设置图的泛型类型StateT
和RunEndT
。这是必需的,以便持久化实例知道如何序列化和反序列化图节点。 - 如果我们之前运行过这个图,
load_next
将返回下一个要运行的节点的快照,这里我们使用该快照中的state
,并用命令行提供的答案创建一个新的Evaluate
节点。 - 如果图之前没有运行过,我们创建一个新的
QuestionState
并从Ask
节点开始。 - 调用
GraphRun.next()
来运行节点。这将返回一个节点或一个End
对象。 - 如果节点是
End
对象,则图的运行完成。End
对象的data
字段包含了evaluate_agent
返回的关于正确答案的评论。 - 为了演示状态持久化,我们调用
load_all
从持久化实例中获取所有快照。这将返回一个Snapshot
对象列表。 - 如果节点是
Answer
对象,我们打印问题并跳出循环以结束进程并等待用户输入。
(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main())
来运行 main
)
关于此图的完整示例,请参见问题图示例。
依赖注入
与 Pydantic AI 一样,pydantic-graph
通过 Graph
和 BaseNode
上的泛型参数以及 GraphRunContext.deps
字段支持依赖注入。
作为依赖注入的一个例子,让我们修改上面的 DivisibleBy5
例子,使用一个 ProcessPoolExecutor
在一个单独的进程中运行计算负载(这是一个刻意设计的例子,ProcessPoolExecutor
在这个例子中实际上不会提高性能)。
from __future__ import annotations
import asyncio
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from pydantic_graph import BaseNode, End, FullStatePersistence, Graph, GraphRunContext
@dataclass
class GraphDeps:
executor: ProcessPoolExecutor
@dataclass
class DivisibleBy5(BaseNode[None, GraphDeps, int]):
foo: int
async def run(
self,
ctx: GraphRunContext[None, GraphDeps],
) -> Increment | End[int]:
if self.foo % 5 == 0:
return End(self.foo)
else:
return Increment(self.foo)
@dataclass
class Increment(BaseNode[None, GraphDeps]):
foo: int
async def run(self, ctx: GraphRunContext[None, GraphDeps]) -> DivisibleBy5:
loop = asyncio.get_running_loop()
compute_result = await loop.run_in_executor(
ctx.deps.executor,
self.compute,
)
return DivisibleBy5(compute_result)
def compute(self) -> int:
return self.foo + 1
fives_graph = Graph(nodes=[DivisibleBy5, Increment])
async def main():
with ProcessPoolExecutor() as executor:
deps = GraphDeps(executor)
result = await fives_graph.run(DivisibleBy5(3), deps=deps, persistence=FullStatePersistence())
print(result.output)
#> 5
# the full history is quite verbose (see below), so we'll just print the summary
print([item.node for item in result.persistence.history])
"""
[
DivisibleBy5(foo=3),
Increment(foo=3),
DivisibleBy5(foo=4),
Increment(foo=4),
DivisibleBy5(foo=5),
End(data=5),
]
"""
(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main())
来运行 main
)
Mermaid 图表
Pydantic Graph 可以为图生成 Mermaid 的 stateDiagram-v2
图表,如上所示。
这些图表可以通过以下方式生成:
Graph.mermaid_code
用于生成图的 Mermaid 代码Graph.mermaid_image
用于使用 mermaid.ink 生成图的图像Graph.mermaid_save
用于使用 mermaid.ink 生成图的图像并将其保存到文件
除了上面展示的图表,你还可以使用以下选项自定义 Mermaid 图表:
Edge
允许你为边应用一个标签BaseNode.docstring_notes
和BaseNode.get_note
允许你为节点添加注释highlighted_nodes
参数允许你在图表中高亮显示特定的节点
综合起来,我们可以编辑上一个 ai_q_and_a_graph.py
示例来:
- 为一些边添加标签
- 为
Ask
节点添加一个注释 - 高亮显示
Answer
节点 - 将图表作为
PNG
图像保存到文件
from typing import Annotated
from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Edge
ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True)
@dataclass
class QuestionState:
question: str | None = None
ask_agent_messages: list[ModelMessage] = field(default_factory=list)
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)
@dataclass
class Ask(BaseNode[QuestionState]):
"""Generate question using GPT-4o."""
docstring_notes = True
async def run(
self, ctx: GraphRunContext[QuestionState]
) -> Annotated[Answer, Edge(label='Ask the question')]:
result = await ask_agent.run(
'Ask a simple question with a single correct answer.',
message_history=ctx.state.ask_agent_messages,
)
ctx.state.ask_agent_messages += result.new_messages()
ctx.state.question = result.output
return Answer(result.output)
@dataclass
class Answer(BaseNode[QuestionState]):
question: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
answer = input(f'{self.question}: ')
return Evaluate(answer)
class EvaluationResult(BaseModel, use_attribute_docstrings=True):
correct: bool
"""Whether the answer is correct."""
comment: str
"""Comment on the answer, reprimand the user if the answer is wrong."""
evaluate_agent = Agent(
'openai:gpt-4o',
output_type=EvaluationResult,
system_prompt='Given a question and answer, evaluate if the answer is correct.',
)
@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
answer: str
async def run(
self,
ctx: GraphRunContext[QuestionState],
) -> Annotated[End[str], Edge(label='success')] | Reprimand:
assert ctx.state.question is not None
result = await evaluate_agent.run(
format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
message_history=ctx.state.evaluate_agent_messages,
)
ctx.state.evaluate_agent_messages += result.new_messages()
if result.output.correct:
return End(result.output.comment)
else:
return Reprimand(result.output.comment)
@dataclass
class Reprimand(BaseNode[QuestionState]):
comment: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
print(f'Comment: {self.comment}')
ctx.state.question = None
return Ask()
question_graph = Graph(
nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)
(此示例不完整,不能直接运行)
这将生成一个看起来像这样的图像
---
title: question_graph
---
stateDiagram-v2
Ask --> Answer: Ask the question
note right of Ask
Judge the answer.
Decide on next step.
end note
Answer --> Evaluate
Evaluate --> Reprimand
Evaluate --> [*]: success
Reprimand --> Ask
classDef highlighted fill:#fdff32
class Answer highlighted
设置状态图的方向
你可以使用以下值之一来指定状态图的方向:
'TB'
: 从上到下,图表从上到下垂直流动。'LR'
: 从左到右,图表从左到右水平流动。'RL'
: 从右到左,图表从右到左水平流动。'BT'
: 从下到上,图表从下到上垂直流动。
下面是一个如何使用‘从左到右’(LR)而不是默认的‘从上到下’(TB)的示例:
from vending_machine import InsertCoin, vending_machine_graph
vending_machine_graph.mermaid_code(start_node=InsertCoin, direction='LR')
---
title: vending_machine_graph
---
stateDiagram-v2
direction LR
[*] --> InsertCoin
InsertCoin --> CoinsInserted
CoinsInserted --> SelectProduct
CoinsInserted --> Purchase
SelectProduct --> Purchase
Purchase --> InsertCoin
Purchase --> SelectProduct
Purchase --> [*]