首頁 >科技週邊 >人工智慧 >TRIBE實現領域適應的穩健性,在多真實場景下達到SOTA的AAAII 2024

TRIBE實現領域適應的穩健性,在多真實場景下達到SOTA的AAAII 2024

WBOY
WBOY轉載
2024-01-01 10:38:521374瀏覽
測試時領域適應(Test-Time Adaptation)的目的是使源域模型適應推理階段的測試數據,在適應未知的圖像損壞領域取得了出色的效果。然而,目前許多方法都缺乏對真實世界場景中測試資料流的考慮,例如:

  • #測試資料流應為時變分佈(而非傳統領域適應中的固定分佈)
  • 測試資料流可能存在局部類別相關性(而非完全獨立同分佈取樣)
  • 測試資料流在較長時間裡仍表現全域類別不平衡

#近日,華南理工、A* STAR 和港中大(深圳)團隊透過大量實驗證明,這些真實場景下的測試資料流會對現有方法帶來巨大挑戰。團隊認為,最先進方法的失敗首先是由於不加區分地根據不平衡測試資料調整歸一化層所造成的。

為此,研究團隊提出了一種創新的平衡批歸一化層(Balanced BatchNorm Layer),以取代推理階段的常規批歸一化層。同時,他們發現僅靠自我訓練(ST)在未知的測試資料流中進行學習,容易造成過度適應(偽標籤類別不平衡、目標域並非固定領域)而導致在領域不斷變化的情況下表現不佳。

因此,該團隊建議透過錨定損失(Anchored Loss) 對模型更新進行正規化處理,從而改進持續領域轉移下的自我訓練,有助於顯著提升模型的穩健性。最終,模型 TRIBE 在四個資料集、多種真實世界測試資料流設定下穩定達到 state-of-the-art 的表現,並大幅超越現有的先進方法。研究論文已被 AAAI 2024 接收。

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA

論文連結:https://arxiv.org/abs/2309.14949
程式碼連結:https://github.com/Gorilla-Lab- SCUT/TRIBE

引言

#深度神經網路的成功取決於將訓練好的模型推廣到i.i.d. 測試域的假設。然而,在實際應用中,分佈外測試資料的穩健性,如不同的照明條件或惡劣天氣造成的視覺損壞,是一個需要關注的問題。最近的研究顯示,這種資料損失可能會嚴重影響預先訓練好的模型的表現。重要的是,在部署前,測試資料的損壞(分佈)通常是未知的,有時也不可預測。

因此,調整預訓練模型以適應推理階段的測試資料分佈是一個值得價值的新課題,即測試時領域適 (TTA)。先前,TTA 主要透過分佈對齊 (TTAC , TTT ),自監督訓練 (AdaContrast) 和自訓練 (Conjugate PL) 來實現,這些方法在多種視覺損壞測試資料中都帶來了顯著的穩健提升。

現有的測試時領域適應(TTA)方法通常基於一些嚴格的測試資料假設,如穩定的類別分佈、樣本服從獨立同分佈取樣以及固定的領域偏移。這些假設啟發了許多研究者去探討真實世界中的測驗資料流,如 CoTTA、NOTE、SAR 和 RoTTA 等。

最近,對真實世界的 TTA 研究,如 SAR(ICLR 2023)和 RoTTA(CVPR 2023)主要關注局部類別不平衡和連續的領域偏移對 TTA 帶來的挑戰。局部類別不平衡通常是由於測試資料並非獨立同分佈採樣而產生的。直接不加區分的領域適應將導致有偏壓的分佈估計。

最近有研究提出了指數式更新批次統計量(RoTTA)或實例層級判別更新批次統計量(NOTE)來解決這個挑戰。其研究目標是超越局部類別不平衡的挑戰,考慮到測試資料的整體分佈可能嚴重失衡,類別的分佈也可能隨著時間的推移而變化。在下圖 1 可以看到更具挑戰性的場景示意圖。

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA

由於在推理階段之前,測試資料中的類別盛行率未知,而且模型可能會透過盲目的測試時間調整偏向多數類別,這使得現有的 TTA 方法變得無效。根據經驗觀察,對於依賴目前批次資料來估計全域統計量來更新歸一化層的方法來說,這個問題變得特別突出(BN, PL, TENT, CoTTA 等)。

這主要是由於:
1.目前批次資料會受到局部類別不平衡的影響帶來有偏移的整體分佈估計;
2.從全域類別不平衡的整個測試資料中估計出單一的全域分佈,全域分佈很容易偏向多數類,導致內部協變數偏移。

為了避免有偏差的批歸一化(BN),該團隊提出了一種平衡的批歸一化層(Balanced Batch Normalization Layer),即對每個單獨類別的分佈進行建模,並從類別分佈中提取全局分佈。平衡的批歸一化層允許在局部和全局類別不平衡的測試資料流下得到分佈的類平衡估計。

隨著時間的推移,領域轉移在現實世界的測試數據中經常發生,例如照明 / 天氣條件的逐漸變化。這給現有的 TTA 方法帶來了另一個挑戰,TTA 模型可能會因為過度適應到領域 A 而當從領域 A 切換到領域 B 時出現矛盾。

為了緩解過度適應到某個短時領域,CoTTA 隨機還原參數,EATA 用 fisher information 對參數進行正規化約束。儘管如此,這些方法仍然沒有明確解決測試資料領域中層出不窮的挑戰。

本文在兩分支自訓練架構的基礎上引入了一個錨定網路(Anchor Network)組成三網路自訓練模型(Tri-Net Self-Training)。錨定網路是一個凍結的來源模型,但允許透過測試樣本調整批歸一化層中的統計量而非參數。並提出了一個錨定損失利用錨定網路的輸出來正則化教師模型的輸出以避免網路過度適應到局部分佈。

最終模型結合了三網絡自訓練模型和平衡的批歸一化層(TRI-net self-training with BalancEd normalization, TRIBE)在較為寬泛的的可調節學習率的範圍裡表現出一致的優越性能。在四個資料集和多種真實世界資料流中顯示了大幅效能提升,展現了獨一檔的穩定性和穩健性。

方法介紹

#論文方法分為三個部分:
  • #介紹真實世界下的TTA 協定;
  • 平衡的批次歸一化;
  • 三網路自訓練模型。

在真實世界下的TTA 協定

作者採用了數學機率模型對真實世界下具有局部類別不平衡和全局類別不平衡的測試資料流,以及隨著時間變化的領域分佈進行了建模。如下圖 2 所示。

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA

平衡的批次歸一化

#為了修正不平衡測試資料對BN統計量產生的估計偏置,作者提出了一個平衡批歸一化層,該層為每個語義類別分別維護了一對統計量,表示為:

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA

為了更新類別統計量,作者在偽標籤預測的幫助下應用了高效的迭代更新方法,如下所示:

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA


透過偽標籤對各個類別資料的取樣點進行單獨統計,並透過下式重新得到類別平衡下的整體分佈統計量,以此來對齊用類別平衡的來源資料學習好的特徵空間。 AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA
在某些特殊情況下,作者發現當類別數量較多AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA或偽標籤準確率較低(accuracy<0.5) 的情況下,以上的類別獨立的更新策略效果沒那麼明顯。因此,他們進一步用超參數γ 來融合類別無關更新策略和類別獨立更新策略,如下式:

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA

透過進一步分析和觀察,作者發現當γ=1時,整個更新策略就退化成了RoTTA 中的RobustBN 的更新策略,當γ=0 時是純粹的類別獨立的更新策略,因此,當γ 取值0~1 時可以適應到各種情況下。

三網路自訓練模型

作者在現有的學生- 教師模型的基礎上,添加了一個錨定網絡分支,並引入了錨定損失來約束教師網絡的預測分佈。這種設計受到了 TTAC 的啟發。 TTAC 指出在測試資料流上僅靠自我訓練會容易導致確認偏壓的積累,這個問題在本文中的真實世界中的測試資料流上更加嚴重。 TTAC 採用了從源域收集到的統計資訊實作領域對齊正規化,但對於 Fully TTA 設定來說,這個源域資訊不可收集。

同時,作者也收穫了另一個啟示,無監督領域對齊的成功是基於兩個領域分佈相對高重疊率的假設。因此,作者僅調整了BN 統計量的凍結源域模型來對教師模型進行正則化,避免教師模型的預測分佈偏離源模型的預測分佈太遠(這破壞了之前的兩者分佈高重合率的經驗觀測)。大量實驗證明,本文的發現與創新是正確的且穩健的。以下是錨定損失的表達式:

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA

 下圖展示了TRIBE 網路的框架圖:

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA

##實驗部分

論文作者在4 個資料集上,以兩個真實世界TTA 協定為基準,對TRIBE 進行了驗證。兩種真實世界 TTA 協定分別是全域類別分佈固定的 GLI-TTA-F 和全域類別分佈不固定的 GLI-TTA-V。

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA

上表展示了CIFAR10-C 資料集兩種協定不同不平衡係數下的表現,可以得到以下結論:

##1.只有LAME, TTAC, NOTE, RoTTA 和論文提出的TRIBE 超過了TEST 的基準線,顯示了真實測試流下更穩健的TTA 方法的必要性。

2.全域類別不平衡對現有的TTA 方法帶來了巨大挑戰,如先前的SOTA 方法RoTTA 在I.F.=1 時表現為錯誤率25.20%但在I.F.=200 時錯誤率升到了32.45%,相較之下,TRIBE 能穩定地展現相對較好的表現。

3. TRIBE 的一致性具有絕對優勢,超越了先前所有的方法,並在全域類別平衡的設定下(I.F.=1) 超越先前SOTA ( TTAC) 約7%,在更加困難的全局類別不平衡(I.F.=200) 的設定下獲得了約13% 的性能提升。

4.從 I.F.=10 到 I.F.=200,其他 TTA 方法隨著不平衡度增加,呈現效能下跌的趨勢。而 TRIBE 能維持較穩定的性能表現。這歸因於引入了平衡批歸一化層,更好地考慮了嚴重的類別不平衡和錨定損失,這避免了跨不同領域的過度適應。
 
更多資料集的結果可參考論文原文。

此外,表4 展示了詳細的模組化消融,有以下幾個觀測性結論:

AAAI 2024 | 测试时领域适应的鲁棒性得以保证,TRIBE在多真实场景下达到SOTA

1.僅將BN 替換成平衡批歸一化層(Balanced BN),不更新任何模型參數,只透過forward 更新BN 統計量,就能帶來10.24% (44.62 -> 34.28) 的效能提升,並超越了Robust BN 的錯誤率41.97%。

2.Anchored Loss 結合Self-Training,無論是在之前BN 結構下還是最新的Balanced BN 結構下,都得到了性能的提升,並超越了EMA Model 的正規化效果。
 
本文的其餘部分和長達 9 頁的附錄最終呈現了 17 個詳細表格結果,從多個維度展示了 TRIBE 的穩定性、穩健性和優越性。附錄中也含有平衡批歸一化層的更詳細的理論推導與解釋。

總結與展望

#為回應真實世界中non-i.i.d. 測試資料流、全局類別不平衡和持續的領域轉移等諸多挑戰,研究團隊深入探索如何改進測試時領域適應演算法的穩健性。為了適應不平衡的測試數據,作者提出了一個平衡批歸一化層(Balanced Batchnorm Layer),以實現對統計量的無偏估計,進而提出了一種包含學生網絡、教師網絡和錨定網絡的三層網路結構,以規範基於自我訓練的TTA。

但本文仍然存在不足和改進的空間,由於大量的實驗和出發點都基於分類任務和BN 模組,因此對於其他任務和基於Transformer 模型的適配程度仍然未知。這些問題值得後續工作進一步研究和探索。

以上是TRIBE實現領域適應的穩健性,在多真實場景下達到SOTA的AAAII 2024的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:jiqizhixin.com。如有侵權,請聯絡admin@php.cn刪除