[Mlir-commits] [mlir] 4b56e2e - [mlir][Analysis][NFC] Remove code duplication around getFlattenedAffineExprs

Matthias Springer llvmlistbot at llvm.org
Wed Aug 11 00:05:52 PDT 2021


Author: Matthias Springer
Date: 2021-08-11T16:02:10+09:00
New Revision: 4b56e2ee1dd46a5c239bd9bbf0408332fba523a9

URL: https://github.com/llvm/llvm-project/commit/4b56e2ee1dd46a5c239bd9bbf0408332fba523a9
DIFF: https://github.com/llvm/llvm-project/commit/4b56e2ee1dd46a5c239bd9bbf0408332fba523a9.diff

LOG: [mlir][Analysis][NFC] Remove code duplication around getFlattenedAffineExprs

Remove code duplication in `addLowerOrUpperBound` and `composeMatchingMap`.

Differential Revision: https://reviews.llvm.org/D107814

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/lib/Analysis/AffineStructures.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 1dc55da9f915..e2f4c10e1078 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -602,6 +602,19 @@ class FlatAffineConstraints {
   /// must already have a corresponding dim/symbol in this constraint system.
   AffineMap computeAlignedMap(AffineMap map, ValueRange operands) const;
 
+  /// Given an affine map that is aligned with this constraint system:
+  /// * Flatten the map.
+  /// * Add newly introduced local columns at the beginning of this constraint
+  ///   system (local column pos 0).
+  /// * Add equalities that define the new local columns to this constraint
+  ///   system.
+  /// * Return the flattened expressions via `flattenedExprs`.
+  ///
+  /// Note: This is a shared helper function of `addLowerOrUpperBound` and
+  ///       `composeMatchingMap`.
+  LogicalResult flattenAlignedMapAndMergeLocals(
+      AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs);
+
   // Eliminates a single identifier at 'position' from equality and inequality
   // constraints. Returns 'success' if the identifier was eliminated, and
   // 'failure' otherwise.

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index cd5b4e9cfe37..984500e94dbd 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -400,28 +400,10 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
   assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
 
   std::vector<SmallVector<int64_t, 8>> flatExprs;
-  FlatAffineConstraints localCst;
-  if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "composition unimplemented for semi-affine maps\n");
+  if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
     return failure();
-  }
   assert(flatExprs.size() == other.getNumResults());
 
-  // Add localCst information.
-  if (localCst.getNumLocalIds() > 0) {
-    unsigned numLocalIds = getNumLocalIds();
-    // Insert local dims of localCst at the beginning.
-    for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; ++l)
-      addLocalId(0);
-    // Insert local dims of `this` at the end of localCst.
-    for (unsigned l = 0; l < numLocalIds; ++l)
-      localCst.addLocalId(localCst.getNumLocalIds());
-    // Dimensions of localCst and this constraint set match. Append localCst to
-    // this constraint set.
-    append(localCst);
-  }
-
   // Add dimensions corresponding to the map's results.
   for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
     addDimId(0);
@@ -429,25 +411,24 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
 
   // We add one equality for each result connecting the result dim of the map to
   // the other identifiers.
-  // For eg: if the expression is 16*i0 + i1, and this is the r^th
+  // E.g.: if the expression is 16*i0 + i1, and this is the r^th
   // iteration/result of the value map, we are adding the equality:
-  //  d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
-  //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
+  // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
+  // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
     const auto &flatExpr = flatExprs[r];
     assert(flatExpr.size() >= other.getNumInputs() + 1);
 
-    // eqToAdd is the equality corresponding to the flattened affine expression.
     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
     // Set the coefficient for this result to one.
     eqToAdd[r] = 1;
 
     // Dims and symbols.
     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
-      // Negate 'eq[r]' since the newly added dimension will be set to this one.
+      // Negate `eq[r]` since the newly added dimension will be set to this one.
       eqToAdd[e + i] = -flatExpr[i];
     }
-    // Local vars common to eq and localCst are at the beginning.
+    // Local columns of `eq` are at the beginning.
     unsigned j = getNumDimIds() + getNumSymbolIds();
     unsigned end = flatExpr.size() - 1;
     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
@@ -1872,27 +1853,14 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
   }
 }
 
-LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
-                                                          AffineMap boundMap,
-                                                          bool eq, bool lower) {
-  assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
-  assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
-  assert(pos < getNumDimAndSymbolIds() && "invalid position");
-
-  // Equality follows the logic of lower bound except that we add an equality
-  // instead of an inequality.
-  assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
-  if (eq)
-    lower = true;
-
-  std::vector<SmallVector<int64_t, 8>> flatExprs;
+LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals(
+    AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
   FlatAffineConstraints localCst;
-  if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localCst))) {
+  if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
     LLVM_DEBUG(llvm::dbgs()
                << "composition unimplemented for semi-affine maps\n");
     return failure();
   }
-  assert(flatExprs.size() == boundMap.getNumResults());
 
   // Add localCst information.
   if (localCst.getNumLocalIds() > 0) {
@@ -1908,6 +1876,27 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
     append(localCst);
   }
 
+  return success();
+}
+
+LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
+                                                          AffineMap boundMap,
+                                                          bool eq, bool lower) {
+  assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
+  assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
+  assert(pos < getNumDimAndSymbolIds() && "invalid position");
+
+  // Equality follows the logic of lower bound except that we add an equality
+  // instead of an inequality.
+  assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
+  if (eq)
+    lower = true;
+
+  std::vector<SmallVector<int64_t, 8>> flatExprs;
+  if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
+    return failure();
+  assert(flatExprs.size() == boundMap.getNumResults());
+
   // Add one (in)equality for each result.
   for (const auto &flatExpr : flatExprs) {
     SmallVector<int64_t> ineq(getNumCols(), 0);
@@ -1921,7 +1910,7 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
     if (ineq[pos] != 0)
       continue;
     ineq[pos] = lower ? 1 : -1;
-    // Local vars common to eq and localCst are at the beginning.
+    // Local columns of `ineq` are at the beginning.
     unsigned j = getNumDimIds() + getNumSymbolIds();
     unsigned end = flatExpr.size() - 1;
     for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {


        


More information about the Mlir-commits mailing list