[Mlir-commits] [mlir] 43a95a5 - [MLIR] Introduce full/partial tile separation using if/else

Uday Bondhugula llvmlistbot at llvm.org
Fri Mar 27 18:41:58 PDT 2020


Author: Uday Bondhugula
Date: 2020-03-28T06:58:35+05:30
New Revision: 43a95a543fbb1ed4b3903e88ce291444d4970f5a

URL: https://github.com/llvm/llvm-project/commit/43a95a543fbb1ed4b3903e88ce291444d4970f5a
DIFF: https://github.com/llvm/llvm-project/commit/43a95a543fbb1ed4b3903e88ce291444d4970f5a.diff

LOG: [MLIR] Introduce full/partial tile separation using if/else

This patch introduces a utility to separate full tiles from partial
tiles when tiling affine loop nests where trip counts are unknown or
where tile sizes don't divide trip counts. A conditional guard is
generated to separate out the full tile (with constant trip count loops)
into the then block of an 'affine.if' and the partial tile to the else
block. The separation allows the 'then' block (which has constant trip
count loops) to be optimized better subsequently: for eg. for
unroll-and-jam, register tiling, vectorization without leading to
cleanup code, or to offload to accelerators. Among techniques from the
literature, the if/else based separation leads to the most compact
cleanup code for multi-dimensional cases (because a single version is
used to model all partial tiles).

INPUT

  affine.for %i0 = 0 to %M {
    affine.for %i1 = 0 to %N {
      "foo"() : () -> ()
    }
  }

OUTPUT AFTER TILING W/O SEPARATION

  map0 = affine_map<(d0) -> (d0)>
  map1 = affine_map<(d0)[s0] -> (d0 + 32, s0)>

  affine.for %arg2 = 0 to %M step 32 {
    affine.for %arg3 = 0 to %N step 32 {
      affine.for %arg4 = #map0(%arg2) to min #map1(%arg2)[%M] {
        affine.for %arg5 = #map0(%arg3) to min #map1(%arg3)[%N] {
          "foo"() : () -> ()
        }
      }
    }
  }

  OUTPUT AFTER TILING WITH SEPARATION

  map0 = affine_map<(d0) -> (d0)>
  map1 = affine_map<(d0) -> (d0 + 32)>
  map2 = affine_map<(d0)[s0] -> (d0 + 32, s0)>

  #set0 = affine_set<(d0, d1)[s0, s1] : (-d0 + s0 - 32 >= 0, -d1 + s1 - 32 >= 0)>

  affine.for %arg2 = 0 to %M step 32 {
    affine.for %arg3 = 0 to %N step 32 {
      affine.if #set0(%arg2, %arg3)[%M, %N] {
        // Full tile.
        affine.for %arg4 = #map0(%arg2) to #map1(%arg2) {
          affine.for %arg5 = #map0(%arg3) to #map1(%arg3) {
            "foo"() : () -> ()
          }
        }
      } else {
        // Partial tile.
        affine.for %arg4 = #map0(%arg2) to min #map2(%arg2)[%M] {
          affine.for %arg5 = #map0(%arg3) to min #map2(%arg3)[%N] {
            "foo"() : () -> ()
          }
        }
      }
    }
  }

The separation is tested via a cmd line flag on the loop tiling pass.
The utility itself allows one to pass in any band of contiguously nested
loops, and can be used by other transforms/utilities. The current
implementation works for hyperrectangular loop nests.

Signed-off-by: Uday Bondhugula <uday at polymagelabs.com>

Differential Revision: https://reviews.llvm.org/D76700

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Transforms/LoopUtils.h
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/test/Dialect/Affine/loop-tiling.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 91435d69147d..cd3195230834 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -210,6 +210,15 @@ class FlatAffineConstraints {
                                      ValueRange operands, bool eq,
                                      bool lower = true);
 
+  /// Returns the bound for the identifier at `pos` from the inequality at
+  /// `ineqPos` as a 1-d affine value map (affine map + operands). The returned
+  /// affine value map can either be a lower bound or an upper bound depending
+  /// on the sign of atIneq(ineqPos, pos). Asserts if the row at `ineqPos` does
+  /// not involve the `pos`th identifier.
+  void getIneqAsAffineValueMap(unsigned pos, unsigned ineqPos,
+                               AffineValueMap &vmap,
+                               MLIRContext *context) const;
+
   /// Returns the constraint system as an integer set. Returns a null integer
   /// set if the system has no constraints, or if an integer set couldn't be
   /// constructed as a result of a local variable's explicit representation not
