PyTorch 对类别张量进行 one-hot 编码

2021 年 12 月 18 日 极市平台
↑ 点击 蓝字  关注极市平台

作者丨Lart
编辑丨极市平台

极市导读

 

one-hot 形式的编码在深度学习任务中非常常见,但是却并不是一种很自然的数据存储方式。所以大多数情况下都需要我们自己手动转换。本文尽可能将基于 pytorch 中常用方法来实现one-hot编码的方式整理了下,希望对大家有用。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

前言

one-hot 形式的编码在深度学习任务中非常常见,但是却并不是一种很自然的数据存储方式。所以大多数情况下都需要我们自己手动转换。虽然思路很直接,就是将类别拆分成一一对应的 0-1 向量,但是具体实现起来确实还是需要思考下的。实际上 pytorch 自身在nn.functional中已经提供了one_hot方法来快速应用。但是这并不能影响我们的思考与实践:>!所以本文尽可能将基于 pytorch 中常用方法来实现one-hot编码的方式整理了下,希望有用。

主要的方式有这么几种:

  • for循环
  • scatter
  • index_select
原始文档:

https://www.yuque.com/lart/ugkv9f/src5w8

代码仓库:

https://github.com/lartpang/CodeForArticle/tree/main/OneHotEncoding.PyTorch

for循环

这种方法非常直观,说白了就是对一个空白(全零)张量中的指定位置进行赋值(赋 1)操作即可。关键在于如何设定索引。下面设计了两种本质相同但由于指定维度不同而导致些许差异的方案。

def bhw_to_onehot_by_for(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """

    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = bhw_tensor.new_zeros(size=(num_classes, *bhw_tensor.shape))
    for i in range(num_classes):
        one_hot[i, bhw_tensor == i] = 1
    one_hot = one_hot.permute(1230)
    return one_hot


def bhw_to_onehot_by_for_V1(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """

    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = bhw_tensor.new_zeros(size=(*bhw_tensor.shape, num_classes))
    for i in range(num_classes):
        one_hot[..., i][bhw_tensor == i] = 1
    return one_hot

scatter

该方法应该是网上大多数简洁的one_hot写法的常用形式了。其实际上主要的作用是向 tensor 中指定的位置上赋值。

由于其可以使用专门构造的索引矩阵来作为索引,所以更加灵活。当然,灵活带来的也就是理解上的困难。官方文档中提供的解释非常直观:

'''
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html

* (int dim, Tensor index, Tensor src)
 * (int dim, Tensor index, Tensor src, *, str reduce)
 * (int dim, Tensor index, Number value)
 * (int dim, Tensor index, Number value, *, str reduce)
'''


self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

文档中使用的是原地置换(in-place)版本,并且基于替换值为src,即 tensor 的情况下来解释。实际上在我们的应用中主要基于原地置换版本并搭配替换值为标量浮点数value的形式。

上述的形式中,我们可以看到,通过指定参数 tensor index,我们就可以将src(i,j,k)的值放置到方法调用者(这里是self)的指定位置上。该指定位置由index(i,j,k)处的值替换坐标(i,j,k)中的dim位置的值来构成(这里也反映出来了index tensor 的一个要求,就是维度数量要和selfsrc(如果src为 tensor 的话。后文中使用的是具体的标量值 1,即src替换为value)一致)。这倒是和one-hot的概念非常吻合。因为one-hot本身形式上的含义就是对于第i类数据,第i个位置为 1,其余位置为 0。所以对全零 tensor 使用scatter_是可以非常容易的构造出one-hottensor 的,即对对应于类别编号的位置放置 1 即可。

对于我们的问题而言,index非常适合使用输入的包含类别编号的 tensor(形状为B,H,W)来表示。基于这样的思考,可以构思出两种不同的策略:

def bhw_to_onehot_by_scatter(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """

    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))
    one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-11), value=1)
    one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)
    return one_hot


def bhw_to_onehot_by_scatter_V1(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """

    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))
    one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)
    return one_hot

