為什么大語言模型難以處理長上下文?從 Transformer 到 Mamba 原創(chuàng) 精華
編者按: 大語言模型真的能像人類一樣高效處理海量信息嗎?我們今天為大家?guī)淼倪@篇文章,作者揭示了大語言模型在長上下文處理中的技術挑戰(zhàn)與未來發(fā)展路徑。
文章重點聚焦于三個關鍵層面:首先,解析了 Transformer 模型注意力機制的計算成本問題,指出隨著上下文長度增加,計算復雜度呈指數(shù)級增長;其次,探討了 Mamba 等新興架構在突破傳統(tǒng)模型局限性方面的潛力;最后,強調需要跳出現(xiàn)有思維模式,尋找處理海量信息的創(chuàng)新方法。
作者 | Timothy B. Lee
編譯 | 岳揚
OpenAI 在兩年前推出 ChatGPT 時,其能夠處理的上下文信息僅有 8,192 個 tokens1。換言之,如果輸入的文本超過大約 15 頁,它就會“遺忘”最初的上下文內容。這一限制使得 ChatGPT 在處理任務時的規(guī)模和復雜度都受到了影響。
而現(xiàn)今的 LLMs 能力有了顯著提升:
- OpenAI 的 GPT-4o[1] 現(xiàn)在能夠處理多達 128,000 個 tokens 的上下文。
- Anthropic 的 Claude 3.5 Sonnet[2] 可以處理 200,000 個 tokens 的上下文。
- Google 的 Gemini 1.5 Pro[3] 更是擁有 2 百萬個 tokens 的上下文處理能力。
盡管如此,要想讓 AI 系統(tǒng)達到人類水平的認知能力,我們還需要取得更多的進步。
許多人展望未來,認為 AI 將能夠承擔大部分甚至全部的人類工作。然而,人類在工作生涯會閱讀和聽到數(shù)以億計的文字,并且還能通過視覺、聽覺和嗅覺從周圍環(huán)境中獲取更多信息。要使 AI 達到人類智能水平,它們也需要具備處理如此大量信息的能力。
目前,處理大量信息的最流行 LLM 系統(tǒng)構建方法是“檢索增強生成”(RAG)。這類系統(tǒng)會尋找與用戶查詢相關的文檔,并將最相關的部分嵌入到 LLM 的上下文中。
盡管 RAG 系統(tǒng)在某些情況下能夠有超越傳統(tǒng)搜索引擎的表現(xiàn),但目前這類系統(tǒng)仍存在諸多不足。它們只有在成功將最關鍵的文檔嵌入 LLM 的上下文時,才能產出滿意的結果。然而,用于檢索這些文檔的技術,通常是在向量數(shù)據(jù)庫[4]中進行搜索 —— 并不夠精細。如果用戶提出的問題復雜或含糊不清,RAG 系統(tǒng)很可能會錯誤地檢索文檔,導致聊天機器人給出錯誤的回答。
此外,RAG 系統(tǒng)并未讓 LLM 在處理大量文檔時展現(xiàn)出更高級的推理能力:
- 例如,律師可能需要 AI 幫助審閱和總結數(shù)十萬封電子郵件。
- 工程師可能需要 AI 分析數(shù)千小時的工廠監(jiān)控視頻。
- 醫(yī)學研究者可能需要 AI 在數(shù)以萬計的患者病歷中識別趨勢。
這些任務任何一個都可能需要超過 200 萬個 tokens 的上下文處理能力。而且,我們希望 AI 系統(tǒng)在完成這些任務后,不是一切從頭開始,而是能夠像人類工作者一樣,通過經驗積累不斷提升。計算機的超強記憶力和耐力一直是其重要優(yōu)勢,在 AI 時代,我們并不想放棄這些特性。但目前 LLMs 在吸收和解讀大量信息的能力上,還遠未能達到人類水平。
確實,LLMs 在訓練過程中吸收的信息量遠遠超過了人類。最新的人工智能模型已經在數(shù)萬億個 tokens 上進行了訓練,這遠遠超過了一個人一生中所能閱讀或聽到的信息量。然而,許多有價值的資料是保密的、具有時效性的,或者因為其他原因無法用于訓練。
因此,我們希望 AI 模型在推理階段能夠閱讀并記住遠超 200 萬個 tokens 的信息。但這并非易事。
基于 transformer 的 LLMs 的核心創(chuàng)新在于“注意力”機制,這是一種數(shù)學運算,使得模型能夠“回顧”之前的 tokens。在 LLM 生成新 token 之前,它會執(zhí)行一次注意力操作,將當前 token 與之前的所有 tokens 進行比較。這導致傳統(tǒng)的 LLMs 在上下文增長時效率逐漸降低。
目前,許多人正在研究解決這一問題的方法,我將在本文后續(xù)部分討論其中的一些方案。但在此之前,我需要解釋一下我們是如何從一開始就形成了這樣一個復雜的架構。
01 GPUs 讓深度學習成為現(xiàn)實
個人電腦的核心 —— 中央處理單元(CPUs) ,曾是通過提高時鐘頻率來提升性能的。但進入 21 世紀初期,由于過熱問題,芯片制造商大多放棄了這種提速方法。
芯片制造商轉而開始研發(fā)能夠同時處理多個指令的 CPU[5]。然而,它們的進步受到了需要指令按順序執(zhí)行的傳統(tǒng)編程模式的限制。
為了充分發(fā)揮摩爾定律[6]的潛力,一種全新的架構應運而生,那就是 Nvidia 推出的 GPUs。
1999 年,Nvidia 開始銷售 GPU,旨在加快 3D 游戲如《Quake III Arena》的渲染速度。這些作為 PC 擴展卡的 GPU,任務是迅速繪制構成游戲中墻壁、武器、怪物等物體的成千上萬的三角形。
這種任務不需要順序編程:屏幕上不同區(qū)域的三角形可以任意順序繪制。因此,Nvidia 的首款 GPU[7] 并不是采用單個處理器逐個執(zhí)行指令,而是擁有十幾個專用核心 —— 類似于微型的CPU —— 它們并行作業(yè),共同繪制場景。
隨著摩爾定律的發(fā)展,Nvidia 制造的 GPU 計算核心數(shù)量從數(shù)十個增加到數(shù)百個,最終甚至達到數(shù)千個。人們逐漸意識到,GPU 強大的并行計算能力不僅可以用于視頻游戲,還能應用于其他領域。
2012 年,多倫多大學的計算機科學家 Alex Krizhevsky、Ilya Sutskever 和 Geoffrey Hinton 利用兩塊 Nvidia GTX 580 GPUs[8] 訓練了一個用于圖像識別的神經網絡。這兩塊 GPU 各自擁有 512 個核心,提供了巨大的計算能力,使他們能夠訓練出一個擁有 6000 萬個參數(shù)的神經網絡。他們在 ImageNet 圖像分類競賽[9]中取得了新的準確率紀錄[10](該競賽的目標是將圖像歸類到 1000 個不同類別之一)。
不久之后,研究人員開始將這些技術應用于更多領域,自然語言處理便是其中之一。
02 Transformers 打破了自然語言理解的瓶頸
在 2010 年代初,循環(huán)神經網絡(RNNs)是處理自然語言的主流架構。RNNs 采用逐詞處理的方式。神經網絡在處理完每個單詞后,會更新其隱藏狀態(tài)(hidden state),這是一組數(shù)字,代表了神經網絡對句子當前理解的程度。
RNNs 在處理短句時表現(xiàn)尚可,但面對長句時就顯得力不從心,更別提段落或更長的文本了。在分析長句時,RNN 有時會“遺忘”句子開頭的關鍵詞。 2014 年,計算機科學家 Dzmitry Bahdanau、KyungHyun Cho 和 Yoshua Bengio 發(fā)現(xiàn)[11],通過引入一個注意力機制,允許網絡“回顧”句子中的早期單詞,可以提升循環(huán)神經網絡的性能。
2017 年,谷歌發(fā)布了《Attention Is All You Need》[12]這篇論文,它被譽為機器學習史上最重要的論文之一。在 Bahdanau 及其團隊的研究基礎上,谷歌的研究人員摒棄了 RNN 及其隱藏狀態(tài)的概念。他們采用的模型利用注意力機制來掃描先前的單詞,以獲取相關的上下文信息。
這種被谷歌命名為 transformer 的新架構,其重要性不言而喻,因為它消除了擴展語言模型的一個關鍵障礙。
以下動畫展示了 RNNs 為何難以擴展:
在這個假想的 RNN2 中,神經網絡試圖預測句子中的下一個單詞,預測結果展示在圖表的頂部。這個神經網絡由三層組成,每層用一個矩形表示。它的處理方式是線性的:必須先完成對第一個單詞“How”的分析,然后將隱藏狀態(tài)傳遞回底層,網絡才能開始分析第二個單詞“are”。
這種限制在機器學習算法在 CPU 上運行時還不是大問題。但當人們開始利用 GPU 的并行計算能力時,RNN 的線性架構就成為了瓶頸。
transformer 通過讓神經網絡能夠同時“思考”輸入中的所有單詞,從而突破了這一限制:
如圖所示,基于 transformer 的模型進行的計算量與前面圖中的 RNN 模型相當。因此,在(單核)CPU 上,它的運行速度可能不會更快。但由于模型不需要在處理“are”、“you”或“doing”之前完成對“How”的分析,它可以同時處理這些單詞。這意味著在擁有多個并行執(zhí)行單元的 GPU 上,它的運行速度可以大幅提升。
速度提升有多大?速度的潛在提升與輸入單詞的數(shù)量成正比。以我的動畫為例,transformer 模型處理四詞輸入的速度大約是 RNN 的四倍。而對于 LLMs,其輸入可能包含數(shù)千個單詞。因此,在強大的 GPU 支持下,基于 transformer 的模型速度可以比類似的 RNN 快出幾個數(shù)量級。
你可能會問,為何不能同時用多個文檔來訓練RNN—— 即在文檔層面而非單個單詞層面利用 GPU 的并行處理能力。
這是因為訓練過程的第二階段——反向傳播的限制。在這個過程中,訓練軟件會“逆向”工作,通過微積分調整模型的參數(shù),以提高得出正確答案的概率。對于 RNN 來說,反向傳播需要從輸入的最后一個單詞反向追溯到第一個單詞。如下圖中紅色箭頭所示:
反向傳播需要保存前向傳遞中每一步的中間結果——也就是說,訓練軟件需要存儲圖表中每個矩形的輸出。對于大模型,這些數(shù)據(jù)占用的空間極大,以至于無法同時并行訓練大量實例。 3
簡而言之,transformer 釋放了 GPU 的全部處理能力,推動了語言模型規(guī)模的飛速增長。領先 LLMs 的參數(shù)量從 2018 年的數(shù)億[13]增長到了 2020 年的數(shù)千億[14]。由于傳統(tǒng)的基于 RNN 的模型受到線性架構的限制,它們無法在 GPU 上高效訓練,因此無法達到如此龐大的規(guī)模。
03 Transformers 模型存在擴展問題
我曾提到,在本文動畫中,循環(huán)神經網絡與 transformer 模型“大致完成了相等的工作量”。然而,兩者的工作量并非完全一致。我們再來看看 transformer 模型的工作圖:
注意到各層間那些交錯的對角線箭頭了嗎?它們代表了注意力機制的運轉。基于 transformer 的語言模型在創(chuàng)造新 token 前,會“審視”之前每一個已有的標記,以確定哪些最為相關。
在較小規(guī)模的上下文中,這些比較的成本微不足道。例如,對于僅有 10 個、100 個甚至 1000 個 tokens 的上下文,這些成本并不構成負擔。但隨著上下文長度的增加,注意力機制的計算成本也隨之攀升。上下文越長,為了生成下一個 token,所需的注意力操作(以及相應的計算資源)就越多。
這導致了一個問題:注意力機制總的計算能力需求與 tokens 總數(shù)成二次方關系增長。舉例來說,如果一個 10 個 tokens 的提示詞需要 414,720 次注意力操作4,那么:
- 處理一個 100 個 tokens 的提示詞,將需要 4560 萬次注意力操作。
- 處理一個 1000 個 tokens 的提示詞,將需要 46 億次注意力操作。
- 處理一個 10000 個 tokens 的提示詞,將需要 4600 億次注意力操作。
這或許也解釋了為何當上下文超過 128,000 個 tokens 時,谷歌會對 Gemini 1.5 Pro 的收費翻倍。因為生成第 128,001 個 token 時,需要與前面 128,000 個 tokens 進行比較,其成本遠高于生成第一個、第十個或第一百個 token。
04 提升注意力的效率和可擴展性
研究者們投入了大量精力優(yōu)化注意力機制。其中一條研究路徑旨在最大化單個 GPU 的運算效率。
我們在前文了解到,現(xiàn)代 GPU 包含了成千上萬的執(zhí)行單元。但在 GPU 開始進行數(shù)學運算之前,它需要將數(shù)據(jù)從較慢的共享內存(即高帶寬內存)轉移到特定執(zhí)行單元內更快的內存(即SRAM)。有時,GPU 在移動數(shù)據(jù)上耗費的時間甚至超過了執(zhí)行計算的時間。
在一系列論文中[15][16][17],普林斯頓大學的計算機科學家 Tri Dao 及其合作者開發(fā)了 FlashAttention,這種計算注意力的方式能夠最大限度地減少慢速內存操作的需求。Dao 等人的工作顯著提升了現(xiàn)代 GPU 上 transformers 的表現(xiàn)。
另一條研究路徑則著眼于如何在多個 GPU 上高效擴展注意力。其中一篇被廣泛引用的論文介紹了環(huán)形注意力機制(ring attention)[18],它通過將 input tokens 分成塊,并將每個塊分配給不同的 GPU 來工作。之所以稱為環(huán)形注意力,是因為 GPU 被構想為一個環(huán)形結構,每個 GPU 將其數(shù)據(jù)傳遞給相鄰的 GPU。
這讓我想起了曾參加過的一堂交誼舞課,舞伴們圍成一圈,女性保持不動,而男性則輪換舞伴。最終,每個男性都能與每位女性共舞。環(huán)形注意力的原理與之類似。"女性"代表查詢(query)向量(描述每個 token 所“尋找”的內容),"男性"代表鍵(key)向量(描述每個 token 的特征)。鍵向量在一連串 GPU 中傳遞,依次與所有查詢向量相乘。
總的來說,環(huán)形注意力機制通過在多個 GPU 間分配計算任務,使得大語言模型(LLM)能夠處理更大的上下文窗口。然而,它并未降低單個注意力計算的成本。
05 RNN 能否卷土重來?
由于 RNN 擁有固定大小的隱藏狀態(tài)(hidden state),因此它不會存在與 transformer 相同的擴展難題。無論是生成第一個、第一百個還是第一百萬個 token,RNN 所需的計算資源都相差無幾。這一點,相較于基于注意力機制的模型,RNN 具有顯著優(yōu)勢。
盡管在 transformer 問世后,RNN 的地位有所下滑,但研究者們并未放棄,他們繼續(xù)探索適合在現(xiàn)代 GPU 上訓練的 RNN 新版本。
今年 4 月,谷歌推出了一款名為 Infini-attention[19] 的新模型。這個模型可謂是 transformer 與 RNN 的“混血兒”。Infini-attention 像傳統(tǒng) transformer 那樣處理最近的 tokens,利用注意力機制記住它們并召回它們。
不過,Infini-attention 并未試圖記住所有上下文中的 tokens。相反,它采用一種“壓縮記憶(compressive memory)”來存儲較舊的 tokens,這種方式與 RNN 的隱藏狀態(tài)有幾分相似。這種數(shù)據(jù)結構能夠完美地存儲和召回少量 tokens,但隨著 tokens 數(shù)量的增加,召回率也會越來越低。
然而,機器學習領域的 YouTube 紅人 Yannic Kilcher 對谷歌的這種做法并不感冒[20]。
“我非常愿意相信這個方法確實有效,也認同這是實現(xiàn)無限注意力的一種途徑,但我還是持懷疑態(tài)度,”Kilcher表示。“它使用的是一種邊走邊存的壓縮記憶方法,并沒有真正學會如何存儲,只是按照一種確定性的方式在存儲,這意味著我們對存儲的內容和方式幾乎沒有控制權?!?/p>
06 Mamba 會是未來嗎?
在復興循環(huán)神經網絡(RNN)的眾多嘗試中,Mamba 架構無疑是最引人注目的。它是 2023 年 12 月發(fā)表的一篇論文[21]中公布的一種架構,其開發(fā)者是計算機科學家 Tri Dao(他也是我之前提到的 FlashAttention 的負責人)和 Albert Gu。
與傳統(tǒng)的 RNN 一樣,Mamba 并不依賴于注意力機制。它擁有一個充當“記憶”角色的隱藏狀態(tài)。由于這個隱藏狀態(tài)的大小是固定的,因此即使輸入的提示詞更長,也不會增加 Mamba 處理每個 token 的成本。
我在三月著手撰寫這篇文章時,本打算深入剖析 Mamba 的架構。然而,到了五月,研究團隊推出了 Mamba-2[22],其架構較之初代 Mamba 有了顯著的改變。坦白說,我一直在努力理解初代 Mamba 的原理,而對于 Mamba-2 的工作機制更是尚未完全弄清。
然而,我們需要明白的是,Mamba 有潛力將 transformer 模型的性能和傳統(tǒng) RNN 的效率結合起來。
在六月,Dao 和 Gu 與 Nvidia 的研究人員合作發(fā)表了一篇論文[23],對擁有 80 億參數(shù)的 Mamba 模型進行了評估。研究發(fā)現(xiàn),Mamba 模型在多項任務中與同等規(guī)模的模型不相上下,但在“上下文學習”和“從上下文中提取信息”的能力上,Mamba 模型仍略遜一籌。
transformer 模型之所以擅長信息提取,是因為它們能夠“記住”上下文中的每一個 token —— 這也是為什么隨著上下文長度的增加,transformer 模型的效率會降低。而 Mamba 則試圖將整個上下文壓縮到一個固定大小的狀態(tài)中,這意味著在處理長上下文時,它不得不舍棄一部分信息。
Nvidia 團隊發(fā)現(xiàn),通過采用一種混合架構,該架構將 24 個 Mamba 層與 4 個注意力層交錯排列,他們獲得了最佳性能。這種混合架構的表現(xiàn)優(yōu)于單純的 transformer 模型或單純的 Mamba 模型。
模型需要一些注意力層來記住其早期上下文中的關鍵細節(jié)。但是,似乎只需要少量的注意力層就夠了;其余的注意力層可以由成本更低的 Mamba 層替換,而對模型的整體性能影響很小。
在八月,一家名為 AI21 的以色列初創(chuàng)公司發(fā)布了其 Jamba 1.5 系列模型[24]。其中最大版本的參數(shù)數(shù)量達到了 3980 億,使其在規(guī)模上與 Meta 的 Llama 405B 模型相當。Jamba 1.5 Large 模型的 Mamba 層數(shù)量是注意力層的七倍。因此,Jamba 1.5 Large 所需的內存遠少于 Meta 和其他公司的同類模型。例如,AI21 估計 Llama 3.1 70B 需要 80 GB 的內存來跟蹤 256,000 個上下文 token ,而 Jamba 1.5 Large 只需要 9 GB,這使得模型能夠在性能較弱的硬件上運行。
Jamba 1.5 Large 模型的 MMLU 得分為 80,顯著低于 Llama 3.1 70B 的 86 分。因此,按照這個標準,Mamba 并沒有完全超越 transformer 模型。然而,這可能并不是一個完全公平的比較。像 Meta 這樣的前沿實驗室在訓練數(shù)據(jù)和后訓練基礎設施上投入了大量資金,以在 MMLU 等基準測試中提高幾個百分點的性能。同樣的高強度優(yōu)化可能會縮小 Jamba 與前沿模型之間的差距。
因此,雖然更長上下文窗口的好處顯而易見,但達到這一目標的最優(yōu)策略尚不明確。短期內, AI 公司可能會繼續(xù)使用巧妙的效率和擴展技巧(如 FlashAttention 和 Ring Attention)來擴展標準的 LLMs。長期來看,我們可能會看到對 Mamba 以及其他無注意力架構的興趣日益增長?;蛘咭苍S有人會提出一種全新的架構,使 transformers 過時。
但我確信,僅僅依靠擴大基于 transformers 的前沿模型規(guī)模并不是一個完整的解決方案。如果我們想要能夠處理數(shù)十億個 tokens 的模型——許多人都有這樣的需求,我們就需要跳出固有的思維模式,尋找新的方法。
1 有網絡消息指出,ChatGPT 起初設定的上下文窗口為 4,096 個 tokens,但發(fā)布后不久的一次實驗[25]顯示,它能夠記憶超過這個數(shù)量的信息。
2 在十年前,循環(huán)神經網絡(RNN)通常會包含編碼器和解碼器兩部分,而像 GPT-3 這樣的現(xiàn)代大語言模型(LLM)則只有解碼器。出于教學目的,我展示了一個與歷史不符的、僅包含解碼器的 RNN 模型,這樣可以更容易地與 GPT-3 等現(xiàn)代 LLM 進行對比。同樣的分析方法也適用于 2010 年代初的真實 RNN 模型,但那時的模型圖會更加復雜。
3 GPU 能夠將中間計算結果傳輸?shù)剿拇笕萘扛邘拑却嬷?。但是,由于高帶寬內存(HBM)[26]的速度限制,這一操作并不會提升訓練速度。
4 這是我針對擁有 1750 億參數(shù)版本的 GPT-3 的一個初步估算,該模型包含 96 層,每層有 96 個注意力頭。因此,實際上每對 tokens 之間需要進行 9,216 次注意力計算。
5 Jamba 模型是一種混合專家模型,這意味著對于任何一個 token,只有網絡中的一部分(980 億參數(shù)中的 3980 億)會被激活和使用。
Thanks for reading!
Hope you have enjoyed and learned new things from this blog!
About the author
Timothy B. Lee
I write the newsletter Understanding AI and cohost the AI Summer podcast. Previously I was a reporter at Ars Technica, Vox, and the Washington Post. ??twitter.com/binarybits??
END
本期互動內容 ??
?在您看來,長上下文處理能力對 AI 發(fā)展意味著什么?人類和 AI 在信息處理能力上最大的差距是什么?
??文中鏈接??
[1]??https://platform.openai.com/docs/models/gp??
[2]??https://www.anthropic.com/news/claude-3-5-sonnet??
[3]??https://blog.google/technology/ai/google-gemini-update-flash-ai-assistant-io-2024/??
[4]??https://en.wikipedia.org/wiki/Vector_database??
[5]??https://en.wikipedia.org/wiki/Multithreading_(computer_architecture)??
[6]??
[7]??https://en.wikipedia.org/wiki/GeForce_256??
[8]??https://www.techpowerup.com/gpu-specs/geforce-gtx-580.c270??
[9]??https://www.understandingai.org/p/why-the-deep-learning-boom-caught??
[10]??https://arstechnica.com/science/2018/12/how-computers-got-shockingly-good-at-recognizing-images/3/??
[11]??https://arxiv.org/pdf/1409.0473??
[12]??https://arxiv.org/abs/1706.03762??
[13]??https://en.wikipedia.org/wiki/GPT-1??
[14]??https://en.wikipedia.org/wiki/GPT-3??
[15]??https://arxiv.org/abs/2205.14135??
[16]??https://arxiv.org/abs/2307.08691??
[17]??https://arxiv.org/abs/2407.08608??
[18]??https://arxiv.org/abs/2310.01889??
[19]??https://arxiv.org/abs/2404.07143??
[20]??https://www.youtube.com/watch?v=r_UBBfTPcF0&t=2s??
[21]??https://arxiv.org/abs/2312.00752??
[22]??https://arxiv.org/abs/2405.21060??
[23]??https://arxiv.org/abs/2406.07887??
[24]??https://arxiv.org/abs/2408.12570??
[25]??https://x.com/goodside/status/1598874674204618753??
[26]??https://en.wikipedia.org/wiki/High_Bandwidth_Memory??
原文鏈接:
??https://www.understandingai.org/p/why-large-language-models-struggle??
