测试与评估
在 PydanticAI 和 LLM 集成中,通常有两种不同的测试类型
- 单元测试 — 针对您的应用程序代码的测试,以及它是否行为正确
- 评估 — 针对 LLM 的测试,以及其响应的好坏程度
在大多数情况下,这两种测试类型具有相当独立的目标和考虑因素。
单元测试
针对 PydanticAI 代码的单元测试就像针对任何其他 Python 代码的单元测试一样。
因为在大多数情况下它们都不是什么新鲜事物,所以我们有非常完善的工具和模式来编写和运行这些类型的测试。
除非您非常确定自己有更好的方法,否则您可能需要大致遵循以下策略
- 使用
pytest
作为您的测试框架 - 如果您发现自己需要输入很长的断言,请使用 inline-snapshot
- 类似地,dirty-equals 对于比较大型数据结构可能很有用
- 使用
TestModel
或FunctionModel
代替您的实际模型,以避免真实 LLM 调用的使用量、延迟和可变性 - 使用
Agent.override
在您的应用程序逻辑中替换您的模型 - 全局设置
ALLOW_MODEL_REQUESTS=False
以阻止意外向非测试模型发出任何请求
使用 TestModel
进行单元测试
锻炼您的大部分应用程序代码最简单和最快的方法是使用 TestModel
,这将(默认情况下)调用代理中的所有工具,然后根据代理的返回类型返回纯文本或结构化响应。
TestModel
不是魔法
TestModel
的“巧妙”(但不太巧妙)之处在于,它将尝试基于已注册工具的模式为函数工具和结果类型生成有效的结构化数据。
TestModel
中没有 ML 或 AI,它只是普通的程序化 Python 代码,试图生成满足工具 JSON 模式的数据。
结果数据看起来不会漂亮或相关,但在大多数情况下它应该通过 Pydantic 的验证。如果您想要更复杂的东西,请使用 FunctionModel
并编写您自己的数据生成逻辑。
让我们为以下应用程序代码编写单元测试
import asyncio
from datetime import date
from pydantic_ai import Agent, RunContext
from fake_database import DatabaseConn # (1)!
from weather_service import WeatherService # (2)!
weather_agent = Agent(
'openai:gpt-4o',
deps_type=WeatherService,
system_prompt='Providing a weather forecast at the locations the user provides.',
)
@weather_agent.tool
def weather_forecast(
ctx: RunContext[WeatherService], location: str, forecast_date: date
) -> str:
if forecast_date < date.today(): # (3)!
return ctx.deps.get_historic_weather(location, forecast_date)
else:
return ctx.deps.get_forecast(location, forecast_date)
async def run_weather_forecast( # (4)!
user_prompts: list[tuple[str, int]], conn: DatabaseConn
):
"""Run weather forecast for a list of user prompts and save."""
async with WeatherService() as weather_service:
async def run_forecast(prompt: str, user_id: int):
result = await weather_agent.run(prompt, deps=weather_service)
await conn.store_forecast(user_id, result.data)
# run all prompts in parallel
await asyncio.gather(
*(run_forecast(prompt, user_id) for (prompt, user_id) in user_prompts)
)
DatabaseConn
是一个持有数据库连接的类WeatherService
具有获取天气预报和有关天气的历史数据的方法- 我们需要根据日期是过去还是将来调用不同的端点,您将在下面看到为什么这种细微差别很重要
- 此函数是我们想要测试的代码,以及它使用的代理
这里我们有一个函数,它接受 (user_prompt, user_id)
元组的列表,获取每个提示的天气预报,并将结果存储在数据库中。
我们想要测试此代码,而无需模拟某些对象或修改我们的代码,以便我们可以传入测试对象。
以下是我们如何使用 TestModel
编写测试
from datetime import timezone
import pytest
from dirty_equals import IsNow
from pydantic_ai import models, capture_run_messages
from pydantic_ai.models.test import TestModel
from pydantic_ai.messages import (
ModelResponse,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
ModelRequest,
)
from fake_database import DatabaseConn
from weather_app import run_weather_forecast, weather_agent
pytestmark = pytest.mark.anyio # (1)!
models.ALLOW_MODEL_REQUESTS = False # (2)!
async def test_forecast():
conn = DatabaseConn()
user_id = 1
with capture_run_messages() as messages:
with weather_agent.override(model=TestModel()): # (3)!
prompt = 'What will the weather be like in London on 2024-11-28?'
await run_weather_forecast([(prompt, user_id)], conn) # (4)!
forecast = await conn.get_forecast(user_id)
assert forecast == '{"weather_forecast":"Sunny with a chance of rain"}' # (5)!
assert messages == [ # (6)!
ModelRequest(
parts=[
SystemPromptPart(
content='Providing a weather forecast at the locations the user provides.',
timestamp=IsNow(tz=timezone.utc),
),
UserPromptPart(
content='What will the weather be like in London on 2024-11-28?',
timestamp=IsNow(tz=timezone.utc), # (7)!
),
]
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='weather_forecast',
args={
'location': 'a',
'forecast_date': '2024-01-01', # (8)!
},
tool_call_id=None,
)
],
model_name='test',
timestamp=IsNow(tz=timezone.utc),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='weather_forecast',
content='Sunny with a chance of rain',
tool_call_id=None,
timestamp=IsNow(tz=timezone.utc),
),
],
),
ModelResponse(
parts=[
TextPart(
content='{"weather_forecast":"Sunny with a chance of rain"}',
)
],
model_name='test',
timestamp=IsNow(tz=timezone.utc),
),
]
- 我们正在使用 anyio 运行异步测试。
- 这是一种安全措施,以确保我们在测试时不会意外地向 LLM 发出真实请求,有关更多详细信息,请参阅
ALLOW_MODEL_REQUESTS
。 - 我们正在使用
Agent.override
将代理的模型替换为TestModel
,关于override
的好处是,我们可以在代理内部替换模型,而无需访问代理run*
方法调用站点。 - 现在我们在
override
上下文管理器中调用我们想要测试的函数。 - 默认情况下,
TestModel
将返回一个 JSON 字符串,总结所做的工具调用以及返回的内容。如果您想自定义响应以使其更符合领域,您可以在定义TestModel
时添加custom_result_text='Sunny'
。 - 到目前为止,我们实际上并不知道调用了哪些工具以及使用了哪些值,我们可以使用
capture_run_messages
来检查最近一次运行的消息,并断言代理和模型之间的交换按预期发生。 IsNow
助手允许我们即使对于包含随时间变化的时间戳的数据也使用声明性断言。TestModel
没有做任何巧妙的事情来从提示中提取值,因此这些值是硬编码的。
使用 FunctionModel
进行单元测试
以上测试是一个很好的开始,但细心的读者会注意到,由于 TestModel
使用过去的日期调用 weather_forecast
,因此永远不会调用 WeatherService.get_forecast
。
为了充分锻炼 weather_forecast
,我们需要使用 FunctionModel
来自定义工具的调用方式。
这是一个使用 FunctionModel
测试具有自定义输入的 weather_forecast
工具的示例
import re
import pytest
from pydantic_ai import models
from pydantic_ai.messages import (
ModelMessage,
ModelResponse,
TextPart,
ToolCallPart,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from fake_database import DatabaseConn
from weather_app import run_weather_forecast, weather_agent
pytestmark = pytest.mark.anyio
models.ALLOW_MODEL_REQUESTS = False
def call_weather_forecast( # (1)!
messages: list[ModelMessage], info: AgentInfo
) -> ModelResponse:
if len(messages) == 1:
# first call, call the weather forecast tool
user_prompt = messages[0].parts[-1]
m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content)
assert m is not None
args = {'location': 'London', 'forecast_date': m.group()} # (2)!
return ModelResponse(parts=[ToolCallPart('weather_forecast', args)])
else:
# second call, return the forecast
msg = messages[-1].parts[0]
assert msg.part_kind == 'tool-return'
return ModelResponse(parts=[TextPart(f'The forecast is: {msg.content}')])
async def test_forecast_future():
conn = DatabaseConn()
user_id = 1
with weather_agent.override(model=FunctionModel(call_weather_forecast)): # (3)!
prompt = 'What will the weather be like in London on 2032-01-01?'
await run_weather_forecast([(prompt, user_id)], conn)
forecast = await conn.get_forecast(user_id)
assert forecast == 'The forecast is: Rainy with a chance of sun'
- 我们定义了一个函数
call_weather_forecast
,它将由FunctionModel
代替 LLM 调用,此函数可以访问构成运行的ModelMessage
列表,以及包含有关代理和函数工具以及返回工具信息的AgentInfo
。 - 我们的函数稍微智能一些,因为它试图从提示中提取日期,但只是硬编码了位置。
- 我们使用
FunctionModel
将代理的模型替换为我们的自定义函数。
通过 pytest fixtures 覆盖模型
如果您正在编写大量都需要覆盖模型的测试,则可以使用 pytest fixtures 以可重用的方式使用 TestModel
或 FunctionModel
覆盖模型。
这是一个使用 TestModel
覆盖模型的 fixture 示例
import pytest
from weather_app import weather_agent
from pydantic_ai.models.test import TestModel
@pytest.fixture
def override_weather_agent():
with weather_agent.override(model=TestModel()):
yield
async def test_forecast(override_weather_agent: None):
...
# test code here
评估
“评估”指的是评估模型在特定应用中的性能。
警告
与单元测试不同,评估是一门新兴的艺术/科学;任何声称确切知道应该如何定义您的评估的人都可以放心地忽略。
评估通常更像基准测试而不是单元测试,它们永远不会“通过”,尽管它们会“失败”;您主要关心的是它们如何随时间变化。
由于评估需要针对真实模型运行,因此运行速度可能很慢且成本很高,您通常不希望在每次提交的 CI 中运行它们。
衡量性能
评估中最难的部分是衡量模型的性能如何。
在某些情况下(例如,用于生成 SQL 的代理),有一些简单易于运行的测试可以用来衡量性能(例如,SQL 是否有效?它是否返回正确的结果?它是否只返回正确的结果?)。
在其他情况下(例如,提供戒烟建议的代理),可能很难或不可能对性能进行定量衡量——在吸烟的情况下,您真的需要进行为期数月的双盲试验,然后等待 40 年并观察健康结果,以了解对您的提示的更改是否有所改进。
您可以使用几种不同的策略来衡量性能
- 端到端、自包含的测试 — 就像 SQL 示例一样,我们可以近乎即时地测试代理的最终结果
- 合成自包含的测试 — 编写单元测试风格的检查,以检查输出是否符合预期,例如
'chewing gum' in response
这样的检查,虽然这些检查可能看起来很简单,但它们可能很有用,一个好的特点是,当它们失败时,很容易判断出哪里出了问题 - LLM 评估 LLM — 使用另一个模型,甚至使用具有不同提示的相同模型来评估代理的性能(就像班级互相批改作业,因为老师宿醉一样),虽然这种方法的缺点和复杂性是显而易见的,但有些人认为在适当的情况下它可以成为一个有用的工具
- 生产环境中的评估 — 衡量代理在生产环境中的最终结果,然后创建性能的定量度量,这样您就可以在更改提示或使用的模型时轻松衡量随时间的变化,logfire 在这种情况下非常有用,因为您可以编写自定义查询来衡量代理的性能
系统提示定制
系统提示是开发人员控制代理行为的主要工具,因此能够自定义系统提示并查看性能如何变化通常很有用。当系统提示包含示例列表并且您想了解更改该列表如何影响模型的性能时,这一点尤其重要。
假设我们有以下应用程序,用于运行从用户提示生成的 SQL(为简洁起见,此示例省略了很多细节,有关更完整的代码,请参阅 SQL gen 示例)
import json
from pathlib import Path
from typing import Union
from pydantic_ai import Agent, RunContext
from fake_database import DatabaseConn
class SqlSystemPrompt: # (1)!
def __init__(
self, examples: Union[list[dict[str, str]], None] = None, db: str = 'PostgreSQL'
):
if examples is None:
# if examples aren't provided, load them from file, this is the default
with Path('examples.json').open('rb') as f:
self.examples = json.load(f)
else:
self.examples = examples
self.db = db
def build_prompt(self) -> str: # (2)!
return f"""\
Given the following {self.db} table of records, your job is to
write a SQL query that suits the user's request.
Database schema:
CREATE TABLE records (
...
);
{''.join(self.format_example(example) for example in self.examples)}
"""
@staticmethod
def format_example(example: dict[str, str]) -> str: # (3)!
return f"""\
<example>
<request>{example['request']}</request>
<sql>{example['sql']}</sql>
</example>
"""
sql_agent = Agent(
'google-gla:gemini-1.5-flash',
deps_type=SqlSystemPrompt,
)
@sql_agent.system_prompt
async def system_prompt(ctx: RunContext[SqlSystemPrompt]) -> str:
return ctx.deps.build_prompt()
async def user_search(user_prompt: str) -> list[dict[str, str]]:
"""Search the database based on the user's prompts."""
... # (4)!
result = await sql_agent.run(user_prompt, deps=SqlSystemPrompt())
conn = DatabaseConn()
return await conn.execute(result.data)
SqlSystemPrompt
类用于构建系统提示,可以使用示例列表和数据库类型对其进行自定义。我们将其实现为一个单独的类,作为 dep 传递给代理,以便我们可以在评估期间通过依赖注入覆盖输入和逻辑。build_prompt
方法从示例和数据库类型构建系统提示。- 有些人认为,如果示例格式为 XML,LLM 更可能生成良好的响应,因为它更容易识别字符串的结尾,请参阅 #93。
- 实际上,您会在这里有更多的逻辑,这使得独立于更广泛的应用程序运行代理变得不切实际。
examples.json
看起来像这样
request: show me error records with the tag "foobar"
response: SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)
{
"examples": [
{
"request": "Show me all records",
"sql": "SELECT * FROM records;"
},
{
"request": "Show me all records from 2021",
"sql": "SELECT * FROM records WHERE date_trunc('year', date) = '2021-01-01';"
},
{
"request": "show me error records with the tag 'foobar'",
"sql": "SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags);"
},
...
]
}
现在我们需要一种量化 SQL 生成成功率的方法,以便我们可以判断对代理的更改如何影响其性能。
我们可以使用 Agent.override
将系统提示替换为使用示例子集的自定义提示,然后运行应用程序代码(在本例中为 user_search
)。我们还运行示例中的实际 SQL,并将示例 SQL 中的“正确”结果与代理生成的 SQL 进行比较。(我们比较运行 SQL 的结果而不是 SQL 本身,因为 SQL 在语义上可能是等效的,但编写方式不同)。
为了获得性能的定量度量,我们按如下方式为每次运行分配点数
- -100 分,如果生成的 SQL 无效
- -1 分,对于代理返回的每一行(因此不鼓励返回大量结果)
- +5 分,对于代理返回的与预期结果匹配的每一行
我们使用 5 折交叉验证来使用我们现有的示例集判断代理的性能。
import json
import statistics
from pathlib import Path
from itertools import chain
from fake_database import DatabaseConn, QueryError
from sql_app import sql_agent, SqlSystemPrompt, user_search
async def main():
with Path('examples.json').open('rb') as f:
examples = json.load(f)
# split examples into 5 folds
fold_size = len(examples) // 5
folds = [examples[i : i + fold_size] for i in range(0, len(examples), fold_size)]
conn = DatabaseConn()
scores = []
for i, fold in enumerate(folds):
fold_score = 0
# build all other folds into a list of examples
other_folds = list(chain(*(f for j, f in enumerate(folds) if j != i)))
# create a new system prompt with the other fold examples
system_prompt = SqlSystemPrompt(examples=other_folds)
# override the system prompt with the new one
with sql_agent.override(deps=system_prompt):
for case in fold:
try:
agent_results = await user_search(case['request'])
except QueryError as e:
print(f'Fold {i} {case}: {e}')
fold_score -= 100
else:
# get the expected results using the SQL from this case
expected_results = await conn.execute(case['sql'])
agent_ids = [r['id'] for r in agent_results]
# each returned value has a score of -1
fold_score -= len(agent_ids)
expected_ids = {r['id'] for r in expected_results}
# each return value that matches the expected value has a score of 3
fold_score += 5 * len(set(agent_ids) & expected_ids)
scores.append(fold_score)
overall_score = statistics.mean(scores)
print(f'Overall score: {overall_score:0.2f}')
#> Overall score: 12.00
然后我们可以更改提示、模型或示例,并查看分数如何随时间变化。