高性能 LLM 推理框架的設計與實現(xiàn)
一、大語言模型推理概要介紹
與傳統(tǒng)的 CNN 模型推理不同,大語言模型的推理通常會分成 prefill 和 decoding 兩個階段。每一個請求發(fā)起后產(chǎn)生的推理過程都會先經(jīng)歷一個 Prefill 過程,prefill 過程會計算用戶所有的輸入,并生成對應的 KV 緩存,再經(jīng)歷若干個 decoding 過程,每一個 decoding 過程,服務器都會生成一個字符,并將其放入到 KV 緩存當中,之后依次迭代。
由于 decoding 過程是逐個字符生成的,每一段答案的生成都需要很長時間,會生成很多字符,所以 decoding 階段的數(shù)量非常多,占到整個推理過程的 90% 以上。
在 Prefill 過程中,雖然計算量很大,因為要一次性完成用戶輸入的所有詞的計算,但它只是一次性的過程,所以在整個推理中只占不到 10% 的時間。
在大語言模型推理中常會用到四個指標:Throughput(吞吐量)、First Token Latency(首字延遲)、Latency(延遲)和QPS(每秒請求數(shù))。這四個性能指標會從四個不同的方面來衡量一個系統(tǒng)的服務提供能力。
首先來介紹 Throughput(吞吐量)。從模型推理層面上看,最先關注的就是吞吐量。吞吐量是指當系統(tǒng)的負載達到最大的時候,在單位時間內(nèi),能夠執(zhí)行多少個 decoding,即生成多少個字符。測試吞吐量的方法是,假設所有用戶都會在同一時刻到來,并且這些用戶問的都是一樣的問題,這些用戶可以同時啟動和結束,且他們生成的文本的長度和輸入的文本長度都是一樣的。通過使用完全相同的輸入,組成一個完整的 batch。在這種情況下,系統(tǒng)的吞吐量會達到最高。但這種情況是不合實際的,所以這是一個理論的最大值。我們會測量在一秒鐘之內(nèi),系統(tǒng)能夠執(zhí)行多少個獨立的 decoding 階段。
第二個指標是 First Token Latency(首字延遲)。指的是當一批用戶進入到推理系統(tǒng)之后,用戶完成 Prefill 階段的過程需要花多長時間。這也是系統(tǒng)生成第一個字符所需的響應時間。很多需求關注這一指標,希望用戶在系統(tǒng)上輸入問題后得到回答的時間小于 2~3 秒。
第三個指標是 Latency(延遲)。指的是每一個 decoding 所需要的時長。它反映的是大語言模型系統(tǒng)在線上處理的過程中,每生成一個字符的間隔是多長時間,也就是生成的過程有多么流暢。大部分情況下,我們希望生成的延遲小于 50 毫秒,也就是一秒鐘生成 20 個字符。這樣大語言模型的生成是比較流暢的。
最后一個指標是 QPS(每秒請求數(shù))。反映了在線上系統(tǒng)的服務當中,一秒鐘能夠處理多少個用戶的請求。這一指標的測量方式比較復雜,后面會展開介紹。
對于 First Token Latency 和 Latency 這兩個指標,我們都進行了相對完善的測試。這兩個指標會因為用戶輸入的長度不同、batch size 的不同而發(fā)生非常大的變化。
在上表中可以看到,對于同樣的 7B 模型,如果用戶的輸入長度從 8 變成 2048,Prefill 的時間將從 6.78 毫秒,直到變成 2078 毫秒,即 2 秒的時間。如果有 80 個用戶,每一個用戶都輸入 1,024 個詞,那么 Prefill 在服務端就要跑 2 秒左右,這個時間已經(jīng)超出了可以接受的范圍。但如果用戶輸入長度都很短,比如每次訪問只輸入 8 個詞,哪怕 768 個用戶同時到來,首字延遲也只有 165 毫秒左右。
與首字延遲最相關的就是用戶的輸入長度,用戶輸入的長度越長,首字延遲也會越高。用戶輸入長度如果很短,那么首字延遲在整個大語言模型推理過程中都不會成為瓶頸。
而后面的 decoding 延遲,通常只要不是千億級別的模型,decoding 的延遲都會控制在 50 毫秒以內(nèi)。它主要受到 batch size 的影響,batch size 越大,推理延遲也會越大,但基本上增加的幅度不會很高。
吞吐量其實也會受到這兩個因素的影響。如果用戶輸入的長度和生成的長度很長,那么系統(tǒng)吞吐量也不會很高。如果用戶輸入長度和生成長度都不是很長,那么系統(tǒng)吞吐量可能會達到一個非常離譜的程度。
再來看 QPS。QPS 是一個非常具體的指標,它表示系統(tǒng)中每秒可以處理多少個請求,在進行這個測試的時候,我們會使用實際的數(shù)據(jù)。(關于這份數(shù)據(jù),我們已經(jīng)做好了采樣,并且放在了 github 上。)
QPS 的測量跟吞吐量不太一樣,因為在實際使用大語言模型系統(tǒng)的時候,每一個用戶到來的時間是不確定的。有的用戶可能早來,有的用戶可能晚來,并且每一個用戶做完 Prefill 之后的生成長度也是不確定的。有的用戶可能生成 4 個詞就退出,有的用戶可能要生成 20 多個詞。
在 Prefill 階段,在實際線上推理當中,因為用戶實際生成長度不一樣,所以會遇到一個問題:有些用戶會提前生成完,而有些用戶要生成很多長度之后才會結束。在這樣的生成過程中,有很多地方的 GPU 會空閑。因此在實際的推理過程中,我們的 QPS 并不能夠發(fā)揮完全的吞吐量優(yōu)勢。我們的吞吐量可能很大,但實際的處理能力可能會很差,因為在這個處理過程當中充滿了無法使用顯卡的空洞。所以在 QPS 指標上,我們會有非常多的具體的優(yōu)化方案,避免計算的空洞或者無法有效利用顯卡的現(xiàn)象存在,從而使得吞吐量能夠完全服務到用戶上。
二、大語言模型推理性能優(yōu)化
接下來進入到大語言模型的推理流程當中,看看我們究竟做了哪些優(yōu)化,使得系統(tǒng)在 QPS 以及吞吐量等指標上都達到比較優(yōu)秀的情況。
1. LLM 推理過程
首先來詳細介紹一下大語言模型的推理過程,前文中提到了每個請求都要經(jīng)歷 prefill 和 decoding 兩個階段,在 prefill 階段,至少要做四件事情:
第一件事情是把用戶的輸入進行向量化,tokenize 的過程指的是將用戶輸入的文本轉換為向量,相對于 prefill 整個階段來說,大概要占掉 10% 的時間,這是有代價的。
之后就會進行真正的 prefill 計算,這一過程會占掉大概 80% 的時間。
計算之后會進行 sampling,這個過程在 Pytorch 里面一般會用 sample、top p。在大語言模型推理當中會用 argmax。總而言之,是根據(jù)模型的結果,生成最后詞的一個過程。這個過程會占掉 10% 的時間。
最后將 refill 的結果返回給客戶,這需要的時間會比較短,大概占 2% 到 5% 的時間。
Decoding 階段不需要 tokenize,每一次做 decoding 都會直接從計算開始,整個decoding 過程會占掉 80% 的時間,而后面的 sampling,也就是采樣生成詞的過程,也要占掉 10% 的時間。但它會有一個 detokenize 的時間,detokenize 是指生成了一個詞之后,這個生成的詞是個向量,需要把它解碼回文本,這一操作大概會占掉 5% 的時間,最后將這個生成的詞返回給用戶。
新的請求進來,在進行完 prefill 之后,會不斷迭代進行 decoding,每一個 decoding 階段結束之后,都會將結果當場返回給客戶。這樣的生成過程在大語言模型里面是很常見的,我們稱這樣的方式為流式傳輸。
2. 優(yōu)化:流水線前后處理與高性能采樣
這里要介紹的第一個優(yōu)化是流水線優(yōu)化,其目的是盡可能讓顯卡利用率占滿。
在大語言模型推理過程中,tokenize、fast sample 和 detokenize 這些過程都與模型的計算無關。我們可以把整個大語言模型的推理想象成這樣一個過程,在執(zhí)行 prefill 的過程中,當我拿到了 fast sample 的詞向量之后,就可以立刻開始下一個階段 decoding,不用等到結果返回,因為結果已經(jīng)在 GPU 上了。而當完成了一次 decoding 后,不用等待 detokenize 的完成,可以立刻開始下一次的 decoding。因為 detokenize 是個 CPU 過程,后面這兩個過程,只涉及到用戶的結果返回,不涉及任何 GPU 的運算。并且在執(zhí)行完采樣過程之后,就已經(jīng)知道下一個生成的詞是什么了,我們已經(jīng)拿到了所需的所有數(shù)據(jù),可以立刻開始下一次運算,不需要再等待后面兩個過程的完成。
在 PPL.LLM 的實現(xiàn)當中使用了三個線程池:
第一個線程池負責執(zhí)行 tokenize 過程;
第三個線程池負責執(zhí)行后面的 fast sample 以及返回結果的過程和 detokenize;
中間的線程池用來執(zhí)行 computing 的過程。
這三個線程池互相異步地把這三部分的延遲相互隔離,從而盡可能地將這三部分的延遲掩蔽掉。這將給系統(tǒng)帶來 10% 到 20% 的 QPS 提升,這就是我們所做的第一項優(yōu)化。
3. 優(yōu)化:動態(tài)批處理
在這之后,PPL.LLM 還可以執(zhí)行一項更有意思的優(yōu)化,叫做動態(tài)批處理。
前文中提到,在實際的推理過程當中,用戶的生成長度不同,并且用戶到達的時間也并不一樣。因此會存在這樣一種情況,如果當前的 GPU 在推理過程當中,已經(jīng)有一個請求在線上進行推理,在推理進行到一半時,第二個請求插入進來,這時第二個請求的生成過程會跟第一個請求的生成過程相沖突。因為我們只有一個 GPU,這個 GPU 上只能夠串形地跑任務,所以不能簡單地把它們在 GPU 上做并行。
我們的做法是,在第二個請求進入的時間點,把它的 prefill 階段和第一個請求對應的decoding 階段進行混合,生成一個新的階段稱為 Merge Step。在這個 Merge Step 中,不僅會進行第一個請求的 decoding,同時會進行第二個請求的 Prefill。這項功能在許多大語言模型推理系統(tǒng)中都會存在,它的實現(xiàn)使得大語言模型的 QPS 提升達到了 100%。
具體過程為,第一個請求生成過程進行了一半,意味著它在進行 decoding 時會有一個長度為 1 的輸入,而第二個請求是新進入的,在進行 Prefill 的過程當中,會有一個長度為 48 的輸入。將這兩個輸入沿著第一個維度相互拼接,拼接完的輸入長度為 49,并且 hidden dimension 是 4096 的輸入。在這長度為 49 的輸入當中,第一個詞是第一個請求的,剩下的 48 個詞是第二個請求的。
由于在大模型推理當中,所需要經(jīng)歷的算子,比如 RMSNorm、矩陣乘和 attention 等算子,不論是做 decoding 還是做 prefill,它們的結構都是不變的。因此拼接完的輸入可以直接放入到整個網(wǎng)絡中去跑。我們只需要在一個地方加以區(qū)分,那就是 attention。在 attention 的過程當中或者在執(zhí)行 self attention 算子的過程當中,我們會做一次數(shù)據(jù)分流,將所有做 decoding 的請求分流成一波,把所有做 prefill 的請求分流到另外一波,執(zhí)行兩個不同的運算。所有做 prefill 的請求,將會執(zhí)行 Flash Attention;所有做 decoding 的用戶,將會執(zhí)行一個非常特殊的算子,叫做 Decoding Attention。在分流執(zhí)行完 attention 算子之后,這些用戶的輸入會再一次被拼接到一起,完成其它算子的計算。
對于 Merge Step,實際上當每個請求到來的時候,我們都會把這個請求跟系統(tǒng)上現(xiàn)在已有的所有請求的輸入拼接在一起,完成這次計算,然后繼續(xù)往下不停地做 decoding,這是動態(tài)批處理在大語言模型中的實現(xiàn)。
4. 優(yōu)化:Decoding Attention
Decoding Attention 算子,不像 Flash Attention 算子那樣出名,但其實在處理 decoding 任務上比 Flash Attention 要快得多。
這是一種專門為 decoding 任務所設計的算子,完全依賴 Cuda Core,不會依賴Tensor Core 完成計算。它非常靈活并且容易修改,但它有一個限制,因為其特點是在 decoding 的 tensor 的運算當中,所以會要求輸入的 q 的長度必須是 1,但 k 和 v 的長度是可變的。這是 Decoding Attention 的限制,在這種限制下,我們可以做一些特定的優(yōu)化。
這種特定的優(yōu)化使得在 decoding 階段的 attention 算子的實現(xiàn),會比 Flash Attention 更快。這個實現(xiàn)目前也已經(jīng)開源,大家可以到上圖中的網(wǎng)址進行訪問。
5. 優(yōu)化:VM Allocator
另一項優(yōu)化是 Virtual Memory Allocator,對應 Page Attention 優(yōu)化。當請求來到之后,要進行 prefill 階段,又要進行 decoding 階段,它所有輸入的 token 會生成一個 KV 緩存,這個KV 緩存記錄了這個請求所有的歷史信息。那么要給這樣一個請求分配多長的 KV 緩存空間,才能滿足它完成此次生成任務呢?如果分的太多,顯存會有浪費,如果分的太少,在 decoding 階段,碰到了 KV 緩存的截止位置,就沒有辦法繼續(xù)往下生成。
為了解決這一問題,有 3 種方案。
Pytorch 的顯存管理方式是為每一個請求預留一片足夠長的空間,通常是 2048 或者 4096,能夠保證完成 4096 個詞的生成。但大部分用戶實際的生成長度不會有那么長,所以會有大量的內(nèi)存空間被浪費掉。
Page Attention 采用的是另外一種顯存管理方式。允許生成過程中不斷為用戶追加顯存。類似于操作系統(tǒng)中的頁式存儲或者內(nèi)存分頁。當一個請求來到之后,系統(tǒng)會為這個請求分配一小塊顯存,這一小塊顯存通常只夠生成 8 個字符,當請求生成了 8 個字符之后,系統(tǒng)會追加一塊顯存,可以把結果再寫到這塊顯存里面,同時系統(tǒng)會維護一個顯存塊和顯存塊之間的鏈表,從而使得算子可以正常地進行輸出。當生成的長度不斷變長時,會不斷地給用戶追加顯存塊的分配,并且可以動態(tài)維護顯存塊分配的列表,使系統(tǒng)不會存在大量浪費的資源,不需要為這個請求保留太多的顯存空間。
PPL.LLM 使用的是 Virtual Memory 的管理機制,為每一個請求預測一個它所需的生成長度。每個請求進來之后,都會直接為其分配一個連續(xù)的空間,這個連續(xù)空間的長度是預測出來的。但理論上看可能難以實現(xiàn),尤其到了線上推理階段,不太可能清楚地知道每個請求究竟要生成多長的內(nèi)容。因此我們推薦訓練一個模型去做這件事情。因為即使我們采用了 Page Attention 這樣的模式,依然會遇到問題。Page Attention 在運行的過程中,具體到一個特定的時間點,比如當前系統(tǒng)上已經(jīng)有了四個請求,系統(tǒng)里面還剩余 6 塊顯存沒有被分配。這時我們無法知道是否會有新的請求進來,能否為其繼續(xù)提供服務,因為當前的四個請求還沒有結束,可能未來還要繼續(xù)為它們追加新的顯存塊。所以即使是 Page Attention 機制,還是需要預測每一個用戶實際的生成長度。這樣才知道在具體的一個時間點上能不能接受一個新的用戶的輸入。
這是我們目前所有的推理系統(tǒng)都沒有做到的事情,包括 PPL 目前也沒有實現(xiàn)。但 Virtual Memory 的管理機制,還是讓我們很大程度上避免了顯存的浪費,從而使系統(tǒng)整體的 QPS 提升達到 200% 左右。
6. 優(yōu)化:KV Cache 量化
PPL.LLM 在做的另外一項優(yōu)化,就是 KV 緩存的量化,在服務端推理的過程當中,KV 緩存會占據(jù)絕大部分的顯存空間,這會嚴重限制系統(tǒng)的并發(fā)請求數(shù)量。
可以看到,在服務端,特別是 A100、H100 這樣的大顯存的服務器上運行如 7B 模型這樣的大語言模型時,它的 KV 緩存將占到 84% 的顯存空間,而對于如 176B 這樣的千億級大模型,它的 KV 緩存也將占用 50% 以上的緩存空間。這會嚴重地限制模型的并發(fā)數(shù)量,每一個請求到來后,都需要給它分配很大的顯存。這樣請求數(shù)量就無法提升上去,繼而使得 QPS 以及吞吐量都無法提升。
PPL.LLM 使用了一種非常特殊的量化方式,分組量化對 KV 緩存的數(shù)據(jù)進行壓縮。也就是對原來 FP16 的數(shù)據(jù),會嘗試把它量化到 INT8。這樣會使 KV 緩存的體積縮小 50%,并使得服務端能夠容納的請求數(shù)量增加 100%。
之所以相比 Faster Transformer 能夠提升大約 50% 的吞吐量,正是得益于 KV 緩存量化所帶來的 batch size 的提升。
7. 優(yōu)化:矩陣乘法量化
在 KV 緩存量化之后,我們進行了更細力度的矩陣乘法的量化。在整個服務端推理的過程當中,矩陣乘法占到整個推理時間的 70% 以上,PPL.LLM 使用了一種動態(tài)的 per-channel/per-token 交替的混合量化方式來加速矩陣乘法。這些量化同樣是精度極高的,并且能夠提升接近 100% 的性能。
具體做法是,在 RMSNorm 算子的基礎之上,融合一個量化算子,這個量化算子會在 RMSNorm 算子的功能基礎之上統(tǒng)計其 Token 信息,統(tǒng)計每一個 token 的最大最小值,并且沿著 token 的維度,把這個數(shù)據(jù)進行量化。也就是說經(jīng)過了RMSNorm 之后的數(shù)據(jù)將會從 FP16 轉成 INT8,并且這一次量化是全動態(tài)的,不需要做 calibration。而在后面的 QKV 矩陣乘當中,這三個矩陣乘都將進行 per-channel 量化。它們接收的數(shù)據(jù)是 INT8 的,同樣它們的權重也是 INT8 的,所以這些矩陣乘可以完整地執(zhí)行 INT8 的矩陣乘法。它們的輸出將會被 Soft Attention 接受,但在接受之前會執(zhí)行一次解量化過程,這次解量化過程將和 soft attention 算子融合。
而后面的 O 矩陣乘法是不做量化的,Soft Attention 本身的計算過程也不做任何量化。在后續(xù)的 FeedForward 過程當中,這兩個矩陣同樣采用一樣的方式進行量化,和上面的 RMSNorm 進行融合,或者與上面的 Silu 和 Mul 這樣的激活函數(shù)進行融合。它們的解量化算子將和其下游算子進行融合。
8. 優(yōu)化:INT8 vs INT4
目前學術界對于大語言模型的量化關注點可能主要集中在 INT4 上,但是在服務端推理的過程中,其實更適合使用 INT8 的量化。
INT4 的量化也叫 Weight Only 的量化,這種量化方式出現(xiàn)的意義在于,當大語言模型推理過程中 batch 比較小時,在矩陣乘法的計算過程中,90% 的時間都會用來加載權重。因為權重的體積非常大,而加載輸入的時間很短,它們的輸入,即 activation 也非常短,計算的時間也不會很長,寫回結果的時間同樣不會很長,這意味著這個算子是一個訪存密集型的算子。在這種情況下,我們會選用 INT4 的量化,前提是 batch 足夠的小,使用 INT4 的量化每一次加載權重之后,會緊接著進行一個解量化的過程。這次解量化會把權重從 INT4 解量化成 FP16,經(jīng)歷解量化過程之后,后面的計算和 FP16 是完全一樣的,也就是說 INT4 Weight Only 的量化適用于訪存密集性的矩陣乘法,其計算過程還是由 FP16 的運算器件去完成的。
當 batch 足夠大,比如 64 或者 128 時,INT4 的 Weight Only 量化將不會帶來任何性能提升。因為如果 batch 足夠大,那計算時間會被拉得很長。并且 INT4 Weight Only 量化有一個非常不好的點,它的解量化過程所需要的計算量是會隨著 batch 的(GEMM Batch)提升而提升的,隨著輸入 batch 的提升,解量化的時間也會越來越長。當 batch 達到 128 的時候,解量化所帶來的時間損耗和加載權重帶來的性能優(yōu)勢,就已經(jīng)相互抵消了。也就是說當 batch 達到 128 之后,INT4 的矩陣量化不會比 FP16 矩陣量化快,性能優(yōu)勢極小。大概在 batch等于 64 的時候,INT4 的 Weight Only 量化只會比 FP16 的快 30%,等到 128 的時候,大約只會快 20% 甚至更小。
但對于 INT8 來說,INT8 的量化與 INT4 量化最不同的一點,是它不需要任何解量化的過程,并且它的計算是可以壓縮一倍時間的。在 batch 等于 128 時,從 FP16量化到 INT8,加載權重的時間將會減半,計算的時間也會減半,這會帶來百分之百的加速。
在服務端場景下,特別是因為會有不斷的請求涌入,大部分的矩陣乘,都會是計算密集型的。在這種情況下,如果為了追求極限的吞吐量,INT8 的效率其實是高于 INT4 的。這也是為什么我們目前已經(jīng)完成的實現(xiàn)里面,在服務端上主推 INT8 的一個原因。
9. 優(yōu)化:FP8 vx INT8
在 H100、H800、4090 上面,我們可能會執(zhí)行 FP8 的量化。FP8 這樣的數(shù)據(jù)格式,在 Nvidia 最新一代的顯卡當中被引入。INT8 的精度從理論上是要高于 FP8 的,但是 FP8 會更好用,性能會更高一些。我們在后續(xù)服務端的推理過程的更新當中也會推進 FP8 的落地。上圖中可以看到,F(xiàn)P8 的誤差相比 INT8 會大 10 倍左右。INT8 會有一個量化的尺寸因子,可以通過調(diào)整尺寸因子,降低 INT8 的量化誤差。而 FP8 的量化誤差跟尺寸因子基本上是無關的,它不受尺寸因子的影響,這使得我們基本上不需要對它做任何的 calibration。但是它的誤差總體來講是要高于 INT8 的。
10. 優(yōu)化:INT4 vs 非線性量化
PPL.LLM 在后續(xù)的更新中,也會更新 INT4 的矩陣量化。這種 Weight Only 的矩陣量化主要是為端側服務的,為了手機端移動端等 batch 固定為 1 的設備。在后續(xù)的更新當中會從 INT4 逐漸轉變?yōu)榉蔷€性量化。因為在 Weight Only 的計算過程當中,會存在一個解量化的過程,這個解量化過程實際是可定制的,未必是一個線性的解量化過程,其使用其它解量化過程以及量化過程,會使得這一次計算的精度更高。
一個比較典型的例子,就是在一篇論文當中所提到的 NF4 的量化,這種量化實際上會通過一種打表的方式進行量化及解量化,是一種非線性的量化。PPL.LLM 的后續(xù)更新當中會嘗試使用這樣的量化來完成端側推理的優(yōu)化。
三、大語言模型推理的硬件
最后,介紹一下大語言模型處理的硬件。
模型結構一旦確定,我們就會知道它具體的計算量,具體需要多少訪存,需要多少計算量。同時還會知道每張顯卡的帶寬、算力、價格等。在確定了模型的結構以及確定了硬件指標之后,我們就可以通過這些指標去計算出在這張顯卡上推理大模型的最大吞吐量會是多少、計算延遲是多少、訪存訪問時間需要多少,可以算出一個非常具體的表。我們把這個表格公開在后續(xù)的資料當中,大家可以訪問這個表格,查看最適合大語言模型推理的顯卡型號有哪些。
對于大語言模型推理來說,因為大部分算子都是訪存密集型的,訪存的延遲總會比計算延遲要高。因為大語言模型的參數(shù)矩陣確實太大了,所以哪怕是在 A100/80G 上,batch size 開到 272 的時候,它的計算延遲都是較小的,訪存延遲反而會更高。因此,我們的許多優(yōu)化都是從訪存上著手的。而進行硬件選擇時,我們主要的方向就是選擇帶寬比較高、顯存比較大的設備。從而使得大語言模型在推理時,可以支撐更多的請求,支撐更快的訪存,相應的吞吐量也會更高。
以上就是本次分享的內(nèi)容。所有相關資料都放在了網(wǎng)盤中,鏈接參見上圖。我們所有的代碼也已經(jīng)開源在了 github 上。歡迎大家隨時與我們進行溝通。
四、Q & A
Q1:PPL.LLM 中有沒有優(yōu)化像 Flash Attention 中的 Softmax 這種訪存的問題?
A1:Decoding Attention 這個算子非常特殊,它的 Q 的長度永遠是 1,所以它不會像Flash Attention 那樣面臨 Softmax 里有非常大的訪存量。實際上在 Decoding Attention 的執(zhí)行過程當中,就是完整地執(zhí)行這次 Softmax 的過程,并不需要像 Flash Attention 那樣更快執(zhí)行。
Q2:INT4 的 Weight Only 量化為什么和 batch 線性相關,請問這是固定數(shù)量嗎?
A2:這是一個好問題,首先這個解量化不是像大家想的那樣,只需要把權重從 INT4 塞回 FP16 就行了,如果只做這件事情,那權重有多少就要解多少。實際上不是這樣的,因為這是一個融合在矩陣乘法里面的解量化,不能在執(zhí)行矩陣乘法之前,把所有權重解量化出來,放在那然后再去讀。這樣我們所做的 INT4 的量化就沒有意義了。它是在執(zhí)行過程當中不停地去解量化,因為我們會執(zhí)行分塊的矩陣乘,每一個權重所要讀寫的次數(shù)并不是 1,需要不停地拿過來計算,這個次數(shù)實際上跟 batch 有關。也就是區(qū)別于之前那些優(yōu)化量化的手段,會有單獨的量化的算子和解量化算子。兩個算子的插入,解量化還是直接融合在算子中的。我們執(zhí)行的是矩陣乘法,所以所要解量化的次數(shù)并不是一次。
Q3:KV Cache 中的反量化計算,可以被仿存掩蓋?
A3:根據(jù)我們的測試是可以被掩蓋的,而且其實還遠遠有剩余。KV 計算中的反量化以及量化都會被融合進 self attention 算子當中,具體來說就是 Decoding Attention。根據(jù)測試,這個算子即使在 10 倍的計算量,可能都可以掩蓋掉。就是訪存的延遲都掩蓋不了它,它主要的瓶頸在于訪存,它計算量還遠遠達不到可以掩蓋掉它訪存的那個程度。所以 KV cache 當中的反量化計算,對于這個算子來說,基本上是一個很好被掩蓋的東西。