ImageNet错误率小于4%,数据量依然不够,N-Shot Learning或是终极解决之道?

2019 年 8 月 20 日 AI100

作者 | Heet Sankesara
译者 | 陆离, 编辑 | 夕颜
出品 | AI科技大本营(ID: rgznai100)

【导读】“ 如果人工智能是新的电力能源,那么数据就是新的煤炭能源。”由于人工智能(AI)和深度学习的快速发展,到现在为止,影响了无数的生命,改变了大千世界,这些都是我们曾经在科幻小说中梦寐以求的。 幸的是,正如我们已经看到的那样,目前世界上可消耗的煤炭资源濒临枯竭,许多 AI 应用系统几乎没有,甚至根本没有可以访问到它们的数据。

新技术弥补了物理资源的不足,同样,也需要新技术来满足在获得很少数据的情况下应用系统依然能正常地运行。那么,N-shot Learning 就成为了这个异常热门领域的核心话题。

N-Shot Learning

 

你可能会问,到底什么是“shot”?问得好。“ shot ”只不过是一个可供训练的实例,所以在 N-Shot Learning 中,我们有 N 个可以供训练的实例。对于“Few-Shot Learning”,“few”通常介于 0 到 5 之间,这就意味着利用零个实例进行模型训练的方式被称为“Zero-Shot Learning”,而只用一个实例的就被称为“One-Shot Learning”,以此类推。这些变量都试图用不同级别的训练目标来解决相同的问题。

为什么是N-Shot?

当我们已经在 ImageNet 中得到的错误率小于 4% 时,为什么还需要 N-Shot Learning 呢?
 
首先,ImageNet 的数据集包含了大量的机器学习的例子,包括了医学成像、药物研发和许多其它可能对人工智能至关重要的领域,而且并不仅涉及到这些。典型的深度学习体系结构总是依赖于大量的数据来获得足够的结果——例如,ImageNet 需要对数百个热狗图像进行训练,之后才能准确地确认新图像里有没有热狗。另外,有一些数据集,就像 7 月 4 日(美国吃热狗比赛日)庆祝活动之后的冰箱那样,里面的热狗非常少。
 
在数据缺乏的时候,机器学习有许多的用例,这就是引入 N-Shot Learning 的由来。我们需要训练一个深度学习模型,这个模型有数百万甚至数十亿个参数,所有的参数都是随机初始化的,然后利用不超过 5 个图像来学习如何对一个不可见的图像进行分类。简而言之,我们的模型必须能够在使用数量极其有限的热狗图像的条件下进行训练。
 
要处理像这样复杂的问题,我们首先需要定义明确。在 N-Shot Learning 领域里,每“K”个类中我们要标记“n”个实例,即总共有 N∗K 个实例,我们称之为支持数据集 S。我们还必须对查询数据集 Query Set Q 进行分类,每个实例都位于 K 个类中的一个。N-shot Learning 有三个主要的子领域:Zero-Shot Learning、One-Shot Learning, 以及Few-Shot Learning,每个子领域都值得我们研究。

Zero-Shot Learning

对我来说,这是一个最有趣的子领域。Zero-Shot Learning 的目的是,在没有一个训练实例的情况下对看不见的类进行分类。一个机器如何在没有获得任何数据的情况下进行“学习”呢?以这种方式思考的话,你能在不可见的情况下对一个对象进行分类吗?

        夜空中的仙后座

是的,如果你在它的外观、特征和作用几个方面有足够的数据的话,就可以这么做。回想一下你小时候是如何了解这个世界的,你可以在了解了火星的颜色和大概的位置之后,在夜空中找到它;或者可以仅凭别人告诉你仙后座基本上是一个难看的“W”型来发现它。
 
根据今年 NLP(自然语言处理)的发展趋势, Zero-Shot Learning 将变得更加的高效。机器会利用图像的元数据来执行相同的工作,元数据就是与图像关联的一些特征。以下是有关这方面的几篇论文,获得了不错的反响:

  • 学习比较:Few-Shot Learning 的关系网络

