轻松学Pytorch-使用ResNet50实现图像分类

2020 年 7 月 21 日 THU数据派


来源:磐创AI

本文 1161 ,建议阅读 4分钟
本文介绍pytorch中最重要的组件torchvision,它包含了常见的数据集、模型架构与预训练模型权重文件、常见图像变换、计算机视觉任务训练。


Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的数据集、模型架构与预训练模型权重文件、常见图像变换、计算机视觉任务训练。可以是说是pytorch中非常有用的模型迁移学习神器。本文将会介绍如何使用torchvison的预训练模型ResNet50实现图像分类。


模型


Torchvision.models包里面包含了常见的各种基础模型架构,主要包括:


AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
GoogLeNet
ShuffleNet v2
MobileNet v2
ResNeXt
Wide ResNet
MNASNet


这里我选择了ResNet50,基于ImageNet训练的基础网络来实现图像分类, 网络模型下载与加载如下:


model = torchvision.models.resnet50(pretrained=True).eval().cuda()tf = transforms.Compose([            transforms.Resize(256),            transforms.CenterCrop(224),            transforms.ToTensor(),            transforms.Normalize(            mean=[0.485, 0.456, 0.406],            std=[0.229, 0.224, 0.225]        )]


使用模型实现图像分类


这里首先需要加载ImageNet的分类标签,目的是最后显示分类的文本标签时候使用。然后对输入图像完成预处理,使用ResNet50模型实现分类预测,对预测结果解析之后,显示标签文本,完整的代码演示如下:


 1with open('imagenet_classes.txt'as f:
2    labels = [line.strip() for line in f.readlines()]
3
4src = cv.imread("D:/images/space_shuttle.jpg"# aeroplane.jpg
5image = cv.resize(src, (224224))
6image = np.float32(image) / 255.0
7image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
8image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
9image = image.transpose((201))
10input_x = torch.from_numpy(image).unsqueeze(0)
11print(input_x.size())
12pred = model(input_x.cuda())
13pred_index = torch.argmax(pred, 1).cpu().detach().numpy()
14print(pred_index)
15print("current predict class name : %s"%labels[pred_index[0]])
16cv.putText(src, labels[pred_index[0]], (5050), cv.FONT_HERSHEY_SIMPLEX, 1.0, (00255), 2)
17cv.imshow("input", src)
18cv.waitKey(0)
19cv.destroyAllWindows()


运行结果如下:



转ONNX支持


在torchvision中的模型基本上都可以转换为ONNX格式,而且被OpenCV DNN模块所支持,所以,很方便的可以对torchvision自带的模型转为ONNX,实现OpenCV DNN的调用,首先转为ONNX模型,直接使用torch.onnx.export即可转换(还不知道怎么转,快点看前面的例子)。转换之后使用OpenCV DNN调用的代码如下:


 1with open('imagenet_classes.txt'as f:
2    labels = [line.strip() for line in f.readlines()]
3net = cv.dnn.readNetFromONNX("resnet.onnx")
4src = cv.imread("D:/images/messi.jpg")  # aeroplane.jpg
5image = cv.resize(src, (224224))
6image = np.float32(image) / 255.0
7image[:, :, ] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
8image[:, :, ] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
9blob = cv.dnn.blobFromImage(image, 1.0, (224224), (000), False)
10net.setInput(blob)
11probs = net.forward()
12index = np.argmax(probs)
13cv.putText(src, labels[index], (5050), cv.FONT_HERSHEY_SIMPLEX, 1.0, (00255), 2)
14cv.imshow("input", src)
15cv.waitKey(0)
16cv.destroyAllWindows()


 运行结果见上图,这里就不再贴了。


——END——


登录查看更多
1

相关内容

【干货书】高级应用深度学习,294页pdf
专知会员服务
153+阅读 · 2020年6月20日
《深度学习》圣经花书的数学推导、原理与Python代码实现
【Google AI】开源NoisyStudent:自监督图像分类
专知会员服务
54+阅读 · 2020年2月18日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
《动手学深度学习》(Dive into Deep Learning)PyTorch实现
专知会员服务
119+阅读 · 2019年12月31日
【干货】用BRET进行多标签文本分类(附代码)
专知会员服务
84+阅读 · 2019年12月27日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
163+阅读 · 2019年10月28日
Opencv+TF-Slim实现图像分类及深度特征提取
极市平台
16+阅读 · 2019年8月19日
Mask-RCNN模型的实现自定义对象(无人机)检测
计算机视觉life
17+阅读 · 2019年8月12日
PyTorch使用总览
极市平台
5+阅读 · 2019年3月25日
如何用TensorFlow和TF-Slim实现图像标注、分类与分割
北京思腾合力科技有限公司
21+阅读 · 2017年11月24日
深度学习入门篇--手把手教你用 TensorFlow 训练模型
全球人工智能
4+阅读 · 2017年10月21日
手把手教TensorFlow(附代码)
深度学习世界
15+阅读 · 2017年10月17日
手把手教你由TensorFlow上手PyTorch(附代码)
数据派THU
5+阅读 · 2017年10月1日
Arxiv
15+阅读 · 2020年2月5日
Nocaps: novel object captioning at scale
Arxiv
6+阅读 · 2018年12月20日
Arxiv
8+阅读 · 2018年5月1日
Arxiv
5+阅读 · 2018年4月17日
Arxiv
7+阅读 · 2018年3月22日
Arxiv
3+阅读 · 2017年10月1日
VIP会员
相关VIP内容
【干货书】高级应用深度学习,294页pdf
专知会员服务
153+阅读 · 2020年6月20日
《深度学习》圣经花书的数学推导、原理与Python代码实现
【Google AI】开源NoisyStudent:自监督图像分类
专知会员服务
54+阅读 · 2020年2月18日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
《动手学深度学习》(Dive into Deep Learning)PyTorch实现
专知会员服务
119+阅读 · 2019年12月31日
【干货】用BRET进行多标签文本分类(附代码)
专知会员服务
84+阅读 · 2019年12月27日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
163+阅读 · 2019年10月28日
相关资讯
Opencv+TF-Slim实现图像分类及深度特征提取
极市平台
16+阅读 · 2019年8月19日
Mask-RCNN模型的实现自定义对象(无人机)检测
计算机视觉life
17+阅读 · 2019年8月12日
PyTorch使用总览
极市平台
5+阅读 · 2019年3月25日
如何用TensorFlow和TF-Slim实现图像标注、分类与分割
北京思腾合力科技有限公司
21+阅读 · 2017年11月24日
深度学习入门篇--手把手教你用 TensorFlow 训练模型
全球人工智能
4+阅读 · 2017年10月21日
手把手教TensorFlow(附代码)
深度学习世界
15+阅读 · 2017年10月17日
手把手教你由TensorFlow上手PyTorch(附代码)
数据派THU
5+阅读 · 2017年10月1日
相关论文
Arxiv
15+阅读 · 2020年2月5日
Nocaps: novel object captioning at scale
Arxiv
6+阅读 · 2018年12月20日
Arxiv
8+阅读 · 2018年5月1日
Arxiv
5+阅读 · 2018年4月17日
Arxiv
7+阅读 · 2018年3月22日
Arxiv
3+阅读 · 2017年10月1日
Top
微信扫码咨询专知VIP会员