[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)

lorenzo chelini llvmlistbot at llvm.org
Fri Jan 12 07:42:18 PST 2024


================
@@ -464,50 +564,73 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   if (failed(identityTensor))
     return b.notifyMatchFailure(op,
                                 "cannot create a tensor of identity value.");
-  // 3. Create the nested loops.
-  SmallVector<OpFoldResult> offsets, sizes;
-  SmallVector<scf::ForOp> loops =
-      generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets,
-                           sizes, identityTensor.value()->getResults());
-
-  // 4. Generate the tiled implementation within the inner most loop.
-  // 4a. Clone the operation within the loop body.
-  SmallVector<Value> clonedOpDestination =
+
+  // 3. Define the callback to use for generating the inner most tile loop body.
+  auto innerTileLoopBodyFn =
+      [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+          ValueRange regionIterArgs,
+          SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+          SmallVector<SmallVector<OpFoldResult>> &resultSizes)
+      -> FailureOr<TilingResult> {
+    SmallVector<OpFoldResult> offsets, sizes;
+    {
+      int materializedLoopNum = 0;
+      for (auto [tileSize, loopRange] :
+           llvm::zip(tileSizesVector, iterationDomain)) {
+        if (isConstantIntValue(tileSize, 0)) {
+          offsets.push_back(loopRange.offset);
+          sizes.push_back(loopRange.size);
+          continue;
+        }
+        Value iv = ivs[materializedLoopNum++];
+        offsets.push_back(iv);
+        sizes.push_back(
+            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+      }
+    }
+
+    // 4a. Clone the operation.
+    auto clonedOp = cast<PartialReductionOpInterface>(
+        cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+
+    // 4b. Tile the cloned operation.
+    Operation *parallelOp = clonedOp.tileToPartialReduction(
+        b, loc, regionIterArgs, offsets, sizes, reductionDims);
+    // 4c. Delete the cloned operation.
+    b.eraseOp(clonedOp);
+
+    // 4d. Compute the offsets and sizes needed to insert the result of the
+    // tiled
+    //     value back into destination before yielding the destination.
+    SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+    resultOffsets.emplace_back(std::move(outOffsets));
+
+    SmallVector<OpFoldResult> outSizes;
+    for (size_t i = 0; i < offsets.size(); i++) {
+      outSizes.push_back(
+          tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+    }
+    resultSizes.emplace_back(std::move(outSizes));
+    return TilingResult{{parallelOp}, parallelOp->getResults()};
+  };
+
+  // 5. Generate the tiled implementation using the destination tensors.
+  SmallVector<Value> destinationTensors =
       llvm::map_to_vector(identityTensor.value()->getResults(),
                           [](OpResult res) -> Value { return res; });
-  if (!loops.empty()) {
-    b.setInsertionPointToEnd(loops.back().getBody());
-    clonedOpDestination =
-        llvm::map_to_vector(loops.back().getRegionIterArgs(),
-                            [](BlockArgument b) -> Value { return b; });
-  }
-  auto clonedOp = cast<PartialReductionOpInterface>(
-      cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination));
-
-  // 4b. Tile the cloned operation.
-  Operation *parallelOp = clonedOp.tileToPartialReduction(
-      b, loc, clonedOpDestination, offsets, sizes, reductionDims);
-  // 4c. Delete the cloned operation.
-  b.eraseOp(clonedOp);
-
-  SmallVector<OpFoldResult> outSizes;
-  for (size_t i = 0; i < offsets.size(); i++) {
-    outSizes.push_back(
-        tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
-  }
-  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-  SmallVector<OpFoldResult> outStrides(outOffsets.size(), b.getIndexAttr(1));
-  SmallVector<Value> yieldedVals;
-  auto bbArgs = loops.back().getRegionIterArgs();
-  for (auto [result, bbArg] : llvm::zip(parallelOp->getResults(), bbArgs)) {
-    Value insert = b.create<tensor::InsertSliceOp>(
-        loc, result, bbArg, outOffsets, outSizes, outStrides);
-    yieldedVals.push_back(insert);
-  }
-  b.create<scf::YieldOp>(loc, yieldedVals);
+
+  SmallVector<LoopLikeOpInterface> loops;
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+  FailureOr<TilingResult> tilingResult =
+      generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
+                       destinationTensors, innerTileLoopBodyFn, loops);
+  if (failed(tilingResult)) {
----------------
chelini wrote:

drop braces?

https://github.com/llvm/llvm-project/pull/77874


More information about the Mlir-commits mailing list