今天凌晨,OpenAI发表文章,宣布他们开发了一种简单的元学习算法,名为Reptile。该算法通过对任务进行重复采样,对其执行随机梯度下降,并且将初始参数朝学习的最终参数更新。这种方法的表现和常用的元学习算法MAML一样好,甚至比它安装起来更简单、计算也更高效。
元学习是学习算法如何学习的过程。将任务的分布输入到元学习算法中,其中每个任务都是一个学习问题,然后它会生成一个快速学习者——该学习者可以从少数样本中生成结果。一个好的元学习问题是few-shot分类问题,每个任务都是分类型的问题,其中学习者在每个类别中只能看到1—5个输入输出样本,然后再对新的输入进行分类。
下面就是该算法的一个Demo,在Training Data中画三个训练图像,然后在Input区域画一个输入图像,算法会自动判断这个输入图像属于训练图像的哪一类,并用概率表示。
以下是论智君做的实验:
输入笑脸,与其最相近的是第一幅图
输入哭脸,最相似的是第二幅图
加上脸部轮廓,最相似的绝对是第三幅图啦
与MAML类似,Reptile寻求神经网络参数的初始化,以便神经网络能够利用少量数据进行微调。但是当MAML通过梯度下降算法的计算图展开并求导时,Reptile只是以标准方式对每个任务执行随机梯度下降——它不展开计算图,也不计算任何二阶导数。这使得Reptile比MAML占用更少的计算和内存。伪代码如下:
为了替代最后一步,我们可以将Φ-W看成是一个梯度,并且把它输入进一个更复杂的优化器中,比如Adam。
令人惊讶的是,这种方法完全可以行得通。如果k=1,该算法将进行“联合训练”——将所有任务混合,对其进行随机梯度下降。尽管联合训练在一些情况下能学到有用的初始化,当零次学习(zero-shot)不可能时(例如当输出标签随机置换时),它学习得就很少。Reptile需要k>1,其中更新取决于损失函数的高阶导数;正如我们在论文中展示的,这与k=1时的表现完全不同。
为了分析Reptile为什么能工作,研究人员用一个泰勒级数进行大概的更新。他们证明了Reptile的更新使来自同一任务中的不同微粒梯度之间的内积最大化,相当于改进的泛化。这一发现可能会影响元学习环境对解释随机梯度下降的泛化特性。我们的分析表明,Reptile和MAML的更新非常相似,包括两个权重不同的相同项。
在他们的实验中,OpenAI展示了Reptile和MAML在Omniglot和Mini-ImageNet标准测试中对few-shot分类表现出同样的水平。由于Reptile的方差较低,它收敛到解决方案的速度也更快。
我们对Reptile的分析表明我们可以将随机梯度下降以不同方式结合,获得很多不同的算法。在下方的表格中,假设我们在每个任务中执行了k步随机梯度下降,生成了梯度g1,g2,…,gk。下面的图表显示出Omniglot上的学习曲线,它是将每个和看作元梯度获得的。g2对应的是MAML原始论文中提到的一阶MAML。由于方差缩减,梯度包含的越多就会生成越快的学习速率。需要注意的是,只使用g1(即对应k=1)在这一任务中是无法产生进步的,因为零次学习的性能无法改进。
Reptile的实现步骤可以在GitHub上找到。它的内部计算使用TensorFlow,代码可在Omniglot和Mini-ImageNet上复现。我们同时还发布了一个较小的JavaScript版本的实现,它是通过微调一个在TensorFlow上预训练过的模型得到的。上面的Demo即由这个小版本实现。
最后,这是一个few-shot回归的简单例子,从十对(x,y)中随机正弦波。该示例基于PyTorch:
import numpy as np
import torch
from torch import nn, autograd as ag
import matplotlib.pyplot as plt
from copy import deepcopy
seed = 0
plot = True
innerstepsize = 0.02 # stepsize in inner SGD
innerepochs = 1 # number of epochs of each inner SGD
outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it
rng = np.random.RandomState(seed)
torch.manual_seed(seed)
# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points
ntrain = 10 # Size of training minibatches
def gen_task():
"Generate classification problem"
phase = rng.uniform(low=0, high=2*np.pi)
ampl = rng.uniform(0.1, 5)
f_randomsine = lambda x : np.sin(x + phase) * ampl
return f_randomsine
# Define model. Reptile paper uses ReLU, but Tanh gives slightly better results
model = nn.Sequential(
nn.Linear(1, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1),
)
def totorch(x):
return ag.Variable(torch.Tensor(x))
def train_on_batch(x, y):
x = totorch(x)
y = totorch(y)
model.zero_grad()
ypred = model(x)
loss = (ypred - y).pow(2).mean()
loss.backward()
for param in model.parameters():
param.data -= innerstepsize * param.grad.data
def predict(x):
x = totorch(x)
return model(x).data.numpy()
# Choose a fixed task and minibatch for visualization
f_plot = gen_task()
xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]
# Reptile training loop
for iteration in range(niterations):
weights_before = deepcopy(model.state_dict())
# Generate task
f = gen_task()
y_all = f(x_all)
# Do SGD on this task
inds = rng.permutation(len(x_all))
for _ in range(innerepochs):
for start in range(0, len(x_all), ntrain):
mbinds = inds[start:start+ntrain]
train_on_batch(x_all[mbinds], y_all[mbinds])
# Interpolate between current weights and trained weights from this task
# I.e. (weights_before - weights_after) is the meta-gradient
weights_after = model.state_dict()
outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule
model.load_state_dict({name :
weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize
for name in weights_before})
# Periodically plot the results on a particular task and minibatch
if plot and iteration==0 or (iteration+1) % 1000 == 0:
plt.cla()
f = f_plot
weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation
plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))
for inneriter in range(32):
train_on_batch(xtrain_plot, f(xtrain_plot))
if (inneriter+1) % 8 == 0:
frac = (inneriter+1) / 32
plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))
plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
lossval = np.square(predict(x_all) - f(x_all)).mean()
plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
plt.ylim(-4,4)
plt.legend(loc="lower right")
plt.pause(0.01)
model.load_state_dict(weights_before) # restore from snapshot
print(f"-----------------------------")
print(f"iteration {iteration+1}")
print(f"loss on plotted curve {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity
原文地址:blog.openai.com/reptile/
GitHub地址:github.com/openai/supervised-reptile
论文地址:d4mucfpksywv.cloudfront.net/research-covers/reptile/reptile_update.pdf