跳转到内容

除非你需要钉枪,否则别用钉枪

如果说 Pydantic AI 代理是锤子,多代理工作流是长柄大锤,那么图就是钉枪。

  • 当然,钉枪看起来比锤子酷。
  • 但是钉枪比锤子需要更多的设置工作。
  • 而且钉枪不会让你成为一个更好的建造者,只会让你成为一个拥有钉枪的建造者。
  • 最后(冒着过度使用这个比喻的风险),如果你是像木槌和无类型 Python 这样的中世纪工具的粉丝,你可能不会喜欢钉枪或我们的图方法。(但话说回来,如果你不喜欢 Python 中的类型提示,你可能已经从 Pydantic AI 转而去使用那些玩具代理框架了——祝你好运,当你意识到你需要长柄大锤时,随时可以来借我的。)

简而言之,图是一个强大的工具,但它并非适用于所有工作。在继续之前,请考虑其他多代理方法

如果你不确定基于图的方法是个好主意,那么它可能是不必要的。

图和有限状态机(FSM)是用于建模、执行、控制和可视化复杂工作流的强大抽象。

与 Pydantic AI 一起,我们开发了 pydantic-graph —— 一个用于 Python 的异步图和状态机库,其中的节点和边是使用类型提示定义的。

虽然这个库是作为 Pydantic AI 的一部分开发的,但它对 pydantic-ai 没有任何依赖,可以被视为一个纯粹的基于图的状态机库。无论你是否在使用 Pydantic AI,甚至是否在用生成式 AI 构建应用,你都可能会发现它很有用。

pydantic-graph 是为高级用户设计的,大量使用了 Python 的泛型和类型提示。它不像 Pydantic AI 那样对初学者友好。

安装

pydantic-graphpydantic-ai 的必需依赖项,是 pydantic-ai-slim 的可选依赖项,更多信息请参见安装说明。你也可以直接安装它。

pip install pydantic-graph
uv add pydantic-graph

图类型

pydantic-graph 由几个关键组件组成

GraphRunContext

GraphRunContext — 图运行的上下文,类似于 Pydantic AI 的 RunContext。它持有图的状态和依赖项,并在节点运行时传递给节点。

GraphRunContext 在其使用的图的状态类型 StateT 上是泛型的。

End

End — 返回值,用于指示图的运行应该结束。

End 在其使用的图的图返回类型 RunEndT 上是泛型的。

节点

BaseNode 的子类定义了在图中执行的节点。

节点通常是dataclass,通常由以下部分组成:

  • 包含调用节点时必需/可选的任何参数的字段
  • run 方法中执行节点的业务逻辑
  • run 方法的返回注解,pydantic-graph 会读取这些注解来确定节点的出边

节点在以下方面是泛型的:

  • state(状态),它必须与包含它们的图的状态具有相同的类型,StateT 的默认值为 None,所以如果你不使用状态,可以省略这个泛型参数,更多信息请参见有状态图
  • deps(依赖),它必须与包含它们的图的依赖具有相同的类型,DepsT 的默认值为 None,所以如果你不使用依赖,可以省略这个泛型参数,更多信息请参见依赖注入
  • graph return type(图返回类型)— 这仅在节点返回 End 时适用。RunEndT 的默认值为 Never,因此如果节点不返回 End,可以省略这个泛型参数,但如果返回,则必须包含。

这是一个图中开始节点或中间节点的示例——它不能结束运行,因为它不返回 End

intermediate_node.py
from dataclasses import dataclass

from pydantic_graph import BaseNode, GraphRunContext


@dataclass
class MyNode(BaseNode[MyState]):  # (1)!
    foo: int  # (2)!

    async def run(
        self,
        ctx: GraphRunContext[MyState],  # (3)!
    ) -> AnotherNode:  # (4)!
        ...
        return AnotherNode()
  1. 本例中的状态是 MyState(未显示),因此 BaseNode 使用 MyState 进行参数化。此节点无法结束运行,因此省略了 RunEndT 泛型参数,其默认值为 Never
  2. MyNode 是一个 dataclass,它有一个字段 foo,类型为 int
  3. run 方法接受一个 GraphRunContext 参数,同样使用状态 MyState 进行参数化。
  4. run 方法的返回类型是 AnotherNode(未显示),这用于确定节点的出边。

我们可以扩展 MyNode,使其在 foo 能被 5 整除时可以选择性地结束运行。

intermediate_or_end_node.py
from dataclasses import dataclass

from pydantic_graph import BaseNode, End, GraphRunContext


@dataclass
class MyNode(BaseNode[MyState, None, int]):  # (1)!
    foo: int

    async def run(
        self,
        ctx: GraphRunContext[MyState],
    ) -> AnotherNode | End[int]:  # (2)!
        if self.foo % 5 == 0:
            return End(self.foo)
        else:
            return AnotherNode()
  1. 我们使用返回类型(本例中为 int)以及状态来参数化节点。因为泛型参数是仅位置参数,我们必须包含 None 作为代表 deps 的第二个参数。
  2. run 方法的返回类型现在是 AnotherNodeEnd[int] 的联合类型,这允许节点在 foo 能被 5 整除时结束运行。

Graph — 这是执行图本身,由一组节点类(即 BaseNode 的子类)组成。

Graph 在以下方面是泛型的:

  • state(状态)图的状态类型,StateT
  • deps(依赖)图的依赖类型,DepsT
  • graph return type(图返回类型)图运行的返回类型,RunEndT

这是一个简单图的示例

graph_example.py
from __future__ import annotations

from dataclasses import dataclass

from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class DivisibleBy5(BaseNode[None, None, int]):  # (1)!
    foo: int

    async def run(
        self,
        ctx: GraphRunContext,
    ) -> Increment | End[int]:
        if self.foo % 5 == 0:
            return End(self.foo)
        else:
            return Increment(self.foo)


@dataclass
class Increment(BaseNode):  # (2)!
    foo: int

    async def run(self, ctx: GraphRunContext) -> DivisibleBy5:
        return DivisibleBy5(self.foo + 1)


fives_graph = Graph(nodes=[DivisibleBy5, Increment])  # (3)!
result = fives_graph.run_sync(DivisibleBy5(4))  # (4)!
print(result.output)
#> 5
  1. DivisibleBy5 节点的状态参数和依赖参数都用 None 进行了参数化,因为此图不使用状态或依赖;而因为它可能结束运行,所以返回类型为 int
  2. Increment 节点不返回 End,因此省略了 RunEndT 泛型参数;由于图不使用状态,状态参数也可以省略。
  3. 图是通过一个节点序列创建的。
  4. 图使用 run_sync 同步运行。初始节点是 DivisibleBy5(4)。因为图不使用外部状态或依赖,我们不传递 statedeps

(这个例子是完整的,可以“按原样”运行)

此图的 Mermaid 图表可以使用以下代码生成

graph_example_diagram.py
from graph_example import DivisibleBy5, fives_graph

fives_graph.mermaid_code(start_node=DivisibleBy5)
---
title: fives_graph
---
stateDiagram-v2
  [*] --> DivisibleBy5
  DivisibleBy5 --> Increment
  DivisibleBy5 --> [*]
  Increment --> DivisibleBy5

为了在 jupyter-notebook 中可视化图,需要使用 IPython.display

jupyter_display_mermaid.py
from graph_example import DivisibleBy5, fives_graph
from IPython.display import Image, display

display(Image(fives_graph.mermaid_image(start_node=DivisibleBy5)))

有状态图

pydantic-graph 中的“状态”概念提供了一种可选的方式,在图中的节点运行时访问和修改一个对象(通常是 dataclass 或 Pydantic 模型)。如果你将图想象成一条生产线,那么你的状态就是沿着生产线传递并由每个节点在图运行时逐步构建的引擎。

pydantic-graph 提供状态持久化,在每个节点运行后记录状态。(请参见状态持久化。)

这是一个表示自动售货机的图的示例,用户可以投入硬币并选择要购买的商品。

vending_machine.py
from __future__ import annotations

from dataclasses import dataclass

from rich.prompt import Prompt

from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class MachineState:  # (1)!
    user_balance: float = 0.0
    product: str | None = None


@dataclass
class InsertCoin(BaseNode[MachineState]):  # (3)!
    async def run(self, ctx: GraphRunContext[MachineState]) -> CoinsInserted:  # (16)!
        return CoinsInserted(float(Prompt.ask('Insert coins')))  # (4)!


@dataclass
class CoinsInserted(BaseNode[MachineState]):
    amount: float  # (5)!

    async def run(
        self, ctx: GraphRunContext[MachineState]
    ) -> SelectProduct | Purchase:  # (17)!
        ctx.state.user_balance += self.amount  # (6)!
        if ctx.state.product is not None:  # (7)!
            return Purchase(ctx.state.product)
        else:
            return SelectProduct()


