Skip to content

[ONNX] Fix conversion of onnx maxpool 1D with indices to torch IR #4212

Open
@zahidwx

Description

@zahidwx

Problem

Conversion of onnx models to torch IR having maxpool op operating on 1 spatial dimension with return indices fails to materialize.

maxpool1d.mlir:4:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %0:2 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64], torch.onnx.pads = [0 : si64, 0 : si64], torch.onnx.strides = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],si64>) 
           ^
maxpool1d.mlir:4:12: note: see current operation: %1:2 = "torch.operator"(%arg0) <{name = "onnx.MaxPool"}> {torch.onnx.kernel_shape = [2 : si64], torch.onnx.pads = [0 : si64, 0 : si64], torch.onnx.strides = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],si64>)

Repro steps:

python generate_onnx.py // Dumps an onnx file
python -m torch_mlir.tools.import_onnx maxpool1d_with_indices.onnx &> maxpool1d.mlir // Emits MLIR
torch-mlir-opt --convert-torch-onnx-to-torch maxpool1d.mlir // Fails here

Create maxpool op as onnx model [generate_onnx.py]

import onnx
import onnx.helper as helper
from onnx import TensorProto

x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 3, 32])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3, 16])
indices = helper.make_tensor_value_info('indices', TensorProto.INT64, [48])

# MaxPool1D node with return_indices=True
node = helper.make_node(
    'MaxPool',
    inputs=['x'],
    outputs=['y', 'indices'],
    kernel_shape=[2],
    strides=[2],
    pads=[0, 0],
)

graph = helper.make_graph(
    [node],
    'MaxPool1DWithIndices',
    [x],
    [y, indices]
)

model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 12)])
onnx.save(model, 'maxpool1d_with_indices.onnx')

print("Saved maxpool1d_with_indices.onnx")

Generated MLIR [maxpool1d.mlir]

module {
  func.func @MaxPool1DWithIndices(%arg0: !torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],si64>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0:2 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64], torch.onnx.pads = [0 : si64, 0 : si64], torch.onnx.strides = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],si64>) 
    return %0#0, %0#1 : !torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],si64>
  }
}

Apply onnx to torch conversion pass

maxpool1d.mlir:4:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %0:2 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64], torch.onnx.pads = [0 : si64, 0 : si64], torch.onnx.strides = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],si64>) 
           ^
maxpool1d.mlir:4:12: note: see current operation: %1:2 = "torch.operator"(%arg0) <{name = "onnx.MaxPool"}> {torch.onnx.kernel_shape = [2 : si64], torch.onnx.pads = [0 : si64, 0 : si64], torch.onnx.strides = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],si64>)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions