Skip to content

Latest commit

 

History

History
103 lines (67 loc) · 6.54 KB

强化微调.md

File metadata and controls

103 lines (67 loc) · 6.54 KB

强化微调

强化微调是目前模型训练非常重要的功能之一,它本身的实现是多种多样的,SWIFT目前已经支持了强化微调所需要的原子能力,如采样、强化学习和微调。目前我们提供了拒绝采样微调的一个具体示例,可以查看这里

强化微调的概念

强化微调是从2022年开始(甚至更早)就被提出的概念。其方式一般有下列流程:

  1. 使用某个模型生成数据,或进行原始数据扩充
  2. 使用数据训练目标模型
  3. 如果有必要,重复上述过程

步骤1:

  • 如果生成数据的模型是更大的模型,如GPT、Qwen-Max、DeepSeek-V3/R1等,则该强化微调可以理解为蒸馏
  • 如果生成数据的模型是本模型,则可以理解为自我提升(self-improvement)微调
  • 如果采样过程是采样一个batch,然后通过KL散度和reward进行拟合训练并不断循环,则可以理解为PPO、GRPO等on-policy算法
  • 采样数据的算法包含蒙特卡洛采样、do_sample采样、group beam search、dvts等
  • 采样过程可以引入ORM(结果判断),PRM(过程打分),多样性过滤,语种过滤等

步骤2:

  • 如果使用SFT,则称为拒绝采样微调
  • 如果是强化学习,则称为强化学习微调

步骤3:

  • 如果使用更大的模型蒸馏,例如更大模型的蒙特卡洛采样蒸馏,一般不会有循环
  • 如果使用本模型进行采样,或者PPO等算法,则会有循环

泛泛来说,常见强化微调的方式有下面几种:

  1. 蒸馏:使用蒙特卡洛、do_sample等方式从超大模型中采样大量优质数据,训练小模型
  2. 自我提升:从本模型中采样部分优质数据,筛选后训练本模型,循环执行
  3. on-policy RL:使用PPO、GRPO等方式循环训练

采样过程一般很漫长,比训练过程漫长的多。如果使用GPT等模型蒸馏数据,则需要购买token。因此,强化微调的时间成本和花费成本比较高,所以一般作为微调的补充机制出现,当然也有特例,例如最近的DeepSeek-R1。

DeepSeek-R1使用了GRPO算法从零使base模型涌现CoT能力,该方法需要大规模集群支持,且模型需要足够大才能发生能力涌现,在本文中不详细讨论。如果需要了解该过程,请查看论文解析

有关强化微调的一些论文:

什么时候使用强化微调

在LLaMA3之后,我们发现一个非常明显但却是不常被提及的特点:使用某个含有CoT的train数据集训练Instruct模型,再通过对应的test集进行评测,会发现test集评测效果变差。例如,使用gsm8k训练集训练llama3.1-8b-instruct,对生成的ckpt使用test集进行评测,会发现掉点。

这个特性主要来源于模型的知识遗忘问题。在模型厂商的微调中,会加入非常多的CoT数据集,模型在解决数学任务的时候,用到的能力很有可能不是来自于math数据集,而是来自arc数据集,这个推论有一些工作可以证明。在继续训练通用任务后,知识遗忘破坏了模型原有能力,导致了掉点。

然而,优先使用微调方式训练模型总是正确的。微调可以使模型快速适应数据集的分布,并且微调的成本很低。当有如下条件之一时使用强化微调:

  1. 已经微调过模型,能力不满足需求
  2. 需要更强的CoT能力
  3. 对基模型训练通用能力,而原始数据集已经导致模型效果无法提升
  4. 对应query的输出结果可以相对准确地评估好坏,例如结果清晰(数学,代码),过程清晰(翻译,风格)等

强化微调非常依赖于reward评估是否准确。如果评估结果不准确,可能导致模型训练原地震荡,甚至越训越差。

SWIFT的实现

SWIFT支持sample命令,该命令就是用于模型采样。目前支持的采样方式有:

  • do_sample:sample方式对模型进行采样,该方式支持对开源模型进行采样,后续会支持模型蒸馏

    • sample方式后续会支持URL采样,用于大模型蒸馏
  • mcts:蒙特卡洛采样,该方式在PR中,后续会支持

  • dvts:调研中

目前我们给出了一个较为通用的RFT脚本。该脚本适用于自我提升方式的训练,且支持动态调整采样温度值、PRM阈值等超参数,并且训练方式灵活可变(微调、DPO等;或者每次迭代重新训练原模型或继续训练上个迭代的模型,甚至加载上个迭代的所有训练状态等)。开发者可以在该脚本中增加其他数据过滤(生成的数据集中,id相同的行来自同一个query),例如多样性判断、语种判断等。

实验结果

我们对该RFT脚本针对数学领域使用competition_math数据集进行了训练和评测,结果如下:

模型 MATH指标 训练方式 迭代次数 训练后MATH指标
LLaMA3.1_8b 12.0 SFT 3 25.2(LLaMA3.1_8b_sft)
LLaMA3.1_8b_sft 25.2 RFT 2 32.4
LLaMA3.1_8b_instruct 52.2 SFT 2 39.0
LLaMA3.1_8b_instruct 52.2 RFT 3 58
Qwen2.5_math_7b_instruct 79.6 RFT 2 83.2

可以看到,使用competition_math直接SFT后,instruct模型的掉点十分严重。而RFT后模型能力有提升,即使对Qwen2.5_math_7b_instruct这个SOTA的math模型也同样有一定提升空间。

特别地,针对Qwen2.5_math_7b_instruct我们测试了gsm8k的指标:

模型 gsm8k指标 RFT后gsm8k指标
Qwen2.5_math_7b_instruct 92.8 91.6

可以看到,RFT训练后gsm8k指标变化不大,并没有出现前述的掉点现象。

未来计划

  1. 更多的采样方式,如MCTS
  2. 超大模型蒸馏训练
  3. 以PPO为主的on-policy训练