选自Github
项目作者:learnables
机器之心编译
元学习似乎一直比较「高级」,毕竟学习如何学习这个概念听起来就很难实现。在本文中,我们介绍了这两天新开源的元学习库 learn2learn,它是用 PyTorch 写的,只需要三四行代码就能构建元学习最为核心的部分。
项目地址:https://github.com/learnables/learn2learn
模块化 API:使用这个库中的底层工具实现你自己的训练循环;
提供多个元学习算法(如 MAML、FOMAML、MetaSGD、ProtoNets、DiCE);
具有统一 API 的任务生成器,兼容 torchvision、torchtext、torchaudio 和 cherry;
提供标准化的视觉(Omniglot、mini-ImageNet)、强化学习(Particles、Mujoco)甚至文本(新闻分类)元学习任务;
100% 兼容 PyTorch——使用你自己的模块、数据集或库。
import learn2learn as l2l
mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)
mnist = l2l.data.MetaDataset(mnist)
task_generator = l2l.data.TaskGenerator(mnist,
ways=3,
classes=[0, 1, 4, 6, 8, 9],
tasks=10)
model = Net()
maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)
opt = optim.Adam(maml.parameters(), lr=4e-3)
for iteration in range(num_iterations):
learner = maml.clone() # Creates a clone of model
adaptation_task = task_generator.sample(shots=1)
# Fast adapt
for step in range(adaptation_steps):
error = compute_loss(adaptation_task)
learner.adapt(error)
# Compute evaluation loss
evaluation_task = task_generator.sample(shots=1,
task=adaptation_task.sampled_task)
evaluation_error = compute_loss(evaluation_task)
# Meta-update the model parameters
opt.zero_grad()
evaluation_error.backward()
opt.step()
文档地址:http://learn2learn.net/docs/learn2learn/
教程地址:http://learn2learn.net/tutorials/getting_started/