[Mlir-commits] [mlir] [mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression n… (PR #111254)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Oct 5 06:08:57 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-affine
Author: long.chen (lipracer)
<details>
<summary>Changes</summary>
…ot satisfying commutative
Fixes https://github.com/llvm/llvm-project/issues/107508
---
Full diff: https://github.com/llvm/llvm-project/pull/111254.diff
2 Files Affected:
- (modified) mlir/lib/IR/AffineExpr.cpp (+37-11)
- (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (+19-3)
``````````diff
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..f947b8c3d54c6a 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -349,6 +349,8 @@ unsigned AffineDimExpr::getPosition() const {
return static_cast<ImplType *>(expr)->position;
}
+namespace {
+
/// Returns true if the expression is divisible by the given symbol with
/// position `symbolPos`. The argument `opKind` specifies here what kind of
/// division or mod operation called this division. It helps in implementing the
@@ -356,12 +358,17 @@ unsigned AffineDimExpr::getPosition() const {
///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
/// operation, then the commutative property can be used otherwise, the floordiv
/// operation is not divisible. The same argument holds for ceildiv operation.
-static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
- AffineExprKind opKind) {
+bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
+ AffineExprKind opKind,
+ SmallVectorImpl<AffineExpr> &visitedExprs,
+ size_t depth = 0) {
// The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
+ if (visitedExprs.size() > depth)
+ visitedExprs.resize(depth);
+ visitedExprs.emplace_back(expr);
switch (expr.getKind()) {
case AffineExprKind::Constant:
return cast<AffineConstantExpr>(expr).getValue() == 0;
@@ -372,8 +379,10 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+ return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
+ visitedExprs, depth + 1) &&
+ isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
+ visitedExprs, depth + 1);
}
// Checks divisibility by the given symbol for both operands. Consider the
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -382,16 +391,20 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
- AffineExprKind::Mod) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
- AffineExprKind::Mod);
+ return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
+ AffineExprKind::Mod, visitedExprs,
+ depth + 1) &&
+ isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos,
+ AffineExprKind::Mod, visitedExprs,
+ depth + 1);
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+ return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
+ visitedExprs, depth + 1) ||
+ isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
+ visitedExprs, depth + 1);
}
// Floordiv and ceildiv are divisible by the given symbol when the first
// operand is divisible, and the affine expression kind of the argument expr
@@ -406,12 +419,25 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind())
return false;
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
+ if (llvm::any_of(visitedExprs, [](auto expr) {
+ return expr.getKind() == AffineExprKind::Mul;
+ }))
+ return false;
+ return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
+ expr.getKind(), visitedExprs, depth + 1);
}
}
llvm_unreachable("Unknown AffineExpr");
}
+bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
+ AffineExprKind opKind) {
+ SmallVector<AffineExpr> visitedExprs;
+ return isDivisibleBySymbolImpl(expr, symbolPos, opKind, visitedExprs);
+}
+
+} // namespace
+
/// Divides the given expression by the given symbol at position `symbolPos`. It
/// considers the divisibility condition is checked before calling itself. A
/// null expression is returned whenever the divisibility condition fails.
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
}
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+ // CHECK: %[[CST:.*]] = arith.constant 43
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
- // CHECK: %[[CST:.*]] = arith.constant 47
+ // CHECK-NOT: arith.constant
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+ // CHECK-NOT: arith.constant
return %a : index
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/111254
More information about the Mlir-commits
mailing list