[llvm-branch-commits] [mlir] [mlir][arith] Add support for `arith.flush_denormals` emulation (PR #192660)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Apr 18 04:10:55 PDT 2026
================
@@ -729,6 +742,109 @@ struct ScalingTruncFOpConverter
}
};
+/// Expands `arith.flush_denormals` into integer arithmetic.
+///
+/// For an IEEE-like floating-point value with a sign|exponent|mantissa
+/// bit layout, this mirrors `APFloat::isDenormal`: a value is denormal
+/// iff its biased exponent field is zero *and* its stored mantissa is
+/// non-zero. Denormal inputs are replaced by a sign-preserved zero
+/// (i.e. the operand's bits AND'ed with the sign-bit mask); all other
+/// inputs pass through unchanged.
+///
+/// Pseudocode:
+/// bits = bitcast(x, iN)
+/// expField = bits & expMask
+/// manField = bits & manMask
+/// isDenormal = (expField == 0) AND (manField != 0)
+/// signZero = bits & signMask
+/// resultBits = select(isDenormal, signZero, bits)
+/// result = bitcast(resultBits, floatTy)
+struct FlushDenormalsOpConverter
+ : public OpRewritePattern<arith::FlushDenormalsOp> {
+ using Base::Base;
+ LogicalResult matchAndRewrite(arith::FlushDenormalsOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ auto floatTy = dyn_cast<FloatType>(getElementTypeOrSelf(operandTy));
+ if (!floatTy)
+ return rewriter.notifyMatchFailure(op, "operand is not a float type");
+
+ const llvm::fltSemantics &sem = floatTy.getFloatSemantics();
+ // Restrict to IEEE-like encodings, where the sign bit is the MSB and
+ // denormals are exactly "biased exponent == 0 and non-zero mantissa".
+ if (!llvm::APFloatBase::isIEEELikeFP(sem))
+ return rewriter.notifyMatchFailure(
+ op, "only IEEE-like floating-point types are supported");
+
+ unsigned totalBits = llvm::APFloatBase::semanticsSizeInBits(sem);
+ unsigned precision = llvm::APFloatBase::semanticsPrecision(sem);
+ // Stored mantissa bits = precision - 1 (implicit leading bit not stored).
+ // Exponent field bits = totalBits - 1 (sign) - storedMantissa.
+ if (precision < 1 || precision > totalBits)
+ return rewriter.notifyMatchFailure(op, "unexpected float semantics");
+ unsigned mantissaBits = precision - 1;
+ unsigned expBits = totalBits - 1 - mantissaBits;
+ if (expBits == 0 || mantissaBits == 0)
+ return rewriter.notifyMatchFailure(
+ op, "degenerate float encoding has no exponent or mantissa");
+
+ Type intTy =
+ cloneToShapedType(operandTy, rewriter.getIntegerType(totalBits));
+ Value bits = arith::BitcastOp::create(b, intTy, operand);
+
+ // Build bit masks using APInt to support widths like 64 bits that don't
+ // fit into an `int` parameter.
+ APInt mantissaMaskVal = APInt::getLowBitsSet(totalBits, mantissaBits);
+ APInt expMaskVal =
+ APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
+ APInt signMaskVal = APInt::getOneBitSet(totalBits, totalBits - 1);
+ APInt zeroVal = APInt::getZero(totalBits);
+
+ Value mantissaMask =
+ createAPIntConst(loc, intTy, mantissaMaskVal, rewriter);
+ Value expMask = createAPIntConst(loc, intTy, expMaskVal, rewriter);
+ Value signMask = createAPIntConst(loc, intTy, signMaskVal, rewriter);
+ Value zero = createAPIntConst(loc, intTy, zeroVal, rewriter);
+
+ // expField == 0
+ Value expField = arith::AndIOp::create(b, bits, expMask);
+ Value expIsZero =
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, expField, zero);
+
+ // mantissaField != 0
+ Value mantissaField = arith::AndIOp::create(b, bits, mantissaMask);
+ Value mantissaNonZero =
----------------
matthias-springer wrote:
Good idea, I didn't think of that.
https://github.com/llvm/llvm-project/pull/192660
More information about the llvm-branch-commits
mailing list