@dataclass
class SelectProduct(BaseNode[MachineState]):
    async def run(self, ctx: GraphRunContext[MachineState]) -> Purchase:
        return Purchase(Prompt.ask('Select product'))


PRODUCT_PRICES = {  # (2)!
    'water': 1.25,
    'soda': 1.50,
    'crisps': 1.75,
    'chocolate': 2.00,
}


@dataclass
class Purchase(BaseNode[MachineState, None, None]):  # (18)!
    product: str

    async def run(
        self, ctx: GraphRunContext[MachineState]
    ) -> End | InsertCoin | SelectProduct:
        if price := PRODUCT_PRICES.get(self.product):  # (8)!
            ctx.state.product = self.product  # (9)!
            if ctx.state.user_balance >= price:  # (10)!
                ctx.state.user_balance -= price
                return End(None)
            else:
                diff = price - ctx.state.user_balance
                print(f'Not enough money for {self.product}, need {diff:0.2f} more')
                #> Not enough money for crisps, need 0.75 more
                return InsertCoin()  # (11)!
        else:
            print(f'No such product: {self.product}, try again')
            return SelectProduct()  # (12)!


vending_machine_graph = Graph(  # (13)!
    nodes=[InsertCoin, CoinsInserted, SelectProduct, Purchase]
)


async def main():
    state = MachineState()  # (14)!
    await vending_machine_graph.run(InsertCoin(), state=state)  # (15)!
    print(f'purchase successful item={state.product} change={state.user_balance:0.2f}')
    #> purchase successful item=crisps change=0.25
  1. 自动售货机的状态被定义为一个 dataclass,包含用户的余额和他们选择的商品(如果有的话)。
  2. 一个将产品映射到价格的字典。
  3. InsertCoin 节点,BaseNode 使用 MachineState 进行参数化,因为这是此图中使用的状态。
  4. InsertCoin 节点提示用户投入硬币。我们通过仅输入一个浮点数金额来简化操作。在你开始认为这也是一个玩具,因为它在节点内使用了 rich 的 Prompt.ask 之前,请看下面关于当节点需要外部输入时如何管理控制流的内容。
  5. CoinsInserted 节点;同样,这是一个带有一个字段 amountdataclass
  6. 用投入的金额更新用户的余额。
  7. 如果用户已经选择了商品,则转到 Purchase,否则转到 SelectProduct
  8. Purchase 节点中,如果用户输入了有效的产品,则查找该产品的价格。
  9. 如果用户确实输入了有效的产品,则在状态中设置该产品,这样我们就不会重新访问 SelectProduct
  10. 如果余额足以购买该产品,则调整余额以反映购买情况,并返回 End 来结束图的运行。我们没有使用运行的返回类型,所以我们用 None 调用 End
  11. 如果余额不足,则转到 InsertCoin 提示用户投入更多硬币。
  12. 如果产品无效,则转到 SelectProduct 提示用户重新选择产品。
  13. 图是通过向 Graph 传递一个节点列表来创建的。节点的顺序不重要,但它可能会影响图表的显示方式。
  14. 初始化状态。这将被传递给图的运行,并在图运行时被修改。
  15. 使用初始状态运行图。由于图可以从任何节点开始运行,我们必须传递起始节点——在本例中是 InsertCoinGraph.run 返回一个 GraphRunResult,它提供了最终数据和运行的历史记录。
  16. 节点的 run 方法的返回类型很重要,因为它用于确定节点的出边。这些信息反过来又用于渲染 Mermaid 图表,并在运行时强制执行,以尽快检测不当行为。
  17. CoinsInsertedrun 方法的返回类型是一个联合类型,意味着可能有多条出边。
  18. 与其他节点不同,Purchase 可以结束运行,因此必须设置 RunEndT 泛型参数。在这种情况下,它是 None,因为图运行的返回类型是 None

(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main()) 来运行 main

此图的 Mermaid 图表可以使用以下代码生成

vending_machine_diagram.py
from vending_machine import InsertCoin, vending_machine_graph

vending_machine_graph.mermaid_code(start_node=InsertCoin)

由上述代码生成的图表是

