加入极市专业CV交流群,与6000+来自腾讯,华为,百度,北大,清华,中科院等名企名校视觉开发者互动交流!更有机会与李开复老师等大牛群内互动!
同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注 极市平台 公众号 ,回复 加群,立刻申请入群~
import math
import torch
import torch.nn
as nn
import torch.nn.functional
as F
class OctConv2d(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
alpha_in=0.5,
alpha_out=0.5,):
assert alpha_in >=
0
and alpha_in <=
1
assert alpha_out >=
0
and alpha_out <=
1
super(OctConv2d, self).__init__(in_channels, out_channels,
kernel_size, stride, padding,
dilation, groups, bias)
self.avgpool = nn.AvgPool2d(kernel_size=
2, stride=
2)
self.alpha_in = alpha_in
self.alpha_out = alpha_out
self.inChannelSplitIndex = math.floor(
self.alpha_in * self.in_channels)
self.outChannelSplitIndex = math.floor(
self.alpha_out * self.out_channels)
def forward(self, input):
if
not isinstance(input, tuple):
assert self.alpha_in ==
0
or self.alpha_in ==
1
inputLow = input
if self.alpha_in ==
1
else
None
inputHigh = input
if self.alpha_in ==
0
else
None
else:
inputLow = input[
0]
inputHigh = input[
1]
output = [
0,
0]
# H->H
if self.outChannelSplitIndex != self.out_channels
and self.inChannelSplitIndex != self.in_channels:
outputH2H = F.conv2d(
inputHigh,
self.weight[
self.outChannelSplitIndex:,
self.inChannelSplitIndex:,
:,
:],
self.bias[
self.outChannelSplitIndex:],
self.stride,
self.padding,
self.dilation,
self.groups)
output[
1] += outputH2H
# H->L
if self.outChannelSplitIndex !=
0
and self.inChannelSplitIndex != self.in_channels:
outputH2L = F.conv2d(
self.avgpool(inputHigh),
self.weight[
:self.outChannelSplitIndex,
self.inChannelSplitIndex:,
:,
:],
self.bias[
:self.outChannelSplitIndex],
self.stride,
self.padding,
self.dilation,
self.groups)
output[
0] += outputH2L
# L->L
if self.outChannelSplitIndex !=
0
and self.inChannelSplitIndex !=
0:
outputL2L = F.conv2d(
inputLow,
self.weight[
:self.outChannelSplitIndex,
:self.inChannelSplitIndex,
:,
:],
self.bias[
:self.outChannelSplitIndex],
self.stride,
self.padding,
self.dilation,
self.groups)
output[
0] += outputL2L
# L->H
if self.outChannelSplitIndex != self.out_channels
and self.inChannelSplitIndex !=
0:
outputL2H = F.conv2d(
F.interpolate(inputLow, scale_factor=
2),
self.weight[
self.outChannelSplitIndex:,
:self.inChannelSplitIndex,
:,
:],
self.bias[
self.outChannelSplitIndex:],
self.stride,
self.padding,
self.dilation,
self.groups)
output[
1] += outputL2H
return tuple(output)
-End-
*延伸阅读
CV细分方向交流群
添加极市小助手微信(ID : cv-mart),备注:研究方向-姓名-学校/公司-城市(如:目标检测-小极-北大-深圳),即可申请加入目标检测、目标跟踪、人脸、工业检测、医学影像、三维&SLAM、图像分割等极市技术交流群(已经添加小助手的好友直接私信),更有每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流,一起来让思想之光照的更远吧~
△长按添加极市小助手
△长按关注极市平台
觉得有用麻烦给个在看啦~