[Mlir-commits] [mlir] 51a2f50 - [mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression not satisfying commutative (#111254)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Oct 12 08:26:01 PDT 2024


Author: long.chen
Date: 2024-10-12T11:25:57-04:00
New Revision: 51a2f50ee76a52d4a542822bfda9c4c17904b72f

URL: https://github.com/llvm/llvm-project/commit/51a2f50ee76a52d4a542822bfda9c4c17904b72f
DIFF: https://github.com/llvm/llvm-project/commit/51a2f50ee76a52d4a542822bfda9c4c17904b72f.diff

LOG: [mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression not satisfying commutative (#111254)

my prove:
we can simple `(n * s) ceildiv a ceildiv s` to `n ceildiv a`
because `(n * s) ceildiv a ceildiv b` <=> `(n * s) ceildiv s ceildiv a`
<=> `n ceildiv a`

let's prove the `s floordiv a floor b` <=> `s floordiv b floor a`
let `s = ka +m (m < a)` so `s floordiv a` <=> `s / a - m / a`

similarly, it can be proven that: 
`s floordiv a floordiv b` <=> `s / (a * b) - m / (a * b) - n / (b)   constrain  (n < b)` 
<=> `s / (a * b) - (m + a*n) / (a*b)`

because `a* b - (m + a*n)` <=> `a*b - a*n - m` > `a - m` > `0`
so `s floordiv a floordiv b` <=> `[s / (a*b)]` <=> `s floordiv b floordiv a`
but if `s floordiv b` mutiply a factor above didn't always hold true.

Fixes https://github.com/llvm/llvm-project/issues/107508

Added: 
    

Modified: 
    mlir/lib/IR/AffineExpr.cpp
    mlir/test/Dialect/Affine/simplify-structures.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..2291d64c50a560 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -356,8 +356,9 @@ unsigned AffineDimExpr::getPosition() const {
 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
 /// operation, then the commutative property can be used otherwise, the floordiv
 /// operation is not divisible. The same argument holds for ceildiv operation.
-static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
-                                AffineExprKind opKind) {
+static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
+                                        AffineExprKind opKind,
+                                        bool fromMul = false) {
   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
@@ -372,8 +373,9 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   // Checks divisibility by the given symbol for both operands.
   case AffineExprKind::Add: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+    return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
+                                       opKind) &&
+           canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
   }
   // Checks divisibility by the given symbol for both operands. Consider the
   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -382,16 +384,18 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   // `AffineExprKind::Mod` for this reason.
   case AffineExprKind::Mod: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
-                               AffineExprKind::Mod) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
-                               AffineExprKind::Mod);
+    return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
+                                       AffineExprKind::Mod) &&
+           canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos,
+                                       AffineExprKind::Mod);
   }
   // Checks if any of the operand divisible by the given symbol.
   case AffineExprKind::Mul: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+    return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind,
+                                       true) ||
+           canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind,
+                                       true);
   }
   // Floordiv and ceildiv are divisible by the given symbol when the first
   // operand is divisible, and the affine expression kind of the argument expr
@@ -399,14 +403,19 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   // property of floordiv and ceildiv operations and are as follow:
   // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
   // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
-  // It will fail if operations are not same. For example:
-  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
+  // It will fail 1.if operations are not same. For example:
+  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
+  // multiplication operation in the expression. For example:
+  //  (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
   case AffineExprKind::FloorDiv:
   case AffineExprKind::CeilDiv: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
     if (opKind != expr.getKind())
       return false;
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
+    if (fromMul)
+      return false;
+    return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
+                                       expr.getKind());
   }
   }
   llvm_unreachable("Unknown AffineExpr");
@@ -448,7 +457,7 @@ static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
   // Dividing any of the operand by the given symbol.
   case AffineExprKind::Mul: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
+    if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
       return binaryExpr.getLHS() *
              symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
     return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
@@ -583,7 +592,8 @@ static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
     if (!symbolExpr)
       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
     unsigned symbolPos = symbolExpr.getPosition();
-    if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
+    if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
+                                     expr.getKind()))
       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
     if (expr.getKind() == AffineExprKind::Mod)
       return getAffineConstantExpr(0, expr.getContext());

diff  --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
 }
 
 // Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = arith.constant 43
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
   %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
-  // CHECK:       %[[CST:.*]] = arith.constant 47
+  // CHECK-NOT:       arith.constant
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+  // CHECK-NOT:       arith.constant
   return %a : index
 }
 


        


More information about the Mlir-commits mailing list