[Mlir-commits] [mlir] [mlir][arith] Add support for `arith.flush_denormals` emulation (PR #192660)

Matthias Springer llvmlistbot at llvm.org
Tue Apr 21 04:47:38 PDT 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/192660

>From 649143c828d3169bc247fa09b54502ac33888f76 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 17 Apr 2026 13:48:01 +0000
Subject: [PATCH 1/2] [mlir][arith] Add support for `arith.flush_denormals`
 emulation

---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   5 +
 .../mlir/Dialect/Arith/Transforms/Passes.td   |  15 +++
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 121 ++++++++++++++++++
 .../Dialect/Arith/expand-flush-denormals.mlir | 108 ++++++++++++++++
 4 files changed, 249 insertions(+)
 create mode 100644 mlir/test/Dialect/Arith/expand-flush-denormals.mlir

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 18ac0dbc8d13e..5a07a01d0928a 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -68,6 +68,11 @@ void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
 /// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
 void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand `arith.flush_denormals` into integer arithmetic
+/// (bitcast + bit masks + compare + select). Only matches IEEE-like
+/// floating-point types.
+void populateExpandFlushDenormalsPatterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..84986bd94a4ac 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -97,6 +97,21 @@ def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
   let dependentDialects = ["vector::VectorDialect"];
 }
 
+def ArithExpandFlushDenormalsPass : Pass<"arith-expand-flush-denormals"> {
+  let summary = "Expand arith.flush_denormals to integer arithmetic";
+  let description = [{
+    Lower `arith.flush_denormals` operations on IEEE-like floating-point
+    types to a bitcast + bit-mask + compare + select sequence using integer
+    arithmetic. A value is denormal iff its biased exponent field is zero and
+    its stored mantissa is non-zero. Denormal inputs are replaced by a
+    sign-preserved zero; all other inputs pass through unchanged.
+
+    This pass leaves `arith.flush_denormals` operating on non-IEEE-like
+    float types untouched.
+  }];
+  let dependentDialects = ["::mlir::arith::ArithDialect"];
+}
+
 def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
   let summary = "Emulate 2*N-bit integer operations using N-bit operations";
   let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 46f8c1037d47b..753f37b009870 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -13,10 +13,12 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 namespace mlir {
 namespace arith {
 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
+#define GEN_PASS_DEF_ARITHEXPANDFLUSHDENORMALSPASS
 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
 } // namespace arith
 } // namespace mlir
@@ -34,6 +36,17 @@ static Value createConst(Location loc, Type type, int value,
   return arith::ConstantOp::create(rewriter, loc, attr);
 }
 
+/// Create an integer constant from an APInt.
+static Value createAPIntConst(Location loc, Type type, const APInt &value,
+                              PatternRewriter &rewriter) {
+  auto attr = IntegerAttr::get(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return arith::ConstantOp::create(rewriter, loc,
+                                     DenseElementsAttr::get(shapedTy, attr));
+  }
+  return arith::ConstantOp::create(rewriter, loc, attr);
+}
+
 /// Create a float constant.
 static Value createFloatConst(Location loc, Type type, const APFloat &value,
                               PatternRewriter &rewriter) {
@@ -729,6 +742,109 @@ struct ScalingTruncFOpConverter
   }
 };
 