@@ -452,15 +461,17 @@ class FlatAffineConstraints {
   /// affine expressions involving only the symbolic identifiers. `lb` and
   /// `ub` (along with the `boundFloorDivisor`) are set to represent the lower
   /// and upper bound associated with the constant 
diff erence: `lb`, `ub` have
-  /// the coefficients, and boundFloorDivisor, their divisor.
+  /// the coefficients, and boundFloorDivisor, their divisor. `minLbPos` and
+  /// `minUbPos` if non-null are set to the position of the constant lower bound
+  /// and upper bound respectively (to the same if they are from an equality).
   /// Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with
-  /// three symbolic identifiers, *lb = [1, 0, 1], boundDivisor = 32. See
-  /// comments at function definition for examples.
-  Optional<int64_t>
-  getConstantBoundOnDimSize(unsigned pos,
-                            SmallVectorImpl<int64_t> *lb = nullptr,
-                            int64_t *boundFloorDivisor = nullptr,
-                            SmallVectorImpl<int64_t> *ub = nullptr) const;
+  /// three symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32. See comments
+  /// at function definition for examples.
+  Optional<int64_t> getConstantBoundOnDimSize(
+      unsigned pos, SmallVectorImpl<int64_t> *lb = nullptr,
+      int64_t *boundFloorDivisor = nullptr,
+      SmallVectorImpl<int64_t> *ub = nullptr, unsigned *minLbPos = nullptr,
+      unsigned *minUbPos = nullptr) const;
 
   /// Returns the constant lower bound for the pos^th identifier if there is
   /// one; None otherwise.
@@ -482,6 +493,20 @@ class FlatAffineConstraints {
                         unsigned symStartPos, ArrayRef<AffineExpr> localExprs,
                         MLIRContext *context) const;
 
+  /// Gather positions of all lower and upper bounds of the identifier at `pos`,
+  /// and optionally any equalities on it. In addition, the bounds are to be
+  /// independent of identifiers in position range [`offset`, `offset` + `num`).
+  void
+  getLowerAndUpperBoundIndices(unsigned pos,
+                               SmallVectorImpl<unsigned> *lbIndices,
+                               SmallVectorImpl<unsigned> *ubIndices,
+                               SmallVectorImpl<unsigned> *eqIndices = nullptr,
+                               unsigned offset = 0, unsigned num = 0) const;
+
+  /// Removes constraints that are independent of (i.e., do not have a
+  /// coefficient for) for identifiers in the range [pos, pos + num).
+  void removeIndependentConstraints(unsigned pos, unsigned num);
+
   /// Returns true if the set can be trivially detected as being
   /// hyper-rectangular on the specified contiguous set of identifiers.
   bool isHyperRectangular(unsigned pos, unsigned num) const;

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 6994b5f17661..6d0148b4d0ad 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -275,6 +275,16 @@ def AffineIfOp : Affine_Op<"if",
     /// list of AffineIf is not resizable.
     void setConditional(IntegerSet set, ValueRange operands);
 
+    Block *getThenBlock() {
+      assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
+      return &thenRegion().front();
+    }
+
+    Block *getElseBlock() {
+      assert(!elseRegion().empty() && "Empty 'else' region.");
+      return &elseRegion().front();
+    }
+
     OpBuilder getThenBodyBuilder() {
       assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
       Block &body = thenRegion().front();
@@ -401,7 +411,7 @@ def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> {
 
     /// Get ranges as constants, may fail in dynamic case.
     Optional<SmallVector<int64_t, 8>> getConstantRanges();
-    
+
     Block *getBody();
     OpBuilder getBodyBuilder();
     void setSteps(ArrayRef<int64_t> newSteps);

diff  --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h
index 861a2821ee00..c0d323f790fb 100644
--- a/mlir/include/mlir/Transforms/LoopUtils.h
+++ b/mlir/include/mlir/Transforms/LoopUtils.h
@@ -24,6 +24,7 @@ class AffineForOp;
 class FuncOp;
 class OpBuilder;
 class Value;
+class ValueRange;
 struct MemRefRegion;
 
 namespace loop {
@@ -90,10 +91,12 @@ LogicalResult affineForOpBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
                                   bool unrollPrologueEpilogue = false);
 
 /// Tiles the specified band of perfectly nested loops creating tile-space loops
-/// and intra-tile loops. A band is a contiguous set of loops.
+/// and intra-tile loops. A band is a contiguous set of loops. `tiledNest` when
+/// non-null is set to the loops of the tiled nest from outermost to innermost.
 LLVM_NODISCARD
 LogicalResult tileCodeGen(MutableArrayRef<AffineForOp> band,
-                          ArrayRef<unsigned> tileSizes);
+                          ArrayRef<unsigned> tileSizes,
+                          SmallVectorImpl<AffineForOp> *tiledNest = nullptr);
 
 /// Performs loop interchange on 'forOpA' and 'forOpB'. Requires that 'forOpA'
 /// and 'forOpB' are part of a perfectly nested sequence of loops.
@@ -271,6 +274,29 @@ void mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value> processorId,
 void gatherLoops(FuncOp func,
                  std::vector<SmallVector<AffineForOp, 2>> &depthToLoops);
 
+/// Creates an AffineForOp while ensuring that the lower and upper bounds are
+/// canonicalized, i.e., unused and duplicate operands are removed, and any
+/// constant operands propagated/folded in.
+AffineForOp createCanonicalizedAffineForOp(OpBuilder b, Location loc,
+                                           ValueRange lbOperands,
+                                           AffineMap lbMap,
+                                           ValueRange ubOperands,
+                                           AffineMap ubMap, int64_t step = 1);
+
+/// Separates full tiles from partial tiles for a perfect nest `nest` by
+/// generating a conditional guard that selects between the full tile version
+/// and the partial tile version using an AffineIfOp. The original loop nest
+/// is replaced by this guarded two version form.
+///
+///    affine.if (cond)
+///      // full_tile
+///    else
+///      // partial tile
+///
+LogicalResult
+separateFullTiles(MutableArrayRef<AffineForOp> nest,
+                  SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
+
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_LOOP_UTILS_H

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index c652aad19035..947087c549e2 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1200,38 +1200,58 @@ static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
   return false;
 }
 
-/// Gather all lower and upper bounds of the identifier at `pos`. The bounds are
-/// to be independent of [offset, offset + num) identifiers.
-static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst,
-                                         unsigned pos,
-                                         SmallVectorImpl<unsigned> *lbIndices,
-                                         SmallVectorImpl<unsigned> *ubIndices,
-                                         unsigned offset = 0,
-                                         unsigned num = 0) {
-  assert(pos < cst.getNumIds() && "invalid position");
+/// Gather all lower and upper bounds of the identifier at `pos`, and
+/// optionally any equalities on it. In addition, the bounds are to be
+/// independent of identifiers in position range [`offset`, `offset` + `num`).
+void FlatAffineConstraints::getLowerAndUpperBoundIndices(
+    unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
+    SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
+    unsigned offset, unsigned num) const {
+  assert(pos < getNumIds() && "invalid position");
+  assert(offset + num < getNumCols() && "invalid range");
 
-  // Gather all lower bounds and upper bounds of the variable. Since the
-  // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
-  // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
-  for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
-    // The bounds are to be independent of [offset, offset + num) columns.
+  // Checks for a constraint that has a non-zero coeff for the identifiers in
+  // the position range [offset, offset + num) while ignoring `pos`.
+  auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
     unsigned c, f;
+    auto cst = isEq ? getEquality(r) : getInequality(r);
     for (c = offset, f = offset + num; c < f; ++c) {
       if (c == pos)
         continue;
-      if (cst.atIneq(r, c) != 0)
+      if (cst[c] != 0)
         break;
     }
-    if (c < f)
+    return c < f;
+  };
+
+  // Gather all lower bounds and upper bounds of the variable. Since the
+  // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
+  // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    // The bounds are to be independent of [offset, offset + num) columns.
+    if (containsConstraintDependentOnRange(r, /*isEq=*/false))
       continue;
-    if (cst.atIneq(r, pos) >= 1) {
+    if (atIneq(r, pos) >= 1) {
       // Lower bound.
       lbIndices->push_back(r);
-    } else if (cst.atIneq(r, pos) <= -1) {
+    } else if (atIneq(r, pos) <= -1) {
       // Upper bound.
       ubIndices->push_back(r);
     }
   }
+
+  // An equality is both a lower and upper bound. Record any equalities
+  // involving the pos^th identifier.
+  if (!eqIndices)
+    return;
+
+  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+    if (atEq(r, pos) == 0)
+      continue;
+    if (containsConstraintDependentOnRange(r, /*isEq=*/true))
+      continue;
+    eqIndices->push_back(r);
+  }
 }
 
 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
