[Mlir-commits] [mlir] [mlir][arith] Add mulf(x, 0) -> 0 to mulf folder (PR #161395)
Xiang Li
llvmlistbot at llvm.org
Tue Sep 30 10:01:08 PDT 2025
https://github.com/python3kgae updated https://github.com/llvm/llvm-project/pull/161395
>From 7e4a0d4dd46002776015c2c715a0e301d07c31ac Mon Sep 17 00:00:00 2001
From: Xiang Li <xiagli at microsoft.com>
Date: Tue, 30 Sep 2025 15:45:15 +0000
Subject: [PATCH 1/2] [mlir][arith] Add mulf(x, 0) -> 0 to mulf folder
Fold `mulf(x, 0) -> 0`.
Updated the yield_constant_loop test in mlir/test/Dialect/SCF/loop-pipelining.mlir
to workaround [TODO](https://github.com/llvm/llvm-project/blob/main/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp#L163) in TestSCFUtils.cpp
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 3 +++
mlir/test/Dialect/Arith/canonicalize.mlir | 10 ++++++++++
mlir/test/Dialect/SCF/loop-pipelining.mlir | 12 ++++++------
3 files changed, 19 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3a98df8..676297f56ac0f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1281,6 +1281,9 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
// mulf(x, 1) -> x
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();
+ // mulf(x, 0) -> 0
+ if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
+ return getRhs();
return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a2d7703..4c72a1bb27b01 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2216,6 +2216,16 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) {
return %2 : f32
}
+// CHECK-LABEL: @test_mulf2(
+func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
+ // CHECK-NEXT: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK-NEXT: return %[[C0]], %[[C0]]
+ %c0 = arith.constant 0.0 : f32
+ %0 = arith.mulf %arg0, %c0 : f32
+ %1 = arith.mulf %c0, %arg1 : f32
+ return %0, %1 : f32, f32
+}
+
// -----
// CHECK-LABEL: @test_divf(
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 86af637fc05d7..11dc55c7ebb17 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -930,7 +930,7 @@ func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: i
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[CST10:.*]] = arith.constant 1.000000e+01 : f32
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
// Prologue:
// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
@@ -938,15 +938,15 @@ func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: i
// CHECK-NEXT: %[[L1:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
// CHECK-SAME: step %[[C1]] iter_args(%[[ARG0:.*]] = %[[CST2]], %[[ARG1:.*]] = %[[L0]]) -> (f32, f32) {
// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG1]], %[[ARG0]] : f32
-// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST0]] : f32
+// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST10]] : f32
// CHECK-NEXT: memref.store %[[MUL0]], %[[A]][%[[IV]]] : memref<?xf32>
// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
-// CHECK-NEXT: scf.yield %[[CST0]], %[[L2]] : f32
+// CHECK-NEXT: scf.yield %[[CST10]], %[[L2]] : f32
// CHECK-NEXT: }
// Epilogue:
-// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST0]] : f32
-// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST0]] : f32
+// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST10]] : f32
+// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST10]] : f32
// CHECK-NEXT: memref.store %[[MUL1]], %[[A]][%[[C3]]] : memref<?xf32>
// CHECK-NEXT: return %[[L1]]#0 : f32
@@ -954,7 +954,7 @@ func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
- %cf0 = arith.constant 0.0 : f32
+ %cf0 = arith.constant 10.0 : f32
%cf2 = arith.constant 2.0 : f32
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf2) -> f32 {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
>From 8d49f3fc48dd629332492a850ea20de054306c58 Mon Sep 17 00:00:00 2001
From: Xiang Li <xiagli at microsoft.com>
Date: Tue, 30 Sep 2025 17:00:30 +0000
Subject: [PATCH 2/2] mulf(NaN, x) -> NaN
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 3 +++
mlir/test/Dialect/Arith/canonicalize.mlir | 11 +++++++++++
2 files changed, 14 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 676297f56ac0f..60269a2f77c0e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1281,6 +1281,9 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
// mulf(x, 1) -> x
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();
+ // mulf(NaN, x) -> NaN
+ if (matchPattern(adaptor.getLhs(), m_NaNFloat()))
+ return getLhs();
// mulf(x, 0) -> 0
if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
return getRhs();
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4c72a1bb27b01..195d4fc8f5e92 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2226,6 +2226,17 @@ func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
return %0, %1 : f32, f32
}
+// CHECK-LABEL: @test_mulf3(
+func.func @test_mulf3(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
+ // CHECK-NEXT: %[[NAN:.+]] = arith.constant 0x7FC00000 : f32
+ // CHECK-NEXT: return %[[NAN]], %[[NAN]]
+ %c0 = arith.constant 0.0 : f32
+ %nan = arith.constant 0x7FC00000 : f32
+ %0 = arith.mulf %nan, %c0 : f32
+ %1 = arith.mulf %c0, %nan : f32
+ return %0, %1 : f32, f32
+}
+
// -----
// CHECK-LABEL: @test_divf(
More information about the Mlir-commits
mailing list