線(xiàn)性化注意力綜述:突破Softmax二次復(fù)雜度瓶頸的高效計(jì)算方案
大型語(yǔ)言模型在各個(gè)領(lǐng)域都展現(xiàn)出了卓越的性能,但其核心組件之一——softmax注意力機(jī)制在計(jì)算資源消耗方面存在顯著局限性。本文將深入探討如何通過(guò)替代方案實(shí)現(xiàn)線(xiàn)性時(shí)間復(fù)雜度,從而突破這一計(jì)算瓶頸。
注意力機(jī)制基礎(chǔ)理論
本文假設(shè)讀者已經(jīng)熟悉ChatGPT、Claude等模型及其底層的transformer架構(gòu)原理。注意力機(jī)制是這類(lèi)模型的核心組件。與傳統(tǒng)循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)將歷史信息壓縮存儲(chǔ)在固定維度的隱藏狀態(tài)中不同,注意力機(jī)制能夠直接訪(fǎng)問(wèn)和選擇性利用歷史信息。這種機(jī)制本質(zhì)上是在每次預(yù)測(cè)時(shí),根據(jù)當(dāng)前查詢(xún)動(dòng)態(tài)檢索最相關(guān)的歷史信息。
transformer架構(gòu)中的注意力機(jī)制通過(guò)鍵(key)、查詢(xún)(query)和值(value)三個(gè)嵌入向量實(shí)現(xiàn)信息的動(dòng)態(tài)檢索。具體而言transformer的注意力機(jī)制通過(guò)計(jì)算查詢(xún)向量與所有鍵向量的相似度,獲得注意力權(quán)重,再用這些權(quán)重對(duì)相應(yīng)的值向量進(jìn)行加權(quán)組合。這一計(jì)算過(guò)程可以形式化表示為:
這種機(jī)制使模型能夠在生成預(yù)測(cè)時(shí)有選擇地利用整個(gè)上下文中的相關(guān)信息。在此過(guò)程中使用softmax函數(shù)的目的是將原始相似度分?jǐn)?shù)轉(zhuǎn)換為概率分布,這在本質(zhì)上類(lèi)似于k近鄰機(jī)制,即相關(guān)性更高的鍵值對(duì)獲得更大的權(quán)重。
下面我們分析單個(gè)注意力層的計(jì)算復(fù)雜度:
Softmax機(jī)制的計(jì)算瓶頸
通過(guò)上述分析可以看出,標(biāo)準(zhǔn)注意力機(jī)制需要對(duì)NxN維度的矩陣執(zhí)行softmax運(yùn)算,這導(dǎo)致計(jì)算復(fù)雜度隨序列長(zhǎng)度呈二次方增長(zhǎng)。雖然這種計(jì)算復(fù)雜度對(duì)于較短序列是可接受的,但在處理長(zhǎng)度達(dá)到100k以上的序列時(shí),計(jì)算效率會(huì)顯著降低。
這一計(jì)算瓶頸促使研究者們思考:是否存在能夠降低計(jì)算復(fù)雜度的替代方案?這就引出了線(xiàn)性注意力機(jī)制的研究。
線(xiàn)性注意力機(jī)制
Katharopoulos等人提出了一種創(chuàng)新性的解決方案,他們通過(guò)將softmax指數(shù)函數(shù)重寫(xiě)為特征映射φ(x)的點(diǎn)積形式的核函數(shù),并利用矩陣乘法的結(jié)合律,成功將注意力計(jì)算重構(gòu)為線(xiàn)性形式。這一轉(zhuǎn)換過(guò)程如下圖所示:
在該方法中Katharopoulos等人采用elu(x) + 1作為特征映射函數(shù)φ(x)。任何能夠有效近似指數(shù)相似度的核特征映射都可以作為候選函數(shù)。這種方法的計(jì)算復(fù)雜度可以表示為:
這種重構(gòu)方法消除了計(jì)算完整N×N注意力矩陣的需求,將復(fù)雜度降低至O(Nd2),其中d表示嵌入維度。在大型語(yǔ)言模型中,通常序列長(zhǎng)度N遠(yuǎn)大于嵌入維度d,因此這種方法實(shí)際上實(shí)現(xiàn)了線(xiàn)性時(shí)間復(fù)雜度。
從循環(huán)的角度來(lái)看線(xiàn)性注意力機(jī)制:
為什么這種轉(zhuǎn)換在線(xiàn)性注意力中可行而在softmax中不可行呢?這是因?yàn)閟oftmax函數(shù)本質(zhì)上不可分離,無(wú)法分解為獨(dú)立項(xiàng)的乘積。在解碼階段,由于只需要維護(hù)d × d維度的狀態(tài)矩陣S_(n-1),每個(gè)token的生成復(fù)雜度僅為O(d2)。
但是這種計(jì)算效率的提升也帶來(lái)了一個(gè)重要的局限性。由于狀態(tài)矩陣S_(n-1)的維度限制為d × d,其信息存儲(chǔ)容量存在上限。比如:如果原始上下文需要存儲(chǔ)20d2的信息量,在壓縮過(guò)程中將不可避免地?fù)p失19d2的信息。這揭示了線(xiàn)性注意力機(jī)制中計(jì)算效率與內(nèi)存容量之間的根本性權(quán)衡:通過(guò)維持固定維度的狀態(tài)矩陣獲得計(jì)算效率的同時(shí),也限制了上下文信息的保存能力。這一矛盾促使研究者們引入門(mén)控機(jī)制來(lái)優(yōu)化這一權(quán)衡。
門(mén)控線(xiàn)性注意力機(jī)制
前文分析表明,在使用固定維度狀態(tài)矩陣優(yōu)化計(jì)算效率的過(guò)程中,信息損失是不可避免的。這引發(fā)了一個(gè)關(guān)鍵問(wèn)題:是否可以通過(guò)某種機(jī)制來(lái)優(yōu)化信息保留策略?門(mén)控機(jī)制正是為解決這一問(wèn)題而提出的。研究者們將其作為一種選擇性信息過(guò)濾機(jī)制,通過(guò)智能地選擇需要保留的信息來(lái)最小化信息損失的影響。門(mén)控并非新概念,在LSTM等架構(gòu)中已經(jīng)得到了廣泛應(yīng)用和驗(yàn)證。
門(mén)控線(xiàn)性注意力對(duì)狀態(tài)矩陣Sn的構(gòu)建方式進(jìn)行了改進(jìn):
門(mén)控函數(shù)G有多種可能的實(shí)現(xiàn)方式,不同的選擇會(huì)導(dǎo)致不同的模型特性:
這種架構(gòu)的一個(gè)顯著優(yōu)勢(shì)在于:門(mén)控函數(shù)僅依賴(lài)于當(dāng)前token x和可學(xué)習(xí)參數(shù),而不需要考慮完整的序列歷史。由于各個(gè)token的門(mén)控計(jì)算相互獨(dú)立,這種設(shè)計(jì)實(shí)現(xiàn)了訓(xùn)練過(guò)程的高效并行化,使得序列中所有token的門(mén)控運(yùn)算能夠同時(shí)進(jìn)行。
狀態(tài)空間模型
在處理序列數(shù)據(jù)(如文本或時(shí)間序列)時(shí),傳統(tǒng)方法通常依賴(lài)注意力機(jī)制或RNN。狀態(tài)空間模型(SSMs)提供了一種全新的視角:它將序列處理問(wèn)題轉(zhuǎn)化為類(lèi)似于CNN處理圖像的方式,通過(guò)卷積操作來(lái)捕獲序列信息。
狀態(tài)空間模型通過(guò)離散線(xiàn)性時(shí)不變系統(tǒng)來(lái)形式化這一思想:
這種方法與卷積運(yùn)算的關(guān)系可以表示為:
其中F表示從參數(shù)(A, B, c)學(xué)習(xí)得到的卷積核,*代表卷積運(yùn)算。
H3模型通過(guò)設(shè)計(jì)包含兩個(gè)互補(bǔ)SSM層的結(jié)構(gòu)化架構(gòu)來(lái)實(shí)現(xiàn)這一理論框架:
H3將輸入分解為三個(gè)通道以模擬K、Q、V結(jié)構(gòu),并通過(guò)組合兩個(gè)SSM層和兩個(gè)門(mén)控機(jī)制來(lái)模擬線(xiàn)性注意力的功能。實(shí)驗(yàn)結(jié)果表明,這種架構(gòu)設(shè)計(jì)在實(shí)際應(yīng)用中展現(xiàn)出了優(yōu)異的性能。
選擇性狀態(tài)空間模型
前文討論的門(mén)控線(xiàn)性注意力通過(guò)引入數(shù)據(jù)依賴(lài)的信息保留機(jī)制改進(jìn)了標(biāo)準(zhǔn)線(xiàn)性注意力。狀態(tài)空間模型同樣面臨類(lèi)似的局限性:控制狀態(tài)轉(zhuǎn)換和輸出的參數(shù)A、B和c都是固定且數(shù)據(jù)無(wú)關(guān)的。這意味著所有輸入都要經(jīng)過(guò)相同的靜態(tài)系統(tǒng)處理,而不考慮輸入的重要性或上下文信息。
為解決這一問(wèn)題,研究者們提出了通過(guò)時(shí)變動(dòng)力系統(tǒng)來(lái)擴(kuò)展SSMs:
這種擴(kuò)展的核心問(wèn)題在于如何將c_t、b_t和A_t參數(shù)化為輸入的函數(shù)。不同的參數(shù)化方案可能導(dǎo)致模型趨近于線(xiàn)性注意力或門(mén)控注意力機(jī)制。
Mamba模型通過(guò)選擇性SSM塊實(shí)現(xiàn)了這種時(shí)變狀態(tài)空間框架:
Mamba的創(chuàng)新之處在于用選擇性SSM取代了標(biāo)準(zhǔn)SSM,并結(jié)合輸出門(mén)控和額外的卷積操作來(lái)提升性能。這種架構(gòu)設(shè)計(jì)展示了如何將多個(gè)關(guān)鍵組件有機(jī)地整合為一個(gè)高效的序列建模系統(tǒng)。
總結(jié)
本文系統(tǒng)性地探討了高效序列建模架構(gòu)的演進(jìn)歷程。從傳統(tǒng)softmax注意力機(jī)制的二次計(jì)算復(fù)雜度限制出發(fā),研究者們發(fā)展出了線(xiàn)性注意力機(jī)制。通過(guò)核函數(shù)的重構(gòu),線(xiàn)性注意力實(shí)現(xiàn)了O(Nd2)的計(jì)算復(fù)雜度,但同時(shí)也面臨著固定維度狀態(tài)矩陣帶來(lái)的內(nèi)存限制。
這一限制促使了門(mén)控線(xiàn)性注意力的提出,通過(guò)引入門(mén)控機(jī)制實(shí)現(xiàn)選擇性信息保留。隨后,狀態(tài)空間模型提供了一個(gè)全新的視角,通過(guò)類(lèi)卷積操作處理序列數(shù)據(jù)。從基礎(chǔ)SSMs到時(shí)變系統(tǒng),再到選擇性SSMs的發(fā)展過(guò)程,與線(xiàn)性注意力到門(mén)控注意力的演進(jìn)具有相似性——在這兩個(gè)方向上,增強(qiáng)模型對(duì)輸入數(shù)據(jù)的適應(yīng)性都是提升性能的關(guān)鍵。
這些發(fā)展揭示了一個(gè)核心主題:計(jì)算效率與內(nèi)存容量之間的基本權(quán)衡。softmax注意力通過(guò)維持完整序列的注意力權(quán)重實(shí)現(xiàn)了出色的上下文學(xué)習(xí)能力,但付出了二次計(jì)算復(fù)雜度的代價(jià)。線(xiàn)性變體(包括SSMs)通過(guò)固定維度的狀態(tài)表示降低了計(jì)算復(fù)雜度,但也限制了保持詳細(xì)上下文信息的能力。這種權(quán)衡仍然是序列建模領(lǐng)域的核心挑戰(zhàn),繼續(xù)推動(dòng)著研究者們探索能夠更好平衡這些競(jìng)爭(zhēng)需求的架構(gòu)設(shè)計(jì)。