[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 12 13:08:17 PST 2024


================
@@ -288,145 +402,131 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
   // dimensions.
-  SmallVector<OpFoldResult> tileSizeVector =
+  SmallVector<OpFoldResult> tileSizes =
       options.tileSizeComputationFunction(rewriter, op);
-  if (tileSizeVector.size() < iterationDomain.size()) {
+  if (tileSizes.size() < iterationDomain.size()) {
     auto zero = rewriter.getIndexAttr(0);
-    tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+    tileSizes.append(numLoops - tileSizes.size(), zero);
   }
 
-  // 3. Find the destination tensors to use for the operation.
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors))) {
-    return rewriter.notifyMatchFailure(op,
-                                       "unable to create destination tensors");
+  // 3. If there is an interchange specified, permute the iteration domain and
+  // the tile sizes.
+  SmallVector<int64_t> interchangeVector;
+  if (!options.interchangeVector.empty()) {
+    interchangeVector = fillInterchangeVector(options.interchangeVector,
+                                              iterationDomain.size());
   }
-
-  SmallVector<OpFoldResult> offsets, sizes;
-  SmallVector<scf::ForOp> forLoops;
-  {
-    // If there is an interchange specified, permute the iteration domain and
-    // the tile sizes.
-    SmallVector<int64_t> interchangeVector;
-    if (!options.interchangeVector.empty()) {
-      interchangeVector = fillInterchangeVector(options.interchangeVector,
-                                                iterationDomain.size());
+  if (!interchangeVector.empty()) {
+    if (!isPermutationVector(interchangeVector)) {
+      return rewriter.notifyMatchFailure(
+          op, "invalid intechange vector, not a permutation of the entire "
+              "iteration space");
     }
-    if (!interchangeVector.empty()) {
-      if (!isPermutationVector(interchangeVector)) {
-        return rewriter.notifyMatchFailure(
-            op, "invalid intechange vector, not a permutation of the entire "
-                "iteration space");
-      }
 
-      applyPermutationToVector(iterationDomain, interchangeVector);
-      applyPermutationToVector(tileSizeVector, interchangeVector);
+    applyPermutationToVector(iterationDomain, interchangeVector);
+    applyPermutationToVector(tileSizes, interchangeVector);
+  }
+
+  // 4. Define the lambda function used later to generate the body of the
+  // innermost tiled loop.
+  auto innerTileLoopBodyFn =
+      [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+          ValueRange regionIterArgs,
+          SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+          SmallVector<SmallVector<OpFoldResult>> &resultSizes)
+      -> FailureOr<TilingResult> {
+    // 4a. Compute the `offsets` and `sizes` to use for tiling.
+    SmallVector<OpFoldResult> offsets, sizes;
+    {
+      int materializedLoopNum = 0;
+      for (auto [tileSize, loopRange] : llvm::zip(tileSizes, iterationDomain)) {
+        if (isConstantIntValue(tileSize, 0)) {
+          offsets.push_back(loopRange.offset);
+          sizes.push_back(loopRange.size);
+          continue;
+        }
+        Value iv = ivs[materializedLoopNum++];
+        offsets.push_back(iv);
+        sizes.push_back(
+            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+      }
     }
 
-    // 4. Materialize an empty loop nest that iterates over the tiles. These
-    // loops for now do not return any values even if the original operation has
-    // results.
-    forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
-                                    tileSizeVector, offsets, sizes,
-                                    destinationTensors);
-
+    // 4b. If interchange was provided, apply inverse of the interchange
+    //     to get back the offsets/sizes in the order to be specified.
     if (!interchangeVector.empty()) {
       auto inversePermutation = invertPermutationVector(interchangeVector);
       applyPermutationToVector(offsets, inversePermutation);
       applyPermutationToVector(sizes, inversePermutation);
     }
-  }
 
-  LLVM_DEBUG({
-    if (!forLoops.empty()) {
-      llvm::dbgs() << "LoopNest shell :\n";
-      forLoops.front().dump();
-      llvm::dbgs() << "\n";
-    }
-  });
+    // 5. Generate the tiled implementation within the inner most loop.
 
-  // 5. Generate the tiled implementation within the inner most loop.
-  SmallVector<Value> clonedOpDestination = destinationTensors;
-  if (!forLoops.empty()) {
-    rewriter.setInsertionPointToEnd(forLoops.back().getBody());
-    clonedOpDestination =
-        llvm::map_to_vector(forLoops.back().getRegionIterArgs(),
-                            [](BlockArgument b) -> Value { return b; });
-  }
+    // 5a. Clone the operation within the loop body.
+    auto clonedOp = cast<TilingInterface>(
+        cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
 
-  // 5a. Clone the operation within the loop body.
-  auto clonedOp = cast<TilingInterface>(
-      cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
+    // 5b. Early return cloned op if tiling is not happening. We can not return
+    // the original op because it could lead to
+    // `rewriter.replaceOp(op, op->getResults())` and user would get crash.
----------------
MaheshRavishankar wrote:

Currently this is structured for the replacement to happen in the caller. So then we are pushing the burden to caller to check everytime whether the same operation is returned.

https://github.com/llvm/llvm-project/pull/77874


More information about the Mlir-commits mailing list