©作者 | 清川
单位 | 上海交通大学博士生
研究方向 | 联邦学习与端云协同推断
本文主要讨论 PyTorch 模型训练中的两种可复现性:一种是在完全不改动代码的情况下重复运行,获得相同的准确率曲线;另一种是改动有限的代码,改动部分不影响训练过程的前提下,获得相同的曲线。
第一种情况,浅显地讲,我们只需要固定所有随机数种子就行
我们知道,计算机一般会使用混合线性同余法来生成伪随机数序列。在我们每次调用 rand() 函数时,就会执行一次或若干次下面的递推公式:
def seed_everything(seed):
torch.manual_seed(seed) # Current CPU
torch.cuda.manual_seed(seed) # Current GPU
np.random.seed(seed) # Numpy module
random.seed(seed) # Python random module
torch.backends.cudnn.benchmark = False # Close optimization
torch.backends.cudnn.deterministic = True # Close optimization
torch.cuda.manual_seed_all(seed) # All GPU (Optional)
>>> import torch
>>> from utils import seed_everything
>>> seed_everything(0)
>>> torch.rand(5)
tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074])
>>> seed_everything(0)
>>> _ = torch.rand(1)
>>> torch.rand(5)
tensor([0.7682, 0.0885, 0.1320, 0.3074, 0.6341])
import torch
from torch.utils.data import TensorDataset, DataLoader
from utils import seed_everything
seed_everything(0)
dataset = TensorDataset(torch.rand((10, 3)), torch.rand(10))
dataloader = DataLoader(dataset, shuffle=False, batch_size=2)
print(torch.rand(5))
# tensor([0.5263, 0.2437, 0.5846, 0.0332, 0.1387])
seed_everything(0)
dataset = TensorDataset(torch.rand((10, 3)), torch.rand(10))
dataloader = DataLoader(dataset, shuffle=False, batch_size=2)
for inputs, labels in dataloader:
pass
print(torch.rand(5))
tensor([0.5846, 0.0332, 0.1387, 0.2422, 0.8155])
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
...
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
...
...
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
w.start()
...
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
num_workers, persistent_workers):
...
seed = base_seed + worker_id
random.seed(seed)
torch.manual_seed(seed)
if HAS_NUMPY:
np_seed = _generate_state(base_seed, worker_id)
import numpy as np
np.random.seed(np_seed)
...
for inputs, labels in DataLoader(...):
pass
# in操作符会调用如下
DataLoader()
DataLoader.self.__iter__()
DataLoader.self._get_iterator()
_MultiProcessingDataLoaderIter(DataLoader.self)
_BaseDataLoaderIter(DataLoader.self)
_BaseDataLoaderIter.self._base_seed = torch.empty(
(), dtype=torch.int64).random_(generator=DataLoader.generator).item()
# 一般来说generator是None,我们不指定,random_没有from和to时,会取数据类型最大范围,这里相当于随机生成一个大整数
def stable(dataloader, seed):
seed_everything(seed)
return dataloader
for inputs, labels in stable(DataLoader(...), seed):
pass
for epoch in range(MAX_EPOCH): # training
for inputs, labels in stable(DataLoader(...), seed + epoch):
pass
import random, numpy, torch
from torch.utils.data import DataLoader, TensorDataset
from utils import seed_everything
seed_everything(0)
BATCH_SIZE, NUM_WORKERS = 8, 4
dataset = TensorDataset(torch.rand((100, 3)), torch.rand(100))
g = torch.Generator()
g.manual_seed(0)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2 ** 32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
train_data = DataLoader(
dataset, shuffle=True,
batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
worker_init_fn=seed_worker, generator=g)
test_data = DataLoader(
dataset, shuffle=False,
batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
worker_init_fn=seed_worker, generator=g)
def test():
for inputs, labels in test_data:
pass
def train():
for inputs, labels in train_data:
print(labels)
break
if __name__ == "__main__":
# case 1
# Result: tensor([0.8174, 0.1753, 0.5049, 0.8947, 0.8472, 0.2588, 0.2568, 0.7127])
train()
# case 2
# Result: tensor([0.8947, 0.1753, 0.7802, 0.2161, 0.9094, 0.7335, 0.3245, 0.6152])
test()
train()
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」