RAG优化技巧
High Contrast
Dark Mode
Light Mode
Sepia
Forest
1 min read186 words

RAG优化技巧

提升RAG系统的性能和效果。

优化概览

mindmap root((RAG优化)) 数据优化 文档清洗 元数据提取 分块策略 检索优化 查询重写 混合检索 重排序 压缩检索 生成优化 Prompt工程 上下文管理 答案验证 性能优化 缓存 批处理 并行检索

数据优化

1. 文档预处理

import re
from typing import List
from langchain.schema import Document
class DocumentPreprocessor:
"""文档预处理器"""
def __init__(self):
pass
def clean_text(self, text: str) -> str:
"""
清理文本
Args:
text: 原始文本
Returns:
清理后的文本
"""
# 移除多余空白
text = re.sub(r'\n{3,}', '\n\n', text)
text = re.sub(r' +', ' ', text)
# 移除特殊字符(保留中英文和常用标点)
text = re.sub(r'[^\w\s\u4e00-\u9fff.,!?;:()"\'-]', '', text)
return text.strip()
def extract_metadata(self, text: str, filepath: str = None) -> dict:
"""
提取元数据
Args:
text: 文本内容
filepath: 文件路径
Returns:
元数据字典
"""
metadata = {}
# 文件路径
if filepath:
metadata["source"] = filepath
metadata["filename"] = Path(filepath).name
# 字数统计
metadata["word_count"] = len(text.split())
# 提取标题(假设第一行是标题)
lines = text.split('\n')
if lines:
metadata["title"] = lines[0].strip('#').strip()
# 语言检测
try:
from langdetect import detect
metadata["language"] = detect(text)
except:
metadata["language"] = "unknown"
return metadata
def process_documents(self, documents: List[Document]) -> List[Document]:
"""
处理文档列表
Args:
documents: 原始文档
Returns:
处理后的文档
"""
processed = []
for doc in documents:
# 清理文本
cleaned_text = self.clean_text(doc.page_content)
# 提取/增强元数据
metadata = doc.metadata.copy()
filepath = metadata.get('source')
new_metadata = self.extract_metadata(cleaned_text, filepath)
metadata.update(new_metadata)
# 创建新文档
processed_doc = Document(
page_content=cleaned_text,
metadata=metadata
)
processed.append(processed_doc)
return processed
# 使用
preprocessor = DocumentPreprocessor()
cleaned_docs = preprocessor.process_documents(raw_docs)

2. 智能分块策略

from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
class SmartSplitter:
"""智能分割器"""
def __init__(self):
pass
def split_by_structure(self, documents: List[Document]) -> List[Document]:
"""
按文档结构分割(如Markdown标题)
Args:
documents: 原始文档
Returns:
分割后的文档
"""
chunks = []
for doc in documents:
text = doc.page_content
metadata = doc.metadata
# 按Markdown标题分割
sections = re.split(r'\n(?=#+\s)', text)
for i, section in enumerate(sections):
section = section.strip()
if not section:
continue
# 提取小节标题
lines = section.split('\n')
title = lines[0].strip('#').strip() if lines else f"Section {i}"
# 创建元数据副本
chunk_metadata = metadata.copy()
chunk_metadata["section_title"] = title
chunk_metadata["section_id"] = i
chunk = Document(
page_content=section,
metadata=chunk_metadata
)
chunks.append(chunk)
return chunks
def split_by_semantic(self, documents: List[Document]) -> List[Document]:
"""
按语义相似度分割
Args:
documents: 原始文档
Returns:
分割后的文档
"""
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
numpy as np
chunks = []
for doc in documents:
text = doc.page_content
sentences = re.split(r'(?<=[.!?。!?])\s+', text)
if len(sentences) < 2:
chunks.append(doc)
continue
# 计算句子相似度
vectorizer = TfidfVectorizer()
tfidf = vectorizer.fit_transform(sentences)
similarities = cosine_similarity(tfidf[:-1], tfidf[1:])
# 找出相似度低的边界(语义转换点)
boundaries = [0]
for i, sim in enumerate(similarities.diagonal()):
if sim < 0.3:  # 相似度阈值
boundaries.append(i + 1)
boundaries.append(len(sentences))
# 按边界分割
for i in range(len(boundaries) - 1):
start = boundaries[i]
end = boundaries[i + 1]
chunk_text = ' '.join(sentences[start:end])
chunk = Document(
page_content=chunk_text,
metadata=doc.metadata.copy()
)
chunks.append(chunk)
return chunks
# 使用
smart_splitter = SmartSplitter()
# 按结构分割
struct_chunks = smart_splitter.split_by_structure(docs)
# 按语义分割
semantic_chunks = smart_splitter.split_by_semantic(docs)

