[Mlir-commits] [mlir] [mlir][affine] fix the issue of celidiv mul childiv expression not satisfying commutative (PR #109382)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 20 07:41:34 PDT 2024


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/109382

>From e8b927859a2b7cc7182dff706377200445de259a Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 20 Sep 2024 14:40:31 +0800
Subject: [PATCH 1/2] [mlir][affine] fix the issue of celidiv mul childiv
 expression not satisfying commutative

Fixes https://github.com/llvm/llvm-project/issues/107508
---
 mlir/lib/IR/AffineExpr.cpp                    | 164 ++++++++++++------
 .../Dialect/Affine/simplify-structures.mlir   |  22 ++-
 2 files changed, 134 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..cc8c4c21b96cee 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -9,6 +9,8 @@
 #include <cmath>
 #include <cstdint>
 #include <limits>
+#include <numeric>
+#include <optional>
 #include <utility>
 
 #include "AffineExprDetail.h"
@@ -18,9 +20,8 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/MathExtras.h"
-#include <numeric>
-#include <optional>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
-  switch (expr.getKind()) {
-  case AffineExprKind::Constant:
-    return cast<AffineConstantExpr>(expr).getValue() == 0;
-  case AffineExprKind::DimId:
-    return false;
-  case AffineExprKind::SymbolId:
-    return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
-  // Checks divisibility by the given symbol for both operands.
-  case AffineExprKind::Add: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
-  }
-  // Checks divisibility by the given symbol for both operands. Consider the
-  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
-  // this is a division by s1 and both the operands of modulo are divisible by
-  // s1 but it is not divisible by s1 always. The third argument is
-  // `AffineExprKind::Mod` for this reason.
-  case AffineExprKind::Mod: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
-                               AffineExprKind::Mod) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
-                               AffineExprKind::Mod);
-  }
-  // Checks if any of the operand divisible by the given symbol.
-  case AffineExprKind::Mul: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
-  }
-  // Floordiv and ceildiv are divisible by the given symbol when the first
-  // operand is divisible, and the affine expression kind of the argument expr
-  // is same as the argument `opKind`. This can be inferred from commutative
-  // property of floordiv and ceildiv operations and are as follow:
-  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
-  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
-  // It will fail if operations are not same. For example:
-  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
-  case AffineExprKind::FloorDiv:
-  case AffineExprKind::CeilDiv: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    if (opKind != expr.getKind())
-      return false;
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
-  }
+  std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+                         llvm::detail::scope_exit<std::function<void(void)>>>>
+      stack;
+  stack.emplace_back(expr, symbolPos, opKind, []() {});
+  bool result = false;
+
+  while (!stack.empty()) {
+    AffineExpr expr = std::get<0>(stack.back());
+    unsigned symbolPos = std::get<1>(stack.back());
+    AffineExprKind opKind = std::get<2>(stack.back());
+
+    switch (expr.getKind()) {
+    case AffineExprKind::Constant: {
+      // Note: Assignment must occur before pop, which will affect whether it
+      // enters other execution branches.
+      result = cast<AffineConstantExpr>(expr).getValue() == 0;
+      stack.pop_back();
+      break;
+    }
+    case AffineExprKind::DimId: {
+      result = false;
+      stack.pop_back();
+      break;
+    }
+    case AffineExprKind::SymbolId: {
+      result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+      stack.pop_back();
+      break;
+    }
+    // Checks divisibility by the given symbol for both operands.
+    case AffineExprKind::Add: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos, opKind,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Checks divisibility by the given symbol for both operands. Consider the
+    // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
+    // s1`, this is a division by s1 and both the operands of modulo are
+    // divisible by s1 but it is not divisible by s1 always. The third argument
+    // is `AffineExprKind::Mod` for this reason.
+    case AffineExprKind::Mod: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos,
+                                 AffineExprKind::Mod,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Checks if any of the operand divisible by the given symbol.
+    case AffineExprKind::Mul: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (!result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos, opKind,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Floordiv and ceildiv are divisible by the given symbol when the first
+    // operand is divisible, and the affine expression kind of the argument expr
+    // is same as the argument `opKind`. This can be inferred from commutative
+    // property of floordiv and ceildiv operations and are as follow:
+    // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
+    // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
+    // It will fail 1.if operations are not same. For example:
+    // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
+    // multiplication operation in the expression. For example:
+    //  (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
+    case AffineExprKind::FloorDiv:
+    case AffineExprKind::CeilDiv: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      if (opKind != expr.getKind()) {
+        result = false;
+        stack.pop_back();
+        break;
+      }
+      if (llvm::any_of(stack, [](auto &it) {
+            return std::get<0>(it).getKind() == AffineExprKind::Mul;
+          })) {
+        result = false;
+        stack.pop_back();
+        break;
+      }
+
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
+                         [&stack]() { stack.pop_back(); });
+      break;
+    }
+      llvm_unreachable("Unknown AffineExpr");
+    }
   }
-  llvm_unreachable("Unknown AffineExpr");
+  return result;
 }
 
 /// Divides the given expression by the given symbol at position `symbolPos`. It
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
 }
 
 // Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = arith.constant 43
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
   %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
-  // CHECK:       %[[CST:.*]] = arith.constant 47
+  // CHECK-NOT:       arith.constant
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+  // CHECK-NOT:       arith.constant
   return %a : index
 }
 

>From 4410645f4f2cfb1a4a4a4749c43f6965894c9134 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 20 Sep 2024 18:51:51 +0800
Subject: [PATCH 2/2] refine

---
 mlir/lib/IR/AffineExpr.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc8c4c21b96cee..05ce0937fa9dba 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -415,7 +415,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
     case AffineExprKind::Mod: {
       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
       stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
-                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                         [&stack, &result, binaryExpr, symbolPos]() {
                            if (result) {
                              stack.emplace_back(
                                  binaryExpr.getRHS(), symbolPos,



More information about the Mlir-commits mailing list