[Mlir-commits] [mlir] 1402299 - [MLIR] Simplify semi-affine expressions using flattening

Uday Bondhugula llvmlistbot at llvm.org
Tue Nov 16 02:12:53 PST 2021


Author: Arnab Dutta
Date: 2021-11-16T15:42:22+05:30
New Revision: 1402299271c1d57784f77dd577949e4a546a4c10

URL: https://github.com/llvm/llvm-project/commit/1402299271c1d57784f77dd577949e4a546a4c10
DIFF: https://github.com/llvm/llvm-project/commit/1402299271c1d57784f77dd577949e4a546a4c10.diff

LOG: [MLIR] Simplify semi-affine expressions using flattening

For the semi affine expressions, whenever rhs of a floordiv, ceildiv, mod
or product expression is a symbolic expression, we introduce a local variable
representing the result, and store the floordiv/ceildiv, mod or product
affine expression in LocalExprs. In this way the expression is flattened,
and trivial addition and subtraction related simplifications are performed.
Also rule based matching for detecting and transforming "expr - q * (expr floordiv q)"
to "expr mod q", where q is a symbolic exxpression, in simplifyAdd function.

Differential Revision: https://reviews.llvm.org/D112808

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineExpr.h
    mlir/include/mlir/IR/AffineExprVisitor.h
    mlir/lib/IR/AffineExpr.cpp
    mlir/test/Dialect/Affine/simplify-affine-structures.mlir
    mlir/test/Dialect/SCF/for-loop-peeling.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index a87a0180bc6dd..f09a4aee00e3f 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -293,10 +293,12 @@ U AffineExpr::cast() const {
   return U(expr);
 }
 
-/// Simplify an affine expression by flattening and some amount of
-/// simple analysis. This has complexity linear in the number of nodes in
-/// 'expr'. Returns the simplified expression, which is the same as the input
-///  expression if it can't be simplified.
+/// Simplify an affine expression by flattening and some amount of simple
+/// analysis. This has complexity linear in the number of nodes in 'expr'.
+/// Returns the simplified expression, which is the same as the input expression
+/// if it can't be simplified. When `expr` is semi-affine, a simplified
+/// semi-affine expression is constructed in the sorted order of dimension and
+/// symbol positions.
 AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
                               unsigned numSymbols);
 

diff  --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 92db1390e30dd..259cc5e3f8717 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -299,7 +299,26 @@ class SimpleAffineExprFlattener
   virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
                                   AffineExpr localExpr);
 
+  /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
+  /// expr) when the rhs is a symbolic expression. The local identifier added
+  /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
+  /// function of other identifiers, coefficients of which are specified in the
+  /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
+  /// symbolic rhs expression. `localExpr` is the simplified tree expression
+  /// (AffineExpr) corresponding to the quantifier.
+  virtual void addLocalIdSemiAffine(AffineExpr localExpr);
+
 private:
+  /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
+  /// representing the affine expression corresponding to the quantifier
+  /// introduced as the local variable corresponding to `expr`. If the
+  /// quantifier is already present, we put the coefficient in the proper index
+  /// of `result`, otherwise we add a new local variable and put the coefficient
+  /// there.
+  void addLocalVariableSemiAffine(AffineExpr expr,
+                                  SmallVectorImpl<int64_t> &result,
+                                  unsigned long resultSize);
+
   // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
   // A floordiv is thus flattened by introducing a new local variable q, and
   // replacing that expression with 'q' while adding the constraints

diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index b98731031521e..36be0d4ee23d4 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -903,21 +903,213 @@ AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
   return expr;
 }
 
