首頁 >科技週邊 >人工智慧 >解決Batch Norm層等短板的開放環境解決方案

解決Batch Norm層等短板的開放環境解決方案

WBOY
WBOY轉載
2023-04-26 10:01:07784瀏覽

測試時自適應(Test-Time Adaptation, TTA)方法在測試階段指導模型進行快速無監督 / 自監督學習,是目前用於提升深度模型分佈外泛化能力的一種強有效工具。然而在動態開放場景中,穩定性不足仍是現有 TTA 方法的一大短板,嚴重阻礙了其實際部署。為此,來自華南理工大學、騰訊AI Lab 及新加坡國立大學的研究團隊,從統一的角度對現有TTA 方法在動態場景下不穩定原因進行分析,指出依賴於Batch 的歸一化層是導致不穩定的關鍵原因之一,另外測試資料流中某些具有雜訊/ 大規模梯度的樣本容易將模型最佳化至退化的平凡解。基於此進一步提出銳利度敏感且可靠的測試時熵最小化方法 SAR,實現動態開放場景下穩定、高效的測試時模型在線遷移泛化。本工作已入選 ICLR 2023 Oral (Top-5% among accepted papers)。

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

  • 論文標題:Towards Stable Test-time Adaptation in Dynamic Wild World
  • 論文網址:https://openreview.net/forum?id=g2YraF75Tj
  • 開源程式碼:https://github.com/ mr-eggplant/SAR

什麼是Test-Time Adaptation?

傳統機器學習技術通常在預先收集好的大量訓練資料上進行學習,之後固定模型進行推理預測。這種範式在測試與訓練資料來自相同資料分佈時,往往會取得十分優異的表現。但在實際應用中,測試資料的分佈很容易偏離原始訓練資料的分佈(distribution shift),例如在擷取測試資料的時候:1)天氣的變化使得影像中包含有雨雪、霧的遮蔽;2)由於拍攝不當使得影像模糊,或感測器退化導致影像中包含雜訊;3)模型基於北方城市擷取資料進行訓練,卻被部署到了南方城市。以上種種情況十分常見,但對於深度模型而言往往是很致命的,因為在這些場景下其性能可能會大幅下降,嚴重製約了其在現實世界中(尤其是類似於自動駕駛等高風險應用)的廣泛部署。

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

#圖1 Test-Time Adaptation 示意圖(參考[5])及其與現有方法特點對比

不同於傳統機器學習範式,如圖1 所示在測試樣本到來後,Test-Time Adaptation (TTA) 首先基於該資料利用自監督或無監督的方式對模型進行精細化微調,而後再使用更新後的模型做出最終預測。典型的自 / 無監督學習目標包括:旋轉預測、對比學習、熵最小化等等。這些方法均展現出了優異的分佈外泛化(Out-of-Distribution Generalization)效能。相較於傳統的 Fine-Tuning 以及 Unsupervised Domain Adaptation 方法,Test-Time Adaptation 能夠做到線上遷移,效率更高也更普適。另外完全測試時適應方法 [2] 其可以針對任意預訓練模型進行適應,無需原始訓練資料也無需干涉模型原始的訓練過程。以上優點極大增強了 TTA 方法的現實通用性,再加上其展現出來的優異性能,使得 TTA 成為遷移、泛化等相關領域極為熱點的研究方向。

為什麼要 Wild Test-Time Adaptation?

儘管現有TTA 方法在分佈外泛化方面已表現出了極大的潛力,但這種優異的性能往往是在一些特定的測試條件下所獲得的,例如測試資料流在一段時間內的樣本均來自於同一種分佈偏移類型、測試樣本的真實類別分佈是均勻且隨機的,以及每次需要有一個mini-batch 的樣本後才可以進行適應。但事實上,以上這些潛在假設在現實開放世界中是很難被一直滿足的。在實際中,測試資料流可能以任意的組合方式到來,而理想情況下模型不應對測試資料流的到來形式做出任何假設。如圖2 所示,測試資料流完全可能遇到:(a)樣本來自不同的分佈偏移(即混合樣本偏移);(b)樣本batch size 非常小(甚至為1);(c)樣本在一段時間內的真實類別分佈是不均衡的且會動態變化的。本文將上述場景下的 TTA 統稱為 Wild TTA。但不幸的是,現有 TTA 方法在這些 Wild 場景下經常會表現得十分脆弱、不穩定,遷移效能有限,甚至可能損壞原始模型的效能。因此,若想真正實現 TTA 方法在實際場景中的大範圍、深度化應用部署,則解決 Wild TTA 問題即是其中不可避免的重要一環。

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

