[Mlir-commits] [mlir] [mlir][affine] fix the issue of celidiv mul ceildiv expression not satisfying commutative (PR #109382)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 28 05:22:55 PDT 2024


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/109382

>From 96dbdc30e4e1c8f816058a8be686cc9cf3aebf50 Mon Sep 17 00:00:00 2001
From: chenlonglong <chenlonglong at willingcore.com>
Date: Fri, 20 Sep 2024 14:40:31 +0800
Subject: [PATCH 1/4] [mlir][affine] fix the issue of ceildiv-mul-ceildiv form
 expression not satisfying commutative

Fixet s https://github.com/llvm/llvm-project/issues/107508
---
 mlir/lib/IR/AffineExpr.cpp                    | 164 ++++++++++++------
 .../Dialect/Affine/simplify-structures.mlir   |  22 ++-
 2 files changed, 134 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..cc8c4c21b96cee 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -9,6 +9,8 @@
 #include <cmath>
 #include <cstdint>
 #include <limits>
+#include <numeric>
+#include <optional>
 #include <utility>
 
 #include "AffineExprDetail.h"
@@ -18,9 +20,8 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/MathExtras.h"
-#include <numeric>
-#include <optional>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
-  switch (expr.getKind()) {
-  case AffineExprKind::Constant:
-    return cast<AffineConstantExpr>(expr).getValue() == 0;
-  case AffineExprKind::DimId:
-    return false;
-  case AffineExprKind::SymbolId:
-    return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
-  // Checks divisibility by the given symbol for both operands.
-  case AffineExprKind::Add: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
-  }
-  // Checks divisibility by the given symbol for both operands. Consider the
-  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
-  // this is a division by s1 and both the operands of modulo are divisible by
-  // s1 but it is not divisible by s1 always. The third argument is
-  // `AffineExprKind::Mod` for this reason.
-  case AffineExprKind::Mod: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
-                               AffineExprKind::Mod) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
-                               AffineExprKind::Mod);
-  }
-  // Checks if any of the operand divisible by the given symbol.
-  case AffineExprKind::Mul: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
-  }
-  // Floordiv and ceildiv are divisible by the given symbol when the first
-  // operand is divisible, and the affine expression kind of the argument expr
-  // is same as the argument `opKind`. This can be inferred from commutative
-  // property of floordiv and ceildiv operations and are as follow:
-  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
-  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
-  // It will fail if operations are not same. For example:
-  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
-  case AffineExprKind::FloorDiv:
-  case AffineExprKind::CeilDiv: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    if (opKind != expr.getKind())
-      return false;
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
-  }
+  std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+                         llvm::detail::scope_exit<std::function<void(void)>>>>
+      stack;
+  stack.emplace_back(expr, symbolPos, opKind, []() {});
+  bool result = false;
+
+  while (!stack.empty()) {
+    AffineExpr expr = std::get<0>(stack.back());
+    unsigned symbolPos = std::get<1>(stack.back());
+    AffineExprKind opKind = std::get<2>(stack.back());
+
+    switch (expr.getKind()) {
+    case AffineExprKind::Constant: {
+      // Note: Assignment must occur before pop, which will affect whether it
+      // enters other execution branches.
+      result = cast<AffineConstantExpr>(expr).getValue() == 0;
+      stack.pop_back();
+      break;
+    }
+    case AffineExprKind::DimId: {
+      result = false;
+      stack.pop_back();
+      break;
+    }
+    case AffineExprKind::SymbolId: {
+      result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+      stack.pop_back();
+      break;
+    }
+    // Checks divisibility by the given symbol for both operands.
+    case AffineExprKind::Add: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos, opKind,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Checks divisibility by the given symbol for both operands. Consider the
+    // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
+    // s1`, this is a division by s1 and both the operands of modulo are
+    // divisible by s1 but it is not divisible by s1 always. The third argument
+    // is `AffineExprKind::Mod` for this reason.
+    case AffineExprKind::Mod: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos,
+                                 AffineExprKind::Mod,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Checks if any of the operand divisible by the given symbol.
+    case AffineExprKind::Mul: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (!result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos, opKind,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Floordiv and ceildiv are divisible by the given symbol when the first
+    // operand is divisible, and the affine expression kind of the argument expr
+    // is same as the argument `opKind`. This can be inferred from commutative
+    // property of floordiv and ceildiv operations and are as follow:
+    // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
+    // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
+    // It will fail 1.if operations are not same. For example:
+    // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
+    // multiplication operation in the expression. For example:
+    //  (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
+    case AffineExprKind::FloorDiv:
+    case AffineExprKind::CeilDiv: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      if (opKind != expr.getKind()) {
+        result = false;
+        stack.pop_back();
+        break;
+      }
+      if (llvm::any_of(stack, [](auto &it) {
+            return std::get<0>(it).getKind() == AffineExprKind::Mul;
+          })) {
+        result = false;
+        stack.pop_back();
+        break;
+      }
+
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
+                         [&stack]() { stack.pop_back(); });
+      break;
+    }
+      llvm_unreachable("Unknown AffineExpr");
+    }
   }
