跳转到内容

除非你需要钉枪,否则别用钉枪

如果说 Pydantic AI 智能体是锤子,多智能体工作流是大锤,那么图就是钉枪

  • 当然,钉枪看起来比锤子酷
  • 但钉枪的设置比锤子复杂得多
  • 而且钉枪不会让你成为更好的建造者,只会让你成为一个有钉枪的建造者
  • 最后,(冒着滥用这个比喻的风险),如果你喜欢像木槌和无类型 Python 这样的中世纪工具,你可能不会喜欢钉枪或我们的图方法。(但话说回来,如果你不喜欢 Python 中的类型提示,你可能已经从 Pydantic AI 转向使用那些玩具般的智能体框架了——祝你好运,当你意识到需要它时,随时可以借用我的大锤)

简而言之,图是一个强大的工具,但并非适用于所有工作。在继续之前,请考虑其他多智能体方法

如果你不确定基于图的方法是个好主意,那么它可能是不必要的。

图和有限状态机 (FSM) 是一个强大的抽象,用于建模、执行、控制和可视化复杂的工作流。

除了 Pydantic AI,我们还开发了 pydantic-graph——一个用于 Python 的异步图和状态机库,其中的节点和边是使用类型提示定义的。

虽然这个库是作为 Pydantic AI 的一部分开发的,但它不依赖于 pydantic-ai,可以被视为一个纯粹的基于图的状态机库。无论你是否在使用 Pydantic AI,甚至是否在用 GenAI 构建应用,你都可能会发现它很有用。

pydantic-graph 是为高级用户设计的,大量使用了 Python 的泛型和类型提示。它的设计不像 Pydantic AI 那样对初学者友好。

安装

pydantic-graphpydantic-ai 的必需依赖项,也是 pydantic-ai-slim 的可选依赖项,更多信息请参见安装说明。你也可以直接安装它

pip install pydantic-graph
uv add pydantic-graph

图类型

pydantic-graph 由几个关键组件组成

GraphRunContext

GraphRunContext — 图运行的上下文,类似于 Pydantic AI 的 RunContext。它持有图的状态和依赖项,并在节点运行时传递给它们。

GraphRunContext 在其所使用的图的状态类型 StateT 上是泛型的。

结束

End — 表示图运行应结束的返回值。

End 在其所使用的图的图返回类型 RunEndT 上是泛型的。

节点

BaseNode 的子类定义了在图中执行的节点。

节点通常是dataclasses,通常由以下部分组成

  • 包含调用节点时所需/可选的任何参数的字段
  • run 方法中执行节点的业务逻辑
  • run 方法的返回注解,pydantic-graph 会读取这些注解来确定节点的出边

节点在以下方面是泛型的

  • state,它必须与它们所属图的状态类型相同,StateT 的默认值为 None,所以如果你不使用状态,可以省略这个泛型参数,更多信息请参见有状态图
  • deps,它必须与它们所属图的依赖项类型相同,DepsT 的默认值为 None,所以如果你不使用依赖项,可以省略这个泛型参数,更多信息请参见依赖注入
  • 图返回类型 — 这仅在节点返回 End 时适用。RunEndT 的默认值为 Never,因此如果节点不返回 End,可以省略这个泛型参数,但如果返回 End,则必须包含它。

这是一个图中开始或中间节点的示例——它不能结束运行,因为它不返回 End

intermediate_node.py
from dataclasses import dataclass

from pydantic_graph import BaseNode, GraphRunContext


@dataclass
class MyNode(BaseNode[MyState]):  # (1)!
    foo: int  # (2)!

    async def run(
        self,
        ctx: GraphRunContext[MyState],  # (3)!
    ) -> AnotherNode:  # (4)!
        ...
        return AnotherNode()
  1. 本例中的状态是 MyState(未显示),因此 BaseNode 使用 MyState 进行参数化。此节点无法结束运行,所以 RunEndT 泛型参数被省略,并默认为 Never
  2. MyNode 是一个数据类,只有一个字段 foo,类型为 int
  3. run 方法接受一个 GraphRunContext 参数,同样使用状态 MyState 进行参数化。
  4. run 方法的返回类型是 AnotherNode(未显示),这用于确定节点的出边。

我们可以扩展 MyNode,使其在 foo 能被 5 整除时可选地结束运行

intermediate_or_end_node.py
from dataclasses import dataclass

from pydantic_graph import BaseNode, End, GraphRunContext


