[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)

lorenzo chelini llvmlistbot at llvm.org
Fri Jan 12 07:42:18 PST 2024


================
@@ -288,145 +402,131 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
   // dimensions.
-  SmallVector<OpFoldResult> tileSizeVector =
+  SmallVector<OpFoldResult> tileSizes =
       options.tileSizeComputationFunction(rewriter, op);
-  if (tileSizeVector.size() < iterationDomain.size()) {
+  if (tileSizes.size() < iterationDomain.size()) {
     auto zero = rewriter.getIndexAttr(0);
-    tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+    tileSizes.append(numLoops - tileSizes.size(), zero);
   }
 
-  // 3. Find the destination tensors to use for the operation.
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors))) {
-    return rewriter.notifyMatchFailure(op,
-                                       "unable to create destination tensors");
+  // 3. If there is an interchange specified, permute the iteration domain and
+  // the tile sizes.
+  SmallVector<int64_t> interchangeVector;
+  if (!options.interchangeVector.empty()) {
+    interchangeVector = fillInterchangeVector(options.interchangeVector,
+                                              iterationDomain.size());
   }
-
-  SmallVector<OpFoldResult> offsets, sizes;
-  SmallVector<scf::ForOp> forLoops;
-  {
-    // If there is an interchange specified, permute the iteration domain and
-    // the tile sizes.
-    SmallVector<int64_t> interchangeVector;
-    if (!options.interchangeVector.empty()) {
-      interchangeVector = fillInterchangeVector(options.interchangeVector,
-                                                iterationDomain.size());
+  if (!interchangeVector.empty()) {
+    if (!isPermutationVector(interchangeVector)) {
+      return rewriter.notifyMatchFailure(
+          op, "invalid intechange vector, not a permutation of the entire "
+              "iteration space");
     }
-    if (!interchangeVector.empty()) {
-      if (!isPermutationVector(interchangeVector)) {
-        return rewriter.notifyMatchFailure(
-            op, "invalid intechange vector, not a permutation of the entire "
-                "iteration space");
-      }
 
-      applyPermutationToVector(iterationDomain, interchangeVector);
-      applyPermutationToVector(tileSizeVector, interchangeVector);
+    applyPermutationToVector(iterationDomain, interchangeVector);
+    applyPermutationToVector(tileSizes, interchangeVector);
+  }
+
+  // 4. Define the lambda function used later to generate the body of the
+  // innermost tiled loop.
+  auto innerTileLoopBodyFn =
+      [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+          ValueRange regionIterArgs,
+          SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+          SmallVector<SmallVector<OpFoldResult>> &resultSizes)
+      -> FailureOr<TilingResult> {
+    // 4a. Compute the `offsets` and `sizes` to use for tiling.
+    SmallVector<OpFoldResult> offsets, sizes;
+    {
+      int materializedLoopNum = 0;
+      for (auto [tileSize, loopRange] : llvm::zip(tileSizes, iterationDomain)) {
+        if (isConstantIntValue(tileSize, 0)) {
+          offsets.push_back(loopRange.offset);
+          sizes.push_back(loopRange.size);
+          continue;
+        }
+        Value iv = ivs[materializedLoopNum++];
+        offsets.push_back(iv);
+        sizes.push_back(
+            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+      }
     }
 
-    // 4. Materialize an empty loop nest that iterates over the tiles. These
-    // loops for now do not return any values even if the original operation has
-    // results.
-    forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
-                                    tileSizeVector, offsets, sizes,
-                                    destinationTensors);
-
+    // 4b. If interchange was provided, apply inverse of the interchange
+    //     to get back the offsets/sizes in the order to be specified.
     if (!interchangeVector.empty()) {
       auto inversePermutation = invertPermutationVector(interchangeVector);
       applyPermutationToVector(offsets, inversePermutation);
       applyPermutationToVector(sizes, inversePermutation);
     }
-  }
 
-  LLVM_DEBUG({
-    if (!forLoops.empty()) {
-      llvm::dbgs() << "LoopNest shell :\n";
-      forLoops.front().dump();
-      llvm::dbgs() << "\n";
-    }
-  });
+    // 5. Generate the tiled implementation within the inner most loop.
 
-  // 5. Generate the tiled implementation within the inner most loop.
-  SmallVector<Value> clonedOpDestination = destinationTensors;
-  if (!forLoops.empty()) {
-    rewriter.setInsertionPointToEnd(forLoops.back().getBody());
-    clonedOpDestination =
-        llvm::map_to_vector(forLoops.back().getRegionIterArgs(),
-                            [](BlockArgument b) -> Value { return b; });
-  }
+    // 5a. Clone the operation within the loop body.
+    auto clonedOp = cast<TilingInterface>(
+        cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
 
-  // 5a. Clone the operation within the loop body.
-  auto clonedOp = cast<TilingInterface>(
-      cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
+    // 5b. Early return cloned op if tiling is not happening. We can not return
+    // the original op because it could lead to
+    // `rewriter.replaceOp(op, op->getResults())` and user would get crash.
+    if (llvm::all_of(tileSizes, isZeroIndex)) {
----------------
chelini wrote:

nit: drop braces.

https://github.com/llvm/llvm-project/pull/77874


More information about the Mlir-commits mailing list