来源“上海交大电院”
近年来,以学习通用环境表征为目的的预测学习(Predictive Learning)越来越多地被应用到工业制造、自动驾驶等场景的各种时空决策任务中。针对持续任务学习设定下的时空预测学习问题,电子信息与电气工程学院人工智能研究院杨小康教授带领的团队通过引入并改进已有的持续学习方法,开创性地提出了可持续时空预测学习框架CPL (Continual Predictive Learning)。由杨小康教授和王韫博助理教授指导的相关研究工作“Continual Predictive Learning from Videos”已被CVPR 2022收录并被选为口头报告(oral presentation)(每年Oral约占投稿数的5%)。
预测学习(Predictive Learning)最早由图灵奖获得者Yann LeCun在NIPS 2016大会主题报告中首先被提出。其核心思想可以简单总结为如何通过完成基于给定视频片段的数据预测未来连续帧这一无监督预测学习任务,使得智能体可以学习到数据所在环境中包含的动态先验信息,如物体在力的作用下的运动状态,从而进一步辅助智能体对于未来行为的决策推理。在已有的研究中,往往假设可以提前获得不同环境、不同预测任务的全部训练数据,然后进行模型训练。
然而,在实际场景中,如图1所示,模型所面临的环境或任务可能是动态变化的,即待学习的预测任务可能以序列化的非平稳的形式出现,比如机械臂需要首先完成推动的动作,再分别学习抓取和堆叠的动作。模型需要序列化地学习一连串不同的任务,而在学习当前任务时,我们无法获得或只能少量获得之前任务的训练数据。在这种持续学习(Continual Learning)的设定下,多数现有的预测学习方法会遭遇严重的灾难性遗忘(Catastrophic Forgetting)问题,即模型在学习任务序列的过程中,会逐渐遗忘掉之前已学习任务的知识,造成在之前任务上测试性能的降低,并且研究人员发现直接将已有作用在图像领域的持续学习方法应用到时空预测上并不能取得很好的效果。
图1 可持续时空预测问题定义及所提出架构在测试时的运行流程
针对以上问题,研究团队开创性地提出了一种可持续时空预测学习框架CPL(Continual Predictive Learning),整体结构如图2所示。在网络结构设计上,针对性地设计了混合世界模型(Mixture World Model),通过引入类别标签分离不同任务对应的时空动态信息。在遗忘数据增广上,提出了基于预测的经验回放(Predictive Experience Replay)策略,通过结合单帧图像生成和世界模型的复用,在内存受限的条件下实现了已有任务数据的生成,打破了数据限制。最后在模型测试流程中,引入了自适应的无参数任务推断机制(Non-Parametric Task Inference),进一步缓解预测阶段的标签遗忘问题。
图2 CPL整体框架