@dataclass
class MyNode(BaseNode[MyState, None, int]):  # (1)!
    foo: int

    async def run(
        self,
        ctx: GraphRunContext[MyState],
    ) -> AnotherNode | End[int]:  # (2)!
        if self.foo % 5 == 0:
            return End(self.foo)
        else:
            return AnotherNode()
  1. 我们用返回类型(本例中为 int)和状态来参数化节点。因为泛型参数是仅位置的,我们必须包含 None 作为代表依赖项的第二个参数。
  2. run 方法的返回类型现在是 AnotherNodeEnd[int] 的联合类型,这使得节点在 foo 能被 5 整除时可以结束运行。

Graph — 这是执行图本身,由一组节点类(即 BaseNode 子类)组成。

Graph 在以下方面是泛型的

  • state 图的状态类型,StateT
  • deps 图的依赖项类型,DepsT
  • 图返回类型 图运行的返回类型,RunEndT

这是一个简单图的示例

graph_example.py
from __future__ import annotations

from dataclasses import dataclass

from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class DivisibleBy5(BaseNode[None, None, int]):  # (1)!
    foo: int

    async def run(
        self,
        ctx: GraphRunContext,
    ) -> Increment | End[int]:
        if self.foo % 5 == 0:
            return End(self.foo)
        else:
            return Increment(self.foo)


@dataclass
class Increment(BaseNode):  # (2)!
    foo: int

    async def run(self, ctx: GraphRunContext) -> DivisibleBy5:
        return DivisibleBy5(self.foo + 1)


fives_graph = Graph(nodes=[DivisibleBy5, Increment])  # (3)!
result = fives_graph.run_sync(DivisibleBy5(4))  # (4)!
print(result.output)
#> 5
  1. DivisibleBy5 节点用 None 作为状态参数,用 None 作为依赖项参数,因为这个图不使用状态或依赖项,用 int 作为返回类型,因为它可以结束运行。
  2. Increment 节点不返回 End,所以 RunEndT 泛型参数被省略,状态也可以被省略,因为该图不使用状态。
  3. 图是使用一系列节点创建的。
  4. 图使用 run_sync 同步运行。初始节点是 DivisibleBy5(4)。因为图不使用外部状态或依赖项,我们不传递 statedeps

(此示例是完整的,可以在 Python 3.10+ 环境下“按原样”运行)

可以使用以下代码为此图生成一个 mermaid 图

graph_example_diagram.py
from graph_example import DivisibleBy5, fives_graph

fives_graph.mermaid_code(start_node=DivisibleBy5)
---
title: fives_graph
---
stateDiagram-v2
  [*] --> DivisibleBy5
  DivisibleBy5 --> Increment
  DivisibleBy5 --> [*]
  Increment --> DivisibleBy5

为了在 jupyter-notebook 中可视化一个图,需要使用 IPython.display

jupyter_display_mermaid.py
from graph_example import DivisibleBy5, fives_graph
from IPython.display import Image, display

display(Image(fives_graph.mermaid_image(start_node=DivisibleBy5)))

有状态图

pydantic-graph 中的“状态”概念提供了一种可选的方式,用于在图中节点运行时访问和修改一个对象(通常是 dataclass 或 Pydantic 模型)。如果你把图想象成一条生产线,那么你的状态就是沿着生产线传递并由每个节点在图运行时构建起来的引擎。

pydantic-graph 提供状态持久化,在每个节点运行后记录状态。(参见 状态持久化。)

这是一个表示自动售货机的图的示例,用户可以投入硬币并选择要购买的产品。

vending_machine.py
from __future__ import annotations

from dataclasses import dataclass

from rich.prompt import Prompt

from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class MachineState:  # (1)!
    user_balance: float = 0.0
    product: str | None = None


@dataclass
class InsertCoin(BaseNode[MachineState]):  # (3)!
    async def run(self, ctx: GraphRunContext[MachineState]) -> CoinsInserted:  # (16)!
        return CoinsInserted(float(Prompt.ask('Insert coins')))  # (4)!


@dataclass
class CoinsInserted(BaseNode[MachineState]):
    amount: float  # (5)!

    async def run(
        self, ctx: GraphRunContext[MachineState]
    ) -> SelectProduct | Purchase:  # (17)!
        ctx.state.user_balance += self.amount  # (6)!
        if ctx.state.product is not None:  # (7)!
            return Purchase(ctx.state.product)
        else:
            return SelectProduct()