---
title: vending_machine_graph
---
stateDiagram-v2
  [*] --> InsertCoin
  InsertCoin --> CoinsInserted
  CoinsInserted --> SelectProduct
  CoinsInserted --> Purchase
  SelectProduct --> Purchase
  Purchase --> InsertCoin
  Purchase --> SelectProduct
  Purchase --> [*]

有关生成图表的更多信息,请参见下文

生成式 AI 示例

到目前为止,我们还没有展示一个实际使用 Pydantic AI 或生成式 AI 的图的示例。

在这个例子中,一个代理生成一封给用户的欢迎邮件,另一个代理对该邮件提供反馈。

这个图的结构非常简单

---
title: feedback_graph
---
stateDiagram-v2
  [*] --> WriteEmail
  WriteEmail --> Feedback
  Feedback --> WriteEmail
  Feedback --> [*]
genai_email_feedback.py
from __future__ import annotations as _annotations

from dataclasses import dataclass, field

from pydantic import BaseModel, EmailStr

from pydantic_ai import Agent, format_as_xml
from pydantic_ai.messages import ModelMessage
from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class User:
    name: str
    email: EmailStr
    interests: list[str]


@dataclass
class Email:
    subject: str
    body: str


@dataclass
class State:
    user: User
    write_agent_messages: list[ModelMessage] = field(default_factory=list)


email_writer_agent = Agent(
    'google-gla:gemini-1.5-pro',
    output_type=Email,
    system_prompt='Write a welcome email to our tech blog.',
)


@dataclass
class WriteEmail(BaseNode[State]):
    email_feedback: str | None = None

    async def run(self, ctx: GraphRunContext[State]) -> Feedback:
        if self.email_feedback:
            prompt = (
                f'Rewrite the email for the user:\n'
                f'{format_as_xml(ctx.state.user)}\n'
                f'Feedback: {self.email_feedback}'
            )
        else:
            prompt = (
                f'Write a welcome email for the user:\n'
                f'{format_as_xml(ctx.state.user)}'
            )

        result = await email_writer_agent.run(
            prompt,
            message_history=ctx.state.write_agent_messages,
        )
        ctx.state.write_agent_messages += result.new_messages()
        return Feedback(result.output)


class EmailRequiresWrite(BaseModel):
    feedback: str


class EmailOk(BaseModel):
    pass


feedback_agent = Agent[None, EmailRequiresWrite | EmailOk](
    'openai:gpt-4o',
    output_type=EmailRequiresWrite | EmailOk,  # type: ignore
    system_prompt=(
        'Review the email and provide feedback, email must reference the users specific interests.'
    ),
)


@dataclass
class Feedback(BaseNode[State, None, Email]):
    email: Email

    async def run(
        self,
        ctx: GraphRunContext[State],
    ) -> WriteEmail | End[Email]:
        prompt = format_as_xml({'user': ctx.state.user, 'email': self.email})
        result = await feedback_agent.run(prompt)
        if isinstance(result.output, EmailRequiresWrite):
            return WriteEmail(email_feedback=result.output.feedback)
        else:
            return End(self.email)


async def main():
    user = User(
        name='John Doe',
        email='john.joe@example.com',
        interests=['Haskel', 'Lisp', 'Fortran'],
    )
    state = State(user)
    feedback_graph = Graph(nodes=(WriteEmail, Feedback))
    result = await feedback_graph.run(WriteEmail(), state=state)
    print(result.output)
    """
    Email(
        subject='Welcome to our tech blog!',
        body='Hello John, Welcome to our tech blog! ...',
    )
    """

(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main()) 来运行 main

迭代图

使用 Graph.iter 进行 async for 迭代

有时你希望在图执行时直接控制或洞察每个节点。最简单的方法是使用 Graph.iter 方法,它返回一个**上下文管理器**,该管理器会产生一个 GraphRun 对象。GraphRun 是一个针对图中节点的可异步迭代对象,允许你在它们执行时记录或修改它们。

这是一个例子

count_down.py
from __future__ import annotations as _annotations

from dataclasses import dataclass
from pydantic_graph import Graph, BaseNode, End, GraphRunContext


@dataclass
class CountDownState:
    counter: int


@dataclass
class CountDown(BaseNode[CountDownState, None, int]):
    async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
        if ctx.state.counter <= 0:
            return End(ctx.state.counter)
        ctx.state.counter -= 1
        return CountDown()


count_down_graph = Graph(nodes=[CountDown])


async def main():
    state = CountDownState(counter=3)
    async with count_down_graph.iter(CountDown(), state=state) as run:  # (1)!
        async for node in run:  # (2)!
            print('Node:', node)
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: End(data=0)
    print('Final output:', run.result.output)  # (3)!
    #> Final output: 0
  1. Graph.iter(...) 返回一个 GraphRun
  2. 在这里,我们逐步执行每个节点。
  3. 一旦图返回一个 End,循环就会结束,run.result 会变成一个 GraphRunResult,其中包含最终结果(这里是 0)。

手动使用 GraphRun.next(node)

或者,你可以使用 GraphRun.next 方法手动驱动迭代,这允许你传入任何你想要接下来运行的节点。通过这种方式,你可以修改或选择性地跳过节点。

下面是一个刻意设计的例子,它在计数器为 2 时停止,忽略了之后的所有节点运行

count_down_next.py
from pydantic_graph import End, FullStatePersistence
from count_down import CountDown, CountDownState, count_down_graph


async def main():
    state = CountDownState(counter=5)
    persistence = FullStatePersistence()  # (7)!
    async with count_down_graph.iter(
        CountDown(), state=state, persistence=persistence
    ) as run:
        node = run.next_node  # (1)!
        while not isinstance(node, End):  # (2)!
            print('Node:', node)
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            if state.counter == 2:
                break  # (3)!
            node = await run.next(node)  # (4)!

        print(run.result)  # (5)!
        #> None

        for step in persistence.history:  # (6)!
            print('History Step:', step.state, step.state)
            #> History Step: CountDownState(counter=5) CountDownState(counter=5)
            #> History Step: CountDownState(counter=4) CountDownState(counter=4)
            #> History Step: CountDownState(counter=3) CountDownState(counter=3)
            #> History Step: CountDownState(counter=2) CountDownState(counter=2)
  1. 我们首先获取代理图中将要运行的第一个节点。
  2. 一旦产生了 End 节点,代理运行就完成了;End 的实例不能传递给 next
  3. 如果用户决定提前停止,我们就跳出循环。在这种情况下,图的运行将没有真正的最终结果(run.result 保持为 None)。
  4. 在每一步,我们调用 await run.next(node) 来运行它并获取下一个节点(或一个 End)。
  5. 因为我们没有继续运行直到它完成,所以 result 没有被设置。
  6. 运行的历史记录仍然包含了我们已经执行的步骤。
  7. 使用 FullStatePersistence,这样我们就可以显示运行的历史记录,更多信息请参见下面的状态持久化

状态持久化

