最多400萬token上下文、推理提速22倍,StreamingLLM火了,已獲GitHub 2.5K星
如果你體驗過與任何一款對話式 AI 機器人的交流,你一定能想起某些極具「挫敗感」的時刻。比如,你在前一天的對話中講述過的要點,被 AI 忘得干干凈凈……
這是因為當前的多數 LLM 只能記住有限的上下文,就像為考試而臨時抱佛腳的學生,稍加盤問就會「露出馬腳」。
想象一下,如果 AI 助手在聊天中能夠根據上下文參考幾周或幾個月前的對話,或者,你可以要求 AI 助手總結長達數千頁的報告,這樣的能力是不是令人羨慕?
為了讓 LLM 記住更多、記得更好,研究者們正在不斷努力。最近,來自 MIT、Meta AI、CMU 的研究者提出了一種名為「StreamingLLM」的方法,使語言模型能夠流暢地處理無窮無盡的文本。
- 論文地址:https://arxiv.org/pdf/2309.17453.pdf
- 項目地址:https://github.com/mit-han-lab/streaming-llm
StreamingLLM 的工作原理是識別并保存模型固有的「注意力池」(attention sinks)錨定其推理的初始 token。結合最近 token 的滾動緩存,StreamingLLM 的推理速度提高了 22 倍,而不需要犧牲任何的準確性。短短幾天,該項目在 GitHub 平臺已斬獲 2.5K 星:
具體地說,StreamingLLM 使得語言模型能準確無誤地記住上一場比賽的觸地得分、新生兒的名字、冗長的合同或辯論內容,就像升級了 AI 助理的內存,可以完美地處理更繁重的工作。
接下來讓我們看看技術細節。
方法創新
通常,LLM 在預訓練時受到注意力窗口的限制。盡管為擴大這一窗口大小、提高訓練和推理效率,此前已有很多工作,但 LLM 可接受的序列長度仍然是有限的,這對于持久部署來說并不友好。
在這篇論文中,研究者首先介紹了 LLM 流應用的概念,并提出了一個問題:「能否在不犧牲效率和性能的情況下以無限長輸入部署 LLM?」
將 LLM 應用于無限長輸入流時,會面臨兩個主要挑戰:
1、在解碼階段,基于 transformer 的 LLM 會緩存所有先前 token 的 Key 和 Value 狀態(KV),如圖 1 (a) 所示,這可能會導致內存使用過多,并增加解碼延遲;
2、現有模型的長度外推能力有限,即當序列長度超過預訓練時設定的注意力窗口大小時,其性能就會下降。
一種直觀的方法被稱為窗口注意力(Window Attention)(如圖 1 b),這種方法只在最近 token 的 KV 狀態上保持一個固定大小的滑動窗口,雖然能確保在緩存填滿后仍能保持穩定的內存使用率和解碼速度,但一旦序列長度超過緩存大小,甚至只是驅逐第一個 token 的 KV,模型就會崩潰。另一種方法是重新計算滑動窗口(如圖 1 c 所示),這種方法會為每個生成的 token 重建最近 token 的 KV 狀態,雖然性能強大,但需要在窗口內計算二次注意力,因此速度明顯更慢,在實際的流應用中并不理想。
在理解窗口注意力失效的過程中,研究者發現了自回歸 LLM 的一個有趣現象:如圖 2 所示,大量注意力分數被分配給了初始 token,而不管這些 token 與語言建模任務是否相關。
研究者將這些 token 稱為「注意力池」:盡管它們缺乏語義上的意義,但卻占據了大量的注意力分數。研究者將這一現象歸因于于 Softmax(要求所有上下文 token 的注意力分數總和為 1),即使當前查詢在許多以前的 token 中沒有很強的匹配,模型仍然需要將這些不需要的注意力值分配到某處,從而使其總和為 1。初始 token 成為「池」的原因很直觀:由于自回歸語言建模的特性,初始 token 對幾乎所有后續 token 都是可見的,這使得它們更容易被訓練成注意力池。
基于上述洞察,研究者提出了 StreamingLLM,這是一個簡單而高效的框架,它可以讓使用有限注意力窗口訓練的注意力模型在不進行微調的情況下處理無限長的文本。
StreamingLLM 利用了注意力池具有高注意力值這一事實,保留這些注意力池可以使注意力分數分布接近正態分布。因此,StreamingLLM 只需保留注意力池 token 的 KV 值(只需 4 個初始 token 即可)和滑動窗口的 KV 值,就能錨定注意力計算并穩定模型的性能。
使用 StreamingLLM,包括 Llama-2-[7,13,70] B、MPT-[7,30] B、Falcon-[7,40] B 和 Pythia [2.9,6.9,12] B 在內的模型可以可靠地模擬 400 萬個 token,甚至更多。
與唯一可行的 baseline—— 重新計算滑動窗口相比,StreamingLLM 的速度提高了 22.2 倍,而沒有損耗性能。
測評
在實驗環節,如圖 3 所示,在跨度為 20K token 的文本上,StreamingLLM 的困惑度可以與 Oracle 基線(重新計算滑動窗口)相媲美。同時,當輸入長度超過預訓練窗口時,密集注意力就會失效,而當輸入長度超過緩存大小時,窗口注意力就會陷入困境,導致初始 token 被剔除。
圖 5 進一步證實了 StreamingLLM 可以可靠地處理非常規規模的文本,包括 400 多萬個 token,涵蓋了各種模型系列和規模。這包括 Llama-2-[7,13,70] B、Falcon-[7,40] B、Pythia-[2.8,6.9,12] B 和 MPT-[7,30] B。
隨后,研究者證實了「注意力池」的假設,并證明語言模型可以通過預訓練,在流式部署時只需要一個注意力池 token。具體來說,他們建議在所有訓練樣本的開頭多加一個可學習的 token,作為指定的注意力池。通過從頭開始預訓練 1.6 億個參數的語言模型,研究者證明了本文方法可以保持模型的性能。這與當前的語言模型形成了鮮明對比,后者需要重新引入多個初始 token 作為注意力池才能達到相同的性能水平。
最后,研究者將 StreamingLLM 的解碼延遲和內存使用率與重新計算滑動窗口進行了比較,并使用 Llama-2-7B 和 Llama-2-13B 模型在單個英偉達 A6000 GPU 上進行了測試。如圖 10 所示,隨著緩存大小的增加,StreamingLLM 的解碼速度呈線性增長。后者解碼延遲則呈二次曲線上升。實驗證明,StreamingLLM 實現了令人印象深刻的提速,每個 token 速度的提升高達 22.2 倍。
更多研究細節,可參考原論文。