将TVM集成到PyTorch中。
* TVM:深度学习编译器
Github项目链接:
https://github.com/pytorch/tvm
你需要在这个PR的基础上构建PyTorch:https://github.com/pytorch/pytorch/pull/18588
cd pytorch
git fetch origin pull/18588/head:tvm_dev
git checkout tvm_dev
python setup.py install
然后,你需要单独构建本仓库:
# Make sure the right llvm-config is in your PATH
python setup.py install
python setup.py test
这个包非常显然地挂钩到 PyTorch 的 JIT 中,因此适用相同的工具(可查看 @torch.jit.script,torch.jit.trace 和 graph_for)。以下是使用示例:
from tvm import relay # This imports all the topi operators
import torch_tvm
torch_tvm.enable()
# The following function will be compiled with TVM
def my_func(a, b, c):
return a * b + c
如果要禁用JIT挂钩,请使用 torch_tvm.disable() 。
register.cpp:设置pybind绑定并调用TVM后端的注册。
compiler.{h,cpp}:用TVM编译PyTorch JIT图的主要逻辑。
operators.{h,cpp}:从JIT IR映射到TVM操作符的位置。
TODO
添加从Python中将不透明op名称的翻译注册到TVM中(如在operator.cpp中完成)的功能。
零拷贝 set_input
纾困机制(调用PyTorch JIT后备)
Threadpool 集成
分配器集成
操作符翻译
加
乘
卷积
BatchNorm
RELU
AveragePool
MaxPool
线性
张量操作
重塑
查看
备受大家期待的强化学习课程终于上线啦!
扫描下方邀请卡,解锁更多课时