[Mlir-commits] [mlir] [mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression n… (PR #111254)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 7 00:19:03 PDT 2024


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/111254

>From a1dcc066c49afcc42dfa3a4899de8a1829d5778f Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sat, 5 Oct 2024 08:55:03 -0400
Subject: [PATCH 1/2] [mlir][affine] fix the issue of ceildiv-mul-ceildiv form
 expression not satisfying commutative

Fixes https://github.com/llvm/llvm-project/issues/107508
---
 mlir/lib/IR/AffineExpr.cpp                    | 48 ++++++++++++++-----
 .../Dialect/Affine/simplify-structures.mlir   | 22 +++++++--
 2 files changed, 56 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..f947b8c3d54c6a 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -349,6 +349,8 @@ unsigned AffineDimExpr::getPosition() const {
   return static_cast<ImplType *>(expr)->position;
 }
 
+namespace {
+
 /// Returns true if the expression is divisible by the given symbol with
 /// position `symbolPos`. The argument `opKind` specifies here what kind of
 /// division or mod operation called this division. It helps in implementing the
@@ -356,12 +358,17 @@ unsigned AffineDimExpr::getPosition() const {
 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
 /// operation, then the commutative property can be used otherwise, the floordiv
 /// operation is not divisible. The same argument holds for ceildiv operation.
-static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
-                                AffineExprKind opKind) {
+bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
+                             AffineExprKind opKind,
+                             SmallVectorImpl<AffineExpr> &visitedExprs,
+                             size_t depth = 0) {
   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
+  if (visitedExprs.size() > depth)
+    visitedExprs.resize(depth);
+  visitedExprs.emplace_back(expr);
   switch (expr.getKind()) {
   case AffineExprKind::Constant:
     return cast<AffineConstantExpr>(expr).getValue() == 0;
@@ -372,8 +379,10 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   // Checks divisibility by the given symbol for both operands.
   case AffineExprKind::Add: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
+                                   visitedExprs, depth + 1) &&
+           isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
+                                   visitedExprs, depth + 1);
   }
   // Checks divisibility by the given symbol for both operands. Consider the
   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -382,16 +391,20 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   // `AffineExprKind::Mod` for this reason.
   case AffineExprKind::Mod: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
-                               AffineExprKind::Mod) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
-                               AffineExprKind::Mod);
+    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
+                                   AffineExprKind::Mod, visitedExprs,
+                                   depth + 1) &&
+           isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos,
+                                   AffineExprKind::Mod, visitedExprs,
+                                   depth + 1);
   }
   // Checks if any of the operand divisible by the given symbol.
   case AffineExprKind::Mul: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
+                                   visitedExprs, depth + 1) ||
+           isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
+                                   visitedExprs, depth + 1);
   }
   // Floordiv and ceildiv are divisible by the given symbol when the first
   // operand is divisible, and the affine expression kind of the argument expr
@@ -406,12 +419,25 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
     if (opKind != expr.getKind())
       return false;
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
+    if (llvm::any_of(visitedExprs, [](auto expr) {
+          return expr.getKind() == AffineExprKind::Mul;
+        }))
+      return false;
+    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
+                                   expr.getKind(), visitedExprs, depth + 1);
   }
   }
   llvm_unreachable("Unknown AffineExpr");
 }
 
+bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
+                         AffineExprKind opKind) {
+  SmallVector<AffineExpr> visitedExprs;
+  return isDivisibleBySymbolImpl(expr, symbolPos, opKind, visitedExprs);
+}
+
+} // namespace
+
 /// Divides the given expression by the given symbol at position `symbolPos`. It
 /// considers the divisibility condition is checked before calling itself. A
 /// null expression is returned whenever the divisibility condition fails.
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
 }
 
 // Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = arith.constant 43
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
   %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
-  // CHECK:       %[[CST:.*]] = arith.constant 47
+  // CHECK-NOT:       arith.constant
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+  // CHECK-NOT:       arith.constant
   return %a : index
 }
 

