[Mlir-commits] [mlir] [NFC] Simplify the tiling implementation using cloning. (PR #72178)
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 15 04:34:09 PST 2023
================
@@ -496,42 +496,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
reductionDims.push_back(idx);
}
- // 1. create the inital tensor value.
+ // 2. create the inital tensor value.
FailureOr<Operation *> identityTensor =
op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
reductionDims);
if (failed(identityTensor))
return b.notifyMatchFailure(op,
"cannot create a tensor of identity value.");
- // 2. Create the nested loops.
+ // 3. Create the nested loops.
SmallVector<OpFoldResult> offsets, sizes;
- SmallVector<scf::ForOp> loops = generateTileLoopNest(
- b, loc, iterationDomain, tileSizesVector, offsets, sizes);
+ SmallVector<scf::ForOp> loops =
+ generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets,
+ sizes, identityTensor.value()->getResults());
+
+ // 4. Generate the tiled implementation within the inner most loop.
+ // 4a. Clone the operation within the loop body.
+ SmallVector<Value> clonedOpDestination =
+ llvm::map_to_vector(identityTensor.value()->getResults(),
+ [](OpResult res) -> Value { return res; });
+ if (!loops.empty()) {
+ b.setInsertionPointToEnd(loops.back().getBody());
+ clonedOpDestination =
+ llvm::map_to_vector(loops.back().getRegionIterArgs(),
+ [](BlockArgument b) -> Value { return b; });
+ }
+ auto clonedOp = cast<PartialReductionOpInterface>(
+ cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination));
- // 3. Generate the tiled implementation within the inner most loop.
- b.setInsertionPoint(loops.back().getBody()->getTerminator());
- Operation *parallelOp = op.tileToPartialReduction(
- b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
+ // 4b. Tile the cloned operation.
+ Operation *parallelOp = clonedOp.tileToPartialReduction(
+ b, loc, clonedOpDestination, offsets, sizes, reductionDims);
+ // 4c. Delete the cloned operation.
+ b.eraseOp(clonedOp);
- SmallVector<OpFoldResult> resultSizesList;
- for (size_t i = 0; i < offsets.size(); i++)
- resultSizesList.push_back(
+ SmallVector<OpFoldResult> outSizes;
+ for (size_t i = 0; i < offsets.size(); i++) {
+ outSizes.push_back(
tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+ }
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
- SmallVector<Value> replacements = yieldTiledValues(
- b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
- resultSizesList, loops);
-
- auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
- auto innerMostLoop = loops.back();
- SmallVector<Value> destinationTensors = llvm::to_vector(dstOp.getDpsInits());
- assert(destinationTensors.size() ==
- innerMostLoop.getRegionIterArgs().size() &&
- "unexpected number of outputs");
- updateDestinationOperandsForTiledOp(b, destinationTensors,
----------------
nicolasvasilache wrote:
++1
https://github.com/llvm/llvm-project/pull/72178
More information about the Mlir-commits
mailing list