Home >Technology peripherals >AI >An open environment solution that solves shortcomings such as the Batch Norm layer

An open environment solution that solves shortcomings such as the Batch Norm layer

WBOY
WBOYforward
2023-04-26 10:01:07784browse

The Test-Time Adaptation (TTA) method guides the model to perform rapid unsupervised/self-supervised learning during the test phase. It is currently a powerful and effective tool for improving the out-of-distribution generalization ability of deep models. However, in dynamic open scenarios, insufficient stability is still a major shortcoming of existing TTA methods, which seriously hinders their practical deployment. To this end, a research team from South China University of Technology, Tencent AI Lab and the National University of Singapore analyzed the reasons why the existing TTA method is unstable in dynamic scenarios from a unified perspective, and pointed out that the normalization layer that relies on Batch is leading to instability. One of the key reasons for stability, in addition, some samples with noise/large-scale gradients in the test data stream can easily optimize the model to a degenerate trivial solution. Based on this, a sharpness-sensitive and reliable test-time entropy minimization method SAR is further proposed to achieve stable and efficient test-time model online migration and generalization in dynamic open scenarios. This work has been selected into ICLR 2023 Oral (Top-5% among accepted papers).

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

  • ##Paper title: Towards Stable Test-time Adaptation in Dynamic Wild World
  • Paper address: https://openreview.net/forum?id=g2YraF75Tj
  • Open source code: https://github.com/ mr-eggplant/SAR
What is Test-Time Adaptation?

Traditional machine learning technology usually learns on a large amount of training data collected in advance, and then fixes the model for inference prediction. This paradigm often achieves very good performance when the test and training data come from the same data distribution. However, in practical applications, the distribution of test data can easily deviate from the distribution of the original training data (distribution shift). For example, when collecting test data: 1) Weather changes cause the image to contain rain, snow, and fog occlusion; 2) The image is blurred due to improper shooting, or the image contains noise due to sensor degradation; 3) The model was trained based on data collected in northern cities, but was deployed to southern cities. The above situations are very common, but they are often fatal for deep models, because their performance may drop significantly in these scenarios, seriously restricting their use in the real world (especially high-risk applications such as autonomous driving) widespread deployment.

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Figure 1 Schematic diagram of Test-Time Adaptation (refer to [5]) and its relationship with the current Comparison of method characteristics

is different from the traditional machine learning paradigm, as shown in Figure 1. After the test sample arrives, Test-Time Adaptation (TTA) is first based on the The data is used to fine-tune the model in a self-supervised or unsupervised manner, and then the updated model is used to make the final prediction. Typical self/unsupervised learning goals include: rotation prediction, contrastive learning, entropy minimization, etc. These methods all exhibit excellent out-of-distribution generalization performance. Compared with the traditional Fine-Tuning and Unsupervised Domain Adaptation methods, Test-Time Adaptation can achieve online migration, which is more efficient and more universal. In addition, the complete test-time adaptation method [2] can be adapted to any pre-trained model without the need for original training data or interference with the original training process of the model. The above advantages have greatly enhanced the practical versatility of the TTA method. Coupled with its excellent performance, TTA has become an extremely hot research direction in migration, generalization and other related fields.

Why Wild Test-Time Adaptation?

Although existing TTA methods have shown great potential in out-of-distribution generalization, this excellent performance is often obtained under some specific test conditions, such as The samples of the data stream within a period of time all come from the same distribution shift type, the true category distribution of the test samples is uniform and random, and each time a mini-batch sample is required before adaptation can be performed. But in fact, these potential assumptions above are difficult to always be satisfied in the real open world. In practice, the test data stream may arrive in any combination, and ideally the model should not make any assumptions about the arriving form of the test data stream. As shown in Figure 2, it is entirely possible for the test data flow to encounter: (a) samples come from different distribution offsets (ie, mixed sample offsets); (b) sample batch size is very small ( Even 1);(c)The true category distribution of samples within a period of time is uneven and changes dynamically. This article refers to the TTA in the above scenario as Wild TTA. Unfortunately, existing TTA methods often appear fragile and unstable in these wild scenarios, with limited migration performance and may even damage the performance of the original model. Therefore, if we want to truly realize the large-scale and in-depth application deployment of the TTA method in actual scenarios, solving the Wild TTA problem is an inevitable and important part.

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Figure 2 Dynamic open scene during adaptation during model testing

Solution ideas and technical solutions

This article analyzes the reasons why TTA fails in many Wild scenarios from a unified perspective, and then provides solutions.

1. Why is Wild TTA unstable?

(1) Batch Normalization (BN) is one of the key reasons for TTA instability in dynamic scenarios: Existing TTA methods are usually established Based on the adaptive BN statistics, the test data is used to calculate the mean and standard deviation in the BN layer. However, in the three actual dynamic scenarios, the statistical estimation accuracy within the BN layer will be biased, resulting in unstable TTA:

  • Scenario (a) : Since the statistics of BN actually represent a certain test data distribution, using a set of statistical parameters to estimate multiple distributions at the same time will inevitably lead to limited performance, see Figure 3;
  • Scenario (b): The statistics of BN depend on the batch size. It is difficult to obtain accurate statistical estimates of BN on small batch size samples. See Figure 4;
  • Scenario (c): Samples with unbalanced label distribution will lead to bias in the statistics within the BN layer, that is, the statistics are biased towards a specific category (the category with a larger proportion in the batch), see Figure 5;

To further verify the above analysis, this article considers 3 widely used models (equipped with different BatchLayerGroup Norm), based on two representative TTA methods (TTT [1] and Tent [2]). Analytical verification. The final conclusion is: Batch-independent Norm layers (Group and Layer Norm) circumvent the limitations of Batch Norm to a certain extent, and are more suitable for executing TTA in dynamic open scenarios, and their stability is also higher. Therefore, this article will also conduct method design based on the model equipped with GroupLayer Norm.

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Figure 3 Different methods and models (different normalization layers) in the mixed distribution partial Move down the performance

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

##Figure 4 Different methods and models ( Different normalization layers) performance under different batch sizes. The shaded area in the figure represents the standard deviation of the model's performance. The standard deviation of ResNet50-BN and ResNet50-GN is too small and is not significant in the figure (the same as the figure below)

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Figure 5 Performance of different methods and models (different normalization layers) under online imbalanced label distribution shift Performance, the larger the Imbalance Ratio on the horizontal axis in the figure, the more serious the label imbalance is.

(2) Online entropy minimization is easy to optimize the model To the degenerate trivial solution, that is, predict any sample to the same class: According to Figure 6 (a) and (b), when the distribution shift is serious (level 5), the online adaptation process suddenly appears Model degradation and collapse phenomenon, that is, all samples (with different real categories) are predicted to the same class; at the same time, the norm of the model gradient increases rapidly before and after the model collapses and then drops to almost 0, see Figure 6 (c), side explanation may It is some large-scale/noise gradient that destroys the model parameters, causing the model to collapse.

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Figure 6 Analysis of failure cases in entropy minimization during online testing

2. Sharpness-sensitive and reliable test-time entropy minimization method

In order to alleviate the above model degradation problem, this paper proposes Sharpness-aware and Reliable Entropy Minimization Method (SAR) during testing. It alleviates this problem in two ways: 1) Reliable entropy minimization removes some samples that produce large/noisy gradients from the model adaptive update; 2) Model sharpness optimization makes the model insensitive to certain noise gradients generated in the remaining samples. The specific details are explained as follows:

Reliable entropy minimization: Based on Entropy, an alternative judgment index for gradient selection is established, and high-entropy samples (including those in Figure 6 (d) Samples from areas 1 and 2) are excluded from model adaptation and do not participate in model update:

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

where x represents the test sample, Θ represents the model parameters, Batch Norm层等暴露TTA短板,开放环境下解决方案来了 represents the indicator function, Batch Norm层等暴露TTA短板,开放环境下解决方案来了 represents the entropy of the sample prediction result, Batch Norm层等暴露TTA短板,开放环境下解决方案来了 represents the super parameter. Only if Batch Norm层等暴露TTA短板,开放环境下解决方案来了

