[Mlir-commits] [mlir] Fix TOSA FP16->INT16 CAST lowering (PR #79299)

Thomas Preud'homme llvmlistbot at llvm.org
Wed Jan 24 13:36:27 PST 2024


================
@@ -480,23 +480,53 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     }
 
     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
-      auto intMin = rewriter.create<arith::ConstantOp>(
+      auto intMinFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
                        .getSExtValue()));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+      // The input floating-point type has enough mantissa bits to represent
+      // the max int value so just clamp the input in the floating-point
+      // domain and convert to int. Note: the min value can be represented
+      // because it consists of a mantissa with only the lsb set.
+      if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+          dstTy.getIntOrFloatBitWidth() - 1) {
----------------
RoboTux wrote:

> Can you expand on the INT_MIN case? I'd think that the minimum 16-bit int would be -2^15 and if the FP format has a dedicated sign bit (and in turn +0 and -0), the mantissa needs to have 16 bits as well to represent 2^15 (i.e., sign bit set, msb mantissa bit set).

-2¹⁵ is -(1 << 15) [1] so we only need 1 bit of mantissa but you make me realize that I didn't consider whether the exponent can be encoded. For FP16 the bias is 15 and the exponent has 5 bits which allow us a max exponent of 2⁵-1-1-15 = 30-15 = 15 so all good. But clearly FP8->I16 would be a problem and likewise FP16->I32 though in both cases we don't need to clamp since the integer result can encode all the integer values that the floating-point can encode. So I think we need more cases:
- no need to clamp (e.g. FP8 -> I16 or FP16 -> I32)
- clamp needed but int min cannot be encoded in FP
- clamp needed and both min and max can be encoded in FP

[1] My understanding is that the real (as in real numbers) value of a non subnormal floating-point value is (-1) * sign * (mantissa << (exponent - bias)).
> 
> I haven't thought about this too much so maybe I'm missing something, but shouldn't the condition be
> 
> ```
> if (MantissaWidth() >= IntBitWidth())
> ```

For a n-bit signed int, the max int value is 0 followed by (n-1) 1 bits. So we need a mantissa width at least as much as IntBitWidth() - 1. Note that the Mantissa includes any implicit leading bit so we don't have to worry about that.

https://github.com/llvm/llvm-project/pull/79299


More information about the Mlir-commits mailing list