LLM 后训练实践
第2课:SFT进阶

2.2 SFT 超参数实践指南

基于 Pareja 等(2024)的系统实验结果,掌握 SFT 核心超参数的选择方法和常见问题诊断

概述

SFT 的效果高度依赖超参数的选择。本节基于 Pareja 等(2024) 的系统性实验结论——该论文在 3B-7B 模型上进行了大量 SFT 超参数消融实验,为小型到中型 LLM 的 SFT 提供了全面的实践指南。


核心超参数

学习率(Learning Rate)

学习率是 SFT 中最敏感的超参数。

方法推荐范围说明
全参数 SFT1e-5 ~ 5e-5偏小更安全,避免灾难性遗忘
LoRA SFT1e-5 ~ 5e-5与全参数类似
QLoRA SFT1e-4 ~ 3e-4通常需要更大的学习率

关键发现(Pareja 等):学习率对最终性能的影响超过其他任何单一超参数。过大的学习率会导致灾难性遗忘(模型忘记预训练知识),过小则会欠拟合

学习率调度器推荐使用 cosine 调度,配合 warmup:

ηt={ηmaxtTwarmupif t<Twarmupηmin+12(ηmaxηmin)(1+cos(tTwarmupTtotalTwarmupπ))otherwise\eta_t = \begin{cases} \eta_{\max} \cdot \frac{t}{T_{\text{warmup}}} & \text{if } t < T_{\text{warmup}} \\ \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})(1 + \cos(\frac{t - T_{\text{warmup}}}{T_{\text{total}} - T_{\text{warmup}}} \pi)) & \text{otherwise} \end{cases}

常见配置:

learning_rate = 2e-5           # 基础学习率
lr_scheduler_type = "cosine"   # cosine 调度
warmup_ratio = 0.1             # warmup 占总步数的 10%

批量大小(Batch Size)

有效批量大小 = per_device_batch_size × gradient_accumulation_steps × num_gpus

有效批量效果适用场景
4-8训练不稳定,梯度噪声大不推荐
16-32稳定,效果好推荐
64-128非常稳定,但可能需要调大学习率大规模数据
# 在单张 GPU 上实现有效批量 32
per_device_train_batch_size = 4
gradient_accumulation_steps = 8
# 有效批量 = 4 × 8 = 32

实用技巧:如果 GPU 显存有限,优先使用梯度累积来增大有效批量,而不是缩小批量。较大的有效批量带来更稳定的梯度估计,尤其在 SFT 数据多样性较高时。

训练轮数(Epochs)

数据规模推荐 Epochs说明
< 5K 条3-5数据少,需要多看几遍
5K-50K 条1-3最常见的配置
> 50K 条1数据充足,1 个 epoch 通常够

Pareja 等的核心发现:对于高质量数据,1-2 个 epoch 通常是最优的。超过 3 个 epoch 几乎总会导致过拟合,表现为:

  • 训练损失继续下降,但验证损失开始上升
  • 模型输出变得重复、风格单一
  • 在新任务上的表现下降

序列长度(Sequence Length)

序列长度直接影响显存训练效率

显存batch_size×seq_length2(注意力机制)\text{显存} \propto \text{batch\_size} \times \text{seq\_length}^2 \quad (\text{注意力机制})
序列长度显存影响适用场景
512短指令、单轮问答
1024-2048多轮对话、一般任务
4096长文档、复杂推理
8192+极高需要 Flash Attention

选择策略:

  1. 分析训练数据的 token 长度分布
  2. 选择能覆盖 90-95% 数据的序列长度
  3. 超过序列长度的样本会被截断
# 分析数据长度分布来确定序列长度
import numpy as np

lengths = [len(tokenizer.encode(sample["text"])) for sample in dataset]
p95 = np.percentile(lengths, 95)
print(f"95th percentile length: {p95:.0f}")
# 选择 max_seq_length 为 p95 的值,向上取整到 2 的幂次

LoRA 超参数

参数推荐值影响
rank (r)16-64表达能力。推荐起始值:32
alpha (α)2 × r缩放因子。alpha/r 控制更新幅度
target_modules所有线性层覆盖范围越广效果越好
dropout0.05-0.1防止过拟合。数据少时可调高

rank 与 alpha 的关系

实际更新幅度=αr×BA\text{实际更新幅度} = \frac{\alpha}{r} \times BA

α=2r\alpha = 2r 时,实际缩放为 2。如果增大 rr 但保持 α/r\alpha/r 不变,则需要相应增大 α\alpha

