[Mlir-commits] [mlir] tosa.cast: fix answer mismatch to cast f64/f32 max value to i64/i32 (PR #130116)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 6 06:56:12 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Alaa Ali (alaa-ali)

<details>
<summary>Changes</summary>

This PR fixes an issue related to integer overflow when casting a value greater than or equals
float MAX from float32 or float64 to int32 or int64 during `tosa-to-linalg` pass for `tosa.cast`.  

This issue was found while debugging a numerical mismatch between tf.cast and tfl.cast. 
tfl.cast is lowered to tosa.cast that casts between these types.  The expected values were also confirmed in PyTorch using torch.Tensor.to to cast between similar dtypes and we chose to fix this overflow issue in order to match the results with Tensorflow casting and PyTorch casting.

Example of casting F64 min / max value to I64:
EXPECTED (tf.cast results):
[-9223372036854775808,  **-9223372036854775808**]
FOUND (tosa.cast results):
[-9223372036854775808, **9223372036854775807**]

Example of casting F32 min / max value to I32:
EXPECTED (tf.cast results):
[-2147483648, **-2147483648** ]
FOUND (tosa.cast results):
[-2147483648, **2147483647** ]

---
Full diff: https://github.com/llvm/llvm-project/pull/130116.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+8-9) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-9) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..8085f1104a4cb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -618,12 +618,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
             loc, rewriter.getIntegerAttr(
                      getElementTypeOrSelf(dstTy),
                      APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
-        auto intMax = rewriter.create<arith::ConstantOp>(
-            loc, rewriter.getIntegerAttr(
-                     getElementTypeOrSelf(dstTy),
-                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
         auto maxClamped =
-            rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+            rewriter.create<arith::SelectOp>(loc, overflow, intMin, conv);
         return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
                                                 maxClamped);
       }
@@ -647,8 +643,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                      APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
                          .getSExtValue()));
 
+        auto overflow = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, rounded, intMaxFP);
+        Value maxClampedFP = rewriter.create<arith::SelectOp>(loc, overflow, intMinFP, rounded);
+
         Value clamped =
-            clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+            clampFloatHelper(loc, maxClampedFP, intMinFP, intMaxFP, rewriter);
         return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
       }
 
@@ -664,17 +663,17 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                            .getSExtValue()) +
                        1.0f));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      auto intMin = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getIntegerAttr(
                    getElementTypeOrSelf(dstTy),
-                   APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+                   APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
       auto minClampedFP =
           rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
       auto minClamped =
           rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
       auto overflow = rewriter.create<arith::CmpFOp>(
           loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
-      return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+      return rewriter.create<arith::SelectOp>(loc, overflow, intMin,
                                               minClamped);
     }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..a10053c31a8e6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -541,13 +541,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
-  // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
+  // CHECK: [[CSTMINF:%.+]] = arith.constant -2.14748365E+9 : f32
   // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
-  // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+  // CHECK: [[CSTMIN:%.+]] = arith.constant -2147483648 : i32
+  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMINF]] : f32
   // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
   // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
-  // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
+  // CHECK: arith.select [[CMP]], [[CSTMIN]], [[CONV]] : i32
   %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
@@ -591,7 +591,9 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
   // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
   // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
-  // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ugt, [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[CLAMPMAX:%.+]] = arith.select [[OVERFLOW]], [[CSTMIN]], [[ROUND]] : f16
+  // CHECK: [[MIN:%.+]] = arith.minimumf [[CLAMPMAX]], [[CSTMAX]] : f16
   // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
   // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
   %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
@@ -604,8 +606,7 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
   // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
   // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
-  // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
+  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MININT]], [[CONV]] : i32
   // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
   %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
   return
@@ -1980,11 +1981,11 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
 // CHECK:             %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
 // CHECK:             %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32
 // CHECK:             %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32
-// CHECK:             %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64
+// CHECK:             %[[INT_MIN:.*]] = arith.constant -9223372036854775808 : i64
 // CHECK:             %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32
 // CHECK:             %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64
 // CHECK:             %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32
-// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64
+// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MIN]], %[[FPTOSI]] : i64
 // CHECK:             linalg.yield %[[SELECT]] : i64
 // CHECK:           } -> tensor<1xi64>
 // CHECK:           return %[[RESULT]] : tensor<1xi64>

``````````

</details>


https://github.com/llvm/llvm-project/pull/130116


More information about the Mlir-commits mailing list