跳转到内容

pydantic_graph.persistence

SnapshotStatus module-attribute

SnapshotStatus = Literal[
    "created", "pending", "running", "success", "error"
]

快照的状态。

  • 'created': 快照已创建,但尚未运行。
  • 'pending': 快照已被 load_next 检索,但尚未运行。
  • 'running': 快照当前正在运行。
  • 'success': 快照已成功运行。
  • 'error': 快照已运行,但发生了错误。

NodeSnapshot dataclass

Bases: Generic[StateT, RunEndT]

描述图中节点执行的历史步骤。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@dataclass(kw_only=True)
class NodeSnapshot(Generic[StateT, RunEndT]):
    """History step describing the execution of a node in a graph."""

    state: StateT
    """The state of the graph before the node is run."""
    node: Annotated[BaseNode[StateT, Any, RunEndT], _utils.CustomNodeSchema()]
    """The node to run next."""
    start_ts: datetime | None = None
    """The timestamp when the node started running, `None` until the run starts."""
    duration: float | None = None
    """The duration of the node run in seconds, if the node has been run."""
    status: SnapshotStatus = 'created'
    """The status of the snapshot."""
    kind: Literal['node'] = 'node'
    """The kind of history step, can be used as a discriminator when deserializing history."""

    id: str = UNSET_SNAPSHOT_ID
    """Unique ID of the snapshot."""

    def __post_init__(self) -> None:
        if self.id == UNSET_SNAPSHOT_ID:
            self.id = self.node.get_snapshot_id()

state 实例属性

state: StateT

节点运行前图的状态。

node instance-attribute

node: Annotated[
    BaseNode[StateT, Any, RunEndT], CustomNodeSchema()
]

下一个要运行的节点。

start_ts class-attribute instance-attribute

start_ts: datetime | None = None

节点开始运行的时间戳,在运行开始前为 None

duration class-attribute instance-attribute

duration: float | None = None

如果节点已运行,则为节点运行的持续时间(秒)。

status class-attribute instance-attribute

status: SnapshotStatus = 'created'

快照的状态。

kind 类属性 实例属性

kind: Literal['node'] = 'node'

历史步骤的类型,可在反序列化历史记录时用作区分符。

id class-attribute instance-attribute

id: str = UNSET_SNAPSHOT_ID

快照的唯一ID。

EndSnapshot dataclass

Bases: Generic[StateT, RunEndT]

描述图运行结束的历史步骤。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@dataclass(kw_only=True)
class EndSnapshot(Generic[StateT, RunEndT]):
    """History step describing the end of a graph run."""

    state: StateT
    """The state of the graph at the end of the run."""
    result: End[RunEndT]
    """The result of the graph run."""
    ts: datetime = field(default_factory=_utils.now_utc)
    """The timestamp when the graph run ended."""
    kind: Literal['end'] = 'end'
    """The kind of history step, can be used as a discriminator when deserializing history."""

    id: str = UNSET_SNAPSHOT_ID
    """Unique ID of the snapshot."""

    def __post_init__(self) -> None:
        if self.id == UNSET_SNAPSHOT_ID:
            self.id = self.node.get_snapshot_id()

    @property
    def node(self) -> End[RunEndT]:
        """Shim to get the [`result`][pydantic_graph.persistence.EndSnapshot.result].

        Useful to allow `[snapshot.node for snapshot in persistence.history]`.
        """
        return self.result

state 实例属性

state: StateT

运行结束时图的状态。

result instance-attribute

result: End[RunEndT]

图运行的结果。

ts class-attribute instance-attribute

ts: datetime = field(default_factory=now_utc)

图运行结束时的时间戳。

kind 类属性 实例属性

kind: Literal['end'] = 'end'

历史步骤的类型,可在反序列化历史记录时用作区分符。

id class-attribute instance-attribute

id: str = UNSET_SNAPSHOT_ID

快照的唯一ID。

node property

node: End[RunEndT]

用于获取 result 的垫片(shim)。

便于使用 [snapshot.node for snapshot in persistence.history]

Snapshot module-attribute

Snapshot = (
    NodeSnapshot[StateT, RunEndT]
    | EndSnapshot[StateT, RunEndT]
)

图运行历史中的一个步骤。

Graph.run 返回一个描述图执行过程的步骤列表,以及运行的返回值。

BaseStatePersistence

Bases: ABC, Generic[StateT, RunEndT]

