# [Mlir-commits] [mlir] 56593fa - [MLIR] Simplify semi-affine expressions

Uday Bondhugula llvmlistbot at llvm.org
Tue Aug 4 09:41:44 PDT 2020

Author: Yash Jain
Date: 2020-08-04T22:07:18+05:30
New Revision: 56593fa370124a4d77703e7ddfa4dfca81e0c8f2

URL: https://github.com/llvm/llvm-project/commit/56593fa370124a4d77703e7ddfa4dfca81e0c8f2
DIFF: https://github.com/llvm/llvm-project/commit/56593fa370124a4d77703e7ddfa4dfca81e0c8f2.diff

LOG: [MLIR] Simplify semi-affine expressions

Simplify semi-affine expression for the operations like ceildiv,
floordiv and modulo by any given symbol by checking divisibilty by that
symbol.

Some properties used in simplification are:

1) Commutative property of the floordiv and ceildiv:
((expr1 floordiv expr2) floordiv expr3 ) = ((expr1 floordiv expr3) floordiv expr2)
((expr1 ceildiv expr2) ceildiv expr3 ) = ((expr1 ceildiv expr3) ceildiv expr2)

While simplification if operations are different no simplification is
possible as there is no property that simplify expressions like these:
((expr1 ceildiv expr2) floordiv expr3) or  ((expr1 floordiv expr2)
ceildiv expr3).

2) If both expr1 and expr2 are divisible by the expr3 then:
(expr1 % expr2) / expr3 = ((expr1 / expr3) % (expr2 / expr3))
where / is divide symbol.

3) If expr1 is divisible by expr2 then expr1 % expr2 = 0.

Signed-off-by: Yash Jain <yash.jain at polymagelabs.com>

Differential Revision: https://reviews.llvm.org/D84920

Modified:
mlir/lib/IR/AffineExpr.cpp
mlir/test/Dialect/Affine/simplify-affine-structures.mlir

Removed:

################################################################################
diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 5ba9737a5245..0d4d9d08c935 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -245,6 +245,170 @@ unsigned AffineDimExpr::getPosition() const {
return static_cast<ImplType *>(expr)->position;
}

