新智元报道
众所周知,PPYOLO和PPYOLOv2的导出部署非常困难,因为它们使用了可变形卷积、MatrixNMS等对部署不太友好的算子。
而作者在ncnn中实现了可变形卷积DCNv2、CoordConcat、PPYOLO Decode MatrixNMS等自定义层,使得使用ncnn部署PPYOLO和PPYOLOv2成为了可能。其中的可变形卷积层也已经被合入ncnn官方仓库。
在ncnn中对图片预处理时,先将图片从BGR格式转成RGB格式,然后用cv2.INTER_CUBIC方式将图片插值成640x640的大小,再使用相同的均值和标准差对图片进行归一化。以上全部与原版PPYOLOv2一样,从而确保了C++端和python端输入神经网络的图片张量是完全一样的。
最后,ncnn的输出与miemiedetection的输出对比如下图所示:
python tools/demo.py image -f exps/ppyolo/ppyolov2_r50vd_365e.py -c ppyolov2_r50vd_365e.pth --path assets/000000013659.jpg --conf 0.15 --tsize 640 --save_result --device gpu
pytorch直接转ncnn
读了一部分ncnn的源码,确保对 *.bin 和 *.param 文件充分了解之后,封装了1个工具ncnn_utils,源码位于miemiedetection的mmdet/models/ncnn_utils.py,它支持写一次前向传播就能导出ncnn使用的 *.bin 和 *.param 文件,你只需给每个pytorch层增加1个export_ncnn()方法,export_ncnn()方法几乎只要照抄farward()方法就能把模型导出到ncnn。
可变形卷积
...
#include "deformableconv2d.h"
#include "fused_activation.h"
namespace ncnn {
DeformableConv2D::DeformableConv2D()
{
one_blob_only = false;
support_inplace = false;
}
int DeformableConv2D::load_param(const ParamDict& pd)
{
num_output = pd.get(0, 0);
kernel_w = pd.get(1, 0);
kernel_h = pd.get(11, kernel_w);
dilation_w = pd.get(2, 1);
dilation_h = pd.get(12, dilation_w);
stride_w = pd.get(3, 1);
stride_h = pd.get(13, stride_w);
pad_left = pd.get(4, 0);
pad_right = pd.get(15, pad_left);
pad_top = pd.get(14, pad_left);
pad_bottom = pd.get(16, pad_top);
bias_term = pd.get(5, 0);
weight_data_size = pd.get(6, 0);
activation_type = pd.get(9, 0);
activation_params = pd.get(10, Mat());
return 0;
}
int DeformableConv2D::load_model(const ModelBin& mb)
{
weight_data = mb.load(weight_data_size, 0);
if (weight_data.empty())
return -100;
if (bias_term)
{
bias_data = mb.load(num_output, 1);
if (bias_data.empty())
return -100;
}
return 0;
}
int DeformableConv2D::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
const Mat& bottom_blob = bottom_blobs[0];
const Mat& offset = bottom_blobs[1];
const bool has_mask = (bottom_blobs.size() == 3);
const int w = bottom_blob.w;
const int h = bottom_blob.h;
const int in_c = bottom_blob.c;
const size_t elemsize = bottom_blob.elemsize;
const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;
const int kernel_extent_h = dilation_h * (kernel_h - 1) + 1;
const int out_w = (w + pad_left + pad_right - kernel_extent_w) / stride_w + 1;
const int out_h = (h + pad_top + pad_bottom - kernel_extent_h) / stride_h + 1;
// output.shape is [num_output, out_h, out_w] (in python).
Mat& output = top_blobs[0];
output.create(out_w, out_h, num_output, elemsize, opt.blob_allocator);
if (output.empty())
return -100;
const float* weight_ptr = weight_data;
const float* bias_ptr = weight_data;
if (bias_term)
bias_ptr = bias_data;
// deformable conv
#pragma omp parallel for num_threads(opt.num_threads)
for (int h_col = 0; h_col < out_h; h_col++)
{
for (int w_col = 0; w_col < out_w; w_col++)
{
int h_in = h_col * stride_h - pad_top;
int w_in = w_col * stride_w - pad_left;
for (int oc = 0; oc < num_output; oc++)
{
float sum = 0.f;
if (bias_term)
sum = bias_ptr[oc];
for (int i = 0; i < kernel_h; i++)
{
for (int j = 0; j < kernel_w; j++)
{
const float offset_h = offset.channel((i * kernel_w + j) * 2).row(h_col)[w_col];
const float offset_w = offset.channel((i * kernel_w + j) * 2 + 1).row(h_col)[w_col];
const float mask_ = has_mask ? bottom_blobs[2].channel(i * kernel_w + j).row(h_col)[w_col] : 1.f;
const float h_im = h_in + i * dilation_h + offset_h;
const float w_im = w_in + j * dilation_w + offset_w;
// Bilinear
const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w;
int h_low = 0;
int w_low = 0;
int h_high = 0;
int w_high = 0;
float w1 = 0.f;
float w2 = 0.f;
float w3 = 0.f;
float w4 = 0.f;
bool v1_cond = false;
bool v2_cond = false;
bool v3_cond = false;
bool v4_cond = false;
if (cond)
{
h_low = floor(h_im);
w_low = floor(w_im);
h_high = h_low + 1;
w_high = w_low + 1;
float lh = h_im - h_low;
float lw = w_im - w_low;
float hh = 1 - lh;
float hw = 1 - lw;
v1_cond = (h_low >= 0 && w_low >= 0);
v2_cond = (h_low >= 0 && w_high <= w - 1);
v3_cond = (h_high <= h - 1 && w_low >= 0);
v4_cond = (h_high <= h - 1 && w_high <= w - 1);
w1 = hh * hw;
w2 = hh * lw;
w3 = lh * hw;
w4 = lh * lw;
}
for (int c_im = 0; c_im < in_c; c_im++)
{
float val = 0.f;
if (cond)
{
float v1 = v1_cond ? bottom_blob.channel(c_im).row(h_low)[w_low] : 0.f;
float v2 = v2_cond ? bottom_blob.channel(c_im).row(h_low)[w_high] : 0.f;
float v3 = v3_cond ? bottom_blob.channel(c_im).row(h_high)[w_low] : 0.f;
float v4 = v4_cond ? bottom_blob.channel(c_im).row(h_high)[w_high] : 0.f;
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
sum += val * mask_ * weight_ptr[((oc * in_c + c_im) * kernel_h + i) * kernel_w + j];
}
}
}
output.channel(oc).row(h_col)[w_col] = activation_ss(sum, activation_type, activation_params);
}
}
}
return 0;
}
} // namespace ncnn
PPYOLOv2输出解码
# mmdet/models/heads/yolov3_head.py
...
if self.iou_aware:
na = len(self.anchors[i])
ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
b, c, h, w = x.shape
no = c // na
x = x.reshape((b, na, no, h * w))
ioup = ioup.reshape((b, na, 1, h * w))
obj = x[:, :, 4:5, :]
ioup = torch.sigmoid(ioup)
obj = torch.sigmoid(obj)
obj_t = (obj**(1 - self.iou_aware_factor)) * (
ioup**self.iou_aware_factor)
obj_t = _de_sigmoid(obj_t)
loc_t = x[:, :, :4, :]
cls_t = x[:, :, 5:, :]
y_t = torch.cat([loc_t, obj_t, cls_t], 2)
out = y_t.reshape((b, c, h, w))
box, score = paddle_yolo_box(out, self._anchors[self.anchor_masks[i]], self.downsample[i],
self.num_classes, self.scale_x_y, im_size, self.clip_bbox,
conf_thresh=self.nms_cfg['score_threshold'])
即分别对ioup和obj进行sigmoid激活,再obj_t = (obj ** (1 - self.iou_aware_factor)) * (ioup ** self.iou_aware_factor)作为新的obj,新的obj经过sigmoid的反函数还原成未接码状态,未接码的新obj贴回x中。最后out的通道数是255,只要像原版YOLOv3那样解码out就行了。
这么做的原因是paddle_yolo_box()的作用是对原版YOLOv3的输出进行解码,充分利用paddle_yolo_box()的话就不用自己写解码的代码。所以就走了曲线救国的道路。
从中我们可以得到一些信息,ioup只不过是和obj经过表达式obj_t = (obj ** (1 - self.iou_aware_factor)) * (ioup ** self.iou_aware_factor)得到新的obj,其余只要像YOLOv3一样解码就ok了!
所以在ncnn中,我这样实现PPYOLOv2的解码:
// examples/test2_06_ppyolo_ncnn.cpp
...
class PPYOLODecodeMatrixNMS : public ncnn::Layer
{
public:
PPYOLODecodeMatrixNMS()
{
// miemie2013: if num of input tensors > 1 or num of output tensors > 1, you must set one_blob_only = false
// And ncnn will use forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) method
// or forward_inplace(std::vector<Mat>& bottom_top_blobs, const Option& opt) method
one_blob_only = false;
support_inplace = false;
}
virtual int load_param(const ncnn::ParamDict& pd)
{
num_classes = pd.get(0, 80);
anchors = pd.get(1, ncnn::Mat());
strides = pd.get(2, ncnn::Mat());
scale_x_y = pd.get(3, 1.f);
iou_aware_factor = pd.get(4, 0.5f);
score_threshold = pd.get(5, 0.1f);
anchor_per_stride = pd.get(6, 3);
post_threshold = pd.get(7, 0.1f);
nms_top_k = pd.get(8, 500);
keep_top_k = pd.get(9, 100);
kernel = pd.get(10, 0);
gaussian_sigma = pd.get(11, 2.f);
return 0;
}
virtual int forward(const std::vector<ncnn::Mat>& bottom_blobs, std::vector<ncnn::Mat>& top_blobs, const ncnn::Option& opt) const
{
const ncnn::Mat& bottom_blob = bottom_blobs[0];
const int tensor_num = bottom_blobs.size() - 1;
const size_t elemsize = bottom_blob.elemsize;
const ncnn::Mat& im_scale = bottom_blobs[tensor_num];
const float scale_x = im_scale[0];
const float scale_y = im_scale[1];
int out_num = 0;
for (size_t b = 0; b < tensor_num; b++)
{
const ncnn::Mat& tensor = bottom_blobs[b];
const int w = tensor.w;
const int h = tensor.h;
out_num += anchor_per_stride * h * w;
}
ncnn::Mat bboxes;
bboxes.create(4 * out_num, elemsize, opt.blob_allocator);
if (bboxes.empty())
return -100;
ncnn::Mat scores;
scores.create(num_classes * out_num, elemsize, opt.blob_allocator);
if (scores.empty())
return -100;
float* bboxes_ptr = bboxes;
float* scores_ptr = scores;
// decode
for (size_t b = 0; b < tensor_num; b++)
{
const ncnn::Mat& tensor = bottom_blobs[b];
const int w = tensor.w;
const int h = tensor.h;
const int c = tensor.c;
const bool use_iou_aware = (c == anchor_per_stride * (num_classes + 6));
const int channel_stride = use_iou_aware ? (c / anchor_per_stride) - 1 : (c / anchor_per_stride);
const int cx_pos = use_iou_aware ? anchor_per_stride : 0;
const int cy_pos = use_iou_aware ? anchor_per_stride + 1 : 1;
const int w_pos = use_iou_aware ? anchor_per_stride + 2 : 2;
const int h_pos = use_iou_aware ? anchor_per_stride + 3 : 3;
const int obj_pos = use_iou_aware ? anchor_per_stride + 4 : 4;
const int cls_pos = use_iou_aware ? anchor_per_stride + 5 : 5;
float stride = strides[b];
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
for (int j = 0; j < w; j++)
{
for (int k = 0; k < anchor_per_stride; k++)
{
float obj = tensor.channel(obj_pos + k * channel_stride).row(i)[j];
obj = static_cast<float>(1.f / (1.f + expf(-obj)));
if (use_iou_aware)
{
float ioup = tensor.channel(k).row(i)[j];
ioup = static_cast<float>(1.f / (1.f + expf(-ioup)));
obj = static_cast<float>(pow(obj, 1.f - iou_aware_factor) * pow(ioup, iou_aware_factor));
}
if (obj > score_threshold)
{
// Grid Sensitive
float cx = static_cast<float>(scale_x_y / (1.f + expf(-tensor.channel(cx_pos + k * channel_stride).row(i)[j])) + j - (scale_x_y - 1.f) * 0.5f);
float cy = static_cast<float>(scale_x_y / (1.f + expf(-tensor.channel(cy_pos + k * channel_stride).row(i)[j])) + i - (scale_x_y - 1.f) * 0.5f);
cx *= stride;
cy *= stride;
float dw = static_cast<float>(expf(tensor.channel(w_pos + k * channel_stride).row(i)[j]) * anchors[(b * anchor_per_stride + k) * 2]);
float dh = static_cast<float>(expf(tensor.channel(h_pos + k * channel_stride).row(i)[j]) * anchors[(b * anchor_per_stride + k) * 2 + 1]);
float x0 = cx - dw * 0.5f;
float y0 = cy - dh * 0.5f;
float x1 = cx + dw * 0.5f;
float y1 = cy + dh * 0.5f;
bboxes_ptr[((i * w + j) * anchor_per_stride + k) * 4] = x0 / scale_x;
bboxes_ptr[((i * w + j) * anchor_per_stride + k) * 4 + 1] = y0 / scale_y;
bboxes_ptr[((i * w + j) * anchor_per_stride + k) * 4 + 2] = x1 / scale_x;
bboxes_ptr[((i * w + j) * anchor_per_stride + k) * 4 + 3] = y1 / scale_y;
for (int r = 0; r < num_classes; r++)
{
float score = static_cast<float>(obj / (1.f + expf(-tensor.channel(cls_pos + k * channel_stride + r).row(i)[j])));
scores_ptr[((i * w + j) * anchor_per_stride + k) * num_classes + r] = score;
}
}else
{
bboxes_ptr[((i * w + j) * anchor_per_stride + k) * 4] = 0.f;
bboxes_ptr[((i * w + j) * anchor_per_stride + k) * 4 + 1] = 0.f;
bboxes_ptr[((i * w + j) * anchor_per_stride + k) * 4 + 2] = 1.f;
bboxes_ptr[((i * w + j) * anchor_per_stride + k) * 4 + 3] = 1.f;
for (int r = 0; r < num_classes; r++)
{
scores_ptr[((i * w + j) * anchor_per_stride + k) * num_classes + r] = -1.f;
}
}
}
}
}
bboxes_ptr += h * w * anchor_per_stride * 4;
scores_ptr += h * w * anchor_per_stride * num_classes;
}
...
MatrixNMS
// examples/test2_06_ppyolo_ncnn.cpp
...
struct Bbox
{
float x0;
float y0;
float x1;
float y1;
int clsid;
float score;
};
bool compare_desc(Bbox bbox1, Bbox bbox2)
{
return bbox1.score > bbox2.score;
}
float calc_iou(Bbox bbox1, Bbox bbox2)
{
float area_1 = (bbox1.y1 - bbox1.y0) * (bbox1.x1 - bbox1.x0);
float area_2 = (bbox2.y1 - bbox2.y0) * (bbox2.x1 - bbox2.x0);
float inter_x0 = std::max(bbox1.x0, bbox2.x0);
float inter_y0 = std::max(bbox1.y0, bbox2.y0);
float inter_x1 = std::min(bbox1.x1, bbox2.x1);
float inter_y1 = std::min(bbox1.y1, bbox2.y1);
float inter_w = std::max(0.f, inter_x1 - inter_x0);
float inter_h = std::max(0.f, inter_y1 - inter_y0);
float inter_area = inter_w * inter_h;
float union_area = area_1 + area_2 - inter_area + 0.000000001f;
return inter_area / union_area;
}
...
class PPYOLODecodeMatrixNMS : public ncnn::Layer
{
public:
...
virtual int forward(const std::vector<ncnn::Mat>& bottom_blobs, std::vector<ncnn::Mat>& top_blobs, const ncnn::Option& opt) const
{
...
// keep bbox whose score > score_threshold
std::vector<Bbox> bboxes_vec;
for (int i = 0; i < out_num; i++)
{
float x0 = bboxes[i * 4];
float y0 = bboxes[i * 4 + 1];
float x1 = bboxes[i * 4 + 2];
float y1 = bboxes[i * 4 + 3];
for (int j = 0; j < num_classes; j++)
{
float score = scores[i * num_classes + j];
if (score > score_threshold)
{
Bbox bbox;
bbox.x0 = x0;
bbox.y0 = y0;
bbox.x1 = x1;
bbox.y1 = y1;
bbox.clsid = j;
bbox.score = score;
bboxes_vec.push_back(bbox);
}
}
}
if (bboxes_vec.size() == 0)
{
ncnn::Mat& pred = top_blobs[0];
pred.create(0, 0, elemsize, opt.blob_allocator);
if (pred.empty())
return -100;
return 0;
}
// sort and keep top nms_top_k
int nms_top_k_ = nms_top_k;
if (bboxes_vec.size() < nms_top_k)
nms_top_k_ = bboxes_vec.size();
size_t count {(size_t)nms_top_k_};
std::partial_sort(std::begin(bboxes_vec), std::begin(bboxes_vec) + count, std::end(bboxes_vec), compare_desc);
if (bboxes_vec.size() > nms_top_k)
bboxes_vec.resize(nms_top_k);
// ---------------------- Matrix NMS ----------------------
// calc a iou matrix whose shape is [n, n], n is bboxes_vec.size()
int n = bboxes_vec.size();
float* decay_iou = new float[n * n];
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
if (j < i + 1)
{
decay_iou[i * n + j] = 0.f;
}else
{
bool same_clsid = bboxes_vec[i].clsid == bboxes_vec[j].clsid;
if (same_clsid)
{
float iou = calc_iou(bboxes_vec[i], bboxes_vec[j]);
decay_iou[i * n + j] = iou;
}else
{
decay_iou[i * n + j] = 0.f;
}
}
}
}
// get max iou of each col
float* compensate_iou = new float[n];
for (int i = 0; i < n; i++)
{
float max_iou = decay_iou[i];
for (int j = 0; j < n; j++)
{
if (decay_iou[j * n + i] > max_iou)
max_iou = decay_iou[j * n + i];
}
compensate_iou[i] = max_iou;
}
float* decay_matrix = new float[n * n];
// get min decay_value of each col
float* decay_coefficient = new float[n];
if (kernel == 0) // gaussian
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
decay_matrix[i * n + j] = static_cast<float>(expf(gaussian_sigma * (compensate_iou[i] * compensate_iou[i] - decay_iou[i * n + j] * decay_iou[i * n + j])));
}
}
}else if (kernel == 1) // linear
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
decay_matrix[i * n + j] = (1.f - decay_iou[i * n + j]) / (1.f - compensate_iou[i]);
}
}
}
for (int i = 0; i < n; i++)
{
float min_v = decay_matrix[i];
for (int j = 0; j < n; j++)
{
if (decay_matrix[j * n + i] < min_v)
min_v = decay_matrix[j * n + i];
}
decay_coefficient[i] = min_v;
}
for (int i = 0; i < n; i++)
{
bboxes_vec[i].score *= decay_coefficient[i];
}
// ---------------------- Matrix NMS (end) ----------------------
std::vector<Bbox> bboxes_vec_keep;
for (int i = 0; i < n; i++)
{
if (bboxes_vec[i].score > post_threshold)
{
bboxes_vec_keep.push_back(bboxes_vec[i]);
}
}
n = bboxes_vec_keep.size();
if (n == 0)
{
ncnn::Mat& pred = top_blobs[0];
pred.create(0, 0, elemsize, opt.blob_allocator);
if (pred.empty())
return -100;
return 0;
}
// sort and keep keep_top_k
int keep_top_k_ = keep_top_k;
if (n < keep_top_k)
keep_top_k_ = n;
size_t keep_count {(size_t)keep_top_k_};
std::partial_sort(std::begin(bboxes_vec_keep), std::begin(bboxes_vec_keep) + keep_count, std::end(bboxes_vec_keep), compare_desc);
if (bboxes_vec_keep.size() > keep_top_k)
bboxes_vec_keep.resize(keep_top_k);
ncnn::Mat& pred = top_blobs[0];
pred.create(6 * n, elemsize, opt.blob_allocator);
if (pred.empty())
return -100;
float* pred_ptr = pred;
for (int i = 0; i < n; i++)
{
pred_ptr[i * 6] = (float)bboxes_vec_keep[i].clsid;
pred_ptr[i * 6 + 1] = bboxes_vec_keep[i].score;
pred_ptr[i * 6 + 2] = bboxes_vec_keep[i].x0;
pred_ptr[i * 6 + 3] = bboxes_vec_keep[i].y0;
pred_ptr[i * 6 + 4] = bboxes_vec_keep[i].x1;
pred_ptr[i * 6 + 5] = bboxes_vec_keep[i].y1;
}
pred = pred.reshape(6, n);
return 0;
}
...
第一步,将得分超过score_threshold的预测框保存到bboxes_vec里,这是第一次分数过滤;如果没有预测框的得分超过score_threshold,直接返回1个形状是(0, 0)的Mat代表没有物体。
第二步,将bboxes_vec中的前nms_top_k个预测框按照得分降序排列,bboxes_vec中只保留前nms_top_k个预测框。
第三步,进入MatrixNMS,设此时bboxes_vec里有n个预测框,我们计算一个n * n的矩阵decay_iou,下三角部分(包括对角线)是0,表示的是bboxes_vec中的预测框两两之间的iou,而且,只计算同类别预测框的iou,非同类的预测框iou置为0;
接下来的代码比较难以理解,我举个例子说明,比如经过第一次分数过滤和得分降序排列后,剩下编号为0、1、2的3个同类的预测框,假设此时的decay_iou值为:
每一列求最小,得到衰减系数向量decay_coefficient=[1, 0.1, 0.2],然后每个bbox的得分再和衰减系数向量里相应的值相乘,就实现减分的效果了!
逐列取decay_matrix的最小值,即可得到decay_coefficient=[1, 0.1, 0.8],你看,2号预测框的得分应该乘以0.8,是由于它和0号预测框的iou是0.2导致的,它减去的分数就比较少。而此时1号预测框和2号预测框在decay_matrix中的值被补偿(被放大)到2,参考意义不大,逐列取最小时取不到它。
现在你应该能更好地理解代码中decay_matrix的计算公式了吗?
decay_matrix[i * n + j] = (1.f - decay_iou[i * n + j]) / (1.f - compensate_iou[i]);
第i个预测框和第j个预测框的iou是decay_iou[i * n + j],第i个预测框它觉得第j个预测框的衰减系数应该是(1.f - decay_iou[i * n + j]),但是第i个预测框它觉得的就是对的吗?
还要看第i个预测框是否被抑制,第i个预测框如果没有被抑制,那么(1.f - decay_iou[i * n + j])就有参考意义,第i个预测框如果被抑制,那么(1.f - decay_iou[i * n + j])就没有什么参考意义。
所以需要除以(1.f - compensate_iou[i])作为补偿,compensate_iou[i]表示的是第i个预测框与比它分高的预测框的最高iou:
如果这个max_iou很大,衰减系数就会被放大,第i个预测框它觉得第j个预测框的衰减系数是xxx就没什么参考意义;如果这个max_iou很小,衰减系数就会放大得很小(max_iou==0时不放大),第i个预测框它觉得第j个预测框的衰减系数是xxx就有参考意义。
然后,逐列取decay_matrix的最小值,第j列的最小值应该是decay_iou[i * n + j]越大越好、compensate_iou[i]越小越好的那个第i个预测框提供。
当kernel == 0,也仅仅表示用其它的函数表示衰减系数和补偿而已。所有的预测框的得分乘以decay_coefficient相应的值实现减分,MatrixNMS结束。
第四步,将得分超过post_threshold的预测框保存到bboxes_vec_keep里,这是第二次分数过滤;如果没有预测框的得分超过post_threshold,直接返回1个形状是(0, 0)的Mat代表没有物体。
第五步,将bboxes_vec_keep中的前keep_top_k个预测框按照得分降序排列,bboxes_vec_keep中只保留前keep_top_k个预测框。
最后,写1个形状是(n, 6)的Mat表示最终所有的预测框后处理结束。
如何导出
wget https://paddledet.bj.bcebos.com/models/ppyolo_r50vd_dcn_2x_coco.pdparams
wget https://paddledet.bj.bcebos.com/models/ppyolo_r18vd_coco.pdparams
wget https://paddledet.bj.bcebos.com/models/ppyolov2_r50vd_dcn_365e_coco.pdparams
wget https://paddledet.bj.bcebos.com/models/ppyolov2_r101vd_dcn_365e_coco.pdparams
python tools/convert_weights.py -f exps/ppyolo/ppyolo_r50vd_2x.py -c ppyolo_r50vd_dcn_2x_coco.pdparams -oc ppyolo_r50vd_2x.pth -nc 80
python tools/convert_weights.py -f exps/ppyolo/ppyolo_r18vd.py -c ppyolo_r18vd_coco.pdparams -oc ppyolo_r18vd.pth -nc 80
python tools/convert_weights.py -f exps/ppyolo/ppyolov2_r50vd_365e.py -c ppyolov2_r50vd_dcn_365e_coco.pdparams -oc ppyolov2_r50vd_365e.pth -nc 80
python tools/convert_weights.py -f exps/ppyolo/ppyolov2_r101vd_365e.py -c ppyolov2_r101vd_dcn_365e_coco.pdparams -oc ppyolov2_r101vd_365e.pth -nc 80
(3)第三步,在miemiedetection根目录下输入这些命令将pytorch模型转ncnn模型:
python tools/demo.py ncnn -f exps/ppyolo/ppyolo_r18vd.py -c ppyolo_r18vd.pth --ncnn_output_path ppyolo_r18vd --conf 0.15
python tools/demo.py ncnn -f exps/ppyolo/ppyolo_r50vd_2x.py -c ppyolo_r50vd_2x.pth --ncnn_output_path ppyolo_r50vd_2x --conf 0.15
python tools/demo.py ncnn -f exps/ppyolo/ppyolov2_r50vd_365e.py -c ppyolov2_r50vd_365e.pth --ncnn_output_path ppyolov2_r50vd_365e --conf 0.15
python tools/demo.py ncnn -f exps/ppyolo/ppyolov2_r101vd_365e.py -c ppyolov2_r101vd_365e.pth --ncnn_output_path ppyolov2_r101vd_365e --conf 0.15
-c代表读取的权重,--ncnn_output_path表示的是保存为NCNN所用的 *.param 和 *.bin 文件的文件名,--conf 0.15表示的是在PPYOLODecodeMatrixNMS层中将score_threshold和post_threshold设置为0.15,你可以在导出的 *.param 中修改score_threshold和post_threshold,分别是PPYOLODecodeMatrixNMS层的5=xxx 7=xxx属性。
然后,下载ncnn_ppyolov2 这个仓库(它自带了glslang和实现了ppyolov2推理),按照官方how-to-build 文档进行编译ncnn。
编译完成后, 将上文得到的ppyolov2_r50vd_365e.param、ppyolov2_r50vd_365e.bin、...这些文件复制到ncnn_ppyolov2的build/examples/目录下,最后在ncnn_ppyolov2根目录下运行以下命令进行ppyolov2的预测:
cd build/examples
./test2_06_ppyolo_ncnn ../../my_tests/000000013659.jpg ppyolo_r18vd.param ppyolo_r18vd.bin 416
./test2_06_ppyolo_ncnn ../../my_tests/000000013659.jpg ppyolo_r50vd_2x.param ppyolo_r50vd_2x.bin 608
./test2_06_ppyolo_ncnn ../../my_tests/000000013659.jpg ppyolov2_r50vd_365e.param ppyolov2_r50vd_365e.bin 640
./test2_06_ppyolo_ncnn ../../my_tests/000000013659.jpg ppyolov2_r101vd_365e.param ppyolov2_r101vd_365e.bin 640
每条命令最后1个参数416、608、640表示的是将图片resize到416、608、640进行推理,即target_size参数。会弹出一个这样的窗口展示预测结果: