为在线学习创建持续进化的神经网络,已经有好几次尝试。但是他们不可避免地遇到了所谓的灾难性遗忘(有时也称为灾难性干扰)问题,在这种情况下,适应新的任务会导致神经网络“忘记”它以前学过的东西。早在1989年,研究人员McCloskey 和 Cohen就首次发现了这种现象,当时他们测试了一个网络按顺序学习关联列表任务的能力。在他们的实验中,第一个任务包括从两组 A 和 B 中学习成对的任意单词,如“火车头 - 抹布,窗户 - 理由,自行车 - 树等”。然后它开始学习第二个任务,在这个任务中,A 与 C 组的不同单词配对,比如“火车头 - 云,窗户 - 书,自行车 - 沙发等” ,并在1、5、10和20次迭代学习 AC 列表后,测试它记住 AB 列表中的配对的能力。下面的图表 b)显示了在开始学习 AC 任务后,网络是如何迅速忘记 AB 任务的,相比之下,在相同的实验设置 a)中人类的表现,表明我们的大脑能够更有效地记住先前任务的知识。 毫无疑问,构建一个结构有限,但能够在连续的数据流中保留过去经验知识的网络是非常有挑战性的。克服灾难性遗忘的最初策略依赖于随着新类别的学习,逐步向网络分配更多的资源,这种方法对于大多数现实世界的应用程序来说最终是不可持续的。现在让我们来看看一些最新的策略,这些策略可以迫使网络记住已经学到的东西。
记忆的策略
正则化(Regularization)
处理灾难性遗忘的一个机制是正则化,已经被深入研究过。正如我们所知道的,一个网络通过调整连接的权重来适应学习新的任务,而正规化涉及到改变权重的可塑性,基于他们对以前的任务的重要性。 在2017年一篇高引用的论文“Overcoming catastrophic forgetting in neural networks”中,Kirkpatrick 等人引入了一种称为EWC(Elastic Weight Consolidation)的正则化技术。当遇到新任务时,EWC 通过约束权重尽量靠近学到的值,来保持对以前学习的任务重要的连接的准确性。 为了说明 EWC 是如何工作的,假设我们正在学习一个分类任务 A,我们的网络正在学习一组权重 θ。实际上,在 A 上有多种可以得到良好性能的 θ 设置——上图中灰色椭圆表示的权重范围。当网络继续学习一个与不同权重范围(奶油色椭圆)相关的另一个任务 B 时,它的重量因此被调整,以至于它们落在A表现好的权重范围外,如蓝色箭头所示,灾难性遗忘就发生了。
在 EWC 中,引入了二次惩罚项来约束网络参数,使其在学习 B 时保持在任务 A 的低误差区域内,如红色箭头所示。二次惩罚作为一种“弹簧”,限定了参数在以前学习到的解决方案范围内,因此得名Elastic Weight Consolidation。弹簧的弹性度,即二次惩罚的度,在权重之间的不同取决于权重对于先前任务的“重要性”。例如,在图表中,任务 A 的2D权重椭球体沿 x 维的长度比 y 维的长,表明 x 权重对于 A 更重要,因此在调整学习 B 时弹性比 y 权重小。若未能使弹簧以这种方式适应,而是对每个权重施加相同的弹性系数的话,将导致权重不能很好地适合任一任务,如图中的绿色箭头所示。
EWC 模型在一连串任务上进行训练,每个任务由一批数据组成。任务是手写 MNIST 数字图像,固定数量随机洗牌。一旦模型训练了一个任务的数据,它就会转移到下一个任务的批处理中,并且不会再次遇到前一个任务的数据,这就可以测试 EWC“记住”如何执行以前学过的任务的能力。下面的图表显示了 EWC 在一系列任务 A、 B 和 C 上的测试性能,这些任务是逐步进行训练的。 我们可以看到,尽管学习了新的任务,EWC 的性能在之前学习的任务中保持得相当稳定,作为对比的是采用对所有权重使用相同的二次惩罚的方法(绿线)和一个根本不包含惩罚,只是使用标准的随机梯度下降的方法(蓝线)——这两种方法都显示了任务 A 的灾难性遗忘,例如,当任务 B 和 C 被学习时。 重播(Replay) 重播是另一种减少遗忘的流行方法,它包括存储以前遇到的训练数据的一些代表(representation)。数据存储在被称为重播缓冲区的地方。这个技术最早在2016年底Rebuffi 等人发表的论文“iCaRL: Incremental Classifier and Representation Learning”中提出。在其重播缓冲区中,iCaRL 为训练期间遇到的每个类存储成组的图像,称为“样本”图像。我们的目标是让这些图片尽可能代表它们各自的类别。对于训练,iCaRL 一次处理一批(含各类别)。当遇到一个新类时,将使用所有存储的样本和新数据创建一个模拟训练集。所有这些数据都通过网络,之前学习的类的输出存储到下一步,在下一步中更新网络的参数。通过最小化损失函数对网络进行更新,该损失函数将分类损失和蒸馏损失结合在一起,分类损失让网络输出新遇到的类的正确标签,蒸馏损失则鼓励网络重新生成以前学过的类的标签。