[Mlir-commits] [mlir] [mlir][tosa][tosa-to-linalg] Ignore Int NaN Mode (PR #129041)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 27 03:36:42 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Jack Frankland (FranklandJack)

<details>
<summary>Changes</summary>

For non floating point operations NaN propagation mode has no meaning and can be safely ignored. For non integer types skip the compare and select materialization for NaN propagation even in "IGNORE" mode. This fixes a bug where an unchecked `cast<FloatType>()` was called in the "IGNORE" case even when the operation is acting on integers.

Update the lit tests for the NaN propagation lowering to check that the propagation logic is not materialized in the case of a non floating point type e.g. i8.

---
Full diff: https://github.com/llvm/llvm-project/pull/129041.diff


4 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+18-2) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+5) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+10) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+95) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 06831a642664e..8732ddafa24d4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -49,6 +49,11 @@ using namespace mlir::tosa;
 // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
 // code:
 //
+// In the case that the op is operating on non floating point types we ignore
+// the attribute completely, this is consistent with the TOSA spec which has
+// the following wording: "This attribute is ignored by non floating-point
+// types."
+//
 // binary<op>(lhs, rhs):
 //   result = op(lhs, rhs)
 //   if lhs == NaN return rhs
@@ -58,6 +63,10 @@ template <typename OpTy>
 static Value
 materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
                                     Value lhs, Value rhs, Value result) {
+  // NaN propagation has no meaning for non floating point types.
+  if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
+    return result;
+
   auto nanMode = op.getNanMode();
   if (nanMode == "PROPAGATE")
     return result;
@@ -449,6 +458,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
     auto clampOp = llvm::cast<tosa::ClampOp>(op);
     const auto nanMode = clampOp.getNanMode();
+
+    // NaN propagation has no meaning for non floating point types.
+    if (!isa<FloatType>(elementTy))
+      return result;
+
     // In the case of "PROPAGATE" semantics no compare and selection is
     // required.
     if (nanMode == "PROPAGATE")
@@ -1192,7 +1206,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   bool isNanIgnoreMode = false;
   if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
                 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
-    if (op.getNanMode() == "IGNORE") {
+    // NaN propagation has no meaning for non floating point types.
+    if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
       isNanIgnoreMode = true;
       // Because the TOSA spec requires the result be NaN iff all elements in
       // the reduction are NaN we can't simply perform a compare and select.
@@ -2282,7 +2297,8 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
           // In the case "IGNORE" we check if the current argument is NaN and
           // select the old index and value otherwise take the updated index and
           // value.
-          if (const auto nanMode = argmaxOp.getNanMode(); nanMode == "IGNORE") {
+          if (const auto nanMode = argmaxOp.getNanMode();
+              isa<FloatType>(inElementTy) && nanMode == "IGNORE") {
             // Unordered comparison of NaN against itself will always return
             // true.
             Value isNaN = rewriter.create<arith::CmpFOp>(
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 006e35806d64f..e3400b9ba4358 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -748,6 +748,11 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
         dilationAttr);
 
     rewriter.replaceOp(op, resultOp);
+
+    // NaN propagation has no meaning for non floating point types.
+    if (!isa<FloatType>(getElementTypeOrSelf(inputTy)))
+      return success();
+
     // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
     // compare and select materialization is required.
     //
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 332b706871547..02d2f16b74ef8 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -940,6 +940,16 @@ func.func @max_pool2d_nan_propagate(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4
 
 // -----
 
+// CHECK-LABEL: @max_pool2d_nan_ignore_int
+func.func @max_pool2d_nan_ignore_int(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x4x32x62xi8>) {
+  // CHECK: linalg.pooling_nhwc_max
+  // CHECK-NOT: linalg.generic
+  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xi8>) -> tensor<1x4x32x62xi8>
+  return %0: tensor<1x4x32x62xi8>
+}
+
+// -----
+
 // CHECK-LABEL: @max_pool2d_nan_ignore
 func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) {
   // CHECK-NOT: linalg.pooling_nhwc_max
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 78f2e173d7cb1..c3992d2cda46e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -2033,6 +2033,44 @@ func.func @reduce_max_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf3
 
 // -----
 
+// CHECK-LABEL: @reduce_min_nan_ignore_int
+func.func @reduce_min_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
+  // CHECK: linalg.reduce
+  // CHECK: arith.minsi
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  // CHECK-NOT: arith.constant 0x7FC00000
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: linalg.fill
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: select
+  // CHECK: return
+  %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<1x4xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @reduce_max_nan_ignore_int
+func.func @reduce_max_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
+  // CHECK: linalg.reduce
+  // CHECK: arith.maxsi
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  // CHECK-NOT: arith.constant 0x7FC00000
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: linalg.fill
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: select
+  // CHECK: return
+  %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<1x4xi8>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @reduce_min_nan_ignore
 func.func @reduce_min_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
   // CHECK: linalg.reduce
@@ -2095,6 +2133,32 @@ func.func @maximum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>)
 
 // -----
 
+// CHECK-LABEL: @minimum_nan_ignore_int
+func.func @minimum_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.minsi
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xi8>, tensor<5x4xi8>) -> tensor<5x4xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @maximum_nan_ignore_int
+func.func @maximum_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.maxsi
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xi8>, tensor<5x4xi8>) -> tensor<5x4xi8>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @minimum_nan_ignore
 func.func @minimum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
   // CHECK: linalg.generic
@@ -2142,6 +2206,23 @@ func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>)
 
 // -----
 
+// CHECK-LABEL: @argmax_nan_ignore_int
+func.func @argmax_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.cmpi sgt
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+ %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xi8>)  -> tensor<4xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @argmax_nan_ignore
 func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
   // CHECK: linalg.generic
@@ -2172,6 +2253,20 @@ func.func @clamp_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -
 
 // -----
 
+// CHECK-LABEL: @clamp_nan_ignore_int
+func.func @clamp_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.maxsi
+  // CHECK: arith.minsi
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %14 = tosa.clamp %arg0 {min_val = 1 : i8, max_val = 5 : i8, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<5x4xi8>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @clamp_nan_ignore
 func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
   // CHECK: linalg.generic

``````````

</details>


https://github.com/llvm/llvm-project/pull/129041


More information about the Mlir-commits mailing list