@@ -1247,7 +1267,7 @@ static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
   assert(pos < cst.getNumIds() && "invalid position");
 
   SmallVector<unsigned, 4> lbIndices, ubIndices;
-  getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices);
+  cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices);
 
   // Check if any lower bound, upper bound pair is of the form:
   // divisor * id >=  expr - (divisor - 1)    <-- Lower bound for 'id'
@@ -1376,7 +1396,7 @@ std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
          "incorrect local exprs count");
 
   SmallVector<unsigned, 4> lbIndices, ubIndices;
-  getLowerAndUpperBoundIndices(*this, pos + offset, &lbIndices, &ubIndices);
+  getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices);
 
   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
@@ -1872,9 +1892,23 @@ void FlatAffineConstraints::removeEquality(unsigned pos) {
   std::copy(equalities.begin() + inputIndex,
             equalities.begin() + inputIndex + numElemsToCopy,
             equalities.begin() + outputIndex);
+  assert(equalities.size() >= numReservedCols);
   equalities.resize(equalities.size() - numReservedCols);
 }
 
+void FlatAffineConstraints::removeInequality(unsigned pos) {
+  unsigned numInequalities = getNumInequalities();
+  assert(pos < numInequalities && "invalid position");
+  unsigned outputIndex = pos * numReservedCols;
+  unsigned inputIndex = (pos + 1) * numReservedCols;
+  unsigned numElemsToCopy = (numInequalities - pos - 1) * numReservedCols;
+  std::copy(inequalities.begin() + inputIndex,
+            inequalities.begin() + inputIndex + numElemsToCopy,
+            inequalities.begin() + outputIndex);
+  assert(inequalities.size() >= numReservedCols);
+  inequalities.resize(inequalities.size() - numReservedCols);
+}
+
 /// Finds an equality that equates the specified identifier to a constant.
 /// Returns the position of the equality row. If 'symbolic' is set to true,
 /// symbols are also treated like a constant, i.e., an affine function of the
