
利用ONNX Runtime高效运行PyTorch模型
本文将指导您如何使用ONNX Runtime运行经torch.onnx.export导出的PyTorch模型,并重点解决PyTorch张量与ONNX Runtime所需NumPy数组类型不兼容的问题。
首先,我们来看一个PyTorch模型导出示例:
import torchclass SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1)torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"])
这段代码定义了一个简单的PyTorch模型SumModule,并将其导出为名为onnx.pb的ONNX模型文件。
直接使用PyTorch张量作为ONNX Runtime的输入会导致错误,因为ONNX Runtime期望的是NumPy数组。 错误信息通常提示输入类型错误。
为了解决这个问题,我们需要将PyTorch张量转换为NumPy数组。 正确的代码如下:
import onnxruntimeimport numpy as npimport torchort_session = onnxruntime.InferenceSession("onnx.pb")# 关键修改:将torch.Tensor转换为np.ndarrayx = np.ones((2, 2), dtype=np.float32)inputs = {ort_session.get_inputs()[0].name: x}print(ort_session.run(None, inputs))
这段代码加载onnx.pb文件,创建一个形状为(2, 2),数据类型为float32的NumPy数组作为模型输入。 ort_session.get_inputs()[0].name 获取输入张量的名称,确保输入数据与模型定义匹配。 ort_session.run 函数运行模型并打印输出结果。
更简洁的等效代码:
import onnxruntime as ortimport numpy as npsess = ort.InferenceSession("onnx.pb")input_data = np.ones((2, 2)).astype(np.float32)output_data = sess.run(None, {"x": input_data})[0]print(output_data)
这段代码功能相同,但更简洁易读。 关键在于使用NumPy数组作为输入。
通过以上方法,您可以成功加载并运行使用torch.onnx.export导出的PyTorch模型。 请确保输入数据的类型和形状与模型的预期输入相匹配。
以上就是如何用ONNX Runtime运行PyTorch导出的模型并解决类型不兼容问题?的详细内容,更多请关注创想鸟其它相关文章!
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1359437.html
微信扫一扫
支付宝扫一扫