有限状态机(FSM)图最大的好处之一是它们如何简化对中断执行的处理。这可能由于多种原因发生:

  • 状态机逻辑可能根本上需要暂停——例如,电子商务订单的退货工作流需要等待物品被寄到退货中心,或者因为下一个节点的执行需要用户的输入,所以需要等待一个新的 http 请求。
  • 执行时间太长,以至于整个图无法可靠地在一次连续运行中执行完毕——例如,一个可能需要数小时运行的深度研究代理。
  • 你希望在不同的进程/硬件实例中并行运行多个图节点(注意:pydantic-graph 目前尚不支持并行节点执行,请参见 #704)。

试图使传统的控制流(即布尔逻辑和嵌套函数调用)实现与这些使用场景兼容,通常会导致代码脆弱且过于复杂的“意大利面条式代码”,其中用于中断和恢复执行的逻辑主导了整个实现。

为了允许图的运行可以被中断和恢复,pydantic-graph 提供了状态持久化——一个在每个节点运行前后对图运行状态进行快照的系统,从而允许图的运行可以从图中的任何一点恢复。

pydantic-graph 包含了三种状态持久化实现

  • SimpleStatePersistence — 简单的内存状态持久化,只保存最新的快照。如果在运行图时没有提供状态持久化实现,则默认使用此实现。
  • FullStatePersistence — 内存状态持久化,保存一个快照列表。
  • FileStatePersistence — 基于文件的状态持久化,将快照保存到 JSON 文件中。

在生产应用中,开发者应该通过子类化 BaseStatePersistence 抽象基类来实现自己的状态持久化,例如将运行状态持久化到像 PostgresQL 这样的关系型数据库中。

从高层次来看,StatePersistence 实现的角色是存储和检索 NodeSnapshotEndSnapshot 对象。

graph.iter_from_persistence() 可以用来基于持久化中存储的状态来运行图。

我们可以使用 graph.iter_from_persistence()FileStatePersistence 来运行上面count_down_graph

正如你在这段代码中看到的,run_node 的运行不需要任何外部应用状态(除了状态持久化),这意味着图可以很容易地由分布式执行和排队系统来执行。

count_down_from_persistence.py
from pathlib import Path

from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence

from count_down import CountDown, CountDownState, count_down_graph


async def main():
    run_id = 'run_abc123'
    persistence = FileStatePersistence(Path(f'count_down_{run_id}.json'))  # (1)!
    state = CountDownState(counter=5)
    await count_down_graph.initialize(  # (2)!
        CountDown(), state=state, persistence=persistence
    )

    done = False
    while not done:
        done = await run_node(run_id)


async def run_node(run_id: str) -> bool:  # (3)!
    persistence = FileStatePersistence(Path(f'count_down_{run_id}.json'))
    async with count_down_graph.iter_from_persistence(persistence) as run:  # (4)!
        node_or_end = await run.next()  # (5)!

    print('Node:', node_or_end)
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: End(data=0)
    return isinstance(node_or_end, End)  # (6)!
  1. 创建一个 FileStatePersistence 用于启动图。
  2. 调用 graph.initialize() 在持久化对象中设置初始图状态。
  3. run_node 是一个纯函数,除了运行的 ID 外,不需要访问任何其他进程状态来运行图的下一个节点。
  4. 调用 graph.iter_from_persistence() 创建一个 GraphRun 对象,它将根据持久化中存储的状态运行图的下一个节点。这将返回一个节点或一个 End 对象。
  5. graph.run() 将返回一个 节点 或一个 End 对象。
  6. 检查节点是否为 End 对象,如果是,则图的运行完成。

(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main()) 来运行 main

示例:人在回路。

如上所述,状态持久化允许图被中断和恢复。其中一个用例是允许用户输入以继续进行。

在这个例子中,一个 AI 向用户提问,用户提供答案,AI 评估答案,如果用户答对了就结束,如果答错了就再问一个问题。

我们不是在单个进程调用中运行整个图,而是通过重复运行进程来运行图,并可选择地通过命令行参数提供问题的答案。

ai_q_and_a_graph.pyquestion_graph 定义
ai_q_and_a_graph.py
from __future__ import annotations as _annotations

from typing import Annotated
from pydantic_graph import Edge
from dataclasses import dataclass, field
from pydantic import BaseModel
from pydantic_graph import (
    BaseNode,
    End,
    Graph,
    GraphRunContext,
)
from pydantic_ai import Agent, format_as_xml
from pydantic_ai.messages import ModelMessage

ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True)


@dataclass
class QuestionState:
    question: str | None = None
    ask_agent_messages: list[ModelMessage] = field(default_factory=list)
    evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)


@dataclass
class Ask(BaseNode[QuestionState]):
    """Generate question using GPT-4o."""
    docstring_notes = True
    async def run(
        self, ctx: GraphRunContext[QuestionState]
    ) -> Annotated[Answer, Edge(label='Ask the question')]:
        result = await ask_agent.run(
            'Ask a simple question with a single correct answer.',
            message_history=ctx.state.ask_agent_messages,
        )
        ctx.state.ask_agent_messages += result.new_messages()
        ctx.state.question = result.output
        return Answer(result.output)


@dataclass
class Answer(BaseNode[QuestionState]):
    question: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
        answer = input(f'{self.question}: ')
        return Evaluate(answer)


class EvaluationResult(BaseModel, use_attribute_docstrings=True):
    correct: bool
    """Whether the answer is correct."""
    comment: str
    """Comment on the answer, reprimand the user if the answer is wrong."""


evaluate_agent = Agent(
    'openai:gpt-4o',
    output_type=EvaluationResult,
    system_prompt='Given a question and answer, evaluate if the answer is correct.',
)


@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
    answer: str

    async def run(
        self,
        ctx: GraphRunContext[QuestionState],
    ) -> Annotated[End[str], Edge(label='success')] | Reprimand:
        assert ctx.state.question is not None
        result = await evaluate_agent.run(
            format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
            message_history=ctx.state.evaluate_agent_messages,
        )
        ctx.state.evaluate_agent_messages += result.new_messages()
        if result.output.correct:
            return End(result.output.comment)
        else:
            return Reprimand(result.output.comment)