@@ -1951,14 +1985,22 @@ void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
     unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
-    SmallVectorImpl<int64_t> *ub) const {
+    SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
+    unsigned *minUbPos) const {
   assert(pos < getNumDimIds() && "Invalid identifier position");
-  assert(getNumLocalIds() == 0);
 
   // Find an equality for 'pos'^th identifier that equates it to some function
   // of the symbolic identifiers (+ constant).
   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
   if (eqPos != -1) {
+    auto eq = getEquality(eqPos);
+    // If the equality involves a local var, punt for now.
+    // TODO: this can be handled in the future by using the explicit
+    // representation of the local vars.
+    if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
+                     [](int64_t coeff) { return coeff == 0; }))
+      return None;
+
     // This identifier can only take a single value.
     if (lb) {
       // Set lb to that symbolic value.
@@ -1979,6 +2021,10 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
              "both lb and divisor or none should be provided");
       *boundFloorDivisor = 1;
     }
+    if (minLbPos)
+      *minLbPos = eqPos;
+    if (minUbPos)
+      *minUbPos = eqPos;
     return 1;
   }
 
@@ -1999,8 +2045,8 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
   // the bounds can only involve symbolic (and local) identifiers. Since the
   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
-  getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices,
-                               /*offset=*/0,
+  getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
+                               /*eqIndices=*/nullptr, /*offset=*/0,
                                /*num=*/getNumDimIds());
 
   Optional<int64_t> minDiff = None;
@@ -2054,6 +2100,10 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
     // the constant term for the lower bound.
     (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
   }
+  if (minLbPos)
+    *minLbPos = minLbPosition;
+  if (minUbPos)
+    *minUbPos = minUbPosition;
   return minDiff;
 }
 
@@ -2726,6 +2776,51 @@ static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
 }
 
+void FlatAffineConstraints::getIneqAsAffineValueMap(
+    unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
+    MLIRContext *context) const {
+  unsigned numDims = getNumDimIds();
+  unsigned numSyms = getNumSymbolIds();
+
+  assert(pos < numDims && "invalid position");
+  assert(ineqPos < getNumInequalities() && "invalid inequality position");
+
+  // Get expressions for local vars.
+  SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
+  if (failed(computeLocalVars(*this, memo, context)))
+    assert(false &&
+           "one or more local exprs do not have an explicit representation");
+  auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
+
+  // Compute the AffineExpr lower/upper bound for this inequality.
+  ArrayRef<int64_t> inequality = getInequality(ineqPos);
+  SmallVector<int64_t, 8> bound;
+  bound.reserve(getNumCols() - 1);
+  // Everything other than the coefficient at `pos`.
+  bound.append(inequality.begin(), inequality.begin() + pos);
+  bound.append(inequality.begin() + pos + 1, inequality.end());
+
+  if (inequality[pos] > 0)
+    // Lower bound.
+    std::transform(bound.begin(), bound.end(), bound.begin(),
+                   std::negate<int64_t>());
+  else
+    // Upper bound (which is exclusive).
+    bound.back() += 1;
+
+  // Convert to AffineExpr (tree) form.
+  auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
+                                             localExprs, context);
+
+  // Get the values to bind to this affine expr (all dims and symbols).
+  SmallVector<Value, 4> operands;
+  getIdValues(0, pos, &operands);
+  SmallVector<Value, 4> trailingOperands;
+  getIdValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
+  operands.append(trailingOperands.begin(), trailingOperands.end());
+  vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
+}
+
 /// Returns true if the pos^th column is all zero for both inequalities and
 /// equalities..
 static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) {
@@ -2739,7 +2834,7 @@ IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
     // Return universal set (always true): 0 == 0.
     return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
                            getAffineConstantExpr(/*constant=*/0, context),
-                           true);
+                           /*eqFlags=*/true);
 
   // Construct local references.
   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
@@ -2778,3 +2873,52 @@ IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
                                               numSyms, localExprs, context));
   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
 }
+
+/// Find positions of inequalities and equalities that do not have a coefficient
+/// for [pos, pos + num) identifiers.
+static void getIndependentConstraints(const FlatAffineConstraints &cst,
+                                      unsigned pos, unsigned num,
+                                      SmallVectorImpl<unsigned> &nbIneqIndices,
+                                      SmallVectorImpl<unsigned> &nbEqIndices) {
+  assert(pos < cst.getNumIds() && "invalid start position");
+  assert(pos + num <= cst.getNumIds() && "invalid limit");
+
+  for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
+    // The bounds are to be independent of [offset, offset + num) columns.
+    unsigned c;
+    for (c = pos; c < pos + num; ++c) {
+      if (cst.atIneq(r, c) != 0)
+        break;
+    }
+    if (c == pos + num)
+      nbIneqIndices.push_back(r);
+  }
+
+  for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
+    // The bounds are to be independent of [offset, offset + num) columns.
+    unsigned c;
+    for (c = pos; c < pos + num; ++c) {
+      if (cst.atEq(r, c) != 0)
+        break;
+    }
+    if (c == pos + num)
+      nbEqIndices.push_back(r);
+  }
+}
+
+void FlatAffineConstraints::removeIndependentConstraints(unsigned pos,
+                                                         unsigned num) {
+  assert(pos + num <= getNumIds() && "invalid range");
+
+  // Remove constraints that are independent of these identifiers.
+  SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
+  getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
+
+  // Iterate in reverse so that indices don't have to be updated.
+  // TODO: This method can be made more efficient (because removal of each
+  // inequality leads to much shifting/copying in the underlying buffer).
+  for (auto nbIndex : llvm::reverse(nbIneqIndices))
+    removeInequality(nbIndex);
+  for (auto nbIndex : llvm::reverse(nbEqIndices))
+    removeEquality(nbIndex);
+}

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 2b14b6a11086..bef227e83564 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -1011,14 +1011,12 @@ bool mlir::isLoopParallel(AffineForOp forOp) {
 
 IntegerSet mlir::simplifyIntegerSet(IntegerSet set) {
   FlatAffineConstraints fac(set);
-  MLIRContext *context = set.getContext();
   if (fac.isEmpty())
     return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
-                                   context);
+                                   set.getContext());
   fac.removeTrivialRedundancy();
 
-  auto simplifiedSet = fac.getAsIntegerSet(context);
+  auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
   assert(simplifiedSet && "guaranteed to succeed while roundtripping");
-
   return simplifiedSet;
 }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
index 3f08315170c8..7568098530f7 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
@@ -15,7 +15,9 @@
 #include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/LoopUtils.h"
@@ -33,6 +35,12 @@ static llvm::cl::opt<unsigned long long>
                    llvm::cl::desc("Set size of cache to tile for in KiB"),
                    llvm::cl::cat(clOptionsCategory));
 
+// Separate full and partial tiles.
+static llvm::cl::opt<bool>
+    clSeparate("affine-tile-separate",
+               llvm::cl::desc("Separate full and partial tiles"),
+               llvm::cl::cat(clOptionsCategory));
+
 // Tile size to use for all loops (overrides -tile-sizes if provided).
 static llvm::cl::opt<unsigned>
     clTileSize("affine-tile-size",
@@ -176,11 +184,12 @@ constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,
 /// and intra-tile loops. A band is a contiguous set of loops.
 //  TODO(bondhugula): handle non hyper-rectangular spaces.
 LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
-                                ArrayRef<unsigned> tileSizes) {
+                                ArrayRef<unsigned> tileSizes,
+                                SmallVectorImpl<AffineForOp> *tiledNest) {
+  // Check if the supplied for op's are all successively nested.
   assert(!band.empty() && "no loops in band");
   assert(band.size() == tileSizes.size() && "Too few/many tile sizes");
 
-  // Check if the supplied for op's are all successively nested.
   for (unsigned i = 1, e = band.size(); i < e; i++)
     assert(band[i].getParentOp() == band[i - 1] && "not a perfect nest / band");
 
@@ -248,6 +257,9 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
   // Erase the old loop nest.
   rootAffineForOp.erase();
 
+  if (tiledNest)
+    *tiledNest = std::move(tiledLoops);
+
   return success();
 }
 
@@ -393,8 +405,16 @@ void LoopTiling::runOnFunction() {
         diag << tSize << ' ';
       diag << "]\n";
     }
-    if (failed(tileCodeGen(band, tileSizes)))
+    SmallVector<AffineForOp, 6> tiledNest;
+    if (failed(tileCodeGen(band, tileSizes, &tiledNest)))
       return signalPassFailure();
+
+    // Separate full and partial tiles.
+    if (clSeparate) {
+      auto intraTileLoops =
+          MutableArrayRef<AffineForOp>(tiledNest).drop_front(band.size());
+      separateFullTiles(intraTileLoops);
+    }
   }
 }
 

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index d5976d0278ed..986f523ccdd4 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -17,10 +17,12 @@
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Transforms/Utils.h"
 #include "llvm/ADT/DenseMap.h"
@@ -2068,3 +2070,213 @@ void mlir::gatherLoops(FuncOp func,
     depthToLoops.pop_back();
   }
 }
