不是RNN的鍋!清華團(tuán)隊(duì)深入分析長上下文建模中的狀態(tài)崩潰,Mamba作者點(diǎn)贊
與Transformer相比,RNN模型的一大優(yōu)勢是應(yīng)對長序列的能力。
比如Mamba,內(nèi)部狀態(tài)大小始終保持不變,計(jì)算隨序列長度線性增長,吃得多,消化快。
理論雖如此,但實(shí)際情況卻是,目前的這些RNN模型在長上下文中的有效性并不能令人滿意。
為啥會(huì)這樣?空有效率但實(shí)際上能力不行?
近日,來自清華的研究團(tuán)隊(duì)對此進(jìn)行了深入的實(shí)驗(yàn)研究:
論文地址:https://arxiv.org/pdf/2410.07145v1
文章表明,Mamba這類RNN模型在長上下文中主要面臨兩個(gè)問題:
一是無法推斷比訓(xùn)練長度更長的輸入,原因是較短的訓(xùn)練數(shù)據(jù)導(dǎo)致了循環(huán)狀態(tài)過擬合;
二是內(nèi)存容量的上限,由于模型無法有效遺忘很久以前的信息,導(dǎo)致新的信息存不進(jìn)來了。
——這倆問題明顯不是RNN的鍋。
而經(jīng)過研究人員的對癥下藥,Mamba-2(370M)在256K上下文長度上達(dá)到了近乎完美的密鑰檢索精度。
所以結(jié)論就是,Mamba yes!「RNN神教」前景一片光明!
對此,Mamba的作者Albert Gu點(diǎn)贊轉(zhuǎn)發(fā),并發(fā)表了相當(dāng)詳細(xì)的見解:
「這是一篇很棒的論文(名字也很棒)—— 關(guān)于狀態(tài)空間模型(SSM)的狀態(tài)容量和長上下文能力的巧妙實(shí)驗(yàn)。」
令人驚訝的是,對于每個(gè)狀態(tài)大小 M,當(dāng)訓(xùn)練上下文長度達(dá)到或超過某個(gè)臨界值 K 時(shí),都會(huì)出現(xiàn)一個(gè)轉(zhuǎn)折點(diǎn),在這個(gè)點(diǎn)上 SSM 就能夠穩(wěn)健地實(shí)現(xiàn)長度泛化。
這是因?yàn)楫?dāng)上下文長度小于 K 時(shí),循環(huán)狀態(tài)沒有被充分利用,導(dǎo)致模型在訓(xùn)練期間會(huì)「過擬合」。但一旦通過足夠長序列的訓(xùn)練使模型的狀態(tài)容量得到充分利用,它就會(huì)自動(dòng)獲得泛化能力。
值得注意的是,K 與 M 竟然呈線性關(guān)系!—— 這表明每個(gè) token 可能存在某種固有的信息含量(即存在一個(gè)值 B,使得上下文中的每個(gè) token 對應(yīng) B 字節(jié)的循環(huán)狀態(tài))。這個(gè) B 值可能是由模型架構(gòu)決定的?
「反過來說,過分擔(dān)心循環(huán)模型的長度泛化問題可能是一個(gè)誤區(qū)。我們無需設(shè)計(jì)新機(jī)制或特殊的緩解措施:只需要在更長的序列上訓(xùn)練(因?yàn)槭蔷€性時(shí)間復(fù)雜度,所以不會(huì)增加計(jì)算開銷!),就能獲得更好的泛化效果。」
最后,Albert Gu用一句話總結(jié):要讓你的Mamba吃得飽飽的,它就能發(fā)揮出最佳狀態(tài)!
喂飽你的Mamba
先來復(fù)習(xí)一下基礎(chǔ)知識。
本文以Mamba2作為主要研究對象,內(nèi)部的計(jì)算表示為下圖中的并行結(jié)構(gòu):
整體的輸入輸出遵循SSM(也即RNN)的形式:
而把上圖中模塊內(nèi)部所有的計(jì)算寫出來,就是下面這一坨公式:
之前提到的兩個(gè)問題,核心在于模型的內(nèi)部狀態(tài),也就是ht的表現(xiàn)。
所以下面在探索問題和解決方案時(shí),咱們可以重點(diǎn)關(guān)注這些公式中,與ht計(jì)算相關(guān)的參數(shù)。
之前有研究表明,當(dāng)上下文長度超過其訓(xùn)練長度時(shí),Mamba-1和RWKV-4的性能會(huì)嚴(yán)重下降。
順著這個(gè)思路,研究人員在兩個(gè)方向上進(jìn)行了實(shí)驗(yàn)分析:狀態(tài)崩潰(STATE COLLAPSE)和容量上限(STATE CAPACITY)。
狀態(tài)崩潰
狀態(tài)崩潰(SC)指的是,RNN模型在輸入上表現(xiàn)出異常行為的時(shí)間比訓(xùn)練期間看到的時(shí)間更長的現(xiàn)象。
上圖展示了Mamba-2和RWKV-6在訓(xùn)練長度之外的語言建模損失。為了可控性和合成任意長度的提示,這個(gè)損失是在僅由「\n」字符組成的提示上計(jì)算的(稱為「newlines」提示)。
結(jié)果表明,當(dāng)上下文長度遠(yuǎn)大于其訓(xùn)練長度時(shí),兩個(gè)RNN的性能都會(huì)嚴(yán)重下降,最后就跟瞎猜差不多了。
語言建模可能無法反映下游能力,上圖給出了Mamba-2(在8K上下文窗口上訓(xùn)練)在密鑰檢索任務(wù)上的評估結(jié)果。
我們可以發(fā)現(xiàn),Mamba-2在8K上下文中具有近乎完美的檢索準(zhǔn)確性,但在序列長度超過16K后就沒法看了,無論模型參數(shù)量大小。
從上面的公式來看,這種結(jié)果可能出人意料,因?yàn)閮?nèi)部狀態(tài)ht的更新應(yīng)該具有穩(wěn)定的指數(shù)內(nèi)存衰減,即對于最后k個(gè)token具有良好的檢索準(zhǔn)確性。
問題出在哪里?
由于遞歸狀態(tài)的維度不會(huì)隨時(shí)間而變化,因此狀態(tài)崩潰期間行為的急劇變化一定是狀態(tài)值變化的結(jié)果。
作者對Mamba-2 370M中每一層的遞歸狀態(tài)進(jìn)行了統(tǒng)計(jì),發(fā)現(xiàn)當(dāng)上下文長度超過訓(xùn)練長度時(shí),一些頭部的平均值和方差會(huì)急劇變化:
圖5顯示了模型第38層第2個(gè)頭的狀態(tài),在t=20K時(shí)方差爆炸。從中可以發(fā)現(xiàn)這種方差爆炸在很大程度上可以歸因于少數(shù)異常通道,其余大多數(shù)通道則相對穩(wěn)定。
分析一下公式,與ht計(jì)算有關(guān)的?t、Bt和xt:
如上圖所示,雖然三者都是輸入的函數(shù),但xt相對穩(wěn)定,而Bt比?t更早發(fā)生爆炸,進(jìn)一步探索還能發(fā)現(xiàn)生成?t和Bt的卷積權(quán)重明顯更大。
作者認(rèn)為,產(chǎn)生SC的原因是,對于訓(xùn)練長度來說,狀態(tài)容量過大,模型能夠?qū)崿F(xiàn)強(qiáng)大的語言建模性能,而無需學(xué)習(xí)如何忘記。
上圖顯示了第一個(gè)token在不同時(shí)間步的內(nèi)存強(qiáng)度,作者發(fā)現(xiàn)爆炸的頭(第38層的第2、4、7個(gè)頭)強(qiáng)烈傾向于在訓(xùn)練長度內(nèi)保留所有信息,在t=8K時(shí)內(nèi)存強(qiáng)度超過0.8。
解決方案
為了緩解SC,使模型沿序列長度更好地泛化,作者提出了3種解決方案,總的思想是修改狀態(tài)的update規(guī)則來避免其溢出。
Method 1: Forget More and Remember Less
通過增加狀態(tài)衰減量(忘記更多)或減少輸入信息的數(shù)量(記住更少)來減少SC,作者選擇干預(yù)Bt和αt(分別控制輸入強(qiáng)度和內(nèi)存衰減強(qiáng)度)。
Method 2: State Normalization
在每次更新后對狀態(tài)進(jìn)行歸一化,以確保狀態(tài)的范數(shù)始終低于閾值:
PS:這種方式會(huì)將模型轉(zhuǎn)換為非線性RNN,無法以與原始模型相同的方式并行化,預(yù)填充速度要慢得多。
Method 3: Sliding Window by State Difference
利用狀態(tài)ht可以寫為加權(quán)和的形式,來模擬滑動(dòng)窗口機(jī)制,無需在每一步都從窗口的開頭重新處理。
此方法適用于所有可以寫成加權(quán)和的RNN,包括RWKV 5和6、RetNet、GLA等。盡管會(huì)使生成的計(jì)算和內(nèi)存成本翻倍,但仍然是一個(gè)可以接受的權(quán)衡,因?yàn)镽NN的生成成本比Transformer低很多。
以上3個(gè)是不需要訓(xùn)練的方案,而基于SC是由狀態(tài)參數(shù)過擬合引起的假設(shè),我們也可以嘗試使用超過狀態(tài)容量的序列長度來訓(xùn)練模型。
容量上限
根據(jù)以上的討論,當(dāng)且僅當(dāng)訓(xùn)練長度包含的信息少于狀態(tài)容量時(shí),才會(huì)發(fā)生SC,所以我們可以通過實(shí)驗(yàn)間接估計(jì)模型的狀態(tài)容量。
研究人員訓(xùn)練了多個(gè)具有不同狀態(tài)大小和訓(xùn)練長度的Mamba-2,并將SC未發(fā)生的最小訓(xùn)練長度視為狀態(tài)容量。
實(shí)驗(yàn)數(shù)據(jù)選擇RedPajama-V2,一個(gè)從CommonCrawl中提取的30T token的開放數(shù)據(jù)集,進(jìn)行去重以確保數(shù)據(jù)質(zhì)量。
在評估過程中,對長度超過16K token的文檔進(jìn)行抽樣,如果不夠長,則對其進(jìn)行拼接。
研究人員試驗(yàn)了具有不同狀態(tài)大小的模型配置,包括來自Mamba-2官方checkpoint的三個(gè)預(yù)訓(xùn)練模型,大小分別為130M、370M和780M,另外3個(gè)模型(36M、47M、85M)則從頭開始訓(xùn)練。
實(shí)驗(yàn)結(jié)果
上圖展示了在Mamba-2 780M上無訓(xùn)練長度泛化方法的結(jié)果。我們可以看到,雖然LongMamba大大提高了模型的長度泛化性(3倍以上),但它在較短的序列上會(huì)導(dǎo)致明顯更大的困惑度,并且仍然不可避免地表現(xiàn)出SC。
相比之下,本文的所有的方法都成功地抑制了SC,使模型能夠泛化到超過64K個(gè)token。
三種方案中,狀態(tài)歸一化在較短序列上的性能大大低于其他方法,這可能是因?yàn)闅w一化折疊狀態(tài)會(huì)改變heads之間的規(guī)范比率,破壞了學(xué)習(xí)機(jī)制。
上圖顯示了Mamba-2在語言建模和密鑰檢索方面的狀態(tài)容量。兩個(gè)圖中最右邊的數(shù)據(jù)點(diǎn)對應(yīng)于Mamba-2 370M。
左邊的圖可以擬合出一個(gè)線性關(guān)系,而右邊的圖則表明Mamba-2在密鑰檢索方面的容量與狀態(tài)大小呈指數(shù)級關(guān)系。
這是因?yàn)樯舷挛闹械男畔⒘坎粫?huì)隨著其長度的增加而增加。換句話說,模型存儲了恒定數(shù)量的信息,而狀態(tài)的組合數(shù)量隨著元素?cái)?shù)量呈指數(shù)增長。