本文为 AI 研习社编译的技术博客,原标题 :
Object detection and tracking in PyTorch
作者 | Chris Fotache
翻译 | 酱番梨、麦尔肯•诺埃、TripleZ
校对 | 酱番梨 整理 | 菠萝妹
原文链接:
https://towardsdatascience.com/object-detection-and-tracking-in-pytorch-b3cf1a696a98
注:本文的相关链接请点击文末【阅读原文】进行访问
在我之前的工作中,我尝试过用自己的图像在PyTorch中训练一个图像分类器,然后用它来进行图像识别。现在,我将向你们展示如何使用预训练的分类器在一张图像中检测多个目标,之后在整个视频中跟踪他们。
图像分类(识别)和目标检测之间有什么区别?在分类问题中,你识别出在图像中哪一个才是主要目标,然后将整张图片分类到一个单一类别中;在检测问题中,图像中有多个目标被识别、分类,而且目标的位置同样被确定下来(比如一个边界框)。
图像中目标检测
现有多种目标检测算法,其中YOLO,SSD是最受欢迎的方法,本文采用YOLOv3作为示例。本文不会对YOLO的技术细节进行分析,只是关注如何在自己的应用中实现。
直接上代码~YOLO检测的代码是基于Erik Lindernoren实现的Joseph Redmon and Ali Farhadi的文章。代码可以在Github中找到,下面为部分代码,在运行代码之前需要先在config文件夹中运行download_weights.sh脚本下载YOLO的权重文件,首先需要导入必要的模块:
from models import *
from utils import *
import os, sys, time, datetime, random
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
之后下载预训练的配置和权重,Darknet训练所用的COCO数据集的类别名称。在PyTorch中,在加载之后不要忘记将model设置为eval模式。
config_path='config/yolov3.cfg'
weights_path='config/yolov3.weights'
class_path='config/coco.names'
img_size=416
conf_thres=0.8
nms_thres=0.4
# Load model and weights
model = Darknet(config_path, img_size=img_size)
model.load_weights(weights_path)
model.cuda()
model.eval()
classes = utils.load_classes(class_path)
Tensor = torch.cuda.FloatTensor
上述代码中还有一些提前定义的值:图像尺寸(416*416像素),置信度阈值,非极大值抑制阈值。
下面是返回对特定图像的检测结果的基本函数。注意输入Pillow图像,大部分代码将图像resize至416*416,保持图像的纵横比并且填充溢出,实际的检测为最后的4行。
def detect_image(img):
# scale and pad image
ratio = min(img_size/img.size[0], img_size/img.size[1])
imw = round(img.size[0] * ratio)
imh = round(img.size[1] * ratio)
img_transforms=transforms.Compose([transforms.Resize((imh,imw)),
transforms.Pad((max(int((imh-imw)/2),0),
max(int((imw-imh)/2),0), max(int((imh-imw)/2),0),
max(int((imw-imh)/2),0)), (128,128,128)),
transforms.ToTensor(),
])
# convert image to Tensor
image_tensor = img_transforms(img).float()
image_tensor = image_tensor.unsqueeze_(0)
input_img = Variable(image_tensor.type(Tensor))
# run inference on the model and get detections
with torch.no_grad():
detections = model(input_img)
detections = utils.non_max_suppression(detections, 80
conf_thres, nms_thres)
return detections[0]
最后,将加载图像,获取检测结果,显示检测到的目标的边界框组合到一起。同样,这里大部分代码处理图像的放缩和填充,对每个不同的目标类别设置不同的颜色。
# load image and get detections
img_path = "images/blueangels.jpg"
prev_time = time.time()
img = Image.open(img_path)
detections = detect_image(img)
inference_time = datetime.timedelta(seconds=time.time() - prev_time)
print ('Inference Time: %s' % (inference_time))
# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]
img = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(img)
pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_x
if detections is not None:
unique_labels = detections[:, -1].cpu().unique()
n_cls_preds = len(unique_labels)
bbox_colors = random.sample(colors, n_cls_preds)
# browse detections and draw bounding boxes
for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
box_h = ((y2 - y1) / unpad_h) * img.shape[0]
box_w = ((x2 - x1) / unpad_w) * img.shape[1]
y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0]
x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1]
color = bbox_colors[int(np.where(
unique_labels == int(cls_pred))[0])]
bbox = patches.Rectangle((x1, y1), box_w, box_h,
linewidth=2, edgecolor=color, facecolor='none')
ax.add_patch(bbox)
plt.text(x1, y1, s=classes[int(cls_pred)],
color='white', verticalalignment='top',
bbox={'color': color, 'pad': 0})
plt.axis('off')
# save image
plt.savefig(img_path.replace(".jpg", "-det.jpg"),
bbox_inches='tight', pad_inches=0.0)
plt.show()
您可以将这些代码片段放在一起运行代码,或者从Github下载。下面是一些图像中目标检测的例子 :
视频中的物体追踪
所以,现在你知道了检测图像中的不同对象的方法。当你在视频中逐帧执行时,可视化可能非常酷,你会看到这些跟踪框四处移动。但是,如果这些视频帧中有多个对象,我们如何知道一帧中的对象是否与前一帧中的对象相同?这就是我们所说的“对象追踪”,并使用多个检测来识别特定对象随时间的变化。
有几种算法可以做到这一点,我决定使用SORT,它非常易于使用且速度非常快。SORT(简单在线和实时跟踪)是由Alex Bewley,Zongyuan Ge,Lionel Ott,Fabio Ramos,Ben Upcroft等人于2017年撰写的论文,其中提出使用卡尔曼滤波器来预测先前识别的对象的轨迹,并将它们与新的检测相匹配。作者Alex Bewley还写了一个多功能的Python实现,我将用它来讲述这个故事。确保从我的Github repo下载Sort版本,因为我必须进行一些小的更改才能将它集成到我的项目中。
现在我们来详细聊聊代码,前3个代码段将与单个图像检测中的相同,因为它们涉及在单个帧上获取YOLO检测。不同之处在于最后一部分,对于每个检测,我们调用Sort对象的Update函数以获取对图像中对象的引用。因此,除了上一个示例的常规检测(包括边界框的坐标和类预测)之外,我们将获得跟踪对象,除了上述参数之外,还包括对象ID。然后我们以几乎相同的方式显示,但添加该ID并使用不同的颜色,以便大家可以轻松地在视频帧中查看对象。
我还使用OpenCV来读取视频并显示视频帧。请注意,Jupyter笔记本在处理视频时速度很慢。你可以将它用于测试和简单可视化,我还提供了一个独立的Python脚本,它将读取源视频,并输出带有跟踪对象的副本。在笔记本电脑中播放OpenCV视频并不容易,因此你可以将此代码保留在其他实验中。
videopath = 'video/intersection.mp4'
%pylab inline
import cv2
from IPython.display import clear_output
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# initialize Sort object and video capture
from sort import *
vid = cv2.VideoCapture(videopath)
mot_tracker = Sort()
#while(True):
for ii in range(40):
ret, frame = vid.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pilimg = Image.fromarray(frame)
detections = detect_image(pilimg)
img = np.array(pilimg)
pad_x = max(img.shape[0] - img.shape[1], 0) *
(img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) *
(img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_x
if detections is not None:
tracked_objects = mot_tracker.update(detections.cpu())
unique_labels = detections[:, -1].cpu().unique()
n_cls_preds = len(unique_labels)
for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:
box_h = int(((y2 - y1) / unpad_h) * img.shape[0])
box_w = int(((x2 - x1) / unpad_w) * img.shape[1])
y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])
x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])
color = colors[int(obj_id) % len(colors)]
color = [i * 255 for i in color]
cls = classes[int(cls_pred)]
cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h),
color, 4)
cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60,
y1), color, -1)
cv2.putText(frame, cls + "-" + str(int(obj_id)),
(x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX,
1, (255,255,255), 3)
fig=figure(figsize=(12, 8))
title("Video Stream")
imshow(frame)
show()
clear_output(wait=True)
使用笔记本后,你可以使用常规Python脚本进行实时处理(可以从相机获取输入)并保存视频。以下是我使用此程序生成的视频示例。
PyTorch中的对象检测和跟踪 [深度学习]
就是这样,你可以尝试自己检测图像中的多个对象并在视频帧中跟踪这些对象。你还可以对YOLO进行更多研究,并了解如何使用图像训练模型。 Chris Fotache是位于新泽西州的CYNET.ai的人工智能研究员。他介绍了与人生智能相关的主题,Python编程,机器学习,计算机视觉,自然语言处理等。
想要继续查看该篇文章相关链接和参考文献?
长按链接点击打开或点击底部【阅读原文】:
https://ai.yanxishe.com/page/TextTranslation/1333
AI研习社每日更新精彩内容,观看更多精彩内容:
等你来译:
点击 阅读原文 查看本文更多内容↙