+/// Constructs a semi-affine expression from a flat ArrayRef. If there are
+/// local identifiers (neither dimensional nor symbolic) that appear in the sum
+/// of products expression, `localExprs` is expected to have the AffineExprs for
+/// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
+/// the format [dims, symbols, locals, constant term]. The semi-affine
+/// expression is constructed in the sorted order of dimension and symbol
+/// position numbers. Note:  local expressions/ids are used for mod, div as well
+/// as symbolic RHS terms for terms that are not pure affine.
+static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
+                                                unsigned numDims,
+                                                unsigned numSymbols,
+                                                ArrayRef<AffineExpr> localExprs,
+                                                MLIRContext *context) {
+  assert(!flatExprs.empty() && "flatExprs cannot be empty");
+
+  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
+  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
+         "unexpected number of local expressions");
+
+  AffineExpr expr = getAffineConstantExpr(0, context);
+
+  // We design indices as a pair which help us present the semi-affine map as
+  // sum of product where terms are sorted based on dimension or symbol
+  // position: <keyA, keyB> for expressions of the form dimension * symbol,
+  // where keyA is the position number of the dimension and keyB is the
+  // position number of the symbol. For dimensional expressions we set the index
+  // as (position number of the dimension, -1), as we want dimensional
+  // expressions to appear before symbolic and product of dimensional and
+  // symbolic expressions having the dimension with the same position number.
+  // For symbolic expression set the index as (position number of the symbol,
+  // maximum of last dimension and symbol position) number. For example, we want
+  // the expression we are constructing to look something like: d0 + d0 * s0 +
+  // s0 + d1*s1 + s1.
+
+  // Stores the affine expression corresponding to a given index.
+  DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
+  // Stores the constant coefficient value corresponding to a given
+  // dimension, symbol or a non-pure affine expression stored in `localExprs`.
+  DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
+  // Stores the indices as defined above, and later sorted to produce
+  // the semi-affine expression in the desired form.
+  SmallVector<std::pair<unsigned, signed>, 8> indices;
+
+  // Example: expression = d0 + d0 * s0 + 2 * s0.
+  // indices = [{0,-1}, {0, 0}, {0, 1}]
+  // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
+  // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
+
+  // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
+  auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
+                      AffineExpr expr) {
+    assert(std::find(indices.begin(), indices.end(), index) == indices.end() &&
+           "Key is already present in indices vector and overwriting will "
+           "happen in `indexToExprMap` and `coefficients`!");
+
+    indices.push_back(index);
+    coefficients.insert({index, coefficient});
+    indexToExprMap.insert({index, expr});
+  };
+
+  // Design indices for dimensional or symbolic terms, and store the indices,
+  // constant coefficient corresponding to the indices in `coefficients` map,
+  // and affine expression corresponding to indices in `indexToExprMap` map.
+
+  for (unsigned j = 0; j < numDims; ++j) {
+    if (flatExprs[j] == 0)
+      continue;
+    // For dimensional expressions we set the index as <position number of the
+    // dimension, 0>, as we want dimensional expressions to appear before
+    // symbolic ones and products of dimensional and symbolic expressions
+    // having the dimension with the same position number.
+    std::pair<unsigned, signed> indexEntry(j, -1);
+    addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
+  }
+  for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
+    if (flatExprs[j] == 0)
+      continue;
+    // For symbolic expression set the index as <position number
+    // of the symbol, max(dimCount, symCount)> number,
+    // as we want symbolic expressions with the same positional number to
+    // appear after dimensional expressions having the same positional number.
+    std::pair<unsigned, signed> indexEntry(j - numDims,
+                                           std::max(numDims, numSymbols));
+    addEntry(indexEntry, flatExprs[j],
+             getAffineSymbolExpr(j - numDims, context));
+  }
+
+  // Denotes semi-affine product, modulo or division terms, which has been added
+  // to the `indexToExpr` map.
+  SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
+                                  false);
+  unsigned lhsPos, rhsPos;
+  // Construct indices for product terms involving dimension, symbol or constant
+  // as lhs/rhs, and store the indices, constant coefficient corresponding to
+  // the indices in `coefficients` map, and affine expression corresponding to
+  // in indices in `indexToExprMap` map.
+  for (auto it : llvm::enumerate(localExprs)) {
+    AffineExpr expr = it.value();
+    if (flatExprs[numDims + numSymbols + it.index()] == 0)
+      continue;
+    AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
+    AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
+    if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
+          (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
+           rhs.isa<AffineConstantExpr>()))) {
+      continue;
+    }
+    if (rhs.isa<AffineConstantExpr>()) {
+      // For product/modulo/division expressions, when rhs of modulo/division
+      // expression is constant, we put 0 in place of keyB, because we want
+      // them to appear earlier in the semi-affine expression we are
+      // constructing. When rhs is constant, we place 0 in place of keyB.
+      if (lhs.isa<AffineDimExpr>()) {
+        lhsPos = lhs.cast<AffineDimExpr>().getPosition();
+        std::pair<unsigned, signed> indexEntry(lhsPos, -1);
+        addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
+                 expr);
+      } else {
+        lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
+        std::pair<unsigned, signed> indexEntry(lhsPos,
+                                               std::max(numDims, numSymbols));
+        addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
+                 expr);
+      }
+    } else if (lhs.isa<AffineDimExpr>()) {
+      // For product/modulo/division expressions having lhs as dimension and rhs
+      // as symbol, we order the terms in the semi-affine expression based on
+      // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
+      // where keyA is the position number of the dimension and keyB is the
+      // position number of the symbol.
+      lhsPos = lhs.cast<AffineDimExpr>().getPosition();
+      rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
+      std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
+      addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
+    } else {
+      // For product/modulo/division expressions having both lhs and rhs as
+      // symbol, we design indices as a pair: <keyA, keyB> for expressions
+      // of the form dimension * symbol, where keyA is the position number of
+      // the dimension and keyB is the position number of the symbol.
+      lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
+      rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
+      std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
+      addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
+    }
+    addedToMap[it.index()] = true;
+  }
+
+  // Constructing the simplified semi-affine sum of product/division/mod
+  // expression from the flattened form in the desired sorted order of indices
+  // of the various individual product/division/mod expressions.
+  std::sort(indices.begin(), indices.end());
+  for (const std::pair<unsigned, unsigned> index : indices) {
+    assert(indexToExprMap.lookup(index) &&
+           "cannot find key in `indexToExprMap` map");
+    expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
+  }
+
+  // Local identifiers.
+  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
+       j++) {
+    // If the coefficient of the local expression is 0, continue as we need not
+    // add it in out final expression.
+    if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
+      continue;
+    auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
+    expr = expr + term;
+  }
+
+  // Constant term.
+  int64_t constTerm = flatExprs.back();
+  if (constTerm != 0)
+    expr = expr + constTerm;
+  return expr;
+}
+
 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
                                                      unsigned numSymbols)
     : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
   operandExprStack.reserve(8);
 }
 
+// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
+//
+// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
+// introduce a local variable p (= expr * symbolic_expr), and the affine
+// expression expr * symbolic_expr is added to `localExprs`.
 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
   assert(operandExprStack.size() >= 2);
-  // This is a pure affine expr; the RHS will be a constant.
-  assert(expr.getRHS().isa<AffineConstantExpr>());
-  // Get the RHS constant.
-  auto rhsConst = operandExprStack.back()[getConstantIndex()];
+  SmallVector<int64_t, 8> rhs = operandExprStack.back();
   operandExprStack.pop_back();
-  // Update the LHS in place instead of pop and push.
-  auto &lhs = operandExprStack.back();
+  SmallVector<int64_t, 8> &lhs = operandExprStack.back();
+
+  // Flatten semi-affine multiplication expressions by introducing a local
+  // variable in place of the product; the affine expression
+  // corresponding to the quantifier is added to `localExprs`.
+  if (!expr.getRHS().isa<AffineConstantExpr>()) {
+    MLIRContext *context = expr.getContext();
+    AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
+                                             localExprs, context);
+    AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
+                                             localExprs, context);
+    addLocalVariableSemiAffine(a * b, lhs, lhs.size());
+    return;
+  }
+
+  // Get the RHS constant.
+  auto rhsConst = rhs[getConstantIndex()];
   for (unsigned i = 0, e = lhs.size(); i < e; i++) {
     lhs[i] *= rhsConst;
   }
@@ -942,13 +1134,32 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
 // A mod expression "expr mod c" is thus flattened by introducing a new local
 // variable q (= expr floordiv c), such that expr mod c is replaced with
 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
+//
+// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
+// introduce a local variable m (= expr mod symbolic_expr), and the affine
+// expression expr mod symbolic_expr is added to `localExprs`.
 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
   assert(operandExprStack.size() >= 2);
-  // This is a pure affine expr; the RHS will be a constant.
-  assert(expr.getRHS().isa<AffineConstantExpr>());
-  auto rhsConst = operandExprStack.back()[getConstantIndex()];
+
+  SmallVector<int64_t, 8> rhs = operandExprStack.back();
   operandExprStack.pop_back();
-  auto &lhs = operandExprStack.back();
+  SmallVector<int64_t, 8> &lhs = operandExprStack.back();
+  MLIRContext *context = expr.getContext();
+
+  // Flatten semi affine modulo expressions by introducing a local
+  // variable in place of the modulo value, and the affine expression
+  // corresponding to the quantifier is added to `localExprs`.
+  if (!expr.getRHS().isa<AffineConstantExpr>()) {
+    AffineExpr dividendExpr = getAffineExprFromFlatForm(
+        lhs, numDims, numSymbols, localExprs, context);
+    AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
+                                                       localExprs, context);
+    AffineExpr modExpr = dividendExpr % divisorExpr;
+    addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
+    return;
+  }
+
+  int64_t rhsConst = rhs[getConstantIndex()];
   // TODO: handle modulo by zero case when this issue is fixed
   // at the other places in the IR.
   assert(rhsConst > 0 && "RHS constant has to be positive");
@@ -979,11 +1190,11 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
   int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
 
   // Construct the AffineExpr form of the floordiv to store in localExprs.
-  MLIRContext *context = expr.getContext();
-  auto dividendExpr = getAffineExprFromFlatForm(
+
+  AffineExpr dividendExpr = getAffineExprFromFlatForm(
       floorDividend, numDims, numSymbols, localExprs, context);
-  auto divisorExpr = getAffineConstantExpr(floorDivisor, context);
-  auto floorDivExpr = dividendExpr.floorDiv(divisorExpr);
+  AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
+  AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
   int loc;
   if ((loc = findLocalId(floorDivExpr)) == -1) {
     addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
@@ -1022,6 +1233,21 @@ void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
   eq[getConstantIndex()] = expr.getValue();
 }
 
+void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
+    AffineExpr expr, SmallVectorImpl<int64_t> &result,
+    unsigned long resultSize) {
+  assert(result.size() == resultSize &&
+         "`result` vector passed is not of correct size");
+  int loc;
+  if ((loc = findLocalId(expr)) == -1)
+    addLocalIdSemiAffine(expr);
+  std::fill(result.begin(), result.end(), 0);
+  if (loc == -1)
+    result[getLocalVarStartIndex() + numLocals - 1] = 1;
+  else
+    result[getLocalVarStartIndex() + loc] = 1;
+}
+
 // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
 // A floordiv is thus flattened by introducing a new local variable q, and
 // replacing that expression with 'q' while adding the constraints
@@ -1030,18 +1256,38 @@ void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
 //
 // A ceildiv is similarly flattened:
 // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
+//
+// In case of semi affine division expressions, t = expr floordiv symbolic_expr
+// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
+// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
+// `localExprs`.
 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
                                              bool isCeil) {
   assert(operandExprStack.size() >= 2);
-  assert(expr.getRHS().isa<AffineConstantExpr>());
+
+  MLIRContext *context = expr.getContext();
+  SmallVector<int64_t, 8> rhs = operandExprStack.back();
+  operandExprStack.pop_back();
+  SmallVector<int64_t, 8> &lhs = operandExprStack.back();
+
+  // Flatten semi affine division expressions by introducing a local
+  // variable in place of the quotient, and the affine expression corresponding
+  // to the quantifier is added to `localExprs`.
+  if (!expr.getRHS().isa<AffineConstantExpr>()) {
+    AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
+                                             localExprs, context);
+    AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
+                                             localExprs, context);
+    AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+    addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
+    return;
+  }
 
   // This is a pure affine expr; the RHS is a positive constant.
