[llvm-branch-commits] [mlir] [mlir][arith] Add support for `arith.flush_denormals` emulation (PR #192660)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Apr 18 02:26:17 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-arith
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add lowering pattern and a new pass `arith-expand-flush-denormals` that rewrites `arith.flush_denormals` ops with integer arithmetics. This lowering is useful for target architectures that cannot pattern-match `arith.flush_denormals` + other FP arithmetics into special instructions with FTZ semantics.
Assisted-by: claude-opus-4.7-thinking-high
---
Full diff: https://github.com/llvm/llvm-project/pull/192660.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (+5)
- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+15)
- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+121)
- (added) mlir/test/Dialect/Arith/expand-flush-denormals.mlir (+108)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 18ac0dbc8d13e..5a07a01d0928a 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -68,6 +68,11 @@ void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
+/// Add patterns to expand `arith.flush_denormals` into integer arithmetic
+/// (bitcast + bit masks + compare + select). Only matches IEEE-like
+/// floating-point types.
+void populateExpandFlushDenormalsPatterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..84986bd94a4ac 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -97,6 +97,21 @@ def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
let dependentDialects = ["vector::VectorDialect"];
}
+def ArithExpandFlushDenormalsPass : Pass<"arith-expand-flush-denormals"> {
+ let summary = "Expand arith.flush_denormals to integer arithmetic";
+ let description = [{
+ Lower `arith.flush_denormals` operations on IEEE-like floating-point
+ types to a bitcast + bit-mask + compare + select sequence using integer
+ arithmetic. A value is denormal iff its biased exponent field is zero and
+ its stored mantissa is non-zero. Denormal inputs are replaced by a
+ sign-preserved zero; all other inputs pass through unchanged.
+
+ This pass leaves `arith.flush_denormals` operating on non-IEEE-like
+ float types untouched.
+ }];
+ let dependentDialects = ["::mlir::arith::ArithDialect"];
+}
+
def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
let summary = "Emulate 2*N-bit integer operations using N-bit operations";
let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 46f8c1037d47b..753f37b009870 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -13,10 +13,12 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
namespace arith {
#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
+#define GEN_PASS_DEF_ARITHEXPANDFLUSHDENORMALSPASS
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace arith
} // namespace mlir
@@ -34,6 +36,17 @@ static Value createConst(Location loc, Type type, int value,
return arith::ConstantOp::create(rewriter, loc, attr);
}
+/// Create an integer constant from an APInt.
+static Value createAPIntConst(Location loc, Type type, const APInt &value,
+ PatternRewriter &rewriter) {
+ auto attr = IntegerAttr::get(getElementTypeOrSelf(type), value);
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+ return arith::ConstantOp::create(rewriter, loc,
+ DenseElementsAttr::get(shapedTy, attr));
+ }
+ return arith::ConstantOp::create(rewriter, loc, attr);
+}
+
/// Create a float constant.
static Value createFloatConst(Location loc, Type type, const APFloat &value,
PatternRewriter &rewriter) {
@@ -729,6 +742,109 @@ struct ScalingTruncFOpConverter
}
};
+/// Expands `arith.flush_denormals` into integer arithmetic.
+///
+/// For an IEEE-like floating-point value with a sign|exponent|mantissa
+/// bit layout, this mirrors `APFloat::isDenormal`: a value is denormal
+/// iff its biased exponent field is zero *and* its stored mantissa is
+/// non-zero. Denormal inputs are replaced by a sign-preserved zero
+/// (i.e. the operand's bits AND'ed with the sign-bit mask); all other
+/// inputs pass through unchanged.
+///
+/// Pseudocode:
+/// bits = bitcast(x, iN)
+/// expField = bits & expMask
+/// manField = bits & manMask
+/// isDenormal = (expField == 0) AND (manField != 0)
+/// signZero = bits & signMask
+/// resultBits = select(isDenormal, signZero, bits)
+/// result = bitcast(resultBits, floatTy)
+struct FlushDenormalsOpConverter
+ : public OpRewritePattern<arith::FlushDenormalsOp> {
+ using Base::Base;
+ LogicalResult matchAndRewrite(arith::FlushDenormalsOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ auto floatTy = dyn_cast<FloatType>(getElementTypeOrSelf(operandTy));
+ if (!floatTy)
+ return rewriter.notifyMatchFailure(op, "operand is not a float type");
+
+ const llvm::fltSemantics &sem = floatTy.getFloatSemantics();
+ // Restrict to IEEE-like encodings, where the sign bit is the MSB and
+ // denormals are exactly "biased exponent == 0 and non-zero mantissa".
+ if (!llvm::APFloatBase::isIEEELikeFP(sem))
+ return rewriter.notifyMatchFailure(
+ op, "only IEEE-like floating-point types are supported");
+
+ unsigned totalBits = llvm::APFloatBase::semanticsSizeInBits(sem);
+ unsigned precision = llvm::APFloatBase::semanticsPrecision(sem);
+ // Stored mantissa bits = precision - 1 (implicit leading bit not stored).
+ // Exponent field bits = totalBits - 1 (sign) - storedMantissa.
+ if (precision < 1 || precision > totalBits)
+ return rewriter.notifyMatchFailure(op, "unexpected float semantics");
+ unsigned mantissaBits = precision - 1;
+ unsigned expBits = totalBits - 1 - mantissaBits;
+ if (expBits == 0 || mantissaBits == 0)
+ return rewriter.notifyMatchFailure(
+ op, "degenerate float encoding has no exponent or mantissa");
+
+ Type intTy =
+ cloneToShapedType(operandTy, rewriter.getIntegerType(totalBits));
+ Value bits = arith::BitcastOp::create(b, intTy, operand);
+
+ // Build bit masks using APInt to support widths like 64 bits that don't
+ // fit into an `int` parameter.
+ APInt mantissaMaskVal = APInt::getLowBitsSet(totalBits, mantissaBits);
+ APInt expMaskVal =
+ APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
+ APInt signMaskVal = APInt::getOneBitSet(totalBits, totalBits - 1);
+ APInt zeroVal = APInt::getZero(totalBits);
+
+ Value mantissaMask =
+ createAPIntConst(loc, intTy, mantissaMaskVal, rewriter);
+ Value expMask = createAPIntConst(loc, intTy, expMaskVal, rewriter);
+ Value signMask = createAPIntConst(loc, intTy, signMaskVal, rewriter);
+ Value zero = createAPIntConst(loc, intTy, zeroVal, rewriter);
+
+ // expField == 0
+ Value expField = arith::AndIOp::create(b, bits, expMask);
+ Value expIsZero =
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, expField, zero);
+
+ // mantissaField != 0
+ Value mantissaField = arith::AndIOp::create(b, bits, mantissaMask);
+ Value mantissaNonZero =
+ arith::CmpIOp::create(b, arith::CmpIPredicate::ne, mantissaField, zero);
+
+ // isDenormal = (exp == 0) AND (mantissa != 0)
+ Value isDenormal = arith::AndIOp::create(b, expIsZero, mantissaNonZero);
+
+ // Flushed bits preserve the sign bit only -> ±0.0.
+ Value signOnly = arith::AndIOp::create(b, bits, signMask);
+ Value resultBits = arith::SelectOp::create(b, isDenormal, signOnly, bits);
+ Value result = arith::BitcastOp::create(b, operandTy, resultBits);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct ArithExpandFlushDenormalsPass final
+ : public arith::impl::ArithExpandFlushDenormalsPassBase<
+ ArithExpandFlushDenormalsPass> {
+ using ArithExpandFlushDenormalsPassBase::ArithExpandFlushDenormalsPassBase;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ arith::populateExpandFlushDenormalsPatterns(patterns);
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ }
+};
+
struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -831,6 +947,11 @@ void mlir::arith::populateExpandScalingExtTruncPatterns(
patterns.getContext());
}
+void mlir::arith::populateExpandFlushDenormalsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FlushDenormalsOpConverter>(patterns.getContext());
+}
+
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
populateExpandScalingExtTruncPatterns(patterns);
diff --git a/mlir/test/Dialect/Arith/expand-flush-denormals.mlir b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
new file mode 100644
index 0000000000000..add6a5adc808b
--- /dev/null
+++ b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -arith-expand-flush-denormals -split-input-file | FileCheck %s
+
+// Expansion for f32:
+// sign mask = 0x80000000 (00000000011111111111111111111111)
+// exp mask = 0x7f800000 (01111111100000000000000000000000)
+// mantissa mask = 0x007fffff (00000000011111111111111111111111)
+// Bit pattern for denormal: zero exponent and non-zero mantissa
+// Bit pattern for zero: 0/1 sign bit, remaining bits zero
+
+// CHECK-LABEL: func @flush_denormals_f32
+// CHECK-SAME: (%[[ARG0:.+]]: f32) -> f32
+// CHECK: %[[BITS:.+]] = arith.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[MAN_MASK:.+]] = arith.constant 8388607 : i32
+// CHECK: %[[EXP_MASK:.+]] = arith.constant 2139095040 : i32
+// CHECK: %[[SIGN_MASK:.+]] = arith.constant -2147483648 : i32
+// CHECK: %[[ZERO:.+]] = arith.constant 0 : i32
+// CHECK: %[[EXP:.+]] = arith.andi %[[BITS]], %[[EXP_MASK]] : i32
+// CHECK: %[[EXP_ZERO:.+]] = arith.cmpi eq, %[[EXP]], %[[ZERO]] : i32
+// CHECK: %[[MAN:.+]] = arith.andi %[[BITS]], %[[MAN_MASK]] : i32
+// CHECK: %[[MAN_NONZERO:.+]] = arith.cmpi ne, %[[MAN]], %[[ZERO]] : i32
+// CHECK: %[[IS_DEN:.+]] = arith.andi %[[EXP_ZERO]], %[[MAN_NONZERO]] : i1
+// CHECK: %[[SIGN_ONLY:.+]] = arith.andi %[[BITS]], %[[SIGN_MASK]] : i32
+// CHECK: %[[RES_BITS:.+]] = arith.select %[[IS_DEN]], %[[SIGN_ONLY]], %[[BITS]] : i32
+// CHECK: %[[RES:.+]] = arith.bitcast %[[RES_BITS]] : i32 to f32
+// CHECK: return %[[RES]] : f32
+func.func @flush_denormals_f32(%arg0: f32) -> f32 {
+ %0 = arith.flush_denormals %arg0 : f32
+ return %0 : f32
+}
+
+// -----
+
+// Expansion for bf16:
+// sign mask = 0x8000
+// exp mask = 0x7f80
+// mantissa mask = 0x007f
+
+// CHECK-LABEL: func @flush_denormals_bf16
+// CHECK: arith.bitcast %{{.*}} : bf16 to i16
+// CHECK: %[[MAN_MASK:.+]] = arith.constant 127 : i16
+// CHECK: %[[EXP_MASK:.+]] = arith.constant 32640 : i16
+// CHECK: %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK: arith.bitcast %{{.*}} : i16 to bf16
+func.func @flush_denormals_bf16(%arg0: bf16) -> bf16 {
+ %0 = arith.flush_denormals %arg0 : bf16
+ return %0 : bf16
+}
+
+// -----
+
+// Expansion for f16:
+// sign mask = 0x8000
+// exp mask = 0x7c00
+// mantissa mask = 0x03ff
+
+// CHECK-LABEL: func @flush_denormals_f16
+// CHECK: arith.bitcast %{{.*}} : f16 to i16
+// CHECK: %[[MAN_MASK:.+]] = arith.constant 1023 : i16
+// CHECK: %[[EXP_MASK:.+]] = arith.constant 31744 : i16
+// CHECK: %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK: arith.bitcast %{{.*}} : i16 to f16
+func.func @flush_denormals_f16(%arg0: f16) -> f16 {
+ %0 = arith.flush_denormals %arg0 : f16
+ return %0 : f16
+}
+
+// -----
+
+// Expansion for f64 (verifies wide APInt masks work):
+// sign mask = 0x8000000000000000 = -9223372036854775808 (signed i64)
+// exp mask = 0x7ff0000000000000 = 9218868437227405312
+// man mask = 0x000fffffffffffff = 4503599627370495
+
+// CHECK-LABEL: func @flush_denormals_f64
+// CHECK: arith.bitcast %{{.*}} : f64 to i64
+// CHECK: %[[MAN_MASK:.+]] = arith.constant 4503599627370495 : i64
+// CHECK: %[[EXP_MASK:.+]] = arith.constant 9218868437227405312 : i64
+// CHECK: %[[SIGN_MASK:.+]] = arith.constant -9223372036854775808 : i64
+// CHECK: arith.bitcast %{{.*}} : i64 to f64
+func.func @flush_denormals_f64(%arg0: f64) -> f64 {
+ %0 = arith.flush_denormals %arg0 : f64
+ return %0 : f64
+}
+
+// -----
+
+// CHECK-LABEL: func @flush_denormals_vector
+// CHECK: arith.bitcast %{{.*}} : vector<4xf32> to vector<4xi32>
+// CHECK: arith.andi %{{.*}} : vector<4xi32>
+// CHECK: arith.cmpi eq, %{{.*}} : vector<4xi32>
+// CHECK: arith.cmpi ne, %{{.*}} : vector<4xi32>
+// CHECK: arith.andi %{{.*}} : vector<4xi1>
+// CHECK: arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xi32>
+// CHECK: arith.bitcast %{{.*}} : vector<4xi32> to vector<4xf32>
+func.func @flush_denormals_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
+ %0 = arith.flush_denormals %arg0 : vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @flush_denormals_tensor
+// CHECK: arith.bitcast %{{.*}} : tensor<8xf32> to tensor<8xi32>
+// CHECK: arith.bitcast %{{.*}} : tensor<8xi32> to tensor<8xf32>
+func.func @flush_denormals_tensor(%arg0: tensor<8xf32>) -> tensor<8xf32> {
+ %0 = arith.flush_denormals %arg0 : tensor<8xf32>
+ return %0 : tensor<8xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/192660
More information about the llvm-branch-commits
mailing list