+
+// TODO: if necessary, this can be extended to also compose in any
+// affine.applys, fold to constant if all result dimensions of the map are
+// constant (canonicalizeMapAndOperands below already does this for single
+// result bound maps), and use simplifyMap to perform algebraic simplication.
+AffineForOp mlir::createCanonicalizedAffineForOp(
+    OpBuilder b, Location loc, ValueRange lbOperands, AffineMap lbMap,
+    ValueRange ubOperands, AffineMap ubMap, int64_t step) {
+  SmallVector<Value, 4> lowerOperands(lbOperands);
+  SmallVector<Value, 4> upperOperands(ubOperands);
+
+  fullyComposeAffineMapAndOperands(&lbMap, &lowerOperands);
+  canonicalizeMapAndOperands(&lbMap, &lowerOperands);
+  fullyComposeAffineMapAndOperands(&ubMap, &upperOperands);
+  canonicalizeMapAndOperands(&ubMap, &upperOperands);
+
+  return b.create<AffineForOp>(loc, lowerOperands, lbMap, upperOperands, ubMap,
+                               step);
+}
+
+/// Creates an AffineIfOp that encodes the conditional to choose between
+/// the constant trip count version and an unknown trip count version of this
+/// nest of loops. This is used to separate partial and full tiles if `loops`
+/// has the intra-tile loops. The affine.if op is inserted at the builder
+/// insertion point of `b`.
+static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
+                                            OpBuilder b) {
+  if (loops.empty())
+    return nullptr;
+
+  auto *context = loops[0].getContext();
+
+  FlatAffineConstraints cst;
+  getIndexSet(loops, &cst);
+
+  // Remove constraints that are independent of these loop IVs.
+  cst.removeIndependentConstraints(/*pos=*/0, /*num=*/loops.size());
+
+  // Construct the constraint set representing the guard for full tiles. The
+  // lower bound (and upper bound) corresponding to the full tile should be
+  // larger (and resp. smaller) than any other lower (or upper bound).
+  SmallVector<int64_t, 8> fullTileLb, fullTileUb;
+  for (auto loop : loops) {
+    // TODO: Non-unit stride is not an issue to generalize to.
+    assert(loop.getStep() == 1 && "point loop step expected to be one");
+    // Mark everything symbols for the purpose of finding a constant 
diff  pair.
+    cst.setDimSymbolSeparation(/*newSymbolCount=*/cst.getNumDimAndSymbolIds() -
+                               1);
+    unsigned fullTileLbPos, fullTileUbPos;
+    if (!cst.getConstantBoundOnDimSize(0, /*lb=*/nullptr,
+                                       /*lbFloorDivisor=*/nullptr,
+                                       /*ub=*/nullptr, &fullTileLbPos,
+                                       &fullTileUbPos)) {
+      LLVM_DEBUG(llvm::dbgs() << "Can't get constant 
diff  pair for a loop\n");
+      return nullptr;
+    }
+
+    SmallVector<unsigned, 4> lbIndices, ubIndices;
+    cst.getLowerAndUpperBoundIndices(/*pos=*/0, &lbIndices, &ubIndices);
+
+    auto fLb = cst.getInequality(fullTileLbPos);
+    auto fUb = cst.getInequality(fullTileUbPos);
+    fullTileLb.assign(fLb.begin(), fLb.end());
+    fullTileUb.assign(fUb.begin(), fUb.end());
+
+    // Full tile lower bound should be >= than any other lower bound.
+    for (auto lbIndex : lbIndices)
+      for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
+        cst.atIneq(lbIndex, i) = fullTileLb[i] - cst.atIneq(lbIndex, i);
+
+    // Full tile upper bound should be <= any other upper bound.
+    for (auto ubIndex : ubIndices)
+      for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
+        cst.atIneq(ubIndex, i) -= fullTileUb[i];
+
+    cst.removeId(0);
+  }
+
+  // The previous step leads to all zeros for the full tile lb and ub position
+  // itself; remove those and any other duplicates / trivial redundancies.
+  cst.removeTrivialRedundancy();
+
+  // Turn everything into dims conservatively since we earlier turned all
+  // trailing ids past point loop IV into symbols. Some of these could be outer
+  // loop IVs; we'll canonicalize anyway.
+  cst.setDimSymbolSeparation(0);
+
+  IntegerSet ifCondSet = cst.getAsIntegerSet(context);
+  // ifCondSet can be null if cst was empty -- this can happen if all loops
+  // in the nest have constant trip counts.
+  if (!ifCondSet)
+    return nullptr;
+
+  SmallVector<Value, 4> setOperands;
+  cst.getIdValues(0, cst.getNumDimAndSymbolIds(), &setOperands);
+  canonicalizeSetAndOperands(&ifCondSet, &setOperands);
+  return b.create<AffineIfOp>(loops[0].getLoc(), ifCondSet, setOperands,
+                              /*withElseRegion=*/true);
+}
+
+/// Create the full tile loop nest (along with its body).
+static LogicalResult
+createFullTiles(MutableArrayRef<AffineForOp> inputNest,
+                SmallVectorImpl<AffineForOp> &fullTileLoops, OpBuilder b) {
+  fullTileLoops.reserve(inputNest.size());
+
+  // For each loop in the original nest identify a lower/upper bound pair such
+  // that their 
diff erence is a constant.
+  FlatAffineConstraints cst;
+  for (auto loop : inputNest) {
+    // TODO: straightforward to generalize to a non-unit stride.
+    if (loop.getStep() != 1) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "[tile separation] non-unit stride not implemented\n");
+      return failure();
+    }
+    getIndexSet({loop}, &cst);
+    // We will mark everything other than this loop IV as symbol for getting a
+    // pair of <lb, ub> with a constant 
diff erence.
+    cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - 1);
+    unsigned lbPos, ubPos;
+    if (!cst.getConstantBoundOnDimSize(/*pos=*/0, /*lb=*/nullptr,
+                                       /*lbDivisor=*/nullptr, /*ub=*/nullptr,
+                                       &lbPos, &ubPos) ||
+        lbPos == ubPos) {
+      LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant 
diff  / "
+                                 "equalities not yet handled\n");
+      return failure();
+    }
+
+    // Set all identifiers as dimensions uniformly since some of those marked as
+    // symbols above could be outer loop IVs (corresponding tile space IVs).
+    cst.setDimSymbolSeparation(/*newSymbolCount=*/0);
+
+    AffineValueMap lbVmap, ubVmap;
+    cst.getIneqAsAffineValueMap(/*pos=*/0, lbPos, lbVmap, b.getContext());
+    cst.getIneqAsAffineValueMap(/*pos=*/0, ubPos, ubVmap, b.getContext());
+    AffineForOp fullTileLoop = createCanonicalizedAffineForOp(
+        b, loop.getLoc(), lbVmap.getOperands(), lbVmap.getAffineMap(),
+        ubVmap.getOperands(), ubVmap.getAffineMap());
+    b = fullTileLoop.getBodyBuilder();
+    fullTileLoops.push_back(fullTileLoop);
+  }
+
+  // Add the body for the full tile loop nest.
+  BlockAndValueMapping operandMap;
+  for (auto loopEn : llvm::enumerate(inputNest))
+    operandMap.map(loopEn.value().getInductionVar(),
+                   fullTileLoops[loopEn.index()].getInductionVar());
+  b = fullTileLoops.back().getBodyBuilder();
+  for (auto &op : inputNest.back().getBody()->without_terminator())
+    b.clone(op, operandMap);
+  return success();
+}
+
+LogicalResult
+mlir::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
+                        SmallVectorImpl<AffineForOp> *fullTileNest) {
+  if (inputNest.empty())
+    return success();
+
+  auto firstLoop = inputNest[0];
+
+  // Each successive for op has to be nested in the other.
+  auto prevLoop = firstLoop;
+  for (auto loop : inputNest.drop_front(1)) {
+    assert(loop.getParentOp() == prevLoop && "input not contiguously nested");
+    prevLoop = loop;
+  }
+
+  // Create the full tile loop nest.
+  SmallVector<AffineForOp, 4> fullTileLoops;
+  OpBuilder b(firstLoop);
+  if (failed(createFullTiles(inputNest, fullTileLoops, b))) {
+    if (!fullTileLoops.empty())
+      fullTileLoops.front().erase();
+    return failure();
+  }
+
+  // Create and insert the version select right before the root of the nest.
+  b = OpBuilder(firstLoop);
+  AffineIfOp ifOp = createSeparationCondition(inputNest, b);
+  if (!ifOp) {
+    fullTileLoops.front().erase();
+    LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating "
+                               "separation condition\n");
+    return failure();
+  }
+
+  // Move the full tile into the then block.
+  Block *thenBlock = ifOp.getThenBlock();
+  AffineForOp outermostFullTileLoop = fullTileLoops[0];
+  thenBlock->getOperations().splice(
+      std::prev(thenBlock->end()),
+      outermostFullTileLoop.getOperation()->getBlock()->getOperations(),
+      Block::iterator(outermostFullTileLoop));
+
+  // Move the partial tile into the else block. The partial tile is the same as
+  // the original loop nest.
+  Block *elseBlock = ifOp.getElseBlock();
+  elseBlock->getOperations().splice(
+      std::prev(elseBlock->end()),
+      firstLoop.getOperation()->getBlock()->getOperations(),
+      Block::iterator(firstLoop));
+
+  if (fullTileNest)
+    *fullTileNest = std::move(fullTileLoops);
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/Affine/loop-tiling.mlir b/mlir/test/Dialect/Affine/loop-tiling.mlir
index 2f8223b37eeb..029c42ae0434 100644
--- a/mlir/test/Dialect/Affine/loop-tiling.mlir
+++ b/mlir/test/Dialect/Affine/loop-tiling.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -split-input-file  -affine-loop-tile -affine-tile-size=32 | FileCheck %s
 // RUN: mlir-opt %s -split-input-file -affine-loop-tile -affine-tile-cache-size=512 | FileCheck %s --check-prefix=MODEL
