[Mlir-commits] [mlir] b23e518 - Fix TOSA FP16->INT16 CAST lowering (#79299)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 30 08:06:12 PST 2024


Author: Thomas Preud'homme
Date: 2024-01-30T16:06:08Z
New Revision: b23e518ce0df5b0835aba245cda50379bd896374

URL: https://github.com/llvm/llvm-project/commit/b23e518ce0df5b0835aba245cda50379bd896374
DIFF: https://github.com/llvm/llvm-project/commit/b23e518ce0df5b0835aba245cda50379bd896374.diff

LOG: Fix TOSA FP16->INT16 CAST lowering (#79299)

Currently cast from FP to int is implemented by clamping on the min and
max
integer values in the floating-point domain and then converting to
integer. However, the max int values are often non representable in the
floating-point input type due to lack of mantissa bits.

This patch instead use a select acting on a compare against max int + 1
which is representable in floating-point. It also has a special lowering
for cases where the integer range is wider than the floating-point range
to clamp the infinite values.

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 647592395c876..1eb5678b41755 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -480,23 +480,88 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     }
 
     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
-      auto intMin = rewriter.create<arith::ConstantOp>(
+      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+      const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
+      // Check whether neither int min nor int max can be represented in the
+      // input floating-point type due to too short exponent range.
+      if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
+          APFloat::semanticsMaxExponent(fltSemantics)) {
+        // Use cmp + select to replace infinites by int min / int max. Other
+        // integral values can be represented in the integer space.
+        auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
+        auto posInf = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
+                                       APFloat::getInf(fltSemantics)));
+        auto negInf = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getFloatAttr(
+                     getElementTypeOrSelf(srcTy),
+                     APFloat::getInf(fltSemantics, /*Negative=*/true)));
+        auto overflow = rewriter.create<arith::CmpFOp>(
+            loc, arith::CmpFPredicate::UEQ, rounded, posInf);
+        auto underflow = rewriter.create<arith::CmpFOp>(
+            loc, arith::CmpFPredicate::UEQ, rounded, negInf);
+        auto intMin = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(
+                     getElementTypeOrSelf(dstTy),
+                     APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
+        auto intMax = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(
+                     getElementTypeOrSelf(dstTy),
+                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+        auto maxClamped =
+            rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+        return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
+                                                maxClamped);
+      }
+
+      auto intMinFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
                        .getSExtValue()));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      // Check whether the mantissa has enough bits to represent int max.
+      if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+          dstTy.getIntOrFloatBitWidth() - 1) {
+        // Int min can also be represented since it is a power of two and thus
+        // consists of a single leading bit. Therefore we can clamp the input
+        // in the floating-point domain.
+
+        auto intMaxFP = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getFloatAttr(
+                     getElementTypeOrSelf(srcTy),
+                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+                         .getSExtValue()));
+
+        Value clamped =
+            clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+        return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+      }
+
+      // Due to earlier check we know exponant range is big enough to represent
+      // int min. We can therefore rely on int max + 1 being representable as
+      // well because it's just int min with a positive sign. So clamp the min
+      // value and compare against that to select the max int value if needed.
+      auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
                    APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
-                       .getSExtValue()));
-
-      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
-
-      auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
+                           .getSExtValue() +
+                       1));
 
-      return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+      auto intMax = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIntegerAttr(
+                   getElementTypeOrSelf(dstTy),
+                   APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+      auto minClampedFP =
+          rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
+      auto minClamped =
+          rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
+      auto overflow = rewriter.create<arith::CmpFOp>(
+          loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
+      return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+                                              minClamped);
     }
 
     // Casting to boolean, integers need to only be checked as not-equal to

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1f63b7d5ca6c8..fc22a436526a6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -514,12 +514,14 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   %19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.constant -2.14748365E+9
-  // CHECK: arith.constant 2.14748365E+9
-  // CHECK: math.roundeven
-  // CHECK: arith.minimumf
-  // CHECK: arith.maximumf
-  // CHECK: arith.fptosi
+  // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
+  // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
+  // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
+  // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
+  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+  // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
+  // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
+  // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
   %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
@@ -552,13 +554,26 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   %0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.constant -1.280000e+02
-  // CHECK: arith.constant 1.270000e+02
-  // CHECK: math.roundeven
-  // CHECK: arith.minimumf
-  // CHECK: arith.maximumf
-  // CHECK: arith.fptosi
+  // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
+  // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
+  // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
+  // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
+  // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
   %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
+
+  // CHECK: linalg.generic
+  // CHECK: [[ROUND:%.+]] = math.roundeven {{%[a-z0-9_]+}} : f16
+  // CHECK: [[CONV:%.+]] = arith.fptosi [[ROUND]] : f16 to i32
+  // CHECK: [[POSINF:%.+]] = arith.constant 0x7C00 : f16
+  // CHECK: [[NEGINF:%.+]] = arith.constant 0xFC00 : f16
+  // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
+  // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
+  // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
+  // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
+  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
+  // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
+  %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
   return
 }
 


        


More information about the Mlir-commits mailing list