def warp_conv(x, conv, factor: int=32):
"""(W*(x/p)+b)*p-b*(p-1) == Wx+b
conv(x) == warp_conv(x, conv)
"""
x_tmp = conv(x / factor)
return factor * x_tmp - (factor - 1) * conv.bias.reshape(
1, -1, 1, 1).repeat(1, 1, x_tmp.size(2), x_tmp.size(3))
# input:x, output:out
w = weight / torch.sqrt(running_var + eps)
out = x * w + (bias - running_mean * w)
def warp_bn(x, bn, factor: int=32):
import torch
scale = bn.weight / torch.sqrt(bn.running_var + bn.eps)
bias = bn.bias - bn.running_mean * scale
bias_t = bias.reshape(1, -1, 1, 1).repeat(1, 1, x.size(2), x.size(3))
return bn(x / factor) * factor - (factor - 1) * bias_t
print(conv(x) - warp_conv(x, conv).sum())
print(bn(x) - warp_bn(x, bn).sum())
整个网络计算过程都有数值溢出
每次提前返回结果,二分地导出 ONNX 再导出 TensorRT 模型,未被导出的部分继续以 PyTorch 代码衔接到 TensoRT 的计算结果后。
from pyclbr import Function
from typing import Sequence
import torch
def fp16_check(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor) -> None:
if isinstance(input, dict):
for _, value in input.items():
fp16_check(module, value, output)
return
if isinstance(input, Sequence):
for value in input:
fp16_check(module, value, output)
return
if isinstance(output, dict):
for _, value in output.items():
fp16_check(module, input, value)
return
if isinstance(output, Sequence):
for value in output:
fp16_check(module, input, value)
return
if torch.abs(input).max()<65504 and torch.abs(output).max()>65504:
print('from: ', module.finspect_name)
if torch.abs(input).max()>65504 and torch.abs(output).max()<65504:
print('to: ', module.finspect_name)
return
from contextlib import contextmanager
class FInspect:
module_names = ['model']
handlers = []
def hook_all_impl(cls, module: torch.nn.Module, hook_func: Function)-> None:
for name, child in module.named_children():
cls.module_names.append(name)
cls.hook_all_impl(cls, module=child, hook_func=hook_func)
linked_name='->'.join(cls.module_names)
setattr(module, 'finspect_name', linked_name)
cls.module_names.pop()
handler = module.register_forward_hook(hook=hook_func)
cls.handlers.append(handler)
@classmethod
@contextmanager
def hook_all(cls, module: torch.nn.Module, hook_func: Function)-> None:
cls.hook_all_impl(cls, module, hook_func)
yield
[i.remove() for i in cls.handlers]
with FInspect.hook_all(patched_model, fp16_check):
patched_model(inputs)
mmocr.models.textdet.necks.FPEM_FFM.forward
mmdet.models.backbones.resnet.BasicBlock.forward
import torch
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils.constants import Backend
FACTOR = 32
ENABLE = False
CHANNEL_THRESH = 400
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textdet.necks.FPEM_FFM.forward',
backend=Backend.TENSORRT.value)
def fpem_ffm__forward__trt(ctx, self, x, *args, **kwargs):
c2, c3, c4, c5 = x
# reduce channel
c2 = self.reduce_conv_c2(c2)
c3 = self.reduce_conv_c3(c3)
c4 = self.reduce_conv_c4(c4)
if ENABLE:
bn_w = self.reduce_conv_c5[1].weight / torch.sqrt(
self.reduce_conv_c5[1].running_var + self.reduce_conv_c5[1].eps)
bn_b = self.reduce_conv_c5[
1].bias - self.reduce_conv_c5[1].running_mean * bn_w
bn_w = bn_w.reshape(1, -1, 1, 1).repeat(1, 1, c5.size(2), c5.size(3))
bn_b = bn_b.reshape(1, -1, 1, 1).repeat(1, 1, c5.size(2), c5.size(3))
conv_b = self.reduce_conv_c5[0].bias.reshape(1, -1, 1, 1).repeat(
1, 1, c5.size(2), c5.size(3))
c5 = FACTOR * (self.reduce_conv_c5[:-1](c5)) - (FACTOR - 1) * (
bn_w * conv_b + bn_b)
c5 = self.reduce_conv_c5[-1](c5)
else:
c5 = self.reduce_conv_c5(c5)
# FPEM
for i, fpem in enumerate(self.fpems):
c2, c3, c4, c5 = fpem(c2, c3, c4, c5)
if i == 0:
c2_ffm = c2
c3_ffm = c3
c4_ffm = c4
c5_ffm = c5
else:
c2_ffm += c2
c3_ffm += c3
c4_ffm += c4
c5_ffm += c5
# FFM
c5 = F.interpolate(
c5_ffm,
c2_ffm.size()[-2:],
mode='bilinear',
align_corners=self.align_corners)
c4 = F.interpolate(
c4_ffm,
c2_ffm.size()[-2:],
mode='bilinear',
align_corners=self.align_corners)
c3 = F.interpolate(
c3_ffm,
c2_ffm.size()[-2:],
mode='bilinear',
align_corners=self.align_corners)
outs = [c2_ffm, c3, c4, c5]
return tuple(outs)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.resnet.BasicBlock.forward',
backend=Backend.TENSORRT.value)
def basic_block__forward__trt(ctx, self, x):
if self.conv1.in_channels < CHANNEL_THRESH:
return ctx.origin_func(self, x)
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
if torch.abs(self.norm2(out)).max() < 65504:
out = self.norm2(out)
out += identity
out = self.relu(out)
return out
else:
global ENABLE
ENABLE = True
# the output of the last bn layer exceeds the range of fp16
w1 = self.norm2.weight / torch.sqrt(self.norm2.running_var +
self.norm2.eps)
bias = self.norm2.bias - self.norm2.running_mean * w1
w1 = w1.reshape(1, -1, 1, 1).repeat(1, 1, out.size(2), out.size(3))
bias = bias.reshape(1, -1, 1, 1).repeat(1, 1, out.size(2),
out.size(3)) + identity
out = self.relu(w1 * (out / FACTOR) + bias / FACTOR)
return out
一个替换多个算子,从原始模型解决 FP16 数值溢出的方法。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