机器之心报道
还记得那个看起来像 Keras 的轻量版 PyTorch 框架 Lightning 吗?它终于出了 1.0.0 版本,并增添了很多新功能,在度量、优化、日志记录、数据流、检查点等方面均进行了完善。
博客地址:https://medium.com/pytorch/pytorch-lightning-1-0-from-0-600k-80fc65e2fab0
GitHub 地址:https://github.com/PyTorchLightning/pytorch-lightning
class LitModel(pl.LightningModule):def __init__(self):...self.train_acc = pl.metrics.Accuracy()self.valid_acc = pl.metrics.Accuracy()def training_step(self, batch, batch_idx):logits = self(x)...self.train_acc(logits, y)# log step metricself.log('train_acc_step', self.train_acc)def validation_step(self, batch, batch_idx):logits = self(x)...self.valid_acc(logits, y)# logs epoch metricsself.log('valid_acc', self.valid_acc)
from pytorch_lightning.metrics import Metricclass MyAccuracy(Metric):def __init__(self, dist_sync_on_step=False):super().__init__(dist_sync_on_step=dist_sync_on_step)self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")def update(self, preds: torch.Tensor, target: torch.Tensor):preds, target = self._input_format(preds, target)assert preds.shape == target.shapeself.correct += torch.sum(preds == target)self.total += target.numel()def compute(self):return self.correct.float() / self.total
def training_step(self, batch, batch_idx):loss = self.encoder(batch[0])return loss
trainer *=* Trainer(automatic_optimization*=False*)
def training_step(self, batch, batch_idx, opt_idx):(opt_a, opt_b, opt_c) = self.optimizers()loss_a = self.generator(batch[0])# use this instead of loss.backward so we can automate half# precision, etc...self.manual_backward(loss_a, opt_a, retain_graph=True)self.manual_backward(loss_a, opt_a)opt_a.step()opt_a.zero_grad()loss_b = self.discriminator(batch[0])self.manual_backward(loss_b, opt_b)...
def training_step(self, batch, batch_idx):self.log('my_metric', x)
def training_step(self, batch, batch_idx):self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
x_step
x_step_end
x_epoch_end
outs *=* []*for* batch *in* data:out *=* training_step(batch)outs*.*append(out)training_epoch_end(outs)
def training_step(self, batch, batch_idx):prediction = …return {'loss': loss, 'preds': prediction}def training_epoch_end(self, training_step_outputs):for out in training_step_outputs:prediction = out['preds']# do something with these
计算想要监控的任意度量或其他数量,如验证损失;
通过 log() 方法记录下数量以及 val_loss 等键(key);
初始化 ModelCheckpoint 回调函数,将 monitor 设置为数量的 key;
将回调函数 checkpoint_callback 返回训练器 flag。
from pytorch_lightning.callbacks import ModelCheckpointclass LitAutoEncoder(pl.LightningModule):def validation_step(self, batch, batch_idx):x, y = batchy_hat = self.backbone(x)# 1. calculate lossloss = F.cross_entropy(y_hat, y)# 2. log `val_loss`self.log('val_loss', loss)# 3. Init ModelCheckpoint callback, monitoring 'val_loss'checkpoint_callback = ModelCheckpoint(monitor='val_loss')# 4. Pass your callback to checkpoint_callback trainer flagtrainer = Trainer(checkpoint_callback=checkpoint_callback)
© THE END
转载请联系本公众号获得授权
投稿或寻求报道:content@jiqizhixin.com