[Mlir-commits] [mlir] [mlir][arith] Add mulf(x, 0) -> 0 to mulf folder (PR #161395)

Xiang Li llvmlistbot at llvm.org
Tue Sep 30 10:01:08 PDT 2025


https://github.com/python3kgae updated https://github.com/llvm/llvm-project/pull/161395

>From 7e4a0d4dd46002776015c2c715a0e301d07c31ac Mon Sep 17 00:00:00 2001
From: Xiang Li <xiagli at microsoft.com>
Date: Tue, 30 Sep 2025 15:45:15 +0000
Subject: [PATCH 1/2] [mlir][arith] Add mulf(x, 0) -> 0 to mulf folder

Fold `mulf(x, 0) -> 0`.

Updated the yield_constant_loop test in mlir/test/Dialect/SCF/loop-pipelining.mlir
 to workaround [TODO](https://github.com/llvm/llvm-project/blob/main/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp#L163) in TestSCFUtils.cpp
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp     |  3 +++
 mlir/test/Dialect/Arith/canonicalize.mlir  | 10 ++++++++++
 mlir/test/Dialect/SCF/loop-pipelining.mlir | 12 ++++++------
 3 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3a98df8..676297f56ac0f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1281,6 +1281,9 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
   // mulf(x, 1) -> x
   if (matchPattern(adaptor.getRhs(), m_OneFloat()))
     return getLhs();
+  // mulf(x, 0) -> 0
+  if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
+    return getRhs();
 
   return constFoldBinaryOp<FloatAttr>(
       adaptor.getOperands(),
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a2d7703..4c72a1bb27b01 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2216,6 +2216,16 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) {
   return %2 : f32
 }
 
+// CHECK-LABEL: @test_mulf2(
+func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
+  // CHECK-NEXT:  %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK-NEXT:  return %[[C0]], %[[C0]]
+  %c0 = arith.constant 0.0 : f32
+  %0 = arith.mulf %arg0, %c0 : f32
+  %1 = arith.mulf %c0, %arg1 : f32
+  return %0, %1  : f32, f32
+}
+
 // -----
 
 // CHECK-LABEL: @test_divf(
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 86af637fc05d7..11dc55c7ebb17 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -930,7 +930,7 @@ func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: i
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG:   %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[CST10:.*]] = arith.constant 1.000000e+01 : f32
 //   CHECK-DAG:   %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
 // Prologue:
 //       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
@@ -938,15 +938,15 @@ func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: i
 //  CHECK-NEXT:   %[[L1:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
 //  CHECK-SAME:     step %[[C1]] iter_args(%[[ARG0:.*]] = %[[CST2]], %[[ARG1:.*]] = %[[L0]]) -> (f32, f32) {
 //  CHECK-NEXT:     %[[ADD0:.*]] = arith.addf %[[ARG1]], %[[ARG0]] : f32
-//  CHECK-NEXT:     %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST0]] : f32
+//  CHECK-NEXT:     %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST10]] : f32
 //  CHECK-NEXT:     memref.store %[[MUL0]], %[[A]][%[[IV]]] : memref<?xf32>
 //  CHECK-NEXT:     %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
 //  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
-//  CHECK-NEXT:     scf.yield %[[CST0]], %[[L2]] : f32
+//  CHECK-NEXT:     scf.yield %[[CST10]], %[[L2]] : f32
 //  CHECK-NEXT:   }
 // Epilogue:
-//  CHECK-NEXT:   %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST0]] : f32
-//  CHECK-NEXT:   %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST0]] : f32
+//  CHECK-NEXT:   %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST10]] : f32
+//  CHECK-NEXT:   %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST10]] : f32
 //  CHECK-NEXT:   memref.store %[[MUL1]], %[[A]][%[[C3]]] : memref<?xf32>
 //  CHECK-NEXT:   return %[[L1]]#0 : f32
 
@@ -954,7 +954,7 @@ func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
-  %cf0 = arith.constant 0.0 : f32
+  %cf0 = arith.constant 10.0 : f32
   %cf2 = arith.constant 2.0 : f32
   %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf2) -> f32 {
     %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>

>From 8d49f3fc48dd629332492a850ea20de054306c58 Mon Sep 17 00:00:00 2001
From: Xiang Li <xiagli at microsoft.com>
Date: Tue, 30 Sep 2025 17:00:30 +0000
Subject: [PATCH 2/2] mulf(NaN, x) -> NaN

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp    |  3 +++
 mlir/test/Dialect/Arith/canonicalize.mlir | 11 +++++++++++
 2 files changed, 14 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 676297f56ac0f..60269a2f77c0e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1281,6 +1281,9 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
   // mulf(x, 1) -> x
   if (matchPattern(adaptor.getRhs(), m_OneFloat()))
     return getLhs();
+  // mulf(NaN, x) -> NaN
+  if (matchPattern(adaptor.getLhs(), m_NaNFloat()))
+    return getLhs();
   // mulf(x, 0) -> 0
   if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
     return getRhs();
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4c72a1bb27b01..195d4fc8f5e92 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2226,6 +2226,17 @@ func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
   return %0, %1  : f32, f32
 }
 
+// CHECK-LABEL: @test_mulf3(
+func.func @test_mulf3(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
+  // CHECK-NEXT:  %[[NAN:.+]] = arith.constant 0x7FC00000 : f32
+  // CHECK-NEXT:  return %[[NAN]], %[[NAN]]
+  %c0 = arith.constant 0.0 : f32
+  %nan = arith.constant 0x7FC00000 : f32
+  %0 = arith.mulf %nan, %c0 : f32
+  %1 = arith.mulf %c0, %nan : f32
+  return %0, %1  : f32, f32
+}
+
 // -----
 
 // CHECK-LABEL: @test_divf(



More information about the Mlir-commits mailing list