构建RAG系统
本章将指导你构建一个完整的、生产级别的RAG系统。
系统架构
graph TB
A[用户界面] --> B[API层]
B --> C[业务逻辑层]
subgraph "RAG核心"
C --> D[查询理解]
D --> E[检索模块]
C --> F[文档管理]
F --> G[向量数据库]
E --> G
G --> H[重排序]
H --> I[LLM生成]
end
I --> J[后处理]
J --> K[返回结果]
L[文档源] --> F
style A fill:#e1f5ff
style K fill:#c8e6c9
项目结构
advanced-rag/
├── .env # 环境变量
├── requirements.txt # 依赖
├── config.py # 配置
├── main.py # 主程序入口
├── src/
│ ├── document/
│ │ ├── loader.py # 文档加载
│ │ ├── splitter.py # 文本分割
│ │ └── processor.py # 文档预处理
│ ├── vector/
│ │ ├── embeddings.py # 嵌入模型
│ │ └── store.py # 向量数据库
│ ├── retrieval/
│ │ ├── retriever.py # 检索器
│ │ └── reranker.py # 重排序
│ ├── generation/
│ │ ├── prompts.py # Prompt模板
│ │ └── llm.py # LLM集成
│ └── api/
│ └── server.py # API服务
└── data/
├── documents/ # 原始文档
└── chroma_db/ # 向量数据库
配置文件
创建 config.py:
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class Config:
"""RAG系统配置"""
# OpenAI配置
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
openai_base_url: Optional[str] = None
# 嵌入模型配置
embedding_model: str = "text-embedding-3-small"
embedding_dim: int = 1536
# LLM配置
llm_model: str = "gpt-4o-mini"
temperature: float = 0.7
max_tokens: int = 1000
# 向量数据库配置
vector_db_type: str = "chroma" # chroma, faiss, pinecone
chroma_persist_dir: str = "./data/chroma_db"
collection_name: str = "rag_collection"
# 检索配置
top_k: int = 5
search_kwargs: dict = None
# 文档处理配置
chunk_size: int = 1000
chunk_overlap: int = 200
# API配置
api_host: str = "0.0.0.0"
api_port: int = 8000
def __post_init__(self):
if self.search_kwargs is None:
self.search_kwargs = {"k": self.top_k}
# 全局配置实例
config = Config()
文档加载模块
创建 src/document/loader.py:
from typing import List
from pathlib import Path
from langchain_community.document_loaders import (
TextLoader,
PyPDFLoader,
Docx2txtLoader,
UnstructuredMarkdownLoader,
DirectoryLoader
)
from langchain.schema import Document
class DocumentLoader:
"""文档加载器"""
def __init__(self):
self.loader_map = {
".txt": TextLoader,
".md": UnstructuredMarkdownLoader,
".pdf": PyPDFLoader,
".docx": Docx2txtLoader,
}
def load_file(self, filepath: str) -> List[Document]:
"""
加载单个文件
Args:
filepath: 文件路径
Returns:
文档列表
"""
ext = Path(filepath).suffix.lower()
if ext not in self.loader_map:
raise ValueError(f"不支持的文件类型: {ext}")
loader = self.loader_map[ext](filepath)
return loader.load()
def load_directory(
self,
dirpath: str,
glob: str = "**/*.*",
exclude: List[str] = None
) -> List[Document]:
"""
加载目录
Args:
dirpath: 目录路径
glob: 文件匹配模式
exclude: 排除的文件
Returns:
文档列表
"""
if exclude is None:
exclude = [".git", "__pycache__", ".venv"]
# 确保排除项以路径分隔符结尾
exclude = [e if e.endswith(("/", "\\")) else e + "/" for e in exclude]
loader = DirectoryLoader(
dirpath,
glob=glob,
exclude=exclude,
use_multithreading=True
)
return loader.load()
def load_from_urls(self, urls: List[str]) -> List[Document]:
"""
从URL加载文档
Args:
urls: URL列表
Returns:
文档列表
"""
from langchain_community.document_loaders import WebBaseLoader
loader = WebBaseLoader(urls)
return loader.load()
# 使用示例
if __name__ == "__main__":
loader = DocumentLoader()
# 加载单个文件
docs = loader.load_file("example.txt")
print(f"加载了 {len(docs)} 个文档")
# 加载目录
docs = loader.load_directory("./data/documents/")
print(f"从目录加载了 {len(docs)} 个文档")
文本分割模块
创建 src/document/splitter.py:
from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
class TextSplitter:
"""文本分割器"""
def __init__(
self,
chunk_size: int = 1000,
chunk_overlap: int = 200,
separators: List[str] = None
):
"""
初始化分割器
Args:
chunk_size: 块大小
chunk_overlap: 重叠大小
separators: 分隔符列表
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
if separators is None:
# 默认分隔符(中英文)
separators = [
"\n\n", "\n",
"。", "!", "?", ";",
".", "!", "?", ";",
",", ",",
" ", ""
]
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
separators=separators
)
def split(self, documents: List[Document]) -> List[Document]:
"""
分割文档
Args:
documents: 原始文档列表
Returns:
分割后的文档列表
"""
chunks = self.splitter.split_documents(documents)
# 添加元数据
for i, chunk in enumerate(chunks):
chunk.metadata["chunk_id"] = i
chunk.metadata["chunk_size"] = len(chunk.page_content)
return chunks
def split_by_tokens(
self,
documents: List[Document],
chunk_size: int = 500,
chunk_overlap: int = 50,
encoding_name: str = "cl100k_base"
) -> List[Document]:
"""
按Token数量分割
Args:
documents: 原始文档
chunk_size: 块大小(tokens)
chunk_overlap: 重叠大小(tokens)
encoding_name: 编码器名称
Returns:
分割后的文档
"""
from langchain.text_splitter import TokenTextSplitter
token_splitter = TokenTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
encoding_name=encoding_name
)
chunks = token_splitter.split_documents(documents)
return chunks
# 使用示例
if __name__ == "__main__":
from src.document.loader import DocumentLoader
loader = DocumentLoader()
docs = loader.load_directory("./data/documents/")
splitter = TextSplitter(chunk_size=800, chunk_overlap=100)
chunks = splitter.split(docs)
print(f"分割后得到 {len(chunks)} 个文档块")
向量数据库模块
创建 src/vector/store.py:
from typing import List, Optional
from langchain.schema import Document
from langchain_community.vectorstores import Chroma, FAISS
from langchain_openai import OpenAIEmbeddings
from config import config
class VectorStore:
"""向量数据库"""
def __init__(self, persist_directory: str = None):
"""
初始化向量数据库
Args:
persist_directory: 持久化目录
"""
self.embeddings = OpenAIEmbeddings(
model=config.embedding_model,
openai_api_key=config.openai_api_key
)
self.persist_directory = persist_directory or config.chroma_persist_dir
# 加载或创建向量数据库
try:
self.db = Chroma(
persist_directory=self.persist_directory,
embedding_function=self.embeddings,
collection_name=config.collection_name
)
print(f"✅ 加载已有向量数据库")
except:
self.db = None
print(f"ℹ️ 未找到已有向量数据库,将创建新数据库")
def add_documents(self, documents: List[Document]):
"""
添加文档
Args:
documents: 文档列表
"""
if self.db is None:
# 创建新的向量数据库
self.db = Chroma.from_documents(
documents=documents,
embedding=self.embeddings,
persist_directory=self.persist_directory,
collection_name=config.collection_name
)
else:
# 添加到已有数据库
self.db.add_documents(documents)
# 持久化
self.db.persist()
print(f"✅ 添加了 {len(documents)} 个文档")
def search(
self,
query: str,
top_k: int = 5,
filter: dict = None
) -> List[Document]:
"""
相似度搜索
Args:
query: 查询文本
top_k: 返回结果数
filter: 元数据过滤
Returns:
相关文档列表
"""
if self.db is None:
raise ValueError("向量数据库为空,请先添加文档")
results = self.db.similarity_search(
query=query,
k=top_k,
filter=filter
)
return results
def search_with_score(
self,
query: str,
top_k: int = 5
) -> List[tuple]:
"""
带分数的相似度搜索
Args:
query: 查询文本
top_k: 返回结果数
Returns:
(文档, 分数)元组列表
"""
if self.db is None:
raise ValueError("向量数据库为空,请先添加文档")
results = self.db.similarity_search_with_score(query, k=top_k)
return results
def delete(self, ids: List[str] = None):
"""
删除文档
Args:
ids: 文档ID列表
"""
if self.db is None:
return
self.db.delete(ids=ids)
self.db.persist()
print(f"✅ 删除了 {len(ids) if ids else 0} 个文档")
def clear(self):
"""清空数据库"""
if self.db is None:
return
# 删除持久化目录
import shutil
if Path(self.persist_directory).exists():
shutil.rmtree(self.persist_directory)
self.db = None
print("✅ 清空向量数据库")
def get_retriever(self, search_type: str = "similarity", **kwargs):
"""
获取检索器
Args:
search_type: 搜索类型
**kwargs: 搜索参数
Returns:
检索器
"""
if self.db is None:
raise ValueError("向量数据库为空,请先添加文档")
return self.db.as_retriever(
search_type=search_type,
search_kwargs=kwargs
)
# 使用示例
if __name__ == "__main__":
from src.document.loader import DocumentLoader
from src.document.splitter import TextSplitter
# 加载文档
loader = DocumentLoader()
docs = loader.load_directory("./data/documents/")
# 分割文档
splitter = TextSplitter()
chunks = splitter.split(docs)
# 创建向量数据库
vector_store = VectorStore()
vector_store.add_documents(chunks)
# 搜索
results = vector_store.search("什么是RAG", top_k=3)
for i, doc in enumerate(results, 1):
print(f"\n结果 {i}:")
print(f"内容: {doc.page_content[:100]}...")
print(f"元数据: {doc.metadata}")
RAG链模块
创建 src/generation/rag_chain.py:
from typing import List, Dict, Optional
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import Document
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from config import config
class RAGChain:
"""RAG链"""
def __init__(
self,
retriever,
llm: ChatOpenAI = None,
prompt_template: str = None
):
"""
初始化RAG链
Args:
retriever: 检索器
llm: LLM实例
prompt_template: Prompt模板
"""
self.retriever = retriever
# 初始化LLM
if llm is None:
self.llm = ChatOpenAI(
model=config.llm_model,
temperature=config.temperature,
max_tokens=config.max_tokens,
openai_api_key=config.openai_api_key
)
else:
self.llm = llm
# 默认Prompt模板
if prompt_template is None:
self.prompt_template = """
你是一个智能问答助手。请基于以下参考信息回答用户问题。
参考信息:
{context}
用户问题:{question}
回答要求:
1. 准确、简洁、有用
2. 如果参考信息不足,请明确说明
3. 引用相关的参考信息
回答:
"""
else:
self.prompt_template = prompt_template
# 构建Prompt
self.prompt = ChatPromptTemplate.from_template(self.prompt_template)
# 构建链
self.chain = self._build_chain()
def _format_docs(self, docs: List[Document]) -> str:
"""格式化文档"""
formatted = []
for i, doc in enumerate(docs, 1):
content = doc.page_content
source = doc.metadata.get('source', 'unknown')
formatted.append(f"[参考{i} 来自 {source}]\n{content}")
return "\n\n".join(formatted)
def _build_chain(self):
"""构建RAG链"""
def format_context(inputs: Dict) -> str:
"""格式化上下文"""
docs = inputs["context"]
return self._format_docs(docs)
chain = (
{
"context": self.retriever | format_context,
"question": RunnablePassthrough()
}
| self.prompt
| self.llm
)
return chain
def invoke(self, query: str, **kwargs) -> str:
"""
调用RAG链
Args:
query: 用户查询
**kwargs: 额外参数
Returns:
生成的回答
"""
result = self.chain.invoke(query, **kwargs)
return result.content
async def ainvoke(self, query: str, **kwargs) -> str:
"""
异步调用RAG链
Args:
query: 用户查询
**kwargs: 额外参数
Returns:
生成的回答
"""
result = await self.chain.ainvoke(query, **kwargs)
return result.content
def stream(self, query: str):
"""
流式输出
Args:
query: 用户查询
Yields:
生成的文本块
"""
for chunk in self.chain.stream(query):
if hasattr(chunk, 'content'):
yield chunk.content
def get_sources(self, query: str, top_k: int = 3) -> List[Document]:
"""
获取参考来源
Args:
query: 用户查询
top_k: 返回结果数
Returns:
参考文档列表
"""
docs = self.retriever.invoke(query, top_k=top_k)
return docs
# 使用示例
if __name__ == "__main__":
from src.vector.store import VectorStore
# 创建向量数据库
vector_store = VectorStore()
# 创建检索器
retriever = vector_store.get_retriever(search_kwargs={"k": 3})
# 创建RAG链
rag = RAGChain(retriever)
# 查询
query = "RAG系统有哪些优势?"
answer = rag.invoke(query)
print(f"问题: {query}")
print(f"回答: {answer}")
# 获取来源
sources = rag.get_sources(query)
print(f"\n参考来源 ({len(sources)}):")
for i, doc in enumerate(sources, 1):
print(f"{i}. {doc.page_content[:80]}...")
主程序
创建 main.py:
import sys
import argparse
from pathlib import Path
# 添加src到路径
sys.path.insert(0, str(Path(__file__).parent))
from src.document.loader import DocumentLoader
from src.document.splitter import TextSplitter
from src.vector.store import VectorStore
from src.generation.rag_chain import RAGChain
from config import config
class RAGSystem:
"""RAG系统"""
def __init__(self):
self.loader = DocumentLoader()
self.splitter = TextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap
)
self.vector_store = VectorStore()
def build_index(self, docs_path: str):
"""
构建索引
Args:
docs_path: 文档路径
"""
print(f"📂 正在加载文档从 {docs_path}...")
# 加载文档
docs = self.loader.load_directory(docs_path)
print(f"✅ 加载了 {len(docs)} 个文档")
# 分割文档
print("✂️ 正在分割文档...")
chunks = self.splitter.split(docs)
print(f"✅ 分割为 {len(chunks)} 个文档块")
# 创建向量数据库
print("🔍 正在创建向量索引...")
self.vector_store.add_documents(chunks)
print("✅ 索引构建完成!")
def query(self, question: str) -> dict:
"""
查询
Args:
question: 问题
Returns:
包含回答和来源的字典
"""
# 创建检索器和链
retriever = self.vector_store.get_retriever(
search_kwargs={"k": config.top_k}
)
rag_chain = RAGChain(retriever)
# 获取回答
print(f"💭 正在思考: {question}")
answer = rag_chain.invoke(question)
# 获取来源
sources = rag_chain.get_sources(question)
return {
"answer": answer,
"sources": sources
}
def main():
parser = argparse.ArgumentParser(description="高级RAG系统")
parser.add_argument("--build", type=str, help="构建索引,指定文档路径")
parser.add_argument("--query", type=str, help="查询问题")
args = parser.parse_args()
rag = RAGSystem()
if args.build:
# 构建索引
rag.build_index(args.build)
elif args.query:
# 查询
result = rag.query(args.query)
print("\n" + "="*50)
print("📝 回答:")
print(result["answer"])
print("\n📚 参考来源:")
for i, doc in enumerate(result["sources"], 1):
print(f"\n{i}. {doc.metadata.get('source', 'unknown')}")
print(f" {doc.page_content[:100]}...")
print("="*50)
else:
# 交互式模式
print("🤖 高级RAG系统(输入 'quit' 退出)\n")
while True:
question = input("❓ 你的问题: ").strip()
if question.lower() in ['quit', 'exit', 'q']:
break
if not question:
continue
result = rag.query(question)
print(f"\n📝 回答: {result['answer']}")
print(f"\n📚 参考来源: {len(result['sources'])} 个文档\n")
if __name__ == "__main__":
main()
使用说明
构建索引
# 构建索引
python main.py --build ./data/documents/
查询
# 单次查询
python main.py --query "什么是RAG?"
# 交互式查询
python main.py
API服务
创建 src/api/server.py:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.vector.store import VectorStore
from src.generation.rag_chain import RAGChain
from config import config
app = FastAPI(title="RAG API")
# 初始化
vector_store = VectorStore()
retriever = vector_store.get_retriever(search_kwargs={"k": config.top_k})
rag_chain = RAGChain(retriever)
class QueryRequest(BaseModel):
query: str
class QueryResponse(BaseModel):
answer: str
sources: List[dict]
@app.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
"""查询接口"""
try:
answer = rag_chain.invoke(request.query)
sources = rag_chain.get_sources(request.query)
return QueryResponse(
answer=answer,
sources=[
{
"content": doc.page_content[:200],
"metadata": doc.metadata
}
for doc in sources
]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
"""健康检查"""
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host=config.api_host,
port=config.api_port
)
启动API服务:
uvicorn src.api.server:app --reload
学习要点
✅ 模块化设计,易于维护和扩展 ✅ 支持多种文档格式和数据源 ✅ 灵活的文本分割策略 ✅ 可扩展的向量数据库接口 ✅ 支持同步和异步调用 ✅ 提供RESTful API接口
下一步: 学习 RAG优化技巧 ⚡