「古董」GPU也能跑DeepSeek同款GRPO!顯存只需1/10,上下文爆漲10倍
開(kāi)源微調(diào)神器Unsloth帶著黑科技又來(lái)了:上次更新把GRPO需要的內(nèi)存見(jiàn)到了7GB,這次只需要5GB的VRAM,就能訓(xùn)練自己的推理模型Qwen2.5(1.5B),比上次要少2GB。
這次徹底把推理模型訓(xùn)練顯存打下來(lái)了!
這次把GRPO訓(xùn)練推理模型的上下文變長(zhǎng)10倍,同時(shí)需要的顯存少了90%。
使用最新的Unsloth,只要5GB顯存就能訓(xùn)練自己的推理模型,而且Qwen2.5-1.5B不會(huì)損失準(zhǔn)確率。
5GB顯存什么概念呢?
16年開(kāi)始發(fā)售的GPU比如GTX 1060的顯存都有8GB。16年GTX 1060放到現(xiàn)在,堪稱(chēng)電子古董!
目前,實(shí)現(xiàn)更長(zhǎng)的上下文是GRPO面臨的最大挑戰(zhàn)之一。
與其他GRPO LoRA/QLoRA實(shí)現(xiàn)相比,即使是基于Flash Attention 2(FA2)的實(shí)現(xiàn),Unsloth新推出的高效GRPO算法上下文長(zhǎng)度增加了10倍,同時(shí)使用的VRAM只要10%。
在配備TRL+FA2的GRPO設(shè)置中,Llama 3.1(8B)在20K上下文長(zhǎng)度下,訓(xùn)練需要510.8GB的VRAM。
而Unsloth將VRAM減少了90%,降至僅54.3GB。
減少長(zhǎng)上下文90%VRAM
和使用Flash Attention 2的標(biāo)準(zhǔn)實(shí)現(xiàn)相比,Unsloth使用多種技巧,巧妙地把GRPO的VRAM使用量減少了90%多!
在20K的上下文長(zhǎng)度下,每個(gè)提示生成8次,Unsloth在Llama-3.1-8B模型上僅使用54.3GB的VRAM,而標(biāo)準(zhǔn)實(shí)現(xiàn)需要510.8GB(Unsloth減少了90%)。這一切得益于下列3項(xiàng)突破:
- 全新設(shè)計(jì)的內(nèi)存高效線性算法:將GRPO的內(nèi)存使用量削減了8倍以上,節(jié)省了68.5GB的內(nèi)存。借助torch.compile,在num_generatinotallow=8和20K上下文長(zhǎng)度下,實(shí)際上還更快。
- 利用了Unsloth已發(fā)布的智能梯度checkpoint算法:將中間激活值異步卸載到系統(tǒng)RAM中,速度僅慢了1%。由于需要num_generatinotallow=8,這節(jié)省了高達(dá)372GB的VRAM。通過(guò)中間梯度累積,甚至可以進(jìn)一步減少內(nèi)存使用。
- 與底層推理引擎(vLLM)共享相同的GPU/CUDA內(nèi)存空間,不像其他包中的實(shí)現(xiàn)那樣。這又節(jié)省了16GB的VRAM。
Unsloth和基于Flash Attention 2(FA2)的標(biāo)準(zhǔn)實(shí)現(xiàn)內(nèi)存比較
在典型的GRPO標(biāo)準(zhǔn)實(shí)現(xiàn)中,需要?jiǎng)?chuàng)建兩個(gè)大小為(8,20K)的logits來(lái)計(jì)算GRPO損失。這需要2*2字節(jié)*8(生成次數(shù))*20K(上下文長(zhǎng)度)*128256(詞匯表大小)=78.3GB的VRAM。
Unsloth將長(zhǎng)上下文GRPO的內(nèi)存使用量削減了8倍,因此對(duì)于20K的上下文長(zhǎng)度,只需要額外的9.8GBVRAM!
還需要以16位格式存儲(chǔ)KV緩存。Llama3.18B有32層,K和V的大小均為1024。因此,對(duì)于20K的上下文長(zhǎng)度,內(nèi)存使用量=2*2字節(jié)*32層*20K上下文長(zhǎng)度*1024=每個(gè)批次2.5GB。
可以將vLLM的批次大小設(shè)置為8,但為了節(jié)省VRAM,在計(jì)算中將其保持為1。否則,需要20GB來(lái)存儲(chǔ)KV緩存。
數(shù)學(xué)原理
分組相對(duì)策略?xún)?yōu)化(Group Relative Policy Optimization,GRPO),出自DeepSeek去年發(fā)表的論文。
如果一生只能讀一篇DeepSeek的論文,網(wǎng)友建議選擇首次提出GRPO的DeepSeekMath論文。
論文鏈接:https://arxiv.org/abs/2402.03300
隨后在DeepSeek的論文中,利用GRPO算法創(chuàng)建了DeepSeek-R1。
發(fā)現(xiàn)的問(wèn)題
在這里利用了Hugging Face的TRL GRPO實(shí)現(xiàn)。
注意到,TRL實(shí)現(xiàn)的公式如下:
其中使用的是反向KL散度(而不是正向KL散度)。β是一個(gè)設(shè)為0.04的縮放因子,A是考慮所有獎(jiǎng)勵(lì)函數(shù)后得到的優(yōu)勢(shì)值。q是新訓(xùn)練的模型,P是原始參考模型。
然后注意到,該實(shí)現(xiàn)將反向KL散度計(jì)算為:
但這真的是正確的嗎?
首先嘗試推導(dǎo)并整理類(lèi)似項(xiàng):
這意味著什么?實(shí)現(xiàn)中可能缺少一個(gè)與q(新分布項(xiàng))的乘法嗎?
但這似乎是正確的,和DeepSeek-Math論文第14頁(yè)首次引入GRPO時(shí)一樣。
DeepSeek-Math論文第14頁(yè):在損失函數(shù)中添加KL散度,正則化GRPO算法
同樣,John Schulman的博客也提到,反向KL項(xiàng)的無(wú)偏估計(jì),實(shí)際上并不需要額外的q項(xiàng)。
鏈接地址:http://joschu.net/blog/kl-approx.html
在博客中看到:
還發(fā)現(xiàn)了一個(gè)有趣的現(xiàn)象:
torch.exp(q-q.detach()) * advantages.unsqueeze(1)
這應(yīng)該等于1,對(duì)嗎?
Hugging Face的TRL GRPO實(shí)現(xiàn)
實(shí)際上,發(fā)現(xiàn)這是必要的——似乎自動(dòng)梯度autograd引擎可能無(wú)法正確傳播梯度。
因此,進(jìn)行了4個(gè)實(shí)驗(yàn):
- 使用參考實(shí)現(xiàn)的常規(guī)GRPO(紅線)
- 移除detach代碼(藍(lán)線)
- 按照之前討論的完整反向KL,添加額外項(xiàng)(黃線)
- 使用正向KL散度代替(綠線)
總體來(lái)說(shuō),移除detach顯然會(huì)破壞訓(xùn)練,所以必須保留它——這很可能需要進(jìn)一步調(diào)查。其他實(shí)現(xiàn)似乎也類(lèi)似?可能需要運(yùn)行模型更長(zhǎng)時(shí)間,以觀察不同的效果。
在所有實(shí)現(xiàn)中,還利用了logsumexp技巧:
Unsloth高效GRPO算法
但沒(méi)想到華人工程師Horace He的線性交叉熵實(shí)現(xiàn),帶給unsloth靈感并成功應(yīng)用于GRPO!
Horace He,在Meta從事PyTorch相關(guān)工作
實(shí)際上,unsloth發(fā)現(xiàn)了一些令人驚訝的要點(diǎn):
1 GRPO參考實(shí)現(xiàn)使用的是反向KL散度,而不是正向KL散度。
2 如果不正確處理,在float16混合精度(以及float8)上直接實(shí)現(xiàn)線性交叉熵,并使用自動(dòng)混合精度縮放機(jī)制,會(huì)導(dǎo)致崩潰。
3 發(fā)現(xiàn)了GRPO損失實(shí)現(xiàn)中的其他一些奇怪之處,主要是在反向KL散度的公式表述方面。
線性交叉商鏈接:https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899
其他功能
GRPO的完整日志記錄
之前,unsloth只顯示總聚合獎(jiǎng)勵(lì)函數(shù)本身,新版本為所有獎(jiǎng)勵(lì)函數(shù)提供完整的日志記錄詳情!
也不再需要調(diào)用函數(shù)來(lái)給GRPO打補(bǔ)丁了!也就是說(shuō),新版本會(huì)自動(dòng)處理,可以刪除下列代碼:
from unsloth import PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
vLLM推理選項(xiàng)
現(xiàn)在在vLLM中還能使用FP8 KV緩存,這可以在較新的GPU(RTX 3090、A100及更新型號(hào))上將KV緩存空間使用量減少2倍。
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True,
fast_inference = True,
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.6,
float8_kv_cache = True,
)
如果想在vLLM中使用min_p=0.1或其他采樣參數(shù),也支持傳遞vLLM的SamplingParams參數(shù)中的任何內(nèi)容!
max_prompt_length = 256
from trl import GRPOConfig, GRPOTrainer
from unsloth import vLLMSamplingParams
vllm_sampling_params = vLLMSamplingParams(
min_p = 0.1,
seed = 3407,
...
)
training_args = GRPOConfig(
...
vllm_sampling_params = vllm_sampling_params,
temperature = 1.5,
)