检索优化

1. 查询重写

from typing import List
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
class QueryRewriter:
"""查询重写器"""
def __init__(self, llm: ChatOpenAI = None):
"""
初始化查询重写器
Args:
llm: LLM实例
"""
self.llm = llm or ChatOpenAI(model="gpt-4o-mini")
self.prompt = ChatPromptTemplate.from_template("""
你是一个查询重写专家。请将用户查询重写为更适合检索的形式。
原始查询:{query}
重写要求:
1. 明确核心意图
2. 补充隐含信息
3. 使用标准术语
4. 保留简洁性
重写后的查询:
""")
def rewrite(self, query: str) -> str:
"""
重写查询
Args:
query: 原始查询
Returns:
重写后的查询
"""
result = self.llm.invoke(
self.prompt.format(query=query)
)
return result.content.strip()
def expand(self, query: str, num_variations: int = 3) -> List[str]:
"""
扩展查询(生成多个变体)
Args:
query: 原始查询
num_variations: 变体数量
Returns:
查询变体列表
"""
prompt = ChatPromptTemplate.from_template(f"""
生成 {num_variations} 个查询变体,每个变体从不同角度表达相同意图。
原始查询:{{query}}
查询变体(每行一个):
""")
result = self.llm.invoke(prompt.format(query=query))
variations = [v.strip() for v in result.content.strip().split('\n') if v.strip()]
# 包含原查询
return [query] + variations[:num_variations]
# 使用
rewriter = QueryRewriter()
original_query = "Python怎么处理日期"
rewritten = rewriter.rewrite(original_query)
print(f"重写后: {rewritten}")
expanded = rewriter.expand(original_query)
print(f"扩展查询: {expanded}")

2. 混合检索

from typing import List
from langchain.schema import Document
from langchain.retrievers import BM25Retriever, EnsembleRetriever
class HybridRetriever:
"""混合检索器"""
def __init__(self, vector_store, top_k: int = 5):
"""
初始化混合检索器
Args:
vector_store: 向量数据库
top_k: 返回结果数
"""
self.top_k = top_k
# 向量检索
self.vector_retriever = vector_store.get_retriever(
search_kwargs={"k": top_k}
)
# BM25关键词检索
self.bm25_retriever = BM25Retriever.from_documents(
vector_store.db._collection.get()['documents']
)
self.bm25_retriever.k = top_k
# 集成检索(混合两种方法)
self.ensemble_retriever = EnsembleRetriever(
retrievers=[self.vector_retriever, self.bm25_retriever],
weights=[0.7, 0.3]  # 向量70%,BM25 30%
)
def search(self, query: str, method: str = "hybrid") -> List[Document]:
"""
搜索
Args:
query: 查询
method: 检索方法 (vector, bm25, hybrid)
Returns:
相关文档
"""
if method == "vector":
return self.vector_retriever.invoke(query)
elif method == "bm25":
return self.bm25_retriever.get_relevant_documents(query)
elif method == "hybrid":
return self.ensemble_retriever.invoke(query)
else:
raise ValueError(f"未知方法: {method}")
# 使用
hybrid_retriever = HybridRetriever(vector_store)
results = hybrid_retriever.search("Python日期处理", method="hybrid")

3. 重排序(Reranking)

from typing import List, Tuple
from langchain.schema import Document
from langchain_openai import ChatOpenAI
class Reranker:
"""重排序器"""
def __init__(self, llm: ChatOpenAI = None):
"""
初始化重排序器
Args:
llm: LLM实例
"""
self.llm = llm or ChatOpenAI(model="gpt-4o-mini")
def rerank(
self,
documents: List[Document],
query: str,
top_k: int = 3
) -> List[Document]:
"""
重排序文档
Args:
documents: 文档列表
query: 查询
top_k: 返回文档数
Returns:
重排序后的文档
"""
if len(documents) <= top_k:
return documents
# 构建Prompt
docs_text = "\n\n".join([
f"[文档{i+1}]\n{doc.page_content}"
for i, doc in enumerate(documents)
])
prompt = f"""
请根据查询的相关性,对以下文档重新排序。
查询:{query}
文档:
{docs_text}
请输出最相关的 {top_k} 个文档编号,从高到低排序(用逗号分隔):
"""
# 调用LLM
result = self.llm.invoke(prompt)
ranks = result.content.strip()
# 解析结果
try:
indices = [int(x.strip()) - 1 for x in ranks.split(',')]
reranked = [documents[i] for i in indices if 0 <= i < len(documents)]
return reranked[:top_k]
except:
# 解析失败,返回原列表
return documents[:top_k]
# 使用
reranker = Reranker()
# 先检索
initial_results = vector_store.search(query, top_k=10)
# 重排序
final_results = reranker.rerank(initial_results, query, top_k=3)

4. 压缩检索(Context Compression)

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_openai import ChatOpenAI
class CompressedRetriever:
"""压缩检索器"""
def __init__(self, base_retriever, llm: ChatOpenAI = None):
"""
初始化压缩检索器
Args:
base_retriever: 基础检索器
llm: LLM实例
"""
self.llm = llm or ChatOpenAI(model="gpt-4o-mini")
# 创建压缩器
self.compressor = LLMChainExtractor.from_llm(self.llm)
# 创建压缩检索器
self.compressed_retriever = ContextualCompressionRetriever(
base_compressor=self.compressor,
base_retriever=base_retriever
)
def search(self, query: str, top_k: int = 5) -> List[Document]:
"""
压缩检索
Args:
query: 查询
top_k: 返回结果数
Returns:
压缩后的相关文档
"""
return self.compressed_retriever.invoke(query, top_k=top_k)
# 使用
base_retriever = vector_store.get_retriever(search_kwargs={"k": 5})
compressed_retriever = CompressedRetriever(base_retriever)
results = compressed_retriever.search("Python日期处理")
# 每个文档只保留与查询相关的部分

生成优化

1. 上下文管理

class ContextManager:
"""上下文管理器"""
def __init__(self, max_context_length: int = 3000):
"""
初始化上下文管理器
Args:
max_context_length: 最大上下文长度(字符)
"""
self.max_context_length = max_context_length
def select_documents(
self,
documents: List[Document],
query: str
) -> List[Document]:
"""
选择最相关的文档,控制上下文长度
Args:
documents: 文档列表
query: 查询
Returns:
选择的文档
"""
selected = []
total_length = 0
for doc in documents:
doc_length = len(doc.page_content)
if total_length + doc_length > self.max_context_length:
break
selected.append(doc)
total_length += doc_length
return selected
def format_context(
self,
documents: List[Document],
max_chars_per_doc: int = 500
) -> str:
"""
格式化上下文
Args:
documents: 文档列表
max_chars_per_doc: 每个文档最大字符数
Returns:
格式化的上下文
"""
formatted = []
for i, doc in enumerate(documents, 1):
content = doc.page_content
source = doc.metadata.get('source', 'unknown')
# 截断过长文档
if len(content) > max_chars_per_doc:
content = content[:max_chars_per_doc] + "..."
formatted.append(f"[参考{i} 来自 {source}]\n{content}")
return "\n\n".join(formatted)
# 使用
context_manager = ContextManager(max_context_length=3000)
selected_docs = context_manager.select_documents(retrieved_docs, query)
formatted_context = context_manager.format_context(selected_docs)

2. 答案验证

class AnswerValidator:
"""答案验证器"""
def __init__(self, llm: ChatOpenAI = None):
"""
初始化验证器
Args:
llm: LLM实例
"""
self.llm = llm or ChatOpenAI(model="gpt-4o-mini")
def validate(
self,
answer: str,
context: str,
query: str
) -> dict:
"""
验证答案
Args:
answer: 生成的答案
context: 参考上下文
query: 原始查询
Returns:
验证结果
"""
prompt = f"""
请验证以下答案是否准确和有用。
查询:{query}
参考信息:
{context}
答案:
{answer}
请评估:
1. 准确性(答案是否基于参考信息)
2. 完整性(是否完整回答了问题)
3. 有用性(答案是否有帮助)
请以JSON格式输出:
{{
"accurate": true/false,
"complete": true/false,
"useful": true/false,
"confidence": 0-1,
"issues": ["问题1", "问题2"],
"suggestion": "改进建议"
}}
"""
result = self.llm.invoke(prompt)
try:
import json
return json.loads(result.content)
except:
return {
"accurate": True,
"complete": True,
"useful": True,
"confidence": 0.8,
"issues": [],
"suggestion": ""
}
# 使用
validator = AnswerValidator()
validation = validator.validate(answer, context, query)
print(f"验证结果: {validation}")

性能优化

1. 缓存

from functools import lru_cache
import hashlib
import pickle
class RAGCache:
"""RAG缓存"""
def __init__(self, cache_file: str = "./rag_cache.pkl"):
"""
初始化缓存
Args:
cache_file: 缓存文件路径
"""
self.cache_file = cache_file
self.cache = self._load_cache()
def _load_cache(self) -> dict:
"""加载缓存"""
try:
with open(self.cache_file, 'rb') as f:
return pickle.load(f)
except:
return {}
def _save_cache(self):
"""保存缓存"""
with open(self.cache_file, 'wb') as f:
pickle.dump(self.cache, f)
def get(self, key: str) -> any:
"""获取缓存"""
return self.cache.get(key)
def set(self, key: str, value: any):
"""设置缓存"""
self.cache[key] = value
self._save_cache()
def _hash_query(self, query: str) -> str:
"""哈希查询"""
return hashlib.md5(query.encode()).hexdigest()
def get_cached_answer(self, query: str) -> any:
"""获取缓存的答案"""
key = self._hash_query(query)
return self.get(key)
def cache_answer(self, query: str, answer: any):
"""缓存答案"""
key = self._hash_query(query)
self.set(key, answer)
# 使用
cache = RAGCache()
# 检查缓存
cached = cache.get_cached_answer("Python日期处理")
if cached:
print("使用缓存答案")
answer = cached
else:
print("生成新答案")
answer = rag.invoke("Python日期处理")
cache.cache_answer("Python日期处理", answer)

2. 批处理

from typing import List
def batch_query(
rag_system,
queries: List[str],
batch_size: int = 5
) -> List[dict]:
"""
批量查询
Args:
rag_system: RAG系统
queries: 查询列表
batch_size: 批大小
Returns:
答案列表
"""
results = []
for i in range(0, len(queries), batch_size):
batch = queries[i:i + batch_size]
# 批量检索
retriever = rag_system.vector_store.get_retriever()
all_docs = retriever.batch(batch)
# 批量生成
batch_results = []
for query, docs in zip(batch, all_docs):
answer = rag_system.llm.invoke(
rag_system.prompt.format(
context="\n\n".join([d.page_content for d in docs]),
question=query
)
)
batch_results.append({"query": query, "answer": answer.content})
results.extend(batch_results)
return results
# 使用
queries = [
"Python日期处理",
"Python文件操作",
"Python网络请求"
]
results = batch_query(rag_system, queries)

学习要点

✅ 数据预处理提升检索质量 ✅ 查询重写和扩展提高召回率 ✅ 混合检索结合多种方法的优势 ✅ 重排序和压缩优化结果相关性 ✅ 上下文管理控制输入长度 ✅ 缓存和批处理提升性能


下一步: 学习 本地LLM部署 💻