-  llvm_unreachable("Unknown AffineExpr");
+  return result;
 }
 
 /// Divides the given expression by the given symbol at position `symbolPos`. It
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
 }
 
 // Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = arith.constant 43
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
   %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
-  // CHECK:       %[[CST:.*]] = arith.constant 47
+  // CHECK-NOT:       arith.constant
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+  // CHECK-NOT:       arith.constant
   return %a : index
 }
 

>From 3b757f5824515338506b068c9ebad6b4a33b00db Mon Sep 17 00:00:00 2001
From: chenlonglong <chenlonglong at willingcore.com>
Date: Fri, 20 Sep 2024 18:51:51 +0800
Subject: [PATCH 2/4] refine

---
 mlir/lib/IR/AffineExpr.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc8c4c21b96cee..05ce0937fa9dba 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -415,7 +415,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
     case AffineExprKind::Mod: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
       stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
-                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                         [&stack, &result, binaryExpr, symbolPos]() {
                            if (result) {
                              stack.emplace_back(
                                  binaryExpr.getRHS(), symbolPos,

>From 89b5924957721ea24351885ef27a8c305fd56545 Mon Sep 17 00:00:00 2001
From: chenlonglong <chenlonglong at willingcore.com>
Date: Tue, 24 Sep 2024 16:52:02 +0800
Subject: [PATCH 3/4] refine

---
 mlir/lib/IR/AffineExpr.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 05ce0937fa9dba..b3d8f1cd0c7313 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -363,7 +363,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
-  std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+  SmallVector<std::tuple<AffineExpr, unsigned, AffineExprKind,
                          llvm::detail::scope_exit<std::function<void(void)>>>>
       stack;
   stack.emplace_back(expr, symbolPos, opKind, []() {});

>From 437549c295f67b598e1e318c748f503cafdeb86d Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sat, 28 Sep 2024 08:07:12 -0400
Subject: [PATCH 4/4] fix ci fail on windows platform

---
 mlir/lib/IR/AffineExpr.cpp | 98 +++++++++++++++++++++++++-------------
 1 file changed, 65 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index b3d8f1cd0c7313..ef037a2e133657 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -379,32 +379,44 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
       // Note: Assignment must occur before pop, which will affect whether it
       // enters other execution branches.
       result = cast<AffineConstantExpr>(expr).getValue() == 0;
+      llvm::detail::scope_exit<std::function<void(void)>> sexit(
+          std::move(std::get<3>(stack.back())));
       stack.pop_back();
       break;
     }
     case AffineExprKind::DimId: {
       result = false;
+      llvm::detail::scope_exit<std::function<void(void)>> sexit(
+          std::move(std::get<3>(stack.back())));
       stack.pop_back();
       break;
     }
     case AffineExprKind::SymbolId: {
       result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+      llvm::detail::scope_exit<std::function<void(void)>> sexit(
+          std::move(std::get<3>(stack.back())));
       stack.pop_back();
       break;
     }
     // Checks divisibility by the given symbol for both operands.
     case AffineExprKind::Add: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
-                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
-                           if (result) {
-                             stack.emplace_back(
-                                 binaryExpr.getRHS(), symbolPos, opKind,
-                                 [&stack]() { stack.pop_back(); });
-                           } else {
-                             stack.pop_back();
-                           }
-                         });
+      stack.emplace_back(
+          binaryExpr.getLHS(), symbolPos, opKind,
+          [&stack, &result, binaryExpr, symbolPos, opKind]() {
+            if (result) {
+              stack.emplace_back(
+                  binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
+                    llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                        std::move(std::get<3>(stack.back())));
+                    stack.pop_back();
+                  });
+            } else {
+              llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                  std::move(std::get<3>(stack.back())));
+              stack.pop_back();
+            }
+          });
       break;
     }
     // Checks divisibility by the given symbol for both operands. Consider the
@@ -414,32 +426,44 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
     // is `AffineExprKind::Mod` for this reason.
     case AffineExprKind::Mod: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-      stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
-                         [&stack, &result, binaryExpr, symbolPos]() {
-                           if (result) {
-                             stack.emplace_back(
-                                 binaryExpr.getRHS(), symbolPos,
-                                 AffineExprKind::Mod,
-                                 [&stack]() { stack.pop_back(); });
-                           } else {
-                             stack.pop_back();
-                           }
-                         });
+      stack.emplace_back(
+          binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+          [&stack, &result, binaryExpr, symbolPos]() {
+            if (result) {
+              stack.emplace_back(
+                  binaryExpr.getRHS(), symbolPos, AffineExprKind::Mod,
+                  [&stack]() {
+                    llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                        std::move(std::get<3>(stack.back())));
+                    stack.pop_back();
+                  });
+            } else {
+              llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                  std::move(std::get<3>(stack.back())));
+              stack.pop_back();
+            }
+          });
       break;
     }
     // Checks if any of the operand divisible by the given symbol.
     case AffineExprKind::Mul: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
-                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
-                           if (!result) {
-                             stack.emplace_back(
-                                 binaryExpr.getRHS(), symbolPos, opKind,
-                                 [&stack]() { stack.pop_back(); });
-                           } else {
-                             stack.pop_back();
-                           }
-                         });
+      stack.emplace_back(
+          binaryExpr.getLHS(), symbolPos, opKind,
+          [&stack, &result, binaryExpr, symbolPos, opKind]() {
+            if (!result) {
+              stack.emplace_back(
+                  binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
+                    llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                        std::move(std::get<3>(stack.back())));
+                    stack.pop_back();
+                  });
+            } else {
+              llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                  std::move(std::get<3>(stack.back())));
+              stack.pop_back();
+            }
+          });
       break;
     }
     // Floordiv and ceildiv are divisible by the given symbol when the first
@@ -457,6 +481,8 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
       if (opKind != expr.getKind()) {
         result = false;
+        llvm::detail::scope_exit<std::function<void(void)>> sexit(
+            std::move(std::get<3>(stack.back())));
         stack.pop_back();
         break;
       }
@@ -464,12 +490,18 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
             return std::get<0>(it).getKind() == AffineExprKind::Mul;
           })) {
         result = false;
+        llvm::detail::scope_exit<std::function<void(void)>> sexit(
+            std::move(std::get<3>(stack.back())));
         stack.pop_back();
         break;
       }
 
-      stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
-                         [&stack]() { stack.pop_back(); });
+      stack.emplace_back(
+          binaryExpr.getLHS(), symbolPos, expr.getKind(), [&stack]() {
+            llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                std::move(std::get<3>(stack.back())));
+            stack.pop_back();
+          });
       break;
     }
       llvm_unreachable("Unknown AffineExpr");



More information about the Mlir-commits mailing list