用于存储图运行状态的抽象基类。

BaseStatePersistence 子类的每个实例应用于单次图运行。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class BaseStatePersistence(ABC, Generic[StateT, RunEndT]):
    """Abstract base class for storing the state of a graph run.

    Each instance of a `BaseStatePersistence` subclass should be used for a single graph run.
    """

    @abstractmethod
    async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
        """Snapshot the state of a graph, when the next step is to run a node.

        This method should add a [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] to persistence.

        Args:
            state: The state of the graph.
            next_node: The next node to run.
        """
        raise NotImplementedError

    @abstractmethod
    async def snapshot_node_if_new(
        self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
    ) -> None:
        """Snapshot the state of a graph if the snapshot ID doesn't already exist in persistence.

        This method will generally call [`snapshot_node`][pydantic_graph.persistence.BaseStatePersistence.snapshot_node]
        but should do so in an atomic way.

        Args:
            snapshot_id: The ID of the snapshot to check.
            state: The state of the graph.
            next_node: The next node to run.
        """
        raise NotImplementedError

    @abstractmethod
    async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
        """Snapshot the state of a graph when the graph has ended.

        This method should add an [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] to persistence.

        Args:
            state: The state of the graph.
            end: data from the end of the run.
        """
        raise NotImplementedError

    @abstractmethod
    def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]:
        """Record the run of the node, or error if the node is already running.

        Args:
            snapshot_id: The ID of the snapshot to record.

        Raises:
            GraphNodeRunningError: if the node status it not `'created'` or `'pending'`.
            LookupError: if the snapshot ID is not found in persistence.

        Returns:
            An async context manager that records the run of the node.

        In particular this should set:

        - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'running'` and
          [`NodeSnapshot.start_ts`][pydantic_graph.persistence.NodeSnapshot.start_ts] when the run starts.
        - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'success'` or `'error'` and
          [`NodeSnapshot.duration`][pydantic_graph.persistence.NodeSnapshot.duration] when the run finishes.
        """
        raise NotImplementedError

    @abstractmethod
    async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
        """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`.

        This is used by [`Graph.iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence]
        to get the next node to run.

        Returns: The snapshot, or `None` if no snapshot with status `'created`' exists.
        """
        raise NotImplementedError

    @abstractmethod
    async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
        """Load the entire history of snapshots.

        `load_all` is not used by pydantic-graph itself, instead it's provided to make it convenient to
        get all [snapshots][pydantic_graph.persistence.Snapshot] from persistence.

        Returns: The list of snapshots.
        """
        raise NotImplementedError

    def set_graph_types(self, graph: Graph[StateT, Any, RunEndT]) -> None:
        """Set the types of the state and run end from a graph.

        You generally won't need to customise this method, instead implement
        [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types] and
        [`should_set_types`][pydantic_graph.persistence.BaseStatePersistence.should_set_types].
        """
        if self.should_set_types():
            with _utils.set_nodes_type_context(graph.get_nodes()):
                self.set_types(*graph.inferred_types)

    def should_set_types(self) -> bool:
        """Whether types need to be set.

        Implementations should override this method to return `True` when types have not been set if they are needed.
        """
        return False

    def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
        """Set the types of the state and run end.

        This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing snapshots,
        e.g. with [`build_snapshot_list_type_adapter`][pydantic_graph.persistence.build_snapshot_list_type_adapter].

        Args:
            state_type: The state type.
            run_end_type: The run end type.
        """
        pass

snapshot_node abstractmethod async

snapshot_node(
    state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
) -> None

当下一步是运行节点时,对图的状态进行快照。

此方法应向持久化存储中添加一个 NodeSnapshot

参数

名称 类型 描述 默认值
state StateT

图的状态。

必需
next_node BaseNode[StateT, Any, RunEndT]

下一个要运行的节点。

必需
源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
112
113
114
115
116
117
118
119
120
121
122
@abstractmethod
async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
    """Snapshot the state of a graph, when the next step is to run a node.

    This method should add a [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] to persistence.

    Args:
        state: The state of the graph.
        next_node: The next node to run.
    """
    raise NotImplementedError

snapshot_node_if_new abstractmethod async

snapshot_node_if_new(
    snapshot_id: str,
    state: StateT,
    next_node: BaseNode[StateT, Any, RunEndT],
) -> None

如果快照ID在持久化存储中尚不存在,则对图的状态进行快照。

此方法通常会调用 snapshot_node,但应以原子方式执行。

参数

名称 类型 描述 默认值
snapshot_id str

要检查的快照ID。

必需
state StateT

图的状态。

必需
next_node BaseNode[StateT, Any, RunEndT]

下一个要运行的节点。

必需
源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@abstractmethod
async def snapshot_node_if_new(
    self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
) -> None:
    """Snapshot the state of a graph if the snapshot ID doesn't already exist in persistence.

    This method will generally call [`snapshot_node`][pydantic_graph.persistence.BaseStatePersistence.snapshot_node]
    but should do so in an atomic way.

    Args:
        snapshot_id: The ID of the snapshot to check.
        state: The state of the graph.
        next_node: The next node to run.
    """
    raise NotImplementedError

snapshot_end abstractmethod async

snapshot_end(state: StateT, end: End[RunEndT]) -> None

当图运行结束时,对图的状态进行快照。

此方法应向持久化存储中添加一个 EndSnapshot

参数

名称 类型 描述 默认值
state StateT

图的状态。

必需
end End[RunEndT]

来自运行结束时的数据。

必需
源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
140
141
142
143
144
145
146
147
148
149
150
@abstractmethod
async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
    """Snapshot the state of a graph when the graph has ended.

    This method should add an [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] to persistence.

    Args:
        state: The state of the graph.
        end: data from the end of the run.
    """
    raise NotImplementedError

record_run abstractmethod

record_run(
    snapshot_id: str,
) -> AbstractAsyncContextManager[None]

记录节点的运行,如果节点已在运行,则报错。

参数

名称 类型 描述 默认值
snapshot_id str

要记录的快照ID。

必需

引发

类型 描述
GraphNodeRunningError

如果节点状态不是 'created''pending'

LookupError

如果在持久化存储中未找到快照ID。

返回

类型 描述
AbstractAsyncContextManager[None]

一个记录节点运行的异步上下文管理器。

具体来说,它应在运行开始时设置

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@abstractmethod
def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]:
    """Record the run of the node, or error if the node is already running.

    Args:
        snapshot_id: The ID of the snapshot to record.

    Raises:
        GraphNodeRunningError: if the node status it not `'created'` or `'pending'`.
        LookupError: if the snapshot ID is not found in persistence.

    Returns:
        An async context manager that records the run of the node.

    In particular this should set:

    - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'running'` and
      [`NodeSnapshot.start_ts`][pydantic_graph.persistence.NodeSnapshot.start_ts] when the run starts.
    - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'success'` or `'error'` and
      [`NodeSnapshot.duration`][pydantic_graph.persistence.NodeSnapshot.duration] when the run finishes.
    """
    raise NotImplementedError

load_next abstractmethod async

load_next() -> NodeSnapshot[StateT, RunEndT] | None

检索一个状态为 'created' 的节点快照,并将其状态设置为 'pending'

此方法由 Graph.iter_from_persistence 用来获取下一个要运行的节点。

返回:快照,如果不存在状态为 'created' 的快照,则返回 None

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
175
176
177
178
179
180
181
182
183
184
@abstractmethod
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
    """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`.

    This is used by [`Graph.iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence]
    to get the next node to run.

    Returns: The snapshot, or `None` if no snapshot with status `'created`' exists.
    """
    raise NotImplementedError

load_all abstractmethod async

load_all() -> list[Snapshot[StateT, RunEndT]]

加载整个快照历史记录。

load_all 不被 pydantic-graph 本身使用,而是为了方便从持久化存储中获取所有快照而提供。

返回:快照列表。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
186
187
188
189
190
191
192
193
194
195
@abstractmethod
async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
    """Load the entire history of snapshots.

    `load_all` is not used by pydantic-graph itself, instead it's provided to make it convenient to
    get all [snapshots][pydantic_graph.persistence.Snapshot] from persistence.

    Returns: The list of snapshots.
    """
    raise NotImplementedError

set_graph_types

set_graph_types(graph: Graph[StateT, Any, RunEndT]) -> None

从一个图中设置状态和运行结束的类型。

通常您不需要自定义此方法,而是实现 set_typesshould_set_types

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
197
198
199
200
201
202
203
204
205
206
def set_graph_types(self, graph: Graph[StateT, Any, RunEndT]) -> None:
    """Set the types of the state and run end from a graph.

    You generally won't need to customise this method, instead implement
    [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types] and
    [`should_set_types`][pydantic_graph.persistence.BaseStatePersistence.should_set_types].
    """
    if self.should_set_types():
        with _utils.set_nodes_type_context(graph.get_nodes()):
            self.set_types(*graph.inferred_types)

should_set_types

should_set_types() -> bool

类型是否需要被设置。

如果需要类型但尚未设置,实现应重写此方法以返回 True

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
208
209
210
211
212
213
def should_set_types(self) -> bool:
    """Whether types need to be set.

    Implementations should override this method to return `True` when types have not been set if they are needed.
    """
    return False

set_types

set_types(
    state_type: type[StateT], run_end_type: type[RunEndT]
) -> None

设置状态和运行结束的类型。

此方法可用于创建用于序列化和反序列化快照的类型适配器,例如使用 build_snapshot_list_type_adapter

参数

名称 类型 描述 默认值
state_type type[StateT]

状态类型。

必需
run_end_type type[RunEndT]

运行结束类型。

必需
源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
215
216
217
218
219
220
221
222
223
224
225
def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
    """Set the types of the state and run end.

    This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing snapshots,
    e.g. with [`build_snapshot_list_type_adapter`][pydantic_graph.persistence.build_snapshot_list_type_adapter].

    Args:
        state_type: The state type.
        run_end_type: The run end type.
    """
    pass

build_snapshot_list_type_adapter

build_snapshot_list_type_adapter(
    state_t: type[StateT], run_end_t: type[RunEndT]
) -> TypeAdapter[list[Snapshot[StateT, RunEndT]]]

为快照列表构建一个类型适配器。

此方法应在 set_types 内部调用,其中会设置上下文变量,以便 Pydantic 可以为 NodeSnapshot.node 创建模式(schema)。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
228
229
230
231
232
233
234
235
236
237
238
def build_snapshot_list_type_adapter(
    state_t: type[StateT], run_end_t: type[RunEndT]
) -> pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]]:
    """Build a type adapter for a list of snapshots.

    This method should be called from within
    [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types]
    where context variables will be set such that Pydantic can create a schema for
    [`NodeSnapshot.node`][pydantic_graph.persistence.NodeSnapshot.node].
    """
    return pydantic.TypeAdapter(list[Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]])

内存中的状态持久化。

此模块为图提供了简单的内存状态持久化功能。

SimpleStatePersistence dataclass

Bases: BaseStatePersistence[StateT, RunEndT]

简单的内存状态持久化,仅保存最新的快照。

如果在运行图时未提供状态持久化实现,则默认使用此实现。

