跳到内容

除非你需要钉枪,否则不要使用钉枪

如果 PydanticAI 代理 是锤子,而 多代理工作流程 是大锤,那么图就是钉枪

  • 当然,钉枪看起来比锤子更酷
  • 但是钉枪比锤子需要更多的设置
  • 而且钉枪不会让你成为更好的建造者,它们只会让你成为一个拥有钉枪的建造者
  • 最后,(并且冒着曲解这个比喻的风险),如果你是中世纪工具(如木槌和无类型 Python)的爱好者,你可能不会喜欢钉枪或我们处理图的方法。(但话又说回来,如果你不喜欢 Python 中的类型提示,你可能已经放弃使用 PydanticAI,转而使用玩具代理框架之一了——祝你好运,当你意识到你需要它时,随时借用我的大锤)

简而言之,图是一个强大的工具,但它们并非适用于所有工作的正确工具。在继续之前,请考虑其他 多代理方法

如果你不确定基于图的方法是否是个好主意,那么它可能是没有必要的。

图和有限状态机 (FSM) 是建模、执行、控制和可视化复杂工作流程的强大抽象。

与 PydanticAI 一起,我们开发了 pydantic-graph —— 一个用于 Python 的异步图和状态机库,其中节点和边使用类型提示定义。

虽然此库是作为 PydanticAI 的一部分开发的;但它不依赖于 pydantic-ai,可以被视为纯粹的基于图的状态机库。无论你是否使用 PydanticAI,甚至是否使用 GenAI 构建,你都可能会发现它很有用。

pydantic-graph 专为高级用户设计,并大量使用 Python 泛型和类型提示。它的设计目的不是像 PydanticAI 那样对初学者友好。

安装

pydantic-graphpydantic-ai 的必需依赖项,也是 pydantic-ai-slim 的可选依赖项,有关更多信息,请参阅 安装说明。你也可以直接安装它

pip install pydantic-graph
uv add pydantic-graph

图类型

pydantic-graph 由几个关键组件组成

GraphRunContext

GraphRunContext —— 图运行的上下文,类似于 PydanticAI 的 RunContext。它保存图的状态和依赖项,并在节点运行时传递给节点。

GraphRunContext 在其使用的图的状态类型中是泛型的,StateT

结束

End —— 返回值,指示图运行应结束。

End 在其使用的图的图返回值类型中是泛型的,RunEndT

节点

BaseNode 的子类定义了图中执行的节点。

节点,通常是 dataclasses,通常由以下部分组成

  • 包含调用节点时需要的任何必需/可选参数的字段
  • 执行节点的业务逻辑,在 run 方法中
  • run 方法的返回注释,pydantic-graph 读取这些注释以确定节点的传出边

节点是泛型的,在

  • 状态,其类型必须与它们包含在内的图的状态类型相同,StateT 的默认值为 None,因此如果你不使用状态,可以省略此泛型参数,有关更多信息,请参阅 有状态图
  • deps,其类型必须与它们包含在内的图的 deps 类型相同,DepsT 的默认值为 None,因此如果你不使用 deps,可以省略此泛型参数,有关更多信息,请参阅 依赖注入
  • 图返回值类型 —— 这仅在节点返回 End 时适用。RunEndT 的默认值为 Never,因此如果节点不返回 End,则可以省略此泛型参数,但如果返回 End,则必须包含此参数。

这是一个图中的开始节点或中间节点的示例 —— 它不能结束运行,因为它不返回 End

intermediate_node.py
from dataclasses import dataclass

from pydantic_graph import BaseNode, GraphRunContext


@dataclass
class MyNode(BaseNode[MyState]):  # (1)!
    foo: int  # (2)!

    async def run(
        self,
        ctx: GraphRunContext[MyState],  # (3)!
    ) -> AnotherNode:  # (4)!
        ...
        return AnotherNode()
  1. 此示例中的状态为 MyState(未显示),因此 BaseNode 使用 MyState 参数化。此节点无法结束运行,因此 RunEndT 泛型参数被省略,并默认为 Never
  2. MyNode 是一个 dataclass,并且具有单个字段 foo,一个 int
  3. run 方法接受 GraphRunContext 参数,同样使用状态 MyState 参数化。
  4. run 方法的返回类型为 AnotherNode(未显示),这用于确定节点的传出边。

