The main challenge in domain generalization (DG) is to handle the distribution shift problem that lies between the training and test data. Recent studies suggest that test-time training (TTT), which adapts the learned model with test data, might be a promising solution to the problem. Generally, a TTT strategy hinges its performance on two main factors: selecting an appropriate auxiliary TTT task for updating and identifying reliable parameters to update during the test phase. Both previous arts and our experiments indicate that TTT may not improve but be detrimental to the learned model if those two factors are not properly considered. This work addresses those two factors by proposing an Improved Test-Time Adaptation (ITTA) method. First, instead of heuristically defining an auxiliary objective, we propose a learnable consistency loss for the TTT task, which contains learnable parameters that can be adjusted toward better alignment between our TTT task and the main prediction task. Second, we introduce additional adaptive parameters for the trained model, and we suggest only updating the adaptive parameters during the test phase. Through extensive experiments, we show that the proposed two strategies are beneficial for the learned model (see Figure 1), and ITTA could achieve superior performance to the current state-of-the-art methods on several DG benchmarks. Code is available at https://github.com/liangchen527/ITTA.
翻译:领域泛化(DG)的主要挑战是在训练和测试数据之间处理分布偏移问题。最近的研究表明,在测试数据中适应学习模型的测试时间训练(TTT)可能是解决该问题的一种有前途的方法。通常,TTT策略的性能取决于两个主要因素:选取适当的辅助TTT任务进行更新,以及在测试阶段识别可靠的参数进行更新。过去的研究和我们的实验都表明,如果不正确考虑这两个因素,TTT可能不会改进学习模型,反而会有害。本文通过提出改进的测试时间适应(ITTA)方法来解决这两个因素。首先,我们提出了一个可学习的一致性损失用于TTT任务,该损失包含可学习参数,可以朝着更好地对齐TTT任务和主要预测任务进行调整。其次,我们为训练模型引入了额外的自适应参数,并建议仅在测试阶段更新自适应参数。通过大量实验证明,我们提出的两种策略对学习模型有益(见图1),并且在几个DG基准上,ITTA方法都可以实现超越当前最先进方法的性能。代码可在https://github.com/liangchen527/ITTA找到。