+/// Expands `arith.flush_denormals` into integer arithmetic.
+///
+/// For an IEEE-like floating-point value with a sign|exponent|mantissa
+/// bit layout, this mirrors `APFloat::isDenormal`: a value is denormal
+/// iff its biased exponent field is zero *and* its stored mantissa is
+/// non-zero. Denormal inputs are replaced by a sign-preserved zero
+/// (i.e. the operand's bits AND'ed with the sign-bit mask); all other
+/// inputs pass through unchanged.
+///
+/// Pseudocode:
+///   bits        = bitcast(x, iN)
+///   expField    = bits & expMask
+///   manField    = bits & manMask
+///   isDenormal  = (expField == 0) AND (manField != 0)
+///   signZero    = bits & signMask
+///   resultBits  = select(isDenormal, signZero, bits)
+///   result      = bitcast(resultBits, floatTy)
+struct FlushDenormalsOpConverter
+    : public OpRewritePattern<arith::FlushDenormalsOp> {
+  using Base::Base;
+  LogicalResult matchAndRewrite(arith::FlushDenormalsOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    ImplicitLocOpBuilder b(loc, rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    auto floatTy = dyn_cast<FloatType>(getElementTypeOrSelf(operandTy));
+    if (!floatTy)
+      return rewriter.notifyMatchFailure(op, "operand is not a float type");
+
+    const llvm::fltSemantics &sem = floatTy.getFloatSemantics();
+    // Restrict to IEEE-like encodings, where the sign bit is the MSB and
+    // denormals are exactly "biased exponent == 0 and non-zero mantissa".
+    if (!llvm::APFloatBase::isIEEELikeFP(sem))
+      return rewriter.notifyMatchFailure(
+          op, "only IEEE-like floating-point types are supported");
+
+    unsigned totalBits = llvm::APFloatBase::semanticsSizeInBits(sem);
+    unsigned precision = llvm::APFloatBase::semanticsPrecision(sem);
+    // Stored mantissa bits = precision - 1 (implicit leading bit not stored).
+    // Exponent field bits = totalBits - 1 (sign) - storedMantissa.
+    if (precision < 1 || precision > totalBits)
+      return rewriter.notifyMatchFailure(op, "unexpected float semantics");
+    unsigned mantissaBits = precision - 1;
+    unsigned expBits = totalBits - 1 - mantissaBits;
+    if (expBits == 0 || mantissaBits == 0)
+      return rewriter.notifyMatchFailure(
+          op, "degenerate float encoding has no exponent or mantissa");
+
+    Type intTy =
+        cloneToShapedType(operandTy, rewriter.getIntegerType(totalBits));
+    Value bits = arith::BitcastOp::create(b, intTy, operand);
+
+    // Build bit masks using APInt to support widths like 64 bits that don't
+    // fit into an `int` parameter.
+    APInt mantissaMaskVal = APInt::getLowBitsSet(totalBits, mantissaBits);
+    APInt expMaskVal =
+        APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
+    APInt signMaskVal = APInt::getOneBitSet(totalBits, totalBits - 1);
+    APInt zeroVal = APInt::getZero(totalBits);
+
+    Value mantissaMask =
+        createAPIntConst(loc, intTy, mantissaMaskVal, rewriter);
+    Value expMask = createAPIntConst(loc, intTy, expMaskVal, rewriter);
+    Value signMask = createAPIntConst(loc, intTy, signMaskVal, rewriter);
+    Value zero = createAPIntConst(loc, intTy, zeroVal, rewriter);
+
+    // expField == 0
+    Value expField = arith::AndIOp::create(b, bits, expMask);
+    Value expIsZero =
+        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, expField, zero);
+
+    // mantissaField != 0
+    Value mantissaField = arith::AndIOp::create(b, bits, mantissaMask);
+    Value mantissaNonZero =
+        arith::CmpIOp::create(b, arith::CmpIPredicate::ne, mantissaField, zero);
+
+    // isDenormal = (exp == 0) AND (mantissa != 0)
+    Value isDenormal = arith::AndIOp::create(b, expIsZero, mantissaNonZero);
+
+    // Flushed bits preserve the sign bit only -> ±0.0.
+    Value signOnly = arith::AndIOp::create(b, bits, signMask);
+    Value resultBits = arith::SelectOp::create(b, isDenormal, signOnly, bits);
+    Value result = arith::BitcastOp::create(b, operandTy, resultBits);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ArithExpandFlushDenormalsPass final
+    : public arith::impl::ArithExpandFlushDenormalsPassBase<
+          ArithExpandFlushDenormalsPass> {
+  using ArithExpandFlushDenormalsPassBase::ArithExpandFlushDenormalsPassBase;
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    arith::populateExpandFlushDenormalsPatterns(patterns);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -831,6 +947,11 @@ void mlir::arith::populateExpandScalingExtTruncPatterns(
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandFlushDenormalsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FlushDenormalsOpConverter>(patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
   populateExpandScalingExtTruncPatterns(patterns);
diff --git a/mlir/test/Dialect/Arith/expand-flush-denormals.mlir b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
new file mode 100644
index 0000000000000..add6a5adc808b
--- /dev/null
+++ b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -arith-expand-flush-denormals -split-input-file | FileCheck %s
+
+// Expansion for f32:
+//   sign mask      = 0x80000000   (00000000011111111111111111111111)
+//   exp  mask      = 0x7f800000   (01111111100000000000000000000000)
+//   mantissa mask  = 0x007fffff   (00000000011111111111111111111111)
+// Bit pattern for denormal: zero exponent and non-zero mantissa
+// Bit pattern for zero: 0/1 sign bit, remaining bits zero
+
+// CHECK-LABEL: func @flush_denormals_f32
+// CHECK-SAME:    (%[[ARG0:.+]]: f32) -> f32
+// CHECK:         %[[BITS:.+]] = arith.bitcast %[[ARG0]] : f32 to i32
+// CHECK:         %[[MAN_MASK:.+]] = arith.constant 8388607 : i32
+// CHECK:         %[[EXP_MASK:.+]] = arith.constant 2139095040 : i32
+// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -2147483648 : i32
+// CHECK:         %[[ZERO:.+]] = arith.constant 0 : i32
+// CHECK:         %[[EXP:.+]] = arith.andi %[[BITS]], %[[EXP_MASK]] : i32
+// CHECK:         %[[EXP_ZERO:.+]] = arith.cmpi eq, %[[EXP]], %[[ZERO]] : i32
+// CHECK:         %[[MAN:.+]] = arith.andi %[[BITS]], %[[MAN_MASK]] : i32
+// CHECK:         %[[MAN_NONZERO:.+]] = arith.cmpi ne, %[[MAN]], %[[ZERO]] : i32
+// CHECK:         %[[IS_DEN:.+]] = arith.andi %[[EXP_ZERO]], %[[MAN_NONZERO]] : i1
+// CHECK:         %[[SIGN_ONLY:.+]] = arith.andi %[[BITS]], %[[SIGN_MASK]] : i32
+// CHECK:         %[[RES_BITS:.+]] = arith.select %[[IS_DEN]], %[[SIGN_ONLY]], %[[BITS]] : i32
+// CHECK:         %[[RES:.+]] = arith.bitcast %[[RES_BITS]] : i32 to f32
+// CHECK:         return %[[RES]] : f32
+func.func @flush_denormals_f32(%arg0: f32) -> f32 {
+  %0 = arith.flush_denormals %arg0 : f32
+  return %0 : f32
+}
+
+// -----
+
+// Expansion for bf16:
+//   sign mask      = 0x8000
+//   exp  mask      = 0x7f80
+//   mantissa mask  = 0x007f
+
+// CHECK-LABEL: func @flush_denormals_bf16
+// CHECK:         arith.bitcast %{{.*}} : bf16 to i16
+// CHECK:         %[[MAN_MASK:.+]] = arith.constant 127 : i16
+// CHECK:         %[[EXP_MASK:.+]] = arith.constant 32640 : i16
+// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK:         arith.bitcast %{{.*}} : i16 to bf16
+func.func @flush_denormals_bf16(%arg0: bf16) -> bf16 {
+  %0 = arith.flush_denormals %arg0 : bf16
+  return %0 : bf16
+}
+
+// -----
+
+// Expansion for f16:
+//   sign mask      = 0x8000
+//   exp  mask      = 0x7c00
+//   mantissa mask  = 0x03ff
+
+// CHECK-LABEL: func @flush_denormals_f16
+// CHECK:         arith.bitcast %{{.*}} : f16 to i16
+// CHECK:         %[[MAN_MASK:.+]] = arith.constant 1023 : i16
+// CHECK:         %[[EXP_MASK:.+]] = arith.constant 31744 : i16
+// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK:         arith.bitcast %{{.*}} : i16 to f16
+func.func @flush_denormals_f16(%arg0: f16) -> f16 {
+  %0 = arith.flush_denormals %arg0 : f16
+  return %0 : f16
+}
+
+// -----
+
+// Expansion for f64 (verifies wide APInt masks work):
+//   sign mask = 0x8000000000000000 = -9223372036854775808 (signed i64)
+//   exp  mask = 0x7ff0000000000000 =  9218868437227405312
+//   man  mask = 0x000fffffffffffff =     4503599627370495
+
+// CHECK-LABEL: func @flush_denormals_f64
+// CHECK:         arith.bitcast %{{.*}} : f64 to i64
+// CHECK:         %[[MAN_MASK:.+]] = arith.constant 4503599627370495 : i64
+// CHECK:         %[[EXP_MASK:.+]] = arith.constant 9218868437227405312 : i64
+// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -9223372036854775808 : i64
+// CHECK:         arith.bitcast %{{.*}} : i64 to f64
+func.func @flush_denormals_f64(%arg0: f64) -> f64 {
+  %0 = arith.flush_denormals %arg0 : f64
+  return %0 : f64
+}
+
+// -----
+
+// CHECK-LABEL: func @flush_denormals_vector
+// CHECK:         arith.bitcast %{{.*}} : vector<4xf32> to vector<4xi32>
+// CHECK:         arith.andi %{{.*}} : vector<4xi32>
+// CHECK:         arith.cmpi eq, %{{.*}} : vector<4xi32>
+// CHECK:         arith.cmpi ne, %{{.*}} : vector<4xi32>
+// CHECK:         arith.andi %{{.*}} : vector<4xi1>
+// CHECK:         arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xi32>
+// CHECK:         arith.bitcast %{{.*}} : vector<4xi32> to vector<4xf32>
+func.func @flush_denormals_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
+  %0 = arith.flush_denormals %arg0 : vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @flush_denormals_tensor
+// CHECK:         arith.bitcast %{{.*}} : tensor<8xf32> to tensor<8xi32>
+// CHECK:         arith.bitcast %{{.*}} : tensor<8xi32> to tensor<8xf32>
+func.func @flush_denormals_tensor(%arg0: tensor<8xf32>) -> tensor<8xf32> {
+  %0 = arith.flush_denormals %arg0 : tensor<8xf32>
+  return %0 : tensor<8xf32>
+}

>From 2196957f8ab08ba39a5b93dbc20e8a4401796698 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 18 Apr 2026 11:10:19 +0000
Subject: [PATCH 2/2] address comments

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 47 +++++++------------
 .../Dialect/Arith/expand-flush-denormals.mlir | 42 ++++++-----------
 2 files changed, 31 insertions(+), 58 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 753f37b009870..0de5f4e9e7d30 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -744,20 +744,19 @@ struct ScalingTruncFOpConverter
 
 /// Expands `arith.flush_denormals` into integer arithmetic.
 ///
-/// For an IEEE-like floating-point value with a sign|exponent|mantissa
-/// bit layout, this mirrors `APFloat::isDenormal`: a value is denormal
-/// iff its biased exponent field is zero *and* its stored mantissa is
-/// non-zero. Denormal inputs are replaced by a sign-preserved zero
-/// (i.e. the operand's bits AND'ed with the sign-bit mask); all other
-/// inputs pass through unchanged.
+/// For an IEEE-like floating-point value with a sign|exponent|mantissa bit
+/// layout, a value is denormal iff its biased exponent field is zero and its
+/// stored mantissa is non-zero. When the exponent field is zero, the value is
+/// either pos/neg 0 (mantissa = 0) or a denormal (mantissa != 0); in both
+/// cases, clearing the mantissa bits produces the desired sign-preserved zero
+/// (a no-op for pos/neg 0, a flush for denormals). When the exponent field is
+/// non-zero, the value passes through unchanged.
 ///
 /// Pseudocode:
 ///   bits        = bitcast(x, iN)
-///   expField    = bits & expMask
-///   manField    = bits & manMask
-///   isDenormal  = (expField == 0) AND (manField != 0)
-///   signZero    = bits & signMask
-///   resultBits  = select(isDenormal, signZero, bits)
+///   expIsZero   = (bits & expMask) == 0
+///   cleared     = bits & ~manMask
+///   resultBits  = select(expIsZero, cleared, bits)
 ///   result      = bitcast(resultBits, floatTy)
 struct FlushDenormalsOpConverter
     : public OpRewritePattern<arith::FlushDenormalsOp> {
@@ -794,19 +793,13 @@ struct FlushDenormalsOpConverter
     Type intTy =
         cloneToShapedType(operandTy, rewriter.getIntegerType(totalBits));
     Value bits = arith::BitcastOp::create(b, intTy, operand);
-
-    // Build bit masks using APInt to support widths like 64 bits that don't
-    // fit into an `int` parameter.
-    APInt mantissaMaskVal = APInt::getLowBitsSet(totalBits, mantissaBits);
     APInt expMaskVal =
         APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
-    APInt signMaskVal = APInt::getOneBitSet(totalBits, totalBits - 1);
+    APInt clearMantissaMaskVal = ~APInt::getLowBitsSet(totalBits, mantissaBits);
     APInt zeroVal = APInt::getZero(totalBits);
-
-    Value mantissaMask =
-        createAPIntConst(loc, intTy, mantissaMaskVal, rewriter);
     Value expMask = createAPIntConst(loc, intTy, expMaskVal, rewriter);
-    Value signMask = createAPIntConst(loc, intTy, signMaskVal, rewriter);
+    Value clearMantissaMask =
+        createAPIntConst(loc, intTy, clearMantissaMaskVal, rewriter);
     Value zero = createAPIntConst(loc, intTy, zeroVal, rewriter);
 
     // expField == 0
@@ -814,17 +807,9 @@ struct FlushDenormalsOpConverter
     Value expIsZero =
         arith::CmpIOp::create(b, arith::CmpIPredicate::eq, expField, zero);
 
-    // mantissaField != 0
-    Value mantissaField = arith::AndIOp::create(b, bits, mantissaMask);
-    Value mantissaNonZero =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::ne, mantissaField, zero);
-
-    // isDenormal = (exp == 0) AND (mantissa != 0)
-    Value isDenormal = arith::AndIOp::create(b, expIsZero, mantissaNonZero);
-
-    // Flushed bits preserve the sign bit only -> ±0.0.
-    Value signOnly = arith::AndIOp::create(b, bits, signMask);
-    Value resultBits = arith::SelectOp::create(b, isDenormal, signOnly, bits);
+    // Clear mantissa bits: when exp == 0, this produces pos/neg 0.0.
+    Value cleared = arith::AndIOp::create(b, bits, clearMantissaMask);
+    Value resultBits = arith::SelectOp::create(b, expIsZero, cleared, bits);
     Value result = arith::BitcastOp::create(b, operandTy, resultBits);
 
     rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Arith/expand-flush-denormals.mlir b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
index add6a5adc808b..e48168cd1b5ad 100644
--- a/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
+++ b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
@@ -1,26 +1,21 @@
 // RUN: mlir-opt %s -arith-expand-flush-denormals -split-input-file | FileCheck %s
 
 // Expansion for f32:
-//   sign mask      = 0x80000000   (00000000011111111111111111111111)
-//   exp  mask      = 0x7f800000   (01111111100000000000000000000000)
-//   mantissa mask  = 0x007fffff   (00000000011111111111111111111111)
-// Bit pattern for denormal: zero exponent and non-zero mantissa
-// Bit pattern for zero: 0/1 sign bit, remaining bits zero
+//   exp  mask        = 0x7f800000   (sign 0, exp all 1, mantissa 0)
+//   clear-man mask   = 0xff800000   (sign 1, exp all 1, mantissa 0)
+// When the exponent field is zero (±0 or denormal), the mantissa bits are
+// cleared, yielding a sign-preserved zero. Otherwise the bits pass through.
 
 // CHECK-LABEL: func @flush_denormals_f32
 // CHECK-SAME:    (%[[ARG0:.+]]: f32) -> f32
 // CHECK:         %[[BITS:.+]] = arith.bitcast %[[ARG0]] : f32 to i32
-// CHECK:         %[[MAN_MASK:.+]] = arith.constant 8388607 : i32
 // CHECK:         %[[EXP_MASK:.+]] = arith.constant 2139095040 : i32
-// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -2147483648 : i32
+// CHECK:         %[[CLEAR_MAN_MASK:.+]] = arith.constant -8388608 : i32
 // CHECK:         %[[ZERO:.+]] = arith.constant 0 : i32
 // CHECK:         %[[EXP:.+]] = arith.andi %[[BITS]], %[[EXP_MASK]] : i32
 // CHECK:         %[[EXP_ZERO:.+]] = arith.cmpi eq, %[[EXP]], %[[ZERO]] : i32
-// CHECK:         %[[MAN:.+]] = arith.andi %[[BITS]], %[[MAN_MASK]] : i32
-// CHECK:         %[[MAN_NONZERO:.+]] = arith.cmpi ne, %[[MAN]], %[[ZERO]] : i32
-// CHECK:         %[[IS_DEN:.+]] = arith.andi %[[EXP_ZERO]], %[[MAN_NONZERO]] : i1
-// CHECK:         %[[SIGN_ONLY:.+]] = arith.andi %[[BITS]], %[[SIGN_MASK]] : i32
-// CHECK:         %[[RES_BITS:.+]] = arith.select %[[IS_DEN]], %[[SIGN_ONLY]], %[[BITS]] : i32
+// CHECK:         %[[CLEARED:.+]] = arith.andi %[[BITS]], %[[CLEAR_MAN_MASK]] : i32
+// CHECK:         %[[RES_BITS:.+]] = arith.select %[[EXP_ZERO]], %[[CLEARED]], %[[BITS]] : i32
 // CHECK:         %[[RES:.+]] = arith.bitcast %[[RES_BITS]] : i32 to f32
 // CHECK:         return %[[RES]] : f32
 func.func @flush_denormals_f32(%arg0: f32) -> f32 {
@@ -31,15 +26,13 @@ func.func @flush_denormals_f32(%arg0: f32) -> f32 {
 // -----
 
 // Expansion for bf16:
-//   sign mask      = 0x8000
 //   exp  mask      = 0x7f80
-//   mantissa mask  = 0x007f
+//   clear-man mask = 0xff80 (-128 as signed i16)
 
 // CHECK-LABEL: func @flush_denormals_bf16
 // CHECK:         arith.bitcast %{{.*}} : bf16 to i16
-// CHECK:         %[[MAN_MASK:.+]] = arith.constant 127 : i16
 // CHECK:         %[[EXP_MASK:.+]] = arith.constant 32640 : i16
-// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK:         %[[CLEAR_MAN_MASK:.+]] = arith.constant -128 : i16
 // CHECK:         arith.bitcast %{{.*}} : i16 to bf16
 func.func @flush_denormals_bf16(%arg0: bf16) -> bf16 {
   %0 = arith.flush_denormals %arg0 : bf16
@@ -49,15 +42,13 @@ func.func @flush_denormals_bf16(%arg0: bf16) -> bf16 {
 // -----
 
 // Expansion for f16:
-//   sign mask      = 0x8000
 //   exp  mask      = 0x7c00
-//   mantissa mask  = 0x03ff
+//   clear-man mask = 0xfc00 (-1024 as signed i16)
 
 // CHECK-LABEL: func @flush_denormals_f16
 // CHECK:         arith.bitcast %{{.*}} : f16 to i16
-// CHECK:         %[[MAN_MASK:.+]] = arith.constant 1023 : i16
 // CHECK:         %[[EXP_MASK:.+]] = arith.constant 31744 : i16
-// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK:         %[[CLEAR_MAN_MASK:.+]] = arith.constant -1024 : i16
 // CHECK:         arith.bitcast %{{.*}} : i16 to f16
 func.func @flush_denormals_f16(%arg0: f16) -> f16 {
   %0 = arith.flush_denormals %arg0 : f16
@@ -67,15 +58,13 @@ func.func @flush_denormals_f16(%arg0: f16) -> f16 {
 // -----
 
 // Expansion for f64 (verifies wide APInt masks work):
-//   sign mask = 0x8000000000000000 = -9223372036854775808 (signed i64)
-//   exp  mask = 0x7ff0000000000000 =  9218868437227405312
-//   man  mask = 0x000fffffffffffff =     4503599627370495
+//   exp  mask      = 0x7ff0000000000000 =  9218868437227405312
+//   clear-man mask = 0xfff0000000000000 = -4503599627370496 (signed i64)
 
 // CHECK-LABEL: func @flush_denormals_f64
 // CHECK:         arith.bitcast %{{.*}} : f64 to i64
-// CHECK:         %[[MAN_MASK:.+]] = arith.constant 4503599627370495 : i64
 // CHECK:         %[[EXP_MASK:.+]] = arith.constant 9218868437227405312 : i64
-// CHECK:         %[[SIGN_MASK:.+]] = arith.constant -9223372036854775808 : i64
+// CHECK:         %[[CLEAR_MAN_MASK:.+]] = arith.constant -4503599627370496 : i64
 // CHECK:         arith.bitcast %{{.*}} : i64 to f64
 func.func @flush_denormals_f64(%arg0: f64) -> f64 {
   %0 = arith.flush_denormals %arg0 : f64
@@ -88,8 +77,7 @@ func.func @flush_denormals_f64(%arg0: f64) -> f64 {
 // CHECK:         arith.bitcast %{{.*}} : vector<4xf32> to vector<4xi32>
 // CHECK:         arith.andi %{{.*}} : vector<4xi32>
 // CHECK:         arith.cmpi eq, %{{.*}} : vector<4xi32>
-// CHECK:         arith.cmpi ne, %{{.*}} : vector<4xi32>
-// CHECK:         arith.andi %{{.*}} : vector<4xi1>
+// CHECK:         arith.andi %{{.*}} : vector<4xi32>
 // CHECK:         arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xi32>
 // CHECK:         arith.bitcast %{{.*}} : vector<4xi32> to vector<4xf32>
 func.func @flush_denormals_vector(%arg0: vector<4xf32>) -> vector<4xf32> {



More information about the Mlir-commits mailing list