>From 169f415d54be20e28613d0b6412c8b7ec1d6c7e3 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Mon, 7 Oct 2024 03:13:34 -0400
Subject: [PATCH 2/2] refine

---
 mlir/lib/IR/AffineExpr.cpp | 54 +++++++++++---------------------------
 1 file changed, 16 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index f947b8c3d54c6a..bfd6fad5022afc 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -349,8 +349,6 @@ unsigned AffineDimExpr::getPosition() const {
   return static_cast<ImplType *>(expr)->position;
 }
 
-namespace {
-
 /// Returns true if the expression is divisible by the given symbol with
 /// position `symbolPos`. The argument `opKind` specifies here what kind of
 /// division or mod operation called this division. It helps in implementing the
@@ -358,17 +356,12 @@ namespace {
 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
 /// operation, then the commutative property can be used otherwise, the floordiv
 /// operation is not divisible. The same argument holds for ceildiv operation.
-bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
-                             AffineExprKind opKind,
-                             SmallVectorImpl<AffineExpr> &visitedExprs,
-                             size_t depth = 0) {
+static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
+                                AffineExprKind opKind, bool fromMul = false) {
   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
-  if (visitedExprs.size() > depth)
-    visitedExprs.resize(depth);
-  visitedExprs.emplace_back(expr);
   switch (expr.getKind()) {
   case AffineExprKind::Constant:
     return cast<AffineConstantExpr>(expr).getValue() == 0;
@@ -379,10 +372,8 @@ bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
   // Checks divisibility by the given symbol for both operands.
   case AffineExprKind::Add: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
-                                   visitedExprs, depth + 1) &&
-           isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
-                                   visitedExprs, depth + 1);
+    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
+           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
   }
   // Checks divisibility by the given symbol for both operands. Consider the
   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -391,20 +382,16 @@ bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
   // `AffineExprKind::Mod` for this reason.
   case AffineExprKind::Mod: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
-                                   AffineExprKind::Mod, visitedExprs,
-                                   depth + 1) &&
-           isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos,
-                                   AffineExprKind::Mod, visitedExprs,
-                                   depth + 1);
+    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
+                               AffineExprKind::Mod) &&
+           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
+                               AffineExprKind::Mod);
   }
   // Checks if any of the operand divisible by the given symbol.
   case AffineExprKind::Mul: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
-                                   visitedExprs, depth + 1) ||
-           isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
-                                   visitedExprs, depth + 1);
+    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind, true) ||
+           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind, true);
   }
   // Floordiv and ceildiv are divisible by the given symbol when the first
   // operand is divisible, and the affine expression kind of the argument expr
@@ -412,32 +399,23 @@ bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
   // property of floordiv and ceildiv operations and are as follow:
   // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
   // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
-  // It will fail if operations are not same. For example:
-  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
+  // It will fail 1.if operations are not same. For example:
+  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
+  // multiplication operation in the expression. For example:
+  //  (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
   case AffineExprKind::FloorDiv:
   case AffineExprKind::CeilDiv: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
     if (opKind != expr.getKind())
       return false;
-    if (llvm::any_of(visitedExprs, [](auto expr) {
-          return expr.getKind() == AffineExprKind::Mul;
-        }))
+    if (fromMul)
       return false;
-    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
-                                   expr.getKind(), visitedExprs, depth + 1);
+    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
   }
   }
   llvm_unreachable("Unknown AffineExpr");
 }
 
-bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
-                         AffineExprKind opKind) {
-  SmallVector<AffineExpr> visitedExprs;
-  return isDivisibleBySymbolImpl(expr, symbolPos, opKind, visitedExprs);
-}
-
-} // namespace
-
 /// Divides the given expression by the given symbol at position `symbolPos`. It
 /// considers the divisibility condition is checked before calling itself. A
 /// null expression is returned whenever the divisibility condition fails.



More information about the Mlir-commits mailing list