[Mlir-commits] [mlir] [mlir][affine] fix the issue of celidiv mul ceildiv expression not satisfying commutative (PR #109382)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 1 23:01:42 PDT 2024
https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/109382
>From 185863695e191a8dce6b37f38bd2dba086da08d3 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 20 Sep 2024 14:40:31 +0800
Subject: [PATCH 1/5] [mlir][affine] fix the issue of ceildiv-mul-ceildiv form
expression not satisfying commutative
Fixes https://github.com/llvm/llvm-project/issues/107508
---
mlir/lib/IR/AffineExpr.cpp | 164 ++++++++++++------
.../Dialect/Affine/simplify-structures.mlir | 22 ++-
2 files changed, 134 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..cc8c4c21b96cee 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -9,6 +9,8 @@
#include <cmath>
#include <cstdint>
#include <limits>
+#include <numeric>
+#include <optional>
#include <utility>
#include "AffineExprDetail.h"
@@ -18,9 +20,8 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/MathExtras.h"
-#include <numeric>
-#include <optional>
using namespace mlir;
using namespace mlir::detail;
@@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
- switch (expr.getKind()) {
- case AffineExprKind::Constant:
- return cast<AffineConstantExpr>(expr).getValue() == 0;
- case AffineExprKind::DimId:
- return false;
- case AffineExprKind::SymbolId:
- return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
- // Checks divisibility by the given symbol for both operands.
- case AffineExprKind::Add: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
- }
- // Checks divisibility by the given symbol for both operands. Consider the
- // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
- // this is a division by s1 and both the operands of modulo are divisible by
- // s1 but it is not divisible by s1 always. The third argument is
- // `AffineExprKind::Mod` for this reason.
- case AffineExprKind::Mod: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
- AffineExprKind::Mod) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
- AffineExprKind::Mod);
- }
- // Checks if any of the operand divisible by the given symbol.
- case AffineExprKind::Mul: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
- }
- // Floordiv and ceildiv are divisible by the given symbol when the first
- // operand is divisible, and the affine expression kind of the argument expr
- // is same as the argument `opKind`. This can be inferred from commutative
- // property of floordiv and ceildiv operations and are as follow:
- // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
- // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
- // It will fail if operations are not same. For example:
- // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
- case AffineExprKind::FloorDiv:
- case AffineExprKind::CeilDiv: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- if (opKind != expr.getKind())
- return false;
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
- }
+ std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+ llvm::detail::scope_exit<std::function<void(void)>>>>
+ stack;
+ stack.emplace_back(expr, symbolPos, opKind, []() {});
+ bool result = false;
+
+ while (!stack.empty()) {
+ AffineExpr expr = std::get<0>(stack.back());
+ unsigned symbolPos = std::get<1>(stack.back());
+ AffineExprKind opKind = std::get<2>(stack.back());
+
+ switch (expr.getKind()) {
+ case AffineExprKind::Constant: {
+ // Note: Assignment must occur before pop, which will affect whether it
+ // enters other execution branches.
+ result = cast<AffineConstantExpr>(expr).getValue() == 0;
+ stack.pop_back();
+ break;
+ }
+ case AffineExprKind::DimId: {
+ result = false;
+ stack.pop_back();
+ break;
+ }
+ case AffineExprKind::SymbolId: {
+ result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+ stack.pop_back();
+ break;
+ }
+ // Checks divisibility by the given symbol for both operands.
+ case AffineExprKind::Add: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos, opKind,
+ [&stack]() { stack.pop_back(); });
+ } else {
+ stack.pop_back();
+ }
+ });
+ break;
+ }
+ // Checks divisibility by the given symbol for both operands. Consider the
+ // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
+ // s1`, this is a division by s1 and both the operands of modulo are
+ // divisible by s1 but it is not divisible by s1 always. The third argument
+ // is `AffineExprKind::Mod` for this reason.
+ case AffineExprKind::Mod: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos,
+ AffineExprKind::Mod,
+ [&stack]() { stack.pop_back(); });
+ } else {
+ stack.pop_back();
+ }
+ });
+ break;
+ }
+ // Checks if any of the operand divisible by the given symbol.
+ case AffineExprKind::Mul: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (!result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos, opKind,
+ [&stack]() { stack.pop_back(); });
+ } else {
+ stack.pop_back();
+ }
+ });
+ break;
+ }
+ // Floordiv and ceildiv are divisible by the given symbol when the first
+ // operand is divisible, and the affine expression kind of the argument expr
+ // is same as the argument `opKind`. This can be inferred from commutative
+ // property of floordiv and ceildiv operations and are as follow:
+ // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
+ // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
+ // It will fail 1.if operations are not same. For example:
+ // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
+ // multiplication operation in the expression. For example:
+ // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
+ case AffineExprKind::FloorDiv:
+ case AffineExprKind::CeilDiv: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ if (opKind != expr.getKind()) {
+ result = false;
+ stack.pop_back();
+ break;
+ }
+ if (llvm::any_of(stack, [](auto &it) {
+ return std::get<0>(it).getKind() == AffineExprKind::Mul;
+ })) {
+ result = false;
+ stack.pop_back();
+ break;
+ }
+
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
+ [&stack]() { stack.pop_back(); });
+ break;
+ }
+ llvm_unreachable("Unknown AffineExpr");
+ }
}
- llvm_unreachable("Unknown AffineExpr");
+ return result;
}
/// Divides the given expression by the given symbol at position `symbolPos`. It
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
}
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+ // CHECK: %[[CST:.*]] = arith.constant 43
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
- // CHECK: %[[CST:.*]] = arith.constant 47
+ // CHECK-NOT: arith.constant
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+ // CHECK-NOT: arith.constant
return %a : index
}
>From c75822467ae32398e586cbe8badb1dfc1d89e386 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 20 Sep 2024 18:51:51 +0800
Subject: [PATCH 2/5] refine
---
mlir/lib/IR/AffineExpr.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc8c4c21b96cee..05ce0937fa9dba 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -415,7 +415,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
- [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ [&stack, &result, binaryExpr, symbolPos]() {
if (result) {
stack.emplace_back(
binaryExpr.getRHS(), symbolPos,
>From dee51cdf83bcf865c76ea792a874314b7b30f30e Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Tue, 24 Sep 2024 16:52:02 +0800
Subject: [PATCH 3/5] refine
---
mlir/lib/IR/AffineExpr.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 05ce0937fa9dba..b3d8f1cd0c7313 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -363,7 +363,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
- std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+ SmallVector<std::tuple<AffineExpr, unsigned, AffineExprKind,
llvm::detail::scope_exit<std::function<void(void)>>>>
stack;
stack.emplace_back(expr, symbolPos, opKind, []() {});
>From c70e79414ffef1d1f81a0d2d30e1d265aa4bdf5c Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sat, 28 Sep 2024 08:07:12 -0400
Subject: [PATCH 4/5] fix ci fail on windows platform
---
mlir/lib/IR/AffineExpr.cpp | 98 +++++++++++++++++++++++++-------------
1 file changed, 65 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index b3d8f1cd0c7313..ef037a2e133657 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -379,32 +379,44 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// Note: Assignment must occur before pop, which will affect whether it
// enters other execution branches.
result = cast<AffineConstantExpr>(expr).getValue() == 0;
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}
case AffineExprKind::DimId: {
result = false;
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}
case AffineExprKind::SymbolId: {
result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
- [&stack, &result, binaryExpr, symbolPos, opKind]() {
- if (result) {
- stack.emplace_back(
- binaryExpr.getRHS(), symbolPos, opKind,
- [&stack]() { stack.pop_back(); });
- } else {
- stack.pop_back();
- }
- });
+ stack.emplace_back(
+ binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
+ stack.pop_back();
+ });
+ } else {
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
+ stack.pop_back();
+ }
+ });
break;
}
// Checks divisibility by the given symbol for both operands. Consider the
@@ -414,32 +426,44 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// is `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
- [&stack, &result, binaryExpr, symbolPos]() {
- if (result) {
- stack.emplace_back(
- binaryExpr.getRHS(), symbolPos,
- AffineExprKind::Mod,
- [&stack]() { stack.pop_back(); });
- } else {
- stack.pop_back();
- }
- });
+ stack.emplace_back(
+ binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+ [&stack, &result, binaryExpr, symbolPos]() {
+ if (result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos, AffineExprKind::Mod,
+ [&stack]() {
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
+ stack.pop_back();
+ });
+ } else {
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
+ stack.pop_back();
+ }
+ });
break;
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
- [&stack, &result, binaryExpr, symbolPos, opKind]() {
- if (!result) {
- stack.emplace_back(
- binaryExpr.getRHS(), symbolPos, opKind,
- [&stack]() { stack.pop_back(); });
- } else {
- stack.pop_back();
- }
- });
+ stack.emplace_back(
+ binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (!result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
+ stack.pop_back();
+ });
+ } else {
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
+ stack.pop_back();
+ }
+ });
break;
}
// Floordiv and ceildiv are divisible by the given symbol when the first
@@ -457,6 +481,8 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind()) {
result = false;
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}
@@ -464,12 +490,18 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
return std::get<0>(it).getKind() == AffineExprKind::Mul;
})) {
result = false;
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
stack.pop_back();
break;
}
- stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
- [&stack]() { stack.pop_back(); });
+ stack.emplace_back(
+ binaryExpr.getLHS(), symbolPos, expr.getKind(), [&stack]() {
+ llvm::detail::scope_exit<std::function<void(void)>> sexit(
+ std::move(std::get<3>(stack.back())));
+ stack.pop_back();
+ });
break;
}
llvm_unreachable("Unknown AffineExpr");
>From 514d78c113ebf3340d0bafbad49baaaca655db97 Mon Sep 17 00:00:00 2001
From: "long.chen" <lipracer at gmail.com>
Date: Wed, 2 Oct 2024 04:51:01 +0000
Subject: [PATCH 5/5] refine
---
mlir/lib/IR/AffineExpr.cpp | 196 +++++++++++++++++++++----------------
1 file changed, 111 insertions(+), 85 deletions(-)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index ef037a2e133657..82b0bc193bfb71 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -350,6 +350,82 @@ unsigned AffineDimExpr::getPosition() const {
return static_cast<ImplType *>(expr)->position;
}
+/// A manually managed stack used to convert recursive function calls into
+/// looping utility classes during the access tree structure process. This has
+/// two benefits: one is to access the current stack, and the other is to avoid
+/// stack explosion when the recursion depth is too deep. Typically, recursive
+/// calls take the form of the following:
+/// push node
+/// visit tree node
+/// push node->left_node
+/// ...
+/// pop left_node
+/// check result and do something
+/// push node->right_node
+/// pop right_node
+/// pop node
+/// ...
+/// This form can be converted into the following form:
+/// push node
+/// visit tree node
+/// push node->left_node
+/// ...
+/// pop left_node and do {
+/// check result and do something
+/// push node->right_node
+/// pop right_node and do { pop node }
+/// }
+/// ...
+/// so we need to perform some operations
+/// after an element is pushed out of the stack. We use the `scope_exit`
+/// structure to invoke these operations.
+template <typename RetT, typename... ArgsT>
+class CallStack {
+public:
+ using value_type =
+ std::tuple<ArgsT..., llvm::detail::scope_exit<std::function<void(void)>>>;
+ CallStack(ArgsT... args) {
+ stack_.emplace_back(args..., []() {});
+ }
+
+ /// Push the parameters into the stack and record the operation to be executed
+ /// when the node access ends. By default, the previous stack element will pop
+ /// up. If you need to check the result of the current push to the stack, you
+ /// need to pass in a function and manually perform the push operation after
+ /// the function ends.
+ void pushArgs(ArgsT... args, const std::function<void(void)> &onExit = {}) {
+ if (onExit)
+ stack_.emplace_back(args..., onExit);
+ else
+ stack_.emplace_back(args..., [this]() { pop(); });
+ }
+
+ RetT getResult() const { return value_; }
+
+ void returnValue(RetT value) {
+ value_ = value;
+ pop();
+ }
+
+ value_type &top() { return stack_.back(); }
+
+ auto begin() const { return stack_.begin(); }
+
+ auto end() const { return stack_.end(); }
+
+ bool empty() const { return stack_.empty(); }
+
+private:
+ /// Note: We must move the top element of the stack and then perform the
+ /// stack pop operation. If we directly pop the stack, the `scope_exit` may
+ /// modify the stack, which may cause the program on the Windows platform to
+ /// crash, but it works normally on Ubuntu.git
+ void pop() { value_type _(stack_.pop_back_val()); }
+
+ SmallVector<value_type> stack_;
+ RetT value_;
+};
+
/// Returns true if the expression is divisible by the given symbol with
/// position `symbolPos`. The argument `opKind` specifies here what kind of
/// division or mod operation called this division. It helps in implementing the
@@ -363,60 +439,40 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
- SmallVector<std::tuple<AffineExpr, unsigned, AffineExprKind,
- llvm::detail::scope_exit<std::function<void(void)>>>>
- stack;
- stack.emplace_back(expr, symbolPos, opKind, []() {});
- bool result = false;
+ CallStack<bool, AffineExpr, unsigned, AffineExprKind> stack(expr, symbolPos,
+ opKind);
while (!stack.empty()) {
- AffineExpr expr = std::get<0>(stack.back());
- unsigned symbolPos = std::get<1>(stack.back());
- AffineExprKind opKind = std::get<2>(stack.back());
+ AffineExpr expr = std::get<0>(stack.top());
+ unsigned symbolPos = std::get<1>(stack.top());
+ AffineExprKind opKind = std::get<2>(stack.top());
switch (expr.getKind()) {
case AffineExprKind::Constant: {
// Note: Assignment must occur before pop, which will affect whether it
// enters other execution branches.
- result = cast<AffineConstantExpr>(expr).getValue() == 0;
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
+ stack.returnValue(cast<AffineConstantExpr>(expr).getValue() == 0);
break;
}
case AffineExprKind::DimId: {
- result = false;
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
+ stack.returnValue(false);
break;
}
case AffineExprKind::SymbolId: {
- result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
+ stack.returnValue(cast<AffineSymbolExpr>(expr).getPosition() ==
+ symbolPos);
break;
}
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- stack.emplace_back(
- binaryExpr.getLHS(), symbolPos, opKind,
- [&stack, &result, binaryExpr, symbolPos, opKind]() {
- if (result) {
- stack.emplace_back(
- binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
- });
- } else {
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
- }
- });
+ stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, binaryExpr, symbolPos, opKind]() {
+ if (stack.getResult())
+ stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind);
+ else
+ stack.returnValue(stack.getResult());
+ });
break;
}
// Checks divisibility by the given symbol for both operands. Consider the
@@ -426,44 +482,26 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// is `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- stack.emplace_back(
- binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
- [&stack, &result, binaryExpr, symbolPos]() {
- if (result) {
- stack.emplace_back(
- binaryExpr.getRHS(), symbolPos, AffineExprKind::Mod,
- [&stack]() {
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
- });
- } else {
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
- }
- });
+ stack.pushArgs(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+ [&stack, binaryExpr, symbolPos]() {
+ if (stack.getResult())
+ stack.pushArgs(binaryExpr.getRHS(), symbolPos,
+ AffineExprKind::Mod);
+ else
+ stack.returnValue(stack.getResult());
+ });
break;
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- stack.emplace_back(
- binaryExpr.getLHS(), symbolPos, opKind,
- [&stack, &result, binaryExpr, symbolPos, opKind]() {
- if (!result) {
- stack.emplace_back(
- binaryExpr.getRHS(), symbolPos, opKind, [&stack]() {
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
- });
- } else {
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
- }
- });
+ stack.pushArgs(binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, binaryExpr, symbolPos, opKind]() {
+ if (!stack.getResult())
+ stack.pushArgs(binaryExpr.getRHS(), symbolPos, opKind);
+ else
+ stack.returnValue(stack.getResult());
+ });
break;
}
// Floordiv and ceildiv are divisible by the given symbol when the first
@@ -480,34 +518,22 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
case AffineExprKind::CeilDiv: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind()) {
- result = false;
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
+ stack.returnValue(false);
break;
}
if (llvm::any_of(stack, [](auto &it) {
return std::get<0>(it).getKind() == AffineExprKind::Mul;
})) {
- result = false;
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
+ stack.returnValue(false);
break;
}
-
- stack.emplace_back(
- binaryExpr.getLHS(), symbolPos, expr.getKind(), [&stack]() {
- llvm::detail::scope_exit<std::function<void(void)>> sexit(
- std::move(std::get<3>(stack.back())));
- stack.pop_back();
- });
+ stack.pushArgs(binaryExpr.getLHS(), symbolPos, expr.getKind());
break;
}
llvm_unreachable("Unknown AffineExpr");
}
}
- return result;
+ return stack.getResult();
}
/// Divides the given expression by the given symbol at position `symbolPos`. It
More information about the Mlir-commits
mailing list