[Mlir-commits] [mlir] 1cade86 - [mlir][arith] Fold `(a * b) / b -> a` (#121534)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 3 09:03:03 PST 2025


Author: Ivan Butygin
Date: 2025-01-03T20:02:59+03:00
New Revision: 1cade8699719c934a8debb7bef9fdc3ff11e9602

URL: https://github.com/llvm/llvm-project/commit/1cade8699719c934a8debb7bef9fdc3ff11e9602
DIFF: https://github.com/llvm/llvm-project/commit/1cade8699719c934a8debb7bef9fdc3ff11e9602.diff

LOG: [mlir][arith] Fold `(a * b) / b -> a` (#121534)

If overflow flags allow it.

Alive2 check: https://alive2.llvm.org/ce/z/5XWjWE

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43c0..e016a6e16e59ff 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -580,11 +580,31 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
 // DivUIOp
 //===----------------------------------------------------------------------===//
 
+/// Fold `(a * b) / b -> a`
+static Value foldDivMul(Value lhs, Value rhs,
+                        arith::IntegerOverflowFlags ovfFlags) {
+  auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
+  if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
+    return {};
+
+  if (mul.getLhs() == rhs)
+    return mul.getRhs();
+
+  if (mul.getRhs() == rhs)
+    return mul.getLhs();
+
+  return {};
+}
+
 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
   // divui (x, 1) -> x.
   if (matchPattern(adaptor.getRhs(), m_One()))
     return getLhs();
 
+  // (a * b) / b -> a
+  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
+    return val;
+
   // Don't fold if it would require a division by zero.
   bool div0 = false;
   auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
@@ -621,6 +641,10 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_One()))
     return getLhs();
 
+  // (a * b) / b -> a
+  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
+    return val;
+
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
   auto result = constFoldBinaryOp<IntegerAttr>(

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 6a186a0c6ceca0..522711b08f289d 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2060,6 +2060,70 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
 
 // -----
 
+func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+  %1 = arith.divui %0, %arg0 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @fold_divui_of_muli_0(
+//  CHECK-SAME:     %[[ARG0:.+]]: index,
+//  CHECK-SAME:     %[[ARG1:.+]]: index)
+//       CHECK:   return %[[ARG1]]
+
+func.func @fold_divui_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+  %1 = arith.divui %0, %arg1 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @fold_divui_of_muli_1(
+//  CHECK-SAME:     %[[ARG0:.+]]: index,
+//  CHECK-SAME:     %[[ARG1:.+]]: index)
+//       CHECK:   return %[[ARG0]]
+
+func.func @fold_divsi_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
+  %1 = arith.divsi %0, %arg0 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @fold_divsi_of_muli_0(
+//  CHECK-SAME:     %[[ARG0:.+]]: index,
+//  CHECK-SAME:     %[[ARG1:.+]]: index)
+//       CHECK:   return %[[ARG1]]
+
+func.func @fold_divsi_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
+  %1 = arith.divsi %0, %arg1 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @fold_divsi_of_muli_1(
+//  CHECK-SAME:     %[[ARG0:.+]]: index,
+//  CHECK-SAME:     %[[ARG1:.+]]: index)
+//       CHECK:   return %[[ARG0]]
+
+// Do not fold divui(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divui_of_muli(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 : index
+  %1 = arith.divui %0, %arg0 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divui_of_muli
+//       CHECK:   %[[T0:.+]] = arith.muli
+//       CHECK:   %[[T1:.+]] = arith.divui %[[T0]],
+//       CHECK:   return %[[T1]]
+
+// Do not fold divsi(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divsi_of_muli(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 : index
+  %1 = arith.divsi %0, %arg0 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divsi_of_muli
+//       CHECK:   %[[T0:.+]] = arith.muli
+//       CHECK:   %[[T1:.+]] = arith.divsi %[[T0]],
+//       CHECK:   return %[[T1]]
+
+// -----
+
 // CHECK-LABEL: @test_cmpf(
 func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true


        


More information about the Mlir-commits mailing list