ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制
本文作者李宏康,美國倫斯勒理工大學(xué)電氣、計(jì)算機(jī)與系統(tǒng)工程系在讀博士生,本科畢業(yè)于中國科學(xué)技術(shù)大學(xué)。研究方向包括深度學(xué)習(xí)理論,大語言模型理論,統(tǒng)計(jì)機(jī)器學(xué)習(xí)等等。目前已在 ICLR/ICML/Neurips 等 AI 頂會(huì)發(fā)表多篇論文。
上下文學(xué)習(xí) (in-context learning, 簡(jiǎn)寫為 ICL) 已經(jīng)在很多 LLM 有關(guān)的應(yīng)用中展現(xiàn)了強(qiáng)大的能力,但是對(duì)其理論的分析仍然比較有限。人們依然試圖理解為什么基于 Transformer 架構(gòu)的 LLM 可以展現(xiàn)出 ICL 的能力。
近期,一個(gè)來自美國倫斯勒理工大學(xué)和 IBM 研究院的團(tuán)隊(duì)從優(yōu)化和泛化理論的角度分析了帶有非線性注意力模塊 (attention) 和多層感知機(jī) (MLP) 的 Transformer 的 ICL 能力。他們特別從理論端證明了單層 Transformer 首先在 attention 層根據(jù) query 選擇一些上下文示例,然后在 MLP 層根據(jù)標(biāo)簽嵌入進(jìn)行預(yù)測(cè)的 ICL 機(jī)制。該文章已收錄在 ICML 2024。
- 論文題目:How Do Nonlinear Transformers Learn and Generalize in In-Context Learning?
- 論文地址:https://arxiv.org/pdf/2402.15607
背景介紹
上下文學(xué)習(xí) in context learning (ICL)
上下文學(xué)習(xí) (ICL) 是一種新的學(xué)習(xí)范式,在大語言模型 (LLM) 中非常流行。它具體是指在測(cè)試查詢 (testing query)
前添加 N 個(gè)測(cè)試樣本 testing examples (上下文),即測(cè)試輸入
和測(cè)試輸出
的組合,從而構(gòu)成一個(gè) testing prompt:
,作為模型的輸入以引導(dǎo)模型作出正確的推斷。這種方式不同于經(jīng)典的對(duì)預(yù)訓(xùn)練模型進(jìn)行微調(diào)的方式,它不需要改變模型的權(quán)重,從而更加的高效。
ICL 理論工作的進(jìn)展
近期的很多理論工作都是基于 [1] 所提出的研究框架,即人們可以直接使用 prompt 的格式來對(duì) Transformer 進(jìn)行訓(xùn)練 (這一步也可以理解為在模擬一種簡(jiǎn)化的 LLM 預(yù)訓(xùn)練模式),從而使得模型具有 ICL 能力。已有的理論工作聚焦于模型的表達(dá)能力 (expressive power) 的角度 [2]。他們發(fā)現(xiàn),人們能夠找到一個(gè)有著 “完美” 的參數(shù)的 Transformer 可以通過前向運(yùn)算執(zhí)行 ICL,甚至隱含地執(zhí)行梯度下降等經(jīng)典機(jī)器學(xué)習(xí)算法。但是這些工作無法回答為什么 Transformer 可以被訓(xùn)練成這樣 “完美” 的,具有 ICL 能力的參數(shù)。因此,還有一些工作試圖從 Transformer 的訓(xùn)練或泛化的角度理解 ICL 機(jī)制 [3,4]。不過,受制于分析 Transformer 結(jié)構(gòu)的復(fù)雜性,這些工作目前止步于研究線性回歸任務(wù),而所考慮的模型通常會(huì)略去 Transformer 中的非線形部分。
本文從優(yōu)化和泛化理論的角度分析了帶有非線性 attention 和 MLP 的 Transformer 的 ICL 能力和機(jī)制:
- 基于一個(gè)簡(jiǎn)化的分類模型,本文具體量化了數(shù)據(jù)的特征如何影響了一層單頭 Transformer 的域內(nèi) (in-domain) 和域外 (out-of-domain, OOD) 的 ICL 泛化能力。
- 本文進(jìn)一步闡釋了 ICL 是如何通過被訓(xùn)練的 Transformer 來實(shí)現(xiàn)了。
- 基于被訓(xùn)練的 Transformer 的特點(diǎn),本文還分析了在 ICL 推斷的時(shí)候使用基于幅值的模型剪枝 (magnitude-based pruning) 的可行性。
理論部分
問題描述
本文考慮一個(gè)二分類問題,即將
通過一個(gè)任務(wù)
映射到
。為了解決這樣的一個(gè)問題,本文構(gòu)建了 prompt 來進(jìn)行學(xué)習(xí)。這里的 prompt 被表示為:
訓(xùn)練網(wǎng)絡(luò)為一個(gè)單層單頭 Transformer:
預(yù)訓(xùn)練過程是求解一個(gè)對(duì)所有訓(xùn)練任務(wù)的經(jīng)驗(yàn)風(fēng)險(xiǎn)最小化 (empirical risk minimization)。損失函數(shù)使用的是適合二分類問題的 Hinge loss,訓(xùn)練算法是隨機(jī)梯度下降。
本文定義了兩種 ICL 泛化的情況。一個(gè)是 in-domain 的,即泛化的時(shí)候測(cè)試數(shù)據(jù)的分布和訓(xùn)練數(shù)據(jù)一樣,注意這個(gè)情況里面測(cè)試任務(wù)不必和訓(xùn)練任務(wù)一樣,即這里已經(jīng)考慮了對(duì)未見任務(wù) (unseen task) 的泛化。另一個(gè)是 out-of-domain 的,即測(cè)試、訓(xùn)練數(shù)據(jù)分布不一樣。
本文還涉及了在 ICL 推斷的時(shí)候進(jìn)行 magnitude-based pruning 的分析,這里的剪枝方式是指對(duì)于訓(xùn)練得到的中的各個(gè)神經(jīng)元,根據(jù)其幅值大小,進(jìn)行從小到大的刪除。
對(duì)數(shù)據(jù)和任務(wù)的構(gòu)建
這一部分請(qǐng)參考原文的 Section 3.2,這里只做一個(gè)概述。本文的理論分析是基于最近比較火熱的 feature learning 路線,即通常將數(shù)據(jù)假設(shè)為可分(通常是正交)的 pattern,從而推導(dǎo)出基于不同 pattern 的梯度變化。本文首先定義了一組 in-domain-relevant (IDR) pattern 用于決定 in-domain 任務(wù)的分類,和一組與任務(wù)無關(guān)的 in-domain-irrelevant (IDI) pattern,這些 pattern 之間互相正交。IDR pattern 有
個(gè),IDI pattern 有
個(gè)。一個(gè)
被表示為一個(gè) IDR pattern 和一個(gè) IDI pattern 的和。一個(gè) in-domain 任務(wù)就被定義為基于某兩個(gè) IDR pattern 的分類問題。
類似地,本文通過定義 out-of-domain-relevant (ODR) pattern 和 out-of-domain-irrelevant (ODI) pattern,可以刻畫 OOD 泛化時(shí)候的數(shù)據(jù)和任務(wù)。
本文對(duì) prompt 的表示可以用下圖的例子來闡述,其中
是 IDR pattern,
是 IDI pattern。這里在做的任務(wù)是基于 x 中的
做分類,如果是
那么其標(biāo)簽為 + 1,對(duì)應(yīng)于 +q,如果是
那么其標(biāo)簽為 - 1,對(duì)應(yīng)于 -q。α,α' 分別被定義為訓(xùn)練和測(cè)試 prompt 中跟 query 的 IDR/ODR pattern 一樣的上下文示例。下圖中的例子里面,
。
理論結(jié)果
首先,對(duì)于 in-domain 的情況,本文先給了一個(gè) condition 3.2 來規(guī)定訓(xùn)練任務(wù)需要滿足的條件,即訓(xùn)練任務(wù)需要覆蓋所有的 IDR pattern 和標(biāo)簽。然后 in-domain 的結(jié)果如下:
這里表明:1,訓(xùn)練任務(wù)的數(shù)量只需要在全部任務(wù)中占比達(dá)到滿足 condition 3.2 的小比例,我們就可以對(duì) unseen task 實(shí)現(xiàn)很好的泛化;2,跟當(dāng)前任務(wù)相關(guān)的 IDR pattern 在 prompt 中的比例越高,就可以以更少的訓(xùn)練數(shù)據(jù),訓(xùn)練迭代次數(shù),以及更短的 training/testing prompt 實(shí)現(xiàn)理想的泛化。
接下來是 out-of-domain 泛化的結(jié)果。
這里說明,如果 ODR pattern 是 IDR pattern 的線性組合且系數(shù)和大于 1,那么此時(shí) OOD ICL 泛化可以達(dá)到理想的效果。這個(gè)結(jié)果給出了在 ICL 的框架下,好的 OOD 泛化所需要的訓(xùn)練和測(cè)試數(shù)據(jù)之間的內(nèi)在聯(lián)系。該定理也通過 GPT-2 的實(shí)驗(yàn)得到了驗(yàn)證。如下圖所示,當(dāng) (12) 中的系數(shù)和
大于 1 的時(shí)候,OOD 分類可以達(dá)到理想的結(jié)果。與此同時(shí),當(dāng)
,即 prompt 中和分類任務(wù)相關(guān)的 ODR/IDR pattern 比例越高的時(shí)候,所需要的 context 長(zhǎng)度越小。
然后,本文給出了帶有 magnitude-based pruning 的 ICL 泛化結(jié)果。
這個(gè)結(jié)果表明,首先,訓(xùn)練得到的
中有一部分(常數(shù)比例)神經(jīng)元的幅值很小,而剩下的相對(duì)比較大(公式 14)。當(dāng)我們只枝剪小神經(jīng)元的時(shí)候,對(duì)泛化結(jié)果基本沒有影響,而當(dāng)枝剪比例增加到要剪大神經(jīng)元的時(shí)候,泛化誤差會(huì)隨之顯著變大(公式 15,16)。以下實(shí)驗(yàn)驗(yàn)證了定理 3.7。下圖 A 中淺藍(lán)色的豎線表示訓(xùn)練得到的
呈現(xiàn)出了公式 14 的結(jié)果。而對(duì)小神經(jīng)元進(jìn)行枝剪不會(huì)使泛化變差,這個(gè)結(jié)果符合理論。圖 B 反映出當(dāng) prompt 中和任務(wù)相關(guān)的上下文越多的時(shí)候,我們可以允許更大的枝剪比例以達(dá)到相同的泛化性能。
ICL 機(jī)制
通過對(duì)預(yù)訓(xùn)練過程的刻畫,本文得到了單層單頭非線性 Transformer 做 ICL 的內(nèi)在機(jī)制,這一部分在原文的 Section 4。該過程可以用下圖表示。
簡(jiǎn)而言之,attention 層會(huì)選擇和 query 的 ODR/IDR pattern 一樣的上下文,賦予它們幾乎全部 attention 權(quán)重,然后 MLP 層會(huì)重點(diǎn)根據(jù) attention 層輸出中的標(biāo)簽嵌入來作出最后的分類。
總結(jié)
本文講解了在 ICL 當(dāng)中,非線性 Transformer 的訓(xùn)練機(jī)制,以及對(duì)于新任務(wù)和分布偏移數(shù)據(jù)的泛化能力。理論結(jié)果對(duì)于設(shè)計(jì) prompt 選擇算法和 LLM 剪枝算法有一定實(shí)際意義。
本文轉(zhuǎn)自 機(jī)器之心 ,作者:機(jī)器之心
