[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)
lorenzo chelini
llvmlistbot at llvm.org
Fri Jan 12 07:42:18 PST 2024
================
@@ -288,145 +402,131 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
- SmallVector<OpFoldResult> tileSizeVector =
+ SmallVector<OpFoldResult> tileSizes =
options.tileSizeComputationFunction(rewriter, op);
- if (tileSizeVector.size() < iterationDomain.size()) {
+ if (tileSizes.size() < iterationDomain.size()) {
auto zero = rewriter.getIndexAttr(0);
- tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+ tileSizes.append(numLoops - tileSizes.size(), zero);
}
- // 3. Find the destination tensors to use for the operation.
- SmallVector<Value> destinationTensors;
- if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
- destinationTensors))) {
- return rewriter.notifyMatchFailure(op,
- "unable to create destination tensors");
+ // 3. If there is an interchange specified, permute the iteration domain and
+ // the tile sizes.
+ SmallVector<int64_t> interchangeVector;
+ if (!options.interchangeVector.empty()) {
+ interchangeVector = fillInterchangeVector(options.interchangeVector,
+ iterationDomain.size());
}
-
- SmallVector<OpFoldResult> offsets, sizes;
- SmallVector<scf::ForOp> forLoops;
- {
- // If there is an interchange specified, permute the iteration domain and
- // the tile sizes.
- SmallVector<int64_t> interchangeVector;
- if (!options.interchangeVector.empty()) {
- interchangeVector = fillInterchangeVector(options.interchangeVector,
- iterationDomain.size());
+ if (!interchangeVector.empty()) {
+ if (!isPermutationVector(interchangeVector)) {
+ return rewriter.notifyMatchFailure(
+ op, "invalid intechange vector, not a permutation of the entire "
+ "iteration space");
}
- if (!interchangeVector.empty()) {
- if (!isPermutationVector(interchangeVector)) {
- return rewriter.notifyMatchFailure(
- op, "invalid intechange vector, not a permutation of the entire "
- "iteration space");
- }
- applyPermutationToVector(iterationDomain, interchangeVector);
- applyPermutationToVector(tileSizeVector, interchangeVector);
+ applyPermutationToVector(iterationDomain, interchangeVector);
+ applyPermutationToVector(tileSizes, interchangeVector);
+ }
+
+ // 4. Define the lambda function used later to generate the body of the
+ // innermost tiled loop.
+ auto innerTileLoopBodyFn =
+ [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+ ValueRange regionIterArgs,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes)
+ -> FailureOr<TilingResult> {
+ // 4a. Compute the `offsets` and `sizes` to use for tiling.
+ SmallVector<OpFoldResult> offsets, sizes;
+ {
+ int materializedLoopNum = 0;
+ for (auto [tileSize, loopRange] : llvm::zip(tileSizes, iterationDomain)) {
+ if (isConstantIntValue(tileSize, 0)) {
+ offsets.push_back(loopRange.offset);
+ sizes.push_back(loopRange.size);
+ continue;
+ }
+ Value iv = ivs[materializedLoopNum++];
+ offsets.push_back(iv);
+ sizes.push_back(
+ getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+ }
}
- // 4. Materialize an empty loop nest that iterates over the tiles. These
- // loops for now do not return any values even if the original operation has
- // results.
- forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
- tileSizeVector, offsets, sizes,
- destinationTensors);
-
+ // 4b. If interchange was provided, apply inverse of the interchange
+ // to get back the offsets/sizes in the order to be specified.
if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
applyPermutationToVector(offsets, inversePermutation);
applyPermutationToVector(sizes, inversePermutation);
}
- }
- LLVM_DEBUG({
- if (!forLoops.empty()) {
- llvm::dbgs() << "LoopNest shell :\n";
- forLoops.front().dump();
- llvm::dbgs() << "\n";
- }
- });
+ // 5. Generate the tiled implementation within the inner most loop.
- // 5. Generate the tiled implementation within the inner most loop.
- SmallVector<Value> clonedOpDestination = destinationTensors;
- if (!forLoops.empty()) {
- rewriter.setInsertionPointToEnd(forLoops.back().getBody());
- clonedOpDestination =
- llvm::map_to_vector(forLoops.back().getRegionIterArgs(),
- [](BlockArgument b) -> Value { return b; });
- }
+ // 5a. Clone the operation within the loop body.
+ auto clonedOp = cast<TilingInterface>(
+ cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
- // 5a. Clone the operation within the loop body.
- auto clonedOp = cast<TilingInterface>(
- cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
+ // 5b. Early return cloned op if tiling is not happening. We can not return
+ // the original op because it could lead to
+ // `rewriter.replaceOp(op, op->getResults())` and user would get crash.
+ if (llvm::all_of(tileSizes, isZeroIndex)) {
----------------
chelini wrote:
nit: drop braces.
https://github.com/llvm/llvm-project/pull/77874
More information about the Mlir-commits
mailing list