[Mlir-commits] [mlir] [mlir][affine] fix the issue of celidiv mul ceildiv expression not satisfying commutative (PR #109382)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 1 23:01:42 PDT 2024


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/109382

>From 185863695e191a8dce6b37f38bd2dba086da08d3 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 20 Sep 2024 14:40:31 +0800
Subject: [PATCH 1/5] [mlir][affine] fix the issue of ceildiv-mul-ceildiv form
 expression not satisfying commutative

Fixes https://github.com/llvm/llvm-project/issues/107508
---
 mlir/lib/IR/AffineExpr.cpp                    | 164 ++++++++++++------
 .../Dialect/Affine/simplify-structures.mlir   |  22 ++-
 2 files changed, 134 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..cc8c4c21b96cee 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -9,6 +9,8 @@
 #include <cmath>
 #include <cstdint>
 #include <limits>
+#include <numeric>
+#include <optional>
 #include <utility>
 
 #include "AffineExprDetail.h"
@@ -18,9 +20,8 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/MathExtras.h"
-#include <numeric>
-#include <optional>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
-  switch (expr.getKind()) {
-  case AffineExprKind::Constant:
-    return cast<AffineConstantExpr>(expr).getValue() == 0;
-  case AffineExprKind::DimId:
-    return false;
-  case AffineExprKind::SymbolId:
-    return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
-  // Checks divisibility by the given symbol for both operands.
-  case AffineExprKind::Add: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
-  }
-  // Checks divisibility by the given symbol for both operands. Consider the
-  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
-  // this is a division by s1 and both the operands of modulo are divisible by
-  // s1 but it is not divisible by s1 always. The third argument is
-  // `AffineExprKind::Mod` for this reason.
-  case AffineExprKind::Mod: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
-                               AffineExprKind::Mod) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
-                               AffineExprKind::Mod);
-  }
-  // Checks if any of the operand divisible by the given symbol.
-  case AffineExprKind::Mul: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
-  }
-  // Floordiv and ceildiv are divisible by the given symbol when the first
-  // operand is divisible, and the affine expression kind of the argument expr
-  // is same as the argument `opKind`. This can be inferred from commutative
-  // property of floordiv and ceildiv operations and are as follow:
-  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
-  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
-  // It will fail if operations are not same. For example:
-  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
-  case AffineExprKind::FloorDiv:
-  case AffineExprKind::CeilDiv: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    if (opKind != expr.getKind())
-      return false;
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
-  }
+  std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+                         llvm::detail::scope_exit<std::function<void(void)>>>>
+      stack;
+  stack.emplace_back(expr, symbolPos, opKind, []() {});
+  bool result = false;
+
+  while (!stack.empty()) {
+    AffineExpr expr = std::get<0>(stack.back());
+    unsigned symbolPos = std::get<1>(stack.back());
+    AffineExprKind opKind = std::get<2>(stack.back());
+
+    switch (expr.getKind()) {
+    case AffineExprKind::Constant: {
+      // Note: Assignment must occur before pop, which will affect whether it
+      // enters other execution branches.
+      result = cast<AffineConstantExpr>(expr).getValue() == 0;
+      stack.pop_back();
+      break;
+    }
+    case AffineExprKind::DimId: {
+      result = false;
+      stack.pop_back();
+      break;
+    }
+    case AffineExprKind::SymbolId: {
+      result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+      stack.pop_back();
+      break;
+    }
+    // Checks divisibility by the given symbol for both operands.
+    case AffineExprKind::Add: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos, opKind,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Checks divisibility by the given symbol for both operands. Consider the
+    // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
+    // s1`, this is a division by s1 and both the operands of modulo are
+    // divisible by s1 but it is not divisible by s1 always. The third argument
+    // is `AffineExprKind::Mod` for this reason.
+    case AffineExprKind::Mod: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos,
+                                 AffineExprKind::Mod,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Checks if any of the operand divisible by the given symbol.
+    case AffineExprKind::Mul: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (!result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos, opKind,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Floordiv and ceildiv are divisible by the given symbol when the first
+    // operand is divisible, and the affine expression kind of the argument expr
+    // is same as the argument `opKind`. This can be inferred from commutative
+    // property of floordiv and ceildiv operations and are as follow:
+    // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
+    // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
+    // It will fail 1.if operations are not same. For example:
+    // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
+    // multiplication operation in the expression. For example:
+    //  (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
+    case AffineExprKind::FloorDiv:
+    case AffineExprKind::CeilDiv: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      if (opKind != expr.getKind()) {
+        result = false;
+        stack.pop_back();
+        break;
+      }
+      if (llvm::any_of(stack, [](auto &it) {
+            return std::get<0>(it).getKind() == AffineExprKind::Mul;
+          })) {
+        result = false;
+        stack.pop_back();
+        break;
+      }
+
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
+                         [&stack]() { stack.pop_back(); });
+      break;
+    }
+      llvm_unreachable("Unknown AffineExpr");
+    }
   }
