[Mlir-commits] [mlir] [mlir][arith] Add support for `arith.flush_denormals` emulation (PR #192660)
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 21 04:47:38 PDT 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/192660
>From 649143c828d3169bc247fa09b54502ac33888f76 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 17 Apr 2026 13:48:01 +0000
Subject: [PATCH 1/2] [mlir][arith] Add support for `arith.flush_denormals`
emulation
---
.../mlir/Dialect/Arith/Transforms/Passes.h | 5 +
.../mlir/Dialect/Arith/Transforms/Passes.td | 15 +++
.../Dialect/Arith/Transforms/ExpandOps.cpp | 121 ++++++++++++++++++
.../Dialect/Arith/expand-flush-denormals.mlir | 108 ++++++++++++++++
4 files changed, 249 insertions(+)
create mode 100644 mlir/test/Dialect/Arith/expand-flush-denormals.mlir
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 18ac0dbc8d13e..5a07a01d0928a 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -68,6 +68,11 @@ void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
+/// Add patterns to expand `arith.flush_denormals` into integer arithmetic
+/// (bitcast + bit masks + compare + select). Only matches IEEE-like
+/// floating-point types.
+void populateExpandFlushDenormalsPatterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..84986bd94a4ac 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -97,6 +97,21 @@ def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
let dependentDialects = ["vector::VectorDialect"];
}
+def ArithExpandFlushDenormalsPass : Pass<"arith-expand-flush-denormals"> {
+ let summary = "Expand arith.flush_denormals to integer arithmetic";
+ let description = [{
+ Lower `arith.flush_denormals` operations on IEEE-like floating-point
+ types to a bitcast + bit-mask + compare + select sequence using integer
+ arithmetic. A value is denormal iff its biased exponent field is zero and
+ its stored mantissa is non-zero. Denormal inputs are replaced by a
+ sign-preserved zero; all other inputs pass through unchanged.
+
+ This pass leaves `arith.flush_denormals` operating on non-IEEE-like
+ float types untouched.
+ }];
+ let dependentDialects = ["::mlir::arith::ArithDialect"];
+}
+
def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
let summary = "Emulate 2*N-bit integer operations using N-bit operations";
let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 46f8c1037d47b..753f37b009870 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -13,10 +13,12 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
namespace arith {
#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
+#define GEN_PASS_DEF_ARITHEXPANDFLUSHDENORMALSPASS
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace arith
} // namespace mlir
@@ -34,6 +36,17 @@ static Value createConst(Location loc, Type type, int value,
return arith::ConstantOp::create(rewriter, loc, attr);
}
+/// Create an integer constant from an APInt.
+static Value createAPIntConst(Location loc, Type type, const APInt &value,
+ PatternRewriter &rewriter) {
+ auto attr = IntegerAttr::get(getElementTypeOrSelf(type), value);
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+ return arith::ConstantOp::create(rewriter, loc,
+ DenseElementsAttr::get(shapedTy, attr));
+ }
+ return arith::ConstantOp::create(rewriter, loc, attr);
+}
+
/// Create a float constant.
static Value createFloatConst(Location loc, Type type, const APFloat &value,
PatternRewriter &rewriter) {
@@ -729,6 +742,109 @@ struct ScalingTruncFOpConverter
}
};
+/// Expands `arith.flush_denormals` into integer arithmetic.
+///
+/// For an IEEE-like floating-point value with a sign|exponent|mantissa
+/// bit layout, this mirrors `APFloat::isDenormal`: a value is denormal
+/// iff its biased exponent field is zero *and* its stored mantissa is
+/// non-zero. Denormal inputs are replaced by a sign-preserved zero
+/// (i.e. the operand's bits AND'ed with the sign-bit mask); all other
+/// inputs pass through unchanged.
+///
+/// Pseudocode:
+/// bits = bitcast(x, iN)
+/// expField = bits & expMask
+/// manField = bits & manMask
+/// isDenormal = (expField == 0) AND (manField != 0)
+/// signZero = bits & signMask
+/// resultBits = select(isDenormal, signZero, bits)
+/// result = bitcast(resultBits, floatTy)
+struct FlushDenormalsOpConverter
+ : public OpRewritePattern<arith::FlushDenormalsOp> {
+ using Base::Base;
+ LogicalResult matchAndRewrite(arith::FlushDenormalsOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ auto floatTy = dyn_cast<FloatType>(getElementTypeOrSelf(operandTy));
+ if (!floatTy)
+ return rewriter.notifyMatchFailure(op, "operand is not a float type");
+
+ const llvm::fltSemantics &sem = floatTy.getFloatSemantics();
+ // Restrict to IEEE-like encodings, where the sign bit is the MSB and
+ // denormals are exactly "biased exponent == 0 and non-zero mantissa".
+ if (!llvm::APFloatBase::isIEEELikeFP(sem))
+ return rewriter.notifyMatchFailure(
+ op, "only IEEE-like floating-point types are supported");
+
+ unsigned totalBits = llvm::APFloatBase::semanticsSizeInBits(sem);
+ unsigned precision = llvm::APFloatBase::semanticsPrecision(sem);
+ // Stored mantissa bits = precision - 1 (implicit leading bit not stored).
+ // Exponent field bits = totalBits - 1 (sign) - storedMantissa.
+ if (precision < 1 || precision > totalBits)
+ return rewriter.notifyMatchFailure(op, "unexpected float semantics");
+ unsigned mantissaBits = precision - 1;
+ unsigned expBits = totalBits - 1 - mantissaBits;
+ if (expBits == 0 || mantissaBits == 0)
+ return rewriter.notifyMatchFailure(
+ op, "degenerate float encoding has no exponent or mantissa");
+
+ Type intTy =
+ cloneToShapedType(operandTy, rewriter.getIntegerType(totalBits));
+ Value bits = arith::BitcastOp::create(b, intTy, operand);
+
+ // Build bit masks using APInt to support widths like 64 bits that don't
+ // fit into an `int` parameter.
+ APInt mantissaMaskVal = APInt::getLowBitsSet(totalBits, mantissaBits);
+ APInt expMaskVal =
+ APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
+ APInt signMaskVal = APInt::getOneBitSet(totalBits, totalBits - 1);
+ APInt zeroVal = APInt::getZero(totalBits);
+
+ Value mantissaMask =
+ createAPIntConst(loc, intTy, mantissaMaskVal, rewriter);
+ Value expMask = createAPIntConst(loc, intTy, expMaskVal, rewriter);
+ Value signMask = createAPIntConst(loc, intTy, signMaskVal, rewriter);
+ Value zero = createAPIntConst(loc, intTy, zeroVal, rewriter);
+
+ // expField == 0
+ Value expField = arith::AndIOp::create(b, bits, expMask);
+ Value expIsZero =
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, expField, zero);
+
+ // mantissaField != 0
+ Value mantissaField = arith::AndIOp::create(b, bits, mantissaMask);
+ Value mantissaNonZero =
+ arith::CmpIOp::create(b, arith::CmpIPredicate::ne, mantissaField, zero);
+
+ // isDenormal = (exp == 0) AND (mantissa != 0)
+ Value isDenormal = arith::AndIOp::create(b, expIsZero, mantissaNonZero);
+
+ // Flushed bits preserve the sign bit only -> ±0.0.
+ Value signOnly = arith::AndIOp::create(b, bits, signMask);
+ Value resultBits = arith::SelectOp::create(b, isDenormal, signOnly, bits);
+ Value result = arith::BitcastOp::create(b, operandTy, resultBits);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct ArithExpandFlushDenormalsPass final
+ : public arith::impl::ArithExpandFlushDenormalsPassBase<
+ ArithExpandFlushDenormalsPass> {
+ using ArithExpandFlushDenormalsPassBase::ArithExpandFlushDenormalsPassBase;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ arith::populateExpandFlushDenormalsPatterns(patterns);
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ }
+};
+
struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -831,6 +947,11 @@ void mlir::arith::populateExpandScalingExtTruncPatterns(
patterns.getContext());
}
+void mlir::arith::populateExpandFlushDenormalsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FlushDenormalsOpConverter>(patterns.getContext());
+}
+
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
populateExpandScalingExtTruncPatterns(patterns);
diff --git a/mlir/test/Dialect/Arith/expand-flush-denormals.mlir b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
new file mode 100644
index 0000000000000..add6a5adc808b
--- /dev/null
+++ b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -arith-expand-flush-denormals -split-input-file | FileCheck %s
+
+// Expansion for f32:
+// sign mask = 0x80000000 (00000000011111111111111111111111)
+// exp mask = 0x7f800000 (01111111100000000000000000000000)
+// mantissa mask = 0x007fffff (00000000011111111111111111111111)
+// Bit pattern for denormal: zero exponent and non-zero mantissa
+// Bit pattern for zero: 0/1 sign bit, remaining bits zero
+
+// CHECK-LABEL: func @flush_denormals_f32
+// CHECK-SAME: (%[[ARG0:.+]]: f32) -> f32
+// CHECK: %[[BITS:.+]] = arith.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[MAN_MASK:.+]] = arith.constant 8388607 : i32
+// CHECK: %[[EXP_MASK:.+]] = arith.constant 2139095040 : i32
+// CHECK: %[[SIGN_MASK:.+]] = arith.constant -2147483648 : i32
+// CHECK: %[[ZERO:.+]] = arith.constant 0 : i32
+// CHECK: %[[EXP:.+]] = arith.andi %[[BITS]], %[[EXP_MASK]] : i32
+// CHECK: %[[EXP_ZERO:.+]] = arith.cmpi eq, %[[EXP]], %[[ZERO]] : i32
+// CHECK: %[[MAN:.+]] = arith.andi %[[BITS]], %[[MAN_MASK]] : i32
+// CHECK: %[[MAN_NONZERO:.+]] = arith.cmpi ne, %[[MAN]], %[[ZERO]] : i32
+// CHECK: %[[IS_DEN:.+]] = arith.andi %[[EXP_ZERO]], %[[MAN_NONZERO]] : i1
+// CHECK: %[[SIGN_ONLY:.+]] = arith.andi %[[BITS]], %[[SIGN_MASK]] : i32
+// CHECK: %[[RES_BITS:.+]] = arith.select %[[IS_DEN]], %[[SIGN_ONLY]], %[[BITS]] : i32
+// CHECK: %[[RES:.+]] = arith.bitcast %[[RES_BITS]] : i32 to f32
+// CHECK: return %[[RES]] : f32
+func.func @flush_denormals_f32(%arg0: f32) -> f32 {
+ %0 = arith.flush_denormals %arg0 : f32
+ return %0 : f32
+}
+
+// -----
+
+// Expansion for bf16:
+// sign mask = 0x8000
+// exp mask = 0x7f80
+// mantissa mask = 0x007f
+
+// CHECK-LABEL: func @flush_denormals_bf16
+// CHECK: arith.bitcast %{{.*}} : bf16 to i16
+// CHECK: %[[MAN_MASK:.+]] = arith.constant 127 : i16
+// CHECK: %[[EXP_MASK:.+]] = arith.constant 32640 : i16
+// CHECK: %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK: arith.bitcast %{{.*}} : i16 to bf16
+func.func @flush_denormals_bf16(%arg0: bf16) -> bf16 {
+ %0 = arith.flush_denormals %arg0 : bf16
+ return %0 : bf16
+}
+
+// -----
+
+// Expansion for f16:
+// sign mask = 0x8000
+// exp mask = 0x7c00
+// mantissa mask = 0x03ff
+
+// CHECK-LABEL: func @flush_denormals_f16
+// CHECK: arith.bitcast %{{.*}} : f16 to i16
+// CHECK: %[[MAN_MASK:.+]] = arith.constant 1023 : i16
+// CHECK: %[[EXP_MASK:.+]] = arith.constant 31744 : i16
+// CHECK: %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK: arith.bitcast %{{.*}} : i16 to f16
+func.func @flush_denormals_f16(%arg0: f16) -> f16 {
+ %0 = arith.flush_denormals %arg0 : f16
+ return %0 : f16
+}
+
+// -----
+
+// Expansion for f64 (verifies wide APInt masks work):
+// sign mask = 0x8000000000000000 = -9223372036854775808 (signed i64)
+// exp mask = 0x7ff0000000000000 = 9218868437227405312
+// man mask = 0x000fffffffffffff = 4503599627370495
+
+// CHECK-LABEL: func @flush_denormals_f64
+// CHECK: arith.bitcast %{{.*}} : f64 to i64
+// CHECK: %[[MAN_MASK:.+]] = arith.constant 4503599627370495 : i64
+// CHECK: %[[EXP_MASK:.+]] = arith.constant 9218868437227405312 : i64
+// CHECK: %[[SIGN_MASK:.+]] = arith.constant -9223372036854775808 : i64
+// CHECK: arith.bitcast %{{.*}} : i64 to f64
+func.func @flush_denormals_f64(%arg0: f64) -> f64 {
+ %0 = arith.flush_denormals %arg0 : f64
+ return %0 : f64
+}
+
+// -----
+
+// CHECK-LABEL: func @flush_denormals_vector
+// CHECK: arith.bitcast %{{.*}} : vector<4xf32> to vector<4xi32>
+// CHECK: arith.andi %{{.*}} : vector<4xi32>
+// CHECK: arith.cmpi eq, %{{.*}} : vector<4xi32>
+// CHECK: arith.cmpi ne, %{{.*}} : vector<4xi32>
+// CHECK: arith.andi %{{.*}} : vector<4xi1>
+// CHECK: arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xi32>
+// CHECK: arith.bitcast %{{.*}} : vector<4xi32> to vector<4xf32>
+func.func @flush_denormals_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
+ %0 = arith.flush_denormals %arg0 : vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @flush_denormals_tensor
+// CHECK: arith.bitcast %{{.*}} : tensor<8xf32> to tensor<8xi32>
+// CHECK: arith.bitcast %{{.*}} : tensor<8xi32> to tensor<8xf32>
+func.func @flush_denormals_tensor(%arg0: tensor<8xf32>) -> tensor<8xf32> {
+ %0 = arith.flush_denormals %arg0 : tensor<8xf32>
+ return %0 : tensor<8xf32>
+}
>From 2196957f8ab08ba39a5b93dbc20e8a4401796698 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 18 Apr 2026 11:10:19 +0000
Subject: [PATCH 2/2] address comments
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 47 +++++++------------
.../Dialect/Arith/expand-flush-denormals.mlir | 42 ++++++-----------
2 files changed, 31 insertions(+), 58 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 753f37b009870..0de5f4e9e7d30 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -744,20 +744,19 @@ struct ScalingTruncFOpConverter
/// Expands `arith.flush_denormals` into integer arithmetic.
///
-/// For an IEEE-like floating-point value with a sign|exponent|mantissa
-/// bit layout, this mirrors `APFloat::isDenormal`: a value is denormal
-/// iff its biased exponent field is zero *and* its stored mantissa is
-/// non-zero. Denormal inputs are replaced by a sign-preserved zero
-/// (i.e. the operand's bits AND'ed with the sign-bit mask); all other
-/// inputs pass through unchanged.
+/// For an IEEE-like floating-point value with a sign|exponent|mantissa bit
+/// layout, a value is denormal iff its biased exponent field is zero and its
+/// stored mantissa is non-zero. When the exponent field is zero, the value is
+/// either pos/neg 0 (mantissa = 0) or a denormal (mantissa != 0); in both
+/// cases, clearing the mantissa bits produces the desired sign-preserved zero
+/// (a no-op for pos/neg 0, a flush for denormals). When the exponent field is
+/// non-zero, the value passes through unchanged.
///
/// Pseudocode:
/// bits = bitcast(x, iN)
-/// expField = bits & expMask
-/// manField = bits & manMask
-/// isDenormal = (expField == 0) AND (manField != 0)
-/// signZero = bits & signMask
-/// resultBits = select(isDenormal, signZero, bits)
+/// expIsZero = (bits & expMask) == 0
+/// cleared = bits & ~manMask
+/// resultBits = select(expIsZero, cleared, bits)
/// result = bitcast(resultBits, floatTy)
struct FlushDenormalsOpConverter
: public OpRewritePattern<arith::FlushDenormalsOp> {
@@ -794,19 +793,13 @@ struct FlushDenormalsOpConverter
Type intTy =
cloneToShapedType(operandTy, rewriter.getIntegerType(totalBits));
Value bits = arith::BitcastOp::create(b, intTy, operand);
-
- // Build bit masks using APInt to support widths like 64 bits that don't
- // fit into an `int` parameter.
- APInt mantissaMaskVal = APInt::getLowBitsSet(totalBits, mantissaBits);
APInt expMaskVal =
APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
- APInt signMaskVal = APInt::getOneBitSet(totalBits, totalBits - 1);
+ APInt clearMantissaMaskVal = ~APInt::getLowBitsSet(totalBits, mantissaBits);
APInt zeroVal = APInt::getZero(totalBits);
-
- Value mantissaMask =
- createAPIntConst(loc, intTy, mantissaMaskVal, rewriter);
Value expMask = createAPIntConst(loc, intTy, expMaskVal, rewriter);
- Value signMask = createAPIntConst(loc, intTy, signMaskVal, rewriter);
+ Value clearMantissaMask =
+ createAPIntConst(loc, intTy, clearMantissaMaskVal, rewriter);
Value zero = createAPIntConst(loc, intTy, zeroVal, rewriter);
// expField == 0
@@ -814,17 +807,9 @@ struct FlushDenormalsOpConverter
Value expIsZero =
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, expField, zero);
- // mantissaField != 0
- Value mantissaField = arith::AndIOp::create(b, bits, mantissaMask);
- Value mantissaNonZero =
- arith::CmpIOp::create(b, arith::CmpIPredicate::ne, mantissaField, zero);
-
- // isDenormal = (exp == 0) AND (mantissa != 0)
- Value isDenormal = arith::AndIOp::create(b, expIsZero, mantissaNonZero);
-
- // Flushed bits preserve the sign bit only -> ±0.0.
- Value signOnly = arith::AndIOp::create(b, bits, signMask);
- Value resultBits = arith::SelectOp::create(b, isDenormal, signOnly, bits);
+ // Clear mantissa bits: when exp == 0, this produces pos/neg 0.0.
+ Value cleared = arith::AndIOp::create(b, bits, clearMantissaMask);
+ Value resultBits = arith::SelectOp::create(b, expIsZero, cleared, bits);
Value result = arith::BitcastOp::create(b, operandTy, resultBits);
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Arith/expand-flush-denormals.mlir b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
index add6a5adc808b..e48168cd1b5ad 100644
--- a/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
+++ b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir
@@ -1,26 +1,21 @@
// RUN: mlir-opt %s -arith-expand-flush-denormals -split-input-file | FileCheck %s
// Expansion for f32:
-// sign mask = 0x80000000 (00000000011111111111111111111111)
-// exp mask = 0x7f800000 (01111111100000000000000000000000)
-// mantissa mask = 0x007fffff (00000000011111111111111111111111)
-// Bit pattern for denormal: zero exponent and non-zero mantissa
-// Bit pattern for zero: 0/1 sign bit, remaining bits zero
+// exp mask = 0x7f800000 (sign 0, exp all 1, mantissa 0)
+// clear-man mask = 0xff800000 (sign 1, exp all 1, mantissa 0)
+// When the exponent field is zero (±0 or denormal), the mantissa bits are
+// cleared, yielding a sign-preserved zero. Otherwise the bits pass through.
// CHECK-LABEL: func @flush_denormals_f32
// CHECK-SAME: (%[[ARG0:.+]]: f32) -> f32
// CHECK: %[[BITS:.+]] = arith.bitcast %[[ARG0]] : f32 to i32
-// CHECK: %[[MAN_MASK:.+]] = arith.constant 8388607 : i32
// CHECK: %[[EXP_MASK:.+]] = arith.constant 2139095040 : i32
-// CHECK: %[[SIGN_MASK:.+]] = arith.constant -2147483648 : i32
+// CHECK: %[[CLEAR_MAN_MASK:.+]] = arith.constant -8388608 : i32
// CHECK: %[[ZERO:.+]] = arith.constant 0 : i32
// CHECK: %[[EXP:.+]] = arith.andi %[[BITS]], %[[EXP_MASK]] : i32
// CHECK: %[[EXP_ZERO:.+]] = arith.cmpi eq, %[[EXP]], %[[ZERO]] : i32
-// CHECK: %[[MAN:.+]] = arith.andi %[[BITS]], %[[MAN_MASK]] : i32
-// CHECK: %[[MAN_NONZERO:.+]] = arith.cmpi ne, %[[MAN]], %[[ZERO]] : i32
-// CHECK: %[[IS_DEN:.+]] = arith.andi %[[EXP_ZERO]], %[[MAN_NONZERO]] : i1
-// CHECK: %[[SIGN_ONLY:.+]] = arith.andi %[[BITS]], %[[SIGN_MASK]] : i32
-// CHECK: %[[RES_BITS:.+]] = arith.select %[[IS_DEN]], %[[SIGN_ONLY]], %[[BITS]] : i32
+// CHECK: %[[CLEARED:.+]] = arith.andi %[[BITS]], %[[CLEAR_MAN_MASK]] : i32
+// CHECK: %[[RES_BITS:.+]] = arith.select %[[EXP_ZERO]], %[[CLEARED]], %[[BITS]] : i32
// CHECK: %[[RES:.+]] = arith.bitcast %[[RES_BITS]] : i32 to f32
// CHECK: return %[[RES]] : f32
func.func @flush_denormals_f32(%arg0: f32) -> f32 {
@@ -31,15 +26,13 @@ func.func @flush_denormals_f32(%arg0: f32) -> f32 {
// -----
// Expansion for bf16:
-// sign mask = 0x8000
// exp mask = 0x7f80
-// mantissa mask = 0x007f
+// clear-man mask = 0xff80 (-128 as signed i16)
// CHECK-LABEL: func @flush_denormals_bf16
// CHECK: arith.bitcast %{{.*}} : bf16 to i16
-// CHECK: %[[MAN_MASK:.+]] = arith.constant 127 : i16
// CHECK: %[[EXP_MASK:.+]] = arith.constant 32640 : i16
-// CHECK: %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK: %[[CLEAR_MAN_MASK:.+]] = arith.constant -128 : i16
// CHECK: arith.bitcast %{{.*}} : i16 to bf16
func.func @flush_denormals_bf16(%arg0: bf16) -> bf16 {
%0 = arith.flush_denormals %arg0 : bf16
@@ -49,15 +42,13 @@ func.func @flush_denormals_bf16(%arg0: bf16) -> bf16 {
// -----
// Expansion for f16:
-// sign mask = 0x8000
// exp mask = 0x7c00
-// mantissa mask = 0x03ff
+// clear-man mask = 0xfc00 (-1024 as signed i16)
// CHECK-LABEL: func @flush_denormals_f16
// CHECK: arith.bitcast %{{.*}} : f16 to i16
-// CHECK: %[[MAN_MASK:.+]] = arith.constant 1023 : i16
// CHECK: %[[EXP_MASK:.+]] = arith.constant 31744 : i16
-// CHECK: %[[SIGN_MASK:.+]] = arith.constant -32768 : i16
+// CHECK: %[[CLEAR_MAN_MASK:.+]] = arith.constant -1024 : i16
// CHECK: arith.bitcast %{{.*}} : i16 to f16
func.func @flush_denormals_f16(%arg0: f16) -> f16 {
%0 = arith.flush_denormals %arg0 : f16
@@ -67,15 +58,13 @@ func.func @flush_denormals_f16(%arg0: f16) -> f16 {
// -----
// Expansion for f64 (verifies wide APInt masks work):
-// sign mask = 0x8000000000000000 = -9223372036854775808 (signed i64)
-// exp mask = 0x7ff0000000000000 = 9218868437227405312
-// man mask = 0x000fffffffffffff = 4503599627370495
+// exp mask = 0x7ff0000000000000 = 9218868437227405312
+// clear-man mask = 0xfff0000000000000 = -4503599627370496 (signed i64)
// CHECK-LABEL: func @flush_denormals_f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
-// CHECK: %[[MAN_MASK:.+]] = arith.constant 4503599627370495 : i64
// CHECK: %[[EXP_MASK:.+]] = arith.constant 9218868437227405312 : i64
-// CHECK: %[[SIGN_MASK:.+]] = arith.constant -9223372036854775808 : i64
+// CHECK: %[[CLEAR_MAN_MASK:.+]] = arith.constant -4503599627370496 : i64
// CHECK: arith.bitcast %{{.*}} : i64 to f64
func.func @flush_denormals_f64(%arg0: f64) -> f64 {
%0 = arith.flush_denormals %arg0 : f64
@@ -88,8 +77,7 @@ func.func @flush_denormals_f64(%arg0: f64) -> f64 {
// CHECK: arith.bitcast %{{.*}} : vector<4xf32> to vector<4xi32>
// CHECK: arith.andi %{{.*}} : vector<4xi32>
// CHECK: arith.cmpi eq, %{{.*}} : vector<4xi32>
-// CHECK: arith.cmpi ne, %{{.*}} : vector<4xi32>
-// CHECK: arith.andi %{{.*}} : vector<4xi1>
+// CHECK: arith.andi %{{.*}} : vector<4xi32>
// CHECK: arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xi32>
// CHECK: arith.bitcast %{{.*}} : vector<4xi32> to vector<4xf32>
func.func @flush_denormals_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
More information about the Mlir-commits
mailing list