[Mlir-commits] [mlir] [mlir][affine] fix the issue of celidiv mul childiv expression not satisfying commutative (PR #109382)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 20 07:41:34 PDT 2024
https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/109382
>From e8b927859a2b7cc7182dff706377200445de259a Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 20 Sep 2024 14:40:31 +0800
Subject: [PATCH 1/2] [mlir][affine] fix the issue of celidiv mul childiv
expression not satisfying commutative
Fixes https://github.com/llvm/llvm-project/issues/107508
---
mlir/lib/IR/AffineExpr.cpp | 164 ++++++++++++------
.../Dialect/Affine/simplify-structures.mlir | 22 ++-
2 files changed, 134 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..cc8c4c21b96cee 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -9,6 +9,8 @@
#include <cmath>
#include <cstdint>
#include <limits>
+#include <numeric>
+#include <optional>
#include <utility>
#include "AffineExprDetail.h"
@@ -18,9 +20,8 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/MathExtras.h"
-#include <numeric>
-#include <optional>
using namespace mlir;
using namespace mlir::detail;
@@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
- switch (expr.getKind()) {
- case AffineExprKind::Constant:
- return cast<AffineConstantExpr>(expr).getValue() == 0;
- case AffineExprKind::DimId:
- return false;
- case AffineExprKind::SymbolId:
- return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
- // Checks divisibility by the given symbol for both operands.
- case AffineExprKind::Add: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
- }
- // Checks divisibility by the given symbol for both operands. Consider the
- // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
- // this is a division by s1 and both the operands of modulo are divisible by
- // s1 but it is not divisible by s1 always. The third argument is
- // `AffineExprKind::Mod` for this reason.
- case AffineExprKind::Mod: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
- AffineExprKind::Mod) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
- AffineExprKind::Mod);
- }
- // Checks if any of the operand divisible by the given symbol.
- case AffineExprKind::Mul: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
- }
- // Floordiv and ceildiv are divisible by the given symbol when the first
- // operand is divisible, and the affine expression kind of the argument expr
- // is same as the argument `opKind`. This can be inferred from commutative
- // property of floordiv and ceildiv operations and are as follow:
- // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
- // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
- // It will fail if operations are not same. For example:
- // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
- case AffineExprKind::FloorDiv:
- case AffineExprKind::CeilDiv: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- if (opKind != expr.getKind())
- return false;
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
- }
+ std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+ llvm::detail::scope_exit<std::function<void(void)>>>>
+ stack;
+ stack.emplace_back(expr, symbolPos, opKind, []() {});
+ bool result = false;
+
+ while (!stack.empty()) {
+ AffineExpr expr = std::get<0>(stack.back());
+ unsigned symbolPos = std::get<1>(stack.back());
+ AffineExprKind opKind = std::get<2>(stack.back());
+
+ switch (expr.getKind()) {
+ case AffineExprKind::Constant: {
+ // Note: Assignment must occur before pop, which will affect whether it
+ // enters other execution branches.
+ result = cast<AffineConstantExpr>(expr).getValue() == 0;
+ stack.pop_back();
+ break;
+ }
+ case AffineExprKind::DimId: {
+ result = false;
+ stack.pop_back();
+ break;
+ }
+ case AffineExprKind::SymbolId: {
+ result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+ stack.pop_back();
+ break;
+ }
+ // Checks divisibility by the given symbol for both operands.
+ case AffineExprKind::Add: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos, opKind,
+ [&stack]() { stack.pop_back(); });
+ } else {
+ stack.pop_back();
+ }
+ });
+ break;
+ }
+ // Checks divisibility by the given symbol for both operands. Consider the
+ // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
+ // s1`, this is a division by s1 and both the operands of modulo are
+ // divisible by s1 but it is not divisible by s1 always. The third argument
+ // is `AffineExprKind::Mod` for this reason.
+ case AffineExprKind::Mod: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos,
+ AffineExprKind::Mod,
+ [&stack]() { stack.pop_back(); });
+ } else {
+ stack.pop_back();
+ }
+ });
+ break;
+ }
+ // Checks if any of the operand divisible by the given symbol.
+ case AffineExprKind::Mul: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (!result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos, opKind,
+ [&stack]() { stack.pop_back(); });
+ } else {
+ stack.pop_back();
+ }
+ });
+ break;
+ }
+ // Floordiv and ceildiv are divisible by the given symbol when the first
+ // operand is divisible, and the affine expression kind of the argument expr
+ // is same as the argument `opKind`. This can be inferred from commutative
+ // property of floordiv and ceildiv operations and are as follow:
+ // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
+ // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
+ // It will fail 1.if operations are not same. For example:
+ // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
+ // multiplication operation in the expression. For example:
+ // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
+ case AffineExprKind::FloorDiv:
+ case AffineExprKind::CeilDiv: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ if (opKind != expr.getKind()) {
+ result = false;
+ stack.pop_back();
+ break;
+ }
+ if (llvm::any_of(stack, [](auto &it) {
+ return std::get<0>(it).getKind() == AffineExprKind::Mul;
+ })) {
+ result = false;
+ stack.pop_back();
+ break;
+ }
+
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
+ [&stack]() { stack.pop_back(); });
+ break;
+ }
+ llvm_unreachable("Unknown AffineExpr");
+ }
}
- llvm_unreachable("Unknown AffineExpr");
+ return result;
}
/// Divides the given expression by the given symbol at position `symbolPos`. It
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
}
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+ // CHECK: %[[CST:.*]] = arith.constant 43
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
- // CHECK: %[[CST:.*]] = arith.constant 47
+ // CHECK-NOT: arith.constant
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+ // CHECK-NOT: arith.constant
return %a : index
}
>From 4410645f4f2cfb1a4a4a4749c43f6965894c9134 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 20 Sep 2024 18:51:51 +0800
Subject: [PATCH 2/2] refine
---
mlir/lib/IR/AffineExpr.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc8c4c21b96cee..05ce0937fa9dba 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -415,7 +415,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
- [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ [&stack, &result, binaryExpr, symbolPos]() {
if (result) {
stack.emplace_back(
binaryExpr.getRHS(), symbolPos,
More information about the Mlir-commits
mailing list