-  llvm_unreachable("Unknown AffineExpr");
+  return result;
 }
 
 /// Divides the given expression by the given symbol at position `symbolPos`. It
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
 }
 
 // Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = arith.constant 43
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
   %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
-  // CHECK:       %[[CST:.*]] = arith.constant 47
+  // CHECK-NOT:       arith.constant
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+  // CHECK-NOT:       arith.constant
   return %a : index
 }
 

>From c75822467ae32398e586cbe8badb1dfc1d89e386 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 20 Sep 2024 18:51:51 +0800
Subject: [PATCH 2/5] refine

---
 mlir/lib/IR/AffineExpr.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc8c4c21b96cee..05ce0937fa9dba 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -415,7 +415,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
     case AffineExprKind::Mod: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
       stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
-                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                         [&stack, &result, binaryExpr, symbolPos]() {
                            if (result) {
                              stack.emplace_back(
                                  binaryExpr.getRHS(), symbolPos,

>From dee51cdf83bcf865c76ea792a874314b7b30f30e Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Tue, 24 Sep 2024 16:52:02 +0800
Subject: [PATCH 3/5] refine

---
 mlir/lib/IR/AffineExpr.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 05ce0937fa9dba..b3d8f1cd0c7313 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -363,7 +363,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
-  std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+  SmallVector<std::tuple<AffineExpr, unsigned, AffineExprKind,
                          llvm::detail::scope_exit<std::function<void(void)>>>>
       stack;
   stack.emplace_back(expr, symbolPos, opKind, []() {});

>From c70e79414ffef1d1f81a0d2d30e1d265aa4bdf5c Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sat, 28 Sep 2024 08:07:12 -0400
Subject: [PATCH 4/5] fix ci fail on windows platform

---
 mlir/lib/IR/AffineExpr.cpp | 98 +++++++++++++++++++++++++-------------
 1 file changed, 65 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index b3d8f1cd0c7313..ef037a2e133657 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -379,32 +379,44 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
       // Note: Assignment must occur before pop, which will affect whether it
       // enters other execution branches.
       result = cast<AffineConstantExpr>(expr).getValue() == 0;
+      llvm::detail::scope_exit<std::function<void(void)>> sexit(
+          std::move(std::get<3>(stack.back())));
       stack.pop_back();
       break;
     }
     case AffineExprKind::DimId: {
       result = false;
+      llvm::detail::scope_exit<std::function<void(void)>> sexit(
+          std::move(std::get<3>(stack.back())));
       stack.pop_back();
       break;
     }
     case AffineExprKind::SymbolId: {
       result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+      llvm::detail::scope_exit<std::function<void(void)>> sexit(
+          std::move(std::get<3>(stack.back())));
       stack.pop_back();
       break;
     }
     // Checks divisibility by the given symbol for both operands.
     case AffineExprKind::Add: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
-                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
-                           if (result) {
-                             stack.emplace_back(
-                                 binaryExpr.getRHS(), symbolPos, opKind,
-                                 [&stack]() { stack.pop_back(); });
-                           } else {
-                             stack.pop_back();
-                           }
-                         });
+      stack.emplace_back(
+          binaryExpr.getLHS(), symbolPos, opKind,
+          [&stack, &result, binaryExpr, symbolPos, opKind]() {
+            if (result) {
+              stack.emplace_back(
+                  binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
+                    llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                        std::move(std::get<3>(stack.back())));
+                    stack.pop_back();
+                  });
+            } else {
+              llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                  std::move(std::get<3>(stack.back())));
+              stack.pop_back();
+            }
+          });
       break;
     }
     // Checks divisibility by the given symbol for both operands. Consider the
@@ -414,32 +426,44 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
     // is `AffineExprKind::Mod` for this reason.
     case AffineExprKind::Mod: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-      stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
-                         [&stack, &result, binaryExpr, symbolPos]() {
-                           if (result) {
-                             stack.emplace_back(
-                                 binaryExpr.getRHS(), symbolPos,
-                                 AffineExprKind::Mod,
-                                 [&stack]() { stack.pop_back(); });
-                           } else {
-                             stack.pop_back();
-                           }
-                         });
+      stack.emplace_back(
+          binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+          [&stack, &result, binaryExpr, symbolPos]() {
+            if (result) {
+              stack.emplace_back(
+                  binaryExpr.getRHS(), symbolPos, AffineExprKind::Mod,
+                  [&stack]() {
+                    llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                        std::move(std::get<3>(stack.back())));
+                    stack.pop_back();
+                  });
+            } else {
+              llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                  std::move(std::get<3>(stack.back())));
+              stack.pop_back();
+            }
+          });
       break;
     }
     // Checks if any of the operand divisible by the given symbol.
     case AffineExprKind::Mul: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
-                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
-                           if (!result) {
-                             stack.emplace_back(
-                                 binaryExpr.getRHS(), symbolPos, opKind,
-                                 [&stack]() { stack.pop_back(); });
-                           } else {
-                             stack.pop_back();
-                           }
-                         });
+      stack.emplace_back(
+          binaryExpr.getLHS(), symbolPos, opKind,
+          [&stack, &result, binaryExpr, symbolPos, opKind]() {
+            if (!result) {
+              stack.emplace_back(
+                  binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
+                    llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                        std::move(std::get<3>(stack.back())));
+                    stack.pop_back();
+                  });
+            } else {
+              llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                  std::move(std::get<3>(stack.back())));
+              stack.pop_back();
+            }
+          });
       break;
     }
     // Floordiv and ceildiv are divisible by the given symbol when the first
@@ -457,6 +481,8 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
       if (opKind != expr.getKind()) {
         result = false;
+        llvm::detail::scope_exit<std::function<void(void)>> sexit(
+            std::move(std::get<3>(stack.back())));
         stack.pop_back();
         break;
       }
@@ -464,12 +490,18 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
             return std::get<0>(it).getKind() == AffineExprKind::Mul;
           })) {
         result = false;
+        llvm::detail::scope_exit<std::function<void(void)>> sexit(
+            std::move(std::get<3>(stack.back())));
         stack.pop_back();
         break;
       }
 
-      stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
-                         [&stack]() { stack.pop_back(); });
+      stack.emplace_back(
+          binaryExpr.getLHS(), symbolPos, expr.getKind(), [&stack]() {
+            llvm::detail::scope_exit<std::function<void(void)>> sexit(
+                std::move(std::get<3>(stack.back())));
+            stack.pop_back();
+          });
       break;
     }
       llvm_unreachable("Unknown AffineExpr");

>From 514d78c113ebf3340d0bafbad49baaaca655db97 Mon Sep 17 00:00:00 2001
From: "long.chen" <lipracer at gmail.com>
Date: Wed, 2 Oct 2024 04:51:01 +0000
Subject: [PATCH 5/5] refine

---
 mlir/lib/IR/AffineExpr.cpp | 196 +++++++++++++++++++++----------------
 1 file changed, 111 insertions(+), 85 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index ef037a2e133657..82b0bc193bfb71 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -350,6 +350,82 @@ unsigned AffineDimExpr::getPosition() const {
   return static_cast<ImplType *>(expr)->position;
 }
 