# the sample will participate in the backpropagation calculation.

Sharpness-sensitive entropy optimization: Samples filtered by a reliable sample selection mechanism cannot avoid still containing Figure 6 (d) Region 4 samples, these samples may produce noise/large gradients that continue to interfere with the model. To this end, this article considers optimizing the model to a flat minimum so that it can be insensitive to model updates caused by noise gradients, that is, it will not affect its original model performance. The optimization goal is:

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

The final gradient update form of the above target is as follows:

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Among them Batch Norm层等暴露TTA短板,开放环境下解决方案来了 is inspired by SAM [4] and is obtained by approximate solution through first-order Taylor expansion. For details, please refer to the original text and code of this paper .

At this point, the overall optimization goal of this article is:

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

# #In addition, in order to prevent the possibility that the above scheme may still fail under extreme conditions, a model recovery strategy is further introduced: through mobile monitoring of whether the model has degraded or collapsed, it is decided to restore the original values ​​of the model update parameters at the necessary moment.

Experimental evaluation

Performance comparison in dynamic open scenarios

SAR is based on the above three A dynamic open scenario, namely a) mixed distribution shift, b) single sample adaptation and c) online imbalanced class distribution shift, was experimentally verified on the ImageNet-C data set, and the results are shown in Tables 1, 2, and 3 . SAR achieves remarkable results in all three scenarios, especially in scenarios b) and c). SAR uses VitBase as the base model, and its accuracy exceeds the current SOTA method EATA by nearly 10%.

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Table 1 SAR mixed with existing methods for 15 corruption types in ImageNet-C Performance comparison in scenarios, corresponding to dynamic scenario (a); and efficiency comparison with existing methods

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

##Table 2 Performance comparison of SAR and existing methods in single sample adaptation scenario on ImageNet-C, corresponding to dynamic scenario (b)

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Table 3 Performance comparison between SAR and existing methods in the online non-balanced class distribution shift scenario on ImageNet-C, corresponding Dynamic scene (c)

Ablation experiment

and gradient clipping method Comparison: Gradient clipping is a simple and direct method to avoid large gradients from affecting model updates (or even causing collapse). Here is a comparison with two variants of gradient clipping (ie: by value or by norm). As shown in the figure below, gradient clipping is very sensitive to the selection of the gradient clipping threshold δ. A smaller δ is equivalent to the result of the model not being updated, and a larger δ is difficult to avoid model collapse. In contrast, SAR does not require a complicated hyperparameter filtering process and performs significantly better than gradient clipping.

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Figure 7 with gradient clipping method in ImageNet-C (shot nosise, level 5 ) on online imbalanced label distribution shift scenario. The accuracy is calculated online based on all previous test samples

The impact of different modules on algorithm performance: as shown in the table below , the synergy of different modules of SAR effectively improves the adaptive stability of the model during testing in dynamic open scenarios.

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Table 4 SAR online imbalanced label distribution on ImageNet-C (level 5) Ablation experiment in offset scenario

Loss surface sharpness visualization: The result of visualizing the loss function by adding perturbation to the model weight is shown in the figure below. Among them, SAR has a larger area (dark blue area) within the lowest loss contour than Tent, indicating that the solution obtained by SAR is flatter, more robust to noise/larger gradients, and has stronger anti-interference ability.

Batch Norm层等暴露TTA短板,开放环境下解决方案来了

Figure 8 Entropy loss surface visualization

Conclusion

This article is dedicated to solving the problem of adaptive instability during model online testing in dynamic open scenarios. To this end, this article first analyzes the reasons why existing methods fail in actual dynamic scenarios from a unified perspective, and designs complete experiments to conduct in-depth verification. Based on these analyses, this paper finally proposes a sharpness-sensitive and reliable test-time entropy minimization method, which achieves stable and efficient model online test-time adaptation by suppressing the impact of certain test samples with large gradients/noise on model updates. .

The above is the detailed content of An open environment solution that solves shortcomings such as the Batch Norm layer. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete