[Mlir-commits] [mlir] [MLIR][Arith] Add subf(-0, x) -> negf cononization rule for SubFOp (PR #194245)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 26 09:38:27 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Max Graey (MaxGraey)

<details>
<summary>Changes</summary>

Arith lacks support for this fold unlike LLVM:
Example: https://godbolt.org/z/1v5jGTsh1

Proof: https://alive2.llvm.org/ce/z/Wq8ALG

---
Full diff: https://github.com/llvm/llvm-project/pull/194245.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+1) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+18) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+5) 
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+40-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ba9ccb6a01d66..f359070b6842f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1089,6 +1089,7 @@ def Arith_SubFOp : Arith_FloatBinaryOpWithRoundingMode<"subf"> {
     ```
   }];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index b822f1eadf0eb..66f4ace265201 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -432,6 +432,24 @@ def UIToFPOfExtUI :
     Pat<(Arith_UIToFPOp (Arith_ExtUIOp $x, $nneg1), $nneg2),
         (Arith_UIToFPOp $x, $nneg1)>;
 
+//===----------------------------------------------------------------------===//
+// SubFOp
+//===----------------------------------------------------------------------===//
+
+// Match a constant (scalar or splat) -0.0 float.
+def IsNegZeroFloat :
+    Constraint<CPred<"m_NegZeroFloat().match($0)">,
+               "is negative-zero float constant">;
+
+// subf(-0, x) -> negf(x)
+// TODO: Verify if this canonicalization is safe when a rounding mode is
+// specified. For the moment, bail on custom rounding modes.
+def SubFOfNegZero :
+    Pat<(Arith_SubFOp (ConstantLikeMatcher AnyAttr:$c), $x, $fmf, $rm),
+        (Arith_NegFOp $x, $fmf),
+        [(IsNegZeroFloat $c),
+         (Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rm)]>;
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 36d0f093c6917..32ac5d606c9a0 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1167,6 +1167,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+void arith::SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                MLIRContext *context) {
+  patterns.add<SubFOfNegZero>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MaximumFOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 9fde39a110473..02626bef856d3 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2500,6 +2500,38 @@ func.func @test_subf(%arg0 : f16) -> (f16, f16, f16) {
   return %0, %1, %2 : f16, f16, f16
 }
 
+// CHECK-LABEL: @test_subf_negzero(
+//  CHECK-SAME: %[[ARG0:.+]]: f16
+func.func @test_subf_negzero(%arg0 : f16) -> f16 {
+  // CHECK-NEXT:  %[[X:.+]] = arith.negf %[[ARG0]] : f16
+  // CHECK-NEXT:  return %[[X]] : f16
+  %c-0 = arith.constant -0.0 : f16
+  %0 = arith.subf %c-0, %arg0 : f16
+  return %0 : f16
+}
+
+// subf(+0, x) must NOT fold to negf(x)
+// CHECK-LABEL: @test_subf_poszero_no_negf(
+//  CHECK-SAME: %[[ARG0:.+]]: f16
+func.func @test_subf_poszero_no_negf(%arg0 : f16) -> f16 {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-NEXT:  %[[X:.+]] = arith.subf %[[C0]], %[[ARG0]] : f16
+  // CHECK-NEXT:  return %[[X]] : f16
+  %c0 = arith.constant 0.0 : f16
+  %0 = arith.subf %c0, %arg0 : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: @test_subf_negzero_splat(
+//  CHECK-SAME: %[[ARG0:.+]]: vector<4xf32>
+func.func @test_subf_negzero_splat(%arg0 : vector<4xf32>) -> vector<4xf32> {
+  // CHECK-NEXT:  %[[X:.+]] = arith.negf %[[ARG0]] : vector<4xf32>
+  // CHECK-NEXT:  return %[[X]] : vector<4xf32>
+  %c-0 = arith.constant dense<-0.0> : vector<4xf32>
+  %0 = arith.subf %c-0, %arg0 : vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
 // -----
 
 // CHECK-LABEL: @test_mulf(
@@ -2587,10 +2619,12 @@ func.func @test_addf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
 
 // CHECK-LABEL: @test_subf_rounding_mode(
 // CHECK-SAME: %[[ARG0:.+]]: f32
-func.func @test_subf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+func.func @test_subf_rounding_mode(%arg0 : f32) -> (f32, f32, f32, f32) {
+  // CHECK-DAG:  %[[NZ:.+]] = arith.constant -0.000000e+00 : f32
   // CHECK-DAG:  %[[UP:.+]] = arith.constant 2.00000024 : f32
   // CHECK-DAG:  %[[DOWN:.+]] = arith.constant 2.000000e+00 : f32
-  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+  // CHECK:      %[[NEG:.+]] = arith.subf %[[NZ]], %[[ARG0]] downward : f32
+  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]], %[[NEG]]
   %a = arith.constant 1.0000001 : f32
   %b = arith.constant -1.0 : f32
   // subf(x, +0) folds even with a rounding mode.
@@ -2598,7 +2632,10 @@ func.func @test_subf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
   %0 = arith.subf %arg0, %c0 downward : f32
   %1 = arith.subf %a, %b upward : f32
   %2 = arith.subf %a, %b downward : f32
-  return %0, %1, %2 : f32, f32, f32
+  // subf(-0, x) must NOT fold to negf when a custom rounding mode is set.
+  %c-0 = arith.constant -0.0 : f32
+  %3 = arith.subf %c-0, %arg0 downward : f32
+  return %0, %1, %2, %3 : f32, f32, f32, f32
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/194245


More information about the Mlir-commits mailing list