模型蒸馏与持续学习
High Contrast
Dark Mode
Light Mode
Sepia
Forest
1 min read179 words

模型蒸馏与持续学习

让 70B 的能力装进 7B 的身体——蒸馏。让模型学新知识不忘旧本领——持续学习。

知识蒸馏架构

graph LR T[教师模型 70B] -->|生成回答| DATA[蒸馏数据集] DATA -->|SFT 训练| S[学生模型 7B] T2[GPT-4o] -->|API 生成| DATA2[高质量数据] DATA2 -->|微调| S2[开源 7B 模型] S -->|推理成本 1/10| DEPLOY[部署] S2 -->|私有化部署| DEPLOY style T fill:#fff3e0,stroke:#f57c00,stroke-width:2px style S fill:#c8e6c9,stroke:#388e3c,stroke-width:2px

知识蒸馏方法

"""
模型蒸馏:大模型知识 → 小模型
"""
class KnowledgeDistillation:
"""知识蒸馏"""
METHODS = {
"输出蒸馏 (最常用)": {
"方法": "教师模型生成回答 → 学生模型 SFT",
"特点": "简单有效,最主流的方法",
"步骤": [
"1. 准备任务 Prompt 集合",
"2. 用教师模型(GPT-4)批量生成回答",
"3. 过滤低质量输出",
"4. 用这些数据 SFT 训练学生模型",
],
},
"Logits 蒸馏": {
"方法": "学生学习教师的概率分布",
"特点": "需要教师模型 logits,开源模型可用",
"步骤": [
"1. 前向传播获取教师模型 logits",
"2. 用 KL 散度对齐学生和教师的分布",
"3. 混合 CE Loss + KL Loss",
],
},
"思维链蒸馏": {
"方法": "蒸馏推理过程而非只有答案",
"特点": "学生模型也能推理",
"步骤": [
"1. 教师用 CoT 生成推理过程",
"2. 学生同时学习推理和答案",
],
},
}
CODE = """
# 输出蒸馏实战
from openai import OpenAI
client = OpenAI()
def distill_from_gpt4(prompts: list[str]) -> list[dict]:
\"\"\"用 GPT-4 生成蒸馏数据\"\"\"
results = []
for prompt in prompts:
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是专家助手,回答详细准确"},
{"role": "user", "content": prompt},
],
temperature=0.7,
)
results.append({
"instruction": prompt,
"output": response.choices[0].message.content,
})
return results
# 然后用这些数据微调 7B 开源模型
"""
dist = KnowledgeDistillation()
print("=== 蒸馏方法 ===")
for name, info in dist.METHODS.items():
print(f"\n{name}:")
print(f"  方法: {info['方法']}")
for step in info["步骤"]:
print(f"    {step}")

持续学习

"""
持续学习:学新知识不忘旧本领
"""
class ContinualLearning:
"""持续学习"""
# 灾难性遗忘
CATASTROPHIC_FORGETTING = {
"现象": "微调新任务后,旧任务性能大幅下降",
"原因": "新任务的梯度覆盖了旧知识的参数",
"示例": (
"用客服数据微调后,模型的代码能力下降了 40%"
),
}
# 缓解策略
STRATEGIES = {
"数据混合": {
"方法": "训练数据混入部分通用数据",
"比例": "新任务 70% + 通用 30%",
"效果": "⭐⭐⭐⭐",
"实现": "简单,最推荐",
},
"LoRA 隔离": {
"方法": "每个任务训练独立 LoRA 适配器",
"比例": "不同任务加载不同 LoRA",
"效果": "⭐⭐⭐⭐⭐",
"实现": "中等,需要适配器管理",
},
"EWC 正则化": {
"方法": "弹性权重巩固,保护重要参数",
"比例": "自动权衡新旧任务",
"效果": "⭐⭐⭐",
"实现": "较复杂",
},
"Replay Buffer": {
"方法": "存储旧任务样本,训练时混入",
"比例": "每 batch 含 10-20% 旧数据",
"效果": "⭐⭐⭐⭐",
"实现": "简单",
},
"渐进式微调": {
"方法": "先冻结低层,只微调高层",
"比例": "逐层解冻",
"效果": "⭐⭐⭐",
"实现": "中等",
},
}
cl = ContinualLearning()
print("=== 灾难性遗忘 ===")
for k, v in cl.CATASTROPHIC_FORGETTING.items():
print(f"  {k}: {v}")
print("\n=== 缓解策略 ===")
for name, info in cl.STRATEGIES.items():
print(f"\n  {name} (效果: {info['效果']}):")
print(f"    方法: {info['方法']}")
print(f"    实现: {info['实现']}")

多任务学习

"""
多任务微调:一次训练,多种能力
"""
class MultitaskFinetuning:
"""多任务微调"""
APPROACH = {
"数据混合策略": {
"均匀混合": "各任务数据量相同",
"加权混合": "按重要性设置采样权重",
"课程学习": "先简单任务后复杂任务",
},
"示例配置": {
"客服对话": {"比例": 0.3, "数据量": 3000},
"文本摘要": {"比例": 0.2, "数据量": 2000},
"代码生成": {"比例": 0.2, "数据量": 2000},
"翻译": {"比例": 0.15, "数据量": 1500},
"通用对话": {"比例": 0.15, "数据量": 1500},
},
}
CODE = """
from datasets import concatenate_datasets, load_dataset
# 加载各任务数据
customer_service = load_dataset("json", data_files="cs.jsonl")["train"]
summarization = load_dataset("json", data_files="summary.jsonl")["train"]
coding = load_dataset("json", data_files="code.jsonl")["train"]
# 按比例采样
cs_sampled = customer_service.select(range(3000))
sum_sampled = summarization.select(range(2000))
code_sampled = coding.select(range(2000))
# 混合
dataset = concatenate_datasets([cs_sampled, sum_sampled, code_sampled])
dataset = dataset.shuffle(seed=42)
# 训练
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
# ...
)
"""
mt = MultitaskFinetuning()
print("=== 多任务混合配置 ===")
for task, config in mt.APPROACH["示例配置"].items():
print(f"  {task}: {config['比例']*100}% ({config['数据量']} 条)")

本章小结

技术 目的 复杂度 适用场景
输出蒸馏 大模型→小模型 降成本部署
Logits 蒸馏 精细知识迁移 ⭐⭐⭐ 开源模型间蒸馏
数据混合 防止遗忘 所有微调必做
LoRA 隔离 多任务切换 ⭐⭐ 多场景服务
多任务微调 一次训练多能力 ⭐⭐ 通用助手

下一章:评估与测试——如何科学衡量微调效果。