[llvm-branch-commits] [mlir] [mlir][arith] Add support for `arith.flush_denormals` emulation (PR #192660)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Apr 18 02:26:17 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add lowering pattern and a new pass `arith-expand-flush-denormals` that rewrites `arith.flush_denormals` ops with integer arithmetics. This lowering is useful for target architectures that cannot pattern-match `arith.flush_denormals` + other FP arithmetics into special instructions with FTZ semantics.

Assisted-by: claude-opus-4.7-thinking-high


---
Full diff: https://github.com/llvm/llvm-project/pull/192660.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (+5) 
- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+15) 
- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+121) 
- (added) mlir/test/Dialect/Arith/expand-flush-denormals.mlir (+108) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 18ac0dbc8d13e..5a07a01d0928a 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -68,6 +68,11 @@ void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
 /// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
 void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand `arith.flush_denormals` into integer arithmetic
+/// (bitcast + bit masks + compare + select). Only matches IEEE-like
+/// floating-point types.
+void populateExpandFlushDenormalsPatterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..84986bd94a4ac 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -97,6 +97,21 @@ def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
   let dependentDialects = ["vector::VectorDialect"];
 }
 
+def ArithExpandFlushDenormalsPass : Pass<"arith-expand-flush-denormals"> {
+  let summary = "Expand arith.flush_denormals to integer arithmetic";
+  let description = [{
+    Lower `arith.flush_denormals` operations on IEEE-like floating-point
+    types to a bitcast + bit-mask + compare + select sequence using integer
+    arithmetic. A value is denormal iff its biased exponent field is zero and
+    its stored mantissa is non-zero. Denormal inputs are replaced by a
+    sign-preserved zero; all other inputs pass through unchanged.
+
+    This pass leaves `arith.flush_denormals` operating on non-IEEE-like
+    float types untouched.
+  }];
+  let dependentDialects = ["::mlir::arith::ArithDialect"];
+}
+
 def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
   let summary = "Emulate 2*N-bit integer operations using N-bit operations";
   let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 46f8c1037d47b..753f37b009870 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -13,10 +13,12 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 namespace mlir {
 namespace arith {
 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
+#define GEN_PASS_DEF_ARITHEXPANDFLUSHDENORMALSPASS
 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
 } // namespace arith
 } // namespace mlir
@@ -34,6 +36,17 @@ static Value createConst(Location loc, Type type, int value,
   return arith::ConstantOp::create(rewriter, loc, attr);
 }
 
+/// Create an integer constant from an APInt.
+static Value createAPIntConst(Location loc, Type type, const APInt &value,
+                              PatternRewriter &rewriter) {
+  auto attr = IntegerAttr::get(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return arith::ConstantOp::create(rewriter, loc,
+                                     DenseElementsAttr::get(shapedTy, attr));
+  }
+  return arith::ConstantOp::create(rewriter, loc, attr);
+}
+
 /// Create a float constant.
 static Value createFloatConst(Location loc, Type type, const APFloat &value,
                               PatternRewriter &rewriter) {
@@ -729,6 +742,109 @@ struct ScalingTruncFOpConverter
   }
 };
 