+/// Returns true if the expression is divisible by the given symbol with
+/// position `symbolPos`. The argument `opKind` specifies here what kind of
+/// division or mod operation called this division. It helps in implementing the
+/// commutative property of the floordiv and ceildiv operations. If the argument
+///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
+/// operation, then the commutative property can be used otherwise, the floordiv
+/// operation is not divisible. The same argument holds for ceildiv operation.
+static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
+                                AffineExprKind opKind) {
+  // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
+  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
+          opKind == AffineExprKind::CeilDiv) &&
+         "unexpected opKind");
+  switch (expr.getKind()) {
+  case AffineExprKind::Constant:
+    if (expr.cast<AffineConstantExpr>().getValue())
+      return false;
+    return true;
+  case AffineExprKind::DimId:
+    return false;
+  case AffineExprKind::SymbolId:
+    return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
+  // Checks divisibility by the given symbol for both operands.
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
+           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+  }
+  // Checks divisibility by the given symbol for both operands. Consider the
+  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
+  // this is a division by s1 and both the operands of modulo are divisible by
+  // s1 but it is not divisible by s1 always. The third argument is
+  // `AffineExprKind::Mod` for this reason.
+  case AffineExprKind::Mod: {
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
+                               AffineExprKind::Mod) &&
+           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
+                               AffineExprKind::Mod);
+  }
+  // Checks if any of the operand divisible by the given symbol.
+  case AffineExprKind::Mul: {
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
+           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+  }
+  // Floordiv and ceildiv are divisible by the given symbol when the first
+  // operand is divisible, and the affine expression kind of the argument expr
+  // is same as the argument `opKind`. This can be inferred from commutative
+  // property of floordiv and ceildiv operations and are as follow:
+  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
+  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
+  // It will fail if operations are not same. For example:
+  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
+  case AffineExprKind::FloorDiv:
+  case AffineExprKind::CeilDiv: {
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    if (opKind != expr.getKind())
+      return false;
+    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
+  }
+  }
+  llvm_unreachable("Unknown AffineExpr");
+}
+
+/// Divides the given expression by the given symbol at position `symbolPos`. It
+/// considers the divisibility condition is checked before calling itself. A
+/// null expression is returned whenever the divisibility condition fails.
+static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
+                                 AffineExprKind opKind) {
+  // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
+  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
+          opKind == AffineExprKind::CeilDiv) &&
+         "unexpected opKind");
+  switch (expr.getKind()) {
+  case AffineExprKind::Constant:
+    if (expr.cast<AffineConstantExpr>().getValue() != 0)
+      return nullptr;
+    return getAffineConstantExpr(0, expr.getContext());
+  case AffineExprKind::DimId:
+    return nullptr;
+  case AffineExprKind::SymbolId:
+    return getAffineConstantExpr(1, expr.getContext());
+  // Dividing both operands by the given symbol.
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    return getAffineBinaryOpExpr(
+        expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
+        symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
+  }
+  // Dividing both operands by the given symbol.
+  case AffineExprKind::Mod: {
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    return getAffineBinaryOpExpr(
+        expr.getKind(),
+        symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
+        symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
+  }
+  // Dividing any of the operand by the given symbol.
+  case AffineExprKind::Mul: {
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
+      return binaryExpr.getLHS() *
+             symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
+    return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
+           binaryExpr.getRHS();
+  }
+  // Dividing first operand only by the given symbol.
+  case AffineExprKind::FloorDiv:
+  case AffineExprKind::CeilDiv: {
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    return getAffineBinaryOpExpr(
+        expr.getKind(),
+        symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
+        binaryExpr.getRHS());
+  }
+  }
+  llvm_unreachable("Unknown AffineExpr");
+}
+
+/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
+/// operations when the second operand simplifies to a symbol and the first
+/// operand is divisible by that symbol. It can be applied to any semi-affine
+/// expression. Returned expression can either be a semi-affine or pure affine
+/// expression.
+static AffineExpr simplifySemiAffine(AffineExpr expr) {
+  switch (expr.getKind()) {
+  case AffineExprKind::Constant:
+  case AffineExprKind::DimId:
+  case AffineExprKind::SymbolId:
+    return expr;
+  case AffineExprKind::Mul: {
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    return getAffineBinaryOpExpr(expr.getKind(),
+                                 simplifySemiAffine(binaryExpr.getLHS()),
+                                 simplifySemiAffine(binaryExpr.getRHS()));
+  }
+  // Check if the simplification of the second operand is a symbol, and the
+  // first operand is divisible by it. If the operation is a modulo, a constant
+  // zero expression is returned. In the case of floordiv and ceildiv, the
+  // symbol from the simplification of the second operand divides the first
+  // operand. Otherwise, simplification is not possible.
+  case AffineExprKind::FloorDiv:
+  case AffineExprKind::CeilDiv:
+  case AffineExprKind::Mod: {
+    AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+    AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
+    AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
+    AffineSymbolExpr symbolExpr =
+        simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
+    if (!symbolExpr)
+      return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
+    unsigned symbolPos = symbolExpr.getPosition();
+    if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
+      return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
+    if (expr.getKind() == AffineExprKind::Mod)
+      return getAffineConstantExpr(0, expr.getContext());
+    return symbolicDivide(sLHS, symbolPos, expr.getKind());
+  }
+  }
+  llvm_unreachable("Unknown AffineExpr");
+}
+
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
MLIRContext *context) {
auto assignCtx = [context](AffineDimExprStorage *storage) {
@@ -878,8 +1042,9 @@ int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
/// Simplify the affine expression by flattening it and reconstructing it.
AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols) {
-  // TODO: only pure affine for now. The simplification here can
-  // be extended to semi-affine maps in the future.
+  // Simplify semi-affine expressions separately.
+  if (!expr.isPureAffine())
+    expr = simplifySemiAffine(expr);
if (!expr.isPureAffine())
return expr;

diff  --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
index 91f153f1fb21..11fb0b128d63 100644
--- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
@@ -281,3 +281,49 @@ func @simplify_zero_dim_map(%in : memref<f32>) -> f32 {
%out = affine.load %in[] : memref<f32>
return %out : f32
}
+
+// -----
+
+// Tests the simplification of a semi-affine expression in various cases.
+// CHECK-DAG: #[[\$map0:.*]] = affine_map<()[s0, s1] -> (-(s1 floordiv s0) + 2)>
+// CHECK-DAG: #[[\$map1:.*]] = affine_map<()[s0, s1] -> (-(s1 floordiv s0) + 42)>
+
+// Tests the simplification of a semi-affine expression with a modulo operartion on a floordiv and multiplication.
+// CHECK-LABEL: func @semiaffine_mod
+func @semiaffine_mod(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + s0 * s0) mod s0)> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = constant 0
+  return %a : index
+}
+
+// Tests the simplification of a semi-affine expression with a nested floordiv and a floordiv on modulo operation.
+// CHECK-LABEL: func @semiaffine_floordiv
+func @semiaffine_floordiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + ((2 * s0) mod (3 * s0))) floordiv s0)> (%arg0)[%arg1]
+  // CHECK: affine.apply #[[\$map0]]()[%arg1, %arg0]
+  return %a : index
+}
+
+// Tests the simplification of a semi-affine expression with a ceildiv operation and a division of constant 0 by a symbol.
+// CHECK-LABEL: func @semiaffine_ceildiv
+func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + s0 * 42 + ((5-5) floordiv s0)) ceildiv  s0)> (%arg0)[%arg1]
+  // CHECK: affine.apply #[[\$map1]]()[%arg1, %arg0]
+  return %a : index
+}
+
+// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
+// CHECK-LABEL: func @semiaffine_composite_floor
+func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = constant 47
+  return %a : index
+}
+
+// Tests the simplification of a semi-affine expression with a modulo operation with a second operand that simplifies to symbol.
+// CHECK-LABEL: func @semiaffine_unsimplified_symbol
+func @semiaffine_unsimplified_symbol(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->(s0 mod (2 * s0 - s0))> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = constant 0
+  return %a : index
+}