圖2 模型測試時自適應中的動態開放情境

解決想法與技術方案

本文從統一角度對TTA 在眾多Wild 場景下失敗原因進行分析,進而給出解決方案。

1. 為何 Wild TTA 會不穩定?

(1)Batch Normalization (BN) 是動態場景下TTA 不穩定的關鍵原因之一:現有TTA 方法通常是建立在BN 統計量自適應基礎之上的,即使用測試資料來計算BN 層中的平均值及標準差。然而,在3 個實際動態場景中,BN 層內的統計量估計準確度都會出現偏差,進而引發不穩定的TTA:

  • ##場景(a) :由於BN 的統計量實際上代表了某一種測試資料分佈,使用一組統計量參數同時估計多個分佈不可避免會獲得有限的效能,請參閱圖3;
  • ##場景(b):BN 的統計量依賴batch size 大小,在小batch size 樣本上很難得到準確的BN 的統計量估計,參見圖4;
  • 情境(c):非均衡標籤分佈的樣本會導致BN 層內統計量有偏差,即統計量偏向某一特定類別(此batch 中佔比較大的類別),參見圖5;
為進一步驗證上述分析,本文考慮3 種廣泛應用的模型(搭載不同的BatchLayerGroup Norm),基於兩種代表性TTA 方法(TTT [1] 和Tent [2])進行分析驗證。最終得出結論為:

batch 無關的 Norm 層(Group 和 Layer Norm)一定程度上規避了 Batch Norm 局限性,更適合在動態開放場景中執行 TTA,其穩定性也更高。因此,本文也將以搭載 GroupLayer Norm 的模型進行方法設計。

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

#圖3 不同方法與模型(不同歸一化層)在混合分佈偏移下性能表現

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

#######圖4 不同方法與模型(不同歸一化層)在不同batch size 下效能表現。圖中陰影區域表示此模型效能的標準差,ResNet50-BN 和 ResNet50-GN 的標準差太小導致在圖中不顯著(下圖同)############

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

圖5 不同方法與模型(不同歸一化層)在線上不平衡標籤分佈偏移下效能表現,圖中橫軸Imbalance Ratio 越大代表的標籤不平衡程度越嚴重

#(2)在線熵最小化易將模型優化至退化的平凡解,即將任意樣本預測到同一個類別:根據圖6 (a) 和(b) 顯示,在分佈偏移程度嚴重(level 5)時,在線自適應過程中突然出現了模型退化崩潰現象,即所有樣本(真實類別不同)被預測到同一類;同時,模型梯度的 範數在模型崩潰前後快速增大而後降至幾乎為0,見圖6(c),側面說明可能是某些大尺度/ 雜訊梯度破壞了模型參數,進而導致模型崩潰。

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

圖6 線上測試時熵最小化中的失敗案例分析

2. 銳利度敏感且可靠的測試時熵最小化方法

為了緩解上述模型退化問題,本文提出了銳利度敏感且可靠的測試時熵最小化方法(Sharpness-aware and Reliable Entropy Minimization Method, SAR)。其從兩個方面緩解此問題:1)可靠熵最小化從模型自適應更新移除部分產生較大/ 雜訊梯度的樣本;2)模型銳利度最佳化使得模型對剩餘樣本中所產生的某些雜訊梯度不敏感。具體細節闡述如下:

可靠熵最小化:基於Entropy 建立梯度選擇的替代判斷指標,將高熵樣本(包含圖6 (d) 中區域1 和2 的樣本)排除在模型自適應之外不參與模型更新:

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

其中x 表示測試樣本,Θ 表示模型參數,Batch Norm层等暴露TTA短板,开放环境下解决方案来了表示指示函數,Batch Norm层等暴露TTA短板,开放环境下解决方案来了表示樣本預測結果的熵,Batch Norm层等暴露TTA短板,开放环境下解决方案来了為超參數。只有當 Batch Norm层等暴露TTA短板,开放环境下解决方案来了

時樣本才會參與反向傳播計算。

銳利度敏感的熵優化:透過可靠樣本選擇機制過濾後的樣本中,無法避免仍含有圖6 (d) 區域4中的樣本,這些樣本可能產生雜訊/ 較大梯度繼續幹擾模型。為此,本文考慮將模型最佳化至一個flat minimum,使其能夠對雜訊梯度帶來的模型更新不敏感,即不影響其原始模型效能,最佳化目標為:

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

上述目標的最終梯度更新形式如下:

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

其中 Batch Norm层等暴露TTA短板,开放环境下解决方案来了 受启发于 SAM [4] 通过一阶泰勒展开近似求解得到,具体细节可参见本论文原文与代码。

至此,本文的总体优化目标为:

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

此外,为了防止极端条件下上述方案仍可能失败的情况,进一步引入了一个模型复原策略:通过移动监测模型是否出现退化崩溃,决定在必要时刻对模型更新参数进行原始值恢复。

实验评估

在动态开放场景下的性能对比

SAR 基于上述三种动态开放场景,即 a)混合分布偏移、b)单样本适应和 c)在线不平衡类别分布偏移,在 ImageNet-C 数据集上进行实验验证,结果如表 1, 2, 3 所示。SAR 在三种场景中均取得显著效果,特别是在场景 b)和 c)中,SAR 以 VitBase 作为基础模型,准确率超过当前 SOTA 方法 EATA 接近 10%。

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

表 1 SAR 与现有方法在 ImageNet-C 的 15 种损坏类型混合场景下性能对比,对应动态场景 (a);以及和现有方法的效率对比

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

表 2 SAR 与现有方法在 ImageNet-C 上单样本适应场景中的性能对比,对应动态场景 (b)

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

表 3 SAR 与现有方法在 ImageNet-C 上在线非均衡类别分布偏移场景中性能对比,对应动态场景(c)

消融实验

与梯度裁剪方法的对比:梯度裁剪避免大梯度影响模型更新(甚至导致坍塌)的一种简单且直接的方法。此处与梯度裁剪的两个变种(即:by value or by norm)进行对比。如下图所示,梯度裁剪对于梯度裁剪阈值 δ 的选取很敏感,较小的 δ 与模型不更新的结果相当,较大的 δ 又难以避免模型坍塌。相反,SAR 不需要繁杂的超参数筛选过程且性能显著优于梯度裁剪。

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

图 7 与梯度裁剪方法的在 ImageNet-C(shot nosise, level 5) 上在线不平衡标签分布偏移场景中的性能对比。其中准确率是基于所有之前的测试样本在线计算得出

不同模块对算法性能的影响:如下表所示,SAR 的不同模块协同作用,有效提升了动态开放场景下测试时模型自适应稳定性。

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

表 4 SAR 在 ImageNet-C (level 5) 上在线不平衡标签分布偏移场景下的消融实验

Loss 表面的銳利度視覺化:透過在模型權重增加擾動對損失函數視覺化的結果如下圖所示。其中,SAR 相較於 Tent 在最低損失等高線內的區域(深藍色區域)更大,表明 SAR 獲得的解更加平坦,對於噪聲 / 較大梯度更加魯棒,抗干擾能力更強。

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

圖8 熵損失表面視覺化

結語

本文致力於解決在動態開放場景中模型線上測試時自適應不穩定的難題。為此,本文首先從統一的角度對已有方法在實際動態場景失效的原因進行分析,並設計完整的實驗進行深度驗證。基於這些分析,本文最終提出銳利度敏感且可靠的測試時熵最小化方法,透過抑制某些具有較大梯度/ 雜訊測試樣本對模型更新的影響,實現了穩定、高效的模型在線測試時自適應。

以上是解決Batch Norm層等短板的開放環境解決方案的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:51cto.com。如有侵權,請聯絡admin@php.cn刪除