这两种形式的差异的根源在于对形状的处理上。由此带来了scatter不同的应用形式。

对于第一种形式,将B,H,W三个维度合并,这样的好处是对通道(类别)的索引的理解变得直观起来。

    one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))
    one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-11), value=1)

这里将类别维度和其他维度直接分离,移到了末位。通过dim指定该维度,于是就有了这样的对应关系:

zero_tensor[abc, index[abc][d]] = value  # d=0

而在第二种情况下仍然保留了前面的三个维度,类别维度依然移动到最后一位。

    one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))
    one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)

此时的对应关系是这样的:

zero_tensor[a,b,c, index[a][b][c][d]] = value # d=0

另外在 pytorch 分类模型库 timm 中,也使用了类似的方法:

# https://github.com/rwightman/pytorch-image-models/blob/2c33ca6d8ce5d9257edf8cab5ab7ece81780aaf7/timm/data/mixup.py#L17-L19
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
    x = x.long().view(-11)
    return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)

index_select

torch.index_select(input, dim, index, *, out=None) → Tensor

- input (Tensor) – the input tensor.
- dim (int) – the dimension in which we index
- index (IntTensor or LongTensor) – the 1-D tensor containing the indices to index

该函数如其名,就是用索引来选择 tensor 的指定维度的子 tensor 的。

想要理解这一方法的动机,实际上需要反过来,从类别标签的角度看待one-hot编码。

对于原始从小到大排布的类别序号对应的one-hot编码成的矩阵就是一个单位矩阵。所以每个类别对应的就是该单位矩阵的特定的列(或者行)。这一需求恰好符合index_select的功能。所以我们可以使用其实现one_hot编码,只需要使用类别序号索引特定的列或者行即可。下面就是一个例子:

def bhw_to_onehot_by_index_select(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """

    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = torch.eye(num_classes).index_select(dim=0, index=bhw_tensor.reshape(-1))
    one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)
    return one_hot

性能对比

整体代码可见:https://github.com/lartpang/CodeForArticle/tree/main/OneHotEncoding.PyTorch

下面展示了不同方法的大致的相对性能(因为后台在跑程序,可能并不是十分准确,建议大家自行测试)。可以看到,pytorch 自带的函数在 CPU 上效率并不是很高,但是在 GPU 上表现良好。其中有趣的是,基于index_select的形式表现非常亮眼。

1.10.0 GeForce RTX 2080 Ti

cpu
('bhw_to_onehot_by_for', 0.5411529541015625)
('bhw_to_onehot_by_for_V1', 0.4515676498413086)
('bhw_to_onehot_by_scatter', 0.0686192512512207)
('bhw_to_onehot_by_scatter_V1', 0.08529376983642578)
('bhw_to_onehot_by_index_select', 0.05156970024108887)
('F.one_hot', 0.07366824150085449)

gpu
('bhw_to_onehot_by_for', 0.005235433578491211)
('bhw_to_onehot_by_for_V1', 0.045584678649902344)
('bhw_to_onehot_by_scatter', 0.0025513172149658203)
('bhw_to_onehot_by_scatter_V1', 0.0024869441986083984)
('bhw_to_onehot_by_index_select', 0.002012014389038086)
('F.one_hot', 0.0024051666259765625)

如果觉得有用,就请分享到朋友圈吧!

△点击卡片关注极市平台,获取 最新CV干货

公众号后台回复“transformer”获取最新Transformer综述论文下载~


极市干货
课程/比赛: 珠港澳人工智能算法大赛 保姆级零基础人工智能教程
算法trick 目标检测比赛中的tricks集锦 从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述: 一文弄懂各种loss function 工业图像异常检测最新研究总结(2019-2020)

极市平台签约作者#


Lart

知乎:人民艺术家

CSDN:有为少年



大连理工大学在读博士

研究领域:主要方向为图像分割,但多从事于二值图像分割的研究。也会关注其他领域,例如分类和检测等方向的发展。


作品精选

实践教程 | PyTorch中相对位置编码的理解
实操教程 | 使用Docker为无网络环境搭建深度学习环境
实践教程 | 一文让你把Docker用起来!

投稿方式:
添加小编微信Fengcall(微信号:fengcall19),备注:姓名-投稿
△长按添加极市平台小编

觉得有用麻烦给个在看啦~   
登录查看更多
0

相关内容

【AAAI 2022】用于文本摘要任务的序列级对比学习模型
专知会员服务
24+阅读 · 2022年1月11日
《深度学习》圣经花书的数学推导、原理与Python代码实现
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
【GitHub实战】Pytorch实现的小样本逼真的视频到视频转换
专知会员服务
35+阅读 · 2019年12月15日
开源书:PyTorch深度学习起步
专知会员服务
50+阅读 · 2019年10月11日
计算机视觉最佳实践、代码示例和相关文档
专知会员服务
18+阅读 · 2019年10月9日
实例:手写 CUDA 算子,让 Pytorch 提速 20 倍
极市平台
4+阅读 · 2022年3月8日
实践教程|卷积神经网络 C++ 从零开始实现
极市平台
2+阅读 · 2022年2月24日
Pytorch中的RNN及变种模型与变长序列的处理
极市平台
1+阅读 · 2022年1月17日
实践教程 | 浅谈 PyTorch 中的 tensor 及使用
极市平台
1+阅读 · 2021年12月14日
pytorch中六种常用的向量相似度评估方法
极市平台
22+阅读 · 2021年12月9日
深度学习Pytorch框架Tensor张量
极市平台
0+阅读 · 2021年11月1日
如何给你PyTorch里的Dataloader打鸡血
极市平台
15+阅读 · 2019年5月21日
基于PyTorch/TorchText的自然语言处理库
专知
28+阅读 · 2019年4月22日
tensorflow LSTM + CTC实现端到端OCR
机器学习研究会
26+阅读 · 2017年11月16日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
5+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Arxiv
0+阅读 · 2022年4月18日
Arxiv
22+阅读 · 2018年8月30日
VIP会员
相关VIP内容
【AAAI 2022】用于文本摘要任务的序列级对比学习模型
专知会员服务
24+阅读 · 2022年1月11日
《深度学习》圣经花书的数学推导、原理与Python代码实现
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
【GitHub实战】Pytorch实现的小样本逼真的视频到视频转换
专知会员服务
35+阅读 · 2019年12月15日
开源书:PyTorch深度学习起步
专知会员服务
50+阅读 · 2019年10月11日
计算机视觉最佳实践、代码示例和相关文档
专知会员服务
18+阅读 · 2019年10月9日
相关资讯
实例:手写 CUDA 算子,让 Pytorch 提速 20 倍
极市平台
4+阅读 · 2022年3月8日
实践教程|卷积神经网络 C++ 从零开始实现
极市平台
2+阅读 · 2022年2月24日
Pytorch中的RNN及变种模型与变长序列的处理
极市平台
1+阅读 · 2022年1月17日
实践教程 | 浅谈 PyTorch 中的 tensor 及使用
极市平台
1+阅读 · 2021年12月14日
pytorch中六种常用的向量相似度评估方法
极市平台
22+阅读 · 2021年12月9日
深度学习Pytorch框架Tensor张量
极市平台
0+阅读 · 2021年11月1日
如何给你PyTorch里的Dataloader打鸡血
极市平台
15+阅读 · 2019年5月21日
基于PyTorch/TorchText的自然语言处理库
专知
28+阅读 · 2019年4月22日
tensorflow LSTM + CTC实现端到端OCR
机器学习研究会
26+阅读 · 2017年11月16日
相关基金
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
5+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Top
微信扫码咨询专知VIP会员