+/// A manually managed stack used to convert recursive function calls into
+/// looping utility classes during the access tree structure process. This has
+/// two benefits: one is to access the current stack, and the other is to avoid
+/// stack explosion when the recursion depth is too deep. Typically, recursive
+/// calls take the form of the following:
+/// push node
+/// visit tree node
+/// push node->left_node
+/// ...
+/// pop left_node
+/// check result and do something
+/// push node->right_node
+/// pop right_node
+/// pop node
+/// ...
+/// This form can be converted into the following form:
+/// push node
+/// visit tree node
+/// push node->left_node
+/// ...
+/// pop left_node and do {
+///  check result and do something
+///  push node->right_node
+///  pop right_node and do { pop node  }
+/// }
+/// ...
+/// so we need to perform some operations
+/// after an element is pushed out of the stack. We use the `scope_exit`
+/// structure to invoke these operations.
+template <typename RetT, typename... ArgsT>
+class CallStack {
+public:
+  using value_type =
+      std::tuple<ArgsT..., llvm::detail::scope_exit<std::function<void(void)>>>;
+  CallStack(ArgsT... args) {
+    stack_.emplace_back(args..., []() {});
+  }
+
+  /// Push the parameters into the stack and record the operation to be executed
+  /// when the node access ends. By default, the previous stack element will pop
+  /// up. If you need to check the result of the current push to the stack, you
+  /// need to pass in a function and manually perform the push operation after
+  /// the function ends.
+  void pushArgs(ArgsT... args, const std::function<void(void)> &onExit = {}) {
+    if (onExit)
+      stack_.emplace_back(args..., onExit);
+    else
+      stack_.emplace_back(args..., [this]() { pop(); });
+  }
+
+  RetT getResult() const { return value_; }
+
+  void returnValue(RetT value) {
+    value_ = value;
+    pop();
+  }
+
+  value_type &top() { return stack_.back(); }
+
+  auto begin() const { return stack_.begin(); }
+
+  auto end() const { return stack_.end(); }
+
+  bool empty() const { return stack_.empty(); }
+
+private:
+  ///  Note: We must move the top element of the stack and then perform the
+  ///  stack pop operation. If we directly pop the stack, the `scope_exit` may
+  ///  modify the stack, which may cause the program on the Windows platform to
+  ///  crash, but it works normally on Ubuntu.git
+  void pop() { value_type _(stack_.pop_back_val()); }
+
+  SmallVector<value_type> stack_;
+  RetT value_;
+};
+
 /// Returns true if the expression is divisible by the given symbol with
 /// position `symbolPos`. The argument `opKind` specifies here what kind of
 /// division or mod operation called this division. It helps in implementing the
@@ -363,60 +439,40 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
-  SmallVector<std::tuple<AffineExpr, unsigned, AffineExprKind,
-                         llvm::detail::scope_exit<std::function<void(void)>>>>
-      stack;
-  stack.emplace_back(expr, symbolPos, opKind, []() {});
-  bool result = false;
+  CallStack<bool, AffineExpr, unsigned, AffineExprKind> stack(expr, symbolPos,
+                                                              opKind);
 
   while (!stack.empty()) {
-    AffineExpr expr = std::get<0>(stack.back());
-    unsigned symbolPos = std::get<1>(stack.back());
-    AffineExprKind opKind = std::get<2>(stack.back());
+    AffineExpr expr = std::get<0>(stack.top());
+    unsigned symbolPos = std::get<1>(stack.top());
+    AffineExprKind opKind = std::get<2>(stack.top());
 
     switch (expr.getKind()) {
     case AffineExprKind::Constant: {
       // Note: Assignment must occur before pop, which will affect whether it
       // enters other execution branches.
-      result = cast<AffineConstantExpr>(expr).getValue() == 0;
-      llvm::detail::scope_exit<std::function<void(void)>> sexit(
-          std::move(std::get<3>(stack.back())));
-      stack.pop_back();
+      stack.returnValue(cast<AffineConstantExpr>(expr).getValue() == 0);
       break;
     }
     case AffineExprKind::DimId: {
-      result = false;
-      llvm::detail::scope_exit<std::function<void(void)>> sexit(
-          std::move(std::get<3>(stack.back())));
-      stack.pop_back();
+      stack.returnValue(false);
       break;
     }
     case AffineExprKind::SymbolId: {
-      result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
-      llvm::detail::scope_exit<std::function<void(void)>> sexit(
-          std::move(std::get<3>(stack.back())));
-      stack.pop_back();
+      stack.returnValue(cast<AffineSymbolExpr>(expr).getPosition() ==
+                        symbolPos);
       break;
     }
     // Checks divisibility by the given symbol for both operands.
     case AffineExprKind::Add: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-      stack.emplace_back(
-          binaryExpr.getLHS(), symbolPos, opKind,
-          [&stack, &result, binaryExpr, symbolPos, opKind]() {
-            if (result) {
-              stack.emplace_back(
-                  binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
-                    llvm::detail::scope_exit<std::function<void(void)>> sexit(
-                        std::move(std::get<3>(stack.back())));
-                    stack.pop_back();
-                  });
-            } else {
-              llvm::detail::scope_exit<std::function<void(void)>> sexit(
-                  std::move(std::get<3>(stack.back())));
-              stack.pop_back();
-            }
-          });
+      stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind,
+                     [&stack, binaryExpr, symbolPos, opKind]() {
+                       if (stack.getResult())
+                         stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind);
+                       else
+                         stack.returnValue(stack.getResult());
+                     });
       break;
     }
     // Checks divisibility by the given symbol for both operands. Consider the
@@ -426,44 +482,26 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
     // is `AffineExprKind::Mod` for this reason.
     case AffineExprKind::Mod: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-      stack.emplace_back(
-          binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
-          [&stack, &result, binaryExpr, symbolPos]() {
-            if (result) {
-              stack.emplace_back(
-                  binaryExpr.getRHS(), symbolPos, AffineExprKind::Mod,
-                  [&stack]() {
-                    llvm::detail::scope_exit<std::function<void(void)>> sexit(
-                        std::move(std::get<3>(stack.back())));
-                    stack.pop_back();
-                  });
-            } else {
-              llvm::detail::scope_exit<std::function<void(void)>> sexit(
-                  std::move(std::get<3>(stack.back())));
-              stack.pop_back();
-            }
-          });
+      stack.pushArgs(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+                     [&stack, binaryExpr, symbolPos]() {
+                       if (stack.getResult())
+                         stack.pushArgs(binaryExpr.getRHS(), symbolPos,
+                                        AffineExprKind::Mod);
+                       else
+                         stack.returnValue(stack.getResult());
+                     });
       break;
     }
     // Checks if any of the operand divisible by the given symbol.
     case AffineExprKind::Mul: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-      stack.emplace_back(
-          binaryExpr.getLHS(), symbolPos, opKind,
-          [&stack, &result, binaryExpr, symbolPos, opKind]() {
-            if (!result) {
-              stack.emplace_back(
-                  binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
-                    llvm::detail::scope_exit<std::function<void(void)>> sexit(
-                        std::move(std::get<3>(stack.back())));
-                    stack.pop_back();
-                  });
-            } else {
-              llvm::detail::scope_exit<std::function<void(void)>> sexit(
-                  std::move(std::get<3>(stack.back())));
-              stack.pop_back();
-            }
-          });
+      stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind,
+                     [&stack, binaryExpr, symbolPos, opKind]() {
+                       if (!stack.getResult())
+                         stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind);
+                       else
+                         stack.returnValue(stack.getResult());
+                     });
       break;
     }
     // Floordiv and ceildiv are divisible by the given symbol when the first
@@ -480,34 +518,22 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
     case AffineExprKind::CeilDiv: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
       if (opKind != expr.getKind()) {
-        result = false;
-        llvm::detail::scope_exit<std::function<void(void)>> sexit(
-            std::move(std::get<3>(stack.back())));
-        stack.pop_back();
+        stack.returnValue(false);
         break;
       }
       if (llvm::any_of(stack, [](auto &it) {
             return std::get<0>(it).getKind() == AffineExprKind::Mul;
           })) {
-        result = false;
-        llvm::detail::scope_exit<std::function<void(void)>> sexit(
-            std::move(std::get<3>(stack.back())));
-        stack.pop_back();
+        stack.returnValue(false);
         break;
       }
-
-      stack.emplace_back(
-          binaryExpr.getLHS(), symbolPos, expr.getKind(), [&stack]() {
-            llvm::detail::scope_exit<std::function<void(void)>> sexit(
-                std::move(std::get<3>(stack.back())));
-            stack.pop_back();
-          });
+      stack.pushArgs(binaryExpr.getLHS(), symbolPos, expr.getKind());
       break;
     }
       llvm_unreachable("Unknown AffineExpr");
     }
   }
-  return result;
+  return stack.getResult();
 }
 
 /// Divides the given expression by the given symbol at position `symbolPos`. It



More information about the Mlir-commits mailing list