2.5%KV緩存保持大模型90%性能,大模型金字塔式信息匯聚模式探秘
用KV緩存加速大模型的顯存瓶頸,終于迎來突破。
北大、威斯康辛-麥迪遜、微軟等聯合團隊提出了全新的緩存分配方案,只用2.5%的KV cache,就能保持大模型90%的性能。
這下再也不用擔心KV占用的顯存容量過高,導致顯卡不夠用了。
圖片
該方法名為PyramidKV,顧名思義,在KV緩存壓縮的過程中融入了金字塔型的信息匯聚方式。
在內存受限的情況下,PyramidKV表現非常出色,既保留了長上下文理解能力,又顯著減少了內存使用。
目前,PyramidKV相關代碼已經在GitHub開源。
引入金字塔信息匯聚方式
隨著模型尺寸的增大,推理需要的時間越來越多。KV cache作為推理加速的關鍵技術,通過緩存之前的解碼步驟中計算出的Transformer的K和V矩陣減少后續解碼時間。
但是,隨著序列長度增大,需要緩存的KV cache會快速增長,占用大量顯存。針對這一問題,之前的工作設計策略是對KV cache進行壓縮。
實際上,長文本的推理加速和顯存節省作為一個重要的話題,這涉及到廣泛的大模型下游應用,比如檢索增強生成(Retrieval-Augmented Generation)、上下文學習(In-Context Learning)受到廣泛關注。
KV cache及KV cache的壓縮能否有效幫助長文本實現推理加速成為廣受關注的研究方向。
采用均一壓縮策略,是最佳方案嗎?
傳統壓縮方法的一個共同特點是,均對每個Transformer層使用同樣的KV cache壓縮設置,使用同樣的方法壓縮到同樣的長度。
圖片
但PyramidKV團隊發現,對KV cache進行極致壓縮情況下上述方法的表現,發現當超長文本壓縮到極致小的KV大小時(從32k 長度壓縮到64,即保留0.2%的KV cache長度)時,會面臨嚴重的性能減弱。
于是作者提出了疑問:對每個Transformer層將KV cache壓縮到同樣的大小是否為最優方案?
為了回答上述問題,研究團隊對大模型進行檢索增強生成的機制進行深入分析。
作者研究了Llama模型進行多文檔問答的逐層注意力圖,發現了注意力層中的金字塔形信息匯聚模式(Pyramidal Information Funneling)的存在:
- 在模型的低層(例如第0層)中,注意力得分呈現近似均勻分布,這表明模型在較低層時從所有可用內容中全局聚合信息,而不會優先關注特定的段落。
- 當編碼信息進行到中間層(6-18)時,逐漸轉變為聚焦在段落內部的注意力模式 (Localized Attention)。在這個階段,注意力主要集中在同一文檔內的Token上,表明模型在單個段落內進行了段落內部的信息聚合。
- 這種趨勢在上層(24-30)繼續并加強,本文觀察到了“Attention Sink”和“Massive Activation”現象。
在這些層中,注意力機制極大地集中在少數幾個關鍵Token上,因此只需要保留這些關鍵Token就能讓輸出保持一致并且減少顯存占用。
圖片
這種注意力分配模式,即極高的注意力得分,表明模型已將信息聚合到這些關鍵標記中。
這種注意力現象顯示了大模型對大量復雜的信息的進行編碼的機制,最終得到生成準確答案所需的最關鍵信息。
根據以上的發現,作者認為之前的工作對所有Transformer層統一處理是低效的,因此不同Transformer層的注意力稀疏程度并不相同。在低層能觀察到特別稠密的注意力,而在較高層則可以觀察到非常稀疏的注意力。
因此,在不同層之間使用固定的 KV 緩存大小可能會導致性能不佳。這些方法可能在較高層的稀疏注意力中保留許多不重要的 tokens,而忽略了較低層密集注意力中的許多重要的 tokens。
每層注意力特點不同,分層施策才是正解
于是,作者選擇了通過基于注意力模式動態分配緩存預算來提高壓縮效率。
具體而言,PyramidKV在信息更加分散的較低層分配更多的KV cache緩存,而在信息集中于少數關鍵tokens的較高層減少KV cache緩存。
一旦為每一層確定了KV緩存預算,PyramidKV在每一個Transformer層中選擇根據注意力選擇要緩存的KV。
最后的部分Token的KV緩存,即Instruction Token,會在所有Transformer層中保留。
根據UIUC、普林斯頓等提出的SnapKV方法,剩余的KV的選擇由從這些Instruction Token中獲得的對其他的Token注意力分數來指導——
接收到更高注意力分數的Token被認為與生成過程更相關,因此其KV狀態優先保存在GPU緩存中。
圖片
2.5%的KV cache,保持90%模型性能
為了評估PyramidKV的表現,作者使用最新的開源大模型Llama-3-8B-Instruct和Mistral-7B-Instruct,來對PyramidKV和其他方法進行對比。
測試示例以生成格式進行評估,所有任務的答案均通過貪婪解碼生成,并使用 LongBench來評估PyramidKV在處理長上下文輸入任務中的表現。
LongBench是一個精心設計的基準測試套件,用于測試語言模型處理長文檔和復雜信息序列的能力。
該基準測試旨在對長上下文輸入進行多任務評估,包括17個數據集,涵蓋單文檔問答、多文檔問答、摘要生成、少樣本學習、合成數據和代碼生成等任務。
數據集的平均輸入長度從1235個到18409個tokens不等,需要大量的內存來管理KV緩存。
對于所有這些任務,作者都遵循 LongBench推薦的標準指標。
結果,在64、96、128、256和512個KV cache緩存大小的設定下,PyramidKV在LongBench中均取得了優于baseline的效果。
圖片
在此基礎上,作者還研究了兩種不同的操作場景——節省內存場景(Memory-Efficient Scenario)和保持性能場景(Performance-Preserving Scenario),分別用于在內存和模型性能之間進行權衡。
PyramidKV在Longbench的多個任務和平均得分上均取得了優于baseline的效果。
值得注意的是,PyramidKV在size為128的設定下,在TREC任務(上下文學習問答挑戰)中表現出顯著優越的性能,相較于baseline,提高了20.的ACC結果。
圖片
總體而言,PyramidKV僅用12%的KV緩存就能保持完整的性能,并且在各種KV緩存大小的設定下和不同主干模型中始終優于其他方法,特別是在僅保留約128(0.7%)KV cache緩存的節省內存場景中,其性能優勢尤為明顯。
在具體任務的檢查中,PyramidKV在TREC任務(上下文學習問答挑戰)中表現出顯著優越的性能,僅僅使用64的KV cache緩存大小(原始輸入是5k長度)就能達到90%的性能。
這表明模型有效地聚合了樣本中的任務信息,突出了在上下文學習任務上進一步研究的潛力。
下面的表則展示了PyramidKV使KV緩存的占用減少的情況。作者評估了Llama-3-8B-Instruct的內存消耗。
具體來說,作者發現在固定批量大小為1、輸入長度為8192、模型權重為fp16格式的情況下,PyramidKV在不同緩存大小下顯著減少了KV緩存的內存,還一定程度上保留了任務性能。
圖片
為了進一步理解PyramidKV在LongBench上的性能,作者還進行了“大海撈針”實驗,將PyramidKV與SnapKV進行比較,并且對比128大小的KV緩存和完整的KV緩存。
在輸入序列長度在2000到4000之間的中等上下文情況下,SnapKV在“大海撈針”測試中產生了越來越多的錯誤案例。
在輸入序列長度超過6000的長上下文情況下,SnapKV顯著降低了LLMs在評估中的性能。
相比之下,PyramidKV在大多數情況下減輕了這種弱化效應。下圖展示了定量結果。分數越高、顏色越淺,表示著檢索能力越強。
在該任務的平均得分中,完整KV得分為65.0,PyramidKV得分為62.6,而SnapKV得分為57.3。
圖片
此外,作者的實驗表明,PyramidKV在上下文學習(In-Context Learning)的少樣本學習任務中顯著優于其他方法。
這表明KV cache緩存壓縮在上下文學習中的應用前景廣闊,這種方法有可能在受限的內存條件下實現更多樣本的引入。
論文地址:https://arxiv.org/abs/2406.02069項目主頁:
https://zefan-cai.github.io/PyramidKV.github.io/