[Mlir-commits] [mlir] [mlir][tosa] Add NaN Propagation Mode Support (PR #121951)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 7 07:25:55 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jack Frankland (FranklandJack)

<details>
<summary>Changes</summary>

The TOSA-V1.0 specification adds "nan propagation" modes as attributes for several operators. Adjust the ODS definitions of the relevant operations to include this attribute.

The defined modes are "PROPAGATE" and "IGNORE" and the PROPAGATE mode is set by default.

MAXIMUM, MINIMUM, REDUCE_MAX, REDUCE_MIN, MAX_POOL, CLAMP, and ARGMAX support this attribute.

---
Full diff: https://github.com/llvm/llvm-project/pull/121951.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+14-7) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+8) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+1-1) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..3e5e612ac02848 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -42,7 +42,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
 
   let arguments = (ins
     Tosa_Tensor: $input,
-    I32Attr: $axis
+    I32Attr: $axis,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -284,7 +285,8 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
 
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
-    Tosa_IntArrayAttr4:$pad
+    Tosa_IntArrayAttr4:$pad,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -383,7 +385,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
     I64Attr:$min_int,
     I64Attr:$max_int,
     Tosa_FloatAttr:$min_fp,
-    Tosa_FloatAttr:$max_fp
+    Tosa_FloatAttr:$max_fp,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -747,7 +750,8 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Tensor:$input2
+    Tosa_Tensor:$input2,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -770,7 +774,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Tensor:$input2
+    Tosa_Tensor:$input2,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -1377,7 +1382,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I32Attr:$axis
+    I32Attr:$axis,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -1412,7 +1418,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I32Attr:$axis
+    I32Attr:$axis,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index a6d3163d4446fa..b7b33efba937ab 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,12 +202,20 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
 //===----------------------------------------------------------------------===//
 // Iterable attributes.
 //===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
 // Supported regimes for tosa.resize.
 def Tosa_ResizeTypeAttr : StringBasedAttr<
     CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\"  || " #
           "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
     "Supported resize/upsampling strategies">;
 
+// Supported NaN propagation strategies.
+def Tosa_NanPropagationAttr : StringBasedAttr<
+    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\"  || " #
+          "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
+    "Supported NaN propagation strategies">;
+
 def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
 
 // Tensor to buffer types.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..865ac038be0b1f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -361,7 +361,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
           op, op.getType(), clampOp.getInput(),
           rewriter.getI64IntegerAttr(minInt),
           rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
-          rewriter.getF32FloatAttr(maxFp));
+          rewriter.getF32FloatAttr(maxFp), rewriter.getStringAttr("IGNORE"));
       return success();
     }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..7d9583a228ffe7 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -130,7 +130,7 @@ func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
 
 // CHECK-LABEL: @clamp_twice_is_single_clamp
 func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
-  // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
+  // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
   %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64} :  (tensor<4xi8>) -> tensor<4xi8>
   %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64} :  (tensor<4xi8>) -> tensor<4xi8>
   return %1 : tensor<4xi8>

``````````

</details>


https://github.com/llvm/llvm-project/pull/121951


More information about the Mlir-commits mailing list