分布式训练与监控
High Contrast
Dark Mode
Light Mode
Sepia
Forest
1 min read134 words

分布式训练与监控

单卡不够用?用 DeepSpeed 和多卡训练大模型,用 WandB 实时追踪。

分布式训练架构

graph TB subgraph 数据并行 DP D1[GPU 0: 完整模型 + Batch 1] D2[GPU 1: 完整模型 + Batch 2] D3[GPU 2: 完整模型 + Batch 3] end subgraph 模型并行 MP M1[GPU 0: Layer 1-10] M2[GPU 1: Layer 11-20] M3[GPU 2: Layer 21-32] end subgraph ZeRO 优化 Z1[ZeRO-1: 分片优化器] Z2[ZeRO-2: 分片梯度] Z3[ZeRO-3: 分片参数] end D1 --> SYNC[梯度同步] D2 --> SYNC D3 --> SYNC M1 --> M2 --> M3 Z1 --> Z2 --> Z3 style Z3 fill:#c8e6c9,stroke:#388e3c,stroke-width:2px

DeepSpeed 配置

"""
DeepSpeed Zero 配置详解
"""
import json
class DeepSpeedConfig:
"""DeepSpeed 配置生成"""
@staticmethod
def zero2_config() -> dict:
"""ZeRO Stage 2(最常用)"""
return {
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
},
"allgather_partitions": True,
"allgather_bucket_size": 2e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 2e8,
"contiguous_gradients": True,
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
}
@staticmethod
def zero3_config() -> dict:
"""ZeRO Stage 3(大模型必用)"""
return {
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
},
"offload_param": {
"device": "cpu",
"pin_memory": True,
},
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
}
# ZeRO Stage 选择指南
STAGE_GUIDE = {
"ZeRO-1": {
"分片": "优化器状态",
"节省": "~4x",
"适用": "7B 模型 2-4 卡",
},
"ZeRO-2": {
"分片": "优化器 + 梯度",
"节省": "~8x",
"适用": "13B 模型 4-8 卡",
},
"ZeRO-3": {
"分片": "优化器 + 梯度 + 参数",
"节省": "线性扩展",
"适用": "70B+ 模型",
},
}
# 生成配置文件
config = DeepSpeedConfig()
print("=== ZeRO Stage 选择 ===")
for stage, info in config.STAGE_GUIDE.items():
print(f"  {stage}: 分片{info['分片']} | 适用: {info['适用']}")
# 保存配置示例
z2 = config.zero2_config()
print(f"\nZeRO-2 配置:")
print(json.dumps(z2, indent=2)[:300] + "...")

多卡训练命令

"""
多卡训练启动方式
"""
LAUNCH_COMMANDS = {
"单卡训练": {
"命令": "python train.py",
"适用": "QLoRA 7B",
},
"多卡数据并行 (Accelerate)": {
"命令": "accelerate launch --num_processes 4 train.py",
"适用": "LoRA 多卡",
},
"DeepSpeed ZeRO-2": {
"命令": (
"deepspeed --num_gpus 4 train.py "
"--deepspeed ds_config_z2.json"
),
"适用": "全量微调 13B",
},
"DeepSpeed ZeRO-3": {
"命令": (
"deepspeed --num_gpus 8 train.py "
"--deepspeed ds_config_z3.json"
),
"适用": "全量微调 70B",
},
"多节点训练": {
"命令": (
"deepspeed --hostfile hostfile.txt "
"--num_gpus 8 train.py "
"--deepspeed ds_config_z3.json"
),
"适用": "超大规模训练",
},
}
print("=== 训练启动命令 ===")
for name, info in LAUNCH_COMMANDS.items():
print(f"\n{name}:")
print(f"  命令: {info['命令']}")
print(f"  适用: {info['适用']}")

训练监控与 WandB

"""
WandB 训练监控
"""
class TrainingMonitor:
"""训练监控"""
# 关键监控指标
KEY_METRICS = {
"train/loss": {
"正常范围": "持续下降,最终 0.5-2.0",
"异常信号": "不下降或剧烈震荡",
"处理": "降低学习率或检查数据",
},
"eval/loss": {
"正常范围": "与 train_loss 同步下降",
"异常信号": "上升(过拟合)",
"处理": "提前停止或增加正则化",
},
"train/learning_rate": {
"正常范围": "warmup → peak → 衰减",
"异常信号": "学习率过大导致 loss 震荡",
"处理": "降低 max_lr",
},
"gpu_memory": {
"正常范围": "< 95% 显存",
"异常信号": "OOM 错误",
"处理": "减小 batch_size 或用 gradient_accumulation",
},
}
WANDB_SETUP = """
# 1. 安装
pip install wandb
# 2. 登录
wandb login  # 输入 API key
# 3. 在训练代码中
training_args = TrainingArguments(
report_to="wandb",
run_name="llama3-lora-customer-service",
# ...
)
# 4. WandB 会自动记录:
# - loss 曲线
# - 学习率变化
# - GPU 使用率
# - 训练速度 (samples/s)
# - 超参数配置
"""
TROUBLESHOOTING = {
"Loss 不下降": [
"检查学习率是否太小",
"检查数据格式是否正确",
"检查 tokenizer 的 chat_template",
],
"Loss 剧烈震荡": [
"降低学习率",
"增大 batch_size (用gradient_accumulation)",
"增加 warmup_ratio",
],
"OOM 显存不足": [
"减小 batch_size 到 1",
"减小 max_seq_length",
"使用 gradient_checkpointing",
"切换到 QLoRA",
],
"训练太慢": [
"使用 bf16 而非 fp16",
"启用 packing=True",
"使用 Flash Attention",
"增大 batch_size",
],
}
monitor = TrainingMonitor()
print("=== 关键监控指标 ===")
for metric, info in monitor.KEY_METRICS.items():
print(f"\n{metric}:")
print(f"  正常: {info['正常范围']}")
print(f"  异常: {info['异常信号']}")
print("\n=== 常见问题排查 ===")
for problem, solutions in monitor.TROUBLESHOOTING.items():
print(f"\n{problem}:")
for s in solutions:
print(f"  - {s}")

本章小结

方案 场景 关键配置
单卡 QLoRA 7-13B 模型 bitsandbytes 4-bit
多卡 LoRA 7-13B 性能优先 accelerate
DeepSpeed ZeRO-2 13B 全量微调 梯度+优化器分片
DeepSpeed ZeRO-3 70B+ 全量微调 全分片+CPU offload

下一章:高级技术——指令微调、RLHF 与 DPO 对齐。