?新一代注意力機(jī)制Lightning Attention-2:無(wú)限序列長(zhǎng)度、恒定算力開(kāi)銷、更高建模精度
大語(yǔ)言模型序列長(zhǎng)度的限制,極大地制約了其在人工智能領(lǐng)域的應(yīng)用,比如多輪對(duì)話、長(zhǎng)文本理解、多模態(tài)數(shù)據(jù)的處理與生成等。造成這一限制的根本原因在于當(dāng)前大語(yǔ)言模型均采用的 Transformer 架構(gòu)有著相對(duì)于序列長(zhǎng)度的二次計(jì)算復(fù)雜度。這意味著隨著序列長(zhǎng)度的增加,需要的計(jì)算資源成幾何倍數(shù)提升。如何高效地處理長(zhǎng)序列一直是大語(yǔ)言模型的挑戰(zhàn)之一。
之前的方法往往集中在如何讓大語(yǔ)言模型在推理階段適應(yīng)更長(zhǎng)的序列。比如采用 Alibi 或者類似的相對(duì)位置編碼的方式來(lái)讓模型自適應(yīng)不同的輸入序列長(zhǎng)度,亦或采用對(duì) RoPE 等類似的相對(duì)位置編碼進(jìn)行差值的方式,在已經(jīng)完成訓(xùn)練的模型上再進(jìn)行進(jìn)一步的短暫精調(diào)來(lái)達(dá)到擴(kuò)增序列長(zhǎng)度的目的。這些方法只是讓大模型具有了一定的長(zhǎng)序列建模能力,但實(shí)際訓(xùn)練和推理的開(kāi)銷并沒(méi)有減少。
OpenNLPLab 團(tuán)隊(duì)嘗試一勞永逸地解決大語(yǔ)言模型長(zhǎng)序列問(wèn)題。他們提出并開(kāi)源了 Lightning Attention-2—— 一種新型的線性注意力機(jī)制,讓長(zhǎng)序列的訓(xùn)練和推理成本與 1K 序列長(zhǎng)度的一致。在遇到顯存瓶頸之前,無(wú)限地增大序列長(zhǎng)度并不會(huì)對(duì)于模型訓(xùn)練速度產(chǎn)生負(fù)面影響。這讓無(wú)限長(zhǎng)度預(yù)訓(xùn)練成為了可能。同時(shí),超長(zhǎng)文本的推理成本也與 1K Tokens 的成本一致甚至更少,這將極大地減少當(dāng)前大語(yǔ)言模型的推理成本。如下圖所示,在 400M、1B、3B 的模型大小下,隨著序列長(zhǎng)度的增加,F(xiàn)lashAttention2 加持的 LLaMA 的訓(xùn)練速度開(kāi)始快速下降,然而 Lightning Attention-2 加持的 TansNormerLLM 的速度幾無(wú)變化。
圖 1
- 論文:Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models
- 論文地址:https://arxiv.org/pdf/2401.04658.pdf
- 開(kāi)源地址:https://github.com/OpenNLPLab/lightning-attention
Lightning Attention-2 簡(jiǎn)介
讓大模型的預(yù)訓(xùn)練速度在不同序列長(zhǎng)度下保持一致,這聽(tīng)起來(lái)是一個(gè)不可能的任務(wù)。事實(shí)上,如果一個(gè)注意力機(jī)制的計(jì)算復(fù)雜度相對(duì)于序列長(zhǎng)度保持線性關(guān)系的話,就可以實(shí)現(xiàn)這一點(diǎn)。自 2020 年線性注意力【https://arxiv.org/abs/2006.16236】橫空出世以來(lái),研究人員一直在為了線性注意力的實(shí)際效率符合它的理論線性計(jì)算復(fù)雜度而努力。在 2023 年之前,大多數(shù)的關(guān)于線性注意力的工作均集中在對(duì)齊它們與 Transformer 的精度上。終于在 2023 年中期,改進(jìn)的線性注意力機(jī)制【https://arxiv.org/abs/2307.14995】在精度上可以與最先進(jìn)的 Transformer 架構(gòu)對(duì)齊。然而,線性注意力中將計(jì)算復(fù)雜度變成線性的最關(guān)鍵的 “左乘變右乘” 的計(jì)算 Trick (如下圖所示),在實(shí)際實(shí)現(xiàn)中遠(yuǎn)慢于直接左乘的算法。其原因在于右乘的實(shí)現(xiàn)需要用到包含大量循環(huán)操作的累積求和(cumsum),大量的 IO 操作使得右乘的效率遠(yuǎn)低于左乘。
圖 2
為了更好的理解 Lightning Attention-2 的思路,讓我們先回顧下傳統(tǒng) softmax attention 的計(jì)算公式:O=softmax ((QK^T)⊙M_) V,其中 Q, K, V, M, O 分別為 query, key, value, mask 和輸出矩陣,這里的 M 在單向任務(wù)(如 GPT)中是一個(gè)下三角的全 1 矩陣,在雙向任務(wù)(如 Bert)中則可以忽略,即雙向任務(wù)沒(méi)有 mask 矩陣。
作者將 Lightning Attention-2 的整體思路總結(jié)為以下三點(diǎn)進(jìn)行解釋:
1. Linear Attention 的核心思想之一就是去除了計(jì)算成本高昂的 softmax 算子,使 Attention 的計(jì)算公式可以寫為 O=((QK^T)⊙M_) V。但由于單向任務(wù)中 mask 矩陣 M 的存在,使得該形式依然只能進(jìn)行左乘計(jì)算,從而不能獲得 O (N) 的復(fù)雜度。但對(duì)于雙向任務(wù),由于沒(méi)有沒(méi)有 mask 矩陣,Linear Attention 的計(jì)算公式可以進(jìn)一步簡(jiǎn)化為 O=(QK^T) V。Linear Attention 的精妙之處在于,僅僅利用簡(jiǎn)單的矩陣乘法結(jié)合律,其計(jì)算公式就可以進(jìn)一步轉(zhuǎn)化為:O=Q (K^T V),這種計(jì)算形式被稱為右乘,相對(duì)應(yīng)的前者為左乘。通過(guò)圖 2 可以直觀地理解到 Linear Attention 在雙向任務(wù)中可以達(dá)到誘人的 O (N) 復(fù)雜度!
2. 但是隨著 decoder-only 的 GPT 形式的模型逐漸成為 LLM 的事實(shí)標(biāo)準(zhǔn),如何利用 Linear Attention 的右乘特性加速單向任務(wù)成為了亟待解決的難題。為了解決這個(gè)問(wèn)題,本文作者提出了利用 “分而治之” 的思想,將注意力矩陣的計(jì)算分為對(duì)角陣和非對(duì)角陣兩種形式,并采用不同的方式對(duì)他們進(jìn)行計(jì)算。如圖 3 所示,Linear Attention-2 利用計(jì)算機(jī)領(lǐng)域常用的 Tiling 思想,將 Q, K, V 矩陣分別切分為了相同數(shù)量的塊 (blocks)。其中 block 自身(intra-block)的計(jì)算由于 mask 矩陣的存在,依然保留左乘計(jì)算的方式,具有 O (N^2) 的復(fù)雜度;而 block 之間(inter-block)的計(jì)算由于沒(méi)有 mask 矩陣的存在,可以采用右乘計(jì)算方式,從而享受到 O (N) 的復(fù)雜度。兩者分別計(jì)算完成后,可以直接相加得到對(duì)應(yīng)第 i 塊的 Linear Attention 輸出 Oi。同時(shí),通過(guò) cumsum 對(duì) KV 的狀態(tài)進(jìn)行累積以在下一個(gè) block 的計(jì)算中使用。這樣就得到了整個(gè) Lightning Attention-2 的算法復(fù)雜度為 intra-block 的 O (N^2) 和 inter-block 的 O (N) 的 Trade-off。怎么取得更好的 Trade-off 則是由 Tiling 的 block size 決定的。
3. 細(xì)心的讀者會(huì)發(fā)現(xiàn),以上的過(guò)程只是 Lightning Attention-2 的算法部分,之所以取名 Lightning 是因?yàn)樽髡叱浞挚紤]了該算法過(guò)程在 GPU 硬件執(zhí)行過(guò)程中的效率問(wèn)題。受到 FlashAttention 系列工作的啟發(fā),實(shí)際在 GPU 上進(jìn)行計(jì)算的時(shí)候,作者將切分后的 Q_i, K_i, V_i 張量從 GPU 內(nèi)部速度更慢容量更大的 HBM 搬運(yùn)到速度更快容量更小的 SRAM 上進(jìn)行計(jì)算,從而減少大量的 memory IO 開(kāi)銷。當(dāng)該 block 完成 Linear Attention 的計(jì)算之后,其輸出結(jié)果 O_i 又會(huì)被搬回至 HBM。重復(fù)這個(gè)過(guò)程直到所有 block 被處理完畢即可。
想要了解更多細(xì)節(jié)的讀者可以仔細(xì)閱讀本文中的 Algorithm 1 和 Algorithm 2,以及論文中的詳細(xì)推導(dǎo)過(guò)程。Algorithm 以及推導(dǎo)過(guò)程都對(duì) Lightning Attention-2 的前向和反向過(guò)程進(jìn)行了區(qū)分,可以幫助讀者有更深入的理解。
圖 3
Lightning Attention-2 精度對(duì)比
研究人員首先在小規(guī)模(400M)參數(shù)模型上對(duì)比了 Lightning Attention-2 與 Lightning Attention-1 的精度區(qū)別,如下圖所示,二者幾無(wú)差別。
隨后研究人員在 1B、3B 上將 Lightning Attention-2 加持的 TransNormerLLM(TNL-LA2)與其它先進(jìn)的非 Transformer 架構(gòu)的網(wǎng)絡(luò)以及 FlashAttention2 加持的 LLaMA 在相同的語(yǔ)料下做了對(duì)比。如下圖所示,TNL-LA2 與 LLaMA 保持了相似的趨勢(shì),并且 loss 的表現(xiàn)更優(yōu)。這個(gè)實(shí)驗(yàn)表明,Lightning Attention-2 在語(yǔ)言建模方面有著不遜于最先進(jìn)的 Transformer 架構(gòu)的精度表現(xiàn)。
在大語(yǔ)言模型任務(wù)中,研究人員對(duì)比了 TNL-LA2 15B 與 Pythia 在類似大小下的大模型常見(jiàn) Benchmark 的結(jié)果。如下表所示,在吃掉了相同 tokens 的條件下,TNL-LA2 在常識(shí)推理和多項(xiàng)選擇綜合能力上均略高于基于 Softmax 的注意力的 Pythia 模型。
Lightning Attention-2 速度對(duì)比
研究人員對(duì) Lightning Attention-2 與 FlashAttention2 進(jìn)行了單模塊速度與顯存占用對(duì)比。如下圖所示,相比于 Lightning Attention-1 和 FlashAttention2,在速度上,Lightning Attention-2 表現(xiàn)出了相比于序列長(zhǎng)度的嚴(yán)格線性增長(zhǎng)。在顯存占用上,三者均顯示出了類似的趨勢(shì),但 Lightning Attention-2 的顯存占用更小。這個(gè)的原因是 FlashAttention2 和 Lightning Attention-1 的顯存占用也是近似線性的。
筆者注意到,這篇文章主要關(guān)注點(diǎn)在解決線性注意力網(wǎng)絡(luò)的訓(xùn)練速度上,并實(shí)現(xiàn)了任意長(zhǎng)度的長(zhǎng)序列與 1K 序列相似的訓(xùn)練速度。在推理速度上,并沒(méi)有過(guò)多的介紹。這是因?yàn)榫€性注意力在推理的時(shí)候可以無(wú)損地轉(zhuǎn)化為 RNN 模式,從而達(dá)到類似的效果,即推理單 token 的速度恒定。對(duì)于 Transformer 來(lái)說(shuō),當(dāng)前 token 的推理速度與它之前的 token 數(shù)量相關(guān)。
筆者測(cè)試了 Lightning Attention-1 加持的 TransNormerLLM-7B 與常見(jiàn)的 7B 模型在推理速度上的對(duì)比。如下圖所示,在近似參數(shù)大小下,Lightning Attention-1 的吞吐速度是百川的 4 倍,ChatGLM 的 3.5 倍以上,顯示出了優(yōu)異的推理速度優(yōu)勢(shì)。
小結(jié)
Lightning Attention-2 代表了線性注意力機(jī)制的重大進(jìn)步,使其無(wú)論在精度還是速度上均可以完美的替換傳統(tǒng)的 Softmax 注意力,為今后越來(lái)越大的模型提供了可持續(xù)擴(kuò)展的能力,并提供了一條以更高效率處理無(wú)限長(zhǎng)序列的途徑。OpenNLPLab 團(tuán)隊(duì)在未來(lái)將研究基于線性注意力機(jī)制的序列并行算法,以解決當(dāng)前遇到的顯存屏障問(wèn)題。