[Mlir-commits] [mlir] [mlir][tosa][tosa-to-linalg] tosa.cast: fix answer mismatch to cast f64/f32 max value to i64/i32 (PR #130116)

Alaa Ali llvmlistbot at llvm.org
Thu Mar 6 19:47:24 PST 2025


https://github.com/alaa-ali updated https://github.com/llvm/llvm-project/pull/130116

>From 6a1342bd7eefe9915bc2f99b2329707262a30bb4 Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Date: Thu, 6 Mar 2025 03:36:47 -0500
Subject: [PATCH 1/3] tosa.cast: fix answer mismatch to cast f64/f32 max value
 to i64/i32

---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 17 ++++++++---------
 .../TosaToLinalg/tosa-to-linalg.mlir          | 19 ++++++++++---------
 2 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..8085f1104a4cb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -618,12 +618,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
             loc, rewriter.getIntegerAttr(
                      getElementTypeOrSelf(dstTy),
                      APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
-        auto intMax = rewriter.create<arith::ConstantOp>(
-            loc, rewriter.getIntegerAttr(
-                     getElementTypeOrSelf(dstTy),
-                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
         auto maxClamped =
-            rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+            rewriter.create<arith::SelectOp>(loc, overflow, intMin, conv);
         return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
                                                 maxClamped);
       }
@@ -647,8 +643,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                      APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
                          .getSExtValue()));
 
+        auto overflow = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, rounded, intMaxFP);
+        Value maxClampedFP = rewriter.create<arith::SelectOp>(loc, overflow, intMinFP, rounded);
+
         Value clamped =
-            clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+            clampFloatHelper(loc, maxClampedFP, intMinFP, intMaxFP, rewriter);
         return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
       }
 
@@ -664,17 +663,17 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                            .getSExtValue()) +
                        1.0f));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      auto intMin = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getIntegerAttr(
                    getElementTypeOrSelf(dstTy),
-                   APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+                   APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
       auto minClampedFP =
           rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
       auto minClamped =
           rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
       auto overflow = rewriter.create<arith::CmpFOp>(
           loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
-      return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+      return rewriter.create<arith::SelectOp>(loc, overflow, intMin,
                                               minClamped);
     }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..a10053c31a8e6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -541,13 +541,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
-  // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
+  // CHECK: [[CSTMINF:%.+]] = arith.constant -2.14748365E+9 : f32
   // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
-  // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+  // CHECK: [[CSTMIN:%.+]] = arith.constant -2147483648 : i32
+  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMINF]] : f32
   // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
   // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
-  // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
+  // CHECK: arith.select [[CMP]], [[CSTMIN]], [[CONV]] : i32
   %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
@@ -591,7 +591,9 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
   // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
   // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
-  // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ugt, [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[CLAMPMAX:%.+]] = arith.select [[OVERFLOW]], [[CSTMIN]], [[ROUND]] : f16
+  // CHECK: [[MIN:%.+]] = arith.minimumf [[CLAMPMAX]], [[CSTMAX]] : f16
   // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
   // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
   %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
@@ -604,8 +606,7 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
   // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
   // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
-  // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
+  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MININT]], [[CONV]] : i32
   // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
   %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
   return
@@ -1980,11 +1981,11 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
 // CHECK:             %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
 // CHECK:             %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32
 // CHECK:             %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32
-// CHECK:             %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64
+// CHECK:             %[[INT_MIN:.*]] = arith.constant -9223372036854775808 : i64
 // CHECK:             %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32
 // CHECK:             %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64
 // CHECK:             %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32
-// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64
+// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MIN]], %[[FPTOSI]] : i64
 // CHECK:             linalg.yield %[[SELECT]] : i64
 // CHECK:           } -> tensor<1xi64>
 // CHECK:           return %[[RESULT]] : tensor<1xi64>

>From f8585ea09f1221b3b92572500da76672407523e3 Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Date: Thu, 6 Mar 2025 03:36:47 -0500
Subject: [PATCH 2/3] tosa.cast: fix answer mismatch to cast f64/f32 max value
 to i64/i32

---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 17 ++++++++---------
 .../TosaToLinalg/tosa-to-linalg.mlir          | 19 ++++++++++---------
 2 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a99cf293b9eac..8854b4690bdf5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -618,12 +618,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
             loc, rewriter.getIntegerAttr(
                      getElementTypeOrSelf(dstTy),
                      APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
-        auto intMax = rewriter.create<arith::ConstantOp>(
-            loc, rewriter.getIntegerAttr(
-                     getElementTypeOrSelf(dstTy),
-                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
         auto maxClamped =
-            rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+            rewriter.create<arith::SelectOp>(loc, overflow, intMin, conv);
         return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
                                                 maxClamped);
       }
@@ -647,8 +643,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                      APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
                          .getSExtValue()));
 
+        auto overflow = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, rounded, intMaxFP);
+        Value maxClampedFP = rewriter.create<arith::SelectOp>(loc, overflow, intMinFP, rounded);
+
         Value clamped =
-            clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+            clampFloatHelper(loc, maxClampedFP, intMinFP, intMaxFP, rewriter);
         return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
       }
 
@@ -664,17 +663,17 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                            .getSExtValue()) +
                        1.0f));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      auto intMin = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getIntegerAttr(
                    getElementTypeOrSelf(dstTy),
-                   APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+                   APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
       auto minClampedFP =
           rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
       auto minClamped =
           rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
       auto overflow = rewriter.create<arith::CmpFOp>(
           loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
-      return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+      return rewriter.create<arith::SelectOp>(loc, overflow, intMin,
                                               minClamped);
     }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9ba9965315fd3..bd6381bedf65c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -541,13 +541,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
-  // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
+  // CHECK: [[CSTMINF:%.+]] = arith.constant -2.14748365E+9 : f32
   // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
-  // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+  // CHECK: [[CSTMIN:%.+]] = arith.constant -2147483648 : i32
+  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMINF]] : f32
   // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
   // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
-  // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
+  // CHECK: arith.select [[CMP]], [[CSTMIN]], [[CONV]] : i32
   %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
@@ -591,7 +591,9 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
   // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
   // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
-  // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ugt, [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[CLAMPMAX:%.+]] = arith.select [[OVERFLOW]], [[CSTMIN]], [[ROUND]] : f16
+  // CHECK: [[MIN:%.+]] = arith.minimumf [[CLAMPMAX]], [[CSTMAX]] : f16
   // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
   // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
   %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
@@ -604,8 +606,7 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
   // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
   // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
-  // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
+  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MININT]], [[CONV]] : i32
   // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
   %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
   return
@@ -1980,11 +1981,11 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
 // CHECK:             %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
 // CHECK:             %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32
 // CHECK:             %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32
-// CHECK:             %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64
+// CHECK:             %[[INT_MIN:.*]] = arith.constant -9223372036854775808 : i64
 // CHECK:             %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32
 // CHECK:             %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64
 // CHECK:             %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32
-// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64
+// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MIN]], %[[FPTOSI]] : i64
 // CHECK:             linalg.yield %[[SELECT]] : i64
 // CHECK:           } -> tensor<1xi64>
 // CHECK:           return %[[RESULT]] : tensor<1xi64>

>From 6a4de5921292117444a108afb0bcca8f73c00cce Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Date: Thu, 6 Mar 2025 22:46:57 -0500
Subject: [PATCH 3/3] clear the code formatting errors

---
 mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8854b4690bdf5..17ebc7dc32372 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -643,8 +643,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                      APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
                          .getSExtValue()));
 
-        auto overflow = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, rounded, intMaxFP);
-        Value maxClampedFP = rewriter.create<arith::SelectOp>(loc, overflow, intMinFP, rounded);
+        auto overflow = rewriter.create<arith::CmpFOp>(
+            loc, arith::CmpFPredicate::UGT, rounded, intMaxFP);
+        Value maxClampedFP =
+            rewriter.create<arith::SelectOp>(loc, overflow, intMinFP, rounded);
 
         Value clamped =
             clampFloatHelper(loc, maxClampedFP, intMinFP, intMaxFP, rewriter);



More information about the Mlir-commits mailing list