# 适合:数据少(<5K)、模型小(<3B)、初次尝试
LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
)
# 学习率: 1e-5, epochs: 3
# 适合:中等数据(5K-50K)、中等模型(3B-8B)
LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
)
# 学习率: 2e-5, epochs: 1-2
# 适合:大量数据(>50K)、充足计算资源
LoraConfig(
    r=64,
    lora_alpha=128,
    lora_dropout=0.0,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
)
# 学习率: 5e-5, epochs: 1

常见问题诊断

问题诊断表

症状可能原因解决方法
训练损失不下降学习率太小增大学习率(如 2e-5 → 5e-5)
数据格式错误检查 chat template 和掩码
梯度消失检查量化配置、使用 BF16
训练损失下降后又上升学习率太大减小学习率(如 5e-5 → 1e-5)
训练轮数过多减少 epochs 或使用 early stopping
模型输出大量重复过拟合减少 epochs、增大 dropout
temperature 太低推理时增大 temperature 到 0.7+
数据多样性不足增加数据来源或样本数
模型"忘记"预训练知识灾难性遗忘减小学习率、减少 epochs
LoRA rank 太大减小 rank(64 → 32 → 16)
验证损失持续上升过拟合经典过拟合,减少训练量
验证集分布不同检查数据划分
回复风格不自然数据质量差清洗数据、增加高质量样本
格式模板错误检查 chat template 是否正确

损失曲线解读

Loss
3.0 │╲
    │ ╲
2.0 │  ╲ ──── Train Loss
    │   ╲
1.5 │    ╲── ── Eval Loss
    │     ╲_________
1.0 │      ──────────
    └──────────────────→ Steps

特征:训练损失平滑下降,验证损失跟随下降,两者差距不大。

Loss
3.0 │╲
    │ ╲
2.0 │  ╲       ╱── Eval Loss(上升!)
    │   ╲   ╱──
1.5 │    ╲╱
    │     ╲────── Train Loss(继续下降)
0.5 │      ──────
    └──────────────────→ Steps

特征:训练损失持续下降,但验证损失在某一点后开始上升。应在验证损失最低点停止。

Loss
3.0 │╲ ╱╲ ╱╲
    │ ╳  ╳  ╲╱╲
2.0 │         ╲  ╱╲
    │          ╳   ╲
1.5 │              ───
    └──────────────────→ Steps

特征:损失剧烈震荡,不稳定。需要减小学习率。

Loss
3.0 │
    │────────────────── Loss 几乎不动
2.8 │
    │──────────────────
2.6 │
    └──────────────────→ Steps

特征:损失几乎不下降。通常是数据格式错误——模型在非 assistant token 上也计算损失,或者 chat template 不匹配。


超参数搜索策略

推荐的搜索顺序

第一步:固定其他参数,搜索学习率

0.00005 中搜索(QLoRA 在 0.0003)。选择验证损失最低的。

第二步:确定训练轮数

用最优学习率训练 3 个 epoch,观察验证损失何时开始上升,确定最佳 epoch 数。

第三步:调整 LoRA rank

64 中尝试,观察性能与效率的权衡。

第四步:微调批量大小

64 中尝试(通过梯度累积实现)。

快速验证技巧

# 在开始正式训练前,做一个快速的"烟雾测试"
# 仅训练 100 步,检查损失是否正常下降

sft_config = SFTConfig(
    output_dir="./smoke_test",
    max_steps=100,              # 仅 100 步
    logging_steps=10,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    # ... 其他参数
)

# 如果 100 步后损失没有明显下降 → 检查数据格式和配置
# 如果损失正常下降 → 继续完整训练

完整超参数配置模板

from trl import SFTConfig

sft_config = SFTConfig(
    # === 输出 ===
    output_dir="./sft_output",

    # === 训练量 ===
    num_train_epochs=1,                   # 高质量数据通常 1-2 epoch
    # max_steps=-1,                       # 也可以按步数控制

    # === 批量大小 ===
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,         # 有效批量 = 4 × 8 = 32

    # === 学习率 ===
    learning_rate=2e-5,                   # LoRA; QLoRA 用 2e-4
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.01,

    # === 精度与效率 ===
    bf16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},

    # === 序列配置 ===
    max_seq_length=2048,
    dataset_text_field="text",
    packing=False,                        # True 可提升吞吐量但需谨慎

    # === 日志与保存 ===
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",

    # === 其他 ===
    report_to="wandb",                    # 或 "none"
    seed=42,
)

本节小结

超参数推荐值影响
学习率2e-5(LoRA) / 2e-4(QLoRA)最敏感,过大遗忘,过小欠拟合
有效批量16-32影响训练稳定性
Epochs1-2过多导致过拟合
序列长度覆盖 95% 数据影响显存和效率
LoRA rank32表达能力与效率的权衡
LoRA alpha2 × rank控制更新幅度