我们可以扩展 MyNode 以在 foo 可被 5 整除时选择性地结束运行

intermediate_or_end_node.py
from dataclasses import dataclass

from pydantic_graph import BaseNode, End, GraphRunContext


@dataclass
class MyNode(BaseNode[MyState, None, int]):  # (1)!
    foo: int

    async def run(
        self,
        ctx: GraphRunContext[MyState],
    ) -> AnotherNode | End[int]:  # (2)!
        if self.foo % 5 == 0:
            return End(self.foo)
        else:
            return AnotherNode()
  1. 我们使用返回类型(在本例中为 int)以及状态来参数化节点。由于泛型参数是仅限位置的,因此我们必须包含 None 作为表示 deps 的第二个参数。
  2. run 方法的返回类型现在是 AnotherNodeEnd[int] 的联合,这允许节点在 foo 可被 5 整除时结束运行。

Graph —— 这是执行图本身,由一组 节点类(即 BaseNode 子类)组成。

Graph 是泛型的,在

  • 状态 图的状态类型,StateT
  • deps 图的 deps 类型,DepsT
  • 图返回值类型 图运行的返回值类型,RunEndT

这是一个简单图的示例

graph_example.py
from __future__ import annotations

from dataclasses import dataclass

from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class DivisibleBy5(BaseNode[None, None, int]):  # (1)!
    foo: int

    async def run(
        self,
        ctx: GraphRunContext,
    ) -> Increment | End[int]:
        if self.foo % 5 == 0:
            return End(self.foo)
        else:
            return Increment(self.foo)


@dataclass
class Increment(BaseNode):  # (2)!
    foo: int

    async def run(self, ctx: GraphRunContext) -> DivisibleBy5:
        return DivisibleBy5(self.foo + 1)


fives_graph = Graph(nodes=[DivisibleBy5, Increment])  # (3)!
result = fives_graph.run_sync(DivisibleBy5(4))  # (4)!
print(result.output)
#> 5
  1. DivisibleBy5 节点使用 None 参数化状态参数,使用 None 参数化 deps 参数,因为此图不使用状态或 deps,并使用 int 参数化,因为它可能结束运行。
  2. Increment 节点不返回 End,因此省略了 RunEndT 泛型参数,状态也可以省略,因为图不使用状态。
  3. 图是使用节点序列创建的。
  4. 图使用 run_sync 同步运行。初始节点为 DivisibleBy5(4)。由于图不使用外部状态或 deps,因此我们不传递 statedeps

(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行)

可以使用以下代码为该图生成 mermaid 图

graph_example_diagram.py
from graph_example import DivisibleBy5, fives_graph

fives_graph.mermaid_code(start_node=DivisibleBy5)
---
title: fives_graph
---
stateDiagram-v2
  [*] --> DivisibleBy5
  DivisibleBy5 --> Increment
  DivisibleBy5 --> [*]
  Increment --> DivisibleBy5

为了在 jupyter-notebook 中可视化图,需要使用 IPython.display

jupyter_display_mermaid.py
from graph_example import DivisibleBy5, fives_graph
from IPython.display import Image, display

display(Image(fives_graph.mermaid_image(start_node=DivisibleBy5)))

有状态图

pydantic-graph 中的 “状态” 概念提供了一种可选方式,用于在图中节点运行时访问和更改对象(通常是 dataclass 或 Pydantic 模型)。如果你将图视为生产线,那么你的状态就是沿线传递并由每个节点在图运行时构建的引擎。

未来,我们计划扩展 pydantic-graph 以提供状态持久化,并在每个节点运行后记录状态,请参阅 #695

这是一个表示自动售货机的图的示例,用户可以在其中投入硬币并选择要购买的产品。

vending_machine.py
from __future__ import annotations

from dataclasses import dataclass

from rich.prompt import Prompt

from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class MachineState:  # (1)!
    user_balance: float = 0.0
    product: str | None = None


@dataclass
class InsertCoin(BaseNode[MachineState]):  # (3)!
    async def run(self, ctx: GraphRunContext[MachineState]) -> CoinsInserted:  # (16)!
        return CoinsInserted(float(Prompt.ask('Insert coins')))  # (4)!