-  int64_t rhsConst = operandExprStack.back()[getConstantIndex()];
+  int64_t rhsConst = rhs[getConstantIndex()];
   // TODO: handle division by zero at the same time the issue is
   // fixed at other places.
   assert(rhsConst > 0 && "RHS constant has to be positive");
-  operandExprStack.pop_back();
-  auto &lhs = operandExprStack.back();
 
   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
   // common divisors of the numerator and denominator.
@@ -1063,13 +1309,12 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
   // the ceil/floor expr (simplified up until here). Add an existential
   // quantifier to express its result, i.e., expr1 div expr2 is replaced
   // by a new identifier, q.
-  MLIRContext *context = expr.getContext();
-  auto a =
+  AffineExpr a =
       getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
-  auto b = getAffineConstantExpr(divisor, context);
+  AffineExpr b = getAffineConstantExpr(divisor, context);
 
   int loc;
-  auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
   if ((loc = findLocalId(divExpr)) == -1) {
     if (!isCeil) {
       SmallVector<int64_t, 8> dividend(lhs);
@@ -1099,13 +1344,20 @@ void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
                                                    int64_t divisor,
                                                    AffineExpr localExpr) {
   assert(divisor > 0 && "positive constant divisor expected");
-  for (auto &subExpr : operandExprStack)
+  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
   localExprs.push_back(localExpr);
   numLocals++;
   // dividend and divisor are not used here; an override of this method uses it.
 }
 
+void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) {
+  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
+    subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
+  localExprs.push_back(localExpr);
+  ++numLocals;
+}
+
 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
   SmallVectorImpl<AffineExpr>::iterator it;
   if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
@@ -1119,17 +1371,24 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
   // Simplify semi-affine expressions separately.
   if (!expr.isPureAffine())
     expr = simplifySemiAffine(expr);
-  if (!expr.isPureAffine())
-    return expr;
 
   SimpleAffineExprFlattener flattener(numDims, numSymbols);
   flattener.walkPostOrder(expr);
   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
-  auto simplifiedExpr =
-      getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
-                                flattener.localExprs, expr.getContext());
+  if (!expr.isPureAffine() &&
+      expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
+                                        flattener.localExprs,
+                                        expr.getContext()))
+    return expr;
+  AffineExpr simplifiedExpr =
+      expr.isPureAffine()
+          ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
+                                      flattener.localExprs, expr.getContext())
+          : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
+                                          flattener.localExprs,
+                                          expr.getContext());
+
   flattener.operandExprStack.pop_back();
   assert(flattener.operandExprStack.empty());
-
   return simplifiedExpr;
 }

diff  --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
index 395f873c945c0..837226204f5e9 100644
--- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
@@ -479,3 +479,57 @@ func @test_not_trivially_true_or_false_returning_three_results() -> (index, inde
   }
   return %res#0, %res#1, %res#2 : index, index, index
 }
