[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)
lorenzo chelini
llvmlistbot at llvm.org
Fri Jan 12 07:42:18 PST 2024
================
@@ -464,50 +564,73 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
if (failed(identityTensor))
return b.notifyMatchFailure(op,
"cannot create a tensor of identity value.");
- // 3. Create the nested loops.
- SmallVector<OpFoldResult> offsets, sizes;
- SmallVector<scf::ForOp> loops =
- generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets,
- sizes, identityTensor.value()->getResults());
-
- // 4. Generate the tiled implementation within the inner most loop.
- // 4a. Clone the operation within the loop body.
- SmallVector<Value> clonedOpDestination =
+
+ // 3. Define the callback to use for generating the inner most tile loop body.
+ auto innerTileLoopBodyFn =
+ [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+ ValueRange regionIterArgs,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes)
+ -> FailureOr<TilingResult> {
+ SmallVector<OpFoldResult> offsets, sizes;
+ {
+ int materializedLoopNum = 0;
+ for (auto [tileSize, loopRange] :
+ llvm::zip(tileSizesVector, iterationDomain)) {
+ if (isConstantIntValue(tileSize, 0)) {
+ offsets.push_back(loopRange.offset);
+ sizes.push_back(loopRange.size);
+ continue;
+ }
+ Value iv = ivs[materializedLoopNum++];
+ offsets.push_back(iv);
+ sizes.push_back(
+ getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+ }
+ }
+
+ // 4a. Clone the operation.
+ auto clonedOp = cast<PartialReductionOpInterface>(
+ cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+
+ // 4b. Tile the cloned operation.
+ Operation *parallelOp = clonedOp.tileToPartialReduction(
+ b, loc, regionIterArgs, offsets, sizes, reductionDims);
+ // 4c. Delete the cloned operation.
+ b.eraseOp(clonedOp);
+
+ // 4d. Compute the offsets and sizes needed to insert the result of the
+ // tiled
+ // value back into destination before yielding the destination.
+ SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+ resultOffsets.emplace_back(std::move(outOffsets));
+
+ SmallVector<OpFoldResult> outSizes;
+ for (size_t i = 0; i < offsets.size(); i++) {
+ outSizes.push_back(
+ tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+ }
+ resultSizes.emplace_back(std::move(outSizes));
+ return TilingResult{{parallelOp}, parallelOp->getResults()};
+ };
+
+ // 5. Generate the tiled implementation using the destination tensors.
+ SmallVector<Value> destinationTensors =
llvm::map_to_vector(identityTensor.value()->getResults(),
[](OpResult res) -> Value { return res; });
- if (!loops.empty()) {
- b.setInsertionPointToEnd(loops.back().getBody());
- clonedOpDestination =
- llvm::map_to_vector(loops.back().getRegionIterArgs(),
- [](BlockArgument b) -> Value { return b; });
- }
- auto clonedOp = cast<PartialReductionOpInterface>(
- cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination));
-
- // 4b. Tile the cloned operation.
- Operation *parallelOp = clonedOp.tileToPartialReduction(
- b, loc, clonedOpDestination, offsets, sizes, reductionDims);
- // 4c. Delete the cloned operation.
- b.eraseOp(clonedOp);
-
- SmallVector<OpFoldResult> outSizes;
- for (size_t i = 0; i < offsets.size(); i++) {
- outSizes.push_back(
- tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
- }
- SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
- SmallVector<OpFoldResult> outStrides(outOffsets.size(), b.getIndexAttr(1));
- SmallVector<Value> yieldedVals;
- auto bbArgs = loops.back().getRegionIterArgs();
- for (auto [result, bbArg] : llvm::zip(parallelOp->getResults(), bbArgs)) {
- Value insert = b.create<tensor::InsertSliceOp>(
- loc, result, bbArg, outOffsets, outSizes, outStrides);
- yieldedVals.push_back(insert);
- }
- b.create<scf::YieldOp>(loc, yieldedVals);
+
+ SmallVector<LoopLikeOpInterface> loops;
+ scf::SCFTilingOptions options;
+ options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+ FailureOr<TilingResult> tilingResult =
+ generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
+ destinationTensors, innerTileLoopBodyFn, loops);
+ if (failed(tilingResult)) {
----------------
chelini wrote:
drop braces?
https://github.com/llvm/llvm-project/pull/77874
More information about the Mlir-commits
mailing list