Pytorch – 使用torch.onnx.export将Pytorch模型导出为ONNX模型
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 使用torch.onnx.export将Pytorch模型导出为ONNX模型
原文链接:https://www.stubbornhuang.com/1694/
发布于:2021年09月16日 15:02:52
修改于:2021年09月16日 15:28:49
1 torch.onnx.export
torch.onnx.export(
model,
args,
f,
export_params=True,
verbose=False,
training=<TrainingMode.EVAL: 0>,
input_names=None,
output_names=None,
aten=False,
export_raw_ir=False,
operator_export_type=None,
opset_version=None,
_retain_param_name=True,
do_constant_folding=True,
example_outputs=None,
strip_doc_string=True,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
enable_onnx_checker=True,
use_external_data_format=False)
1.1 函数作用
pytorch模型导出为ONNX格式,这个导出器通过运行一次Pytorch模型获得模型的执行轨迹。目前,它支持一组有限的动态模型(例如,RNN)。
1.2 函数参数
- model - torch.nn.Module,要导出的模型;
- args - 参数元祖,模型的输入,例如model(*args),支持单个参数或者多个参数。任何非变量参数将被硬编码到导出的模型中,按照它们在参数中出现的顺序,任何变量参数都将成为输出模型的输入;
- f - 输出文件对象或者包含文件名的字符串;
- export_params - 布尔值,默认为True,如果指定为True,则所有参数将被导出。如果要导出未经训练的模型,请将该参数设置为False;
- verbose - 布尔值,默认为False,如果指定为True,将打印一个调试描述的导出轨迹;
- training - 布尔值,默认为False,如果指定为True,将在训练模式下导出模型。目前,ONNX只是为了导出模型,所以你通常不需要将其设置为True;
- input_names - 字符串列表,默认为空列表。按顺序分配给图形输入节点的名称;
- output_names - 字符串列表,默认为空列表。按顺序分配给图形输出节点的名称;
- aten - bool ,默认为 False。 [已弃用。使用 operator_export_type] 以 aten 模式导出模型。如果使用aten模式,所有由symbolic_opset
.py中的函数导出的原始ops都被导出为ATen ops。 - export_raw_ir - bool,默认为 False。[已弃用。使用 operator_export_type] 直接导出内部 IR,而不是将其转换为 ONNX 操作。
- operator_export_type - enum ,默认 OperatorExportTypes.ONNX
- OperatorExportTypes.ONNX:所有操作都导出为常规 ONNX 操作(带有 ONNX 命名空间)
- OperatorExportTypes.ONNX_ATEN:所有操作都导出为 ATen 操作(带有 aten 命名空间)
- OperatorExportTypes.ONNX_ATEN_FALLBACK:如果 ONNX 不支持 ATen 操作或其符号丢失,请回退到 ATen 操作。注册的操作会定期导出到 ONNX
- opset_version - int , default is 9。默认情况下,我们将模型导出到 onnx 子模块的 opset 版本。由于 ONNX 的最新 opset 可能会在下一个稳定版本之前发展,因此默认情况下我们导出到一个稳定的 opset 版本。现在,支持的稳定 opset 版本是 9。
- do_constant_folding - bool , default False。如果为 True,则在导出期间将常量折叠优化应用于模型。常量折叠优化将用预先计算的常量节点替换一些具有所有常量输入的操作
- example_outputs - tuple of Tensors , list of Tensor , Tensor , int , float , bool , default None。正在导出的模型示例输出。导出 ScriptModule 或 TorchScript 函数时必须提供“example_outputs”。如果有多个项目,则应以元组格式传递,例如:example_outputs = (x, y, z)。否则,只应传递一项作为示例输出,例如 example_outputs=x。导出 ScriptModule 或 TorchScript 函数时必须提供 example_outputs。
- strip_doc_string - bool , default True。如果为 True,则从导出的模型中删除字段“doc_string”,其中包含有关堆栈跟踪的信息;
- dynamic_axes - ( dict
>或dict ,默认为空 dict )。用于指定输入/输出的动态轴的字典,例如: - KEY:输入和/或输出名称 - VALUE:给定键的动态轴索引以及可能用于导出动态轴的名称。通常,根据以下方式之一或两者的组合来定义该值: (1)。指定所提供输入的动态轴的整数列表。在这种情况下,将生成自动名称并将其应用于导出期间提供的输入/输出的动态轴。或 (2)。一个内部字典,指定从相应输入/输出中的动态轴索引到导出期间希望应用于此类输入/输出的此类轴的名称的映射; - keep_initializers_as_inputs - bool,默认 None。如果为 True,导出图中的所有初始值设定项(通常对应于参数)也将作为输入添加到图中。如果为 False,则初始值设定项不会作为输入添加到图形中,而仅将非参数输入添加为输入;
- custom_opsets - dict
, default empty dict。用于在导出时指示自定义 opset 域和版本的字典。如果模型包含自定义操作集,则可以选择在字典中指定域和操作集版本: - KEY:操作集域名 - VALUE:操作集版本如果此字典中未提供自定义操作集,则操作集版本设置为 1默认; - enable_onnx_checker - bool , default True。如果为 True,onnx 模型检查器将作为导出的一部分运行,以确保导出的模型是有效的 ONNX 模型;
- use_external_data_format - bool , default False。如果为 True,则模型以 ONNX 外部数据格式导出,在这种情况下,某些模型参数存储在外部二进制文件中,而不是存储在 ONNX 模型文件本身中;
2 使用示例
2.1 导出示例1
一个简单的示例,将torchvision中定义的预训练的AlexNet导出到ONNX模型中,并将模型保存到alexnet.onnx
import torch
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
2.2 导出示例2
一个简单的示例,将torchvision中densenet169中定义的预训练模型DenseNet169导出到ONNX模型,并将模型保存到./onnx/dense.onnx中
import torch
from torchvision.models import densenet169
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = densenet169(pretrained=True).cuda()
input_names = ["features"]
output_names = ["classifier"]
output = torch.onnx.export(model, dummy_input, "./onnx/dense.onnx", verbose=True, input_names=input_names, output_names=output_names)
当前分类随机文章推荐
- Pytorch - RuntimeError: No rendezvous handler for env://错误 阅读1057次,点赞0次
- Pytorch - .to()和.cuda()的区别 阅读848次,点赞0次
- Pytorch - 为什么要设置随机数种子? 阅读560次,点赞0次
- Pytorch - 用Pytorch实现ResNet 阅读506次,点赞0次
- Pytorch - 多GPU训练方式nn.DataParallel与nn.parallel.DistributedDataParallel的区别 阅读923次,点赞0次
- Pytorch - torch.distributed.init_process_group函数详解 阅读623次,点赞0次
- Pytorch – 使用torch.matmul()替换torch.einsum('bhxyd,md->bhxym',(a,b))算子模式 阅读1115次,点赞0次
- Pytorch - 梯度累积/梯度累加trick,在显存有限的情况下使用更大batch_size训练模型 阅读402次,点赞0次
- Pytorch - 内置的CTC损失函数torch.nn.CTCLoss参数详解与使用示例 阅读1311次,点赞1次
- Pytorch - 内置的LSTM网络torch.nn.LSTM参数详解与使用示例 阅读2084次,点赞0次
全站随机文章推荐
- C++ - 最简单的将文本文件的内容一次性读取到std::string的方法 阅读4939次,点赞4次
- WordPress - 升级WordPress5.8后切换回旧版的小工具管理页面 阅读1620次,点赞0次
- 深度学习 - 语音识别框架wenet源码wenet/utils/mask.py中的mask机制 阅读881次,点赞1次
- 资源分享 - Mathematics for 3D Game Programming and Computer Graphics, Third Edition英文高清PDF下载 阅读2984次,点赞0次
- FFmpeg - 将某个文件夹下的图片按标号顺序合成为指定编码格式和指定帧率的视频 阅读4443次,点赞0次
- 风铃发卡系统配置Payjs支付 阅读181次,点赞0次
- C++ - 求解std::vector
中topk数值以及topk数值对应的索引 阅读2480次,点赞0次 - 资源分享 - C++程序设计语言(第1- 3部分),原书第4版 高清PDF下载 阅读2934次,点赞2次
- 深度学习 - CTC解码算法详解 阅读887次,点赞0次
- 资源分享 - Computational Geometry - An Introduction Through Randomized Algorithms 英文高清PDF下载 阅读1525次,点赞0次
评论
169