[Mlir-commits] [mlir] Fix TOSA FP16->INT16 CAST lowering (PR #79299)

Jakub Kuderski llvmlistbot at llvm.org
Mon Jan 29 10:20:58 PST 2024


================
@@ -480,23 +480,85 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     }
 
     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
-      auto intMin = rewriter.create<arith::ConstantOp>(
+      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+      const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
+      // The range of integer values is wider than floating-point integral
+      // values so we only need to clamp infinites values.
+      if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
+          APFloat::semanticsMaxExponent(fltSemantics)) {
+        auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
+        auto posInf = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
+                                       APFloat::getInf(fltSemantics)));
+        auto negInf = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getFloatAttr(
+                     getElementTypeOrSelf(srcTy),
+                     APFloat::getInf(fltSemantics, /*Negative=*/true)));
+        auto overflow = rewriter.create<arith::CmpFOp>(
+            loc, arith::CmpFPredicate::UEQ, rounded, posInf);
+        auto underflow = rewriter.create<arith::CmpFOp>(
+            loc, arith::CmpFPredicate::UEQ, rounded, negInf);
+        auto intMin = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(
+                     getElementTypeOrSelf(dstTy),
+                     APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
+        auto intMax = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(
+                     getElementTypeOrSelf(dstTy),
+                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+        auto maxClamped =
+            rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+        return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
+                                                maxClamped);
+      }
+
+      auto intMinFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
                        .getSExtValue()));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      // The input floating-point type has enough mantissa bits to represent
+      // the max int value (n-1 bits set for a n-bit integer) so just clamp the
+      // input in the floating-point domain and convert to int. Note: the min
+      // value can be represented in the mantissa because, being a power of 2,
+      // it consists of a single leading bit.
+      if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+          dstTy.getIntOrFloatBitWidth() - 1) {
----------------
kuhar wrote:

Previous comment https://github.com/llvm/llvm-project/pull/79299#discussion_r1465461861 not addressed/resolved

https://github.com/llvm/llvm-project/pull/79299


More information about the Mlir-commits mailing list