强化微调是目前模型训练非常重要的功能之一,它本身的实现是多种多样的,SWIFT目前已经支持了强化微调所需要的原子能力,如采样、强化学习和微调。目前我们提供了拒绝采样微调的一个具体示例,可以查看这里。
强化微调是从2022年开始(甚至更早)就被提出的概念。其方式一般有下列流程:
- 使用某个模型生成数据,或进行原始数据扩充
- 使用数据训练目标模型
- 如果有必要,重复上述过程
步骤1:
- 如果生成数据的模型是更大的模型,如GPT、Qwen-Max、DeepSeek-V3/R1等,则该强化微调可以理解为蒸馏
- 如果生成数据的模型是本模型,则可以理解为自我提升(self-improvement)微调
- 如果采样过程是采样一个batch,然后通过KL散度和reward进行拟合训练并不断循环,则可以理解为PPO、GRPO等on-policy算法
- 采样数据的算法包含蒙特卡洛采样、do_sample采样、group beam search、dvts等
- 采样过程可以引入ORM(结果判断),PRM(过程打分),多样性过滤,语种过滤等
步骤2:
- 如果使用SFT,则称为拒绝采样微调
- 如果是强化学习,则称为强化学习微调
步骤3:
- 如果使用更大的模型蒸馏,例如更大模型的蒙特卡洛采样蒸馏,一般不会有循环
- 如果使用本模型进行采样,或者PPO等算法,则会有循环
泛泛来说,常见强化微调的方式有下面几种:
- 蒸馏:使用蒙特卡洛、do_sample等方式从超大模型中采样大量优质数据,训练小模型
- 自我提升:从本模型中采样部分优质数据,筛选后训练本模型,循环执行
- on-policy RL:使用PPO、GRPO等方式循环训练
采样过程一般很漫长,比训练过程漫长的多。如果使用GPT等模型蒸馏数据,则需要购买token。因此,强化微调的时间成本和花费成本比较高,所以一般作为微调的补充机制出现,当然也有特例,例如最近的DeepSeek-R1。
DeepSeek-R1使用了GRPO算法从零使base模型涌现CoT能力,该方法需要大规模集群支持,且模型需要足够大才能发生能力涌现,在本文中不详细讨论。如果需要了解该过程,请查看论文解析。
有关强化微调的一些论文:
- 拒绝采样微调:https://arxiv.org/pdf/2308.01825
- ReST:https://arxiv.org/pdf/2308.08998
- B-STAR:https://arxiv.org/pdf/2412.17256
- DeepSeekMath:https://arxiv.org/pdf/2402.03300
- Qwen-math-PRM:https://arxiv.org/pdf/2501.07301
- DeepSeek-R1:https://github.com/deepseek-ai/DeepSeek-R1/tree/main
在LLaMA3之后,我们发现一个非常明显但却是不常被提及的特点:使用某个含有CoT的train数据集训练Instruct模型,再通过对应的test集进行评测,会发现test集评测效果变差。例如,使用gsm8k训练集训练llama3.1-8b-instruct,对生成的ckpt使用test集进行评测,会发现掉点。
这个特性主要来源于模型的知识遗忘问题。在模型厂商的微调中,会加入非常多的CoT数据集,模型在解决数学任务的时候,用到的能力很有可能不是来自于math数据集,而是来自arc数据集,这个推论有一些工作可以证明。在继续训练通用任务后,知识遗忘破坏了模型原有能力,导致了掉点。
然而,优先使用微调方式训练模型总是正确的。微调可以使模型快速适应数据集的分布,并且微调的成本很低。当有如下条件之一时使用强化微调:
- 已经微调过模型,能力不满足需求
- 需要更强的CoT能力
- 对基模型训练通用能力,而原始数据集已经导致模型效果无法提升
- 对应query的输出结果可以相对准确地评估好坏,例如结果清晰(数学,代码),过程清晰(翻译,风格)等
强化微调非常依赖于reward评估是否准确。如果评估结果不准确,可能导致模型训练原地震荡,甚至越训越差。
SWIFT支持sample命令,该命令就是用于模型采样。目前支持的采样方式有:
-
do_sample:sample方式对模型进行采样,该方式支持对开源模型进行采样,后续会支持模型蒸馏
- sample方式后续会支持URL采样,用于大模型蒸馏
-
mcts:蒙特卡洛采样,该方式在PR中,后续会支持
-
dvts:调研中
目前我们给出了一个较为通用的RFT脚本。该脚本适用于自我提升方式的训练,且支持动态调整采样温度值、PRM阈值等超参数,并且训练方式灵活可变(微调、DPO等;或者每次迭代重新训练原模型或继续训练上个迭代的模型,甚至加载上个迭代的所有训练状态等)。开发者可以在该脚本中增加其他数据过滤(生成的数据集中,id相同的行来自同一个query),例如多样性判断、语种判断等。
我们对该RFT脚本针对数学领域使用competition_math数据集进行了训练和评测,结果如下:
模型 | MATH指标 | 训练方式 | 迭代次数 | 训练后MATH指标 |
---|---|---|---|---|
LLaMA3.1_8b | 12.0 | SFT | 3 | 25.2(LLaMA3.1_8b_sft) |
LLaMA3.1_8b_sft | 25.2 | RFT | 2 | 32.4 |
LLaMA3.1_8b_instruct | 52.2 | SFT | 2 | 39.0 |
LLaMA3.1_8b_instruct | 52.2 | RFT | 3 | 58 |
Qwen2.5_math_7b_instruct | 79.6 | RFT | 2 | 83.2 |
可以看到,使用competition_math直接SFT后,instruct模型的掉点十分严重。而RFT后模型能力有提升,即使对Qwen2.5_math_7b_instruct这个SOTA的math模型也同样有一定提升空间。
特别地,针对Qwen2.5_math_7b_instruct我们测试了gsm8k的指标:
模型 | gsm8k指标 | RFT后gsm8k指标 |
---|---|---|
Qwen2.5_math_7b_instruct | 92.8 | 91.6 |
可以看到,RFT训练后gsm8k指标变化不大,并没有出现前述的掉点现象。
- 更多的采样方式,如MCTS
- 超大模型蒸馏训练
- 以PPO为主的on-policy训练