加入极市专业CV交流群,与6000+来自腾讯,华为,百度,北大,清华,中科院等名企名校视觉开发者互动交流!更有机会与李开复老师等大牛群内互动!
同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注 极市平台 公众号 ,回复 加群,立刻申请入群~
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class OctConv2d(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
alpha_in=0.5,
alpha_out=0.5,):
assert alpha_in >= 0 and alpha_in <= 1
assert alpha_out >= 0 and alpha_out <= 1
super(OctConv2d, self).__init__(in_channels, out_channels,
kernel_size, stride, padding,
dilation, groups, bias)
self.avgpool = nn.AvgPool2d(kernel_size= 2, stride= 2)
self.alpha_in = alpha_in
self.alpha_out = alpha_out
self.inChannelSplitIndex = math.floor(
self.alpha_in * self.in_channels)
self.outChannelSplitIndex = math.floor(
self.alpha_out * self.out_channels)
def forward(self, input):
if not isinstance(input, tuple):
assert self.alpha_in == 0 or self.alpha_in == 1
inputLow = input if self.alpha_in == 1 else None
inputHigh = input if self.alpha_in == 0 else None
else:
inputLow = input[ 0]
inputHigh = input[ 1]
output = [ 0, 0]
# H->H
if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != self.in_channels:
outputH2H = F.conv2d(
inputHigh,
self.weight[
self.outChannelSplitIndex:,
self.inChannelSplitIndex:,
:,
:],
self.bias[
self.outChannelSplitIndex:],
self.stride,
self.padding,
self.dilation,
self.groups)
output[ 1] += outputH2H
# H->L
if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != self.in_channels:
outputH2L = F.conv2d(
self.avgpool(inputHigh),
self.weight[
:self.outChannelSplitIndex,
self.inChannelSplitIndex:,
:,
:],
self.bias[
:self.outChannelSplitIndex],
self.stride,
self.padding,
self.dilation,
self.groups)
output[ 0] += outputH2L
# L->L
if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != 0:
outputL2L = F.conv2d(
inputLow,
self.weight[
:self.outChannelSplitIndex,
:self.inChannelSplitIndex,
:,
:],
self.bias[
:self.outChannelSplitIndex],
self.stride,
self.padding,
self.dilation,
self.groups)
output[ 0] += outputL2L
# L->H
if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != 0:
outputL2H = F.conv2d(
F.interpolate(inputLow, scale_factor= 2),
self.weight[
self.outChannelSplitIndex:,
:self.inChannelSplitIndex,
:,
:],
self.bias[
self.outChannelSplitIndex:],
self.stride,
self.padding,
self.dilation,
self.groups)
output[ 1] += outputL2H
return tuple(output)
-End-
*延伸阅读
CV细分方向交流群
添加极市小助手微信(ID : cv-mart),备注:研究方向-姓名-学校/公司-城市(如:目标检测-小极-北大-深圳),即可申请加入目标检测、目标跟踪、人脸、工业检测、医学影像、三维&SLAM、图像分割等极市技术交流群(已经添加小助手的好友直接私信),更有每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流,一起来让思想之光照的更远吧~
△长按添加极市小助手
△长按关注极市平台
觉得有用麻烦给个在看啦~