[Mlir-commits] [mlir] [mlir][tosa] Add NaN Propagation Mode Support (PR #121951)

Jack Frankland llvmlistbot at llvm.org
Wed Jan 22 06:33:28 PST 2025


https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/121951

>From 5ba05b9f87967dd37614f3e3a5e875693cd63a5b Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Thu, 3 Oct 2024 15:18:55 -0700
Subject: [PATCH] [mlir][tosa] Add NaN Propagation Mode Support

The TOSA-V1.0 specification adds "nan propagation" modes as attributes
for several operators. Adjust the ODS definitions of the relevant
operations to include this attribute.

The defined modes are "PROPAGATE" and "IGNORE" and the PROPAGATE mode is
set by default.

MAXIMUM, MINIMUM, REDUCE_MAX, REDUCE_MIN, MAX_POOL, CLAMP, and ARGMAX
support this attribute.

Refactor the clamp + clamp optimization in order to better handle edge
cases such as invalid NaN propgation combinations and disjoint clamp
ranges.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 21 +++--
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  8 ++
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 85 +++++++++++++++----
 mlir/test/Dialect/Tosa/canonicalize.mlir      |  2 +-
 4 files changed, 91 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..3e5e612ac02848 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -42,7 +42,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
 
   let arguments = (ins
     Tosa_Tensor: $input,
-    I32Attr: $axis
+    I32Attr: $axis,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -284,7 +285,8 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
 
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
-    Tosa_IntArrayAttr4:$pad
+    Tosa_IntArrayAttr4:$pad,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -383,7 +385,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
     I64Attr:$min_int,
     I64Attr:$max_int,
     Tosa_FloatAttr:$min_fp,
-    Tosa_FloatAttr:$max_fp
+    Tosa_FloatAttr:$max_fp,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -747,7 +750,8 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Tensor:$input2
+    Tosa_Tensor:$input2,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -770,7 +774,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Tensor:$input2
+    Tosa_Tensor:$input2,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -1377,7 +1382,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I32Attr:$axis
+    I32Attr:$axis,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
@@ -1412,7 +1418,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I32Attr:$axis
+    I32Attr:$axis,
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index a6d3163d4446fa..b7b33efba937ab 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,12 +202,20 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
 //===----------------------------------------------------------------------===//
 // Iterable attributes.
 //===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
 // Supported regimes for tosa.resize.
 def Tosa_ResizeTypeAttr : StringBasedAttr<
     CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\"  || " #
           "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
     "Supported resize/upsampling strategies">;
 
+// Supported NaN propagation strategies.
+def Tosa_NanPropagationAttr : StringBasedAttr<
+    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\"  || " #
+          "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
+    "Supported NaN propagation strategies">;
+
 def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
 
 // Tensor to buffer types.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..b2c44a3d92e5f4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -339,33 +339,84 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
   }
 };
 
+// Attempts the following transformation:
+//
+// For integers a, b, a', and b' such that [a, b] ∩ [c, d] ≠ ∅ and input
+// tensor X the following identity holds:
+//
+// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'),  min(b, b'))
+//
+// subject to the following valid NaN propagation semantics:
+// --------------------------------------------
+// | opNanMode | clampNanMode | resultNanMode |
+// |-----------|--------------|---------------|
+// | PROPAGATE | PROPAGATE    | PROPAGATE     |
+// | PROPAGATE | IGNORE       | IGNORE        |
+// | IGNORE    | PROPAGATE    | INVALID       |
+// | IGNORE    | IGNORE       | INGORE        |
+// |------------------------------------------|
+
 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
   using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
 
+  // Helper structure to describe the range of a clamp operation.
+  template <typename T>
+  struct ClampRange {
+    ClampRange(const T &start, const T &end) : start(start), end(end) {}
+    T start;
+    T end;
+
+    // Helper function to determine if two Clamp ranges intersect.
+    bool intersects(const ClampRange<T> &otherRange) {
+      return start < otherRange.end && otherRange.start < end;
+    }
+  };
+
   LogicalResult matchAndRewrite(tosa::ClampOp op,
                                 PatternRewriter &rewriter) const override {
-    Value input = op.getInput();
-
-    Operation *definingOp = input.getDefiningOp();
-    if (!definingOp)
+    // Check the input to the CLAMP op is itself a CLAMP.
+    auto clampOp =
+        dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
+    if (!clampOp)
       return failure();
 
-    if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
-      auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
-      auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
+    // Check we have a valid NaN propagation combination.
+    const auto opNanMode = op.getNanMode();
+    const auto clampNanMode = clampOp.getNanMode();
+    if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+      return failure();
 
-      auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
-      auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
+    // Check we have intersecting ranges.
+    const auto opMinInt = op.getMinInt();
+    const auto opMaxInt = op.getMaxInt();
+    const auto clampOpMinInt = clampOp.getMinInt();
+    const auto clampOpMaxInt = clampOp.getMaxInt();
+    ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
+    ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
+    if (!opRangeIntRange.intersects(clampRangeIntRange))
+      return failure();
 
-      rewriter.replaceOpWithNewOp<tosa::ClampOp>(
-          op, op.getType(), clampOp.getInput(),
-          rewriter.getI64IntegerAttr(minInt),
-          rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
-          rewriter.getF32FloatAttr(maxFp));
-      return success();
-    }
+    const auto opMinFloat = op.getMinFp();
+    const auto opMaxFloat = op.getMaxFp();
+    const auto clampOpMinFloat = clampOp.getMinFp();
+    const auto clampOpMaxFloat = clampOp.getMaxFp();
+    ClampRange opRangeFloatRange(opMinFloat, opMaxFloat);
+    ClampRange clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
+    if (!opRangeFloatRange.intersects(clampRangeFloatRange))
+      return failure();
 
-    return failure();
+    // Run the transformation.
+    const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
+    const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
+    const auto minInt = std::max(opMinInt, clampOpMinInt);
+    const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
+    rewriter.replaceOpWithNewOp<tosa::ClampOp>(
+        op, op.getType(), clampOp.getInput(),
+        rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
+        rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
+        rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
+                                                           : opNanMode));
+    return success();
   }
 };
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..7d9583a228ffe7 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -130,7 +130,7 @@ func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
 
 // CHECK-LABEL: @clamp_twice_is_single_clamp
 func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
-  // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
+  // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
   %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64} :  (tensor<4xi8>) -> tensor<4xi8>
   %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64} :  (tensor<4xi8>) -> tensor<4xi8>
   return %1 : tensor<4xi8>



More information about the Mlir-commits mailing list