Bengio等人新作:注意力可被視為RNN,新模型媲美Transformer,但超級省內存
序列建模的進展具有極大的影響力,因為它們在廣泛的應用中發揮著重要作用,包括強化學習(例如,機器人和自動駕駛)、時間序列分類(例如,金融欺詐檢測和醫學診斷)等。
在過去的幾年里,Transformer 的出現標志著序列建模中的一個重大突破,這主要得益于 Transformer 提供了一種能夠利用 GPU 并行處理的高性能架構。
然而,Transformer 在推理時計算開銷很大,主要在于內存和計算需求呈二次擴展,從而限制了其在低資源環境中的應用(例如,移動和嵌入式設備)。盡管可以采用 KV 緩存等技術提高推理效率,但 Transformer 對于低資源領域來說仍然非常昂貴,原因在于:(1)隨 token 數量線性增加的內存,以及(2)緩存所有先前的 token 到模型中。在具有長上下文(即大量 token)的環境中,這一問題對 Transformer 推理的影響更大。
為了解決這個問題,加拿大皇家銀行 AI 研究所 Borealis AI、蒙特利爾大學的研究者在論文《Attention as an RNN 》中給出了解決方案。值得一提的是,我們發現圖靈獎得主 Yoshua Bengio 出現在作者一欄里。
- 論文地址:https://arxiv.org/pdf/2405.13956
- 論文標題:Attention as an RNN?
具體而言,研究者首先檢查了 Transformer 中的注意力機制,這是導致 Transformer 計算復雜度呈二次增長的組件。該研究表明注意力機制可以被視為一種特殊的循環神經網絡(RNN),具有高效計算的多對一(many-to-one)RNN 輸出的能力。利用注意力的 RNN 公式,該研究展示了流行的基于注意力的模型(例如 Transformer 和 Perceiver)可以被視為 RNN 變體。
然而,與 LSTM、GRU 等傳統 RNN 不同,Transformer 和 Perceiver 等流行的注意力模型雖然可以被視為 RNN 變體。但遺憾的是,它們無法高效地使用新 token 進行更新。
為了解決這個問題,該研究引入了一種基于并行前綴掃描(prefix scan)算法的新的注意力公式,該公式能夠高效地計算注意力的多對多(many-to-many)RNN 輸出,從而實現高效的更新。
在此新注意力公式的基礎上,該研究提出了 Aaren([A] ttention [a] s a [re] current neural [n] etwork),這是一種計算效率很高的模塊,不僅可以像 Transformer 一樣并行訓練,還可以像 RNN 一樣高效更新。
實驗結果表明,Aaren 在 38 個數據集上的表現與 Transformer 相當,這些數據集涵蓋了四種常見的序列數據設置:強化學習、事件預測、時間序列分類和時間序列預測任務,同時在時間和內存方面更加高效。
方法介紹
為了解決上述問題,作者提出了一種基于注意力的高效模塊,它能夠利用 GPU 并行性,同時又能高效更新。
首先,作者在第 3.1 節中表明,注意力可被視為一種 RNN,具有高效計算多對一 RNN(圖 1a)輸出的特殊能力。利用注意力的 RNN 形式,作者進一步說明,基于注意力的流行模型,如 Transformer(圖 1b)和 Perceiver(圖 1c),可以被視為 RNN。然而,與傳統的 RNN 不同的是,這些模型無法根據新 token 有效地更新自身,從而限制了它們在數據以流的形式到達的序列問題中的潛力。
為了解決這個問題,作者在第 3.2 節中介紹了一種基于并行前綴掃描算法的多對多 RNN 計算注意力的高效方法。在此基礎上,作者在第 3.3 節中介紹了 Aaren—— 一個計算效率高的模塊,它不僅可以并行訓練(就像 Transformer),還可以在推理時用新 token 高效更新,推理只需要恒定的內存(就像傳統 RNN)。
將注意力視為一個多對一 RNN
查詢向量 q 的注意力可被視為一個函數,它通過 N 個上下文 token x_1:N 的鍵和值
將其映射到單一輸出 o_N = Attention (q, k_1:N , v_1:N ) 。給定 s_i = dot (q,k_i),輸出 o_N 可表述為:
,分母為
。將注意力視為 RNN,可以在 k = 1,...,...... 時,以滾動求和的方式迭代計算
和
。然而,在實踐中,這種實現方式并不穩定,會因有限的精度表示和可能非常小或非常大的指數(即 exp (s))而遇到數值問題。為了緩解這一問題,作者用累積最大值項
來重寫遞推公式,計算
和
。值得注意的是,最終結果是相同的
,m_k 的循環計算如下:
計算注意力的方法:通過將注意力視為一個 RNN,可以看到計算注意力的不同方法:在 O (1) 內存中逐個 token 循環計算(即順序計算);或以傳統方式計算(即并行計算),需要線性 O (N) 內存。由于注意力可以被看作是一個 RNN,因此計算注意力的傳統方法也可以被看作是計算注意力多對一 RNN 輸出的高效方法,即 RNN 的輸出以多個上下文 token 為輸入,但在 RNN 結束時只輸出一個 token(見圖 1a)。最后,也可以將注意力計算為一個逐塊處理 token 的 RNN,而不是完全按順序或完全并行計算,這需要 O (b) 內存,其中 b 是塊的大小。
將現有的注意力模型視為 RNN。通過將注意力視為 RNN,現有的基于注意力的模型也可以被視為 RNN 的變體。例如,Transformer 的自注意力是 RNN(圖 1b),上下文 token 是其初始隱藏狀態。Perceiver 的交叉注意力是 RNN(圖 1c),其初始隱藏狀態是與上下文相關的潛變量。通過利用其注意力機制的 RNN 形式,這些現有模型可以高效地計算其輸出存儲。
然而,當將現有的基于注意力的模型(如 Transformers)視為 RNN 時,這些模型又缺乏傳統 RNN(如 LSTM 和 GRU)中常見的重要屬性。
值得注意的是,LSTM 和 GRU 能夠僅在 O (1) 常量內存和計算中使用新 token 有效地更新自身,相比之下, Transformer 的 RNN 視圖(見圖 1b)會通過將一個新的 token 作為初始狀態添加一個新的 RNN 來處理新 token。這個新的 RNN 處理所有先前的 token,需要 O (N) 的線性計算量。
在 Perceiver 中,由于其架構的原因,潛變量(圖 1c 中的 L_i)是依賴于輸入的,這意味著它們的值在接收新 token 時會發生變化。由于其 RNN 的初始隱藏狀態(即潛變量)發生變化,Perceiver 因此需要從頭開始重新計算其 RNN,需要 O (NL) 的線性計算量,其中 N 是 token 的數量,L 是潛變量的數量。
將注意力視為一個多對多 RNN
針對這些局限性,作者建議開發一種基于注意力的模型,利用 RNN 公式的能力來執行高效更新。為此,作者首先引入了一種高效的并行化方法,將注意力作為多對多 RNN 計算,即并行計算
的方法。為此,作者利用并行前綴掃描算法(見算法 1),這是一種通過關聯算子 ⊕ 從 N 個連續數據點計算 N 個前綴的并行計算方法。該算法可高效計算
回顧
,其中
,
為了高效計算
,可以通過并行掃描算法計算
和
,然后結合 a_k 和 c_k 計算
。
為此,作者提出了以下關聯算子⊕,該算子作用于形式為(m_A、u_A、w_A)的三元組,其中 A 是一組索引,
,
,
。并行掃描算法的輸入為
。該算法遞歸應用算子 ⊕,其工作原理如下:
,其中,
,
。
在完成遞歸應用算子后,算法輸出
。也被稱作
。結合輸出元組的最后兩個值,檢索
從而產生一種高效的并行方法,將注意力計算為多對多 RNN(圖 3)。
Aaren:[A] ttention [a] s a [re] current neural [n] etwork
Aaren 的接口與 Transformer 相同,即將 N 個輸入映射到 N 個輸出,而第 i 個輸出是第 1 到第 i 個輸入的聚合。此外,Aaren 還自然可堆疊,并且能夠計算每個序列 token 的單獨損失項。然而,與使用因果自注意力的 Transformers 不同,Aaren 使用上述計算注意力的方法作為多對多 RNN,使其更加高效。Aaren 形式如下:
與 Transformer 不同,在 Transformer 中查詢是輸入到注意力的 token 之一,而在 Aaren 中,查詢 token q 是在訓練過程中通過反向傳播學習得到的。
下圖展示了一個堆疊 Aaren 模型的例子,該模型的輸入上下文 token 為 x_1:3,輸出為 y_1:3。值得注意的是,由于 Aaren 利用了 RNN 形式的注意力機制,堆疊 Aarens 也相當于堆疊 RNN。因此,Aarens 也能夠高效地用新 token 進行更新,即 y_k 的迭代計算僅需要常量計算,因為它僅依賴于 h_k-1 和 x_k。
基于 Transformer 的模型需要線性內存(使用 KV 緩存時)并且需要存儲所有先前的 token ,包括中間 Transformer 層中的那些,但基于 Aaren 的模型只需要常量內存,并且不需要存儲所有先前的 token ,這使得 Aarens 在計算效率上顯著優于 Transformer。
實驗
實驗部分的目標是比較 Aaren 和 Transformer 在性能和所需資源(時間和內存)方面的表現。為了進行全面比較,作者在四個問題上進行了評估:強化學習、事件預測、時間序列預測和時間序列分類。
強化學習
作者首先比較了 Aaren 和 Transformer 在強化學習方面的表現。強化學習在機器人、推薦引擎和交通控制等交互式環境中很受歡迎。
表 1 中的結果表明,在所有 12 個數據集和 4 種環境中,Aaren 與 Transformer 的性能都不相上下。不過,與 Transformer 不同的是,Aaren 也是一種 RNN,因此能夠在持續計算中高效處理新的環境交互,從而更適合強化學習。
事件預測
接下來,作者比較了 Aaren 和 Transformer 在事件預測方面的表現。事件預測在許多現實環境中都很流行,例如金融(如交易)、醫療保健(如患者觀察)和電子商務(如購買)。
表 2 中的結果顯示,Aaren 在所有數據集上的表現都與 Transformer 相當。Aaren 能夠高效處理新輸入,這在事件預測環境中尤為有用,因為在這種環境中,事件會以不規則流的形式出現。
時間序列預測
然后,作者比較了 Aaren 和 Transformer 在時間序列預測方面的表現。時間序列預測模型通常用在與氣候(如天氣)、能源(如供需)和經濟(如股票價格)相關的領域。
表 3 中的結果顯示,在所有數據集上,Aaren 與 Transformer 的性能相當。不過,與 Transformer 不同的是,Aaren 能高效處理時間序列數據,因此更適合與時間序列相關的領域。
時間序列分類
接下來,作者比較了 Aaren 和 Transformer 在時間序列分類方面的表現。時間序列分類在許多重要的應用中很常見,例如模式識別(如心電圖)、異常檢測(如銀行欺詐)或故障預測(如電網波動)。
從表 4 中可以看出,在所有數據集上,Aaren 與 Transformer 的表現不相上下。
分析
最后,作者比較了 Aaren 和 Transformer 所需的資源。
內存復雜性:在圖 5(左)中,作者比較了 Aaren 和 Transformer(使用 KV 緩存)在推理時的內存使用情況。可以看到,伴隨 KV 緩存技術的使用,Transformer 的內存使用量呈線性增長。相比之下,Aaren 只使用恒定的內存,無論 token 數量如何增長,因此它的效率要高得多。
時間復雜度:在圖 5(右圖)中,作者比較了 Aaren 和 Transformer(使用 KV 緩存)按順序處理一串 token 所需的累計時間。對于 Transformer,累計計算量是 token 數的二次方,即 O (1 + 2 + ... + N) = O (N^2 )。相比之下,Aaren 的累計計算量是線性的。在圖中,可以看到模型所需的累計時間也是類似的結果。具體來說,Transformer 所需的累計時間呈二次增長,而 Aaren 所需的累計時間呈線性增長。
參數數量:由于要學習初始隱藏狀態 q,Aaren 模塊需要的參數略多于 Transformer 模塊。不過,由于 q 只是一個向量,因此差別不大。通過在同類模型中進行實證測量,作者發現 Transformer 使用了 3, 152, 384 個參數。相比之下,等效的 Aaren 使用了 3, 152, 896 個參數,參數增加量僅為 0.016%—— 對于內存和時間復雜性的顯著差異來說,這只是微不足道的代價。
本文轉自 機器之心 ,作者:機器之心
