Skip to content

[OnnxToTorch] Lower onnx.MeanVarianceNormalization op to torch dialect without expansion #4218

Open
@zahidwx

Description

@zahidwx

Problem

Currently there is no direct lowering logic defined for onnx MeanVarianceNormalization op to torch IR. It is expanded in onnx itself before being imported to torch IR via torch-mlir's onnx_importer.
https://onnx.ai/onnx/operators/onnx__MeanVarianceNormalization.html

There are some potential issues with this approach -

  • The importer should ideally be responsible for 1-1 mapping of onnx ops to torch dialect ops without doing any IR modifications.
  • MeanVarianceNormalization is not identifiable in torch IR if it is expanded during import from onnx hence each primitive ops needs to be lowered to torch's corresponding ops.
  • Making an exception for this op is not consistent with other onnx ops which are functions and can be expanded, but such ops nonetheless have separate lowering routines to torch IR. For example -
    • Few lowered onnx ops: CenterCrop, GroupNormalization etc despite being functions and thus expandable.

MeanVarianceNormalization before lowering to torch IR

func.func @test_meanvar_norm(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
  %0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [0 : si64, 2 : si64, 3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32>
  return %0 : !torch.vtensor<[3,5,2,2],f32>
}

Expanded MeanVarianceNormalization before lowering to torch IR

module {
  func.func @MVN_Graph(%arg0: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,224,224],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = call @"('MeanVarianceNormalization', '', 13, [tensor_type {\0A  elem_type: 1\0A  shape {\0A    dim {\0A      dim_value: 1\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 224\0A    }\0A    dim {\0A      dim_value: 224\0A    }\0A  }\0A}\0A], [tensor_type {\0A  elem_type: 1\0A  shape {\0A    dim {\0A      dim_value: 1\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 224\0A    }\0A    dim {\0A      dim_value: 224\0A    }\0A  }\0A}\0A], [name: \22axes\22\0Aints: 0\0Aints: 2\0Aints: 3\0Atype: INTS\0A])"(%arg0) : (!torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,224,224],f32>
    return %0 : !torch.vtensor<[1,3,224,224],f32>
  }
  func.func private @"('MeanVarianceNormalization', '', 13, [tensor_type {\0A  elem_type: 1\0A  shape {\0A    dim {\0A      dim_value: 1\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 224\0A    }\0A    dim {\0A      dim_value: 224\0A    }\0A  }\0A}\0A], [tensor_type {\0A  elem_type: 1\0A  shape {\0A    dim {\0A      dim_value: 1\0A    }\0A    dim {\0A      dim_value: 3\0A    }\0A    dim {\0A      dim_value: 224\0A    }\0A    dim {\0A      dim_value: 224\0A    }\0A  }\0A}\0A], [name: \22axes\22\0Aints: 0\0Aints: 2\0Aints: 3\0Atype: INTS\0A])"(%arg0: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,224,224],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2.000000e+00> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<9.99999971E-10> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %2 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [0 : si64, 2 : si64, 3 : si64]} : (!torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,1,1],f32> 
    %3 = torch.operator "onnx.Pow"(%2, %0) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,1,1],f32> 
    %4 = torch.operator "onnx.Pow"(%arg0, %0) : (!torch.vtensor<[1,3,224,224],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,224,224],f32> 
    %5 = torch.operator "onnx.ReduceMean"(%4) {torch.onnx.axes = [0 : si64, 2 : si64, 3 : si64]} : (!torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,1,1],f32> 
    %6 = torch.operator "onnx.Sub"(%5, %3) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,1,1],f32> 
    %7 = torch.operator "onnx.Sqrt"(%6) : (!torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,1,1],f32> 
    %8 = torch.operator "onnx.Sub"(%arg0, %2) : (!torch.vtensor<[1,3,224,224],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,224,224],f32> 
    %9 = torch.operator "onnx.Add"(%7, %1) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,1,1],f32> 
    %10 = torch.operator "onnx.Div"(%8, %9) : (!torch.vtensor<[1,3,224,224],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,224,224],f32> 
    return %10 : !torch.vtensor<[1,3,224,224],f32>
  }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions