OpenAI发布可扩展的元学习算法Reptile

2018 年 3 月 8 日 论智 OpenAI
来源:OpenAI
编译:Bing

今天凌晨,OpenAI发表文章,宣布他们开发了一种简单的元学习算法,名为Reptile。该算法通过对任务进行重复采样,对其执行随机梯度下降,并且将初始参数朝学习的最终参数更新。这种方法的表现和常用的元学习算法MAML一样好,甚至比它安装起来更简单、计算也更高效。

元学习是学习算法如何学习的过程。将任务的分布输入到元学习算法中,其中每个任务都是一个学习问题,然后它会生成一个快速学习者——该学习者可以从少数样本中生成结果。一个好的元学习问题是few-shot分类问题,每个任务都是分类型的问题,其中学习者在每个类别中只能看到1—5个输入输出样本,然后再对新的输入进行分类。

下面就是该算法的一个Demo,在Training Data中画三个训练图像,然后在Input区域画一个输入图像,算法会自动判断这个输入图像属于训练图像的哪一类,并用概率表示。

以下是论智君做的实验:

输入笑脸,与其最相近的是第一幅图

输入哭脸,最相似的是第二幅图

加上脸部轮廓,最相似的绝对是第三幅图啦

Reptile是如何工作的

与MAML类似,Reptile寻求神经网络参数的初始化,以便神经网络能够利用少量数据进行微调。但是当MAML通过梯度下降算法的计算图展开并求导时,Reptile只是以标准方式对每个任务执行随机梯度下降——它不展开计算图,也不计算任何二阶导数。这使得Reptile比MAML占用更少的计算和内存。伪代码如下:

为了替代最后一步,我们可以将Φ-W看成是一个梯度,并且把它输入进一个更复杂的优化器中,比如Adam。

令人惊讶的是,这种方法完全可以行得通。如果k=1,该算法将进行“联合训练”——将所有任务混合,对其进行随机梯度下降。尽管联合训练在一些情况下能学到有用的初始化,当零次学习(zero-shot)不可能时(例如当输出标签随机置换时),它学习得就很少。Reptile需要k>1,其中更新取决于损失函数的高阶导数;正如我们在论文中展示的,这与k=1时的表现完全不同。

为了分析Reptile为什么能工作,研究人员用一个泰勒级数进行大概的更新。他们证明了Reptile的更新使来自同一任务中的不同微粒梯度之间的内积最大化,相当于改进的泛化。这一发现可能会影响元学习环境对解释随机梯度下降的泛化特性。我们的分析表明,Reptile和MAML的更新非常相似,包括两个权重不同的相同项。

在他们的实验中,OpenAI展示了Reptile和MAML在Omniglot和Mini-ImageNet标准测试中对few-shot分类表现出同样的水平。由于Reptile的方差较低,它收敛到解决方案的速度也更快。

我们对Reptile的分析表明我们可以将随机梯度下降以不同方式结合,获得很多不同的算法。在下方的表格中,假设我们在每个任务中执行了k步随机梯度下降,生成了梯度g1,g2,…,gk。下面的图表显示出Omniglot上的学习曲线,它是将每个和看作元梯度获得的。g2对应的是MAML原始论文中提到的一阶MAML。由于方差缩减,梯度包含的越多就会生成越快的学习速率。需要注意的是,只使用g1(即对应k=1)在这一任务中是无法产生进步的,因为零次学习的性能无法改进。

实现

Reptile的实现步骤可以在GitHub上找到。它的内部计算使用TensorFlow,代码可在Omniglot和Mini-ImageNet上复现。我们同时还发布了一个较小的JavaScript版本的实现,它是通过微调一个在TensorFlow上预训练过的模型得到的。上面的Demo即由这个小版本实现。

最后,这是一个few-shot回归的简单例子,从十对(x,y)中随机正弦波。该示例基于PyTorch:

  
    
    
    
  1. import numpy as np

  2. import torch

  3. from torch import nn, autograd as ag

  4. import matplotlib.pyplot as plt

  5. from copy import deepcopy

  6. seed = 0

  7. plot = True

  8. innerstepsize = 0.02 # stepsize in inner SGD

  9. innerepochs = 1 # number of epochs of each inner SGD

  10. outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization

  11. niterations = 30000 # number of outer updates; each iteration we sample one task and update on it

  12. rng = np.random.RandomState(seed)

  13. torch.manual_seed(seed)

  14. # Define task distribution

  15. x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points

  16. ntrain = 10 # Size of training minibatches

  17. def gen_task():

  18.    "Generate classification problem"

  19.    phase = rng.uniform(low=0, high=2*np.pi)

  20.    ampl = rng.uniform(0.1, 5)

  21.    f_randomsine = lambda x : np.sin(x + phase) * ampl

  22.    return f_randomsine

  23. # Define model. Reptile paper uses ReLU, but Tanh gives slightly better results

  24. model = nn.Sequential(

  25.    nn.Linear(1, 64),

  26.    nn.Tanh(),

  27.    nn.Linear(64, 64),

  28.    nn.Tanh(),

  29.    nn.Linear(64, 1),

  30. )

  31. def totorch(x):

  32.    return ag.Variable(torch.Tensor(x))

  33. def train_on_batch(x, y):

  34.    x = totorch(x)

  35.    y = totorch(y)

  36.    model.zero_grad()

  37.    ypred = model(x)

  38.    loss = (ypred - y).pow(2).mean()

  39.    loss.backward()

  40.    for param in model.parameters():

  41.        param.data -= innerstepsize * param.grad.data

  42. def predict(x):

  43.    x = totorch(x)

  44.    return model(x).data.numpy()

  45. # Choose a fixed task and minibatch for visualization

  46. f_plot = gen_task()

  47. xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]

  48. # Reptile training loop

  49. for iteration in range(niterations):

  50.    weights_before = deepcopy(model.state_dict())

  51.    # Generate task

  52.    f = gen_task()

  53.    y_all = f(x_all)

  54.    # Do SGD on this task

  55.    inds = rng.permutation(len(x_all))

  56.    for _ in range(innerepochs):

  57.        for start in range(0, len(x_all), ntrain):

  58.            mbinds = inds[start:start+ntrain]

  59.            train_on_batch(x_all[mbinds], y_all[mbinds])

  60.    # Interpolate between current weights and trained weights from this task

  61.    # I.e. (weights_before - weights_after) is the meta-gradient

  62.    weights_after = model.state_dict()

  63.    outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule

  64.    model.load_state_dict({name :

  65.        weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize

  66.        for name in weights_before})

  67.    # Periodically plot the results on a particular task and minibatch

  68.    if plot and iteration==0 or (iteration+1) % 1000 == 0:

  69.        plt.cla()

  70.        f = f_plot

  71.        weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation

  72.        plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))

  73.        for inneriter in range(32):

  74.            train_on_batch(xtrain_plot, f(xtrain_plot))

  75.            if (inneriter+1) % 8 == 0:

  76.                frac = (inneriter+1) / 32

  77.                plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))

  78.        plt.plot(x_all, f(x_all), label="true", color=(0,1,0))

  79.        lossval = np.square(predict(x_all) - f(x_all)).mean()

  80.        plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")

  81.        plt.ylim(-4,4)

  82.        plt.legend(loc="lower right")

  83.        plt.pause(0.01)

  84.        model.load_state_dict(weights_before) # restore from snapshot

  85.        print(f"-----------------------------")

  86.        print(f"iteration               {iteration+1}")

  87.        print(f"loss on plotted curve   {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity

原文地址:blog.openai.com/reptile/

GitHub地址:github.com/openai/supervised-reptile

论文地址:d4mucfpksywv.cloudfront.net/research-covers/reptile/reptile_update.pdf

登录查看更多
1

相关内容

Reptile是元学习(Meta learning)最经典的几个算法之一,出自论文《Reptile: a Scalable Metalearning Algorithm》。除了对算法本身的贡献,论文还给出了Reptile和MAML算法的数学解释与分析。 原文地址:https://d4mucfpksywv.cloudfront.net/research-covers/reptile/reptile_update.pdf
【浙江大学】使用MAML元学习的少样本图分类
专知会员服务
62+阅读 · 2020年3月22日
专知会员服务
87+阅读 · 2020年1月20日
Meta-Learning 元学习:学会快速学习
专知
24+阅读 · 2018年12月8日
入门 | 从零开始,了解元学习
机器之心
17+阅读 · 2018年5月6日
OpenAI提出Reptile:可扩展的元学习算法
深度学习世界
7+阅读 · 2018年3月9日
OpenAI发布大规模元学习算法Reptile
AI前线
6+阅读 · 2018年3月9日
Arxiv
14+阅读 · 2019年9月11日
Meta-Learning with Implicit Gradients
Arxiv
13+阅读 · 2019年9月10日
Meta-Transfer Learning for Few-Shot Learning
Arxiv
8+阅读 · 2018年12月6日
Arxiv
136+阅读 · 2018年10月8日
Arxiv
5+阅读 · 2018年9月11日
Arxiv
6+阅读 · 2018年4月24日
VIP会员
相关资讯
Meta-Learning 元学习:学会快速学习
专知
24+阅读 · 2018年12月8日
入门 | 从零开始,了解元学习
机器之心
17+阅读 · 2018年5月6日
OpenAI提出Reptile:可扩展的元学习算法
深度学习世界
7+阅读 · 2018年3月9日
OpenAI发布大规模元学习算法Reptile
AI前线
6+阅读 · 2018年3月9日
相关论文
Arxiv
14+阅读 · 2019年9月11日
Meta-Learning with Implicit Gradients
Arxiv
13+阅读 · 2019年9月10日
Meta-Transfer Learning for Few-Shot Learning
Arxiv
8+阅读 · 2018年12月6日
Arxiv
136+阅读 · 2018年10月8日
Arxiv
5+阅读 · 2018年9月11日
Arxiv
6+阅读 · 2018年4月24日
Top
微信扫码咨询专知VIP会员