[Mlir-commits] [mlir] 2d06374 - [mlir][arith] Add mulf(x, 0) -> 0 to mulf folder (#161395)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 1 19:47:57 PDT 2025
Author: Xiang Li
Date: 2025-10-01T22:47:53-04:00
New Revision: 2d0637494936be3742750ab95b856e3cb86d1198
URL: https://github.com/llvm/llvm-project/commit/2d0637494936be3742750ab95b856e3cb86d1198
DIFF: https://github.com/llvm/llvm-project/commit/2d0637494936be3742750ab95b856e3cb86d1198.diff
LOG: [mlir][arith] Add mulf(x, 0) -> 0 to mulf folder (#161395)
Fold `mulf(x, 0) -> 0` when (nnan | nsz)
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3a98df8..898d76ce8d9b5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1282,6 +1282,13 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();
+ if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
+ arith::FastMathFlags::nsz)) {
+ // mulf(x, 0) -> 0
+ if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
+ return getRhs();
+ }
+
return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return a * b; });
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a2d7703..2fe0995c9d4df 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2216,6 +2216,18 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) {
return %2 : f32
}
+// CHECK-LABEL: @test_mulf2(
+func.func @test_mulf2(%arg0 : f32) -> (f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[C0n:.+]] = arith.constant -0.000000e+00 : f32
+ // CHECK-NEXT: return %[[C0]], %[[C0n]]
+ %c0 = arith.constant 0.0 : f32
+ %c0n = arith.constant -0.0 : f32
+ %0 = arith.mulf %c0, %arg0 fastmath<nnan,nsz> : f32
+ %1 = arith.mulf %c0n, %arg0 fastmath<nnan,nsz> : f32
+ return %0, %1 : f32, f32
+}
+
// -----
// CHECK-LABEL: @test_divf(
More information about the Mlir-commits
mailing list