+
+// -----
+
+// Test simplification of mod expressions.
+// CHECK-DAG:   #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 * s1 + (s0 - s1) mod s2)>
+// CHECK-DAG:   #[[$SIMPLIFIED_MOD_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 mod (s2 - s0 * s1))>
+// CHECK-DAG:   #[[$MODULO_AND_PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s3 - (-s0 + s3) mod s2)>
+// CHECK-LABEL: func @semiaffine_simplification_mod
+// CHECK-SAME:  (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index)
+func @semiaffine_simplification_mod(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index, index) {
+  %a = affine.apply affine_map<(d0, d1)[s0, s1, s2, s3] -> ((-(d1 * s0 - (s0 - s1) mod s2) + s3) + (d0 * s1 + d1 * s0))>(%arg0, %arg1)[%arg2, %arg3, %arg4, %arg5]
+  %b = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 mod (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3]
+  %c = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + (d0 + s0) mod s2 + s0 * s1 - (d0 + s0) mod s2 - (d0 - s0) mod s2)>(%arg0)[%arg1, %arg2, %arg3]
+  return %a, %b, %c : index, index, index
+}
+// CHECK-NEXT: %[[RESULT0:.*]] = affine.apply #[[$MOD]]()[%[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG0]]]
+// CHECK-NEXT: %[[RESULT1:.*]] = affine.apply #[[$SIMPLIFIED_MOD_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]]
+// CHECK-NEXT: %[[RESULT2:.*]] = affine.apply #[[$MODULO_AND_PRODUCT]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]]
+// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]], %[[RESULT2]]
+
+// -----
+
+// Test simplification of floordiv and ceildiv expressions.
+// CHECK-DAG:   #[[$SIMPLIFIED_FLOORDIV_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 floordiv (s2 - s0 * s1))>
+// CHECK-DAG:   #[[$FLOORDIV:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 + s3 + (s0 - s1) floordiv s2)>
+// CHECK-DAG:   #[[$SIMPLIFIED_CEILDIV_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 ceildiv (s2 - s0 * s1))>
+// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv
+// CHECK-SAME:  (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
+func @semiaffine_simplification_floordiv_and_ceildiv(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %a = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 floordiv (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3]
+  %b = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> ((-(d0 * s1 - (s0 - s1) floordiv s2) + s3) + (d0 * s1 + s0))>(%arg0)[%arg1, %arg2, %arg3, %arg4]
+  %c = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 ceildiv (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3]
+  return %a, %b, %c : index, index, index
+}
+// CHECK-NEXT: %[[RESULT0:.*]] = affine.apply #[[$SIMPLIFIED_FLOORDIV_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]]
+// CHECK-NEXT: %[[RESULT1:.*]] = affine.apply #[[$FLOORDIV]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]]
+// CHECK-NEXT: %[[RESULT2:.*]] = affine.apply #[[$SIMPLIFIED_CEILDIV_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]]
+// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]], %[[RESULT2]]
+
+// -----
+
+// Test simplification of product expressions.
+// CHECK-DAG:   #[[$PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 + (s0 - s1) * s2)>
+// CHECK-DAG:   #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s2 * s0 + s2 + s3 * s0 + s3 * s1 + s3 + s4 * s1 + s4)>
+// CHECK-LABEL: func @semiaffine_simplification_product
+// CHECK-SAME:  (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index)
+func @semiaffine_simplification_product(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index) {
+  %a = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> ((-(s0 - (s0 - s1) * s2) + s3) + (d0 + s0))>(%arg0)[%arg1, %arg2, %arg3, %arg4]
+  %b = affine.apply affine_map<(d0, d1, d2)[s0, s1] -> (d0 + d1 * s1 + d1 + d0 * s0 + d1 * s0 + d2 * s1 + d2)>(%arg0, %arg1, %arg2)[%arg3, %arg4]
+  return %a, %b : index, index
+}
+// CHECK-NEXT: %[[RESULT0:.*]] = affine.apply #[[$PRODUCT]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG0]]]
+// CHECK-NEXT: %[[RESULT1:.*]] = affine.apply #[[$SUM_OF_PRODUCTS]]()[%[[ARG3]], %[[ARG4]], %[[ARG0]], %[[ARG1]], %[[ARG2]]]
+// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]]

diff  --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir
index 2a7ce9650fbfe..e138dcf2bd102 100644
--- a/mlir/test/Dialect/SCF/for-loop-peeling.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -for-loop-peeling -canonicalize -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -for-loop-peeling=skip-partial=false -canonicalize -split-input-file | FileCheck %s -check-prefix=CHECK-NO-SKIP
 
-//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)>
+//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (-s0 + s1) mod s2)>
 //  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0)>
 //      CHECK: func @fully_dynamic_bounds(
 // CHECK-SAME:     %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index


        


More information about the Mlir-commits mailing list