分布式训练与监控
单卡不够用?用 DeepSpeed 和多卡训练大模型,用 WandB 实时追踪。
分布式训练架构
graph TB
subgraph 数据并行 DP
D1[GPU 0: 完整模型 + Batch 1]
D2[GPU 1: 完整模型 + Batch 2]
D3[GPU 2: 完整模型 + Batch 3]
end
subgraph 模型并行 MP
M1[GPU 0: Layer 1-10]
M2[GPU 1: Layer 11-20]
M3[GPU 2: Layer 21-32]
end
subgraph ZeRO 优化
Z1[ZeRO-1: 分片优化器]
Z2[ZeRO-2: 分片梯度]
Z3[ZeRO-3: 分片参数]
end
D1 --> SYNC[梯度同步]
D2 --> SYNC
D3 --> SYNC
M1 --> M2 --> M3
Z1 --> Z2 --> Z3
style Z3 fill:#c8e6c9,stroke:#388e3c,stroke-width:2px
DeepSpeed 配置
"""
DeepSpeed Zero 配置详解
"""
import json
class DeepSpeedConfig:
"""DeepSpeed 配置生成"""
@staticmethod
def zero2_config() -> dict:
"""ZeRO Stage 2(最常用)"""
return {
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
},
"allgather_partitions": True,
"allgather_bucket_size": 2e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 2e8,
"contiguous_gradients": True,
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
}
@staticmethod
def zero3_config() -> dict:
"""ZeRO Stage 3(大模型必用)"""
return {
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
},
"offload_param": {
"device": "cpu",
"pin_memory": True,
},
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
}
# ZeRO Stage 选择指南
STAGE_GUIDE = {
"ZeRO-1": {
"分片": "优化器状态",
"节省": "~4x",
"适用": "7B 模型 2-4 卡",
},
"ZeRO-2": {
"分片": "优化器 + 梯度",
"节省": "~8x",
"适用": "13B 模型 4-8 卡",
},
"ZeRO-3": {
"分片": "优化器 + 梯度 + 参数",
"节省": "线性扩展",
"适用": "70B+ 模型",
},
}
# 生成配置文件
config = DeepSpeedConfig()
print("=== ZeRO Stage 选择 ===")
for stage, info in config.STAGE_GUIDE.items():
print(f" {stage}: 分片{info['分片']} | 适用: {info['适用']}")
# 保存配置示例
z2 = config.zero2_config()
print(f"\nZeRO-2 配置:")
print(json.dumps(z2, indent=2)[:300] + "...")
多卡训练命令
"""
多卡训练启动方式
"""
LAUNCH_COMMANDS = {
"单卡训练": {
"命令": "python train.py",
"适用": "QLoRA 7B",
},
"多卡数据并行 (Accelerate)": {
"命令": "accelerate launch --num_processes 4 train.py",
"适用": "LoRA 多卡",
},
"DeepSpeed ZeRO-2": {
"命令": (
"deepspeed --num_gpus 4 train.py "
"--deepspeed ds_config_z2.json"
),
"适用": "全量微调 13B",
},
"DeepSpeed ZeRO-3": {
"命令": (
"deepspeed --num_gpus 8 train.py "
"--deepspeed ds_config_z3.json"
),
"适用": "全量微调 70B",
},
"多节点训练": {
"命令": (
"deepspeed --hostfile hostfile.txt "
"--num_gpus 8 train.py "
"--deepspeed ds_config_z3.json"
),
"适用": "超大规模训练",
},
}
print("=== 训练启动命令 ===")
for name, info in LAUNCH_COMMANDS.items():
print(f"\n{name}:")
print(f" 命令: {info['命令']}")
print(f" 适用: {info['适用']}")
训练监控与 WandB
"""
WandB 训练监控
"""
class TrainingMonitor:
"""训练监控"""
# 关键监控指标
KEY_METRICS = {
"train/loss": {
"正常范围": "持续下降,最终 0.5-2.0",
"异常信号": "不下降或剧烈震荡",
"处理": "降低学习率或检查数据",
},
"eval/loss": {
"正常范围": "与 train_loss 同步下降",
"异常信号": "上升(过拟合)",
"处理": "提前停止或增加正则化",
},
"train/learning_rate": {
"正常范围": "warmup → peak → 衰减",
"异常信号": "学习率过大导致 loss 震荡",
"处理": "降低 max_lr",
},
"gpu_memory": {
"正常范围": "< 95% 显存",
"异常信号": "OOM 错误",
"处理": "减小 batch_size 或用 gradient_accumulation",
},
}
WANDB_SETUP = """
# 1. 安装
pip install wandb
# 2. 登录
wandb login # 输入 API key
# 3. 在训练代码中
training_args = TrainingArguments(
report_to="wandb",
run_name="llama3-lora-customer-service",
# ...
)
# 4. WandB 会自动记录:
# - loss 曲线
# - 学习率变化
# - GPU 使用率
# - 训练速度 (samples/s)
# - 超参数配置
"""
TROUBLESHOOTING = {
"Loss 不下降": [
"检查学习率是否太小",
"检查数据格式是否正确",
"检查 tokenizer 的 chat_template",
],
"Loss 剧烈震荡": [
"降低学习率",
"增大 batch_size (用gradient_accumulation)",
"增加 warmup_ratio",
],
"OOM 显存不足": [
"减小 batch_size 到 1",
"减小 max_seq_length",
"使用 gradient_checkpointing",
"切换到 QLoRA",
],
"训练太慢": [
"使用 bf16 而非 fp16",
"启用 packing=True",
"使用 Flash Attention",
"增大 batch_size",
],
}
monitor = TrainingMonitor()
print("=== 关键监控指标 ===")
for metric, info in monitor.KEY_METRICS.items():
print(f"\n{metric}:")
print(f" 正常: {info['正常范围']}")
print(f" 异常: {info['异常信号']}")
print("\n=== 常见问题排查 ===")
for problem, solutions in monitor.TROUBLESHOOTING.items():
print(f"\n{problem}:")
for s in solutions:
print(f" - {s}")
本章小结
| 方案 | 场景 | 关键配置 |
|---|---|---|
| 单卡 QLoRA | 7-13B 模型 | bitsandbytes 4-bit |
| 多卡 LoRA | 7-13B 性能优先 | accelerate |
| DeepSpeed ZeRO-2 | 13B 全量微调 | 梯度+优化器分片 |
| DeepSpeed ZeRO-3 | 70B+ 全量微调 | 全分片+CPU offload |
下一章:高级技术——指令微调、RLHF 与 DPO 对齐。