I try to export a very complex PyTorch model using torch.onnx.export to ONNX. Behind the scenes torch.jit.trace mechanism is used (https://pytorch.org/docs/stable/generated/torch.jit.trace.html) Technically I am calling the export function like this:
torch.onnx.export(model, (data,), "model.onnx", verbose=True, export_params=True, opset_version=11)
The tracing part of this function seems to work fine, I removed all warnings during this step and the trace graph could be created. however during the optimization of the graph I do receive the following Exception:
File "/home/dnndev02/workspace_uidj9636/common-mmlab/yaaf/algorithm/network.py", line 206, in inference
torch.onnx.export(model, # or scripted_model
File "/home/dnndev02/miniconda3/envs/mm_deploy/lib/python3.8/site-packages/torch/onnx/utils.py", line 504, in export
_export(
File "/home/dnndev02/miniconda3/envs/mm_deploy/lib/python3.8/site-packages/torch/onnx/utils.py", line 1529, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/dnndev02/miniconda3/envs/mm_deploy/lib/python3.8/site-packages/torch/onnx/utils.py", line 1115, in _model_to_graph
graph = _optimize_graph(
File "/home/dnndev02/miniconda3/envs/mm_deploy/lib/python3.8/site-packages/torch/onnx/utils.py", line 605, in _optimize_graph
_C._jit_pass_peephole(graph, True)
IndexError: Argument passed to at() was not in the map.
python-BaseException
Unfortunately this exception is not pointing me to the concrete location in the graph that caused the problem. I have access to the graph as a string representation. However, the graph consists of 240k Lines and it is almost impossible to identify the root cause of the problem.
Example:
%207405 : Float(4, 1, 6, 40000, strides=[480000, 480000, 80000, 2], requires_grad=0, device=cuda:1) = aten::copy_(%207403, %207400, %188882), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder
%207410 : Float(4, 1, 6, 40000, 1, strides=[480000, 480000, 80000, 2, 1], requires_grad=0, device=cuda:1) = aten::slice(%207387, %189242, %188872, %189234, %188872), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder
%207411 : float = prim::Constant[value=0.](), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.uniad.modules.encoder::encoder
%207412 : Bool(4, 1, 6, 40000, 1, strides=[240000, 240000, 40000, 1, 1], requires_grad=0, device=cuda:1) = aten::gt(%207410, %207411), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder
%207413 : Bool(4, 1, 6, 40000, 1, strides=[240000, 240000, 40000, 1, 1], requires_grad=0, device=cuda:1) = aten::__and__(%207373, %207412), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder
%207418 : Float(4, 1, 6, 40000, 1, strides=[480000, 480000, 80000, 2, 1], requires_grad=0, device=cuda:1) = aten::slice(%207387, %189242, %188872, %189234, %188872), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder
Do you have an idea how to identify the root-cause of the problem or to identify the location in the graph that might cause problems. Or do you have an idea how to resolve this concrete problem ("Argument passed to at() was not in the map")?
I am using Pytorch 1.13.1. (Problem seems to persists even in 2.x versions and I do not have the possibility to upgrade) https://github.com/pytorch/pytorch/issues/97160#issuecomment-1610832426
Thanks very much
Related issues: