Open
Description
Problem
Currently there is no direct lowering logic defined for onnx MeanVarianceNormalization op to torch IR. It is expanded in onnx itself before being imported to torch IR via torch-mlir's onnx_importer.
https://onnx.ai/onnx/operators/onnx__MeanVarianceNormalization.html
There are some potential issues with this approach -
- The importer should ideally be responsible for 1-1 mapping of onnx ops to torch dialect ops without doing any IR modifications.
- MeanVarianceNormalization is not identifiable in torch IR if it is expanded during import from onnx hence each primitive ops needs to be lowered to torch's corresponding ops.
- Making an exception for this op is not consistent with other onnx ops which are functions and can be expanded, but such ops nonetheless have separate lowering routines to torch IR. For example -
- Few lowered onnx ops:
CenterCrop
,GroupNormalization
etc despite being functions and thus expandable.
- Few lowered onnx ops:
MeanVarianceNormalization before lowering to torch IR
func.func @test_meanvar_norm(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [0 : si64, 2 : si64, 3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32>
return %0 : !torch.vtensor<[3,5,2,2],f32>
}
Expanded MeanVarianceNormalization before lowering to torch IR
module {
func.func @MVN_Graph(%arg0: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,224,224],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = call @"('MeanVarianceNormalization', '', 13, [tensor_type {\0A elem_type: 1\0A shape {\0A dim {\0A dim_value: 1\0A }\0A dim {\0A dim_value: 3\0A }\0A dim {\0A dim_value: 224\0A }\0A dim {\0A dim_value: 224\0A }\0A }\0A}\0A], [tensor_type {\0A elem_type: 1\0A shape {\0A dim {\0A dim_value: 1\0A }\0A dim {\0A dim_value: 3\0A }\0A dim {\0A dim_value: 224\0A }\0A dim {\0A dim_value: 224\0A }\0A }\0A}\0A], [name: \22axes\22\0Aints: 0\0Aints: 2\0Aints: 3\0Atype: INTS\0A])"(%arg0) : (!torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,224,224],f32>
return %0 : !torch.vtensor<[1,3,224,224],f32>
}
func.func private @"('MeanVarianceNormalization', '', 13, [tensor_type {\0A elem_type: 1\0A shape {\0A dim {\0A dim_value: 1\0A }\0A dim {\0A dim_value: 3\0A }\0A dim {\0A dim_value: 224\0A }\0A dim {\0A dim_value: 224\0A }\0A }\0A}\0A], [tensor_type {\0A elem_type: 1\0A shape {\0A dim {\0A dim_value: 1\0A }\0A dim {\0A dim_value: 3\0A }\0A dim {\0A dim_value: 224\0A }\0A dim {\0A dim_value: 224\0A }\0A }\0A}\0A], [name: \22axes\22\0Aints: 0\0Aints: 2\0Aints: 3\0Atype: INTS\0A])"(%arg0: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,224,224],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2.000000e+00> : tensor<f32>} : () -> !torch.vtensor<[],f32>
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<9.99999971E-10> : tensor<f32>} : () -> !torch.vtensor<[],f32>
%2 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [0 : si64, 2 : si64, 3 : si64]} : (!torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,1,1],f32>
%3 = torch.operator "onnx.Pow"(%2, %0) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,1,1],f32>
%4 = torch.operator "onnx.Pow"(%arg0, %0) : (!torch.vtensor<[1,3,224,224],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,224,224],f32>
%5 = torch.operator "onnx.ReduceMean"(%4) {torch.onnx.axes = [0 : si64, 2 : si64, 3 : si64]} : (!torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,3,1,1],f32>
%6 = torch.operator "onnx.Sub"(%5, %3) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,1,1],f32>
%7 = torch.operator "onnx.Sqrt"(%6) : (!torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,1,1],f32>
%8 = torch.operator "onnx.Sub"(%arg0, %2) : (!torch.vtensor<[1,3,224,224],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,224,224],f32>
%9 = torch.operator "onnx.Add"(%7, %1) : (!torch.vtensor<[1,3,1,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,3,1,1],f32>
%10 = torch.operator "onnx.Div"(%8, %9) : (!torch.vtensor<[1,3,224,224],f32>, !torch.vtensor<[1,3,1,1],f32>) -> !torch.vtensor<[1,3,224,224],f32>
return %10 : !torch.vtensor<[1,3,224,224],f32>
}
}
Metadata
Metadata
Assignees
Labels
No labels