(https://arxiv.org/pdf/1711.06025v2.pdf)
  • 粒度可视化描述中的学习深度的表示方式

(https://arxiv.org/pdf/1605.05395v1.pdf)

  • 通过减少 Hubness 问题来提高 Zero-Shot Learning 的效率

(https://arxiv.org/abs/1412.6568v3)

One-Shot Learning

在 One-Shot Learning 中,我们每个类只有一个实例。现在的任务是利用这个限制将任何测试图像分配到一个类里面。为了达到这一目的,我们开发了许多不同的体系结构,例如 Siamese 神经网络(https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf),它取得了重大的进步并得到了显著的成果,然后就是匹配网络(,这也帮助我们在这一领域获得了巨大的飞跃。
 
现在有许多关于 One-Shot Learning 方面的优秀论文可以参考,如下所示:

  •  对于深度网络快速适应(Fast Adaptation of Deep Networks)的模型未知元学习(Model-Agnostic Meta-Learning)
    (https://arxiv.org/pdf/1703.03400v3.pdf)

  • 基于记忆增强神经网络学习(Memory-Augmented Neural Networks)的One-Shot Learning
    (https://arxiv.org/pdf/1605.06065v1.pdf)


  • 对于Few-Shot Learning的原型网络(Prototypical Networks)(https://arxiv.org/pdf/1703.05175v2.pdf)


Few-Shot Learning

Few-Shot Learning 只是 One-Shot Learning 的一个灵活版本,我们有多个训练实例(通常是2到5个图像,虽然上述大多数模型也可以用于 Few-Shot Learning)。
 
在 2019 年的计算机视觉与模式识别的大会上,有人针对 Few-Shot Learning 提出了元迁移学习(https://arxiv.org/pdf/1812.02391v3.pdf)。该模型为未来的研究奠定了基础,它获得了最高水平的成果,并为更复杂的元迁移学习方法铺平了道路。
 
通过许多元学习和强化学习算法与典型的深度学习算法相结合,获得了显著的成果。原型网络(Prototypical networks )是最流行的深度学习算法之一,就经常被用于此。
 
在本文中,我们将使用 原型网络(https://arxiv.org/pdf/1703.05175v2.pdf) 来完成这项任务,并了解它的工作过程以及原理。

原型网络的思想

             
上图是原型网络的功能图。编码器将图像映射到嵌入空间(黑圈)中的向量中。辅助图像用于定义原型(星型)。原型和被编码的查询图像之间的距离用于对它们进行分类
 

与典型的深度学习体系结构不同,原型网络不直接对图像进行分类,而是在度量空间(https://en.wikipedia.org/wiki/Metric_space)中学习图像的映射。

 
对于那些需要复习一下数学知识的人来说,度量空间是处理“距离”方面的概念。它没有一个可以分辨的“起点”;相反,在度量空间中,我们只是计算一个点到另一个点的距离。因此,你缺少在向量空间中进行的加法和标量乘法的运算(因为与向量不同,一个点只能代表一个坐标,而添加两个坐标或是改变一个坐标的大小是没有意义的)。请点击此链接 (https://math.stackexchange.com/questions/114940/what-is-the-difference-between-metric-spaces-and-vector-spaces) 来更多的了解向量空间和度量空间之间的差异。
 
 
幽默时段
 
既然我们有了这样的条件,那么就可以开始了解原型网络是如何学习度量空间中图像的映射,而不是直接对图像进行分类的了。如上图所示,编码器在很近的距离内映射同一类的图像,而不同类之间的距离则相当大。这就意味着,无论什么时候给出一个新的实例,网络只检查距离最近的集群,并将该实例分配到其相应的类里。将图像映射到度量空间的原型网络中的底层模型,可以称为“Image2Vector”模型,这是一种基于卷积神经网络(Convolutional Neural Network,CNN)的体系结构。
 
现在,对于那些不太了解 CNN 的人来说,可以点击下面的链接获得更多相关的资料:

  • 点击此链接查看最佳深度学习课程列表
    (https://blog.floydhub.com/best-deep-learning-courses-updated-for-2019/)


  • 点击这里查看最佳深度学习书籍的列表
    (https://blog.floydhub.com/best-deep-learning-books-updated-for-2019/)


  • 想要快速的学习和应用,请参考创建你的第一个ConvNet
    (https://blog.floydhub.com/building-your-first-convnet/)

原型网络简介

简单地说,原型网络的目标就是训练一个分类器。然后,这个分类器可以对训练期间不可用的新类进行标准化,并且仅仅需要每个新类的少量实例。因此,训练数据集中包含一组类的图像,而测试数据集中包含另一组类的图像,这些类与前一组类完全不相关。在该模型中,实例会被随机分为支持数据集和查询数据集。

原型网络概述

通过 Few-shot 原型 Ck 计算出用以作为每个类的嵌入式支持实例的平均值。编码器映射新图像(X),并将其分类到最接近的类,如上图中的C2,图源:arXiv
 
在 Few-Shot Learning 的环境中,训练迭代就是一个片段。一个片段只不过是我们用来训练一次网络、计算损失和反向传播错误的一个步骤。在每个片段中,我们从训练数据集中随机地选择Nc 类。对于每个类,我们随机地抽取 Ns 类图像。这些图像属于支持数据集,学习模型称为 Ns-Shot 模型。而另一个随机采样的 Nq 类图像属于查询集。这里 Nc、Ns 和 Nq 只是模型中的超参数,其中 Nc 是每次迭代的类的数量,Ns 是每个类的支持实例的数量,Nq 是每个类的查询实例的数量。
 
然后,我们通过“Image2Vector”模型从支持数据集的图像中检索 D 维度的节点。该模型对图像进行编码,使其在度量空间中具有相应的节点。对于每个类,现在有多个节点,但是需要将它们表示为对于每个类的一个节点。因此,我们给每个类计算几何中心,即节点的平均值。之后,我们还需要对查询图像进行分类。
 
为此,我们首先需要将查询数据集中的每个图像编码为一个节点。之后,计算中心(centroid)到每个查询点的距离。最后,预测每个查询图像都位于最接近它的类之中。一般来说,这就是模型的工作原理。
 
现在的问题是,这个“Image2Vector”模型的体系结构是什么样的呢?

Image2Vector 作用

              本文使用的 Image2vector CNN 体系结构
 
出于实践的目的,使用 4-5 个 CNN 块。 如上图所示,每个块由一个 CNN 层组成,随后是批量规范化,然后是一个 ReLu 激活函数,该函数将引入最大池化层。 在所有的块之后,列出剩余的输出并作为结果返回。 这里(https://arxiv.org/pdf/1703.05175v2.pdf)是本文中所使用的体系结构,你可以使用自己喜欢的体系结构。 有必要知道我们为何称之为“Image2Vector”模型,但它实际上是将图像转换为度量空间中的64维的节点。 要进一步了解其中的差异,请查看数学堆栈(https://math.stackexchange.com/questions/645672/what-is-the-difference-between-a-point-and-a-vector)交换中的结论吧。

损失函数

 
Negative log-likelihood 函数的工作(来源)
 
现在我们知道了模型是如何工作的,你可能想知道我们是如何计算损失函数的。 这需要一个足够强大的损失函数,使我们的模型能够快速有效地学习表示。 原型网络使用的是 log-softmax 损失函数,它只不过是在softMax损失函数上再做多一次 log 运算。 当无法预测正确的类时,log-softmax 会对模型产生重大的不利影响,这正是我们所需要的。 要了解更多的有关损失函数的内容,请点击这里(https://ljvmiranda921.github.io/notebook/2017/08/13/softmax-and-the-negative-log-likelihood/)。 另外,这里(https://discuss.pytorch.org/t/logsoftmax-vs-softmax/21386)有一个关于 softMax 和 log-softmax 不错的概述。

数据集概述

 
                Omniglot 数据集中的几个图像类,图源:GitHub
 
网络是在 Omniglot 数据集(https://github.com/brendenlake/omniglot)上训练的。Omniglot 数据集是为开发更多的类似人类的学习算法而设计的。它包含 1623 个不同的手写字符,来自50个不同的字母。然后,为了增加类的数量,将所有图像分别旋转 90、180 和 270 度,每次旋转都会产生一个额外的类。因此,类的总数达到了 6492(1623*4)个。我们将 4200 个类的图像分割成训练数据集,而其余的则分割成测试数据集。对于每一个片段,我们基于 5 个实例训练的模型,而这5个实例是来自随机选择的 64 个类中的一个。我们对模型进行了 1 小时的训练,获得了大约 88% 的准确度。而官方则声称,经过了几个小时的训练和一些参数的调整后,准确度达到了 99.7%。
 
代码:
https://github.com/Hsankesara/Prototypical-Networks
 
让我们深入代码:
               
上面的代码是 Image2Vector CNN 体系结构的一个实现。它获取维度为 28x28x3 的图像,并返回一个长度是 64 的向量。
               
上面的代码是原型网络中一个片段的实现。如果你有任何疑问,请点击此链接(https://github.com/Hsankesara/DeepResearch/)进行评论或者提问。
 
网络概述,图源:YouTube

代码的结构与解释算法的格式相同。我们为原型网络函数提供了以下输入:输入图像数据、输入标签、每次迭代的类数量(即   )、每个类的支持实例数量(即   )和每个类的查询实例数量(即     )。函数返回了  ,它是从每个查询节点到每个平均节点的距离矩阵,    是包含与     对应的标签的向量。   存放了    的图像实际所属的类。

在上图中,我们可以看到 3 个类被使用了,即   =3,而且对于每个类,总共有 5 个实例用于训练,即      =5。上面的 S 表示包含了那 15(     )个图像的支持数据集,而 X 则表示查询数据集。请注意,支持数据集和查询数据集都通过了 f,而f只是我们的“Image2Vector”函数。它映射了度量空间中所有的图像。让我们一步一步地把整个过程分解开看看。
 
首先,我们从输入数据中随机地选择 Nc 的类。对于每个类,我们使用random_sample_cls函数从图像中随机选择一个支持数据集和一个查询数据集。在上图中,S 是支持数据集,而 X 是查询数据集。现在,我们选择几个类(C1、C2和C3),通过“Image2Vector”模型,并使用get_centroid 函数来计算每个类的中心。在以 C1 和 C2 为中心的附近的图像中也可以观察到同样的情况,这是用邻近点计算出来的。每个中心代表了一个类,这将用于分类查询。
 
网络中的中心计算,图源:YouTube
 
在为每个类计算完了中心之后,我们现在必须预测其中一个类的查询图像。为此,我们需要那些与每个查询相对应的实际标签,这些标签是通过调用get_query_y函数来获取的。   是分过类的数据,并且函数将这个分过类的文本数据转换为一个单热(one-hot)向量,该向量在列节点对应的图像实际所属的行标签中仅为“1”,而在列中则为“0”。
 
之后,需要对应于每个    图像的节点来对其进行分类。我们使用“Image2Vector”模型来获取这些节点,现在需要对它们进行分类。为此,我们计算   中每个点到每个类中心的距离。这就给了我们一个矩阵,其中索引ij表示与第 i 个查询图像相对应的节点与第 j 个类的中心的距离。我们使用get_query_x函数来构造矩阵,并将矩阵保存在   变量之中。这对于附近的图像也是一样的。对于查询数据集中的每个实例,则正在计算它与 C1、C2 和 C3 之间的距离。在这种情况下,x 最接近 C2,因此我们可以得出预测,x 属于 C2 这个类。
 
通过编程的方式,我们可以通过使用一个简单的 argmin 函数来完成同样的工作,即找出预测图像所在的类。然后利用预测出的类和实际的类来计算损失,并对误差进行反概率分析。

如果你想使用训练过的模型,或者只是不得不重新进行训练,那么点击此链接(https://github.com/Hsankesara/DeepResearch/tree/master/Prototypical_Nets),查看我的代码。你可以将它看作API,并使用几行代码来训练模型。你可以点击此链接(https://www.kaggle.com/hsankesara/prototypical-net/)找到相关的网络。

相关资源

以下列出了一些有助于你彻底了解相关主题的资源:

  • 使用 Keras 的关于 Siamese 网络的 One-Shot Learning

    (https://sorenbouma.github.io/blog/oneshot/)


  • One-Shot Learning:使用 Siamese 神经网络的人脸识别

(https://towardsdatascience.com/one-shot-learning-face-recognition-using-siamese-neural-network-a13dcf739e)


  • 匹配网络的官方实现

(https://github.com/AntreasAntoniou/MatchingNetworks)


  • 原型网络的官方实现

(https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch)


  • 对于半监督的 Few-Shot 分类的元学习

(https://arxiv.org/abs/1803.00676)


局限性

尽管原型网络产生了非常好的成果,但仍然有一些局限性。第一个问题是缺乏泛化。原型网络在Omniglot 数据集上表现的很好,因为数据集中的所有图像都是一个字符的图像,因此具有一些相似的特征。然而,如果我们尝试利用这个模型来对不同种类的猫进行分类,它就不会给我们准确的结果了。猫和字符图像之间具有较少的共性,可用于将图像映射到相应度量空间的常见特征数量几乎是可以忽略不计的。
 
原型网络的另一个局限性是它们只使用平均值来确定中心,而忽略了支持数据集的方差。这就阻碍了当图像有噪音的时候模型的分类能力。这一限制是通过使用高斯原型网络(https://arxiv.org/abs/1708.02735)来克服的,它利用了类中的方差,通过使用高斯公式对嵌入点进行建模。

结论

Few-Shot Learning 成为一个热门的研究课题已经有一段时间了。有许多新奇的使用原型网络的方法,如元学习网络(https://arxiv.org/abs/1803.00676),并且已经显示出了显著成果。研究人员也在通过强化学习来进行探索,这也有着相当大的潜力。Few-Shot Learning 这个模型最好的一点是它简单易懂,并且能产生令人难以置信的结果。

原文链接:

https://blog.floydhub.com/n-shot-learning/


(*本文为 AI科技大本营翻译文章,转载请联系微信 1092722531)

福利时刻



入群参与每周抽奖~


扫码添加小助手,回复:大会,加入福利群,参与抽奖送礼!


AI ProCon 2019 邀请到了亚马逊首席科学家@李沐,在大会的前一天(9.5)亲授「深度学习实训营」 ,通过动手实操,帮助开发者全面了解深度学习的基础知识和开发技巧。还有  9大技术论坛、60+主题分享,百余家企业、千余名开发者共同相约 2019 AI ProCon! 5折优惠票抢购中!      
 


推荐阅读


你点的每个“在看”,我都认真当成了喜欢
登录查看更多
0

相关内容

零样本文本分类,Zero-Shot Learning for Text Classification
专知会员服务
95+阅读 · 2020年5月31日
【SIGIR2020】学习词项区分性,Learning Term Discrimination
专知会员服务
15+阅读 · 2020年4月28日
【Uber AI新论文】持续元学习,Learning to Continually Learn
专知会员服务
36+阅读 · 2020年2月27日
深度学习的冬天什么时候到来?
中国计算机学会
14+阅读 · 2019年7月17日
深度学习训练数据不平衡问题,怎么解决?
AI研习社
7+阅读 · 2018年7月3日
机器学习不能做什么?
引力空间站
5+阅读 · 2018年3月28日
为什么深度学习不能取代传统的计算机视觉技术?
人工智能头条
3+阅读 · 2018年3月14日
Anomalous Instance Detection in Deep Learning: A Survey
A Survey on Deep Learning for Named Entity Recognition
Arxiv
26+阅读 · 2020年3月13日
Meta-Transfer Learning for Few-Shot Learning
Arxiv
8+阅读 · 2018年12月6日
Multi-task Deep Reinforcement Learning with PopArt
Arxiv
4+阅读 · 2018年9月12日
Arxiv
22+阅读 · 2018年8月30日
Few Shot Learning with Simplex
Arxiv
5+阅读 · 2018年7月27日
Arxiv
8+阅读 · 2018年5月15日
Arxiv
15+阅读 · 2018年2月4日
VIP会员
相关资讯
相关论文
Anomalous Instance Detection in Deep Learning: A Survey
A Survey on Deep Learning for Named Entity Recognition
Arxiv
26+阅读 · 2020年3月13日
Meta-Transfer Learning for Few-Shot Learning
Arxiv
8+阅读 · 2018年12月6日
Multi-task Deep Reinforcement Learning with PopArt
Arxiv
4+阅读 · 2018年9月12日
Arxiv
22+阅读 · 2018年8月30日
Few Shot Learning with Simplex
Arxiv
5+阅读 · 2018年7月27日
Arxiv
8+阅读 · 2018年5月15日
Arxiv
15+阅读 · 2018年2月4日
Top
微信扫码咨询专知VIP会员