Prompt 管理与版本控制
Prompt 是 LLM 应用的"源代码"。在生产环境中,Prompt 需要和代码一样进行版本管理、测试和灰度发布。
Prompt 管理架构
graph TB
A[Prompt 管理] --> B[版本控制]
A --> C[测试验证]
A --> D[灰度发布]
A --> E[监控回滚]
B --> B1[Git 管理 Prompt 模板]
B --> B2[元数据:作者/日期/变更]
C --> C1[回归测试集]
C --> C2[自动评估]
D --> D1[A/B 测试]
D --> D2[流量分流]
E --> E1[效果追踪]
E --> E2[一键回滚]
style A fill:#e3f2fd,stroke:#1976d2,stroke-width:3px
style C fill:#fff3e0,stroke:#f57c00,stroke-width:2px
Prompt 版本注册表
"""
Prompt 版本管理系统
"""
import hashlib
import json
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
@dataclass
class PromptTemplate:
"""Prompt 模板"""
name: str
template: str
variables: list[str]
model: str = ""
temperature: float = 0.7
max_tokens: int = 2048
def render(self, **kwargs) -> str:
"""渲染模板"""
result = self.template
for var in self.variables:
placeholder = "{" + var + "}"
value = kwargs.get(var, "")
result = result.replace(placeholder, str(value))
return result
@dataclass
class PromptVersion:
"""Prompt 版本"""
version: str
template: PromptTemplate
author: str
created_at: datetime
changelog: str
test_score: float | None = None
is_active: bool = False
@property
def content_hash(self) -> str:
return hashlib.sha256(self.template.template.encode()).hexdigest()[:12]
class PromptRegistry:
"""Prompt 注册表"""
def __init__(self, storage_path: Path | None = None):
self.registry: dict[str, list[PromptVersion]] = {}
self._active: dict[str, PromptVersion] = {}
self.storage_path = storage_path
def register(
self,
name: str,
template: PromptTemplate,
author: str,
changelog: str,
) -> PromptVersion:
"""注册新版本"""
if name not in self.registry:
self.registry[name] = []
version_num = len(self.registry[name]) + 1
version = PromptVersion(
version=f"v{version_num}",
template=template,
author=author,
created_at=datetime.now(),
changelog=changelog,
)
self.registry[name].append(version)
return version
def activate(self, name: str, version: str) -> bool:
"""激活指定版本"""
for v in self.registry.get(name, []):
if v.version == version:
if name in self._active:
self._active[name].is_active = False
v.is_active = True
self._active[name] = v
return True
return False
def get_active(self, name: str) -> PromptVersion | None:
"""获取活跃版本"""
return self._active.get(name)
def rollback(self, name: str) -> bool:
"""回滚到上一版本"""
versions = self.registry.get(name, [])
current = self._active.get(name)
if not current or len(versions) < 2:
return False
idx = next(
(i for i, v in enumerate(versions) if v.version == current.version), -1
)
if idx > 0:
return self.activate(name, versions[idx - 1].version)
return False
def list_versions(self, name: str) -> list[dict]:
"""列出所有版本"""
return [
{
"version": v.version,
"hash": v.content_hash,
"author": v.author,
"created": v.created_at.isoformat(),
"active": v.is_active,
"score": v.test_score,
}
for v in self.registry.get(name, [])
]
Prompt 测试框架
"""
Prompt 回归测试
"""
from dataclasses import dataclass, field
@dataclass
class TestCase:
"""测试用例"""
name: str
input_vars: dict
expected_contains: list[str] = field(default_factory=list)
expected_not_contains: list[str] = field(default_factory=list)
max_tokens: int | None = None
@dataclass
class TestResult:
"""测试结果"""
case_name: str
passed: bool
output: str
errors: list[str] = field(default_factory=list)
class PromptTestRunner:
"""Prompt 测试运行器"""
def __init__(self, llm_client):
self.llm = llm_client
def run_tests(
self, template: PromptTemplate, cases: list[TestCase]
) -> list[TestResult]:
"""运行测试套件"""
results = []
for case in cases:
prompt = template.render(**case.input_vars)
output = self.llm.generate(prompt)
errors = self._check_assertions(output, case)
results.append(TestResult(
case_name=case.name,
passed=len(errors) == 0,
output=output,
errors=errors,
))
return results
def _check_assertions(self, output: str, case: TestCase) -> list[str]:
"""检查断言"""
errors = []
for expected in case.expected_contains:
if expected.lower() not in output.lower():
errors.append(f"缺少预期内容: '{expected}'")
for forbidden in case.expected_not_contains:
if forbidden.lower() in output.lower():
errors.append(f"包含禁止内容: '{forbidden}'")
if case.max_tokens and len(output.split()) > case.max_tokens:
errors.append(f"输出超过 {case.max_tokens} tokens")
return errors
def score(self, results: list[TestResult]) -> float:
"""计算通过率"""
if not results:
return 0.0
passed = sum(1 for r in results if r.passed)
return passed / len(results)
Prompt 管理最佳实践
| 实践 | 说明 | 工具推荐 |
|---|---|---|
| 模板化 | 使用变量槽位,不硬编码 | Jinja2 / Handlebars |
| 版本化 | 每次变更记录版本号和变更说明 | Git / Prompt Registry |
| 测试 | 维护回归测试集,自动评估 | pytest + LLM Judge |
| 灰度 | 新版本先小流量验证 | A/B 框架 |
| 监控 | 追踪效果指标,异常告警 | LangFuse / Phoenix |
| 文档化 | 记录设计意图和约束 | 与 Prompt 同仓库 |
本章小结
| 主题 | 要点 |
|---|---|
| Prompt 模板 | 变量化 + 可渲染 + 绑定模型参数 |
| 版本管理 | 注册 → 激活 → 回滚,哈希追踪内容变更 |
| 回归测试 | 包含/排除断言 + Token 限制 + 自动评分 |
| 发布流程 | 测试通过 → 灰度发布 → 全量激活 |
下一章:部署方案与平台选型