RAG
RAG 搜索示例。此演示允许您询问关于 logfire 文档的问题。
演示了
这是通过创建一个包含 markdown 文档每个部分的数据库,然后向 PydanticAI 代理注册搜索工具来完成的。
用于从 markdown 文件和包含该数据的 JSON 文件中提取部分的逻辑,可在 此 gist 中找到。
PostgreSQL with pgvector 被用作搜索数据库,下载和运行 pgvector 最简单的方法是使用 Docker
mkdir postgres-data
docker run --rm \
-e POSTGRES_PASSWORD=postgres \
-p 54320:5432 \
-v `pwd`/postgres-data:/var/lib/postgresql/data \
pgvector/pgvector:pg17
与 SQL gen 示例一样,我们在端口 54320
上运行 postgres,以避免与您可能正在运行的任何其他 postgres 实例冲突。我们还在本地挂载 PostgreSQL data
目录,以便在您需要停止并重新启动容器时持久化数据。
在运行该程序并安装依赖项并设置环境变量后,我们可以使用以下命令构建搜索数据库(警告:这需要 OPENAI_API_KEY
环境变量,并将调用 OpenAI embedding API 约 300 次,以为文档的每个部分生成嵌入):
python -m pydantic_ai_examples.rag build
uv run -m pydantic_ai_examples.rag build
(注意:构建数据库目前不使用 PydanticAI,而是直接使用 OpenAI SDK。)
然后您可以使用以下命令向代理提问:
python -m pydantic_ai_examples.rag search "How do I configure logfire to work with FastAPI?"
uv run -m pydantic_ai_examples.rag search "How do I configure logfire to work with FastAPI?"
示例代码
rag.py
from __future__ import annotations as _annotations
import asyncio
import re
import sys
import unicodedata
from contextlib import asynccontextmanager
from dataclasses import dataclass
import asyncpg
import httpx
import logfire
import pydantic_core
from openai import AsyncOpenAI
from pydantic import TypeAdapter
from typing_extensions import AsyncGenerator
from pydantic_ai import RunContext
from pydantic_ai.agent import Agent
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')
logfire.instrument_asyncpg()
@dataclass
class Deps:
openai: AsyncOpenAI
pool: asyncpg.Pool
agent = Agent('openai:gpt-4o', deps_type=Deps, instrument=True)
@agent.tool
async def retrieve(context: RunContext[Deps], search_query: str) -> str:
"""Retrieve documentation sections based on a search query.
Args:
context: The call context.
search_query: The search query.
"""
with logfire.span(
'create embedding for {search_query=}', search_query=search_query
):
embedding = await context.deps.openai.embeddings.create(
input=search_query,
model='text-embedding-3-small',
)
assert len(embedding.data) == 1, (
f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
)
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
rows = await context.deps.pool.fetch(
'SELECT url, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8',
embedding_json,
)
return '\n\n'.join(
f'# {row["title"]}\nDocumentation URL:{row["url"]}\n\n{row["content"]}\n'
for row in rows
)
async def run_agent(question: str):
"""Entry point to run the agent and perform RAG based question answering."""
openai = AsyncOpenAI()
logfire.instrument_openai(openai)
logfire.info('Asking "{question}"', question=question)
async with database_connect(False) as pool:
deps = Deps(openai=openai, pool=pool)
answer = await agent.run(question, deps=deps)
print(answer.data)
#######################################################
# The rest of this file is dedicated to preparing the #
# search database, and some utilities. #
#######################################################
# JSON document from
# https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992
DOCS_JSON = (
'https://gist.githubusercontent.com/'
'samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/'
'80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json'
)
async def build_search_db():
"""Build the search database."""
async with httpx.AsyncClient() as client:
response = await client.get(DOCS_JSON)
response.raise_for_status()
sections = sessions_ta.validate_json(response.content)
openai = AsyncOpenAI()
logfire.instrument_openai(openai)
async with database_connect(True) as pool:
with logfire.span('create schema'):
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(DB_SCHEMA)
sem = asyncio.Semaphore(10)
async with asyncio.TaskGroup() as tg:
for section in sections:
tg.create_task(insert_doc_section(sem, openai, pool, section))
async def insert_doc_section(
sem: asyncio.Semaphore,
openai: AsyncOpenAI,
pool: asyncpg.Pool,
section: DocsSection,
) -> None:
async with sem:
url = section.url()
exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url)
if exists:
logfire.info('Skipping {url=}', url=url)
return
with logfire.span('create embedding for {url=}', url=url):
embedding = await openai.embeddings.create(
input=section.embedding_content(),
model='text-embedding-3-small',
)
assert len(embedding.data) == 1, (
f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
)
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
await pool.execute(
'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)',
url,
section.title,
section.content,
embedding_json,
)
@dataclass
class DocsSection:
id: int
parent: int | None
path: str
level: int
title: str
content: str
def url(self) -> str:
url_path = re.sub(r'\.md$', '', self.path)
return (
f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, "-")}'
)
def embedding_content(self) -> str:
return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content))
sessions_ta = TypeAdapter(list[DocsSection])
# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
@asynccontextmanager
async def database_connect(
create_db: bool = False,
) -> AsyncGenerator[asyncpg.Pool, None]:
server_dsn, database = (
'postgresql://postgres:postgres@localhost:54320',
'pydantic_ai_rag',
)
if create_db:
with logfire.span('check and create DB'):
conn = await asyncpg.connect(server_dsn)
try:
db_exists = await conn.fetchval(
'SELECT 1 FROM pg_database WHERE datname = $1', database
)
if not db_exists:
await conn.execute(f'CREATE DATABASE {database}')
finally:
await conn.close()
pool = await asyncpg.create_pool(f'{server_dsn}/{database}')
try:
yield pool
finally:
await pool.close()
DB_SCHEMA = """
CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS doc_sections (
id serial PRIMARY KEY,
url text NOT NULL UNIQUE,
title text NOT NULL,
content text NOT NULL,
-- text-embedding-3-small returns a vector of 1536 floats
embedding vector(1536) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_doc_sections_embedding ON doc_sections USING hnsw (embedding vector_l2_ops);
"""
def slugify(value: str, separator: str, unicode: bool = False) -> str:
"""Slugify a string, to make it URL friendly."""
# Taken unchanged from https://github.com/Python-Markdown/markdown/blob/3.7/markdown/extensions/toc.py#L38
if not unicode:
# Replace Extended Latin characters with ASCII, i.e. `žlutý` => `zluty`
value = unicodedata.normalize('NFKD', value)
value = value.encode('ascii', 'ignore').decode('ascii')
value = re.sub(r'[^\w\s-]', '', value).strip().lower()
return re.sub(rf'[{separator}\s]+', separator, value)
if __name__ == '__main__':
action = sys.argv[1] if len(sys.argv) > 1 else None
if action == 'build':
asyncio.run(build_search_db())
elif action == 'search':
if len(sys.argv) == 3:
q = sys.argv[2]
else:
q = 'How do I configure logfire to work with FastAPI?'
asyncio.run(run_agent(q))
else:
print(
'uv run --extra examples -m pydantic_ai_examples.rag build|search',
file=sys.stderr,
)
sys.exit(1)