比標準Attention提速5-9倍,大模型都在用的FlashAttention v2來了
近來,幾種長上下文語言模型陸續問世,包括 GPT-4(上下文長度為 32k)、MosaicML 的 MPT(上下文長度為 65k)Anthropic 的 Claude(上下文長度為 100k)。長文檔查詢和故事寫作等新興用例已經表明擴展語言模型上下文窗口是非常必要的。
然而,擴大 Transformer 的上下文長度是一個挑戰,因為其核心的注意力層在時間復雜度和空間復雜度與輸入序列長度的平方成正比。
一年前,來自斯坦福大學、紐約州立大學布法羅分校的研究者共同提出一種快速、內存高效的注意力算法 ——FlashAttention。該算法無需任何近似即可加速注意力并減少內存占用。現在,已經有許多機構和研究實驗室采用 FlashAttention 來加速訓練和推理。
FlashAttention 示意圖。
盡管 FlashAttention 的速度已經是優化基線的 2-4 倍,但它仍然有相當大的改進空間。FlashAttention 仍然不如優化過的矩陣乘法 (GEMM) 運算快,僅達到理論最大 FLOPs/s 的 25-40%。
現在,研究團隊宣布推出 FlashAttention-2。FlashAttention-2 完全從頭開始重寫,使用 Nvidia 的 CUTLASS 3.x 及其核心庫 CuTe 的原語(primitive)。
圖片
FlashAttention-2 開發者 Tri Dao。他是斯坦福大學博士生,還是 Together.AI 首席科學家,并將于 2024 年 9 月開始任職普林斯頓大學計算機科學助理教授。
FlashAttention-2 的速度是 FlashAttention 的 2 倍,在 A100 GPU 上達到 230 TFLOPs/s。在端到端訓練 GPT 類語言模型時,FlashAttention-2 可讓訓練速度高達 225 TFLOPs/s(模型 FLOP 利用率為 72%)。
FlashAttention-2 將加速現有模型的訓練、微調和推理。這意味著我們可以用相同成本訓練 2 倍上下文長度的語言模型。這將有助于語言模型理解長篇書籍和報告、高分辨率圖像、音頻和視頻。
圖片
- 項目地址:https://github.com/Dao-AILab/flash-attention
- 技術報告:https://tridao.me/publications/flash2/flash2.pdf
FlashAttention 是什么?
FlashAttention 是一種重新排序注意力計算的算法,它利用平鋪、重計算等經典技術來顯著提升計算速度,并將序列長度中的內存使用實現從二次到線性減少。其中平鋪意味著將輸入塊從 HBM(GPU 內存)加載到 SRAM(快速緩存),并對該塊執行注意力操作,更新 HBM 中的輸出。
此外通過不將大型中間注意力矩陣寫入 HBM,內存讀寫量減少,帶來了 2-4 倍的時鐘時間加速。
下圖為 FlashAttention 的前向傳遞圖:通過平鋪和 softmax 重新縮放,研究者按塊進行操作,避免從 HBM 中讀取 / 寫入,同時獲得正確的輸出,無需近似操作。
圖片
然而,FlashAttention 仍然存在一些低效率問題,原因在于不同線程塊之間的工作分區不理想以及 GPU 上的 warp。這些導致低占用率或不必要的共享內存讀寫。
FlashAttention-2
更好的算法、并行化和工作分區
更少的非矩陣乘法 Flops
研究者調整了 FlashAttention 的算法,從而減少了非矩陣乘法(non-matmul)的 Flops 數量。這點很重要,因為現代 GPU 具有專門的計算單元(例如 Nvidia GPU 上的張量核心),使得矩陣乘法速度更快。
舉例而言,A100 GPU 的 FP16/BF16 矩陣乘法的最大理論吞吐量為 312 TFLOPs/s,但非矩陣乘法 FP32 的理論吞吐量僅為 19.5 TFLOPs/s。
換一種思考方式,每個非矩陣乘法 FLOP 比矩陣乘法 FLOP 的代價高 16 倍。為了保持高吞吐量,研究者希望在矩陣乘法 FLOP 上花費盡可能多的時間。因此他們重寫了 FlashAttention 中使用的在線 softmax 技巧,以減少重新縮放操作、邊界檢查和因果掩碼操作的數量,而無需更改輸出。
更好的并行化
FlashAttention v1 在批大小和頭(head)數量上進行并行化。研究者使用 1 個線程塊來處理一個注意力頭,總共有(批大小 * 頭數量)個線程塊。每個線程塊都計劃在流式多處理器(SM)上運行,例如 A100 GPU 上有 108 個這樣的 SM。當這個數字非常大(如 >= 80)時,這種調度是有效的,這時可以高效地使用 GPU 上幾乎所有計算資源。
在長序列的情況下(通常意味著小批量或少量頭),為了更好地利用 GPU 上的多處理器,現在研究者在序列長度維數上額外地進行并行化,使該機制顯著加速。
更好的工作分區
即使在每個線程塊內,研究者也必須決定如何在不同的 warp 之間劃分工作(一組 32 個線程一起工作)。通常情況下,每個線程塊使用 4 或 8 個 warp,分區方案如下圖所述。
研究者改進了 FlashAttention-2 中的這種分區,減少不同 warp 之間的同步和通信量,進而減少共享內存讀寫。
圖片
對于每個塊,FlashAttention 將 K 和 V 分割到 4 個 warp 上,同時保持 Q 可被所有 warp 訪問。這被稱為「sliced-K」方案。不過,這種方案是低效的,原因在于所有 warp 都需要將它們的中間結果寫入共享內存,并同步,然后將中間結果相加。這些共享內存讀寫會減慢 FlashAttention 中的前向傳遞速度。
在 FlashAttention-2 中,研究者將 Q 分割在 4 個 warp 上,同時保持 K 和 V 可被所有的 warp 訪問。每個 warp 執行矩陣乘法以獲得 Q K^T 的切片,然后只需與 V 的共享切片相乘就能獲得相應的輸出切片。warp 之間不需要通信。共享內存讀寫的減少也可以提升速度。
新特性:頭維數高達 256、多查詢注意力
我們知道,FlashAttention 僅支持最高 128 的頭維數,這適用于大多數模型,但有一些模型被遺漏了。
因此,FlashAttention-2 支持了高達 256 的頭維數,這意味著 GPT-J、CodeGen 和 CodeGen2、StableDiffusion 1.x 等模型可以使用 FlashAttention-2 來獲得加速和節省內存。
此外,FlashAttention-2 還支持了多查詢注意力(multi-query attention, MQA)以及分組查詢注意力(grouped-query attention, GQA)。它們是注意力的變體,其中多個查詢頭關注相同的鍵和值頭,以減少推理過程中 KV 緩存的大小,并可以顯著提高推理吞吐量。
注意力基準結果
研究者在 A100 80GB SXM4 GPU 上,測量不同設置(無 / 有因果掩碼、頭維數 64 或 128)下不同注意力方法的運行時。
結果發現, FlashAttention-2 的速度是 FlashAttention(以及 xformers 庫和 Triton 中的其他實現)的 2 倍。與 PyTorch 中的標準注意力實現相比,FlashAttention-2 的速度最高是它們的 9 倍。
A100 GPU 上的注意力前向 + 后向速度。
此外只需要在 H100 GPU 上 運行相同的實現(不使用特殊指令來利用 TMA 和第四代 Tensor Core 等新硬件功能),研究者最高獲得了 335 TFLOPs/s。
H100 GPU 上的注意力前向 + 后向速度。
當用于端到端 GPT 類模型訓練時,FlashAttention-2 有助于在 A100 GPU 上實現最高 225 TFLOPs/s(模型 FLOPs 利用率為 72%)。與優化良好的 FlashAttention 模型相比,端到端實現 1.3 倍加速。
這里的基線是不使用 FlashAttention 的 Megatron-LM,它現在也可以選擇使用 FlashAttention 了。不久的將來,FlashAttention-2 也將集成到 Megatron-LM 中。
研究團隊表示:下一步將針對 H100 GPU 優化 FlashAttention-2,以使用新的硬件功能。