[Mlir-commits] [mlir] [mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (PR #120115)
Kunwar Grover
llvmlistbot at llvm.org
Mon Dec 16 09:21:44 PST 2024
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/120115
This patch unifies the tiling implementation for tileUsingFor and tileReductionUsingFor. This is done by passing an addition option to SCFTilingOptions, allowing it to set how reduction dimensions should be tiled. Currently, there are 3 different options for reduction tiling: FullReduction (old tileUsingFor), PartialReductionOuterReduction (old tileReductionUsingFor) and PartialReductionOuterParallel (linalg::tileReductionUsingForall, this isn't implemented in this patch).
The patch makes tileReductionUsingFor use the tileUsingFor implementation with the new reduction tiling options.
There are no test changes because the implementation was doing almost the exactly same thing. This was also tested in IREE (which uses both these APIs heavily) and there were no test changes.
>From facdf23493e348a2c3a03083a24b4b7a6c072a9c Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 13 Dec 2024 16:58:10 +0000
Subject: [PATCH] [mlir][SCF] Merge tileUsingFor and tileReductionUsingFor
implementation
---
.../SCF/Transforms/TileUsingInterface.h | 57 ++-
.../TransformOps/LinalgTransformOps.cpp | 13 +-
.../SCF/Transforms/TileUsingInterface.cpp | 458 ++++++++++--------
.../TestTilingInterfaceTransformOps.cpp | 3 +-
4 files changed, 306 insertions(+), 225 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9f5f9f3fca97ad..d2cddfe00ac78e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,36 @@ struct SCFTilingOptions {
return *this;
}
+ /// Specify how reduction dimensions should be tiled.
+ ///
+ /// Tiling can be thought of as splitting a dimension into 2 and materializing
+ /// the outer dimension as a loop:
+ ///
+ /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+ ///
+ /// For parallel dimensions, the split can only happen in one way, with both
+ /// dimensions being parallel. For reduction dimensions however, there is a
+ /// choice in how we split the reduction dimension. This enum exposes this
+ /// choice.
+ enum class ReductionTilingStrategy {
+ // [reduction] -> [reduction1, reduction2]
+ // -> loop[reduction1] { [reduction2] }
+ FullReduction,
+ // [reduction] -> [reduction1, parallel2]
+ // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+ PartialReductionOuterReduction,
+ // [reduction] -> [parallel1, reduction2]
+ // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+ PartialReductionOuterParallel
+ };
+ ReductionTilingStrategy reductionStrategy =
+ ReductionTilingStrategy::FullReduction;
+ SCFTilingOptions &
+ setReductionTilingStrategy(ReductionTilingStrategy strategy) {
+ reductionStrategy = strategy;
+ return *this;
+ }
+
/// Specify mapping of loops to devices. This is only respected when the loop
/// constructs support such a mapping (like `scf.forall`). Will be ignored
/// when using loop constructs that dont support such a mapping (like
@@ -102,11 +132,16 @@ struct SCFTilingResult {
/// matter except the last op. The replacements are expected to be the results
/// of the last op.
SmallVector<Operation *> tiledOps;
+ /// The initial destination values passed to the tiled operations.
+ SmallVector<Value> initialValues;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
- /// Values to use as replacements for the untiled op. Is the same size as the
- /// number of results of the untiled op.
- SmallVector<Value> replacements;
+ /// The result generated by the loop nest in tiling, may hold partial results,
+ /// which need to be merged to match the computation of the untiled operation.
+ /// `mergeResult` contains the operations used to perform this merge from
+ /// partial results and the values that can be used as replacements of
+ /// the untiled operation.
+ MergeResult mergeResult;
/// Slices generated after tiling that can be used for fusing with the tiled
/// producer.
SmallVector<Operation *> generatedSlices;
@@ -300,20 +335,6 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
FailureOr<SmallVector<scf::ForOp>>
lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
-/// Transformation information returned after reduction tiling.
-struct SCFReductionTilingResult {
- /// The partial reduction tiled op generated.
- SmallVector<Operation *> parallelTiledOps;
- /// The final reduction operation merging all the partial reductions.
- SmallVector<Operation *> mergeOps;
- /// Initial values used for reduction.
- SmallVector<Value> initialValues;
- /// The loop operations that iterate over the tiles.
- SmallVector<LoopLikeOpInterface> loops;
- /// The replacements to use for the results of the tiled operation.
- SmallVector<Value> replacements;
-};
-
/// Method to tile a reduction and generate a parallel op within a serial loop.
/// Each of the partial reductions are calculated in parallel. Then after the
/// loop all the partial reduction are merged into a final reduction.
@@ -338,7 +359,7 @@ struct SCFReductionTilingResult {
/// %6 = linalg.generic %1 ["parallel", "reduction"]
/// : tensor<7x4xf32> -> tensor<7xf32>
/// ```
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSize);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8839faf4cafb2d..66a3947e0f91fc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2224,7 +2224,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultDefiniteFailure(target);
if (target->getNumResults())
- rewriter.replaceOp(target, maybeTilingResult->replacements);
+ rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
else
rewriter.eraseOp(target);
@@ -2631,17 +2631,18 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
+ FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
if (failed(result))
return emitDefaultSilenceableFailure(target);
+ rewriter.replaceOp(target, result->mergeResult.replacements);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
- for (auto parallelTiledOp : result->parallelTiledOps)
+ for (auto parallelTiledOp : result->tiledOps)
results.push_back(parallelTiledOp);
- for (auto mergeOp : result->mergeOps)
+ for (auto mergeOp : result->mergeResult.mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
@@ -3065,7 +3066,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure::definiteFailure();
- rewriter.replaceOp(op, maybeTilingResult->replacements);
+ rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
tiled.append(maybeTilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3304,7 +3305,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);
- rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
+ rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
tilingResult = *maybeTilingResult;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 6a4a6b43933806..8ece9fb259ddd6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -570,6 +570,146 @@ static LogicalResult generateLoopNest(
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
}
+static LogicalResult
+createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
+ ArrayRef<OpFoldResult> tileSizes,
+ SmallVector<Value> &initTensors,
+ const scf::SCFTilingOptions &options) {
+ Location loc = op->getLoc();
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return tensor::getOrCreateDestinations(rewriter, loc, op, initTensors);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ FailureOr<SmallVector<Value>> maybeInitTensors =
+ redOp.generateInitialTensorForPartialReduction(rewriter, loc, tileSizes,
+ reductionDims);
+ if (failed(maybeInitTensors)) {
+ return failure();
+ }
+ initTensors = maybeInitTensors.value();
+ return success();
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
+static FailureOr<TilingResult>
+getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
+ ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ const scf::SCFTilingOptions &options) {
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return op.getTiledImplementation(rewriter, offsets, sizes);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
+ offsets, sizes, reductionDims);
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
+static LogicalResult
+getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
+ TilingInterface op, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &resultOffset,
+ SmallVector<OpFoldResult> &resultSize,
+ const scf::SCFTilingOptions &options) {
+
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return op.getResultTilePosition(rewriter, index, offsets, sizes,
+ resultOffset, resultSize);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ resultOffset =
+ SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
+ for (size_t i = 0; i < offsets.size(); i++) {
+ resultSize.push_back(
+ tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
+ }
+ return success();
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+ }
+}
+
+static FailureOr<MergeResult>
+mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
+ ValueRange partialResults,
+ const scf::SCFTilingOptions &options) {
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ // No need to merge results for reduction tiling strategy.
+ return MergeResult{{}, partialResults};
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
+ reductionDims);
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
/// Append the specified additional `newInitOperands` operands to the
/// loops existing `init` operands (or similar), and replace `loopOp` with
/// the new loop that has the additional init operands. The loop body of
@@ -710,11 +850,11 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
});
}
-/// Method to add new init values to a loop nest. Updates `loops` in-place with
-/// new loops that use the `newInitValues`.
-/// The outer-loops are updated to yield the new result values of the inner
-/// loop. For the innermost loop, the call back `getNewYields` is invoked to get
-/// the additional values to yield form the innermost loop.
+/// Method to add new init values to a loop nest. Updates `loops` in-place
+/// with new loops that use the `newInitValues`. The outer-loops are updated
+/// to yield the new result values of the inner loop. For the innermost loop,
+/// the call back `getNewYields` is invoked to get the additional values to
+/// yield form the innermost loop.
static LogicalResult addInitOperandsToLoopNest(
RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
@@ -852,9 +992,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
auto clonedOp = cast<TilingInterface>(
cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
- // 5b. Early return cloned op if tiling is not happening. We can not return
- // the original op because it could lead to
- // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
+ // 5b. Early return cloned op if tiling is not happening. We can not
+ // return the original op because it could lead to `rewriter.replaceOp(op,
+ // op->getResults())` and users would get crash.
if (llvm::all_of(tileSizes, isZeroIndex)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
@@ -864,7 +1004,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
}
// 5c. Tile the cloned operation.
- tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
+ tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
+ offsets, sizes, options);
if (failed(tilingResult)) {
rewriter.eraseOp(clonedOp);
return op.emitOpError("faild to tile operation");
@@ -879,8 +1020,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
llvm::enumerate(tilingResult->tiledValues)) {
tiledResults.push_back(tiledValue);
SmallVector<OpFoldResult> resultOffset, resultSize;
- if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
- resultOffset, resultSize))) {
+ if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
+ sizes, resultOffset, resultSize,
+ options))) {
for (auto op : tilingResult->tiledOps) {
rewriter.eraseOp(op);
}
@@ -895,158 +1037,64 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
};
// 6. Find the destination tensors to use for the operation.
- SmallVector<Value> destinationTensors;
- if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
- destinationTensors))) {
- return rewriter.notifyMatchFailure(op,
- "unable to create destination tensors");
+ SmallVector<Value> initTensors;
+ if (failed(createInitialTensorsForTiling(rewriter, op, tileSizes, initTensors,
+ options))) {
+ return rewriter.notifyMatchFailure(
+ op, "unable to create initial tensors for tiling");
}
// 7. Generate the tiled loops nest using the callback defined above.
SmallVector<LoopLikeOpInterface> loops;
if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
- tileSizes, numThreads, destinationTensors,
+ tileSizes, numThreads, initTensors,
innerYieldTiledValuesFn, loops)))
return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) &&
"expected tiling result to be computed after loop generation");
- // If loops are empty, the tiled op is used as the replacement for the untiled
- // op.
+ SmallVector<Value> partialResults;
if (loops.empty()) {
- return scf::SCFTilingResult{tilingResult->tiledOps, loops,
- tilingResult->tiledValues,
- tilingResult->generatedSlices};
+ // If loops are empty, the tiled op is used as the replacement for the
+ // untiled op.
+ partialResults = tilingResult->tiledValues;
+ } else {
+ partialResults = llvm::map_to_vector(loops.front()->getResults(),
+ [](OpResult r) -> Value { return r; });
+ }
+
+ FailureOr<MergeResult> mergeResult =
+ mergeTilingResults(rewriter, op, partialResults, options);
+ if (failed(mergeResult)) {
+ return rewriter.notifyMatchFailure(
+ op, "Failed to merge partial results from tiling");
}
- SmallVector<Value> replacements = llvm::map_to_vector(
- loops.front()->getResults(), [](OpResult r) -> Value { return r; });
- return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+ return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
+ mergeResult.value(),
tilingResult->generatedSlices};
}
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
mlir::scf::tileReductionUsingScf(RewriterBase &b,
PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSizes) {
- Location loc = op.getLoc();
- // Ops implementing PartialReductionOpInterface are expected to implement
- // TilingInterface.
- auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
- SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
- auto tileSizesVector = llvm::to_vector(tileSizes);
- if (tileSizesVector.size() < iterationDomain.size()) {
- auto zero = b.getIndexAttr(0);
- tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
- zero);
- }
- SmallVector<utils::IteratorType> iterators =
- tilingInterfaceOp.getLoopIteratorTypes();
-
- SmallVector<int> reductionDims;
- for (auto [idx, iteratorType] :
- llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
- if (iteratorType == utils::IteratorType::reduction)
- reductionDims.push_back(idx);
- }
-
- // 2. create the inital tensor value.
- FailureOr<SmallVector<Value>> maybeInitTensors =
- op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
- reductionDims);
- if (failed(maybeInitTensors)) {
- return b.notifyMatchFailure(op, "Failed to create initial tensors.");
- }
- SmallVector<Value> &initTensors = maybeInitTensors.value();
-
- // 3. Define the callback to use for generating the inner most tile loop body.
- SmallVector<Operation *> parallelTiledOps;
- auto innerYieldTiledValuesFn =
- [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
- ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
- SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
- SmallVector<SmallVector<OpFoldResult>> &resultSizes)
- -> LogicalResult {
- SmallVector<OpFoldResult> offsets, sizes;
- {
- int materializedLoopNum = 0;
- for (auto [tileSize, loopRange] :
- llvm::zip_equal(tileSizesVector, iterationDomain)) {
- if (isConstantIntValue(tileSize, 0)) {
- offsets.push_back(loopRange.offset);
- sizes.push_back(loopRange.size);
- continue;
- }
- Value iv = ivs[materializedLoopNum++];
- offsets.push_back(iv);
- sizes.push_back(
- getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
- }
- }
-
- // 4a. Clone the operation.
- {
- auto clonedOp = cast<PartialReductionOpInterface>(
- cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
-
- // 4b. Tile the cloned operation.
- FailureOr<TilingResult> partialTilingResult =
- clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
- sizes, reductionDims);
- if (failed(partialTilingResult)) {
- return failure();
- }
- std::swap(parallelTiledOps, partialTilingResult->tiledOps);
- std::swap(tiledResult, partialTilingResult->tiledValues);
-
- // 4c. Delete the cloned operation.
- b.eraseOp(clonedOp);
- }
-
- // 4d. Compute the offsets and sizes needed to insert the result of the
- // tiled value back into destination before yielding the destination.
- for (auto result : tiledResult) {
- SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
- resultOffsets.emplace_back(std::move(outOffsets));
-
- SmallVector<OpFoldResult> outSizes;
- for (size_t i = 0; i < offsets.size(); i++) {
- outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
- }
- resultSizes.emplace_back(std::move(outSizes));
- }
- return success();
- };
-
- // 5. Generate the tiled implementation using the destination tensors.
- SmallVector<LoopLikeOpInterface> loops;
- scf::SCFTilingOptions options;
- options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
- if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
- /*numThreads=*/ArrayRef<OpFoldResult>{},
- initTensors, innerYieldTiledValuesFn, loops)))
- return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
-
- SmallVector<Value> replacements = llvm::map_to_vector(
- loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-
- // 5. Apply the merge reduction to combine all the partial values.
- b.setInsertionPointAfter(*loops.begin());
- FailureOr<MergeResult> mergeResult =
- op.mergeReductions(b, loc, replacements, reductionDims);
- if (failed(mergeResult)) {
- return failure();
- }
- b.replaceOp(op, mergeResult->replacements);
-
- SCFReductionTilingResult reductionTilingResult;
- std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
- std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
- std::swap(reductionTilingResult.initialValues, initTensors);
- std::swap(reductionTilingResult.loops, loops);
- std::swap(reductionTilingResult.replacements, mergeResult->replacements);
-
- return reductionTilingResult;
+ SCFTilingOptions options;
+ options.setLoopType(SCFTilingOptions::LoopType::ForOp);
+ options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction);
+ options.setTileSizes(tileSizes);
+
+ TilingInterface tilingInterfaceOp =
+ dyn_cast<TilingInterface>(op.getOperation());
+ if (!tilingInterfaceOp) {
+ return b.notifyMatchFailure(
+ op,
+ "Operation implementing PartialReductionOpInterface should implement "
+ "TilingInterface");
+ }
+
+ return tileUsingSCF(b, tilingInterfaceOp, options);
}
//===----------------------------------------------------------------------===//
@@ -1055,9 +1103,10 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
/// Return the untiled producer whose slice is used in a tiled consumer. The
/// method traverses the tile loop nest (`loops`) if needed, and returns the
-/// `iter_args` of the outer most that is encountered. Traversing the iter_args
-/// indicates that this is a destination operand of the consumer. If there was
-/// no loop traversal needed, the second value of the returned tuple is empty.
+/// `iter_args` of the outer most that is encountered. Traversing the
+/// iter_args indicates that this is a destination operand of the consumer. If
+/// there was no loop traversal needed, the second value of the returned tuple
+/// is empty.
static std::tuple<OpResult, std::optional<OpOperand *>>
getUntiledProducerFromSliceSource(OpOperand *source,
ArrayRef<LoopLikeOpInterface> loops) {
@@ -1115,8 +1164,8 @@ mlir::scf::tileAndFuseProducerOfSlice(
Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
rewriter, fusableProducerOp, clonedOpDestinationTensors);
// 2d. Update the source of the candidateSlice to be the cloned producer.
- // Easier to just clone the slice with different source since replacements
- // and DCE of cloned ops becomes easier
+ // Easier to just clone the slice with different source since
+ // replacements and DCE of cloned ops becomes easier
SmallVector<Value> candidateSliceOpOperands =
llvm::to_vector(candidateSliceOp->getOperands());
candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
@@ -1250,13 +1299,13 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
failed(tilableOp.getIterationDomainTileFromResultTile(
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
iterDomainOffset, iterDomainSizes))) {
- // In theory, it is unnecessary to raise an error here. Actually although
- // it fails to reconstruct the result tensor, it should not broke current
- // fusion anyway. The reason why we must return failure currently is that
- // the callback function `newYieldValuesFn` will be called after new init
- // operand(s) has already been appended. It will take more refactoring to
- // make sure the init operands are added consistently in the future. For
- // more details, please refer to:
+ // In theory, it is unnecessary to raise an error here. Actually
+ // although it fails to reconstruct the result tensor, it should not
+ // broke current fusion anyway. The reason why we must return failure
+ // currently is that the callback function `newYieldValuesFn` will be
+ // called after new init operand(s) has already been appended. It will
+ // take more refactoring to make sure the init operands are added
+ // consistently in the future. For more details, please refer to:
// https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
return failure();
}
@@ -1282,7 +1331,8 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
}
}
- // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+ // d. create `extract_slice` for `iter_args` for DPS operation if
+ // necessary
if (auto tiledDestStyleOp =
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
rewriter.setInsertionPoint(tiledDestStyleOp);
@@ -1334,9 +1384,10 @@ class SliceTrackingListener : public RewriterBase::Listener {
std::optional<FrozenRewritePatternSet> patterns);
SliceTrackingListener() = default;
- /// Adds the given list of operations to the worklist, and if present, applies
- /// the list of `patterns` to the newly added operations. This only processes
- /// the given operations and any newly inserted ones by the pattern set.
+ /// Adds the given list of operations to the worklist, and if present,
+ /// applies the list of `patterns` to the newly added operations. This only
+ /// processes the given operations and any newly inserted ones by the
+ /// pattern set.
LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
/// Add to the new operation worklist if it is an extract_slice.
@@ -1357,7 +1408,8 @@ class SliceTrackingListener : public RewriterBase::Listener {
std::deque<tensor::ExtractSliceOp> worklist;
private:
- /// Optional pattern set to apply when adding new operations to the worklist.
+ /// Optional pattern set to apply when adding new operations to the
+ /// worklist.
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
};
@@ -1390,8 +1442,9 @@ void SliceTrackingListener::notifyOperationInserted(
worklist.push_back(slice);
}
-// Scan the worklist for the given op and remove it if present. The expectation
-// is for the worklist to be small and for removal to be relatively rare.
+// Scan the worklist for the given op and remove it if present. The
+// expectation is for the worklist to be small and for removal to be
+// relatively rare.
void SliceTrackingListener::removeOp(Operation *op) {
if (!isa<tensor::ExtractSliceOp>(op))
return;
@@ -1445,17 +1498,18 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
auto &loops = tilingResult->loops;
if (loops.empty()) {
DenseMap<Value, Value> replacements;
- for (auto [origVal, replacement] :
- llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
+ for (auto [origVal, replacement] : llvm::zip_equal(
+ consumer->getResults(), tilingResult->mergeResult.replacements)) {
replacements[origVal] = replacement;
}
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}
- // To keep track of replacements for now just record the map from the original
- // untiled value to the result number of the for loop. Since the loop gets
- // potentially replaced during fusion, keeping the value directly wont work.
+ // To keep track of replacements for now just record the map from the
+ // original untiled value to the result number of the for loop. Since the
+ // loop gets potentially replaced during fusion, keeping the value directly
+ // wont work.
DenseMap<Value, size_t> origValToResultNumber;
for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
origValToResultNumber[result] = index;
@@ -1463,11 +1517,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
- // `tensor.extract_slice` operations with source being the operands of the
- // untiled operation. Create a worklist of these `tensor.extract_slice`
- // operations. If the producers of the source of the `tensor.extract_slice`
- // can be tiled such that the tiled value is generated in-place, that
- // effectively tiles + fuses the operations.
+ // `tensor.extract_slice` operations with source being the operands of
+ // the untiled operation. Create a worklist of these
+ // `tensor.extract_slice` operations. If the producers of the source of
+ // the `tensor.extract_slice` can be tiled such that the tiled value is
+ // generated in-place, that effectively tiles + fuses the operations.
struct WorklistItem {
tensor::ExtractSliceOp candidateSlice;
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
@@ -1511,9 +1565,10 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
if (worklistItem.controlFnResult.yieldProducerReplacement) {
- // Reconstruct and yield all opResult of fusableProducerOp by default. The
- // caller can specific which one to yield by designating optional argument
- // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+ // Reconstruct and yield all opResult of fusableProducerOp by default.
+ // The caller can specific which one to yield by designating optional
+ // argument named `yieldResultNumber` of
+ // `yieldReplacementForFusedProducer`.
Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
FailureOr<SmallVector<Operation *>> newSlices =
yieldReplacementForFusedProducer(rewriter,
@@ -1582,8 +1637,8 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
return success();
}
-/// An utility to get the first user of the given loopOp. If any of user stay in
-/// different block of loopOp, return failure.
+/// An utility to get the first user of the given loopOp. If any of user stay
+/// in different block of loopOp, return failure.
static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
if (!isa<LoopLikeOpInterface>(loopOp))
return failure();
@@ -1616,11 +1671,11 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
return firstUserOfLoop;
}
-/// This utility currently checks whether the first userOp of loop is NOT before
-/// the last defineOp of consumer operand. Because that we need to move the
-/// whole loop structure right before the `firstUserOfLoop`. This utility thus
-/// helps ensuring that no invalid IR is formed, i.e. no backward slice of
-/// consumerOp is dominated by the `firstUserOfLoop`. Saying that:
+/// This utility currently checks whether the first userOp of loop is NOT
+/// before the last defineOp of consumer operand. Because that we need to move
+/// the whole loop structure right before the `firstUserOfLoop`. This utility
+/// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
+/// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
///
/// ```
/// %0 = scf.for() {
@@ -1634,9 +1689,9 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
/// %3 = consumerOp(%2)
/// ```
///
-/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
-/// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
-/// use-def chain violation:
+/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
+/// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
+/// a.k.a. use-def chain violation:
///
/// ```
/// %0:2 = scf.for() {
@@ -1650,10 +1705,10 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
///
/// @param loopOp: loop operation
/// @param consumerOp: consumer operation
-/// @param reorderOperations: the flag controls whether to reorder the backward
-/// slice w.r.t. the defineOp of `consumerOp` operands.
-/// @return: computed backward slice of consumerOp, but excluding those already
-/// dominates `firstUserOfLoop`.
+/// @param reorderOperations: the flag controls whether to reorder the
+/// backward slice w.r.t. the defineOp of `consumerOp` operands.
+/// @return: computed backward slice of consumerOp, but excluding those
+/// already dominates `firstUserOfLoop`.
static FailureOr<llvm::SetVector<Operation *>>
checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
bool reorderOperations) {
@@ -1713,8 +1768,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
if (!isa<TilingInterface>(consumerOp) ||
!isa<DestinationStyleOpInterface>(consumerOp)) {
// TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now. Add
- // support for other op such as op has InferTypeOpInterface.
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
continue;
}
// Step 2. Check if user stay in the same block.
@@ -1729,7 +1784,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
checkAssumptionForLoop(loopOp, consumerOp, true);
if (failed(slice))
continue;
- // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
+ // Step 5. If backward sice is not empty, move them before
+ // firstUserOfLoop.
if (!slice->empty()) {
mlir::topologicalSort(*slice);
FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
@@ -1743,8 +1799,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
return failure();
}
-/// Find the perfectly nested loops outside of given loop(included) sorted from
-/// outer to inner.
+/// Find the perfectly nested loops outside of given loop(included) sorted
+/// from outer to inner.
///
/// E.g.
///
@@ -1997,10 +2053,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
}
// 10. Try to get iter domain position from input position. Use
- // clonedConsumerOp instead of tiledConsumerOp, because the iteration domain
- // may require index computation based on the result size. The sizes and
- // offsets should be the same either way, but using tiledConsumerOp could
- // lead to some chained unnecessary extra index computation.
+ // clonedConsumerOp instead of tiledConsumerOp, because the iteration
+ // domain may require index computation based on the result size. The
+ // sizes and offsets should be the same either way, but using
+ // tiledConsumerOp could lead to some chained unnecessary extra index
+ // computation.
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
@@ -2067,7 +2124,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
"unable to add new inits to nest loop");
}
- // 15. Replace the result of scf loop and consumer op with new loop's results.
+ // 15. Replace the result of scf loop and consumer op with new loop's
+ // results.
for (auto &&[oldResult, newResult] : llvm::zip(
consumerOp->getResults(),
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 5e903e378daf82..7380b766935ffe 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -250,7 +250,8 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
return failure();
// Perform the replacement of tiled and fused values.
- rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
+ rewriter.replaceOp(tilingInterfaceOp,
+ tiledResults->mergeResult.replacements);
// Report back the relevant handles to the transform op.
tiledOps.push_back(tiledResults->tiledOps.front());
More information about the Mlir-commits
mailing list