@dataclass
class SelectProduct(BaseNode[MachineState]):
    async def run(self, ctx: GraphRunContext[MachineState]) -> Purchase:
        return Purchase(Prompt.ask('Select product'))


PRODUCT_PRICES = {  # (2)!
    'water': 1.25,
    'soda': 1.50,
    'crisps': 1.75,
    'chocolate': 2.00,
}


@dataclass
class Purchase(BaseNode[MachineState, None, None]):  # (18)!
    product: str

    async def run(
        self, ctx: GraphRunContext[MachineState]
    ) -> End | InsertCoin | SelectProduct:
        if price := PRODUCT_PRICES.get(self.product):  # (8)!
            ctx.state.product = self.product  # (9)!
            if ctx.state.user_balance >= price:  # (10)!
                ctx.state.user_balance -= price
                return End(None)
            else:
                diff = price - ctx.state.user_balance
                print(f'Not enough money for {self.product}, need {diff:0.2f} more')
                #> Not enough money for crisps, need 0.75 more
                return InsertCoin()  # (11)!
        else:
            print(f'No such product: {self.product}, try again')
            return SelectProduct()  # (12)!


vending_machine_graph = Graph(  # (13)!
    nodes=[InsertCoin, CoinsInserted, SelectProduct, Purchase]
)


async def main():
    state = MachineState()  # (14)!
    await vending_machine_graph.run(InsertCoin(), state=state)  # (15)!
    print(f'purchase successful item={state.product} change={state.user_balance:0.2f}')
    #> purchase successful item=crisps change=0.25
  1. 自动售货机的状态被定义为一个数据类,包含用户的余额和他们已选择的产品(如果有的话)。
  2. 一个将产品映射到价格的字典。
  3. InsertCoin 节点中,BaseNode 使用 MachineState 进行参数化,因为这是该图中使用的状态。
  4. InsertCoin 节点提示用户投入硬币。我们通过只输入一个浮点数金额来保持简单。在你开始认为这也是个玩具之前,因为它在节点内使用了 rich 的 Prompt.ask,请参见下文,了解当节点需要外部输入时如何管理控制流。
  5. CoinsInserted 节点;同样,这是一个只有一个字段 amountdataclass
  6. 用投入的金额更新用户余额。
  7. 如果用户已经选择了产品,则转到 Purchase,否则转到 SelectProduct
  8. Purchase 节点中,如果用户输入了有效产品,则查找该产品的价格。
  9. 如果用户确实输入了有效产品,则在状态中设置产品,这样我们就不会重新访问 SelectProduct
  10. 如果余额足以购买产品,则调整余额以反映购买,并返回 End 以结束图。我们不使用运行返回类型,所以我们用 None 调用 End
  11. 如果余额不足,则转到 InsertCoin 提示用户投入更多硬币。
  12. 如果产品无效,则转到 SelectProduct 提示用户再次选择产品。
  13. 图是通过将节点列表传递给 Graph 来创建的。节点的顺序不重要,但它会影响图表的显示方式。
  14. 初始化状态。这将被传递给图运行,并在图运行时被修改。
  15. 使用初始状态运行图。由于图可以从任何节点开始运行,我们必须传递起始节点——在本例中是 InsertCoinGraph.run 返回一个 GraphRunResult,它提供最终数据和运行历史。
  16. 节点的 run 方法的返回类型很重要,因为它用于确定节点的出边。这些信息反过来又用于渲染Mermaid 图,并在运行时强制执行,以便尽早发现不当行为。
  17. CoinsInsertedrun 方法的返回类型是一个联合类型,意味着可能有多个出边。
  18. 与其他节点不同,Purchase 可以结束运行,因此必须设置 RunEndT 泛型参数。在这种情况下,它是 None,因为图运行的返回类型是 None

