代码检索与辅助开发
代码库是一种特殊的知识库:它有语法结构、依赖关系、版本历史。Code RAG 帮助开发者在海量代码中快速找到相关实现、理解代码逻辑、生成新代码。
Code RAG 架构
graph TB
A[代码库] --> B[代码解析]
B --> C[函数/类粒度切分]
B --> D[文档字符串提取]
B --> E[依赖关系图]
C --> F[Code Embedding]
D --> F
G[开发者查询] --> H[意图识别]
H --> I{查询类型}
I -->|代码搜索| J[语义代码检索]
I -->|理解代码| K[上下文聚合 + 解释]
I -->|生成代码| L[检索示例 + 生成]
J --> M[结果 + 代码片段]
K --> M
L --> M
style G fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style M fill:#c8e6c9,stroke:#388e3c,stroke-width:3px
代码解析与切分
"""
代码解析与智能切分
"""
import ast
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class CodeUnit:
"""代码单元"""
name: str
unit_type: str # function, class, method, module
code: str
docstring: str = ""
file_path: str = ""
line_start: int = 0
line_end: int = 0
dependencies: list[str] = field(default_factory=list)
signature: str = ""
class PythonCodeParser:
"""Python 代码解析器"""
def parse_file(self, file_path: str) -> list[CodeUnit]:
"""解析 Python 文件为代码单元"""
path = Path(file_path)
source = path.read_text(encoding="utf-8")
try:
tree = ast.parse(source)
except SyntaxError as e:
print(f" 解析失败 {file_path}: {e}")
return []
units = []
lines = source.split("\n")
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
unit = self._extract_function(node, lines, file_path)
units.append(unit)
elif isinstance(node, ast.ClassDef):
unit = self._extract_class(node, lines, file_path)
units.append(unit)
return units
def _extract_function(self, node: ast.FunctionDef, lines: list[str], file_path: str) -> CodeUnit:
"""提取函数"""
code = "\n".join(lines[node.lineno - 1: node.end_lineno])
docstring = ast.get_docstring(node) or ""
# 构建签名
args = [arg.arg for arg in node.args.args]
signature = f"def {node.name}({', '.join(args)})"
# 提取依赖(import 和调用)
deps = []
for child in ast.walk(node):
if isinstance(child, ast.Call) and isinstance(child.func, ast.Name):
deps.append(child.func.id)
return CodeUnit(
name=node.name,
unit_type="function",
code=code,
docstring=docstring,
file_path=file_path,
line_start=node.lineno,
line_end=node.end_lineno or node.lineno,
dependencies=list(set(deps)),
signature=signature,
)
def _extract_class(self, node: ast.ClassDef, lines: list[str], file_path: str) -> CodeUnit:
"""提取类"""
code = "\n".join(lines[node.lineno - 1: node.end_lineno])
docstring = ast.get_docstring(node) or ""
bases = [
getattr(base, "id", getattr(base, "attr", ""))
for base in node.bases
]
signature = f"class {node.name}({', '.join(bases)})" if bases else f"class {node.name}"
return CodeUnit(
name=node.name,
unit_type="class",
code=code,
docstring=docstring,
file_path=file_path,
line_start=node.lineno,
line_end=node.end_lineno or node.lineno,
signature=signature,
)
Code Embedding 策略
| 策略 | 方法 | 优势 | 适用场景 |
|---|---|---|---|
| 代码直接嵌入 | 对源代码做 Embedding | 保留语法信息 | 代码相似搜索 |
| 文档字符串嵌入 | 对 docstring 做 Embedding | 语义理解好 | 自然语言搜索代码 |
| 代码+文档混合 | 拼接代码和文档 | 平衡语法和语义 | 通用搜索 |
| 签名嵌入 | 对函数签名做 Embedding | 轻量快速 | API 检索 |
"""
代码 Embedding 管理
"""
from dataclasses import dataclass
@dataclass
class CodeEmbeddingConfig:
"""代码 Embedding 配置"""
include_code: bool = True
include_docstring: bool = True
include_signature: bool = True
max_code_length: int = 2000
class CodeEmbedder:
"""代码嵌入生成器"""
def __init__(self, embed_client, config: CodeEmbeddingConfig | None = None):
self.embedder = embed_client
self.config = config or CodeEmbeddingConfig()
def embed_unit(self, unit: CodeUnit) -> dict:
"""为代码单元生成嵌入"""
parts = []
if self.config.include_signature:
parts.append(f"签名:{unit.signature}")
if self.config.include_docstring and unit.docstring:
parts.append(f"文档:{unit.docstring}")
if self.config.include_code:
code = unit.code[:self.config.max_code_length]
parts.append(f"代码:{code}")
text = "\n".join(parts)
vector = self.embedder.embed(text)
return {
"id": f"{unit.file_path}::{unit.name}",
"vector": vector,
"text": text,
"metadata": {
"name": unit.name,
"type": unit.unit_type,
"file": unit.file_path,
"line_start": unit.line_start,
"line_end": unit.line_end,
},
}
RAG 辅助代码生成
"""
基于 RAG 的代码生成
"""
class CodeRAGGenerator:
"""RAG 辅助代码生成器"""
CODE_GEN_PROMPT = """根据以下代码库中的相关实现,生成满足需求的代码。
相关代码参考:
{reference_code}
代码风格和约定:
- 遵循项目现有的命名规范
- 使用相同的错误处理模式
- 保持一致的注释风格
需求:{requirement}
生成的代码:"""
def __init__(self, code_retriever, llm_client):
self.retriever = code_retriever
self.llm = llm_client
def generate(self, requirement: str) -> dict:
"""检索相关代码 + 生成新代码"""
# 检索相关代码单元
related_units = self.retriever.search(requirement, top_k=5)
# 构建参考代码上下文
reference_parts = []
for unit in related_units:
meta = unit.get("metadata", {})
reference_parts.append(
f"# {meta.get('file', '')} - {meta.get('name', '')}\n{unit.get('text', '')}"
)
reference_code = "\n\n".join(reference_parts)
# 生成代码
generated = self.llm.generate(
self.CODE_GEN_PROMPT.format(
reference_code=reference_code,
requirement=requirement,
)
)
return {
"code": generated,
"references": [u.get("metadata", {}) for u in related_units],
}
def explain_code(self, code_snippet: str) -> str:
"""基于代码库上下文解释代码"""
# 检索相关的类和函数定义
related = self.retriever.search(code_snippet, top_k=3)
context = "\n".join(
f"- {r['metadata']['name']}: {r.get('text', '')[:200]}"
for r in related
)
prompt = f"""解释以下代码片段。参考项目中的相关实现来理解上下文。
代码片段:
{code_snippet}
项目中的相关代码:
{context}
解释:"""
return self.llm.generate(prompt)
Code RAG 应用场景
graph LR
A[Code RAG] --> B[代码搜索]
A --> C[代码理解]
A --> D[代码生成]
A --> E[Bug 分析]
A --> F[代码评审]
B --> B1[自然语言搜索代码库]
C --> C1[解释复杂函数逻辑]
D --> D1[参考现有代码生成新功能]
E --> E1[检索历史 Bug 修复模式]
F --> F1[对比最佳实践建议改进]
style A fill:#e3f2fd,stroke:#1976d2,stroke-width:3px
| 应用 | 输入 | 输出 | 关键检索对象 |
|---|---|---|---|
| 代码搜索 | 自然语言描述 | 代码片段列表 | 函数、类 |
| 代码理解 | 代码片段 | 自然语言解释 | 上下文代码 |
| 代码生成 | 需求描述 | 新代码 | 相似实现 |
| Bug 分析 | 错误信息 | 修复建议 | 历史 commit |
| 代码评审 | PR diff | 改进建议 | 最佳实践 |
本章小结
| 主题 | 要点 |
|---|---|
| 代码解析 | AST 解析到函数/类粒度 |
| Embedding | 混合代码 + 文档字符串效果最佳 |
| 代码生成 | 检索相似实现作为参考 |
| 应用场景 | 搜索、理解、生成、Bug 分析、评审 |
| 关键挑战 | 代码库更新频繁,需增量索引 |
下一章:RAG 技术趋势与展望