Exception during torch.onnx.export: IndexError: Argument passed to at() was not in the map

22 Views Asked by At

I try to export a very complex PyTorch model using torch.onnx.export to ONNX. Behind the scenes torch.jit.trace mechanism is used (https://pytorch.org/docs/stable/generated/torch.jit.trace.html) Technically I am calling the export function like this:

torch.onnx.export(model, (data,), "model.onnx", verbose=True, export_params=True, opset_version=11)

The tracing part of this function seems to work fine, I removed all warnings during this step and the trace graph could be created. however during the optimization of the graph I do receive the following Exception:

  File "/home/dnndev02/workspace_uidj9636/common-mmlab/yaaf/algorithm/network.py", line 206, in inference
    torch.onnx.export(model,  # or scripted_model
  File "/home/dnndev02/miniconda3/envs/mm_deploy/lib/python3.8/site-packages/torch/onnx/utils.py", line 504, in export
    _export(
  File "/home/dnndev02/miniconda3/envs/mm_deploy/lib/python3.8/site-packages/torch/onnx/utils.py", line 1529, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/dnndev02/miniconda3/envs/mm_deploy/lib/python3.8/site-packages/torch/onnx/utils.py", line 1115, in _model_to_graph
    graph = _optimize_graph(
  File "/home/dnndev02/miniconda3/envs/mm_deploy/lib/python3.8/site-packages/torch/onnx/utils.py", line 605, in _optimize_graph
    _C._jit_pass_peephole(graph, True)
IndexError: Argument passed to at() was not in the map.
python-BaseException

Unfortunately this exception is not pointing me to the concrete location in the graph that caused the problem. I have access to the graph as a string representation. However, the graph consists of 240k Lines and it is almost impossible to identify the root cause of the problem.

Example:

   %207405 : Float(4, 1, 6, 40000, strides=[480000, 480000, 80000, 2], requires_grad=0, device=cuda:1) = aten::copy_(%207403, %207400, %188882), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder
  %207410 : Float(4, 1, 6, 40000, 1, strides=[480000, 480000, 80000, 2, 1], requires_grad=0, device=cuda:1) = aten::slice(%207387, %189242, %188872, %189234, %188872), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder
  %207411 : float = prim::Constant[value=0.](), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.uniad.modules.encoder::encoder
  %207412 : Bool(4, 1, 6, 40000, 1, strides=[240000, 240000, 40000, 1, 1], requires_grad=0, device=cuda:1) = aten::gt(%207410, %207411), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder
  %207413 : Bool(4, 1, 6, 40000, 1, strides=[240000, 240000, 40000, 1, 1], requires_grad=0, device=cuda:1) = aten::__and__(%207373, %207412), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder
  %207418 : Float(4, 1, 6, 40000, 1, strides=[480000, 480000, 80000, 2, 1], requires_grad=0, device=cuda:1) = aten::slice(%207387, %189242, %188872, %189234, %188872), scope: plugins.mmdet3d_plugins.model.detectors.model_e2e.model::/plugins.mmdet3d_plugins.model.modules.encoder::encoder

Do you have an idea how to identify the root-cause of the problem or to identify the location in the graph that might cause problems. Or do you have an idea how to resolve this concrete problem ("Argument passed to at() was not in the map")?

I am using Pytorch 1.13.1. (Problem seems to persists even in 2.x versions and I do not have the possibility to upgrade) https://github.com/pytorch/pytorch/issues/97160#issuecomment-1610832426

Thanks very much

Related issues:

0

There are 0 best solutions below