+/// Expands `arith.flush_denormals` into integer arithmetic.
+///
+/// For an IEEE-like floating-point value with a sign|exponent|mantissa
+/// bit layout, this mirrors `APFloat::isDenormal`: a value is denormal
+/// iff its biased exponent field is zero *and* its stored mantissa is
+/// non-zero. Denormal inputs are replaced by a sign-preserved zero
+/// (i.e. the operand's bits AND'ed with the sign-bit mask); all other
+/// inputs pass through unchanged.
+///
+/// Pseudocode:
+///   bits        = bitcast(x, iN)
+///   expField    = bits & expMask
+///   manField    = bits & manMask
+///   isDenormal  = (expField == 0) AND (manField != 0)
+///   signZero    = bits & signMask
+///   resultBits  = select(isDenormal, signZero, bits)
+///   result      = bitcast(resultBits, floatTy)
+struct FlushDenormalsOpConverter
+    : public OpRewritePattern<arith::FlushDenormalsOp> {
+  using Base::Base;
+  LogicalResult matchAndRewrite(arith::FlushDenormalsOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    ImplicitLocOpBuilder b(loc, rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    auto floatTy = dyn_cast<FloatType>(getElementTypeOrSelf(operandTy));
+    if (!floatTy)
+      return rewriter.notifyMatchFailure(op, "operand is not a float type");
+
+    const llvm::fltSemantics &sem = floatTy.getFloatSemantics();
+    // Restrict to IEEE-like encodings, where the sign bit is the MSB and
+    // denormals are exactly "biased exponent == 0 and non-zero mantissa".
+    if (!llvm::APFloatBase::isIEEELikeFP(sem))
+      return rewriter.notifyMatchFailure(
+          op, "only IEEE-like floating-point types are supported");
+
+    unsigned totalBits = llvm::APFloatBase::semanticsSizeInBits(sem);
+    unsigned precision = llvm::APFloatBase::semanticsPrecision(sem);
+    // Stored mantissa bits = precision - 1 (implicit leading bit not stored).
+    // Exponent field bits = totalBits - 1 (sign) - storedMantissa.
+    if (precision < 1 || precision > totalBits)
+      return rewriter.notifyMatchFailure(op, "unexpected float semantics");
+    unsigned mantissaBits = precision - 1;
+    unsigned expBits = totalBits - 1 - mantissaBits;
+    if (expBits == 0 || mantissaBits == 0)
+      return rewriter.notifyMatchFailure(
+          op, "degenerate float encoding has no exponent or mantissa");
+
+    Type intTy =
+        cloneToShapedType(operandTy, rewriter.getIntegerType(totalBits));
+    Value bits = arith::BitcastOp::create(b, intTy, operand);
+
+    // Build bit masks using APInt to support widths like 64 bits that don't
+    // fit into an `int` parameter.
+    APInt mantissaMaskVal = APInt::getLowBitsSet(totalBits, mantissaBits);
+    APInt expMaskVal =
+        APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
+    APInt signMaskVal = APInt::getOneBitSet(totalBits, totalBits - 1);
+    APInt zeroVal = APInt::getZero(totalBits);
+
+    Value mantissaMask =
+        createAPIntConst(loc, intTy, mantissaMaskVal, rewriter);
+    Value expMask = createAPIntConst(loc, intTy, expMaskVal, rewriter);
+    Value signMask = createAPIntConst(loc, intTy, signMaskVal, rewriter);
+    Value zero = createAPIntConst(loc, intTy, zeroVal, rewriter);
+
+    // expField == 0
+    Value expField = arith::AndIOp::create(b, bits, expMask);
+    Value expIsZero =
+        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, expField, zero);
+
+    // mantissaField != 0
+    Value mantissaField = arith::AndIOp::create(b, bits, mantissaMask);
+    Value mantissaNonZero =
+        arith::CmpIOp::create(b, arith::CmpIPredicate::ne, mantissaField, zero);
+
+    // isDenormal = (exp == 0) AND (mantissa != 0)
+    Value isDenormal = arith::AndIOp::create(b, expIsZero, mantissaNonZero);
+
+    // Flushed bits preserve the sign bit only -> ±0.0.
+    Value signOnly = arith::AndIOp::create(b, bits, signMask);
+    Value resultBits = arith::SelectOp::create(b, isDenormal, signOnly, bits);
+    Value result = arith::BitcastOp::create(b, operandTy, resultBits);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ArithExpandFlushDenormalsPass final
+    : public arith::impl::ArithExpandFlushDenormalsPassBase<
+          ArithExpandFlushDenormalsPass> {
+  using ArithExpandFlushDenormalsPassBase::ArithExpandFlushDenormalsPassBase;
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    arith::populateExpandFlushDenormalsPatterns(patterns);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -831,6 +947,11 @@ void mlir::arith::populateExpandScalingExtTruncPatterns(
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandFlushDenormalsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FlushDenormalsOpConverter>(patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
   populateExpandScalingExtTruncPatterns(patterns);
diff --git a/mlir/test/Dialect/Arith/expand-flush-denormals.mlir b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
new file mode 100644
index 0000000000000..add6a5adc808b
--- /dev/null
+++ b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -arith-expand-flush-denormals -split-input-file | FileCheck %s
+
+// Expansion for f32:
+//   sign mask      = 0x80000000   (00000000011111111111111111111111)
+//   exp  mask      = 0x7f800000   (01111111100000000000000000000000)
+//   mantissa mask  = 0x007fffff   (00000000011111111111111111111111)
+// Bit pattern for denormal: zero exponent and non-zero mantissa
+// Bit pattern for zero: 0/1 sign bit, remaining bits zero
+
+// CHECK-LABEL: func @flush_denormals_f32
+// CHECK-SAME:    (%[[ARG0:.+]]: f32) -> f32
+// CHECK:         %[[BITS:.+]] = arith.bitcast %[[ARG0]] : f32 to i32
+// CHECK:         %[[MAN_MASK:.+]] = arith.constant 8388607 : i32
+// CHECK:         %[[EXP_MASK:.+]] = arith.constant 2139095040 : i32
+// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -2147483648 : i32
+// CHECK:         %[[ZERO:.+]] = arith.constant 0 : i32
+// CHECK:         %[[EXP:.+]] = arith.andi %[[BITS]], %[[EXP_MASK]] : i32
+// CHECK:         %[[EXP_ZERO:.+]] = arith.cmpi eq, %[[EXP]], %[[ZERO]] : i32
+// CHECK:         %[[MAN:.+]] = arith.andi %[[BITS]], %[[MAN_MASK]] : i32
+// CHECK:         %[[MAN_NONZERO:.+]] = arith.cmpi ne, %[[MAN]], %[[ZERO]] : i32
+// CHECK:         %[[IS_DEN:.+]] = arith.andi %[[EXP_ZERO]], %[[MAN_NONZERO]] : i1
+// CHECK:         %[[SIGN_ONLY:.+]] = arith.andi %[[BITS]], %[[SIGN_MASK]] : i32
+// CHECK:         %[[RES_BITS:.+]] = arith.select %[[IS_DEN]], %[[SIGN_ONLY]], %[[BITS]] : i32
+// CHECK:         %[[RES:.+]] = arith.bitcast %[[RES_BITS]] : i32 to f32
+// CHECK:         return %[[RES]] : f32
+func.func @flush_denormals_f32(%arg0: f32) -> f32 {
+  %0 = arith.flush_denormals %arg0 : f32
+  return %0 : f32
+}
+
+// -----
+
+// Expansion for bf16:
+//   sign mask      = 0x8000
+//   exp  mask      = 0x7f80
+//   mantissa mask  = 0x007f
+
+// CHECK-LABEL: func @flush_denormals_bf16
+// CHECK:         arith.bitcast %{{.*}} : bf16 to i16
+// CHECK:         %[[MAN_MASK:.+]] = arith.constant 127 : i16
+// CHECK:         %[[EXP_MASK:.+]] = arith.constant 32640 : i16
+// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK:         arith.bitcast %{{.*}} : i16 to bf16
+func.func @flush_denormals_bf16(%arg0: bf16) -> bf16 {
+  %0 = arith.flush_denormals %arg0 : bf16
+  return %0 : bf16
+}
+
+// -----
+
+// Expansion for f16:
+//   sign mask      = 0x8000
+//   exp  mask      = 0x7c00
+//   mantissa mask  = 0x03ff
+
+// CHECK-LABEL: func @flush_denormals_f16
+// CHECK:         arith.bitcast %{{.*}} : f16 to i16
+// CHECK:         %[[MAN_MASK:.+]] = arith.constant 1023 : i16
+// CHECK:         %[[EXP_MASK:.+]] = arith.constant 31744 : i16
+// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK:         arith.bitcast %{{.*}} : i16 to f16
+func.func @flush_denormals_f16(%arg0: f16) -> f16 {
+  %0 = arith.flush_denormals %arg0 : f16
+  return %0 : f16
+}
+
+// -----
+
+// Expansion for f64 (verifies wide APInt masks work):
+//   sign mask = 0x8000000000000000 = -9223372036854775808 (signed i64)
+//   exp  mask = 0x7ff0000000000000 =  9218868437227405312
+//   man  mask = 0x000fffffffffffff =     4503599627370495
+
+// CHECK-LABEL: func @flush_denormals_f64
+// CHECK:         arith.bitcast %{{.*}} : f64 to i64
+// CHECK:         %[[MAN_MASK:.+]] = arith.constant 4503599627370495 : i64
+// CHECK:         %[[EXP_MASK:.+]] = arith.constant 9218868437227405312 : i64
+// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -9223372036854775808 : i64
+// CHECK:         arith.bitcast %{{.*}} : i64 to f64
+func.func @flush_denormals_f64(%arg0: f64) -> f64 {
+  %0 = arith.flush_denormals %arg0 : f64
+  return %0 : f64
+}
+
+// -----
+
+// CHECK-LABEL: func @flush_denormals_vector
+// CHECK:         arith.bitcast %{{.*}} : vector<4xf32> to vector<4xi32>
+// CHECK:         arith.andi %{{.*}} : vector<4xi32>
+// CHECK:         arith.cmpi eq, %{{.*}} : vector<4xi32>
+// CHECK:         arith.cmpi ne, %{{.*}} : vector<4xi32>
+// CHECK:         arith.andi %{{.*}} : vector<4xi1>
+// CHECK:         arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xi32>
+// CHECK:         arith.bitcast %{{.*}} : vector<4xi32> to vector<4xf32>
+func.func @flush_denormals_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
+  %0 = arith.flush_denormals %arg0 : vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @flush_denormals_tensor
+// CHECK:         arith.bitcast %{{.*}} : tensor<8xf32> to tensor<8xi32>
+// CHECK:         arith.bitcast %{{.*}} : tensor<8xi32> to tensor<8xf32>
+func.func @flush_denormals_tensor(%arg0: tensor<8xf32>) -> tensor<8xf32> {
+  %0 = arith.flush_denormals %arg0 : tensor<8xf32>
+  return %0 : tensor<8xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/192660


More information about the llvm-branch-commits mailing list