騰訊優(yōu)圖&廈門大學(xué)提出無需訓(xùn)練的ViT結(jié)構(gòu)搜索算法
近期,ViT 在計(jì)算機(jī)視覺領(lǐng)域展現(xiàn)了出強(qiáng)大的競(jìng)爭(zhēng)力、在多個(gè)任務(wù)里取得了驚人的進(jìn)展。隨著許多人工設(shè)計(jì)的 ViT 結(jié)構(gòu)(如Swin-Transformer、PVT、XCiT 等)的出現(xiàn),面向 ViT 的結(jié)構(gòu)搜索(TAS) 開始受到越來越多的關(guān)注。TAS 旨在以自動(dòng)化的方式在 ViT 搜索空間(如MSA 的 head 數(shù)量、channel ratio 等)中找到更優(yōu)的網(wǎng)絡(luò)結(jié)構(gòu)。基于 one-shot NAS 的方案(如AutoFormer、GLiT 等)已經(jīng)取得了初步進(jìn)展,但他們?nèi)匀恍枰芨叩挠?jì)算成本(如24 GPU days 以上)。主要原因有以下兩點(diǎn):
1.在空間的復(fù)雜度上,ViT 搜索空間(如,GLiT 空間的量級(jí)約 10^30)在數(shù)量上遠(yuǎn)遠(yuǎn)超過 CNN 搜索空間(如,DARTS 空間的量級(jí)約 10^18);
2.ViT 模型通常需要更多的訓(xùn)練周期(如300 epochs)才能知道其對(duì)應(yīng)的效果。
在近期的一篇論文《Training-free Transformer Architecture Search》中,來自騰訊優(yōu)圖實(shí)驗(yàn)室、廈門大學(xué)、鵬城實(shí)驗(yàn)室等結(jié)構(gòu)的研究者回顧近些年 NAS 領(lǐng)域的進(jìn)展,并注意到:為了提高搜索效率,研究社區(qū)提出了若干零成本代理(zero-cost proxy)的評(píng)估指標(biāo)(如GraSP、TE-score 和 NASWOT)。這些方法讓我們能夠在無需訓(xùn)練的條件下就能評(píng)估出不同 CNN 結(jié)構(gòu)的排序關(guān)系,從而極大程度上節(jié)省計(jì)算成本。
- 論文地址:https://arxiv.org/pdf/2203.12217.pdf
- 項(xiàng)目地址:https://github.com/decemberzhou/TF_TAS
從技術(shù)上來說,一個(gè)典型的 CNN 模型主要由卷積模塊組成,而一個(gè) ViT 模型主要由多頭注意力模塊(MSA)和多層感知機(jī)模塊(MLP)組成。這種網(wǎng)絡(luò)結(jié)構(gòu)上的差異會(huì)讓現(xiàn)有的、在 CNN 搜索空間上驗(yàn)證有效的零成本代理無法保證其在 ViT 搜索空間上模型評(píng)估效果(見下圖 1)。
因此,研究一種更適合 ViT 結(jié)構(gòu)評(píng)估、有利于 TAS 訓(xùn)練效率的零成本代理指標(biāo)是有必要且值得探索的。這一問題也將促使研究者進(jìn)一步研究和更好地理解 ViT 結(jié)構(gòu),從而設(shè)計(jì)一種有效的、無需訓(xùn)練的 TAS 搜索算法。
圖 1. (a)研究者采樣的 1000 個(gè) ViT 模型的參數(shù)量和效果分布。(b-e)在 CNN 搜索空間效果好的 zero-cost proxy 方法并不適用于 ViT 搜索空間。(f)他們的 DSS-indicator 更適合用來評(píng)估不同的 ViT 模型。
方法主體
為了達(dá)到這個(gè)目的,研究者對(duì) MSA 和 MLP 模塊進(jìn)行了理論分析,希望找到某種可量化的屬性來有效地評(píng)估 ViT 網(wǎng)絡(luò)。
基于量化結(jié)果,他們觀察到:在 ViT 中,MSA 和 MLP 確實(shí)具有各自不同的、適合用來揭示模型效果的性質(zhì)。研究者有如下定義:衡量一個(gè) MSA 的秩復(fù)雜程度,將其計(jì)作突觸多樣性(synaptic diversity);估計(jì)一個(gè) MLP 內(nèi)重要參數(shù)的數(shù)量,將其計(jì)作突觸顯著性(synaptic saliency)。當(dāng) MSA 擁有更高的突觸多樣性或者當(dāng) MLP 有更多的突觸顯著性時(shí),其對(duì)應(yīng)的 ViT 模型總是擁有更好的效果。
基于這個(gè)重要的結(jié)果,研究者設(shè)計(jì)了一個(gè)有效且高效的零代價(jià)代理評(píng)估指標(biāo) DSS-indicator(下圖 2),并基于此設(shè)計(jì)了一個(gè)包含模塊化策略的無訓(xùn)練 Transformer 結(jié)構(gòu)搜索算法(Transformer Architecture Search,TF-TAS)。
圖 2. 方法的整體框架圖。
具體來說,DSS-indicator 通過計(jì)算 MSA 的突觸多樣性和 MLP 的突觸顯著性來得到 ViT 結(jié)構(gòu)的評(píng)價(jià)分?jǐn)?shù)。這是學(xué)術(shù)界首次提出基于 MSA 的突觸多樣性和 MLP 的突觸顯著性來作為評(píng)價(jià) ViT 結(jié)構(gòu)的代理評(píng)估指標(biāo)。而且需要注意的是,TF-TAS 與搜索空間設(shè)計(jì)和權(quán)值共享策略是正交的。因此,可以靈活地將 TF-TAS 與其他 ViT 搜索空間或 TAS 方法相結(jié)合,進(jìn)一步提高搜索效率。
與人工設(shè)計(jì)的 ViT 和自動(dòng)搜索的 ViT 相比,研究者設(shè)計(jì)的 TF-TAS 實(shí)現(xiàn)了具有競(jìng)爭(zhēng)力的效果,將搜索過程從 24 GPU 天數(shù)縮短到不到 0.5 GPU 天數(shù),大約快 48 倍。
MSA 的突觸多樣性
MSA 是 ViT 結(jié)構(gòu)的一個(gè)基本組件,其多樣性對(duì) ViT 效果有重要意義。基于已有的工作可以知道:MSA 模塊學(xué)到的特征表示存在秩崩潰(rank collapse)的現(xiàn)象。隨著輸入在網(wǎng)絡(luò)中前向傳播和深度的不斷加深,ViT 中 MSA 的輸出會(huì)逐漸收斂到秩為 1、并最終退化為一個(gè)秩為 1 的矩陣(每一行的值不變,即多樣性出現(xiàn)稀疏的情況)。秩崩潰意味著 ViT 模型效果很差。因此,我們可以通過估計(jì)秩崩潰的程度來推測(cè) ViT 模型的效果。
然而,在高維空間中估計(jì)秩崩潰需要大量計(jì)算量。實(shí)際上,已經(jīng)被證明矩陣的秩包含特征中多樣性信息的代表性線索。基于這些理解,MSA 模塊中權(quán)重參數(shù)的秩可以作為評(píng)價(jià) ViT 結(jié)構(gòu)的指標(biāo)。
對(duì)于 MSA 模塊,直接對(duì)其權(quán)值矩陣的秩進(jìn)行度量,存在計(jì)算量較大的問題。為了加速計(jì)算,研究者利用 MSA 權(quán)重矩陣的核范數(shù)近似其秩作為多樣性指標(biāo)。理論上,當(dāng)權(quán)重矩陣的 Frobenius 范數(shù)(F 范數(shù))滿足一定條件時(shí),權(quán)重矩陣的核范數(shù)可視為其秩的等價(jià)替換。具體來說,研究者將 MSA 模塊的權(quán)值參數(shù)矩陣表示為
。m 表示 MSA 中第 m 個(gè)線性層。因此,
的 F 范數(shù)可以定義為:
其中,
表示
中第 i 行 j 列的元素,根據(jù)算術(shù)均值和幾何均值的不等式,
的上界為:
上式表示
的上界即為
的最大線性獨(dú)立的向量數(shù),即矩陣的秩。隨機(jī)給定
中的兩個(gè)向量
,
。當(dāng)
,
獨(dú)立時(shí),
的值相應(yīng)的會(huì)更大。這表明:
的 F 范數(shù)越大,
的秩越接近
的多樣性。當(dāng)
時(shí),
的核范數(shù)可以是
秩的近似。形式上,
的核范數(shù)被定義為:
其中,
表示相應(yīng)矩陣的跡,從而容易得到:
。因此,
的秩可近似為
。理論上,
和
成正比,這也表明利用的核范數(shù)可以測(cè)度
的多樣性。為了更好地估計(jì)權(quán)重隨機(jī)初始化的 ViT 網(wǎng)絡(luò)中 MSA 模塊的突觸多樣性,研究者在每個(gè) MSA 模塊的梯度矩陣
(L 為損失函數(shù)) 上進(jìn)一步考慮上述步驟。
總的來說,研究者對(duì)第 l 個(gè) MSA 模塊中權(quán)重參數(shù)的突觸多樣性定義如下:
為了驗(yàn)證 MSA 的突觸多樣性與給定 ViT 架構(gòu)的測(cè)試精度之間的正相關(guān)關(guān)系,研究者對(duì)從 AutoFormer 搜索空間中采樣的 200 個(gè) ViT 網(wǎng)絡(luò)進(jìn)行完整的訓(xùn)練,得到其對(duì)應(yīng)的 MSA 模塊的分類效果和突觸多樣性。它們之間的 Kentall’s τ 相關(guān)系數(shù)為 0.65,如下圖 3a 所示。表明 MSA 的突觸多樣性與每個(gè)輸入 ViT 架構(gòu)的效果之間的正相關(guān)聯(lián)系。
圖 3. (a)MSA 的突觸多樣性(紅)以及 MLP 的突觸顯著性(藍(lán))的評(píng)估效果;(b-c)MSA 和 MLP 具有不同的剪枝敏感性。
MLP 的突觸顯著性
模型剪枝對(duì)于 CNN 領(lǐng)域已經(jīng)取得了很多進(jìn)展,并開始在 Transformer 上得到應(yīng)用。目前已經(jīng)有幾種有效的 CNN 剪枝方法被提出用來衡量早期訓(xùn)練階段模型權(quán)重的重要性。主要有以下兩派方法:
- 測(cè)量在初始化狀態(tài)下衡量突觸的顯著性用于 CNN 模型的剪枝;
- 由于 Transformer 中不同模塊在初始化階段也有不同程度的冗余,因而可以通過對(duì)不同大小的 Transformer 進(jìn)行剪枝。
與剪枝相似,TAS 主要搜索幾個(gè)重要維度,包括注意力頭數(shù)量、MSA 和 MLP 比值等。受這些剪枝方法的啟發(fā),研究者嘗試使用突觸顯著性來評(píng)估不同的 ViT。然而, MSA 和 MLP 的結(jié)構(gòu)差異較大,因此需要分析剪枝敏感性對(duì)度量 ViT 中不同模塊的影響。
為了進(jìn)一步分析 MSA 和 MLP 對(duì)剪枝的敏感性不同對(duì)評(píng)估 ViT 模型的影響,研究者通過剪枝敏感性實(shí)驗(yàn)給出了一些定量結(jié)果。如圖 3b 所示,他們從 AutoFormer 搜索空間中隨機(jī)抽樣 5 個(gè) ViT 架構(gòu),分析 MSA 和 MLP 對(duì)剪枝的敏感性。結(jié)果顯示,MLP 對(duì)修剪比 MSA 更敏感。他們還對(duì) PiT 搜索空間進(jìn)行了分析,得到了類似的觀察結(jié)果 (圖 3c)。
此外,研究者采用 MSA 和 MLP 模塊上的突觸顯著性作為代理,分別計(jì)算代理 ViT 基準(zhǔn)上的 Kendall’s τ 相關(guān)性系數(shù)。最終結(jié)果表明在 MLP 上突觸顯著性的 Kendall’s τ 為 0.47,優(yōu)于 MSA (0.24)、MLP 和 MSA (0.41)。
由于突觸顯著性通常以總和的形式計(jì)算,冗余的權(quán)重往往帶來負(fù)面的累積效應(yīng)。MSA 模塊對(duì)剪枝不敏感,說明 MSA 的權(quán)值參數(shù)具有較高的冗余性。在剪枝領(lǐng)域中被證明冗余權(quán)參數(shù)的值要比非冗余權(quán)參數(shù)的值小得多。盡管這些冗余參數(shù)的值相對(duì)較小,但超過 50% 的冗余往往會(huì)產(chǎn)生較大的累積效應(yīng),尤其是在區(qū)分相似的 ViT 結(jié)構(gòu)時(shí)。
對(duì)于累積效應(yīng),一般的零成本代理中不加區(qū)分地將 MSA 的冗余權(quán)重參數(shù)考慮在內(nèi)來衡量顯著性,導(dǎo)致相應(yīng)的零成本代理中的累加形式存在 MSA 的累積效應(yīng)。累積效應(yīng)可能會(huì)使零成本代理給差的網(wǎng)絡(luò)更高的排名。同時(shí),權(quán)重冗余對(duì) MLP 模塊突觸顯著性的影響較小,因此可以作為評(píng)估 MLP 模塊權(quán)重次數(shù)秩的復(fù)雜性的一個(gè)指標(biāo),從一個(gè)方面指示模型的優(yōu)劣。
為了評(píng)估 ViT 中的 MLP,研究者基于突觸顯著性設(shè)計(jì)了評(píng)估的代理指標(biāo)。在網(wǎng)絡(luò)剪枝中,對(duì)模型權(quán)值的重要性進(jìn)行了廣泛的研究。由于神經(jīng)網(wǎng)絡(luò)主要由卷積層組成,有幾種基于剪枝的零成本代理可以直接用于測(cè)量神經(jīng)網(wǎng)絡(luò)的突觸顯著性。另一方面,ViT 體系結(jié)構(gòu)主要由 MLP 和 MSA 模塊組成,它們具有不同的剪枝特性。通過對(duì) MSA 和 MLP 模塊的剪枝敏感性分析,他們驗(yàn)證了 MLP 模塊對(duì)剪枝更加敏感。因此,突觸顯著性可以更好地反映 MLP 模塊中權(quán)重重要性的差異。相比之下,MSA 模塊對(duì)剪枝相對(duì)不敏感,其突觸顯著性往往受到冗余權(quán)重的影響。
基于 MLP 的修剪敏感性,研究者建議以模塊化的方式測(cè)量突觸顯著性。具體來說,所提出的模塊化策略測(cè)量了作為 ViT 結(jié)構(gòu)的一個(gè)重要部分的 MLPs 的突觸顯著性。給定一個(gè) ViT 架構(gòu),第 l 個(gè) MLP 模塊的顯著性得分為:
其中 n 為指定 ViT 網(wǎng)絡(luò)中第 l 個(gè) MLP 的線性層數(shù),通常設(shè)為 2。圖 3a 顯示了一些定性結(jié)果,以驗(yàn)證
在評(píng)估 ViT 架構(gòu)方面的有效性。
無需訓(xùn)練的 TAS
基于上述分析,研究者設(shè)計(jì)了一種基于模塊化策略的無需訓(xùn)練的 TAS(TF-TAS),來提高搜索 TAS 的搜索效率。如下公式所示,DSS-indicator 同時(shí)考慮 MSA 的突觸多樣性和 MLP 的突觸顯著性來對(duì)模型進(jìn)行評(píng)分:
總的來說,DSS-indicator 從兩個(gè)不同的維度評(píng)估每個(gè) ViT 結(jié)構(gòu)。TF-TAS 在輸入模型經(jīng)過一個(gè)前向傳播和后向更新后計(jì)算
,作為相應(yīng)的 ViT 模型的代理分?jǐn)?shù)。研究者保持模型的輸入數(shù)據(jù)的每個(gè)像素為 1,以消除輸入數(shù)據(jù)對(duì)權(quán)重計(jì)算的影響。因此,
對(duì)隨機(jī)種子具有不變性,與真實(shí)的圖片輸入數(shù)據(jù)無關(guān)。
實(shí)驗(yàn)結(jié)果
1.Image-Net
研究者首先在 ImageNet 數(shù)據(jù)集上進(jìn)行搜索效果測(cè)試,結(jié)果如下所示。在三種參數(shù)量級(jí)上,研究者都能找到不亞于、甚至比基于 one-shot NAS 的 TAS 方法更好的模型結(jié)果。而且所需要的耗時(shí)(0.5 GPU days)要遠(yuǎn)小于現(xiàn)有 TAS 方法所需的計(jì)算成本(24 GPU days 以上)。
2. 遷移實(shí)驗(yàn)
為了進(jìn)一步驗(yàn)證搜索得到的模型的效果,研究者在 CIFAR-10、CIFAR-100 數(shù)據(jù)集上驗(yàn)證其遷移性。按照 AutoFormer 論文的設(shè)定,他們將模型在 384 x 384 大小的圖像上進(jìn)行 fintune,效果如下所示。基于 DSS-indicator 找到的模型與基于 one-shot NAS 找到的模型在遷移性上不相上下。
3. 在其他 ViT 搜索空間的搜索效果
此外,研究者也在 PiT 搜索空間上進(jìn)行了搜索測(cè)試,并按照論文的設(shè)定,在 COCO 數(shù)據(jù)集上測(cè)試了搜索到的模型結(jié)果對(duì)應(yīng)的檢測(cè)效果。結(jié)果如下表所示:他們搜索找到的 PiT 模型 TF-TAS-Ti、TF-TAS-XS 和 TF-TAS-S 和基于手工設(shè)計(jì)的 PiT 的效果不相上下,而且遠(yuǎn)好于隨機(jī)搜索的模型結(jié)果。并且在檢測(cè)效果上,研究者的方法也有一定的優(yōu)勢(shì)。這些結(jié)果驗(yàn)證了該方法的有效性和普適性。
? ?