@dataclass
class CoinsInserted(BaseNode[MachineState]):
    amount: float  # (5)!

    async def run(
        self, ctx: GraphRunContext[MachineState]
    ) -> SelectProduct | Purchase:  # (17)!
        ctx.state.user_balance += self.amount  # (6)!
        if ctx.state.product is not None:  # (7)!
            return Purchase(ctx.state.product)
        else:
            return SelectProduct()


@dataclass
class SelectProduct(BaseNode[MachineState]):
    async def run(self, ctx: GraphRunContext[MachineState]) -> Purchase:
        return Purchase(Prompt.ask('Select product'))


PRODUCT_PRICES = {  # (2)!
    'water': 1.25,
    'soda': 1.50,
    'crisps': 1.75,
    'chocolate': 2.00,
}


@dataclass
class Purchase(BaseNode[MachineState, None, None]):  # (18)!
    product: str

    async def run(
        self, ctx: GraphRunContext[MachineState]
    ) -> End | InsertCoin | SelectProduct:
        if price := PRODUCT_PRICES.get(self.product):  # (8)!
            ctx.state.product = self.product  # (9)!
            if ctx.state.user_balance >= price:  # (10)!
                ctx.state.user_balance -= price
                return End(None)
            else:
                diff = price - ctx.state.user_balance
                print(f'Not enough money for {self.product}, need {diff:0.2f} more')
                #> Not enough money for crisps, need 0.75 more
                return InsertCoin()  # (11)!
        else:
            print(f'No such product: {self.product}, try again')
            return SelectProduct()  # (12)!


vending_machine_graph = Graph(  # (13)!
    nodes=[InsertCoin, CoinsInserted, SelectProduct, Purchase]
)


async def main():
    state = MachineState()  # (14)!
    await vending_machine_graph.run(InsertCoin(), state=state)  # (15)!
    print(f'purchase successful item={state.product} change={state.user_balance:0.2f}')
    #> purchase successful item=crisps change=0.25
  1. 自动售货机的状态定义为一个 dataclass,其中包含用户的余额以及他们已选择的产品(如果有)。
  2. 产品到价格的字典映射。
  3. InsertCoin 节点,BaseNode 使用 MachineState 参数化,因为这是此图中使用的状态。
  4. InsertCoin 节点提示用户投入硬币。我们通过仅输入货币金额(浮点数)来保持简单。在你开始认为这也是一个玩具,因为它在节点内使用 rich 的 Prompt.ask 之前,请参阅 下方,了解当节点需要外部输入时如何管理控制流。
  5. CoinsInserted 节点;同样,这是一个 dataclass,带有一个字段 amount
  6. 使用投入的金额更新用户的余额。
  7. 如果用户已经选择了产品,则转到 Purchase,否则转到 SelectProduct
  8. Purchase 节点中,如果用户输入了有效产品,则查找产品的价格。
  9. 如果用户确实输入了有效产品,则在状态中设置产品,以便我们不再访问 SelectProduct
  10. 如果余额足以购买产品,则调整余额以反映购买情况,并返回 End 以结束图。我们未使用运行返回类型,因此我们使用 None 调用 End
  11. 如果余额不足,则转到 InsertCoin 以提示用户投入更多硬币。
  12. 如果产品无效,则转到 SelectProduct 以提示用户再次选择产品。
  13. 通过将节点列表传递给 Graph 来创建图。节点的顺序并不重要,但它可能会影响 的显示方式。
  14. 初始化状态。这将传递给图运行并在图运行时发生变化。
  15. 使用初始状态运行图。由于图可以从任何节点运行,因此我们必须传递起始节点 —— 在本例中为 InsertCoinGraph.run 返回一个 GraphRunResult,该结果提供最终数据和运行历史记录。
  16. 节点的 run 方法的返回类型很重要,因为它用于确定节点的传出边。此信息反过来用于渲染 mermaid 图,并在运行时强制执行以尽快检测到错误行为。
  17. CoinsInsertedrun 方法的返回类型是联合,这意味着可能有多个传出边。
  18. 与其他节点不同,Purchase 可以结束运行,因此必须设置 RunEndT 泛型参数。在本例中,它是 None,因为图运行返回类型为 None

