跳转到内容

单元测试

为 Pydantic AI 代码编写单元测试就像为任何其他 Python 代码编写单元测试一样。

因为在很大程度上它们并非新生事物,我们已经有了非常成熟的工具和模式来编写和运行这类测试。

除非你真的确信自己有更好的方法,否则你可能需要大致遵循以下策略:

  • 使用 pytest作为你的测试工具
  • 如果你发现自己正在输入冗长的断言,请使用 inline-snapshot
  • 同样,dirty-equals 对于比较大型数据结构也很有用
  • 使用 TestModelFunctionModel 来代替你的实际模型,以避免实际 LLM 调用带来的使用成本、延迟和可变性
  • 使用 Agent.override 在你的应用程序逻辑内部替换代理(agent)的模型、依赖项或工具集
  • 全局设置 ALLOW_MODEL_REQUESTS=False,以阻止任何请求意外地发送到非测试模型

使用 TestModel 进行单元测试

运行大部分应用程序代码最简单、最快的方法是使用 TestModel,它会(默认情况下)调用代理中的所有工具,然后根据代理的返回类型返回纯文本或结构化响应。

TestModel 并非魔法

TestModel 的“巧妙”(但不太聪明)之处在于,它会尝试根据已注册工具的模式(schema),为函数工具输出类型生成有效的结构化数据。

TestModel 中没有机器学习或人工智能,它只是普通的程序化 Python 代码,试图生成满足工具的 JSON schema 的数据。

生成的数据看起来不会很美观或相关,但在大多数情况下应该能通过 Pydantic 的验证。如果你想要更复杂的东西,请使用 FunctionModel 并编写自己的数据生成逻辑。

让我们为以下应用程序代码编写单元测试

weather_app.py
import asyncio
from datetime import date

from pydantic_ai import Agent, RunContext

from fake_database import DatabaseConn  # (1)!
from weather_service import WeatherService  # (2)!

weather_agent = Agent(
    'openai:gpt-4o',
    deps_type=WeatherService,
    system_prompt='Providing a weather forecast at the locations the user provides.',
)


@weather_agent.tool
def weather_forecast(
    ctx: RunContext[WeatherService], location: str, forecast_date: date
) -> str:
    if forecast_date < date.today():  # (3)!
        return ctx.deps.get_historic_weather(location, forecast_date)
    else:
        return ctx.deps.get_forecast(location, forecast_date)


async def run_weather_forecast(  # (4)!
    user_prompts: list[tuple[str, int]], conn: DatabaseConn
):
    """Run weather forecast for a list of user prompts and save."""
    async with WeatherService() as weather_service:

        async def run_forecast(prompt: str, user_id: int):
            result = await weather_agent.run(prompt, deps=weather_service)
            await conn.store_forecast(user_id, result.output)

        # run all prompts in parallel
        await asyncio.gather(
            *(run_forecast(prompt, user_id) for (prompt, user_id) in user_prompts)
        )
  1. DatabaseConn 是一个持有数据库连接的类
  2. WeatherService 拥有获取天气预报和天气历史数据的方法
  3. 我们需要根据日期是过去还是未来来调用不同的端点,下面你将看到为什么这个细微差别很重要
  4. 这个函数是我们要测试的代码,连同它使用的代理一起

这里我们有一个函数,它接受一个 (user_prompt, user_id) 元组的列表,为每个提示获取天气预报,并将结果存储在数据库中。

我们希望在不模拟某些对象或修改代码以便传入测试对象的情况下测试这段代码。

以下是我们如何使用 TestModel 编写测试

test_weather_app.py
from datetime import timezone
import pytest

from dirty_equals import IsNow, IsStr

from pydantic_ai import models, capture_run_messages, RequestUsage
from pydantic_ai.models.test import TestModel
from pydantic_ai.messages import (
    ModelResponse,
    SystemPromptPart,
    TextPart,
    ToolCallPart,
    ToolReturnPart,
    UserPromptPart,
    ModelRequest,
)

from fake_database import DatabaseConn
from weather_app import run_weather_forecast, weather_agent

pytestmark = pytest.mark.anyio  # (1)!
models.ALLOW_MODEL_REQUESTS = False  # (2)!


async def test_forecast():
    conn = DatabaseConn()
    user_id = 1
    with capture_run_messages() as messages:
        with weather_agent.override(model=TestModel()):  # (3)!
            prompt = 'What will the weather be like in London on 2024-11-28?'
            await run_weather_forecast([(prompt, user_id)], conn)  # (4)!

    forecast = await conn.get_forecast(user_id)
    assert forecast == '{"weather_forecast":"Sunny with a chance of rain"}'  # (5)!

    assert messages == [  # (6)!
        ModelRequest(
            parts=[
                SystemPromptPart(
                    content='Providing a weather forecast at the locations the user provides.',
                    timestamp=IsNow(tz=timezone.utc),
                ),
                UserPromptPart(
                    content='What will the weather be like in London on 2024-11-28?',
                    timestamp=IsNow(tz=timezone.utc),  # (7)!
                ),
            ]
        ),
        ModelResponse(
            parts=[
                ToolCallPart(
                    tool_name='weather_forecast',
                    args={
                        'location': 'a',
                        'forecast_date': '2024-01-01',  # (8)!
                    },
                    tool_call_id=IsStr(),
                )
            ],
            usage=RequestUsage(
                input_tokens=71,
                output_tokens=7,
            ),
            model_name='test',
            timestamp=IsNow(tz=timezone.utc),
        ),
        ModelRequest(
            parts=[
                ToolReturnPart(
                    tool_name='weather_forecast',
                    content='Sunny with a chance of rain',
                    tool_call_id=IsStr(),
                    timestamp=IsNow(tz=timezone.utc),
                ),
            ],
        ),
        ModelResponse(
            parts=[
                TextPart(
                    content='{"weather_forecast":"Sunny with a chance of rain"}',
                )
            ],
            usage=RequestUsage(
                input_tokens=77,
                output_tokens=16,
            ),
            model_name='test',
            timestamp=IsNow(tz=timezone.utc),
        ),
    ]
  1. 我们正在使用 anyio 来运行异步测试。
  2. 这是一项安全措施,确保我们在测试时不会意外地向 LLM 发出真实请求,更多详情请参见 ALLOW_MODEL_REQUESTS
  3. 我们正在使用 Agent.override 将代理的模型替换为 TestModeloverride 的好处是,我们可以在代理内部替换模型,而无需访问代理 run* 方法的调用点。
  4. 现在我们在 override 上下文管理器中调用我们想要测试的函数。
  5. 默认情况下,TestModel 会返回一个 JSON 字符串,总结所做的工具调用以及返回的内容。如果你想将响应自定义为更贴近领域的内容,可以在定义 TestModel 时添加 custom_output_text='Sunny'
  6. 到目前为止,我们实际上并不知道调用了哪些工具以及使用了哪些值,我们可以使用 capture_run_messages 来检查最近一次运行的消息,并断言代理和模型之间的交互按预期发生。
  7. [IsNow][dirty_equals.IsNow] 辅助工具允许我们即使在数据包含随时间变化的时间戳时,也能使用声明式断言。
  8. TestModel 并没有做任何聪明的事情来从提示中提取值,所以这些值是硬编码的。

使用 FunctionModel 进行单元测试

上述测试是一个很好的开始,但细心的读者会注意到 WeatherService.get_forecast 从未被调用,因为 TestModel 使用过去的日期调用了 weather_forecast

为了全面测试 weather_forecast,我们需要使用 FunctionModel 来自定义工具的调用方式。

这是一个使用 FunctionModel 通过自定义输入来测试 weather_forecast 工具的示例

test_weather_app2.py
import re

import pytest

from pydantic_ai import models
from pydantic_ai.messages import (
    ModelMessage,
    ModelResponse,
    TextPart,
    ToolCallPart,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel

from fake_database import DatabaseConn
from weather_app import run_weather_forecast, weather_agent

pytestmark = pytest.mark.anyio
models.ALLOW_MODEL_REQUESTS = False


def call_weather_forecast(  # (1)!
    messages: list[ModelMessage], info: AgentInfo
) -> ModelResponse:
    if len(messages) == 1:
        # first call, call the weather forecast tool
        user_prompt = messages[0].parts[-1]
        m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content)
        assert m is not None
        args = {'location': 'London', 'forecast_date': m.group()}  # (2)!
        return ModelResponse(parts=[ToolCallPart('weather_forecast', args)])
    else:
        # second call, return the forecast
        msg = messages[-1].parts[0]
        assert msg.part_kind == 'tool-return'
        return ModelResponse(parts=[TextPart(f'The forecast is: {msg.content}')])


async def test_forecast_future():
    conn = DatabaseConn()
    user_id = 1
    with weather_agent.override(model=FunctionModel(call_weather_forecast)):  # (3)!
        prompt = 'What will the weather be like in London on 2032-01-01?'
        await run_weather_forecast([(prompt, user_id)], conn)

    forecast = await conn.get_forecast(user_id)
    assert forecast == 'The forecast is: Rainy with a chance of sun'
  1. 我们定义了一个函数 call_weather_forecastFunctionModel 将调用它来代替 LLM,该函数可以访问构成运行的 ModelMessage 列表,以及包含有关代理、函数工具和返回工具信息的 AgentInfo
  2. 我们的函数稍微智能一些,它会尝试从提示中提取日期,但只是硬编码了位置。
  3. 我们使用 FunctionModel 将代理的模型替换为我们的自定义函数。

通过 pytest fixtures 覆盖模型

如果你正在编写大量都需要覆盖模型的测试,你可以使用 pytest fixtures,以可重用的方式用 TestModelFunctionModel 来覆盖模型。

这是一个使用 TestModel 覆盖模型的 fixture 示例

test_agent.py
import pytest

from pydantic_ai.models.test import TestModel

from weather_app import weather_agent


@pytest.fixture
def override_weather_agent():
    with weather_agent.override(model=TestModel()):
        yield


async def test_forecast(override_weather_agent: None):
    ...
    # test code here