[Mlir-commits] [mlir] [mlir][arith] Fold `(a * b) / b -> a` (PR #121534)
Ivan Butygin
llvmlistbot at llvm.org
Thu Jan 2 18:33:47 PST 2025
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/121534
>From bdf8ed01b611b664245ca0bc0801a98eb44a312c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 3 Jan 2025 02:31:26 +0100
Subject: [PATCH 1/3] [mlir][arith] Fold `(a * b) / b`
Alive2 check: https://alive2.llvm.org/ce/z/5XWjWE
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 21 +++++++++++++++++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 20 ++++++++++++++++++++
2 files changed, 41 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43c0..e486bb678ce33e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -580,11 +580,29 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
// DivUIOp
//===----------------------------------------------------------------------===//
+static Value foldDivMul(Value lhs, Value rhs,
+ arith::IntegerOverflowFlags ovfFlags) {
+ auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
+ if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
+ return {};
+
+ if (mul.getLhs() == rhs)
+ return mul.getRhs();
+
+ if (mul.getRhs() == rhs)
+ return mul.getLhs();
+
+ return {};
+}
+
OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
// divui (x, 1) -> x.
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
+ if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
+ return val;
+
// Don't fold if it would require a division by zero.
bool div0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
@@ -621,6 +639,9 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
+ if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
+ return val;
+
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 6a186a0c6ceca0..1b80086c13eced 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2060,6 +2060,26 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
// -----
+// CHECK-LABEL: @test_divui_mul
+// CHECK-SAME: (%[[ARG:.*]]: index, %{{.*}}: index)
+// CHECK: return %[[ARG]]
+func.func @test_divui_mul(%arg0: index, %arg1: index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+ %1 = arith.divui %0, %arg1 : index
+ return %1 : index
+}
+
+// CHECK-LABEL: @test_divsi_mul
+// CHECK-SAME: (%[[ARG:.*]]: index, %{{.*}}: index)
+// CHECK: return %[[ARG]]
+func.func @test_divsi_mul(%arg0: index, %arg1: index) -> index {
+ %0 = arith.muli %arg1, %arg0 overflow<nsw> : index
+ %1 = arith.divsi %0, %arg1 : index
+ return %1 : index
+}
+
+// -----
+
// CHECK-LABEL: @test_cmpf(
func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
>From c67b426b04a4345e794e992646325bcd2cff1501 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 3 Jan 2025 02:48:18 +0100
Subject: [PATCH 2/3] comment
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e486bb678ce33e..e016a6e16e59ff 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -580,6 +580,7 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
// DivUIOp
//===----------------------------------------------------------------------===//
+/// Fold `(a * b) / b -> a`
static Value foldDivMul(Value lhs, Value rhs,
arith::IntegerOverflowFlags ovfFlags) {
auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
@@ -600,6 +601,7 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
+ // (a * b) / b -> a
if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
return val;
@@ -639,6 +641,7 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
+ // (a * b) / b -> a
if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
return val;
>From f4b5447ba3c261580f222061760c8a864d8e224b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 3 Jan 2025 03:09:05 +0100
Subject: [PATCH 3/3] more tests
---
mlir/test/Dialect/Arith/canonicalize.mlir | 64 +++++++++++++++++++----
1 file changed, 54 insertions(+), 10 deletions(-)
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1b80086c13eced..522711b08f289d 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2060,23 +2060,67 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
// -----
-// CHECK-LABEL: @test_divui_mul
-// CHECK-SAME: (%[[ARG:.*]]: index, %{{.*}}: index)
-// CHECK: return %[[ARG]]
-func.func @test_divui_mul(%arg0: index, %arg1: index) -> index {
- %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+ %1 = arith.divui %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divui_of_muli_0(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG1]]
+
+func.func @fold_divui_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
%1 = arith.divui %0, %arg1 : index
return %1 : index
}
+// CHECK-LABEL: func @fold_divui_of_muli_1(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG0]]
+
+func.func @fold_divsi_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
+ %1 = arith.divsi %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divsi_of_muli_0(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG1]]
-// CHECK-LABEL: @test_divsi_mul
-// CHECK-SAME: (%[[ARG:.*]]: index, %{{.*}}: index)
-// CHECK: return %[[ARG]]
-func.func @test_divsi_mul(%arg0: index, %arg1: index) -> index {
- %0 = arith.muli %arg1, %arg0 overflow<nsw> : index
+func.func @fold_divsi_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
%1 = arith.divsi %0, %arg1 : index
return %1 : index
}
+// CHECK-LABEL: func @fold_divsi_of_muli_1(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG0]]
+
+// Do not fold divui(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divui_of_muli(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 : index
+ %1 = arith.divui %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divui_of_muli
+// CHECK: %[[T0:.+]] = arith.muli
+// CHECK: %[[T1:.+]] = arith.divui %[[T0]],
+// CHECK: return %[[T1]]
+
+// Do not fold divsi(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divsi_of_muli(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 : index
+ %1 = arith.divsi %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divsi_of_muli
+// CHECK: %[[T0:.+]] = arith.muli
+// CHECK: %[[T1:.+]] = arith.divsi %[[T0]],
+// CHECK: return %[[T1]]
// -----
More information about the Mlir-commits
mailing list