(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main()) 来运行 main

可以使用以下代码为该图生成 mermaid 图

vending_machine_diagram.py
from vending_machine import InsertCoin, vending_machine_graph

vending_machine_graph.mermaid_code(start_node=InsertCoin)

上述代码生成的图是

---
title: vending_machine_graph
---
stateDiagram-v2
  [*] --> InsertCoin
  InsertCoin --> CoinsInserted
  CoinsInserted --> SelectProduct
  CoinsInserted --> Purchase
  SelectProduct --> Purchase
  Purchase --> InsertCoin
  Purchase --> SelectProduct
  Purchase --> [*]

有关生成图的更多信息,请参阅 下方

GenAI 示例

到目前为止,我们还没有展示一个实际使用 PydanticAI 或 GenAI 的图的示例。

在此示例中,一个代理生成一封欢迎电子邮件给用户,另一个代理提供有关该电子邮件的反馈。

此图具有非常简单的结构

---
title: feedback_graph
---
stateDiagram-v2
  [*] --> WriteEmail
  WriteEmail --> Feedback
  Feedback --> WriteEmail
  Feedback --> [*]
genai_email_feedback.py
from __future__ import annotations as _annotations

from dataclasses import dataclass, field

from pydantic import BaseModel, EmailStr

from pydantic_ai import Agent
from pydantic_ai.format_as_xml import format_as_xml
from pydantic_ai.messages import ModelMessage
from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class User:
    name: str
    email: EmailStr
    interests: list[str]


@dataclass
class Email:
    subject: str
    body: str


@dataclass
class State:
    user: User
    write_agent_messages: list[ModelMessage] = field(default_factory=list)


email_writer_agent = Agent(
    'google-vertex:gemini-1.5-pro',
    result_type=Email,
    system_prompt='Write a welcome email to our tech blog.',
)


@dataclass
class WriteEmail(BaseNode[State]):
    email_feedback: str | None = None

    async def run(self, ctx: GraphRunContext[State]) -> Feedback:
        if self.email_feedback:
            prompt = (
                f'Rewrite the email for the user:\n'
                f'{format_as_xml(ctx.state.user)}\n'
                f'Feedback: {self.email_feedback}'
            )
        else:
            prompt = (
                f'Write a welcome email for the user:\n'
                f'{format_as_xml(ctx.state.user)}'
            )

        result = await email_writer_agent.run(
            prompt,
            message_history=ctx.state.write_agent_messages,
        )
        ctx.state.write_agent_messages += result.all_messages()
        return Feedback(result.data)


class EmailRequiresWrite(BaseModel):
    feedback: str


class EmailOk(BaseModel):
    pass


feedback_agent = Agent[None, EmailRequiresWrite | EmailOk](
    'openai:gpt-4o',
    result_type=EmailRequiresWrite | EmailOk,  # type: ignore
    system_prompt=(
        'Review the email and provide feedback, email must reference the users specific interests.'
    ),
)


@dataclass
class Feedback(BaseNode[State, None, Email]):
    email: Email

    async def run(
        self,
        ctx: GraphRunContext[State],
    ) -> WriteEmail | End[Email]:
        prompt = format_as_xml({'user': ctx.state.user, 'email': self.email})
        result = await feedback_agent.run(prompt)
        if isinstance(result.data, EmailRequiresWrite):
            return WriteEmail(email_feedback=result.data.feedback)
        else:
            return End(self.email)


async def main():
    user = User(
        name='John Doe',
        email='john.joe@example.com',
        interests=['Haskel', 'Lisp', 'Fortran'],
    )
    state = State(user)
    feedback_graph = Graph(nodes=(WriteEmail, Feedback))
    result = await feedback_graph.run(WriteEmail(), state=state)
    print(result.output)
    """
    Email(
        subject='Welcome to our tech blog!',
        body='Hello John, Welcome to our tech blog! ...',
    )
    """

(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main()) 来运行 main

迭代图

使用 Graph.iter 进行 async for 迭代

有时你希望在图执行时直接控制或深入了解每个节点。最简单的方法是使用 Graph.iter 方法,该方法返回一个上下文管理器,该管理器产生一个 GraphRun 对象。GraphRun 是图中节点的异步可迭代对象,允许你在节点执行时记录或修改它们。

这是一个例子

count_down.py
from __future__ import annotations as _annotations

from dataclasses import dataclass
from pydantic_graph import Graph, BaseNode, End, GraphRunContext


@dataclass
class CountDownState:
    counter: int


@dataclass
class CountDown(BaseNode[CountDownState, None, int]):
    async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
        if ctx.state.counter <= 0:
            return End(ctx.state.counter)
        ctx.state.counter -= 1
        return CountDown()


count_down_graph = Graph(nodes=[CountDown])


async def main():
    state = CountDownState(counter=3)
    async with count_down_graph.iter(CountDown(), state=state) as run:  # (1)!
        async for node in run:  # (2)!
            print('Node:', node)
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: End(data=0)
    print('Final result:', run.result.output)  # (3)!
    #> Final result: 0
  1. Graph.iter(...) 返回一个 GraphRun
  2. 在这里,我们逐步遍历每个节点,因为它是执行的。
  3. 一旦图返回一个 End,循环结束,并且 run.final_result 变为一个 GraphRunResult,其中包含最终结果(此处为 0)。

手动使用 GraphRun.next(node)

或者,你可以使用 GraphRun.next 方法手动驱动迭代,该方法允许你传入你想要接下来运行的任何节点。你可以通过这种方式修改或选择性地跳过节点。

下面是一个人为的示例,它在计数器为 2 时停止,忽略该值之外的任何节点运行

count_down_next.py
from pydantic_graph import End, FullStatePersistence
from count_down import CountDown, CountDownState, count_down_graph


async def main():
    state = CountDownState(counter=5)
    persistence = FullStatePersistence()  # (7)!
    async with count_down_graph.iter(
        CountDown(), state=state, persistence=persistence
    ) as run:
        node = run.next_node  # (1)!
        while not isinstance(node, End):  # (2)!
            print('Node:', node)
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            if state.counter == 2:
                break  # (3)!
            node = await run.next(node)  # (4)!

        print(run.result)  # (5)!
        #> None

        for step in persistence.history:  # (6)!
            print('History Step:', step.state, step.state)
            #> History Step: CountDownState(counter=5) CountDownState(counter=5)
            #> History Step: CountDownState(counter=4) CountDownState(counter=4)
            #> History Step: CountDownState(counter=3) CountDownState(counter=3)
            #> History Step: CountDownState(counter=2) CountDownState(counter=2)
  1. 我们首先获取将在代理图中运行的第一个节点。
  2. 一旦生成 End 节点,代理运行就完成;End 的实例无法传递给 next
  3. 如果用户决定提前停止,我们将跳出循环。在这种情况下,图运行将没有真正的最终结果(run.final_result 仍然为 None)。
  4. 在每个步骤中,我们调用 await run.next(node) 来运行它并获取下一个节点(或 End)。
  5. 因为我们没有继续运行直到完成,所以 result 未设置。
  6. 运行的历史记录仍然填充了我们到目前为止执行的步骤。
  7. 使用 FullStatePersistence,以便我们可以显示运行的历史记录,有关更多信息,请参阅下面的 状态持久化

状态持久化

有限状态机 (FSM) 图的最大好处之一是它们如何简化中断执行的处理。这可能是由于多种原因造成的

  • 状态机逻辑可能从根本上需要暂停 —— 例如,电子商务订单的退货工作流程需要等待物品被寄到退货中心,或者因为下一个节点的执行需要来自用户的输入,因此需要等待新的 http 请求,
  • 执行时间太长,以至于整个图无法在单个连续运行中可靠地执行 —— 例如,一个深度研究代理可能需要数小时才能运行,
  • 你希望在不同的进程/硬件实例中并行运行多个图节点(注意:pydantic-graph 尚不支持并行节点执行,请参阅 #704)。

尝试使传统的控制流(即,布尔逻辑和嵌套函数调用)实现与这些使用场景兼容通常会导致脆弱且过于复杂的意大利面条式代码,其中中断和恢复执行所需的逻辑主导了实现。

为了允许图运行被中断和恢复,pydantic-graph 提供了状态持久化 —— 一种在每个节点运行之前和之后快照图运行状态的系统,允许从图中的任何点恢复图运行。

pydantic-graph 包括三种状态持久化实现

  • SimpleStatePersistence —— 简单的内存状态持久化,仅保存最新的快照。如果在运行图时未提供状态持久化实现,则默认使用此实现。
  • FullStatePersistence —— 内存状态持久化,保存快照列表。
  • FileStatePersistence —— 基于文件的状态持久化,将快照保存到 JSON 文件。

在生产应用程序中,开发人员应通过子类化 BaseStatePersistence 抽象基类来实现自己的状态持久化,这可能会将运行持久化到关系数据库(如 PostgresQL)中。

在高层次上,StatePersistence 实现的作用是存储和检索 NodeSnapshotEndSnapshot 对象。

graph.iter_from_persistence() 可用于基于持久化中存储的状态运行图。

我们可以使用 graph.iter_from_persistence()FileStatePersistence上方 运行 count_down_graph

正如你在此代码中看到的,run_node 运行不需要外部应用程序状态(除了状态持久化),这意味着图可以轻松地通过分布式执行和排队系统执行。

count_down_from_persistence.py
from pathlib import Path

from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence

from count_down import CountDown, CountDownState, count_down_graph


async def main():
    run_id = 'run_abc123'
    persistence = FileStatePersistence(Path(f'count_down_{run_id}.json'))  # (1)!
    state = CountDownState(counter=5)
    await count_down_graph.initialize(  # (2)!
        CountDown(), state=state, persistence=persistence
    )

    done = False
    while not done:
        done = await run_node(run_id)


async def run_node(run_id: str) -> bool:  # (3)!
    persistence = FileStatePersistence(Path(f'count_down_{run_id}.json'))
    async with count_down_graph.iter_from_persistence(persistence) as run:  # (4)!
        node_or_end = await run.next()  # (5)!

    print('Node:', node_or_end)
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: End(data=0)
    return isinstance(node_or_end, End)  # (6)!
  1. 创建一个 FileStatePersistence 以用于启动图。
  2. 调用 graph.initialize() 以在持久化对象中设置初始图状态。
  3. run_node 是一个纯函数,除了运行 ID 之外,不需要访问任何其他进程状态即可运行图的下一个节点。
  4. 调用 graph.iter_from_persistence() 创建一个 GraphRun 对象,该对象将从持久化中存储的状态运行图的下一个节点。这将返回一个节点或一个 End 对象。
  5. graph.run() 将返回一个 node 或一个 End 对象。
  6. 检查节点是否为 End 对象,如果是,则图运行完成。

(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main()) 来运行 main

示例:人机环路。

如上所述,状态持久化允许图被中断和恢复。其中一个用例是允许用户输入继续。

在此示例中,AI 向用户提出一个问题,用户提供答案,AI 评估答案,如果用户回答正确则结束,如果回答错误则提出另一个问题。

我们不是在单个进程调用中运行整个图,而是通过重复运行进程来运行图,可以选择性地将问题的答案作为命令行参数提供。

ai_q_and_a_graph.py —— question_graph 定义
ai_q_and_a_graph.py
from __future__ import annotations as _annotations

from dataclasses import dataclass, field

from groq import BaseModel
from pydantic_graph import (
    BaseNode,
    End,
    Graph,
    GraphRunContext,
)

from pydantic_ai import Agent
from pydantic_ai.format_as_xml import format_as_xml
from pydantic_ai.messages import ModelMessage

ask_agent = Agent('openai:gpt-4o', result_type=str, instrument=True)


@dataclass
class QuestionState:
    question: str | None = None
    ask_agent_messages: list[ModelMessage] = field(default_factory=list)
    evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)


@dataclass
class Ask(BaseNode[QuestionState]):
    async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
        result = await ask_agent.run(
            'Ask a simple question with a single correct answer.',
            message_history=ctx.state.ask_agent_messages,
        )
        ctx.state.ask_agent_messages += result.all_messages()
        ctx.state.question = result.data
        return Answer(result.data)


@dataclass
class Answer(BaseNode[QuestionState]):
    question: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
        answer = input(f'{self.question}: ')
        return Evaluate(answer)


class EvaluationResult(BaseModel, use_attribute_docstrings=True):
    correct: bool
    """Whether the answer is correct."""
    comment: str
    """Comment on the answer, reprimand the user if the answer is wrong."""


evaluate_agent = Agent(
    'openai:gpt-4o',
    result_type=EvaluationResult,
    system_prompt='Given a question and answer, evaluate if the answer is correct.',
)


@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
    answer: str

    async def run(
        self,
        ctx: GraphRunContext[QuestionState],
    ) -> End[str] | Reprimand:
        assert ctx.state.question is not None
        result = await evaluate_agent.run(
            format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
            message_history=ctx.state.evaluate_agent_messages,
        )
        ctx.state.evaluate_agent_messages += result.all_messages()
        if result.data.correct:
            return End(result.data.comment)
        else:
            return Reprimand(result.data.comment)


@dataclass
class Reprimand(BaseNode[QuestionState]):
    comment: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
        print(f'Comment: {self.comment}')
        ctx.state.question = None
        return Ask()


question_graph = Graph(
    nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)

(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行)

ai_q_and_a_run.py
import sys
from pathlib import Path

from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence
from pydantic_ai.messages import ModelMessage  # noqa: F401

from ai_q_and_a_graph import Ask, question_graph, Evaluate, QuestionState, Answer


async def main():
    answer: str | None = sys.argv[2] if len(sys.argv) > 2 else None  # (1)!
    persistence = FileStatePersistence(Path('question_graph.json'))  # (2)!
    persistence.set_graph_types(question_graph)  # (3)!

    if snapshot := await persistence.load_next():  # (4)!
        state = snapshot.state
        assert answer is not None
        node = Evaluate(answer)
    else:
        state = QuestionState()
        node = Ask()  # (5)!

    async with question_graph.iter(node, state=state, persistence=persistence) as run:
        while True:
            node = await run.next()  # (6)!
            if isinstance(node, End):  # (7)!
                print('END:', node.data)
                history = await persistence.load_all()  # (8)!
                print([e.node for e in history])
                break
            elif isinstance(node, Answer):  # (9)!
                print(node.question)
                #> What is the capital of France?
                break
            # otherwise just continue
  1. 从命令行获取用户的答案(如果提供)。有关完整示例,请参阅 问题图示例
  2. 创建一个状态持久化实例,'question_graph.json' 文件可能已经存在,也可能不存在。
  3. 由于我们在图外部使用 持久化接口,因此我们需要调用 set_graph_types 以设置持久化实例的图泛型类型 StateTRunEndT。这对于允许持久化实例知道如何序列化和反序列化图节点是必要的。
  4. 如果我们之前运行过图,load_next 将返回要运行的下一个节点的快照,此处我们使用该快照中的 state,并使用命令行上提供的答案创建一个新的 Evaluate 节点。
  5. 如果图之前未运行过,我们创建一个新的 QuestionState 并从 Ask 节点开始。
  6. 调用 GraphRun.next() 以运行节点。这将返回一个节点或一个 End 对象。
  7. 如果节点是 End 对象,则图运行完成。End 对象的 data 字段包含 evaluate_agent 返回的关于正确答案的注释。
  8. 为了演示状态持久化,我们调用 load_all 以从持久化实例获取所有快照。这将返回 Snapshot 对象列表。
  9. 如果节点是 Answer 对象,我们打印问题并跳出循环以结束进程并等待用户输入。

(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main(answer)) 来运行 main

有关此图的完整示例,请参阅 问题图示例

依赖注入

与 PydanticAI 一样,pydantic-graph 通过 GraphBaseNode 上的泛型参数以及 GraphRunContext.deps 字段支持依赖注入。

作为依赖注入的示例,让我们修改 上面DivisibleBy5 示例,以使用 ProcessPoolExecutor 在单独的进程中运行计算负载(这是一个人为的示例,ProcessPoolExecutor 实际上不会在此示例中提高性能)

deps_example.py
from __future__ import annotations

import asyncio
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass

from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class GraphDeps:
    executor: ProcessPoolExecutor


@dataclass
class DivisibleBy5(BaseNode[None, GraphDeps, int]):
    foo: int

    async def run(
        self,
        ctx: GraphRunContext[None, GraphDeps],
    ) -> Increment | End[int]:
        if self.foo % 5 == 0:
            return End(self.foo)
        else:
            return Increment(self.foo)


@dataclass
class Increment(BaseNode[None, GraphDeps]):
    foo: int

    async def run(self, ctx: GraphRunContext[None, GraphDeps]) -> DivisibleBy5:
        loop = asyncio.get_running_loop()
        compute_result = await loop.run_in_executor(
            ctx.deps.executor,
            self.compute,
        )
        return DivisibleBy5(compute_result)

    def compute(self) -> int:
        return self.foo + 1


fives_graph = Graph(nodes=[DivisibleBy5, Increment])


async def main():
    with ProcessPoolExecutor() as executor:
        deps = GraphDeps(executor)
        result = await fives_graph.run(DivisibleBy5(3), deps=deps)
    print(result.output)
    #> 5
    # the full history is quite verbose (see below), so we'll just print the summary
    print([item.data_snapshot() for item in result.history])
    """
    [
        DivisibleBy5(foo=3),
        Increment(foo=3),
        DivisibleBy5(foo=4),
        Increment(foo=4),
        DivisibleBy5(foo=5),
        End(data=5),
    ]
    """

(此示例是完整的,可以使用 Python 3.10+ “按原样” 运行 —— 你需要添加 asyncio.run(main()) 来运行 main

Mermaid 图

Pydantic Graph 可以为图生成 mermaid stateDiagram-v2 图,如上所示。

这些图可以使用以下方法生成

除了上面显示的图之外,你还可以使用以下选项自定义 mermaid 图

将它们放在一起,我们可以编辑最后一个 ai_q_and_a_graph.py 示例以

  • 向某些边添加标签
  • Ask 节点添加注释
  • 突出显示 Answer 节点
  • 将图另存为 PNG 图像文件
ai_q_and_a_graph_extra.py
...
from typing import Annotated

from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Edge

...

@dataclass
class Ask(BaseNode[QuestionState]):
    """Generate question using GPT-4o."""
    docstring_notes = True
    async def run(
        self, ctx: GraphRunContext[QuestionState]
    ) -> Annotated[Answer, Edge(label='Ask the question')]:
        ...

...

@dataclass
class Evaluate(BaseNode[QuestionState]):
    answer: str

    async def run(
            self,
            ctx: GraphRunContext[QuestionState],
    ) -> Annotated[End[str], Edge(label='success')] | Reprimand:
        ...

...

question_graph.mermaid_save('image.png', highlighted_nodes=[Answer])

(此示例不完整,无法直接运行)

这将生成如下所示的图像

---
title: question_graph
---
stateDiagram-v2
  Ask --> Answer: Ask the question
  note right of Ask
    Judge the answer.
    Decide on next step.
  end note
  Answer --> Evaluate
  Evaluate --> Reprimand
  Evaluate --> [*]: success
  Reprimand --> Ask

classDef highlighted fill:#fdff32
class Answer highlighted

设置状态图的方向

你可以使用以下值之一指定状态图的方向

  • 'TB':从上到下,图垂直地从上到下流动。
  • 'LR':从左到右,图水平地从左到右流动。
  • 'RL':从右到左,图水平地从右到左流动。
  • 'BT':从下到上,图垂直地从下到上流动。

这是一个如何使用 “从左到右”(LR) 而不是默认的 “从上到下”(TB) 的示例

vending_machine_diagram.py
from vending_machine import InsertCoin, vending_machine_graph

vending_machine_graph.mermaid_code(start_node=InsertCoin, direction='LR')

---
title: vending_machine_graph
---
stateDiagram-v2
  direction LR
  [*] --> InsertCoin
  InsertCoin --> CoinsInserted
  CoinsInserted --> SelectProduct
  CoinsInserted --> Purchase
  SelectProduct --> Purchase
  Purchase --> InsertCoin
  Purchase --> SelectProduct
  Purchase --> [*]