Skip to content

[mlir][tosa] Add NaN Propagation Mode Support #121951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2025

Conversation

FranklandJack
Copy link
Contributor

The TOSA-V1.0 specification adds "nan propagation" modes as attributes for several operators. Adjust the ODS definitions of the relevant operations to include this attribute.

The defined modes are "PROPAGATE" and "IGNORE" and the PROPAGATE mode is set by default.

MAXIMUM, MINIMUM, REDUCE_MAX, REDUCE_MIN, MAX_POOL, CLAMP, and ARGMAX support this attribute.

@llvmbot
Copy link
Member

llvmbot commented Jan 7, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Jack Frankland (FranklandJack)

Changes

The TOSA-V1.0 specification adds "nan propagation" modes as attributes for several operators. Adjust the ODS definitions of the relevant operations to include this attribute.

The defined modes are "PROPAGATE" and "IGNORE" and the PROPAGATE mode is set by default.

MAXIMUM, MINIMUM, REDUCE_MAX, REDUCE_MIN, MAX_POOL, CLAMP, and ARGMAX support this attribute.


Full diff: https://github.com/llvm/llvm-project/pull/121951.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+14-7)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+8)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+1-1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..3e5e612ac02848 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -42,7 +42,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
 
   let arguments = (ins
     Tosa_Tensor: $input,
-    I32Attr: $axis
+    I32Attr: $axis,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -284,7 +285,8 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
 
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
-    Tosa_IntArrayAttr4:$pad
+    Tosa_IntArrayAttr4:$pad,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -383,7 +385,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
     I64Attr:$min_int,
     I64Attr:$max_int,
     Tosa_FloatAttr:$min_fp,
-    Tosa_FloatAttr:$max_fp
+    Tosa_FloatAttr:$max_fp,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -747,7 +750,8 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Tensor:$input2
+    Tosa_Tensor:$input2,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -770,7 +774,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Tensor:$input2
+    Tosa_Tensor:$input2,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -1377,7 +1382,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I32Attr:$axis
+    I32Attr:$axis,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -1412,7 +1418,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I32Attr:$axis
+    I32Attr:$axis,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index a6d3163d4446fa..b7b33efba937ab 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,12 +202,20 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
 //===----------------------------------------------------------------------===//
 // Iterable attributes.
 //===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
 // Supported regimes for tosa.resize.
 def Tosa_ResizeTypeAttr : StringBasedAttr<
     CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\"  || " #
           "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
     "Supported resize/upsampling strategies">;
 
+// Supported NaN propagation strategies.
+def Tosa_NanPropagationAttr : StringBasedAttr<
+    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\"  || " #
+          "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
+    "Supported NaN propagation strategies">;
+
 def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
 
 // Tensor to buffer types.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..865ac038be0b1f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -361,7 +361,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
           op, op.getType(), clampOp.getInput(),
           rewriter.getI64IntegerAttr(minInt),
           rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
-          rewriter.getF32FloatAttr(maxFp));
+          rewriter.getF32FloatAttr(maxFp), rewriter.getStringAttr("IGNORE"));
       return success();
     }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..7d9583a228ffe7 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -130,7 +130,7 @@ func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
 
 // CHECK-LABEL: @clamp_twice_is_single_clamp
 func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
-  // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
+  // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
   %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64} :  (tensor<4xi8>) -> tensor<4xi8>
   %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64} :  (tensor<4xi8>) -> tensor<4xi8>
   return %1 : tensor<4xi8>

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a couple of comments on the clamp clamp optimization, otherwise looks good!

@FranklandJack FranklandJack force-pushed the nan_propagation branch 2 times, most recently from 5ba05b9 to fd01630 Compare January 22, 2025 15:56
Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, leaving to @lhutton1 for final approval

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted a possible typo, otherwise looks fantastic, thanks!

The TOSA-V1.0 specification adds "nan propagation" modes as attributes
for several operators. Adjust the ODS definitions of the relevant
operations to include this attribute.

The defined modes are "PROPAGATE" and "IGNORE" and the PROPAGATE mode is
set by default.

MAXIMUM, MINIMUM, REDUCE_MAX, REDUCE_MIN, MAX_POOL, CLAMP, and ARGMAX
support this attribute.

Refactor the clamp + clamp optimization in order to better handle edge
cases such as invalid NaN propagation combinations and disjoint clamp
ranges.

Add lit test including negative tests for valid and invalid
optimizations as part of TOSA canonicalizations.

Signed-off-by: Jack Frankland <[email protected]>
@GeorgeARM GeorgeARM merged commit 8388040 into llvm:main Jan 23, 2025
8 checks passed
justin-ngo-arm added a commit to justin-ngo-arm/torch-mlir that referenced this pull request Jan 29, 2025
* TOSA ops' maximum, minimum, reduce_max, reduce_min, max_pool, clamp,
  and argmax has a new NaN propagation mode added as an attribute.
  LLVM PR: llvm/llvm-project#121951
* Add NaN Propagation mode attributes to the ops mentioned above in
  Torch to TOSA legalization.

Signed-off-by: Justin Ngo <[email protected]>
Change-Id: I27dfda9d0146da8742f14718c713998c436899ce
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants