跳到内容

pydantic_graph.persistence

SnapshotStatus module-attribute

SnapshotStatus = Literal[
    "created", "pending", "running", "success", "error"
]

快照的状态。

  • 'created': 快照已创建但尚未运行。
  • 'pending': 快照已使用 load_next 检索,但尚未运行。
  • 'running': 快照当前正在运行。
  • 'success': 快照已成功运行。
  • 'error': 快照已运行,但发生错误。

NodeSnapshot dataclass

基类:Generic[StateT, RunEndT]

描述图中节点执行的历史步骤。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@dataclass
class NodeSnapshot(Generic[StateT, RunEndT]):
    """History step describing the execution of a node in a graph."""

    state: StateT
    """The state of the graph before the node is run."""
    node: Annotated[BaseNode[StateT, Any, RunEndT], _utils.CustomNodeSchema()]
    """The node to run next."""
    start_ts: datetime | None = None
    """The timestamp when the node started running, `None` until the run starts."""
    duration: float | None = None
    """The duration of the node run in seconds, if the node has been run."""
    status: SnapshotStatus = 'created'
    """The status of the snapshot."""
    kind: Literal['node'] = 'node'
    """The kind of history step, can be used as a discriminator when deserializing history."""

    id: str = UNSET_SNAPSHOT_ID
    """Unique ID of the snapshot."""

    def __post_init__(self) -> None:
        if self.id == UNSET_SNAPSHOT_ID:
            self.id = self.node.get_snapshot_id()

state instance-attribute

state: StateT

节点运行之前图的状态。

node instance-attribute

node: Annotated[
    BaseNode[StateT, Any, RunEndT], CustomNodeSchema()
]

接下来要运行的节点。

start_ts class-attribute instance-attribute

start_ts: datetime | None = None

节点开始运行的时间戳,在运行开始之前为 None

duration class-attribute instance-attribute

duration: float | None = None

节点运行的持续时间(秒),如果节点已运行。

status class-attribute instance-attribute

status: SnapshotStatus = 'created'

快照的状态。

kind class-attribute instance-attribute

kind: Literal['node'] = 'node'

历史步骤的类型,可以用作反序列化历史记录时的鉴别器。

id class-attribute instance-attribute

id: str = UNSET_SNAPSHOT_ID

快照的唯一 ID。

EndSnapshot dataclass

基类:Generic[StateT, RunEndT]

描述图运行结束的历史步骤。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@dataclass
class EndSnapshot(Generic[StateT, RunEndT]):
    """History step describing the end of a graph run."""

    state: StateT
    """The state of the graph at the end of the run."""
    result: End[RunEndT]
    """The result of the graph run."""
    ts: datetime = field(default_factory=_utils.now_utc)
    """The timestamp when the graph run ended."""
    kind: Literal['end'] = 'end'
    """The kind of history step, can be used as a discriminator when deserializing history."""

    id: str = UNSET_SNAPSHOT_ID
    """Unique ID of the snapshot."""

    def __post_init__(self) -> None:
        if self.id == UNSET_SNAPSHOT_ID:
            self.id = self.node.get_snapshot_id()

    @property
    def node(self) -> End[RunEndT]:
        """Shim to get the [`result`][pydantic_graph.persistence.EndSnapshot.result].

        Useful to allow `[snapshot.node for snapshot in persistence.history]`.
        """
        return self.result

state instance-attribute

state: StateT

图运行结束时图的状态。

result instance-attribute

result: End[RunEndT]

图运行的结果。

ts class-attribute instance-attribute

ts: datetime = field(default_factory=now_utc)

图运行结束的时间戳。

kind class-attribute instance-attribute

kind: Literal['end'] = 'end'

历史步骤的类型,可以用作反序列化历史记录时的鉴别器。

id class-attribute instance-attribute

id: str = UNSET_SNAPSHOT_ID

快照的唯一 ID。

node property

node: End[RunEndT]

用于获取 result 的垫片。

用于允许 [snapshot.node for snapshot in persistence.history]

Snapshot module-attribute

Snapshot = Union[
    NodeSnapshot[StateT, RunEndT],
    EndSnapshot[StateT, RunEndT],
]

图运行历史记录中的一个步骤。

Graph.run 返回这些步骤的列表,描述图的执行过程,以及运行返回值。

BaseStatePersistence

基类:ABC, Generic[StateT, RunEndT]

