測試時領域適應的魯棒性得以保證,TRIBE在多真實場景下達到SOTA
測試時領域適應(Test-Time Adaptation)的目的是使源域模型適應推理階段的測試數據,在適應未知的圖像損壞領域取得了出色的效果。然而,當前許多方法都缺乏對真實世界場景中測試數據流的考慮,例如:
- 測試數據流應當是時變分布(而非傳統領域適應中的固定分布)
- 測試數據流可能存在局部類別相關性(而非完全獨立同分布采樣)
- 測試數據流在較長時間里仍表現全局類別不平衡
近日,華南理工、A*STAR 和港中大(深圳)團隊通過大量實驗證明,這些真實場景下的測試數據流會對現有方法帶來巨大挑戰。該團隊認為,最先進方法的失敗首先是由于不加區分地根據不平衡測試數據調整歸一化層造成的。
為此,研究團隊提出了一種創新的平衡批歸一化層 (Balanced BatchNorm Layer),以取代推理階段的常規批歸一化層。同時,他們發現僅靠自我訓練(ST)在未知的測試數據流中進行學習,容易造成過度適應(偽標簽類別不平衡、目標域并非固定領域)而導致在領域不斷變化的情況下性能不佳。
因此,該團隊建議通過錨定損失 (Anchored Loss) 對模型更新進行正則化處理,從而改進持續領域轉移下的自我訓練,有助于顯著提升模型的魯棒性。最終,模型 TRIBE 在四個數據集、多種真實世界測試數據流設定下穩定達到 state-of-the-art 的表現,并大幅度超越已有的先進方法。研究論文已被 AAAI 2024 接收。
論文鏈接:https://arxiv.org/abs/2309.14949
代碼鏈接:https://github.com/Gorilla-Lab-SCUT/TRIBE
引言
深度神經網絡的成功依賴于將訓練好的模型推廣到 i.i.d. 測試域的假設。然而,在實際應用中,分布外測試數據的魯棒性,如不同的照明條件或惡劣天氣造成的視覺損壞,是一個需要關注的問題。最近的研究顯示,這種數據損失可能會嚴重影響預先訓練好的模型的性能。重要的是,在部署前,測試數據的損壞(分布)通常是未知的,有時也不可預測。
因此,調整預訓練模型以適應推理階段的測試數據分布是一個值得價值的新課題,即測試時領域適 (TTA)。此前,TTA 主要通過分布對齊 (TTAC++, TTT++),自監督訓練 (AdaContrast) 和自訓練 (Conjugate PL) 來實現,這些方法在多種視覺損壞測試數據中都帶來了顯著的穩健提升。
現有的測試時領域適應(TTA)方法通?;谝恍﹪栏竦臏y試數據假設,如穩定的類別分布、樣本服從獨立同分布采樣以及固定的領域偏移。這些假設啟發了許多研究者去探究真實世界中的測試數據流,如 CoTTA、NOTE、SAR 和 RoTTA 等。
最近,對真實世界的 TTA 研究,如 SAR(ICLR 2023)和 RoTTA(CVPR 2023)主要關注局部類別不平衡和連續的領域偏移對 TTA 帶來的挑戰。局部類別不平衡通常是由于測試數據并非獨立同分布采樣而產生的。直接不加區分的領域適應將導致有偏置的分布估計。
最近有研究提出了指數式更新批歸一化統計量(RoTTA)或實例級判別更新批歸一化統計量(NOTE)來解決這個挑戰。其研究目標是超越局部類不平衡的挑戰,考慮到測試數據的總體分布可能嚴重失衡,類的分布也可能隨著時間的推移而變化。在下圖 1 中可以看到更具挑戰性的場景示意圖。
由于在推理階段之前,測試數據中的類別流行率未知,而且模型可能會通過盲目的測試時間調整偏向于多數類別,這使得現有的 TTA 方法變得無效。根據經驗觀察,對于依靠當前批數據來估計全局統計量來更新歸一化層的方法來說,這個問題變得尤為突出(BN, PL, TENT, CoTTA 等)。
這主要是由于:
1.當前批數據會受到局部類別不平衡的影響帶來有偏置的整體分布估計;
2.從全局類別不平衡的整個測試數據中估計出單一的全局分布,全局分布很容易偏向多數類,導致內部協變量偏移。
為了避免有偏差的批歸一化(BN),該團隊提出了一種平衡的批歸一化層(Balanced Batch Normalization Layer),即對每個單獨類別的分布進行建模,并從類別分布中提取全局分布。平衡的批歸一化層允許在局部和全局類別不平衡的測試數據流下得到分布的類平衡估計。
隨著時間的推移,領域轉移在現實世界的測試數據中經常發生,例如照明 / 天氣條件的逐漸變化。這給現有的 TTA 方法帶來了另一個挑戰,TTA 模型可能由于過度適應到領域 A 而當從領域 A 切換到領域 B 時出現矛盾。
為了緩解過度適應到某個短時領域,CoTTA 隨機還原參數,EATA 用 fisher information 對參數進行正則化約束。盡管如此,這些方法仍然沒有明確解決測試數據領域中層出不窮的挑戰。
本文在兩分支自訓練架構的基礎上引入了一個錨定網絡(Anchor Network)組成三網絡自訓練模型(Tri-Net Self-Training)。錨定網絡是一個凍結的源模型,但允許通過測試樣本調整批歸一化層中的統計量而非參數。并提出了一個錨定損失利用錨定網絡的輸出來正則化教師模型的輸出以避免網絡過度適應到局部分布中。
最終模型結合了三網絡自訓練模型和平衡的批歸一化層(TRI-net self-training with BalancEd normalization, TRIBE)在較為寬泛的的可調節學習率的范圍里表現出一致的優越性能。在四個數據集和多種真實世界數據流下顯示了大幅性能提升,展示了獨一檔的穩定性和魯棒性。
方法介紹
論文方法分為三部分:
- 介紹真實世界下的 TTA 協議;
- 平衡的批歸一化;
- 三網絡自訓練模型。
真實世界下的 TTA 協議
作者采用了數學概率模型對真實世界下具有局部類別不平衡和全局類別不平衡的測試數據流,以及隨著時間變化的領域分布進行了建模。如下圖 2 所示。
平衡的批歸一化
為了糾正不平衡測試數據對 BN 統計量產生的估計偏置,作者提出了一個平衡批歸一化層,該層為每個語義類分別維護了一對統計量,表示為:
為了更新類別統計量,作者在偽標簽預測的幫助下應用了高效的迭代更新方法,如下所示:
通過偽標簽對各個類別數據的采樣點進行單獨統計,并通過下式重新得到類別平衡下的整體分布統計量,以此來對齊用類別平衡的源數據學習好的特征空間。
在某些特殊情況下,作者發現當類別數量較多或偽標簽準確率較低 (accuracy<0.5) 的情況下,以上的類別獨立的更新策略效果沒那么明顯。因此,他們進一步用超參數 γ 來融合類別無關更新策略和類別獨立更新策略,如下式:
通過進一步分析和觀察,作者發現當 γ=1 時,整個更新策略就退化成了 RoTTA 中的 RobustBN 的更新策略,當 γ=0 時是純粹的類別獨立的更新策略,因此,當 γ 取值 0~1 時可以適應到各種情況下。
網絡自訓練模型
作者在現有的學生 - 教師模型的基礎上,添加了一個錨定網絡分支,并引入了錨定損失來約束教師網絡的預測分布。這種設計受到了 TTAC++ 的啟發。TTAC++ 指出在測試數據流上僅靠自我訓練會容易導致確認偏置的積累,這個問題在本文中的真實世界中的測試數據流上更加嚴重。TTAC++ 采用了從源域收集到的統計信息實現領域對齊正則化,但對于 Fully TTA 設定來說,這個源域信息不可收集。
同時,作者也收獲了另一個啟示,無監督領域對齊的成功是基于兩個領域分布相對高重疊率的假設。因此,作者僅調整了 BN 統計量的凍結源域模型來對教師模型進行正則化,避免教師模型的預測分布偏離源模型的預測分布太遠(這破壞了之前的兩者分布高重合率的經驗觀測)。大量實驗證明,本文中的發現與創新是正確的且魯棒的。以下是錨定損失的表達式:
下圖展示了 TRIBE 網絡的框架圖:
實驗部分
論文作者在 4 個數據集上,以兩種真實世界 TTA 協議為基準,對 TRIBE 進行了驗證。兩種真實世界 TTA 協議分別是全局類分布固定的 GLI-TTA-F 和全局類分布不固定的 GLI-TTA-V。
上表展示了 CIFAR10-C 數據集兩種協議不同不平衡系數下的表現,可以得到以下結論:
1.只有 LAME, TTAC, NOTE, RoTTA 和論文提出的 TRIBE 超過了 TEST 的基準線,表明了真實測試流下更加魯棒的 TTA 方法的必要性。
2.全局類別不平衡對現有的 TTA 方法帶來了巨大挑戰,如先前的 SOTA 方法 RoTTA 在 I.F.=1 時表現為錯誤率 25.20% 但在 I.F.=200 時錯誤率升到了 32.45%,相比之下,TRIBE 能穩定地展示相對較好的性能。
3. TRIBE 的一致性具有絕對優勢,超越了先前的所有方法,并在全局類別平衡的設定下 (I.F.=1) 超越先前 SOTA (TTAC) 約 7%,在更加困難的全局類別不平衡 (I.F.=200) 的設定下獲得了約 13% 的性能提升。
4.從 I.F.=10 到 I.F.=200,其他 TTA 方法隨著不平衡度增加,呈現性能下跌的趨勢。而 TRIBE 能維持較為穩定的性能表現。這歸因于引入了平衡批歸一化層,更好地考慮了嚴重的類別不平衡和錨定損失,這避免了跨不同領域的過度適應。
更多數據集的結果可查閱論文原文。
此外,表 4 展示了詳細的模塊化消融,有以下幾個觀測性結論:
1.僅將 BN 替換成平衡批歸一化層 (Balanced BN),不更新任何模型參數,只通過 forward 更新 BN 統計量,就能帶來 10.24% (44.62 -> 34.28) 的性能提升,并超越了 Robust BN 的錯誤率 41.97%。
2.Anchored Loss 結合 Self-Training,無論是在之前 BN 結構下還是最新的 Balanced BN 結構下,都得到了性能的提升,并超越了 EMA Model 的正則化效果。
本文的其余部分和長達 9 頁的附錄最終呈現了 17 個詳細表格結果,從多個維度展示了 TRIBE 的穩定性、魯棒性和優越性。附錄中也含有對平衡批歸一化層的更加詳細的理論推導和解釋。
總結和展望
為應對真實世界中 non-i.i.d. 測試數據流、全局類不平衡和持續的領域轉移等諸多挑戰,研究團隊深入探索了如何改進測試時領域適應算法的魯棒性。為了適應不平衡的測試數據,作者提出了一個平衡批歸一化層(Balanced Batchnorm Layer),以實現對統計量的無偏估計,進而提出了一種包含學生網絡、教師網絡和錨定網絡的三層網絡結構,以規范基于自我訓練的 TTA。
但本文仍然存在不足和改進的空間,由于大量的實驗和出發點都基于分類任務和 BN 模塊,因此對于其他任務和基于 Transformer 模型的適配程度仍然未知。這些問題值得后續工作進一步研究和探索。