[Mlir-commits] [mlir] 744279b - [mlir][arith] Add `arith.flush_denormals` operation (#192641)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 21 04:46:03 PDT 2026


Author: Matthias Springer
Date: 2026-04-21T13:45:59+02:00
New Revision: 744279b9f17344433dbdb471cb47fb225f2954c0

URL: https://github.com/llvm/llvm-project/commit/744279b9f17344433dbdb471cb47fb225f2954c0
DIFF: https://github.com/llvm/llvm-project/commit/744279b9f17344433dbdb471cb47fb225f2954c0.diff

LOG: [mlir][arith] Add `arith.flush_denormals` operation (#192641)

Add a new `arith.flush_denormals` operation. The operation takes a
floating-point value as input and returns zero if the value is denormal.
If the input is not denormal, the operation passes through the input.
This commit also adds support to the `ArithToAPFloat` infrastructure.

Running example:
```mlir
%flush_a = arith.flush_denormals %a : f32
%flush_b = arith.flush_denormals %b : f32
%res = arith.addf %flush_a, %flush_b : f32
%flush_res = arith.flush_denormals %res : f32
```

The exact lowering path depends on the backend and is not implemented as
part of this PR:
- Per-instruction mode. E.g., on NVIDIA architectures, the above example
can lower to `add.ftz.f32 dest, a, b`.
- Global status register. E.g., on `x86_64`, the above example can lower
to `_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); r = a + b;
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF);
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);`. Subsequent
ON-OFF-ON switches can be folded away.
- Emulation via integer arithmetics. Check the bit pattern of the input
float (depending on the specific FP type) and pass-through either the
input or a zero constant. This lowering approach works on all
architectures.

Assisted-by: claude-opus-4.7-thinking-high

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/ExecutionEngine/APFloatWrappers.cpp
    mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir
    mlir/test/Dialect/Arith/canonicalize.mlir
    mlir/test/Dialect/Arith/ops.mlir
    mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index e6a29180066e9..ba9ccb6a01d66 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -990,6 +990,45 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// FlushDenormalsOp
+//===----------------------------------------------------------------------===//
+
+def Arith_FlushDenormalsOp : Arith_FloatUnaryOp<"flush_denormals"> {
+  let summary = "flush denormal floating-point values to zero";
+  let description = [{
+    The `flush_denormals` operation takes a floating-point value and returns
+    the input value if it is a normal (or zero, infinity, or NaN) value, or
+    a zero of the same type if the input is a denormal (subnormal) value.
+    The sign of zero is preserved when flushing a denormal: negative
+    denormals flush to `-0.0`, positive denormals flush to `+0.0`.
+
+    A denormal number ("subnormal number" in IEEE-754) is a non-zero floating
+    point number that is smaller (closer to zero) than the smallest normal
+    number. Denormals fill the underflow gap around zero in floating-point
+    arithmetics, but may come at a runtime cost on some architectures.
+
+    The input and result are required to be the same type. This type may be
+    a floating-point scalar type, a vector whose element type is a
+    floating-point type, or a tensor of floats. When operating on vectors
+    or tensors, the operation is applied elementwise.
+
+    Example:
+
+    ```mlir
+    // Scalar denormal flushing.
+    %a = arith.flush_denormals %b : f32
+
+    // SIMD vector element-wise denormal flushing.
+    %f = arith.flush_denormals %g : vector<4xf32>
+
+    // Tensor element-wise denormal flushing.
+    %x = arith.flush_denormals %y : tensor<4x?xbf16>
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // AddFOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 98185697e4591..bf80a12429bdf 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -447,12 +447,17 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
   SymbolOpInterface symTable;
 };
 
-struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
-  NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
-                            PatternBenefit benefit = 1)
-      : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
+/// Rewrite a unary floating-point op (same input/output float type) to an
+/// APFloat runtime call of the form `(i32 semantics, i64 bits) -> i64 bits`.
+template <typename OpTy>
+struct UnaryFloatOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+  UnaryFloatOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+                                  SymbolOpInterface symTable,
+                                  PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+        APFloatName(APFloatName) {}
 
-  LogicalResult matchAndRewrite(arith::NegFOp op,
+  LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
     if (failed(checkPreconditions(rewriter, op)))
       return failure();
@@ -460,8 +465,9 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
-        rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
+    std::string funcName = (llvm::Twine("_mlir_apfloat_") + APFloatName).str();
+    FailureOr<FuncOp> fn =
+        lookupOrCreateFnDecl(rewriter, symTable, funcName, {i32Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -481,14 +487,14 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
           // Call APFloat function.
           Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
           SmallVector<Value> params = {semValue, operandBits};
-          Value negatedBits =
+          Value resultBits =
               func::CallOp::create(rewriter, loc, TypeRange(i64Type),
                                    SymbolRefAttr::get(*fn), params)
                   ->getResult(0);
 
           // Truncate result to the original width.
           Value truncatedBits =
-              arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+              arith::TruncIOp::create(rewriter, loc, intWType, resultBits);
           return arith::BitcastOp::create(rewriter, loc, floatTy,
                                           truncatedBits);
         });
@@ -497,6 +503,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
   }
 
   SymbolOpInterface symTable;
+  const char *APFloatName;
 };
 
 namespace {
@@ -528,10 +535,13 @@ void ArithToAPFloatConversionPass::runOnOperation() {
       context, "minimum", getOperation());
   patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
       context, "maximum", getOperation());
-  patterns
-      .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
-           CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
-          context, getOperation());
+  patterns.add<FpToFpConversion<arith::ExtFOp>,
+               FpToFpConversion<arith::TruncFOp>, CmpFOpToAPFloatConversion>(
+      context, getOperation());
+  patterns.add<UnaryFloatOpToAPFloatConversion<arith::NegFOp>>(context, "neg",
+                                                               getOperation());
+  patterns.add<UnaryFloatOpToAPFloatConversion<arith::FlushDenormalsOp>>(
+      context, "flush_denormals", getOperation());
   patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
                                                    /*isUnsigned=*/false);
   patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e11a38ffec50c..fef8fd210a495 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1104,6 +1104,28 @@ OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
                                      [](const APFloat &a) { return -a; });
 }
 
+//===----------------------------------------------------------------------===//
+// FlushDenormalsOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
+  // TODO: Fold flush_denormals if the floating-point type does not support
+  // denormals. There is currently no API to query this information from
+  // APFloat.
+
+  // flush_denormals(flush_denormals(x)) -> flush_denormals(x)
+  if (auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
+    return op.getResult();
+
+  // Constant-fold flush_denormals if the operand is a constant.
+  return constFoldUnaryOp<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) {
+        if (a.isDenormal())
+          return APFloat::getZero(a.getSemantics(), a.isNegative());
+        return a;
+      });
+}
+
 //===----------------------------------------------------------------------===//
 // AddFOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index b30a9072c0094..f23bcdda9ad1b 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -163,6 +163,17 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_abs(int32_t semantics,
   return abs(x).bitcastToAPInt().getZExtValue();
 }
 
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t
+_mlir_apfloat_flush_denormals(int32_t semantics, uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  if (x.isDenormal())
+    x = llvm::APFloat::getZero(sem, x.isNegative());
+  return x.bitcastToAPInt().getZExtValue();
+}
+
 MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isfinite(int32_t semantics,
                                                          uint64_t a) {
   const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(

diff  --git a/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir
index bd4a9dac597e7..1128e5c39c68f 100644
--- a/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir
@@ -226,6 +226,32 @@ func.func @negf(%arg0: f32) {
 
 // -----
 
+// CHECK: func.func private @_mlir_apfloat_flush_denormals(i32, i64) -> i64
+// CHECK-LABEL: func.func @flush_denormals
+// CHECK: %[[bc:.*]] = arith.bitcast %{{.*}} : f32 to i32
+// CHECK: %[[ext:.*]] = arith.extui %[[bc]] : i32 to i64
+// CHECK: %[[sem:.*]] = arith.constant 2 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_flush_denormals(%[[sem]], %[[ext]]) : (i32, i64) -> i64
+// CHECK: %[[trunc:.*]] = arith.trunci %[[res]] : i64 to i32
+// CHECK: arith.bitcast %[[trunc]] : i32 to f32
+func.func @flush_denormals(%arg0: f32) {
+  %0 = arith.flush_denormals %arg0 : f32
+  return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_flush_denormals(i32, i64) -> i64
+// CHECK-LABEL: func.func @flush_denormals_f8
+// CHECK: %[[sem:.*]] = arith.constant 10 : i32
+// CHECK: call @_mlir_apfloat_flush_denormals(%[[sem]], %{{.*}}) : (i32, i64) -> i64
+func.func @flush_denormals_f8(%arg0: f8E4M3FN) {
+  %0 = arith.flush_denormals %arg0 : f8E4M3FN
+  return
+}
+
+// -----
+
 // CHECK: func.func private @_mlir_apfloat_minimum(i32, i64, i64) -> i64
 // CHECK: %[[sem:.*]] = arith.constant 2 : i32
 // CHECK: %[[res:.*]] = call @_mlir_apfloat_minimum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index b153bd7c32261..a3a0dc7adf7cc 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3094,6 +3094,29 @@ func.func @test_negf1(%f : f32) -> (f32) {
 
 // -----
 
+// CHECK-LABEL: @test_flush_denormals_const(
+// CHECK: %[[res:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: return %[[res]]
+func.func @test_flush_denormals_const() -> (f32) {
+  %c = arith.constant 1.0e-40 : f32
+  %0 = arith.flush_denormals %c : f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_flush_denormals_idempotent(
+// CHECK-SAME: %[[arg0:.+]]:
+// CHECK: %[[res:.+]] = arith.flush_denormals %[[arg0]] : f32
+// CHECK: return %[[res]]
+func.func @test_flush_denormals_idempotent(%f : f32) -> (f32) {
+  %0 = arith.flush_denormals %f : f32
+  %1 = arith.flush_denormals %0 : f32
+  return %1 : f32
+}
+
+// -----
+
 // CHECK-LABEL: @test_remui(
 // CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32>
 // CHECK: return %[[res]]

diff  --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 3874c85818eb4..059e35c384dac 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -481,6 +481,36 @@ func.func @test_negf_scalable_vector(%arg0 : vector<[8]xf64>) -> vector<[8]xf64>
   return %0 : vector<[8]xf64>
 }
 
+// CHECK-LABEL: test_flush_denormals
+func.func @test_flush_denormals(%arg0 : f32) -> f32 {
+  %0 = arith.flush_denormals %arg0 : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: test_flush_denormals_tensor
+func.func @test_flush_denormals_tensor(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %0 = arith.flush_denormals %arg0 : tensor<8x8xf32>
+  return %0 : tensor<8x8xf32>
+}
+
+// CHECK-LABEL: test_flush_denormals_vector
+func.func @test_flush_denormals_vector(%arg0 : vector<8xf32>) -> vector<8xf32> {
+  %0 = arith.flush_denormals %arg0 : vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: test_flush_denormals_scalable_vector
+func.func @test_flush_denormals_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = arith.flush_denormals %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
+// CHECK-LABEL: test_flush_denormals_bf16
+func.func @test_flush_denormals_bf16(%arg0 : bf16) -> bf16 {
+  %0 = arith.flush_denormals %arg0 : bf16
+  return %0 : bf16
+}
+
 // CHECK-LABEL: test_addf
 func.func @test_addf(%arg0 : f64, %arg1 : f64) -> f64 {
   %0 = arith.addf %arg0, %arg1 : f64
@@ -1216,12 +1246,14 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
 // CHECK: {{.*}} = arith.divf %arg0, %arg1 fastmath<fast> : f32
 // CHECK: {{.*}} = arith.remf %arg0, %arg1 fastmath<fast> : f32
 // CHECK: {{.*}} = arith.negf %arg0 fastmath<fast> : f32
+// CHECK: {{.*}} = arith.flush_denormals %arg0 fastmath<fast> : f32
   %0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
   %1 = arith.subf %arg0, %arg1 fastmath<fast> : f32
   %2 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
   %3 = arith.divf %arg0, %arg1 fastmath<fast> : f32
   %4 = arith.remf %arg0, %arg1 fastmath<fast> : f32
   %5 = arith.negf %arg0 fastmath<fast> : f32
+  %flush = arith.flush_denormals %arg0 fastmath<fast> : f32
 // CHECK: {{.*}} = arith.addf %arg0, %arg1 : f32
   %6 = arith.addf %arg0, %arg1 fastmath<none> : f32
 // CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32

diff  --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index e6c60001e5aee..6095a55ac89ca 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -23,6 +23,14 @@ func.func @foo() -> (f8E4M3FN, f32) {
   return %cst1, %cst2 : f8E4M3FN, f32
 }
 
+// Put the denormal in a separate function so that the flush_denormals folder
+// does not collapse it before we reach the runtime.
+// 2^-9 = 0.001953125 is the smallest positive f8E4M3FN denormal.
+func.func @denormal_f8() -> f8E4M3FN {
+  %cst = arith.constant 0.001953125 : f8E4M3FN
+  return %cst : f8E4M3FN
+}
+
 func.func @entry() {
   %a1 = arith.constant 1.4 : f8E4M3FN
   %a2 = arith.constant 1.4 : f32
@@ -78,5 +86,16 @@ func.func @entry() {
   %cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN
   vector.print %cvt_from_unsigned_int : f4E2M1FN
 
+  // flush_denormals on an f8E4M3FN denormal: returns +0.0.
+  %denormal = func.call @denormal_f8() : () -> f8E4M3FN
+  %flushed_denormal = arith.flush_denormals %denormal : f8E4M3FN
+  // CHECK-NEXT: 0
+  vector.print %flushed_denormal : f8E4M3FN
+
+  // flush_denormals on a normal f8E4M3FN value (%cvt == 2.25): passes through.
+  %flushed_normal = arith.flush_denormals %cvt : f8E4M3FN
+  // CHECK-NEXT: 2.25
+  vector.print %flushed_normal : f8E4M3FN
+
   return
 }


        


More information about the Mlir-commits mailing list