用于存储图运行状态的抽象基类。

BaseStatePersistence 子类的每个实例都应用于单个图运行。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class BaseStatePersistence(ABC, Generic[StateT, RunEndT]):
    """Abstract base class for storing the state of a graph run.

    Each instance of a `BaseStatePersistence` subclass should be used for a single graph run.
    """

    @abstractmethod
    async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
        """Snapshot the state of a graph, when the next step is to run a node.

        This method should add a [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] to persistence.

        Args:
            state: The state of the graph.
            next_node: The next node to run.
        """
        raise NotImplementedError

    @abstractmethod
    async def snapshot_node_if_new(
        self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
    ) -> None:
        """Snapshot the state of a graph if the snapshot ID doesn't already exist in persistence.

        This method will generally call [`snapshot_node`][pydantic_graph.persistence.BaseStatePersistence.snapshot_node]
        but should do so in an atomic way.

        Args:
            snapshot_id: The ID of the snapshot to check.
            state: The state of the graph.
            next_node: The next node to run.
        """
        raise NotImplementedError

    @abstractmethod
    async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
        """Snapshot the state of a graph when the graph has ended.

        This method should add an [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] to persistence.

        Args:
            state: The state of the graph.
            end: data from the end of the run.
        """
        raise NotImplementedError

    @abstractmethod
    def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]:
        """Record the run of the node, or error if the node is already running.

        Args:
            snapshot_id: The ID of the snapshot to record.

        Raises:
            GraphNodeRunningError: if the node status it not `'created'` or `'pending'`.
            LookupError: if the snapshot ID is not found in persistence.

        Returns:
            An async context manager that records the run of the node.

        In particular this should set:

        - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'running'` and
          [`NodeSnapshot.start_ts`][pydantic_graph.persistence.NodeSnapshot.start_ts] when the run starts.
        - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'success'` or `'error'` and
          [`NodeSnapshot.duration`][pydantic_graph.persistence.NodeSnapshot.duration] when the run finishes.
        """
        raise NotImplementedError

    @abstractmethod
    async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
        """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`.

        This is used by [`Graph.iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence]
        to get the next node to run.

        Returns: The snapshot, or `None` if no snapshot with status `'created`' exists.
        """
        raise NotImplementedError

    @abstractmethod
    async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
        """Load the entire history of snapshots.

        `load_all` is not used by pydantic-graph itself, instead it's provided to make it convenient to
        get all [snapshots][pydantic_graph.persistence.Snapshot] from persistence.

        Returns: The list of snapshots.
        """
        raise NotImplementedError

    def set_graph_types(self, graph: Graph[StateT, Any, RunEndT]) -> None:
        """Set the types of the state and run end from a graph.

        You generally won't need to customise this method, instead implement
        [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types] and
        [`should_set_types`][pydantic_graph.persistence.BaseStatePersistence.should_set_types].
        """
        if self.should_set_types():
            with _utils.set_nodes_type_context(graph.get_nodes()):
                self.set_types(*graph.inferred_types)

    def should_set_types(self) -> bool:
        """Whether types need to be set.

        Implementations should override this method to return `True` when types have not been set if they are needed.
        """
        return False

    def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
        """Set the types of the state and run end.

        This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing snapshots,
        e.g. with [`build_snapshot_list_type_adapter`][pydantic_graph.persistence.build_snapshot_list_type_adapter].

        Args:
            state_type: The state type.
            run_end_type: The run end type.
        """
        pass

snapshot_node abstractmethod async

snapshot_node(
    state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
) -> None

当下一步是运行节点时,快照图的状态。

此方法应将 NodeSnapshot 添加到持久化。

参数

名称 类型 描述 默认值
状态 StateT

图的状态。

必需
next_node BaseNode[StateT, Any, RunEndT]

接下来要运行的节点。

必需
源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
112
113
114
115
116
117
118
119
120
121
122
@abstractmethod
async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
    """Snapshot the state of a graph, when the next step is to run a node.

    This method should add a [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] to persistence.

    Args:
        state: The state of the graph.
        next_node: The next node to run.
    """
    raise NotImplementedError

snapshot_node_if_new abstractmethod async

snapshot_node_if_new(
    snapshot_id: str,
    state: StateT,
    next_node: BaseNode[StateT, Any, RunEndT],
) -> None

如果快照 ID 在持久化中尚不存在,则快照图的状态。

此方法通常会调用 snapshot_node,但应以原子方式执行。

参数

名称 类型 描述 默认值
snapshot_id str

要检查的快照的 ID。

必需
状态 StateT

图的状态。

必需
next_node BaseNode[StateT, Any, RunEndT]

接下来要运行的节点。

必需
源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@abstractmethod
async def snapshot_node_if_new(
    self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
) -> None:
    """Snapshot the state of a graph if the snapshot ID doesn't already exist in persistence.

    This method will generally call [`snapshot_node`][pydantic_graph.persistence.BaseStatePersistence.snapshot_node]
    but should do so in an atomic way.

    Args:
        snapshot_id: The ID of the snapshot to check.
        state: The state of the graph.
        next_node: The next node to run.
    """
    raise NotImplementedError

snapshot_end abstractmethod async

snapshot_end(state: StateT, end: End[RunEndT]) -> None

当图结束时,快照图的状态。

此方法应将 EndSnapshot 添加到持久化。

参数

名称 类型 描述 默认值
状态 StateT

图的状态。

必需
end End[RunEndT]

来自运行结束的数据。

必需
源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
140
141
142
143
144
145
146
147
148
149
150
@abstractmethod
async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
    """Snapshot the state of a graph when the graph has ended.

    This method should add an [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] to persistence.

    Args:
        state: The state of the graph.
        end: data from the end of the run.
    """
    raise NotImplementedError

record_run abstractmethod

record_run(
    snapshot_id: str,
) -> AbstractAsyncContextManager[None]

记录节点的运行,如果节点已在运行,则记录错误。

参数

名称 类型 描述 默认值
snapshot_id str

要记录的快照的 ID。

必需

引发

类型 描述
GraphNodeRunningError

如果节点状态不是 'created''pending'

LookupError

如果在持久化中找不到快照 ID。

返回

类型 描述
AbstractAsyncContextManager[None]

一个异步上下文管理器,用于记录节点的运行。

特别是,这应设置

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@abstractmethod
def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]:
    """Record the run of the node, or error if the node is already running.

    Args:
        snapshot_id: The ID of the snapshot to record.

    Raises:
        GraphNodeRunningError: if the node status it not `'created'` or `'pending'`.
        LookupError: if the snapshot ID is not found in persistence.

    Returns:
        An async context manager that records the run of the node.

    In particular this should set:

    - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'running'` and
      [`NodeSnapshot.start_ts`][pydantic_graph.persistence.NodeSnapshot.start_ts] when the run starts.
    - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'success'` or `'error'` and
      [`NodeSnapshot.duration`][pydantic_graph.persistence.NodeSnapshot.duration] when the run finishes.
    """
    raise NotImplementedError

load_next abstractmethod async

load_next() -> NodeSnapshot[StateT, RunEndT] | None

检索状态为 'created' 的节点快照,并将其状态设置为 'pending'

这由 Graph.iter_from_persistence 使用,以获取下一个要运行的节点。

返回:快照,如果不存在状态为 'created' 的快照,则返回 None

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
175
176
177
178
179
180
181
182
183
184
@abstractmethod
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
    """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`.

    This is used by [`Graph.iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence]
    to get the next node to run.

    Returns: The snapshot, or `None` if no snapshot with status `'created`' exists.
    """
    raise NotImplementedError

load_all abstractmethod async

load_all() -> list[Snapshot[StateT, RunEndT]]

加载快照的整个历史记录。

load_all 本身不被 pydantic-graph 使用,而是为了方便获取持久化中的所有 快照 而提供的。

返回:快照列表。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
186
187
188
189
190
191
192
193
194
195
@abstractmethod
async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
    """Load the entire history of snapshots.

    `load_all` is not used by pydantic-graph itself, instead it's provided to make it convenient to
    get all [snapshots][pydantic_graph.persistence.Snapshot] from persistence.

    Returns: The list of snapshots.
    """
    raise NotImplementedError

set_graph_types

set_graph_types(graph: Graph[StateT, Any, RunEndT]) -> None

从图中设置状态和运行结束的类型。

通常你不需要自定义此方法,而是实现 set_typesshould_set_types

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
197
198
199
200
201
202
203
204
205
206
def set_graph_types(self, graph: Graph[StateT, Any, RunEndT]) -> None:
    """Set the types of the state and run end from a graph.

    You generally won't need to customise this method, instead implement
    [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types] and
    [`should_set_types`][pydantic_graph.persistence.BaseStatePersistence.should_set_types].
    """
    if self.should_set_types():
        with _utils.set_nodes_type_context(graph.get_nodes()):
            self.set_types(*graph.inferred_types)

should_set_types

should_set_types() -> bool

是否需要设置类型。

如果需要类型但尚未设置,则实现应覆盖此方法以返回 True

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
208
209
210
211
212
213
def should_set_types(self) -> bool:
    """Whether types need to be set.

    Implementations should override this method to return `True` when types have not been set if they are needed.
    """
    return False

set_types

set_types(
    state_type: type[StateT], run_end_type: type[RunEndT]
) -> None

设置状态和运行结束的类型。

这可以用于创建 类型适配器,用于序列化和反序列化快照,例如使用 build_snapshot_list_type_adapter

参数

名称 类型 描述 默认值
state_type type[StateT]

状态类型。

必需
run_end_type type[RunEndT]

运行结束类型。

必需
源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
215
216
217
218
219
220
221
222
223
224
225
def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
    """Set the types of the state and run end.

    This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing snapshots,
    e.g. with [`build_snapshot_list_type_adapter`][pydantic_graph.persistence.build_snapshot_list_type_adapter].

    Args:
        state_type: The state type.
        run_end_type: The run end type.
    """
    pass

build_snapshot_list_type_adapter

build_snapshot_list_type_adapter(
    state_t: type[StateT], run_end_t: type[RunEndT]
) -> TypeAdapter[list[Snapshot[StateT, RunEndT]]]

为快照列表构建类型适配器。

此方法应从 set_types 中调用,其中上下文变量将被设置,以便 Pydantic 可以为 NodeSnapshot.node 创建模式。

源代码位于 pydantic_graph/pydantic_graph/persistence/__init__.py
228
229
230
231
232
233
234
235
236
237
238
def build_snapshot_list_type_adapter(
    state_t: type[StateT], run_end_t: type[RunEndT]
) -> pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]]:
    """Build a type adapter for a list of snapshots.

    This method should be called from within
    [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types]
    where context variables will be set such that Pydantic can create a schema for
    [`NodeSnapshot.node`][pydantic_graph.persistence.NodeSnapshot.node].
    """
    return pydantic.TypeAdapter(list[Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]])

内存中状态持久化。

此模块为图提供简单的内存中状态持久化。

SimpleStatePersistence dataclass

基类:BaseStatePersistence[StateT, RunEndT]

简单的内存中状态持久化,仅保存最新的快照。

如果在运行图时未提供状态持久化实现,则默认使用此实现。