+// RUN: mlir-opt %s -split-input-file -affine-loop-tile -affine-tile-size=32 -affine-tile-separate | FileCheck %s --check-prefix=SEPARATE
 
 // -----
 
@@ -169,6 +170,8 @@ func @tile_with_loop_upper_bounds_in_two_symbols(%arg0: memref<?xf32>, %limit: i
 
 // -----
 
+// CHECK-LABEL: func @trip_count_1
+// SEPARATE-LABEL: func @trip_count_1
 func @trip_count_1(%arg0: memref<196608x1xf32>, %arg1: memref<196608x1xf32>)
     -> memref<196608x1xf32> {
   affine.for %i1 = 0 to 196608 {
@@ -177,8 +180,65 @@ func @trip_count_1(%arg0: memref<196608x1xf32>, %arg1: memref<196608x1xf32>)
       affine.store %4, %arg1[%i1, %i3] : memref<196608x1xf32>
     }
   }
+  // CHECK: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<196608x1xf32>
   return %arg1 : memref<196608x1xf32>
 }
+// SEPARATE: return
 
-// CHECK: %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<196608x1xf32>
+// -----
+
+func @separate_full_tile_2d(%M : index, %N : index) {
+  affine.for %i0 = 0 to %M {
+    affine.for %i1 = 0 to %N {
+      "foo"() : () -> ()
+    }
+  }
+  return
+}
+
+// SEPARATE-DAG: #[[SEP_COND:.*]] = affine_set<(d0, d1)[s0, s1] : (-d0 + s0 - 32 >= 0, -d1 + s1 - 32 >= 0)>
+// SEPARATE-DAG: #[[LB:.*]] = affine_map<(d0) -> (d0)>
+// SEPARATE-DAG: #[[FULL_TILE_UB:.*]] = affine_map<(d0) -> (d0 + 32)>
+// SEPARATE-DAG: #[[PART_TILE_UB:.*]] = affine_map<(d0)[s0] -> (d0 + 32, s0)>
+
+// SEPARATE:       affine.for %arg2
+// SEPARATE-NEXT:    affine.for %arg3
+// SEPARATE-NEXT:      affine.if #[[SEP_COND]](%arg2, %arg3)[%arg0, %arg1] {
+// SEPARATE-NEXT:        affine.for %arg4 = #[[LB]](%arg2) to #[[FULL_TILE_UB]](%arg2) {
+// SEPARATE-NEXT:          affine.for %arg5 = #[[LB]](%arg3) to #[[FULL_TILE_UB]](%arg3) {
+// SEPARATE-NEXT:           "foo"
+// SEPARATE-NEXT:          }
+// SEPARATE-NEXT:        }
+// SEPARATE-NEXT:      } else {
+// SEPARATE-NEXT:        affine.for %arg4 = #[[LB]](%arg2) to min #[[PART_TILE_UB]](%arg2)[%arg0] {
+// SEPARATE-NEXT:          affine.for %arg5 = #[[LB]](%arg3) to min #[[PART_TILE_UB]](%arg3)[%arg1] {
+// SEPARATE-NEXT:           "foo"
+// SEPARATE-NEXT:          }
+// SEPARATE-NEXT:        }
+// SEPARATE-NEXT:      }
+// SEPARATE-NEXT:    }
+// SEPARATE-NEXT:  }
+// SEPARATE-NEXT:  return
+
+// -----
+
+func @separate_full_tile_1d_max_min(%M : index, %N : index, %P : index, %Q : index) {
+  affine.for %i0 = max affine_map<(d0, d1) -> (d0, d1)>  (%M, %N) to min affine_map< (d0, d1) -> (d0, d1)> (%P, %Q) {
+  }
+  return
+}
 
