如何在LLM訓(xùn)練過程中精妙設(shè)計(jì)SFT與RL步驟—— LLM訓(xùn)練框架推薦 原創(chuàng)
一種可以“自適應(yīng)切換SFT與RL”的訓(xùn)練框架分享。
大家應(yīng)該都還記得,DeepSeek-R1的“SFT->RL->增強(qiáng)SFT->增強(qiáng)RL”這種左腳踩右腳直接起飛的操作,這說明監(jiān)督微調(diào)(SFT)與強(qiáng)化學(xué)習(xí)(RL)交替訓(xùn)練的訓(xùn)練范式確實(shí)可以提高模型性能。
很多大佬也有自己做小規(guī)模實(shí)驗(yàn),在進(jìn)行新的訓(xùn)練范式探索:
- 預(yù)訓(xùn)練后做兩次SFT接一次RL
- 預(yù)訓(xùn)練后先RL再SFT
- ....
那么如何設(shè)計(jì)訓(xùn)練框架能實(shí)現(xiàn)效果最優(yōu)呢?
本篇分享一種可以“自適應(yīng)切換SFT與RL”的訓(xùn)練框架;這是念空科技聯(lián)合上海交通大學(xué)計(jì)算機(jī)學(xué)院投的新論文 《Step-wise Adaptive Integration of Supervised Fine-tuning and Reinforcement Learning for Task-Specific LLMs》。
下面是一個(gè)快捷目錄。
1. 待解決的問題
2. 論文方法
3. 實(shí)驗(yàn)結(jié)果
4. 其他可發(fā)散的點(diǎn)
一、待解決的問題
目前這種 “固定步驟的SFT和RL交替” 靜態(tài)混合訓(xùn)練方法可能會(huì)帶來一些問題,比如,一種訓(xùn)練范式直接切換到另一種時(shí),可能會(huì)導(dǎo)致模型下降;不同階段任務(wù)著重訓(xùn)練的知識(shí)不同,模型很可能災(zāi)難性遺忘或者陷入局部最優(yōu)等,最終影響訓(xùn)練的連續(xù)性和穩(wěn)定性。
這篇論文主要解決的就是如何設(shè)計(jì)訓(xùn)練步驟的問題:如何設(shè)計(jì)一個(gè)最優(yōu)的訓(xùn)練框架來保證LLM的訓(xùn)練穩(wěn)定性。
二、論文方法
論文提出了一個(gè)名為SASR(Step-wise Adaptive Integration of Supervised Fine-tuning and Reinforcement Learning)的逐步自適應(yīng)混合訓(xùn)練框架,通過理論統(tǒng)一監(jiān)督微調(diào)(SFT)和強(qiáng)化學(xué)習(xí)(RL),并動(dòng)態(tài)平衡兩者在整個(gè)優(yōu)化過程中的比例。
主要包含兩個(gè)階段:
第一階段:Warm-up Phase
首先使用小規(guī)模的(問題,鏈?zhǔn)剿伎迹?shù)據(jù)對進(jìn)行SFT,以建立模型的基本推理能力。這些數(shù)據(jù)對包括輸入問題的標(biāo)記序列和對應(yīng)的鏈?zhǔn)剿伎纪评砺窂剑瑤椭P蛯W(xué)習(xí)結(jié)構(gòu)化的問題解決策略。
在第一階段中通過最小化負(fù)對數(shù)似然(NLL)損失來最大化真實(shí)序列的似然,從而更新模型參數(shù)。
loss長這樣,at是思維鏈中的token第t個(gè)token標(biāo)記,st是步驟t中的上下文狀態(tài),包括之前所有生成的標(biāo)記。
第二階段:Hybrid Training Phase
在Warm-up之后,逐步開始自適應(yīng)混合訓(xùn)練,把SFT和GRPO結(jié)合起來。
GRPO通過組間比較擴(kuò)展策略優(yōu)化,通過采樣當(dāng)前和舊策略的輸出,并根據(jù)相對優(yōu)勢將它們分為高優(yōu)勢組和低優(yōu)勢組,然后結(jié)合優(yōu)勢最大化和KL正則化來更新策略。
另外此階段根據(jù)當(dāng)前模型的訓(xùn)練狀態(tài)來動(dòng)態(tài)調(diào)整SFT和GRPO的比例。具體來說,通過比較當(dāng)前梯度范數(shù)與Warm-up階段記錄的梯度范數(shù),動(dòng)態(tài)更新兩者的比例。
loss長這樣, πθold 是更新前的上一個(gè)策略,πref 表示參考策略(通常是初始 SFT 模型),ε控制策略更新的裁剪范圍,β調(diào)整 KL 正則化的強(qiáng)度。比率 πθ πθold 衡量每個(gè)step的新策略與舊策略的偏差程度。
那么如何進(jìn)行動(dòng)態(tài)比例的分配呢?主要通過監(jiān)測訓(xùn)練過程中的梯度范數(shù)和模型策略相對于原始數(shù)據(jù)分布的KL散度,當(dāng)模型與原始數(shù)據(jù)分布的偏差較大時(shí),增加SFT的權(quán)重;當(dāng)模型接近原始數(shù)據(jù)分布時(shí),增加GRPO的權(quán)重。
最終整體損失函數(shù) L(θ)如下
這里引入了 I(t) 作為狀態(tài)函數(shù),它根據(jù)當(dāng)前模型的訓(xùn)練狀態(tài) t 返回訓(xùn)練范式?jīng)Q策變量 I(t)。
與傳統(tǒng)的 Hybrid方法在一個(gè) epoch 內(nèi)使用固定的訓(xùn)練范式相比,SASR 采用更細(xì)粒度的訓(xùn)練步驟 s 作為訓(xùn)練單元,可實(shí)現(xiàn)更靈活的自適應(yīng)調(diào)整。
下面這段偽代碼可以輔助大家很快理解他的思路。
另外論文還進(jìn)行了理論分析與實(shí)驗(yàn)驗(yàn)證,建立了SFT損失的梯度范數(shù)與KL散度之間的關(guān)系,證明了SASR在避免SFT引起的過擬合、緩解RL導(dǎo)致的模型坍塌以及克服靜態(tài)混合訓(xùn)練的局限的優(yōu)勢。
三、實(shí)驗(yàn)結(jié)果
模型設(shè)計(jì)了三個(gè)實(shí)驗(yàn):
- GSM8K(小學(xué)水平數(shù)學(xué)算術(shù))+ DeepSeek-R1-Distill-Qwen-1.5B模型:模型的準(zhǔn)確率從63.8%提高到80.3%,接近GPT-4o的水平
- KK(邏輯推理)+ Qwen2.5-1.5B-Instruct模型:平均準(zhǔn)確率提升9%,超過了GPT-4o
- MATH(數(shù)學(xué)競賽、公式)+ Qwen2.5-0.5B-Instruct模型:平均準(zhǔn)確率提升了9%,超過了GPT-4o
四、其他可發(fā)散的點(diǎn)
這篇論文感覺還是有很多可以繼續(xù)去發(fā)散的,比如跟除了GPRO的其他強(qiáng)化學(xué)習(xí)算法結(jié)合,推廣到多模態(tài),改進(jìn)動(dòng)態(tài)調(diào)整策略等等。有想法的朋友們可以一起交流一下~
參考文獻(xiàn)
[1] ???https://arxiv.org/pdf/2505.13026??
本文轉(zhuǎn)載自??瓦力算法學(xué)研所??,作者:喜歡瓦力的卷卷