源代码位于 pydantic_graph/pydantic_graph/persistence/in_mem.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@dataclass
class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]):
    """Simple in memory state persistence that just hold the latest snapshot.

    If no state persistence implementation is provided when running a graph, this is used by default.
    """

    last_snapshot: Snapshot[StateT, RunEndT] | None = None
    """The last snapshot."""

    async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
        self.last_snapshot = NodeSnapshot(state=state, node=next_node)

    async def snapshot_node_if_new(
        self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
    ) -> None:
        if self.last_snapshot and self.last_snapshot.id == snapshot_id:
            return
        else:
            await self.snapshot_node(state, next_node)

    async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
        self.last_snapshot = EndSnapshot(state=state, result=end)

    @asynccontextmanager
    async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
        if self.last_snapshot is None or snapshot_id != self.last_snapshot.id:
            raise LookupError(f'No snapshot found with id={snapshot_id!r}')

        assert isinstance(self.last_snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
        exceptions.GraphNodeStatusError.check(self.last_snapshot.status)
        self.last_snapshot.status = 'running'
        self.last_snapshot.start_ts = _utils.now_utc()

        start = perf_counter()
        try:
            yield
        except Exception:
            self.last_snapshot.duration = perf_counter() - start
            self.last_snapshot.status = 'error'
            raise
        else:
            self.last_snapshot.duration = perf_counter() - start
            self.last_snapshot.status = 'success'

    async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
        if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created':
            self.last_snapshot.status = 'pending'
            return self.last_snapshot

    async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
        raise NotImplementedError('load is not supported for SimpleStatePersistence')

last_snapshot class-attribute instance-attribute

last_snapshot: Snapshot[StateT, RunEndT] | None = None

最新的快照。

FullStatePersistence dataclass

基类:BaseStatePersistence[StateT, RunEndT]

内存中状态持久化,保存快照列表。

源代码位于 pydantic_graph/pydantic_graph/persistence/in_mem.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@dataclass
class FullStatePersistence(BaseStatePersistence[StateT, RunEndT]):
    """In memory state persistence that hold a list of snapshots."""

    deep_copy: bool = True
    """Whether to deep copy the state and nodes when storing them.

    Defaults to `True` so even if nodes or state are modified after the snapshot is taken,
    the persistence history will record the value at the time of the snapshot.
    """
    history: list[Snapshot[StateT, RunEndT]] = field(default_factory=list)
    """List of snapshots taken during the graph run."""
    _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field(
        default=None, init=False, repr=False
    )

    async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
        snapshot = NodeSnapshot(
            state=self._prep_state(state),
            node=next_node.deep_copy() if self.deep_copy else next_node,
        )
        self.history.append(snapshot)

    async def snapshot_node_if_new(
        self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
    ) -> None:
        if not any(s.id == snapshot_id for s in self.history):
            await self.snapshot_node(state, next_node)

    async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
        snapshot = EndSnapshot(
            state=self._prep_state(state),
            result=end.deep_copy_data() if self.deep_copy else end,
        )
        self.history.append(snapshot)

    @asynccontextmanager
    async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
        try:
            snapshot = next(s for s in self.history if s.id == snapshot_id)
        except StopIteration as e:
            raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e

        assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
        exceptions.GraphNodeStatusError.check(snapshot.status)
        snapshot.status = 'running'
        snapshot.start_ts = _utils.now_utc()
        start = perf_counter()
        try:
            yield
        except Exception:
            snapshot.duration = perf_counter() - start
            snapshot.status = 'error'
            raise
        else:
            snapshot.duration = perf_counter() - start
            snapshot.status = 'success'

    async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
        if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
            snapshot.status = 'pending'
            return snapshot

    async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
        return self.history

    def should_set_types(self) -> bool:
        return self._snapshots_type_adapter is None

    def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
        self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type)

    def dump_json(self, *, indent: int | None = None) -> bytes:
        """Dump the history to JSON bytes."""
        assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`'
        return self._snapshots_type_adapter.dump_json(self.history, indent=indent)

    def load_json(self, json_data: str | bytes | bytearray) -> None:
        """Load the history from JSON."""
        assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`'
        self.history = self._snapshots_type_adapter.validate_json(json_data)

    def _prep_state(self, state: StateT) -> StateT:
        """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default."""
        if not self.deep_copy or state is None:
            return state
        else:
            return copy.deepcopy(state)

deep_copy class-attribute instance-attribute

deep_copy: bool = True

是否在存储状态和节点时进行深拷贝。

默认为 True,因此即使在拍摄快照后修改了节点或状态,持久化历史记录也将记录快照时的值。

history class-attribute instance-attribute

history: list[Snapshot[StateT, RunEndT]] = field(
    default_factory=list
)

图运行期间拍摄的快照列表。

dump_json

dump_json(*, indent: int | None = None) -> bytes

将历史记录转储为 JSON 字节。

源代码位于 pydantic_graph/pydantic_graph/persistence/in_mem.py
157
158
159
160
def dump_json(self, *, indent: int | None = None) -> bytes:
    """Dump the history to JSON bytes."""
    assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`'
    return self._snapshots_type_adapter.dump_json(self.history, indent=indent)

load_json

load_json(json_data: str | bytes | bytearray) -> None

从 JSON 加载历史记录。

源代码位于 pydantic_graph/pydantic_graph/persistence/in_mem.py
162
163
164
165
def load_json(self, json_data: str | bytes | bytearray) -> None:
    """Load the history from JSON."""
    assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`'
    self.history = self._snapshots_type_adapter.validate_json(json_data)

FileStatePersistence dataclass

基类:BaseStatePersistence[StateT, RunEndT]

基于文件的状态持久化,将图运行状态保存在 JSON 文件中。

源代码位于 pydantic_graph/pydantic_graph/persistence/file.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
@dataclass
class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]):
    """File based state persistence that hold graph run state in a JSON file."""

    json_file: Path
    """Path to the JSON file where the snapshots are stored.

    You should use a different file for each graph run, but a single file should be reused for multiple
    steps of the same run.

    For example if you have a run ID of the form `run_123abc`, you might create a `FileStatePersistence` thus:

    ```py
    from pathlib import Path

    from pydantic_graph import FullStatePersistence

    run_id = 'run_123abc'
    persistence = FullStatePersistence(Path('runs') / f'{run_id}.json')
    ```
    """
    _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field(
        default=None, init=False, repr=False
    )

    async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
        await self._append_save(NodeSnapshot(state=state, node=next_node))

    async def snapshot_node_if_new(
        self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
    ) -> None:
        async with self._lock():
            snapshots = await self.load_all()
            if not any(s.id == snapshot_id for s in snapshots):
                await self._append_save(NodeSnapshot(state=state, node=next_node), lock=False)

    async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
        await self._append_save(EndSnapshot(state=state, result=end))

    @asynccontextmanager
    async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
        async with self._lock():
            snapshots = await self.load_all()
            try:
                snapshot = next(s for s in snapshots if s.id == snapshot_id)
            except StopIteration as e:
                raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e

            assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
            exceptions.GraphNodeStatusError.check(snapshot.status)
            snapshot.status = 'running'
            snapshot.start_ts = _utils.now_utc()
            await self._save(snapshots)

        start = perf_counter()
        try:
            yield
        except Exception:
            duration = perf_counter() - start
            async with self._lock():
                await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, duration, 'error')
            raise
        else:
            snapshot.duration = perf_counter() - start
            async with self._lock():
                await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, snapshot.duration, 'success')

    async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
        async with self._lock():
            snapshots = await self.load_all()
            if snapshot := next((s for s in snapshots if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
                snapshot.status = 'pending'
                await self._save(snapshots)
                return snapshot

    def should_set_types(self) -> bool:
        """Whether types need to be set."""
        return self._snapshots_type_adapter is None

    def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
        self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type)

    async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
        return await _graph_utils.run_in_executor(self._load_sync)

    def _load_sync(self) -> list[Snapshot[StateT, RunEndT]]:
        assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set'
        try:
            content = self.json_file.read_bytes()
        except FileNotFoundError:
            return []
        else:
            return self._snapshots_type_adapter.validate_json(content)

    def _after_run_sync(self, snapshot_id: str, duration: float, status: SnapshotStatus) -> None:
        snapshots = self._load_sync()
        snapshot = next(s for s in snapshots if s.id == snapshot_id)
        assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
        snapshot.duration = duration
        snapshot.status = status
        self._save_sync(snapshots)

    async def _save(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None:
        await _graph_utils.run_in_executor(self._save_sync, snapshots)

    def _save_sync(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None:
        assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set'
        self.json_file.write_bytes(self._snapshots_type_adapter.dump_json(snapshots, indent=2))

    async def _append_save(self, snapshot: Snapshot[StateT, RunEndT], *, lock: bool = True) -> None:
        assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set'
        async with AsyncExitStack() as stack:
            if lock:
                await stack.enter_async_context(self._lock())
            snapshots = await self.load_all()
            snapshots.append(snapshot)
            await self._save(snapshots)

    @asynccontextmanager
    async def _lock(self, *, timeout: float = 1.0) -> AsyncIterator[None]:
        """Lock a file by checking and writing a `.pydantic-graph-persistence-lock` to it.

        Args:
            timeout: how long to wait for the lock

        Returns: an async context manager that holds the lock
        """
        lock_file = self.json_file.parent / f'{self.json_file.name}.pydantic-graph-persistence-lock'
        lock_id = secrets.token_urlsafe().encode()
        await asyncio.wait_for(_get_lock(lock_file, lock_id), timeout=timeout)
        try:
            yield
        finally:
            await _graph_utils.run_in_executor(lock_file.unlink, missing_ok=True)

json_file instance-attribute

json_file: Path

存储快照的 JSON 文件的路径。

对于每个图运行,应使用不同的文件,但单个文件应在同一运行的多个步骤中重复使用。

例如,如果你有一个 run_123abc 形式的运行 ID,你可以创建一个 FileStatePersistence,如下所示

from pathlib import Path

from pydantic_graph import FullStatePersistence

run_id = 'run_123abc'
persistence = FullStatePersistence(Path('runs') / f'{run_id}.json')

should_set_types

should_set_types() -> bool

是否需要设置类型。

源代码位于 pydantic_graph/pydantic_graph/persistence/file.py
104
105
106
def should_set_types(self) -> bool:
    """Whether types need to be set."""
    return self._snapshots_type_adapter is None