[Mlir-commits] [mlir] [mlir][arith] Add `flush_denormals` operations (PR #192641)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 17 04:49:43 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add a new `arith.flush_denormals` operation. The operation takes a floating-point value as input and returns zero if the value is denormal. If the input is not denormal, the operation passes through the input. This commit also adds support to the `ArithToAPFloat` infrastructure.

Running example:
```mlir
%flush_a = arith.flush_denormals %a : f32
%flush_b = arith.flush_denormals %b : f32
%res = arith.addf %flush_a, %flush_b : f32
%flush_res = arith.flush_denormals %res : f32
```

The exact lowering path depends on the backend and is not implemented as part of this PR:
- Per-instruction mode. E.g., on NVIDIA architectures, the above example can lower to `add.ftz.f32 dest, a, b`.
- Global status register. E.g., on `x86_64`, the above example can lower to `_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); r = a + b; _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF); _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);`. Subsequent ON-OFF-ON switches can be folded away.
- Emulation via integer arithmetics. Check the bit pattern of the input float (depending on the specific FP type) and pass-through either the input or a zero constant. This lowering approach works on all architectures.

Assisted-by: claude-opus-4.7-thinking-high


---
Full diff: https://github.com/llvm/llvm-project/pull/192641.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+34) 
- (modified) mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp (+23-13) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+22) 
- (modified) mlir/lib/ExecutionEngine/APFloatWrappers.cpp (+11) 
- (modified) mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir (+26) 
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+23) 
- (modified) mlir/test/Dialect/Arith/ops.mlir (+32) 
- (modified) mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir (+19) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index e6a29180066e9..6317f99ace410 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -990,6 +990,40 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// FlushDenormalsOp
+//===----------------------------------------------------------------------===//
+
+def Arith_FlushDenormalsOp : Arith_FloatUnaryOp<"flush_denormals"> {
+  let summary = "flush denormal floating-point values to zero";
+  let description = [{
+    The `flush_denormals` operation takes a floating-point value and returns
+    the input value if it is a normal (or zero, infinity, or NaN) value, or
+    a zero of the same type if the input is a denormal (subnormal) value.
+    The sign of zero is preserved when flushing a denormal: negative
+    denormals flush to `-0.0`, positive denormals flush to `+0.0`.
+
+    The input and result are required to be the same type. This type may be
+    a floating-point scalar type, a vector whose element type is a
+    floating-point type, or a tensor of floats. When operating on vectors
+    or tensors, the operation is applied elementwise.
+
+    Example:
+
+    ```mlir
+    // Scalar denormal flushing.
+    %a = arith.flush_denormals %b : f32
+
+    // SIMD vector element-wise denormal flushing.
+    %f = arith.flush_denormals %g : vector<4xf32>
+
+    // Tensor element-wise denormal flushing.
+    %x = arith.flush_denormals %y : tensor<4x?xbf16>
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // AddFOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 98185697e4591..bf80a12429bdf 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -447,12 +447,17 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
   SymbolOpInterface symTable;
 };
 
-struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
-  NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
-                            PatternBenefit benefit = 1)
-      : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
+/// Rewrite a unary floating-point op (same input/output float type) to an
+/// APFloat runtime call of the form `(i32 semantics, i64 bits) -> i64 bits`.
+template <typename OpTy>
+struct UnaryFloatOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+  UnaryFloatOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+                                  SymbolOpInterface symTable,
+                                  PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+        APFloatName(APFloatName) {}
 
-  LogicalResult matchAndRewrite(arith::NegFOp op,
+  LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
     if (failed(checkPreconditions(rewriter, op)))
       return failure();
@@ -460,8 +465,9 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
-        rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
+    std::string funcName = (llvm::Twine("_mlir_apfloat_") + APFloatName).str();
+    FailureOr<FuncOp> fn =
+        lookupOrCreateFnDecl(rewriter, symTable, funcName, {i32Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -481,14 +487,14 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
           // Call APFloat function.
           Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
           SmallVector<Value> params = {semValue, operandBits};
-          Value negatedBits =
+          Value resultBits =
               func::CallOp::create(rewriter, loc, TypeRange(i64Type),
                                    SymbolRefAttr::get(*fn), params)
                   ->getResult(0);
 
           // Truncate result to the original width.
           Value truncatedBits =
-              arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+              arith::TruncIOp::create(rewriter, loc, intWType, resultBits);
           return arith::BitcastOp::create(rewriter, loc, floatTy,
                                           truncatedBits);
         });
@@ -497,6 +503,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
   }
 
   SymbolOpInterface symTable;
+  const char *APFloatName;
 };
 
 namespace {
@@ -528,10 +535,13 @@ void ArithToAPFloatConversionPass::runOnOperation() {
       context, "minimum", getOperation());
   patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
       context, "maximum", getOperation());
-  patterns
-      .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
-           CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
-          context, getOperation());
+  patterns.add<FpToFpConversion<arith::ExtFOp>,
+               FpToFpConversion<arith::TruncFOp>, CmpFOpToAPFloatConversion>(
+      context, getOperation());
+  patterns.add<UnaryFloatOpToAPFloatConversion<arith::NegFOp>>(context, "neg",
+                                                               getOperation());
+  patterns.add<UnaryFloatOpToAPFloatConversion<arith::FlushDenormalsOp>>(
+      context, "flush_denormals", getOperation());
   patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
                                                    /*isUnsigned=*/false);
   patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e11a38ffec50c..fef8fd210a495 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1104,6 +1104,28 @@ OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
                                      [](const APFloat &a) { return -a; });
 }
 
+//===----------------------------------------------------------------------===//
+// FlushDenormalsOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
+  // TODO: Fold flush_denormals if the floating-point type does not support
+  // denormals. There is currently no API to query this information from
+  // APFloat.
+
+  // flush_denormals(flush_denormals(x)) -> flush_denormals(x)
+  if (auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
+    return op.getResult();
+
+  // Constant-fold flush_denormals if the operand is a constant.
+  return constFoldUnaryOp<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) {
+        if (a.isDenormal())
+          return APFloat::getZero(a.getSemantics(), a.isNegative());
+        return a;
+      });
+}
+
 //===----------------------------------------------------------------------===//
 // AddFOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index b30a9072c0094..f23bcdda9ad1b 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -163,6 +163,17 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_abs(int32_t semantics,
   return abs(x).bitcastToAPInt().getZExtValue();
 }
 
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t
+_mlir_apfloat_flush_denormals(int32_t semantics, uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  if (x.isDenormal())
+    x = llvm::APFloat::getZero(sem, x.isNegative());
+  return x.bitcastToAPInt().getZExtValue();
+}
+
 MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isfinite(int32_t semantics,
                                                          uint64_t a) {
   const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
diff --git a/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir
index bd4a9dac597e7..1128e5c39c68f 100644
--- a/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir
@@ -226,6 +226,32 @@ func.func @negf(%arg0: f32) {
 
 // -----
 
+// CHECK: func.func private @_mlir_apfloat_flush_denormals(i32, i64) -> i64
+// CHECK-LABEL: func.func @flush_denormals
+// CHECK: %[[bc:.*]] = arith.bitcast %{{.*}} : f32 to i32
+// CHECK: %[[ext:.*]] = arith.extui %[[bc]] : i32 to i64
+// CHECK: %[[sem:.*]] = arith.constant 2 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_flush_denormals(%[[sem]], %[[ext]]) : (i32, i64) -> i64
+// CHECK: %[[trunc:.*]] = arith.trunci %[[res]] : i64 to i32
+// CHECK: arith.bitcast %[[trunc]] : i32 to f32
+func.func @flush_denormals(%arg0: f32) {
+  %0 = arith.flush_denormals %arg0 : f32
+  return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_flush_denormals(i32, i64) -> i64
+// CHECK-LABEL: func.func @flush_denormals_f8
+// CHECK: %[[sem:.*]] = arith.constant 10 : i32
+// CHECK: call @_mlir_apfloat_flush_denormals(%[[sem]], %{{.*}}) : (i32, i64) -> i64
+func.func @flush_denormals_f8(%arg0: f8E4M3FN) {
+  %0 = arith.flush_denormals %arg0 : f8E4M3FN
+  return
+}
+
+// -----
+
 // CHECK: func.func private @_mlir_apfloat_minimum(i32, i64, i64) -> i64
 // CHECK: %[[sem:.*]] = arith.constant 2 : i32
 // CHECK: %[[res:.*]] = call @_mlir_apfloat_minimum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index b153bd7c32261..a3a0dc7adf7cc 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3094,6 +3094,29 @@ func.func @test_negf1(%f : f32) -> (f32) {
 
 // -----
 
+// CHECK-LABEL: @test_flush_denormals_const(
+// CHECK: %[[res:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: return %[[res]]
+func.func @test_flush_denormals_const() -> (f32) {
+  %c = arith.constant 1.0e-40 : f32
+  %0 = arith.flush_denormals %c : f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_flush_denormals_idempotent(
+// CHECK-SAME: %[[arg0:.+]]:
+// CHECK: %[[res:.+]] = arith.flush_denormals %[[arg0]] : f32
+// CHECK: return %[[res]]
+func.func @test_flush_denormals_idempotent(%f : f32) -> (f32) {
+  %0 = arith.flush_denormals %f : f32
+  %1 = arith.flush_denormals %0 : f32
+  return %1 : f32
+}
+
+// -----
+
 // CHECK-LABEL: @test_remui(
 // CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32>
 // CHECK: return %[[res]]
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 3874c85818eb4..059e35c384dac 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -481,6 +481,36 @@ func.func @test_negf_scalable_vector(%arg0 : vector<[8]xf64>) -> vector<[8]xf64>
   return %0 : vector<[8]xf64>
 }
 
+// CHECK-LABEL: test_flush_denormals
+func.func @test_flush_denormals(%arg0 : f32) -> f32 {
+  %0 = arith.flush_denormals %arg0 : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: test_flush_denormals_tensor
+func.func @test_flush_denormals_tensor(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %0 = arith.flush_denormals %arg0 : tensor<8x8xf32>
+  return %0 : tensor<8x8xf32>
+}
+
+// CHECK-LABEL: test_flush_denormals_vector
+func.func @test_flush_denormals_vector(%arg0 : vector<8xf32>) -> vector<8xf32> {
+  %0 = arith.flush_denormals %arg0 : vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: test_flush_denormals_scalable_vector
+func.func @test_flush_denormals_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = arith.flush_denormals %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
+// CHECK-LABEL: test_flush_denormals_bf16
+func.func @test_flush_denormals_bf16(%arg0 : bf16) -> bf16 {
+  %0 = arith.flush_denormals %arg0 : bf16
+  return %0 : bf16
+}
+
 // CHECK-LABEL: test_addf
 func.func @test_addf(%arg0 : f64, %arg1 : f64) -> f64 {
   %0 = arith.addf %arg0, %arg1 : f64
@@ -1216,12 +1246,14 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
 // CHECK: {{.*}} = arith.divf %arg0, %arg1 fastmath<fast> : f32
 // CHECK: {{.*}} = arith.remf %arg0, %arg1 fastmath<fast> : f32
 // CHECK: {{.*}} = arith.negf %arg0 fastmath<fast> : f32
+// CHECK: {{.*}} = arith.flush_denormals %arg0 fastmath<fast> : f32
   %0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
   %1 = arith.subf %arg0, %arg1 fastmath<fast> : f32
   %2 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
   %3 = arith.divf %arg0, %arg1 fastmath<fast> : f32
   %4 = arith.remf %arg0, %arg1 fastmath<fast> : f32
   %5 = arith.negf %arg0 fastmath<fast> : f32
+  %flush = arith.flush_denormals %arg0 fastmath<fast> : f32
 // CHECK: {{.*}} = arith.addf %arg0, %arg1 : f32
   %6 = arith.addf %arg0, %arg1 fastmath<none> : f32
 // CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index e6c60001e5aee..6095a55ac89ca 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -23,6 +23,14 @@ func.func @foo() -> (f8E4M3FN, f32) {
   return %cst1, %cst2 : f8E4M3FN, f32
 }
 
+// Put the denormal in a separate function so that the flush_denormals folder
+// does not collapse it before we reach the runtime.
+// 2^-9 = 0.001953125 is the smallest positive f8E4M3FN denormal.
+func.func @denormal_f8() -> f8E4M3FN {
+  %cst = arith.constant 0.001953125 : f8E4M3FN
+  return %cst : f8E4M3FN
+}
+
 func.func @entry() {
   %a1 = arith.constant 1.4 : f8E4M3FN
   %a2 = arith.constant 1.4 : f32
@@ -78,5 +86,16 @@ func.func @entry() {
   %cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN
   vector.print %cvt_from_unsigned_int : f4E2M1FN
 
+  // flush_denormals on an f8E4M3FN denormal: returns +0.0.
+  %denormal = func.call @denormal_f8() : () -> f8E4M3FN
+  %flushed_denormal = arith.flush_denormals %denormal : f8E4M3FN
+  // CHECK-NEXT: 0
+  vector.print %flushed_denormal : f8E4M3FN
+
+  // flush_denormals on a normal f8E4M3FN value (%cvt == 2.25): passes through.
+  %flushed_normal = arith.flush_denormals %cvt : f8E4M3FN
+  // CHECK-NEXT: 2.25
+  vector.print %flushed_normal : f8E4M3FN
+
   return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/192641


More information about the Mlir-commits mailing list