@dataclass
class Reprimand(BaseNode[QuestionState]):
    comment: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
        print(f'Comment: {self.comment}')
        ctx.state.question = None
        return Ask()


question_graph = Graph(
    nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)

(这个例子是完整的,可以“按原样”运行)

ai_q_and_a_run.py
import sys
from pathlib import Path

from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence
from pydantic_ai.messages import ModelMessage  # noqa: F401

from ai_q_and_a_graph import Ask, question_graph, Evaluate, QuestionState, Answer


async def main():
    answer: str | None = sys.argv[1] if len(sys.argv) > 1 else None  # (1)!
    persistence = FileStatePersistence(Path('question_graph.json'))  # (2)!
    persistence.set_graph_types(question_graph)  # (3)!

    if snapshot := await persistence.load_next():  # (4)!
        state = snapshot.state
        assert answer is not None
        node = Evaluate(answer)
    else:
        state = QuestionState()
        node = Ask()  # (5)!

    async with question_graph.iter(node, state=state, persistence=persistence) as run:
        while True:
            node = await run.next()  # (6)!
            if isinstance(node, End):  # (7)!
                print('END:', node.data)
                history = await persistence.load_all()  # (8)!
                print([e.node for e in history])
                break
            elif isinstance(node, Answer):  # (9)!
                print(node.question)
                #> What is the capital of France?
                break
            # otherwise just continue
  1. 从命令行获取用户的答案(如果提供)。完整的示例请参见问题图示例
  2. 创建一个状态持久化实例,'question_graph.json' 文件可能存在也可能不存在。
  3. 由于我们在图外部使用持久化接口,我们需要调用 set_graph_types 来为持久化实例设置图的泛型类型 StateTRunEndT。这是必需的,以便持久化实例知道如何序列化和反序列化图节点。
  4. 如果我们之前运行过这个图,load_next 将返回下一个要运行的节点的快照,这里我们使用该快照中的 state,并用命令行提供的答案创建一个新的 Evaluate 节点。
  5. 如果图之前没有运行过,我们创建一个新的 QuestionState 并从 Ask 节点开始。
  6. 调用 GraphRun.next() 来运行节点。这将返回一个节点或一个 End 对象。
  7. 如果节点是 End 对象,则图的运行完成。End 对象的 data 字段包含了 evaluate_agent 返回的关于正确答案的评论。
  8. 为了演示状态持久化,我们调用 load_all 从持久化实例中获取所有快照。这将返回一个 Snapshot 对象列表。
  9. 如果节点是 Answer 对象,我们打印问题并跳出循环以结束进程并等待用户输入。

(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main()) 来运行 main

关于此图的完整示例,请参见问题图示例

依赖注入

与 Pydantic AI 一样,pydantic-graph 通过 GraphBaseNode 上的泛型参数以及 GraphRunContext.deps 字段支持依赖注入。

作为依赖注入的一个例子,让我们修改上面DivisibleBy5 例子,使用一个 ProcessPoolExecutor 在一个单独的进程中运行计算负载(这是一个刻意设计的例子,ProcessPoolExecutor 在这个例子中实际上不会提高性能)。

deps_example.py
from __future__ import annotations

import asyncio
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass

from pydantic_graph import BaseNode, End, FullStatePersistence, Graph, GraphRunContext


@dataclass
class GraphDeps:
    executor: ProcessPoolExecutor


@dataclass
class DivisibleBy5(BaseNode[None, GraphDeps, int]):
    foo: int

    async def run(
        self,
        ctx: GraphRunContext[None, GraphDeps],
    ) -> Increment | End[int]:
        if self.foo % 5 == 0:
            return End(self.foo)
        else:
            return Increment(self.foo)


@dataclass
class Increment(BaseNode[None, GraphDeps]):
    foo: int

    async def run(self, ctx: GraphRunContext[None, GraphDeps]) -> DivisibleBy5:
        loop = asyncio.get_running_loop()
        compute_result = await loop.run_in_executor(
            ctx.deps.executor,
            self.compute,
        )
        return DivisibleBy5(compute_result)

    def compute(self) -> int:
        return self.foo + 1


fives_graph = Graph(nodes=[DivisibleBy5, Increment])


