训练环境与超参数配置
从环境搭建到超参数调优,让微调跑起来。
训练环境全景
graph TB
subgraph 硬件选择
GPU1[云GPU: A100/H100]
GPU2[消费级: RTX 4090]
GPU3[免费: Colab/Kaggle]
end
subgraph 训练框架
HF[HuggingFace Transformers]
PEFT[PEFT + bitsandbytes]
DS[DeepSpeed]
AX[Axolotl]
end
subgraph 模型选择
M1[Llama 3.1]
M2[Qwen 2.5]
M3[Mistral]
M4[Gemma 2]
end
GPU1 --> HF
GPU2 --> PEFT
GPU3 --> PEFT
HF --> M1
HF --> M2
PEFT --> M3
AX --> M4
style GPU1 fill:#e3f2fd,stroke:#1976d2
style PEFT fill:#c8e6c9,stroke:#388e3c
GPU 选型指南
"""
GPU 选型与成本估算
"""
class GPUSelector:
"""GPU 选型"""
GPUS = {
"NVIDIA A100 80GB": {
"显存": "80 GB",
"算力": "312 TFLOPS (FP16)",
"价格": "$1.5-2.5/h (云)",
"适用": "全量微调 7-70B",
"来源": "AWS/Azure/GCP/Lambda",
},
"NVIDIA H100 80GB": {
"显存": "80 GB",
"算力": "990 TFLOPS (FP16)",
"价格": "$2.5-4.0/h (云)",
"适用": "大规模训练,速度 2x A100",
"来源": "AWS/Azure/GCP/CoreWeave",
},
"NVIDIA RTX 4090 24GB": {
"显存": "24 GB",
"算力": "165 TFLOPS (FP16)",
"价格": "$0.3-0.7/h (云) | $1600 (购买)",
"适用": "QLoRA 微调 7-13B",
"来源": "Vast.ai/RunPod/自建",
},
"Google Colab T4": {
"显存": "16 GB",
"算力": "65 TFLOPS (FP16)",
"价格": "免费 / $10/月 (Pro)",
"适用": "QLoRA 微调 7B 小规模实验",
"来源": "Google Colab",
},
}
MODEL_GPU_REQUIREMENTS = {
"7B (QLoRA)": "1x RTX 4090 (6 GB)",
"7B (LoRA)": "1x A100 (16 GB)",
"7B (全量)": "2x A100 (84 GB)",
"13B (QLoRA)": "1x RTX 4090 (10 GB)",
"13B (LoRA)": "1x A100 (28 GB)",
"70B (QLoRA)": "1x A100 (40 GB)",
"70B (LoRA)": "4x A100 (160 GB)",
"70B (全量)": "8x A100 (480 GB)",
}
selector = GPUSelector()
print("=== GPU 选型 ===")
for name, info in selector.GPUS.items():
print(f"\n{name}:")
print(f" 显存: {info['显存']} | 价格: {info['价格']}")
print(f" 适用: {info['适用']}")
print("\n=== 模型-GPU 对照 ===")
for model, gpu in selector.MODEL_GPU_REQUIREMENTS.items():
print(f" {model}: {gpu}")
环境搭建
"""
训练环境一键搭建
"""
SETUP_SCRIPT = """
# ===== 1. 创建环境 =====
conda create -n finetune python=3.11 -y
conda activate finetune
# ===== 2. 安装 PyTorch (CUDA 12.1) =====
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121
# ===== 3. 安装核心库 =====
pip install transformers==4.44.0
pip install datasets==2.20.0
pip install accelerate==0.33.0
pip install peft==0.12.0
pip install trl==0.10.0 # SFT/DPO 训练
pip install bitsandbytes==0.43.0 # 量化
# ===== 4. 安装监控工具 =====
pip install wandb # 实验追踪
pip install tensorboard # 训练可视化
# ===== 5. 验证安装 =====
python -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'显存: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')
"
"""
# 打印安装脚本
print("=== 环境搭建脚本 ===")
print(SETUP_SCRIPT)
超参数配置详解
"""
超参数调优指南
"""
from dataclasses import dataclass
@dataclass
class HyperparameterGuide:
"""超参数指南"""
PARAMS = {
"learning_rate": {
"范围": "1e-5 ~ 5e-4",
"推荐": {
"全量微调": "1e-5 ~ 2e-5",
"LoRA": "1e-4 ~ 3e-4",
"QLoRA": "2e-4",
},
"说明": "太大会不稳定,太小学不到东西",
},
"num_train_epochs": {
"范围": "1 ~ 10",
"推荐": {
"数据 < 1000": "3-5 epochs",
"数据 1000-10000": "2-3 epochs",
"数据 > 10000": "1-2 epochs",
},
"说明": "epoch 过多会过拟合",
},
"per_device_train_batch_size": {
"范围": "1 ~ 32",
"推荐": {
"24GB GPU": "2-4",
"80GB GPU": "8-16",
},
"说明": "受限于显存,配合 gradient_accumulation 使用",
},
"gradient_accumulation_steps": {
"范围": "1 ~ 32",
"推荐": "4-8",
"说明": "等效增大 batch size,不增加显存",
},
"warmup_ratio": {
"范围": "0.0 ~ 0.1",
"推荐": "0.03 ~ 0.1",
"说明": "学习率预热,防止训练初期不稳定",
},
"weight_decay": {
"范围": "0.0 ~ 0.1",
"推荐": "0.01",
"说明": "正则化,防止过拟合",
},
"max_seq_length": {
"范围": "512 ~ 8192",
"推荐": "1024 ~ 2048",
"说明": "根据数据长度决定,越长越耗显存",
},
}
guide = HyperparameterGuide()
print("=== 超参数配置 ===")
for name, info in guide.PARAMS.items():
print(f"\n{name}:")
if isinstance(info["推荐"], dict):
for scenario, val in info["推荐"].items():
print(f" {scenario}: {val}")
else:
print(f" 推荐: {info['推荐']}")
print(f" 说明: {info['说明']}")
完整训练脚本
"""
LoRA 微调完整脚本(可直接运行)
"""
TRAINING_SCRIPT = '''
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
# ===== 配置 =====
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
DATA_FILE = "train_data.jsonl"
OUTPUT_DIR = "./output"
# ===== 1. 量化配置 (QLoRA) =====
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# ===== 2. 加载模型 =====
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
# ===== 3. LoRA 配置 =====
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# ===== 4. 数据集 =====
dataset = load_dataset("json", data_files=DATA_FILE, split="train")
def format_chat(example):
"""格式化为聊天模板"""
messages = [
{"role": "system", "content": "你是一个专业的AI助手。"},
{"role": "user", "content": example["instruction"]},
{"role": "assistant", "content": example["output"]},
]
return {"text": tokenizer.apply_chat_template(
messages, tokenize=False
)}
dataset = dataset.map(format_chat)
# ===== 5. 训练配置 =====
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
weight_decay=0.01,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
fp16=True,
logging_steps=10,
save_strategy="epoch",
report_to="wandb",
optim="paged_adamw_8bit",
)
# ===== 6. 训练 =====
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
dataset_text_field="text",
max_seq_length=2048,
packing=True,
)
trainer.train()
trainer.save_model(OUTPUT_DIR)
print("训练完成!")
'''
print("=== 完整训练脚本 ===")
print(TRAINING_SCRIPT[:500] + "...")
本章小结
| 配置项 | 推荐值 | 说明 |
|---|---|---|
| GPU | RTX 4090 / A100 | QLoRA 用 4090,LoRA 用 A100 |
| learning_rate | 2e-4 (LoRA) | 全量微调用更小的 1e-5 |
| epochs | 2-3 | 数据多用少,数据少用多 |
| batch_size | 4 + accumulation 4 | 等效 batch=16 |
| max_seq_length | 2048 | 根据数据长度调整 |
下一章:高级微调技术——指令微调、RLHF 与 DPO。