源代码位于 pydantic_graph/pydantic_graph/persistence/in_mem.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@dataclass
class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]):
    """Simple in memory state persistence that just hold the latest snapshot.

    If no state persistence implementation is provided when running a graph, this is used by default.
    """

    last_snapshot: Snapshot[StateT, RunEndT] | None = None
    """The last snapshot."""

    async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
        self.last_snapshot = NodeSnapshot(state=state, node=next_node)

    async def snapshot_node_if_new(
        self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
    ) -> None:
        if self.last_snapshot and self.last_snapshot.id == snapshot_id:
            return  # pragma: no cover
        else:
            await self.snapshot_node(state, next_node)

    async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
        self.last_snapshot = EndSnapshot(state=state, result=end)

    @asynccontextmanager
    async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
        if self.last_snapshot is None or snapshot_id != self.last_snapshot.id:
            raise LookupError(f'No snapshot found with id={snapshot_id!r}')

        assert isinstance(self.last_snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
        exceptions.GraphNodeStatusError.check(self.last_snapshot.status)
        self.last_snapshot.status = 'running'
        self.last_snapshot.start_ts = _utils.now_utc()

        start = perf_counter()
        try:
            yield
        except Exception:
            self.last_snapshot.duration = perf_counter() - start
            self.last_snapshot.status = 'error'
            raise
        else:
            self.last_snapshot.duration = perf_counter() - start
            self.last_snapshot.status = 'success'

    async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
        if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created':
            self.last_snapshot.status = 'pending'
            return self.last_snapshot

    async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
        raise NotImplementedError('load is not supported for SimpleStatePersistence')

last_snapshot class-attribute instance-attribute

last_snapshot: Snapshot[StateT, RunEndT] | None = None

最后一个快照。

FullStatePersistence dataclass

Bases: BaseStatePersistence[StateT, RunEndT]

保存快照列表的内存状态持久化。

源代码位于 pydantic_graph/pydantic_graph/persistence/in_mem.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@dataclass
class FullStatePersistence(BaseStatePersistence[StateT, RunEndT]):
    """In memory state persistence that hold a list of snapshots."""

    deep_copy: bool = True
    """Whether to deep copy the state and nodes when storing them.

    Defaults to `True` so even if nodes or state are modified after the snapshot is taken,
    the persistence history will record the value at the time of the snapshot.
    """
    history: list[Snapshot[StateT, RunEndT]] = field(default_factory=list)
    """List of snapshots taken during the graph run."""
    _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field(
        default=None, init=False, repr=False
    )

    async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
        snapshot = NodeSnapshot(
            state=self._prep_state(state),
            node=next_node.deep_copy() if self.deep_copy else next_node,
        )
        self.history.append(snapshot)

    async def snapshot_node_if_new(
        self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
    ) -> None:
        if not any(s.id == snapshot_id for s in self.history):
            await self.snapshot_node(state, next_node)

    async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
        snapshot = EndSnapshot(
            state=self._prep_state(state),
            result=end.deep_copy_data() if self.deep_copy else end,
        )
        self.history.append(snapshot)

    @asynccontextmanager
    async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
        try:
            snapshot = next(s for s in self.history if s.id == snapshot_id)
        except StopIteration as e:
            raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e

        assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
        exceptions.GraphNodeStatusError.check(snapshot.status)
        snapshot.status = 'running'
        snapshot.start_ts = _utils.now_utc()
        start = perf_counter()
        try:
            yield
        except Exception:
            snapshot.duration = perf_counter() - start
            snapshot.status = 'error'
            raise
        else:
            snapshot.duration = perf_counter() - start
            snapshot.status = 'success'

    async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
        if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
            snapshot.status = 'pending'
            return snapshot

    async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
        return self.history

    def should_set_types(self) -> bool:
        return self._snapshots_type_adapter is None

    def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
        self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type)

    def dump_json(self, *, indent: int | None = None) -> bytes:
        """Dump the history to JSON bytes."""
        assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`'
        return self._snapshots_type_adapter.dump_json(self.history, indent=indent)

    def load_json(self, json_data: str | bytes | bytearray) -> None:
        """Load the history from JSON."""
        assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`'
        self.history = self._snapshots_type_adapter.validate_json(json_data)

    def _prep_state(self, state: StateT) -> StateT:
        """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default."""
        if not self.deep_copy or state is None:
            return state
        else:
            return copy.deepcopy(state)

deep_copy class-attribute instance-attribute

deep_copy: bool = True

在存储状态和节点时是否进行深拷贝。

默认为 True,因此即使在拍摄快照后修改了节点或状态,持久化历史记录也会记录快照拍摄时的值。

history class-attribute instance-attribute

history: list[Snapshot[StateT, RunEndT]] = field(
    default_factory=list
)

图运行期间拍摄的快照列表。

dump_json

dump_json(*, indent: int | None = None) -> bytes

将历史记录转储为 JSON 字节。

源代码位于 pydantic_graph/pydantic_graph/persistence/in_mem.py
157
158
159
160
def dump_json(self, *, indent: int | None = None) -> bytes:
    """Dump the history to JSON bytes."""
    assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`'
    return self._snapshots_type_adapter.dump_json(self.history, indent=indent)

load_json

load_json(json_data: str | bytes | bytearray) -> None

从 JSON 加载历史记录。

源代码位于 pydantic_graph/pydantic_graph/persistence/in_mem.py
162
163
164
165
def load_json(self, json_data: str | bytes | bytearray) -> None:
    """Load the history from JSON."""
    assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`'
    self.history = self._snapshots_type_adapter.validate_json(json_data)

FileStatePersistence dataclass

Bases: BaseStatePersistence[StateT, RunEndT]

基于文件的状态持久化,将图运行状态保存在 JSON 文件中。