(此示例是完整的,可以在 Python 3.10+ 环境下“按原样”运行——你需要添加 asyncio.run(main()) 来运行 main

可以使用以下代码为此图生成一个 mermaid 图

vending_machine_diagram.py
from vending_machine import InsertCoin, vending_machine_graph

vending_machine_graph.mermaid_code(start_node=InsertCoin)

上述代码生成的图表是

---
title: vending_machine_graph
---
stateDiagram-v2
  [*] --> InsertCoin
  InsertCoin --> CoinsInserted
  CoinsInserted --> SelectProduct
  CoinsInserted --> Purchase
  SelectProduct --> Purchase
  Purchase --> InsertCoin
  Purchase --> SelectProduct
  Purchase --> [*]

有关生成图表的更多信息,请参见下文

GenAI 示例

到目前为止,我们还没有展示一个实际使用 Pydantic AI 或 GenAI 的图的示例。

在这个例子中,一个智能体为用户生成欢迎邮件,另一个智能体对邮件提供反馈。

这个图的结构非常简单

---
title: feedback_graph
---
stateDiagram-v2
  [*] --> WriteEmail
  WriteEmail --> Feedback
  Feedback --> WriteEmail
  Feedback --> [*]
genai_email_feedback.py
from __future__ import annotations as _annotations

from dataclasses import dataclass, field

from pydantic import BaseModel, EmailStr

from pydantic_ai import Agent, format_as_xml
from pydantic_ai.messages import ModelMessage
from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class User:
    name: str
    email: EmailStr
    interests: list[str]


@dataclass
class Email:
    subject: str
    body: str


@dataclass
class State:
    user: User
    write_agent_messages: list[ModelMessage] = field(default_factory=list)


email_writer_agent = Agent(
    'google-vertex:gemini-1.5-pro',
    output_type=Email,
    system_prompt='Write a welcome email to our tech blog.',
)


@dataclass
class WriteEmail(BaseNode[State]):
    email_feedback: str | None = None

    async def run(self, ctx: GraphRunContext[State]) -> Feedback:
        if self.email_feedback:
            prompt = (
                f'Rewrite the email for the user:\n'
                f'{format_as_xml(ctx.state.user)}\n'
                f'Feedback: {self.email_feedback}'
            )
        else:
            prompt = (
                f'Write a welcome email for the user:\n'
                f'{format_as_xml(ctx.state.user)}'
            )

        result = await email_writer_agent.run(
            prompt,
            message_history=ctx.state.write_agent_messages,
        )
        ctx.state.write_agent_messages += result.new_messages()
        return Feedback(result.output)


class EmailRequiresWrite(BaseModel):
    feedback: str


class EmailOk(BaseModel):
    pass


feedback_agent = Agent[None, EmailRequiresWrite | EmailOk](
    'openai:gpt-4o',
    output_type=EmailRequiresWrite | EmailOk,  # type: ignore
    system_prompt=(
        'Review the email and provide feedback, email must reference the users specific interests.'
    ),
)


@dataclass
class Feedback(BaseNode[State, None, Email]):
    email: Email

    async def run(
        self,
        ctx: GraphRunContext[State],
    ) -> WriteEmail | End[Email]:
        prompt = format_as_xml({'user': ctx.state.user, 'email': self.email})
        result = await feedback_agent.run(prompt)
        if isinstance(result.output, EmailRequiresWrite):
            return WriteEmail(email_feedback=result.output.feedback)
        else:
            return End(self.email)


async def main():
    user = User(
        name='John Doe',
        email='john.joe@example.com',
        interests=['Haskel', 'Lisp', 'Fortran'],
    )
    state = State(user)
    feedback_graph = Graph(nodes=(WriteEmail, Feedback))
    result = await feedback_graph.run(WriteEmail(), state=state)
    print(result.output)
    """
    Email(
        subject='Welcome to our tech blog!',
        body='Hello John, Welcome to our tech blog! ...',
    )
    """

(此示例是完整的,可以在 Python 3.10+ 环境下“按原样”运行——你需要添加 asyncio.run(main()) 来运行 main

遍历图

使用 Graph.iter 进行 async for 迭代

有时您希望在图执行时直接控制或了解每个节点。最简单的方法是使用 Graph.iter 方法,它返回一个**上下文管理器**,该管理器产生一个 GraphRun 对象。GraphRun 是一个针对图节点的异步可迭代对象,允许您在它们执行时记录或修改它们。

这是一个例子

count_down.py
from __future__ import annotations as _annotations

from dataclasses import dataclass
from pydantic_graph import Graph, BaseNode, End, GraphRunContext


@dataclass
class CountDownState:
    counter: int


@dataclass
class CountDown(BaseNode[CountDownState, None, int]):
    async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
        if ctx.state.counter <= 0:
            return End(ctx.state.counter)
        ctx.state.counter -= 1
        return CountDown()


count_down_graph = Graph(nodes=[CountDown])


async def main():
    state = CountDownState(counter=3)
    async with count_down_graph.iter(CountDown(), state=state) as run:  # (1)!
        async for node in run:  # (2)!
            print('Node:', node)
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: End(data=0)
    print('Final output:', run.result.output)  # (3)!
    #> Final output: 0
  1. Graph.iter(...) 返回一个 GraphRun
  2. 在这里,我们在每个节点执行时逐步遍历。
  3. 一旦图返回一个 End,循环结束,run.result 变成一个 GraphRunResult,其中包含最终结果(此处为 0)。

手动使用 GraphRun.next(node)

或者,您可以使用 GraphRun.next 方法手动驱动迭代,这允许您传入任何您想下一步运行的节点。您可以通过这种方式修改或选择性地跳过节点。

下面是一个精心设计的示例,当计数器为 2 时停止,忽略此后的任何节点运行

count_down_next.py
from pydantic_graph import End, FullStatePersistence
from count_down import CountDown, CountDownState, count_down_graph


async def main():
    state = CountDownState(counter=5)
    persistence = FullStatePersistence()  # (7)!
    async with count_down_graph.iter(
        CountDown(), state=state, persistence=persistence
    ) as run:
        node = run.next_node  # (1)!
        while not isinstance(node, End):  # (2)!
            print('Node:', node)
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            #> Node: CountDown()
            if state.counter == 2:
                break  # (3)!
            node = await run.next(node)  # (4)!

        print(run.result)  # (5)!
        #> None

        for step in persistence.history:  # (6)!
            print('History Step:', step.state, step.state)
            #> History Step: CountDownState(counter=5) CountDownState(counter=5)
            #> History Step: CountDownState(counter=4) CountDownState(counter=4)
            #> History Step: CountDownState(counter=3) CountDownState(counter=3)
            #> History Step: CountDownState(counter=2) CountDownState(counter=2)
  1. 我们首先获取智能体图中将要运行的第一个节点。
  2. 一旦生成了 `End` 节点,智能体运行就完成了;`End` 的实例不能传递给 `next`。
  3. 如果用户决定提前停止,我们跳出循环。在这种情况下,图运行将没有真正的最终结果(`run.result` 保持为 `None`)。
  4. 在每一步中,我们调用 `await run.next(node)` 来运行它并获取下一个节点(或一个 `End`)。
  5. 因为我们没有继续运行直到完成,所以 `result` 没有被设置。
  6. 运行的历史记录仍然填充了我们到目前为止执行的步骤。
  7. 使用 FullStatePersistence 以便我们能显示运行的历史记录,更多信息请参见下文的状态持久化

状态持久化

有限状态机 (FSM) 图的最大好处之一是它们如何简化中断执行的处理。这可能由于多种原因发生

  • 状态机逻辑可能根本上需要暂停——例如,电子商务订单的退货工作流需要等待物品被邮寄到退货中心,或者因为下一个节点的执行需要用户的输入,因此需要等待新的 http 请求,
  • 执行时间太长,以至于整个图无法可靠地在一次连续运行中执行——例如,一个可能需要数小时运行的深度研究智能体,
  • 您希望在不同的进程/硬件实例中并行运行多个图节点(注意:`pydantic-graph` 尚不支持并行节点执行,请参见 #704)。

试图使传统的控制流(即布尔逻辑和嵌套函数调用)实现与这些使用场景兼容,通常会导致脆弱且过于复杂的意大利面条式代码,其中中断和恢复执行所需的逻辑主导了实现。

为了允许图运行被中断和恢复,pydantic-graph 提供了状态持久化——一个在每个节点运行前后对图运行状态进行快照的系统,允许从图中的任何一点恢复图运行。

pydantic-graph 包括三种状态持久化实现

  • SimpleStatePersistence — 简单的内存中状态持久化,仅保存最新的快照。如果在运行图时未提供状态持久化实现,则默认使用此实现。
  • FullStatePersistence — 内存中状态持久化,保存一个快照列表。
  • FileStatePersistence — 基于文件的状态持久化,将快照保存到 JSON 文件。

在生产应用程序中,开发人员应通过子类化 BaseStatePersistence 抽象基类来实现自己的状态持久化,这可能会将运行持久化到像 PostgresQL 这样的关系数据库中。

从高层次来看,StatePersistence 实现的角色是存储和检索 NodeSnapshotEndSnapshot 对象。

graph.iter_from_persistence() 可用于基于持久化中存储的状态来运行图。

我们可以运行上面count_down_graph,使用 graph.iter_from_persistence()FileStatePersistence

正如您在这段代码中看到的,run_node 的运行不需要任何外部应用程序状态(除了状态持久化),这意味着图可以轻松地由分布式执行和排队系统执行。

count_down_from_persistence.py
from pathlib import Path

from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence

from count_down import CountDown, CountDownState, count_down_graph


async def main():
    run_id = 'run_abc123'
    persistence = FileStatePersistence(Path(f'count_down_{run_id}.json'))  # (1)!
    state = CountDownState(counter=5)
    await count_down_graph.initialize(  # (2)!
        CountDown(), state=state, persistence=persistence
    )

    done = False
    while not done:
        done = await run_node(run_id)


async def run_node(run_id: str) -> bool:  # (3)!
    persistence = FileStatePersistence(Path(f'count_down_{run_id}.json'))
    async with count_down_graph.iter_from_persistence(persistence) as run:  # (4)!
        node_or_end = await run.next()  # (5)!

    print('Node:', node_or_end)
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: CountDown()
    #> Node: End(data=0)
    return isinstance(node_or_end, End)  # (6)!
  1. 创建一个 FileStatePersistence 来启动图。
  2. 调用 graph.initialize() 在持久化对象中设置初始图状态。
  3. run_node 是一个纯函数,除了运行的 ID 外,不需要访问任何其他进程状态来运行图的下一个节点。
  4. 调用 graph.iter_from_persistence() 创建一个 GraphRun 对象,它将从持久化中存储的状态运行图的下一个节点。这将返回一个节点或一个 End 对象。
  5. graph.run() 将返回一个 节点 或一个 End 对象。
  6. 检查节点是否是 End 对象,如果是,则图运行完成。

(此示例是完整的,可以在 Python 3.10+ 环境下“按原样”运行——你需要添加 asyncio.run(main()) 来运行 main

示例:人机交互。

如上所述,状态持久化允许图被中断和恢复。这种用法之一是允许用户输入以继续。

在这个例子中,一个 AI 向用户提问,用户提供答案,AI 评估答案,如果用户答对了就结束,如果答错了就再问一个问题。

我们不是在单个进程调用中运行整个图,而是通过重复运行进程来运行图,可选地通过命令行参数提供问题的答案。

ai_q_and_a_graph.pyquestion_graph 定义
ai_q_and_a_graph.py
from __future__ import annotations as _annotations

from typing import Annotated
from pydantic_graph import Edge
from dataclasses import dataclass, field
from pydantic import BaseModel
from pydantic_graph import (
    BaseNode,
    End,
    Graph,
    GraphRunContext,
)
from pydantic_ai import Agent, format_as_xml
from pydantic_ai.messages import ModelMessage

ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True)


@dataclass
class QuestionState:
    question: str | None = None
    ask_agent_messages: list[ModelMessage] = field(default_factory=list)
    evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)


@dataclass
class Ask(BaseNode[QuestionState]):
    """Generate question using GPT-4o."""
    docstring_notes = True
    async def run(
        self, ctx: GraphRunContext[QuestionState]
    ) -> Annotated[Answer, Edge(label='Ask the question')]:
        result = await ask_agent.run(
            'Ask a simple question with a single correct answer.',
            message_history=ctx.state.ask_agent_messages,
        )
        ctx.state.ask_agent_messages += result.new_messages()
        ctx.state.question = result.output
        return Answer(result.output)


@dataclass
class Answer(BaseNode[QuestionState]):
    question: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
        answer = input(f'{self.question}: ')
        return Evaluate(answer)


class EvaluationResult(BaseModel, use_attribute_docstrings=True):
    correct: bool
    """Whether the answer is correct."""
    comment: str
    """Comment on the answer, reprimand the user if the answer is wrong."""


evaluate_agent = Agent(
    'openai:gpt-4o',
    output_type=EvaluationResult,
    system_prompt='Given a question and answer, evaluate if the answer is correct.',
)


@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
    answer: str

    async def run(
        self,
        ctx: GraphRunContext[QuestionState],
    ) -> Annotated[End[str], Edge(label='success')] | Reprimand:
        assert ctx.state.question is not None
        result = await evaluate_agent.run(
            format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
            message_history=ctx.state.evaluate_agent_messages,
        )
        ctx.state.evaluate_agent_messages += result.new_messages()
        if result.output.correct:
            return End(result.output.comment)
        else:
            return Reprimand(result.output.comment)


@dataclass
class Reprimand(BaseNode[QuestionState]):
    comment: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
        print(f'Comment: {self.comment}')
        ctx.state.question = None
        return Ask()


question_graph = Graph(
    nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)

(此示例是完整的,可以在 Python 3.10+ 环境下“按原样”运行)

ai_q_and_a_run.py
import sys
from pathlib import Path

from pydantic_graph import End
from pydantic_graph.persistence.file import FileStatePersistence
from pydantic_ai.messages import ModelMessage  # noqa: F401

from ai_q_and_a_graph import Ask, question_graph, Evaluate, QuestionState, Answer


async def main():
    answer: str | None = sys.argv[1] if len(sys.argv) > 1 else None  # (1)!
    persistence = FileStatePersistence(Path('question_graph.json'))  # (2)!
    persistence.set_graph_types(question_graph)  # (3)!

    if snapshot := await persistence.load_next():  # (4)!
        state = snapshot.state
        assert answer is not None
        node = Evaluate(answer)
    else:
        state = QuestionState()
        node = Ask()  # (5)!

    async with question_graph.iter(node, state=state, persistence=persistence) as run:
        while True:
            node = await run.next()  # (6)!
            if isinstance(node, End):  # (7)!
                print('END:', node.data)
                history = await persistence.load_all()  # (8)!
                print([e.node for e in history])
                break
            elif isinstance(node, Answer):  # (9)!
                print(node.question)
                #> What is the capital of France?
                break
            # otherwise just continue
  1. 如果提供,从命令行获取用户的答案。有关完整示例,请参见问题图示例
  2. 创建一个状态持久化实例,'question_graph.json' 文件可能已经存在,也可能不存在。
  3. 由于我们在图外部使用持久化接口,我们需要调用 set_graph_types 来为持久化实例设置图的泛型类型 StateTRunEndT。这是必要的,以便持久化实例知道如何序列化和反序列化图节点。
  4. 如果我们之前运行过图,load_next 将返回下一个要运行的节点的快照,这里我们使用该快照中的 state,并用命令行提供的答案创建一个新的 Evaluate 节点。
  5. 如果图之前没有运行过,我们创建一个新的 QuestionState 并从 Ask 节点开始。
  6. 调用 GraphRun.next() 来运行节点。这将返回一个节点或一个 End 对象。
  7. 如果节点是 `End` 对象,则图运行完成。`End` 对象的 `data` 字段包含 `evaluate_agent` 返回的关于正确答案的评论。
  8. 为了演示状态持久化,我们调用 load_all 从持久化实例中获取所有快照。这将返回一个 Snapshot 对象列表。
  9. 如果节点是 `Answer` 对象,我们打印问题并跳出循环以结束进程并等待用户输入。

(此示例是完整的,可以在 Python 3.10+ 环境下“按原样”运行——你需要添加 asyncio.run(main()) 来运行 main

有关此图的完整示例,请参见问题图示例

依赖注入

与 Pydantic AI 一样,pydantic-graph 通过 GraphBaseNode 上的泛型参数以及 GraphRunContext.deps 字段支持依赖注入。

作为依赖注入的示例,让我们修改上面DivisibleBy5 示例,使用 ProcessPoolExecutor 在单独的进程中运行计算负载(这是一个刻意设计的示例,ProcessPoolExecutor 在这个示例中实际上不会提高性能)

deps_example.py
from __future__ import annotations

import asyncio
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass

from pydantic_graph import BaseNode, End, FullStatePersistence, Graph, GraphRunContext


@dataclass
class GraphDeps:
    executor: ProcessPoolExecutor


@dataclass
class DivisibleBy5(BaseNode[None, GraphDeps, int]):
    foo: int

    async def run(
        self,
        ctx: GraphRunContext[None, GraphDeps],
    ) -> Increment | End[int]:
        if self.foo % 5 == 0:
            return End(self.foo)
        else:
            return Increment(self.foo)


@dataclass
class Increment(BaseNode[None, GraphDeps]):
    foo: int

    async def run(self, ctx: GraphRunContext[None, GraphDeps]) -> DivisibleBy5:
        loop = asyncio.get_running_loop()
        compute_result = await loop.run_in_executor(
            ctx.deps.executor,
            self.compute,
        )
        return DivisibleBy5(compute_result)

    def compute(self) -> int:
        return self.foo + 1


fives_graph = Graph(nodes=[DivisibleBy5, Increment])


async def main():
    with ProcessPoolExecutor() as executor:
        deps = GraphDeps(executor)
        result = await fives_graph.run(DivisibleBy5(3), deps=deps, persistence=FullStatePersistence())
    print(result.output)
    #> 5
    # the full history is quite verbose (see below), so we'll just print the summary
    print([item.node for item in result.persistence.history])
    """
    [
        DivisibleBy5(foo=3),
        Increment(foo=3),
        DivisibleBy5(foo=4),
        Increment(foo=4),
        DivisibleBy5(foo=5),
        End(data=5),
    ]
    """

(此示例是完整的,可以在 Python 3.10+ 环境下“按原样”运行——你需要添加 asyncio.run(main()) 来运行 main

Mermaid 图

Pydantic Graph 可以为图生成 mermaid stateDiagram-v2 图,如上所示。

这些图表可以通过以下方式生成

除了上面显示的图表,您还可以使用以下选项自定义 mermaid 图表

综上所述,我们可以编辑最后一个ai_q_and_a_graph.py示例来

  • 为一些边添加标签
  • Ask 节点添加注释
  • 高亮显示 Answer 节点
  • 将图表另存为 PNG 图像文件
ai_q_and_a_graph_extra.py
from typing import Annotated

from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Edge

ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True)


@dataclass
class QuestionState:
    question: str | None = None
    ask_agent_messages: list[ModelMessage] = field(default_factory=list)
    evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)


@dataclass
class Ask(BaseNode[QuestionState]):
    """Generate question using GPT-4o."""
    docstring_notes = True
    async def run(
        self, ctx: GraphRunContext[QuestionState]
    ) -> Annotated[Answer, Edge(label='Ask the question')]:
        result = await ask_agent.run(
            'Ask a simple question with a single correct answer.',
            message_history=ctx.state.ask_agent_messages,
        )
        ctx.state.ask_agent_messages += result.new_messages()
        ctx.state.question = result.output
        return Answer(result.output)


@dataclass
class Answer(BaseNode[QuestionState]):
    question: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
        answer = input(f'{self.question}: ')
        return Evaluate(answer)


class EvaluationResult(BaseModel, use_attribute_docstrings=True):
    correct: bool
    """Whether the answer is correct."""
    comment: str
    """Comment on the answer, reprimand the user if the answer is wrong."""


evaluate_agent = Agent(
    'openai:gpt-4o',
    output_type=EvaluationResult,
    system_prompt='Given a question and answer, evaluate if the answer is correct.',
)


@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
    answer: str

    async def run(
        self,
        ctx: GraphRunContext[QuestionState],
    ) -> Annotated[End[str], Edge(label='success')] | Reprimand:
        assert ctx.state.question is not None
        result = await evaluate_agent.run(
            format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
            message_history=ctx.state.evaluate_agent_messages,
        )
        ctx.state.evaluate_agent_messages += result.new_messages()
        if result.output.correct:
            return End(result.output.comment)
        else:
            return Reprimand(result.output.comment)


@dataclass
class Reprimand(BaseNode[QuestionState]):
    comment: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
        print(f'Comment: {self.comment}')
        ctx.state.question = None
        return Ask()


question_graph = Graph(
    nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)

(此示例不完整,不能直接运行)

这将生成一个如下所示的图像

---
title: question_graph
---
stateDiagram-v2
  Ask --> Answer: Ask the question
  note right of Ask
    Judge the answer.
    Decide on next step.
  end note
  Answer --> Evaluate
  Evaluate --> Reprimand
  Evaluate --> [*]: success
  Reprimand --> Ask

classDef highlighted fill:#fdff32
class Answer highlighted

设置状态图的方向

您可以使用以下值之一来指定状态图的方向

  • 'TB':从上到下,图表从上到下垂直流动。
  • 'LR':从左到右,图表从左到右水平流动。
  • 'RL':从右到左,图表从右到左水平流动。
  • 'BT':从下到上,图表从下到上垂直流动。

以下是如何使用“从左到右”(LR)而不是默认的“从上到下”(TB)的示例

vending_machine_diagram.py
from vending_machine import InsertCoin, vending_machine_graph

vending_machine_graph.mermaid_code(start_node=InsertCoin, direction='LR')
---
title: vending_machine_graph
---
stateDiagram-v2
  direction LR
  [*] --> InsertCoin
  InsertCoin --> CoinsInserted
  CoinsInserted --> SelectProduct
  CoinsInserted --> Purchase
  SelectProduct --> Purchase
  Purchase --> InsertCoin
  Purchase --> SelectProduct
  Purchase --> [*]