跳转到内容

数据分析师

有时在智能体(agent)的工作流中,智能体不需要知道工具的确切输出,但仍需要以某种方式处理该工具的输出。这在数据分析中尤其常见:智能体需要知道查询工具的结果是一个包含特定命名列的 DataFrame,但不一定需要知道每一行的具体内容。

使用 Pydantic AI,您可以用一个 依赖对象 来存储一个工具的结果,并在另一个工具中使用它。

在本例中,我们将构建一个智能体来分析康奈尔大学的烂番茄电影评论数据集

演示内容:

运行示例

安装依赖并设置环境变量后,运行

python -m pydantic_ai_examples.data_analyst
uv run -m pydantic_ai_examples.data_analyst

输出 (调试)

根据我对康奈尔电影评论数据集(rotten_tomatoes)的分析,训练集中有 4,265 条负面评论。这些评论被标记为 'neg'(在数据集中用 0 表示)。

示例代码

data_analyst.py
from dataclasses import dataclass, field

import datasets
import duckdb
import pandas as pd

from pydantic_ai import Agent, ModelRetry, RunContext


@dataclass
class AnalystAgentDeps:
    output: dict[str, pd.DataFrame] = field(default_factory=dict)

    def store(self, value: pd.DataFrame) -> str:
        """Store the output in deps and return the reference such as Out[1] to be used by the LLM."""
        ref = f'Out[{len(self.output) + 1}]'
        self.output[ref] = value
        return ref

    def get(self, ref: str) -> pd.DataFrame:
        if ref not in self.output:
            raise ModelRetry(
                f'Error: {ref} is not a valid variable reference. Check the previous messages and try again.'
            )
        return self.output[ref]


analyst_agent = Agent(
    'openai:gpt-4o',
    deps_type=AnalystAgentDeps,
    instructions='You are a data analyst and your job is to analyze the data according to the user request.',
)


@analyst_agent.tool
def load_dataset(
    ctx: RunContext[AnalystAgentDeps],
    path: str,
    split: str = 'train',
) -> str:
    """Load the `split` of dataset `dataset_name` from huggingface.

    Args:
        ctx: Pydantic AI agent RunContext
        path: name of the dataset in the form of `<user_name>/<dataset_name>`
        split: load the split of the dataset (default: "train")
    """
    # begin load data from hf
    builder = datasets.load_dataset_builder(path)  # pyright: ignore[reportUnknownMemberType]
    splits: dict[str, datasets.SplitInfo] = builder.info.splits or {}  # pyright: ignore[reportUnknownMemberType]
    if split not in splits:
        raise ModelRetry(
            f'{split} is not valid for dataset {path}. Valid splits are {",".join(splits.keys())}'
        )

    builder.download_and_prepare()  # pyright: ignore[reportUnknownMemberType]
    dataset = builder.as_dataset(split=split)
    assert isinstance(dataset, datasets.Dataset)
    dataframe = dataset.to_pandas()
    assert isinstance(dataframe, pd.DataFrame)
    # end load data from hf

    # store the dataframe in the deps and get a ref like "Out[1]"
    ref = ctx.deps.store(dataframe)
    # construct a summary of the loaded dataset
    output = [
        f'Loaded the dataset as `{ref}`.',
        f'Description: {dataset.info.description}'
        if dataset.info.description
        else None,
        f'Features: {dataset.info.features!r}' if dataset.info.features else None,
    ]
    return '\n'.join(filter(None, output))


@analyst_agent.tool
def run_duckdb(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str:
    """Run DuckDB SQL query on the DataFrame.

    Note that the virtual table name used in DuckDB SQL must be `dataset`.

    Args:
        ctx: Pydantic AI agent RunContext
        dataset: reference string to the DataFrame
        sql: the query to be executed using DuckDB
    """
    data = ctx.deps.get(dataset)
    result = duckdb.query_df(df=data, virtual_table_name='dataset', sql_query=sql)
    # pass the result as ref (because DuckDB SQL can select many rows, creating another huge dataframe)
    ref = ctx.deps.store(result.df())  # pyright: ignore[reportUnknownMemberType]
    return f'Executed SQL, result is `{ref}`'


@analyst_agent.tool
def display(ctx: RunContext[AnalystAgentDeps], name: str) -> str:
    """Display at most 5 rows of the dataframe."""
    dataset = ctx.deps.get(name)
    return dataset.head().to_string()  # pyright: ignore[reportUnknownMemberType]


if __name__ == '__main__':
    deps = AnalystAgentDeps()
    result = analyst_agent.run_sync(
        user_prompt='Count how many negative comments are there in the dataset `cornell-movie-review-data/rotten_tomatoes`',
        deps=deps,
    )
    print(result.output)

附录

选择模型

此示例需要使用一个能理解 DuckDB SQL 的模型。您可以通过 clai 进行检查。

> clai -m bedrock:us.anthropic.claude-3-7-sonnet-20250219-v1:0
clai - Pydantic AI CLI v0.0.1.dev920+41dd069 with bedrock:us.anthropic.claude-3-7-sonnet-20250219-v1:0
clai  do you understand duckdb sql?
# DuckDB SQL

Yes, I understand DuckDB SQL. DuckDB is an in-process analytical SQL database
that uses syntax similar to PostgreSQL. It specializes in analytical queries
and is designed for high-performance analysis of structured data.

Some key features of DuckDB SQL include:

  OLAP (Online Analytical Processing) optimized
  Columnar-vectorized query execution
  Standard SQL support with PostgreSQL compatibility
  Support for complex analytical queries
  Efficient handling of CSV/Parquet/JSON files

I can help you with DuckDB SQL queries, schema design, optimization, or other
DuckDB-related questions.