源代码位于 pydantic_graph/pydantic_graph/persistence/file.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@dataclass
class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]):
    """File based state persistence that hold graph run state in a JSON file."""

    json_file: Path
    """Path to the JSON file where the snapshots are stored.

    You should use a different file for each graph run, but a single file should be reused for multiple
    steps of the same run.

    For example if you have a run ID of the form `run_123abc`, you might create a `FileStatePersistence` thus:

    ```py
    from pathlib import Path

    from pydantic_graph import FullStatePersistence

    run_id = 'run_123abc'
    persistence = FullStatePersistence(Path('runs') / f'{run_id}.json')
    ```
    """
    _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field(
        default=None, init=False, repr=False
    )

    async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
        await self._append_save(NodeSnapshot(state=state, node=next_node))

    async def snapshot_node_if_new(
        self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
    ) -> None:
        async with self._lock():
            snapshots = await self.load_all()
            if not any(s.id == snapshot_id for s in snapshots):  # pragma: no branch
                await self._append_save(NodeSnapshot(state=state, node=next_node), lock=False)

    async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
        await self._append_save(EndSnapshot(state=state, result=end))

    @asynccontextmanager
    async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
        async with self._lock():
            snapshots = await self.load_all()
            try:
                snapshot = next(s for s in snapshots if s.id == snapshot_id)
            except StopIteration as e:
                raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e

            assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
            exceptions.GraphNodeStatusError.check(snapshot.status)
            snapshot.status = 'running'
            snapshot.start_ts = _utils.now_utc()
            await self._save(snapshots)

        start = perf_counter()
        try:
            yield
        except Exception:
            duration = perf_counter() - start
            async with self._lock():
                await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, duration, 'error')
            raise
        else:
            snapshot.duration = perf_counter() - start
            async with self._lock():
                await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, snapshot.duration, 'success')

    async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
        async with self._lock():
            snapshots = await self.load_all()
            if snapshot := next((s for s in snapshots if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
                snapshot.status = 'pending'
                await self._save(snapshots)
                return snapshot

    def should_set_types(self) -> bool:
        """Whether types need to be set."""
        return self._snapshots_type_adapter is None

    def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
        self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type)

    async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
        return await _graph_utils.run_in_executor(self._load_sync)

    def _load_sync(self) -> list[Snapshot[StateT, RunEndT]]:
        assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set'
        try:
            content = self.json_file.read_bytes()
        except FileNotFoundError:
            return []
        else:
            return self._snapshots_type_adapter.validate_json(content)

    def _after_run_sync(self, snapshot_id: str, duration: float, status: SnapshotStatus) -> None:
        snapshots = self._load_sync()
        snapshot = next(s for s in snapshots if s.id == snapshot_id)
        assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
        snapshot.duration = duration
        snapshot.status = status
        self._save_sync(snapshots)

    async def _save(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None:
        await _graph_utils.run_in_executor(self._save_sync, snapshots)

    def _save_sync(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None:
        assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set'
        self.json_file.write_bytes(self._snapshots_type_adapter.dump_json(snapshots, indent=2))

    async def _append_save(self, snapshot: Snapshot[StateT, RunEndT], *, lock: bool = True) -> None:
        assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set'
        async with AsyncExitStack() as stack:
            if lock:
                await stack.enter_async_context(self._lock())
            snapshots = await self.load_all()
            snapshots.append(snapshot)
            await self._save(snapshots)

    @asynccontextmanager
    async def _lock(self, *, timeout: float = 1.0) -> AsyncIterator[None]:
        """Lock a file by checking and writing a `.pydantic-graph-persistence-lock` to it.

        Args:
            timeout: how long to wait for the lock

        Returns: an async context manager that holds the lock
        """
        lock_file = self.json_file.parent / f'{self.json_file.name}.pydantic-graph-persistence-lock'
        lock_id = secrets.token_urlsafe().encode()

        with anyio.fail_after(timeout):
            while not await _file_append_check(lock_file, lock_id):
                await anyio.sleep(0.01)

        try:
            yield
        finally:
            await _graph_utils.run_in_executor(lock_file.unlink, missing_ok=True)

json_file instance-attribute

json_file: Path

存储快照的 JSON 文件的路径。

您应该为每次图运行使用不同的文件,但对于同一次运行的多个步骤应重用同一个文件。

例如,如果您有一个形式为 run_123abc 的运行 ID,您可以这样创建一个 FileStatePersistence

from pathlib import Path

from pydantic_graph import FullStatePersistence

run_id = 'run_123abc'
persistence = FullStatePersistence(Path('runs') / f'{run_id}.json')

should_set_types

should_set_types() -> bool

类型是否需要被设置。

源代码位于 pydantic_graph/pydantic_graph/persistence/file.py
104
105
106
def should_set_types(self) -> bool:
    """Whether types need to be set."""
    return self._snapshots_type_adapter is None