图
除非你需要钉枪,否则不要使用钉枪
如果 PydanticAI 代理 是锤子,而 多代理工作流程 是大锤,那么图就是钉枪
- 当然,钉枪看起来比锤子更酷
- 但是钉枪比锤子需要更多的设置
- 而且钉枪不会让你成为更好的建造者,它们只会让你成为一个拥有钉枪的建造者
- 最后,(并且冒着曲解这个比喻的风险),如果你是中世纪工具(如木槌和无类型 Python)的爱好者,你可能不会喜欢钉枪或我们处理图的方法。(但话又说回来,如果你不喜欢 Python 中的类型提示,你可能已经放弃使用 PydanticAI,转而使用玩具代理框架之一了——祝你好运,当你意识到你需要它时,随时借用我的大锤)
简而言之,图是一个强大的工具,但它们并非适用于所有工作的正确工具。在继续之前,请考虑其他 多代理方法。
如果你不确定基于图的方法是否是个好主意,那么它可能是没有必要的。
图和有限状态机 (FSM) 是建模、执行、控制和可视化复杂工作流程的强大抽象。
与 PydanticAI 一起,我们开发了 pydantic-graph
—— 一个用于 Python 的异步图和状态机库,其中节点和边使用类型提示定义。
虽然此库是作为 PydanticAI 的一部分开发的;但它不依赖于 pydantic-ai
,可以被视为纯粹的基于图的状态机库。无论你是否使用 PydanticAI,甚至是否使用 GenAI 构建,你都可能会发现它很有用。
pydantic-graph
专为高级用户设计,并大量使用 Python 泛型和类型提示。它的设计目的不是像 PydanticAI 那样对初学者友好。
安装
pydantic-graph
是 pydantic-ai
的必需依赖项,也是 pydantic-ai-slim
的可选依赖项,有关更多信息,请参阅 安装说明。你也可以直接安装它
pip install pydantic-graph
uv add pydantic-graph
图类型
pydantic-graph
由几个关键组件组成
GraphRunContext
GraphRunContext
—— 图运行的上下文,类似于 PydanticAI 的 RunContext
。它保存图的状态和依赖项,并在节点运行时传递给节点。
GraphRunContext
在其使用的图的状态类型中是泛型的,StateT
。
结束
End
—— 返回值,指示图运行应结束。
End
在其使用的图的图返回值类型中是泛型的,RunEndT
。
节点
BaseNode
的子类定义了图中执行的节点。
节点,通常是 dataclass
es,通常由以下部分组成
节点是泛型的,在
- 状态,其类型必须与它们包含在内的图的状态类型相同,
StateT
的默认值为None
,因此如果你不使用状态,可以省略此泛型参数,有关更多信息,请参阅 有状态图 - deps,其类型必须与它们包含在内的图的 deps 类型相同,
DepsT
的默认值为None
,因此如果你不使用 deps,可以省略此泛型参数,有关更多信息,请参阅 依赖注入 - 图返回值类型 —— 这仅在节点返回
End
时适用。RunEndT
的默认值为 Never,因此如果节点不返回End
,则可以省略此泛型参数,但如果返回End
,则必须包含此参数。
这是一个图中的开始节点或中间节点的示例 —— 它不能结束运行,因为它不返回 End
from dataclasses import dataclass
from pydantic_graph import BaseNode, GraphRunContext
@dataclass
class MyNode(BaseNode[MyState]): # (1)!
foo: int # (2)!
async def run(
self,
ctx: GraphRunContext[MyState], # (3)!
) -> AnotherNode: # (4)!
...
return AnotherNode()
- 此示例中的状态为
MyState
(未显示),因此BaseNode
使用MyState
参数化。此节点无法结束运行,因此RunEndT
泛型参数被省略,并默认为Never
。 MyNode
是一个 dataclass,并且具有单个字段foo
,一个int
。run
方法接受GraphRunContext
参数,同样使用状态MyState
参数化。run
方法的返回类型为AnotherNode
(未显示),这用于确定节点的传出边。
我们可以扩展 MyNode
以在 foo
可被 5 整除时选择性地结束运行
from dataclasses import dataclass
from pydantic_graph import BaseNode, End, GraphRunContext
@dataclass
class MyNode(BaseNode[MyState, None, int]): # (1)!
foo: int
async def run(
self,
ctx: GraphRunContext[MyState],
) -> AnotherNode | End[int]: # (2)!
if self.foo % 5 == 0:
return End(self.foo)
else:
return AnotherNode()
- 我们使用返回类型(在本例中为
int
)以及状态来参数化节点。由于泛型参数是仅限位置的,因此我们必须包含None
作为表示 deps 的第二个参数。 run
方法的返回类型现在是AnotherNode
和End[int]
的联合,这允许节点在foo
可被 5 整除时结束运行。
图
Graph
—— 这是执行图本身,由一组 节点类(即 BaseNode
子类)组成。
Graph
是泛型的,在
这是一个简单图的示例
from __future__ import annotations
from dataclasses import dataclass
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
@dataclass
class DivisibleBy5(BaseNode[None, None, int]): # (1)!
foo: int
async def run(
self,
ctx: GraphRunContext,
) -> Increment | End[int]:
if self.foo % 5 == 0:
return End(self.foo)
else:
return Increment(self.foo)
@dataclass
class Increment(BaseNode): # (2)!
foo: int
async def run(self, ctx: GraphRunContext) -> DivisibleBy5:
return DivisibleBy5(self.foo + 1)
fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)!
result = fives_graph.run_sync(DivisibleBy5(4)) # (4)!
print(result.output)
#> 5
DivisibleBy5
节点使用None
参数化状态参数,使用None
参数化 deps 参数,因为此图不使用状态或 deps,并使用int
参数化,因为它可能结束运行。Increment
节点不返回End
,因此省略了RunEndT
泛型参数,状态也可以省略,因为图不使用状态。- 图是使用节点序列创建的。
- 图使用
run_sync
同步运行。初始节点为DivisibleBy5(4)
。由于图不使用外部状态或 deps,因此我们不传递state
或deps
。
(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行)
可以使用以下代码为该图生成 mermaid 图
from graph_example import DivisibleBy5, fives_graph
fives_graph.mermaid_code(start_node=DivisibleBy5)
---
title: fives_graph
---
stateDiagram-v2
[*] --> DivisibleBy5
DivisibleBy5 --> Increment
DivisibleBy5 --> [*]
Increment --> DivisibleBy5
为了在 jupyter-notebook
中可视化图,需要使用 IPython.display
from graph_example import DivisibleBy5, fives_graph
from IPython.display import Image, display
display(Image(fives_graph.mermaid_image(start_node=DivisibleBy5)))
有状态图
pydantic-graph
中的 “状态” 概念提供了一种可选方式,用于在图中节点运行时访问和更改对象(通常是 dataclass
或 Pydantic 模型)。如果你将图视为生产线,那么你的状态就是沿线传递并由每个节点在图运行时构建的引擎。
未来,我们计划扩展 pydantic-graph
以提供状态持久化,并在每个节点运行后记录状态,请参阅 #695。
这是一个表示自动售货机的图的示例,用户可以在其中投入硬币并选择要购买的产品。
from __future__ import annotations
from dataclasses import dataclass
from rich.prompt import Prompt
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
@dataclass
class MachineState: # (1)!
user_balance: float = 0.0
product: str | None = None
@dataclass
class InsertCoin(BaseNode[MachineState]): # (3)!
async def run(self, ctx: GraphRunContext[MachineState]) -> CoinsInserted: # (16)!
return CoinsInserted(float(Prompt.ask('Insert coins'))) # (4)!
@dataclass
class CoinsInserted(BaseNode[MachineState]):
amount: float # (5)!
async def run(
self, ctx: GraphRunContext[MachineState]
) -> SelectProduct | Purchase: # (17)!
ctx.state.user_balance += self.amount # (6)!
if ctx.state.product is not None: # (7)!
return Purchase(ctx.state.product)
else:
return SelectProduct()
@dataclass
class SelectProduct(BaseNode[MachineState]):
async def run(self, ctx: GraphRunContext[MachineState]) -> Purchase:
return Purchase(Prompt.ask('Select product'))
PRODUCT_PRICES = { # (2)!
'water': 1.25,
'soda': 1.50,
'crisps': 1.75,
'chocolate': 2.00,
}
@dataclass
class Purchase(BaseNode[MachineState, None, None]): # (18)!
product: str
async def run(
self, ctx: GraphRunContext[MachineState]
) -> End | InsertCoin | SelectProduct:
if price := PRODUCT_PRICES.get(self.product): # (8)!
ctx.state.product = self.product # (9)!
if ctx.state.user_balance >= price: # (10)!
ctx.state.user_balance -= price
return End(None)
else:
diff = price - ctx.state.user_balance
print(f'Not enough money for {self.product}, need {diff:0.2f} more')
#> Not enough money for crisps, need 0.75 more
return InsertCoin() # (11)!
else:
print(f'No such product: {self.product}, try again')
return SelectProduct() # (12)!
vending_machine_graph = Graph( # (13)!
nodes=[InsertCoin, CoinsInserted, SelectProduct, Purchase]
)
async def main():
state = MachineState() # (14)!
await vending_machine_graph.run(InsertCoin(), state=state) # (15)!
print(f'purchase successful item={state.product} change={state.user_balance:0.2f}')
#> purchase successful item=crisps change=0.25
- 自动售货机的状态定义为一个 dataclass,其中包含用户的余额以及他们已选择的产品(如果有)。
- 产品到价格的字典映射。
InsertCoin
节点,BaseNode
使用MachineState
参数化,因为这是此图中使用的状态。InsertCoin
节点提示用户投入硬币。我们通过仅输入货币金额(浮点数)来保持简单。在你开始认为这也是一个玩具,因为它在节点内使用 rich 的Prompt.ask
之前,请参阅 下方,了解当节点需要外部输入时如何管理控制流。CoinsInserted
节点;同样,这是一个dataclass
,带有一个字段amount
。- 使用投入的金额更新用户的余额。
- 如果用户已经选择了产品,则转到
Purchase
,否则转到SelectProduct
。 - 在
Purchase
节点中,如果用户输入了有效产品,则查找产品的价格。 - 如果用户确实输入了有效产品,则在状态中设置产品,以便我们不再访问
SelectProduct
。 - 如果余额足以购买产品,则调整余额以反映购买情况,并返回
End
以结束图。我们未使用运行返回类型,因此我们使用None
调用End
。 - 如果余额不足,则转到
InsertCoin
以提示用户投入更多硬币。 - 如果产品无效,则转到
SelectProduct
以提示用户再次选择产品。 - 通过将节点列表传递给
Graph
来创建图。节点的顺序并不重要,但它可能会影响 图 的显示方式。 - 初始化状态。这将传递给图运行并在图运行时发生变化。
- 使用初始状态运行图。由于图可以从任何节点运行,因此我们必须传递起始节点 —— 在本例中为
InsertCoin
。Graph.run
返回一个GraphRunResult
,该结果提供最终数据和运行历史记录。 - 节点的
run
方法的返回类型很重要,因为它用于确定节点的传出边。此信息反过来用于渲染 mermaid 图,并在运行时强制执行以尽快检测到错误行为。 CoinsInserted
的run
方法的返回类型是联合,这意味着可能有多个传出边。- 与其他节点不同,
Purchase
可以结束运行,因此必须设置RunEndT
泛型参数。在本例中,它是None
,因为图运行返回类型为None
。
(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main())
来运行 main
)
可以使用以下代码为该图生成 mermaid 图
from vending_machine import InsertCoin, vending_machine_graph
vending_machine_graph.mermaid_code(start_node=InsertCoin)
上述代码生成的图是
---
title: vending_machine_graph
---
stateDiagram-v2
[*] --> InsertCoin
InsertCoin --> CoinsInserted
CoinsInserted --> SelectProduct
CoinsInserted --> Purchase
SelectProduct --> Purchase
Purchase --> InsertCoin
Purchase --> SelectProduct
Purchase --> [*]
有关生成图的更多信息,请参阅 下方。
GenAI 示例
到目前为止,我们还没有展示一个实际使用 PydanticAI 或 GenAI 的图的示例。
在此示例中,一个代理生成一封欢迎电子邮件给用户,另一个代理提供有关该电子邮件的反馈。
此图具有非常简单的结构
---
title: feedback_graph
---
stateDiagram-v2
[*] --> WriteEmail
WriteEmail --> Feedback
Feedback --> WriteEmail
Feedback --> [*]
from __future__ import annotations as _annotations
from dataclasses import dataclass, field
from pydantic import BaseModel, EmailStr
from pydantic_ai import Agent
from pydantic_ai.format_as_xml import format_as_xml
from pydantic_ai.messages import ModelMessage
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
@dataclass
class User:
name: str
email: EmailStr
interests: list[str]
@dataclass
class Email:
subject: str
body: str
@dataclass
class State:
user: User
write_agent_messages: list[ModelMessage] = field(default_factory=list)
email_writer_agent = Agent(
'google-vertex:gemini-1.5-pro',
result_type=Email,
system_prompt='Write a welcome email to our tech blog.',
)
@dataclass
class WriteEmail(BaseNode[State]):
email_feedback: str | None = None
async def run(self, ctx: GraphRunContext[State]) -> Feedback:
if self.email_feedback:
prompt = (
f'Rewrite the email for the user:\n'
f'{format_as_xml(ctx.state.user)}\n'
f'Feedback: {self.email_feedback}'
)
else:
prompt = (
f'Write a welcome email for the user:\n'
f'{format_as_xml(ctx.state.user)}'
)
result = await email_writer_agent.run(
prompt,
message_history=ctx.state.write_agent_messages,
)
ctx.state.write_agent_messages += result.all_messages()
return Feedback(result.data)
class EmailRequiresWrite(BaseModel):
feedback: str
class EmailOk(BaseModel):
pass
feedback_agent = Agent[None, EmailRequiresWrite | EmailOk](
'openai:gpt-4o',
result_type=EmailRequiresWrite | EmailOk, # type: ignore
system_prompt=(
'Review the email and provide feedback, email must reference the users specific interests.'
),
)
@dataclass
class Feedback(BaseNode[State, None, Email]):
email: Email
async def run(
self,
ctx: GraphRunContext[State],
) -> WriteEmail | End[Email]:
prompt = format_as_xml({'user': ctx.state.user, 'email': self.email})
result = await feedback_agent.run(prompt)
if isinstance(result.data, EmailRequiresWrite):
return WriteEmail(email_feedback=result.data.feedback)
else:
return End(self.email)
async def main():
user = User(
name='John Doe',
email='john.joe@example.com',
interests=['Haskel', 'Lisp', 'Fortran'],
)
state = State(user)
feedback_graph = Graph(nodes=(WriteEmail, Feedback))
result = await feedback_graph.run(WriteEmail(), state=state)
print(result.output)
"""
Email(
subject='Welcome to our tech blog!',
body='Hello John, Welcome to our tech blog! ...',
)
"""
(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main())
来运行 main
)
迭代图
使用 Graph.iter
进行 async for
迭代
有时你希望在图执行时直接控制或深入了解每个节点。最简单的方法是使用 Graph.iter
方法,该方法返回一个上下文管理器,该管理器产生一个 GraphRun
对象。GraphRun
是图中节点的异步可迭代对象,允许你在节点执行时记录或修改它们。
这是一个例子
from __future__ import annotations as _annotations
from dataclasses import dataclass
from pydantic_graph import Graph, BaseNode, End, GraphRunContext
@dataclass
class CountDownState:
counter: int
@dataclass
class CountDown(BaseNode[CountDownState, None, int]):
async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
if ctx.state.counter <= 0:
return End(ctx.state.counter)
ctx.state.counter -= 1
return CountDown()
count_down_graph = Graph(nodes=[CountDown])
async def main():
state = CountDownState(counter=3)
async with count_down_graph.iter(CountDown(), state=state) as run: # (1)!
async for node in run: # (2)!
print('Node:', node)
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: End(data=0)
print('Final result:', run.result.output) # (3)!
#> Final result: 0
Graph.iter(...)
返回一个GraphRun
。- 在这里,我们逐步遍历每个节点,因为它是执行的。
- 一旦图返回一个
End
,循环结束,并且run.final_result
变为一个GraphRunResult
,其中包含最终结果(此处为0
)。
手动使用 GraphRun.next(node)
或者,你可以使用 GraphRun.next
方法手动驱动迭代,该方法允许你传入你想要接下来运行的任何节点。你可以通过这种方式修改或选择性地跳过节点。
下面是一个人为的示例,它在计数器为 2 时停止,忽略该值之外的任何节点运行
from pydantic_graph import End, FullStatePersistence
from count_down import CountDown, CountDownState, count_down_graph
async def main():
state = CountDownState(counter=5)
persistence = FullStatePersistence() # (7)!
async with count_down_graph.iter(
CountDown(), state=state, persistence=persistence
) as run:
node = run.next_node # (1)!
while not isinstance(node, End): # (2)!
print('Node:', node)
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
if state.counter == 2:
break # (3)!
node = await run.next(node) # (4)!
print(run.result) # (5)!
#> None
for step in persistence.history: # (6)!
print('History Step:', step.state, step.state)
#> History Step: CountDownState(counter=5) CountDownState(counter=5)
#> History Step: CountDownState(counter=4) CountDownState(counter=4)
#> History Step: CountDownState(counter=3) CountDownState(counter=3)
#> History Step: CountDownState(counter=2) CountDownState(counter=2)
- 我们首先获取将在代理图中运行的第一个节点。
- 一旦生成
End
节点,代理运行就完成;End
的实例无法传递给next
。 - 如果用户决定提前停止,我们将跳出循环。在这种情况下,图运行将没有真正的最终结果(
run.final_result
仍然为None
)。 - 在每个步骤中,我们调用
await run.next(node)
来运行它并获取下一个节点(或End
)。 - 因为我们没有继续运行直到完成,所以
result
未设置。 - 运行的历史记录仍然填充了我们到目前为止执行的步骤。
- 使用
FullStatePersistence
,以便我们可以显示运行的历史记录,有关更多信息,请参阅下面的 状态持久化。
状态持久化
有限状态机 (FSM) 图的最大好处之一是它们如何简化中断执行的处理。这可能是由于多种原因造成的
- 状态机逻辑可能从根本上需要暂停 —— 例如,电子商务订单的退货工作流程需要等待物品被寄到退货中心,或者因为下一个节点的执行需要来自用户的输入,因此需要等待新的 http 请求,
- 执行时间太长,以至于整个图无法在单个连续运行中可靠地执行 —— 例如,一个深度研究代理可能需要数小时才能运行,
- 你希望在不同的进程/硬件实例中并行运行多个图节点(注意:
pydantic-graph
尚不支持并行节点执行,请参阅 #704)。
尝试使传统的控制流(即,布尔逻辑和嵌套函数调用)实现与这些使用场景兼容通常会导致脆弱且过于复杂的意大利面条式代码,其中中断和恢复执行所需的逻辑主导了实现。
为了允许图运行被中断和恢复,pydantic-graph
提供了状态持久化 —— 一种在每个节点运行之前和之后快照图运行状态的系统,允许从图中的任何点恢复图运行。
pydantic-graph
包括三种状态持久化实现
SimpleStatePersistence
—— 简单的内存状态持久化,仅保存最新的快照。如果在运行图时未提供状态持久化实现,则默认使用此实现。FullStatePersistence
—— 内存状态持久化,保存快照列表。FileStatePersistence
—— 基于文件的状态持久化,将快照保存到 JSON 文件。
在生产应用程序中,开发人员应通过子类化 BaseStatePersistence
抽象基类来实现自己的状态持久化,这可能会将运行持久化到关系数据库(如 PostgresQL)中。
在高层次上,StatePersistence
实现的作用是存储和检索 NodeSnapshot
和 EndSnapshot
对象。
graph.iter_from_persistence()
可用于基于持久化中存储的状态运行图。
我们可以使用 graph.iter_from_persistence()
和 FileStatePersistence
从 上方 运行 count_down_graph
。
正如你在此代码中看到的,run_node
运行不需要外部应用程序状态(除了状态持久化),这意味着图可以轻松地通过分布式执行和排队系统执行。
from pathlib import Path
from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence
from count_down import CountDown, CountDownState, count_down_graph
async def main():
run_id = 'run_abc123'
persistence = FileStatePersistence(Path(f'count_down_{run_id}.json')) # (1)!
state = CountDownState(counter=5)
await count_down_graph.initialize( # (2)!
CountDown(), state=state, persistence=persistence
)
done = False
while not done:
done = await run_node(run_id)
async def run_node(run_id: str) -> bool: # (3)!
persistence = FileStatePersistence(Path(f'count_down_{run_id}.json'))
async with count_down_graph.iter_from_persistence(persistence) as run: # (4)!
node_or_end = await run.next() # (5)!
print('Node:', node_or_end)
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: CountDown()
#> Node: End(data=0)
return isinstance(node_or_end, End) # (6)!
- 创建一个
FileStatePersistence
以用于启动图。 - 调用
graph.initialize()
以在持久化对象中设置初始图状态。 run_node
是一个纯函数,除了运行 ID 之外,不需要访问任何其他进程状态即可运行图的下一个节点。- 调用
graph.iter_from_persistence()
创建一个GraphRun
对象,该对象将从持久化中存储的状态运行图的下一个节点。这将返回一个节点或一个End
对象。 graph.run()
将返回一个 node 或一个End
对象。- 检查节点是否为
End
对象,如果是,则图运行完成。
(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main())
来运行 main
)
示例:人机环路。
如上所述,状态持久化允许图被中断和恢复。其中一个用例是允许用户输入继续。
在此示例中,AI 向用户提出一个问题,用户提供答案,AI 评估答案,如果用户回答正确则结束,如果回答错误则提出另一个问题。
我们不是在单个进程调用中运行整个图,而是通过重复运行进程来运行图,可以选择性地将问题的答案作为命令行参数提供。
ai_q_and_a_graph.py
—— question_graph
定义
from __future__ import annotations as _annotations
from dataclasses import dataclass, field
from groq import BaseModel
from pydantic_graph import (
BaseNode,
End,
Graph,
GraphRunContext,
)
from pydantic_ai import Agent
from pydantic_ai.format_as_xml import format_as_xml
from pydantic_ai.messages import ModelMessage
ask_agent = Agent('openai:gpt-4o', result_type=str, instrument=True)
@dataclass
class QuestionState:
question: str | None = None
ask_agent_messages: list[ModelMessage] = field(default_factory=list)
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)
@dataclass
class Ask(BaseNode[QuestionState]):
async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
result = await ask_agent.run(
'Ask a simple question with a single correct answer.',
message_history=ctx.state.ask_agent_messages,
)
ctx.state.ask_agent_messages += result.all_messages()
ctx.state.question = result.data
return Answer(result.data)
@dataclass
class Answer(BaseNode[QuestionState]):
question: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
answer = input(f'{self.question}: ')
return Evaluate(answer)
class EvaluationResult(BaseModel, use_attribute_docstrings=True):
correct: bool
"""Whether the answer is correct."""
comment: str
"""Comment on the answer, reprimand the user if the answer is wrong."""
evaluate_agent = Agent(
'openai:gpt-4o',
result_type=EvaluationResult,
system_prompt='Given a question and answer, evaluate if the answer is correct.',
)
@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
answer: str
async def run(
self,
ctx: GraphRunContext[QuestionState],
) -> End[str] | Reprimand:
assert ctx.state.question is not None
result = await evaluate_agent.run(
format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
message_history=ctx.state.evaluate_agent_messages,
)
ctx.state.evaluate_agent_messages += result.all_messages()
if result.data.correct:
return End(result.data.comment)
else:
return Reprimand(result.data.comment)
@dataclass
class Reprimand(BaseNode[QuestionState]):
comment: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
print(f'Comment: {self.comment}')
ctx.state.question = None
return Ask()
question_graph = Graph(
nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)
(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行)
import sys
from pathlib import Path
from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence
from pydantic_ai.messages import ModelMessage # noqa: F401
from ai_q_and_a_graph import Ask, question_graph, Evaluate, QuestionState, Answer
async def main():
answer: str | None = sys.argv[2] if len(sys.argv) > 2 else None # (1)!
persistence = FileStatePersistence(Path('question_graph.json')) # (2)!
persistence.set_graph_types(question_graph) # (3)!
if snapshot := await persistence.load_next(): # (4)!
state = snapshot.state
assert answer is not None
node = Evaluate(answer)
else:
state = QuestionState()
node = Ask() # (5)!
async with question_graph.iter(node, state=state, persistence=persistence) as run:
while True:
node = await run.next() # (6)!
if isinstance(node, End): # (7)!
print('END:', node.data)
history = await persistence.load_all() # (8)!
print([e.node for e in history])
break
elif isinstance(node, Answer): # (9)!
print(node.question)
#> What is the capital of France?
break
# otherwise just continue
- 从命令行获取用户的答案(如果提供)。有关完整示例,请参阅 问题图示例。
- 创建一个状态持久化实例,
'question_graph.json'
文件可能已经存在,也可能不存在。 - 由于我们在图外部使用 持久化接口,因此我们需要调用
set_graph_types
以设置持久化实例的图泛型类型StateT
和RunEndT
。这对于允许持久化实例知道如何序列化和反序列化图节点是必要的。 - 如果我们之前运行过图,
load_next
将返回要运行的下一个节点的快照,此处我们使用该快照中的state
,并使用命令行上提供的答案创建一个新的Evaluate
节点。 - 如果图之前未运行过,我们创建一个新的
QuestionState
并从Ask
节点开始。 - 调用
GraphRun.next()
以运行节点。这将返回一个节点或一个End
对象。 - 如果节点是
End
对象,则图运行完成。End
对象的data
字段包含evaluate_agent
返回的关于正确答案的注释。 - 为了演示状态持久化,我们调用
load_all
以从持久化实例获取所有快照。这将返回Snapshot
对象列表。 - 如果节点是
Answer
对象,我们打印问题并跳出循环以结束进程并等待用户输入。
(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main(answer))
来运行 main
)
有关此图的完整示例,请参阅 问题图示例。
依赖注入
与 PydanticAI 一样,pydantic-graph
通过 Graph
和 BaseNode
上的泛型参数以及 GraphRunContext.deps
字段支持依赖注入。
作为依赖注入的示例,让我们修改 上面 的 DivisibleBy5
示例,以使用 ProcessPoolExecutor
在单独的进程中运行计算负载(这是一个人为的示例,ProcessPoolExecutor
实际上不会在此示例中提高性能)
from __future__ import annotations
import asyncio
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
@dataclass
class GraphDeps:
executor: ProcessPoolExecutor
@dataclass
class DivisibleBy5(BaseNode[None, GraphDeps, int]):
foo: int
async def run(
self,
ctx: GraphRunContext[None, GraphDeps],
) -> Increment | End[int]:
if self.foo % 5 == 0:
return End(self.foo)
else:
return Increment(self.foo)
@dataclass
class Increment(BaseNode[None, GraphDeps]):
foo: int
async def run(self, ctx: GraphRunContext[None, GraphDeps]) -> DivisibleBy5:
loop = asyncio.get_running_loop()
compute_result = await loop.run_in_executor(
ctx.deps.executor,
self.compute,
)
return DivisibleBy5(compute_result)
def compute(self) -> int:
return self.foo + 1
fives_graph = Graph(nodes=[DivisibleBy5, Increment])
async def main():
with ProcessPoolExecutor() as executor:
deps = GraphDeps(executor)
result = await fives_graph.run(DivisibleBy5(3), deps=deps)
print(result.output)
#> 5
# the full history is quite verbose (see below), so we'll just print the summary
print([item.data_snapshot() for item in result.history])
"""
[
DivisibleBy5(foo=3),
Increment(foo=3),
DivisibleBy5(foo=4),
Increment(foo=4),
DivisibleBy5(foo=5),
End(data=5),
]
"""
(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main())
来运行 main
)
Mermaid 图
Pydantic Graph 可以为图生成 mermaid stateDiagram-v2
图,如上所示。
这些图可以使用以下方法生成
Graph.mermaid_code
以生成图的 mermaid 代码Graph.mermaid_image
以使用 mermaid.ink 生成图的图像Graph.mermaid_save
以使用 mermaid.ink 生成图的图像并将其保存到文件
除了上面显示的图之外,你还可以使用以下选项自定义 mermaid 图
Edge
允许你将标签应用于边BaseNode.docstring_notes
和BaseNode.get_note
允许你向节点添加注释highlighted_nodes
参数允许你在图中突出显示特定节点
将它们放在一起,我们可以编辑最后一个 ai_q_and_a_graph.py
示例以
- 向某些边添加标签
- 向
Ask
节点添加注释 - 突出显示
Answer
节点 - 将图另存为
PNG
图像文件
...
from typing import Annotated
from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Edge
...
@dataclass
class Ask(BaseNode[QuestionState]):
"""Generate question using GPT-4o."""
docstring_notes = True
async def run(
self, ctx: GraphRunContext[QuestionState]
) -> Annotated[Answer, Edge(label='Ask the question')]:
...
...
@dataclass
class Evaluate(BaseNode[QuestionState]):
answer: str
async def run(
self,
ctx: GraphRunContext[QuestionState],
) -> Annotated[End[str], Edge(label='success')] | Reprimand:
...
...
question_graph.mermaid_save('image.png', highlighted_nodes=[Answer])
(此示例不完整,无法直接运行)
这将生成如下所示的图像
---
title: question_graph
---
stateDiagram-v2
Ask --> Answer: Ask the question
note right of Ask
Judge the answer.
Decide on next step.
end note
Answer --> Evaluate
Evaluate --> Reprimand
Evaluate --> [*]: success
Reprimand --> Ask
classDef highlighted fill:#fdff32
class Answer highlighted
设置状态图的方向
你可以使用以下值之一指定状态图的方向
'TB'
:从上到下,图垂直地从上到下流动。'LR'
:从左到右,图水平地从左到右流动。'RL'
:从右到左,图水平地从右到左流动。'BT'
:从下到上,图垂直地从下到上流动。
这是一个如何使用 “从左到右”(LR) 而不是默认的 “从上到下”(TB) 的示例
from vending_machine import InsertCoin, vending_machine_graph
vending_machine_graph.mermaid_code(start_node=InsertCoin, direction='LR')
---
title: vending_machine_graph
---
stateDiagram-v2
direction LR
[*] --> InsertCoin
InsertCoin --> CoinsInserted
CoinsInserted --> SelectProduct
CoinsInserted --> Purchase
SelectProduct --> Purchase
Purchase --> InsertCoin
Purchase --> SelectProduct
Purchase --> [*]