+// SEPARATE-DAG: #[[SEP_COND:.*]] = affine_set<(d0)[s0, s1] : (-d0 + s0 - 32 >= 0, -d0 + s1 - 32 >= 0)>
+// SEPARATE-DAG: #[[TILE_LB:.*]] = affine_map<(d0) -> (d0)>
+// SEPARATE-DAG: #[[FULL_TILE_UB:.*]] = affine_map<(d0) -> (d0 + 32)>
+// SEPARATE-DAG: #[[PARTIAL_TILE_UB:.*]] = affine_map<(d0, d1, d2) -> (d2 + 32, d0, d1)>
+
+// SEPARATE:         affine.for %arg4
+// SEPARATE-NEXT:      affine.if #[[SEP_COND]](%arg4)[%arg2, %arg3] {
+// SEPARATE-NEXT:        affine.for %arg5 = #[[TILE_LB]](%arg4) to #[[FULL_TILE_UB]](%arg4) {
+// SEPARATE-NEXT:        }
+// SEPARATE-NEXT:      } else {
+// SEPARATE-NEXT:        affine.for %arg5 = #[[TILE_LB]](%arg4) to min #[[PARTIAL_TILE_UB]](%arg2, %arg3, %arg4) {
+// SEPARATE-NEXT:        }
+// SEPARATE-NEXT:      }
+// SEPARATE-NEXT:    }


        


More information about the Mlir-commits mailing list