Python – 使用onnxruntime加载和推理onnx模型
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Python – 使用onnxruntime加载和推理onnx模型
原文链接:https://www.stubbornhuang.com/2427/
发布于:2022年11月30日 13:59:20
修改于:2022年12月01日 13:19:02

1 onnxruntime
Onnx runtime是一个跨平台的机器学习模型加速器,可以在不同的硬件和操作系统上运行,可以加载和推理任意机器学习框架导出的onnx模型并进行加速。
如要使用onnxruntime,一般通过以下步骤:
- 从机器学习框架中将模型导出为onnx
- 使用onnxruntime加载onnx模型并进行推理
onnxruntime官网:https://onnxruntime.ai/
Github地址:https://github.com/microsoft/onnxruntime
1.1 onnxruntime安装
onnxruntime在python上有两个版本:cpu和gpu版本,在一个python环境中只能安装一个版本,gpu版本包含了大部分cpu版本的内容,所以在有gpu的情况下,尽量安装gpu版本。
cpu版本
pip install onnxruntime
gpu版本
pip install onnxruntime-gpu
1.2 从pytorch导出onnx模型
使用pytorch
的torch.onnx.export
函数导出onnx模型,这里以pytorch的resnet18预训练模型为例
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torchvision.models
import onnx
import onnxruntime
if __name__ == '__main__':
resnet18 = torchvision.models.resnet18(pretrained=True)
input_image = torch.rand([1, 3, 224, 224],dtype=torch.float)
onnx_outpath = 'resnet18.onnx'
torch.onnx.export(resnet18,
input_image,
onnx_outpath,
opset_version=13,
verbose=True,
do_constant_folding=True,
input_names=['input'],
output_names=['output']
)
# 检查导出的onnx模型
onnx_model = onnx.load(onnx_outpath)
onnx.checker.check_model(onnx_model, full_check=True)
inferred = onnx.shape_inference.infer_shapes(onnx_model, check_type=True)
1.3 使用onnxruntime对onnx模型进行推理
从pytorch导出onnx模型之后,就可以使用onnxruntime加载模型并进行推理,还是以resnet18为例,示例代码如下
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torchvision.models
import onnx
import onnxruntime
if __name__ == '__main__':
resnet18 = torchvision.models.resnet18(pretrained=True)
input_image = torch.rand([1, 3, 224, 224],dtype=torch.float)
onnx_outpath = 'resnet18.onnx'
torch.onnx.export(resnet18,
input_image,
onnx_outpath,
opset_version=13,
verbose=True,
do_constant_folding=True,
input_names=['input'],
output_names=['output']
)
# 检查导出的onnx模型
onnx_model = onnx.load(onnx_outpath)
onnx.checker.check_model(onnx_model, full_check=True)
inferred = onnx.shape_inference.infer_shapes(onnx_model, check_type=True)
# 使用onnxruntime对onnx模型进行推理
providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(onnx_outpath, providers=providers)
output = ort_session.run(None,{'input': input_image.numpy()})
result = output[0]
print(result)
1.4 onnxruntime API
onnxruntime的python版本的API文档地址为:https://onnxruntime.ai/docs/api/python/api_summary.html,有兴趣的可以仔细看看,下面简要介绍onnxruntime中经常使用的api。
1.4.1 onnxruntime.InferenceSession
1.4.1.1 onnxruntime.InferenceSession类
类原型
class onnxruntime.InferenceSession(path_or_bytes, sess_options=None, providers=None, provider_options=None, **kwargs)
类初始化参数
- path_or_bytes:onnx或者ort模型的文件名或者序列化模型
- sess_options:会话选项
- providers:按优先级递减顺序排列的可选程序序列,如
CUDAExecutionProvider
,CPUExecutionProvider
,具体的可参考Execution Providers - provider_options:与providers列出的程序相对应的可选字典序列
1.4.1.2 onnxruntime.InferenceSession.run成员函数
函数原型
run(output_names, input_feed, run_options=None)
函数参数
- output_names:输出的名称
- input_feed:格式为
{input_name:input_value}
的字典 - run_options:参考onnxruntime.RunOptions
函数返回值
返回一个列表,列表中的每一个输出结果为numpy数组、sparse tensor(稀疏张量)、列表或者字典。
当前分类随机文章推荐
- TensorRT - 使用C++ SDK出现无法解析的外部符号 "class sample::Logger sample::gLogger"错误 阅读106次,点赞0次
- TensorRT - 使用trtexec工具转换模型、运行模型、测试网络性能 阅读3124次,点赞1次
- TensorRT - 扩展TensorRT C++API的模型输入维度,增加Dims5,Dims6,Dims7,Dims8 阅读1515次,点赞0次
- TensorRT - 使用Polygraphy工具比较onnx模型和TensorRT模型的推理结果是否一致 阅读121次,点赞0次
- TensortRT - 转换模型出现Could not locate zlibwapi.dll. Please make sure it is in your library path!错误 阅读309次,点赞0次
- TensorRT - Windows下TensorRT下载与配置 阅读1453次,点赞0次
- TensorRT - 解决INVALID_ARGUMENT: getPluginCreator could not find plugin ScatterND version 1,TensorRT找不到ScatterND插件的问题 阅读2824次,点赞0次
- TensorRT - workspace的作用 阅读157次,点赞0次
- TensorRT - 转换onnx模型出现Slice_74 requires bool or uint8 I/O but node can not be handled by Myelin错误 阅读172次,点赞0次
- TensorRT - Polygraphy工具的使用 阅读3827次,点赞0次
全站随机文章推荐
- 深度学习 - 图解Transformer,小白也能看懂的Transformer处理过程 阅读576次,点赞0次
- 默认的左手坐标系与右手坐标系的比较 阅读3736次,点赞2次
- 书籍翻译 – Fundamentals of Computer Graphics, Fourth Edition,第12章 Data Structures for Graphics中文翻译 阅读996次,点赞3次
- Python - 配置Yolov5出现ImportError: cannot import name 'PILLOW_VERSION' from 'PIL'错误 阅读1003次,点赞0次
- C++11 - std::chrono - 使用std::chrono::duration_cast进行时间转换,hours/minutes/seconds/milliseconds/microseconds相互转换,以及自定义duration进行转换 阅读1944次,点赞0次
- 深度学习 - 语音识别框架wenet中的CTC Prefix Beam Search算法的实现 阅读74次,点赞0次
- C++ - 在CTC解码算法后移除相邻重复和blank索引 阅读117次,点赞0次
- C++11 - 使用std::thread::join()/std::thread::detach()方法需要注意的点 阅读2469次,点赞0次
- 资源分享 - Advanced High Dynamic Range Imaging, First Edition 英文高清PDF下载 阅读1224次,点赞0次
- opencv-python - 读取视频,不改变视频分辨率修改视频帧率 阅读4509次,点赞2次
评论
164