[Mlir-commits] [mlir] Fix TOSA FP16->INT16 CAST lowering (PR #79299)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Jan 29 10:20:58 PST 2024
================
@@ -480,23 +480,85 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
- auto intMin = rewriter.create<arith::ConstantOp>(
+ auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+ const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
+ // The range of integer values is wider than floating-point integral
+ // values so we only need to clamp infinites values.
+ if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
+ APFloat::semanticsMaxExponent(fltSemantics)) {
+ auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
+ auto posInf = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics)));
+ auto negInf = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics, /*Negative=*/true)));
+ auto overflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UEQ, rounded, posInf);
+ auto underflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UEQ, rounded, negInf);
+ auto intMin = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
+ auto intMax = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto maxClamped =
+ rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+ return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
+ maxClamped);
+ }
+
+ auto intMinFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
- auto intMax = rewriter.create<arith::ConstantOp>(
+ // The input floating-point type has enough mantissa bits to represent
+ // the max int value (n-1 bits set for a n-bit integer) so just clamp the
+ // input in the floating-point domain and convert to int. Note: the min
+ // value can be represented in the mantissa because, being a power of 2,
+ // it consists of a single leading bit.
+ if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+ dstTy.getIntOrFloatBitWidth() - 1) {
----------------
kuhar wrote:
Previous comment https://github.com/llvm/llvm-project/pull/79299#discussion_r1465461861 not addressed/resolved
https://github.com/llvm/llvm-project/pull/79299
More information about the Mlir-commits
mailing list