Yann LeCun團隊新研究成果:對自監督學習逆向工程,原來聚類是這樣實現的
自監督學習(SSL)在最近幾年取得了很大的進展,在許多下游任務上幾乎已經達到監督學習方法的水平。但是,由于模型的復雜性以及缺乏有標注訓練數據集,我們還一直難以理解學習到的表征及其底層的工作機制。此外,自監督學習中使用的 pretext 任務通常與特定下游任務的直接關系不大,這就進一步增大了解釋所學習到的表征的復雜性。而在監督式分類中,所學到的表征的結構往往很簡單。
相比于傳統的分類任務(目標是準確將樣本歸入特定類別),現代 SSL 算法的目標通常是最小化包含兩大成分的損失函數:一是對增強過的樣本進行聚類(不變性約束),二是防止表征坍縮(正則化約束)。舉個例子,對于同一樣本經過不同增強之后的數據,對比式學習方法的目標是讓這些樣本的分類結果一樣,同時又要能區分經過增強之后的不同樣本。另一方面,非對比式方法要使用正則化器(regularizer)來避免表征坍縮。
自監督學習可以利用輔助任務(pretext)無監督數據中挖掘自身的監督信息,通過這種構造的監督信息對網絡進行訓練,從而可以學習到對下游任務有價值的表征。近日,圖靈獎得主 Yann LeCun 在內的多位研究者發布了一項研究,宣稱對自監督學習進行了逆向工程,讓我們得以了解其訓練過程的內部行為。
論文地址:https://arxiv.org/abs/2305.15614v2
這篇論文通過一系列精心設計的實驗對使用 SLL 的表征學習進行了深度分析,幫助人們理解訓練期間的聚類過程。具體來說,研究揭示出增強過的樣本會表現出高度聚類的行為,這會圍繞共享同一圖像的增強樣本的含義嵌入形成質心。更出人意料的是,研究者觀察到:即便缺乏有關目標任務的明確信息,樣本也會根據語義標簽發生聚類。這表明 SSL 有能力根據語義相似性對樣本進行分組。
問題設置
由于自監督學習(SSL)通常用于預訓練,讓模型做好準備適應下游任務,這帶來了一個關鍵問題:SSL 訓練會對所學到的表征產生什么影響?具體來說,訓練期間 SSL 的底層工作機制是怎樣的,這些表征函數能學到什么類別?
為了調查這些問題,研究者在多種設置上訓練了 SSL 網絡并使用不同的技術分析了它們的行為。
數據和增強:本文提到的所有實驗都使用了 CIFAR100 圖像分類數據集。為了訓練模型,研究者使用了 SimCLR 中提出的圖像增強協議。每一個 SSL 訓練 session 都執行 1000 epoch,使用了帶動量的 SGD 優化器。
骨干架構:所有的實驗都使用了 RES-L-H 架構作為骨干,再加上了兩層多層感知器(MLP)投射頭。
線性探測(linear probing):為了評估從表征函數中提取給定離散函數(例如類別)的有效性,這里使用的方法是線性探測。這需要基于該表征訓練一個線性分類器(也稱為線性探針),這需要用到一些訓練樣本。
樣本層面的分類:為了評估樣本層面的可分離性,研究者創建了一個專門的新數據集。
其中訓練數據集包含來自 CIFAR-100 訓練集的 500 張隨機圖像。每張圖像都代表一個特定類別并會進行 100 種不同的增強。因此,訓練數據集包含 500 個類別的共計 50000 個樣本。測試集依然是用這 500 張圖像,但要使用 20 種不同的增強,這些增強都來自同一分布。因此,測試集中的結果由 10000 個樣本構成。為了在樣本層面衡量給定表征函數的線性或 NCC(nearest class-center / 最近類別中心)準確度,這里采用的方法是先使用訓練數據計算出一個相關的分類器,然后再在相應測試集上評估其準確率。
揭示自監督學習的聚類過程
在幫助分析深度學習模型方面,聚類過程一直以來都發揮著重要作用。為了直觀地理解 SSL 訓練,圖 1 通過 UMAP 可視化展示了網絡的訓練樣本的嵌入空間,其中包含訓練前后的情況并分了不同層級。
圖 1:SSL 訓練引起的語義聚類
正如預期的那樣,訓練過程成功地在樣本層面上對樣本進行了聚類,映射了同一圖像的不同增強(如第一行圖示)。考慮到目標函數本身就會鼓勵這種行為(通過不變性損失項),因此這樣的結果倒是不意外。然而,更值得注意的是,該訓練過程還會根據標準 CIFAR-100 數據集的原始「語義類別」進行聚類,即便該訓練過程期間缺乏標簽。有趣的是,更高的層級(超類別)也能被有效聚類。這個例子表明,盡管訓練流程直接鼓勵的是樣本層面的聚類,但 SSL 訓練的數據表征還會在不同層面上根據語義類別來進行聚類。
為了進一步量化這個聚類過程,研究者使用 VICReg 訓練了一個 RES-10-250。研究者衡量的是 NCC 訓練準確度,既有樣本層面的,也有基于原始類別的。值得注意的是,SSL 訓練的表征在樣本層面上展現出了神經坍縮(neural collapse,即 NCC 訓練準確度接近于 1.0),然而在語義類別方面的聚類也很顯著(在原始目標上約為 0.41)。
如圖 2 左圖所示,涉及增強(網絡直接基于其訓練的)的聚類過程大部分都發生在訓練過程初期,然后陷入停滯;而在語義類別方面的聚類(訓練目標中并未指定)則會在訓練過程中持續提升。
圖 2:SSL 算法根據語義目標對對數據的聚類
之前有研究者觀察到,監督式訓練樣本的頂層嵌入會逐漸向一個類質心的結構收斂。為了更好地理解 SSL 訓練的表征函數的聚類性質,研究者調查了 SSL 過程中的類似情況。其 NCC 分類器是一種線性分類器,其表現不會超過最佳的線性分類器。通過評估 NCC 分類器與同樣數據上訓練的線性分類器的準確度之比,能夠在不同粒度層級上研究數據聚類。圖 2 的中圖給出了樣本層面類別和原始目標類別上的這一比值的變化情況,其值根據初始化的值進行了歸一化。隨著 SSL 訓練的進行,NCC 準確度和線性準確度之間的差距會變小,這說明增強后的樣本會根據其樣本身份和語義屬性逐漸提升聚類水平。
此外,該圖還說明,樣本層面的比值起初會高一些,這說明增強后的樣本會根據它們的身份進行聚類,直到收斂至質心(NCC 準確度和線性準確度的比值在 100 epoch 時 ≥ 0.9)。但是,隨著訓練繼續,樣本層面的比值會飽和,而類別層面的比值會繼續增長并收斂至 0.75 左右。這說明增強后的樣本首先會根據樣本身份進行聚類,實現之后,再根據高層面的語義類別進行聚類。
SSL 訓練中隱含的信息壓縮
如果能有效進行壓縮,那么就能得到有益又有用的表征。但 SSL 訓練過程中是否會出現那樣的壓縮卻仍是少有人研究的課題。
為了了解這一點,研究者使用了互信息神經估計(Mutual Information Neural Estimation/MINE),這種方法可以估計訓練過程中輸入與其對應嵌入表征之間的互信息。這個度量可用于有效衡量表征的復雜度水平,其做法是展現其編碼的信息量(比特數量)。
圖 3 的中圖報告了在 5 個不同的 MINE 初始化種子上計算得到的平均互信息。如圖所示,訓練過程會有顯著的壓縮,最終形成高度緊湊的訓練表征。
圖 3:(左)一個 SSL 訓練的模型在訓練期間的正則化和不變性損失以及原始目標線性測試準確度。(中)訓練期間輸入和表征之間的互信息的壓縮。(右)SSL 訓練學習聚類的表征。
正則化損失的作用
目標函數包含兩項:不變性和正則化。不變性項的主要功能是強化同一樣本的不同增強的表征之間的相似性。而正則化項的目標是幫助防止表征坍縮。
為了探究這些分量對聚類過程的作用,研究者將目標函數分解為了不變性項和正則化項,并觀察它們在訓練過程中的行為。比較結果見圖 3 左圖,其中給出了原始語義目標上的損失項的演變以及線性測試準確度。不同于普遍流行的想法,不變性損失項在訓練過程中并不會顯著改善。相反,損失(以及下游的語義準確度)的改善是通過降低正則化損失實現的。
由此可以得出結論:SSL 的大部分訓練過程都是為了提升語義準確度和所學表征的聚類,而非樣本層面的分類準確度和聚類。
從本質上講,這里的發現表明:盡管自監督學習的直接目標是樣本層面的分類,但其實大部分訓練時間都用于不同層級上基于語義類別的數據聚類。這一觀察結果表明 SSL 方法有能力通過聚類生成有語義含義的表征,這也讓我們得以了解其底層機制。
監督學習和 SSL 聚類的比較
深度網絡分類器往往是基于訓練樣本的類別將它們聚類到各個質心。但學習得到的函數要能真正聚類,必須要求這一性質對測試樣本依然有效;這是我們期望得到的效果,但效果會差一點。
這里有一個有趣的問題:相比于監督學習的聚類,SSL 能在多大程度上根據樣本的語義類別來執行聚類?圖 3 右圖報告了在不同場景(使用和不使用增強的監督學習以及 SSL)的訓練結束時的 NCC 訓練和測試準確度比率。
盡管監督式分類器的 NCC 訓練準確度為 1.0,顯著高于 SSL 訓練的模型的 NCC 訓練準確度,但 SSL 模型的 NCC 測試準確度卻略高于監督式模型的 NCC 測試準確度。這說明兩種模型根據語義類別的聚類行為具有相似的程度。有意思的是,使用增強樣本訓練監督式模型會稍微降低 NCC 訓練準確度,卻會大幅提升 NCC 測試準確度。
探索語義類別學習和隨機性的影響
語義類別是根據輸入的內在模式來定義輸入和目標的關系。另一方面,如果將輸入映射到隨機目標,則會看到缺乏可辨別的模式,這會導致輸入和目標之間的連接看起來很任意。
研究者還探究了隨機性對模型學習所需目標的熟練程度的影響。為此,他們構建了一系列具有不同隨機度的目標系統,然后檢查了隨機度對所學表征的影響。他們在用于分類的同一數據集上訓練了一個神經網絡分類器,然后使用其不同 epoch 的目標預測作為具有不同隨機度的目標。在 epoch 0 時,網絡是完全隨機的,會得到確定的但看似任意的標簽。隨著訓練進行,其函數的隨機性下降,最終得到與基本真值目標對齊的目標(可認為是完全不隨機)。這里將隨機度歸一化到 0(完全不隨機,訓練結束時)到 1(完全隨機,初始化時)之間。
圖 4 左圖展示了不同隨機度目標的線性測試準確度。每條線都對應于不同隨機度的 SSL 不同訓練階段的準確度。可以看到,在訓練過程中,模型會更高效地捕獲與「語義」目標(更低隨機度)更接近的類別,同時在高隨機度的目標上沒有表現出顯著的性能改進。
圖 4:SSL 持續學習語義目標,而非隨機目標
深度學習的一個關鍵問題是理解中間層對分類不同類型類別的作用和影響。比如,不同的層會學到不同類型的類別嗎?研究者也探索了這個問題,其做法是在訓練結束時不同目標隨機度下評估不同層表征的線性測試準確度。如圖 4 中圖所示,隨著隨機度下降,線性測試準確度持續提升,更深度的層在所有類別類型上都表現更優,而對于接近語義類別的分類,性能差距會更大。
研究者還使用了其它一些度量來評估聚類的質量:NCC 準確度、CDNV、平均每類方差、類別均值之間的平均平方距離。為了衡量表征隨訓練進行的改進情況,研究者為語義目標和隨機目標計算了這些指標的比率。圖 4 右圖展示了這些比率,結果表明相比于隨機目標,表征會更加偏向根據語義目標來聚類數據。有趣的是,可以看到 CDNV(方差除以平方距離)會降低,其原因僅僅是平方距離的下降。方差比率在訓練期間相當穩定。這會鼓勵聚類之間的間距拉大,這一現象已被證明能帶來性能提升。
了解類別層級結構和中間層
之前的研究已經證明,在監督學習中,中間層會逐漸捕獲不同抽象層級的特征。初始的層傾向于低層級的特征,而更深的層會捕獲更抽象的特征。接下來,研究者探究了 SSL 網絡能否學習更高層面的層次屬性以及哪些層面與這些屬性的關聯性更好。
在實驗中,他們計算了三個層級的線性測試準確度:樣本層級、原始的 100 個類別、20 個超類別。圖 2 右圖給出了為這三個不同類別集計算的數量。可以觀察到,在訓練過程中,相較于樣本層級的類別,在原始類別和超類別層級上的表現的提升更顯著。
接下來是 SSL 訓練的模型的中間層的行為以及它們捕獲不同層級的目標的能力。圖 5 左和中圖給出了不同訓練階段在所有中間層上的線性測試準確度,這里度量了原始目標和超目標。圖 5 右圖給出超類別和原始類別之間的比率。
圖 5:SSL 能在整體中間層中有效學習語義類別
研究者基于這些結果得到了幾個結論。首先,可以觀察到隨著層的深入,聚類效果會持續提升。此外,與監督學習情況類似,研究者發現在 SSL 訓練期間,網絡每一層的線性準確度都有提升。值得注意的是,他們發現對于原始類別,最終層并不是最佳層。近期的一些 SSL 研究表明:下游任務能高度影響不同算法的性能。本文的研究拓展了這一觀察結果,并且表明網絡的不同部分可能適合不同的下游任務與任務層級。根據圖 5 右圖,可以看出,在網絡的更深層,超類別的準確度的提升幅度超過原始類別。