async def main():
    with ProcessPoolExecutor() as executor:
        deps = GraphDeps(executor)
        result = await fives_graph.run(DivisibleBy5(3), deps=deps, persistence=FullStatePersistence())
    print(result.output)
    #> 5
    # the full history is quite verbose (see below), so we'll just print the summary
    print([item.node for item in result.persistence.history])
    """
    [
        DivisibleBy5(foo=3),
        Increment(foo=3),
        DivisibleBy5(foo=4),
        Increment(foo=4),
        DivisibleBy5(foo=5),
        End(data=5),
    ]
    """

(此示例是完整的,可以“按原样”运行——您需要添加 asyncio.run(main()) 来运行 main

Mermaid 图表

Pydantic Graph 可以为图生成 MermaidstateDiagram-v2 图表,如上所示。

这些图表可以通过以下方式生成:

除了上面展示的图表,你还可以使用以下选项自定义 Mermaid 图表:

综合起来,我们可以编辑上一个 ai_q_and_a_graph.py 示例来:

  • 为一些边添加标签
  • Ask 节点添加一个注释
  • 高亮显示 Answer 节点
  • 将图表作为 PNG 图像保存到文件
ai_q_and_a_graph_extra.py
from typing import Annotated

from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Edge

ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True)


@dataclass
class QuestionState:
    question: str | None = None
    ask_agent_messages: list[ModelMessage] = field(default_factory=list)
    evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)


@dataclass
class Ask(BaseNode[QuestionState]):
    """Generate question using GPT-4o."""
    docstring_notes = True
    async def run(
        self, ctx: GraphRunContext[QuestionState]
    ) -> Annotated[Answer, Edge(label='Ask the question')]:
        result = await ask_agent.run(
            'Ask a simple question with a single correct answer.',
            message_history=ctx.state.ask_agent_messages,
        )
        ctx.state.ask_agent_messages += result.new_messages()
        ctx.state.question = result.output
        return Answer(result.output)


@dataclass
class Answer(BaseNode[QuestionState]):
    question: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
        answer = input(f'{self.question}: ')
        return Evaluate(answer)


class EvaluationResult(BaseModel, use_attribute_docstrings=True):
    correct: bool
    """Whether the answer is correct."""
    comment: str
    """Comment on the answer, reprimand the user if the answer is wrong."""


evaluate_agent = Agent(
    'openai:gpt-4o',
    output_type=EvaluationResult,
    system_prompt='Given a question and answer, evaluate if the answer is correct.',
)


@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
    answer: str

    async def run(
        self,
        ctx: GraphRunContext[QuestionState],
    ) -> Annotated[End[str], Edge(label='success')] | Reprimand:
        assert ctx.state.question is not None
        result = await evaluate_agent.run(
            format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
            message_history=ctx.state.evaluate_agent_messages,
        )
        ctx.state.evaluate_agent_messages += result.new_messages()
        if result.output.correct:
            return End(result.output.comment)
        else:
            return Reprimand(result.output.comment)


@dataclass
class Reprimand(BaseNode[QuestionState]):
    comment: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
        print(f'Comment: {self.comment}')
        ctx.state.question = None
        return Ask()


question_graph = Graph(
    nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)

(此示例不完整,不能直接运行)

这将生成一个看起来像这样的图像

---
title: question_graph
---
stateDiagram-v2
  Ask --> Answer: Ask the question
  note right of Ask
    Judge the answer.
    Decide on next step.
  end note
  Answer --> Evaluate
  Evaluate --> Reprimand
  Evaluate --> [*]: success
  Reprimand --> Ask

classDef highlighted fill:#fdff32
class Answer highlighted

设置状态图的方向

你可以使用以下值之一来指定状态图的方向:

  • 'TB': 从上到下,图表从上到下垂直流动。
  • 'LR': 从左到右,图表从左到右水平流动。
  • 'RL': 从右到左,图表从右到左水平流动。
  • 'BT': 从下到上,图表从下到上垂直流动。

下面是一个如何使用‘从左到右’(LR)而不是默认的‘从上到下’(TB)的示例:

vending_machine_diagram.py
from vending_machine import InsertCoin, vending_machine_graph

vending_machine_graph.mermaid_code(start_node=InsertCoin, direction='LR')
---
title: vending_machine_graph
---
stateDiagram-v2
  direction LR
  [*] --> InsertCoin
  InsertCoin --> CoinsInserted
  CoinsInserted --> SelectProduct
  CoinsInserted --> Purchase
  SelectProduct --> Purchase
  Purchase --> InsertCoin
  Purchase --> SelectProduct
  Purchase --> [*]