[Mlir-commits] [mlir] [mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (PR #120115)
Kunwar Grover
llvmlistbot at llvm.org
Mon Dec 16 09:22:06 PST 2024
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/120115
>From facdf23493e348a2c3a03083a24b4b7a6c072a9c Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 13 Dec 2024 16:58:10 +0000
Subject: [PATCH] [mlir][SCF] Merge tileUsingFor and tileReductionUsingFor
implementation
---
.../SCF/Transforms/TileUsingInterface.h | 57 ++-
.../TransformOps/LinalgTransformOps.cpp | 13 +-
.../SCF/Transforms/TileUsingInterface.cpp | 458 ++++++++++--------
.../TestTilingInterfaceTransformOps.cpp | 3 +-
4 files changed, 306 insertions(+), 225 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9f5f9f3fca97ad..d2cddfe00ac78e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,36 @@ struct SCFTilingOptions {
return *this;
}
+ /// Specify how reduction dimensions should be tiled.
+ ///
+ /// Tiling can be thought of as splitting a dimension into 2 and materializing
+ /// the outer dimension as a loop:
+ ///
+ /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+ ///
+ /// For parallel dimensions, the split can only happen in one way, with both
+ /// dimensions being parallel. For reduction dimensions however, there is a
+ /// choice in how we split the reduction dimension. This enum exposes this
+ /// choice.
+ enum class ReductionTilingStrategy {
+ // [reduction] -> [reduction1, reduction2]
+ // -> loop[reduction1] { [reduction2] }
+ FullReduction,
+ // [reduction] -> [reduction1, parallel2]
+ // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+ PartialReductionOuterReduction,
+ // [reduction] -> [parallel1, reduction2]
+ // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+ PartialReductionOuterParallel
+ };
+ ReductionTilingStrategy reductionStrategy =
+ ReductionTilingStrategy::FullReduction;
+ SCFTilingOptions &
+ setReductionTilingStrategy(ReductionTilingStrategy strategy) {
+ reductionStrategy = strategy;
+ return *this;
+ }
+
/// Specify mapping of loops to devices. This is only respected when the loop
/// constructs support such a mapping (like `scf.forall`). Will be ignored
/// when using loop constructs that dont support such a mapping (like
@@ -102,11 +132,16 @@ struct SCFTilingResult {
/// matter except the last op. The replacements are expected to be the results
/// of the last op.
SmallVector<Operation *> tiledOps;
+ /// The initial destination values passed to the tiled operations.
+ SmallVector<Value> initialValues;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
- /// Values to use as replacements for the untiled op. Is the same size as the
- /// number of results of the untiled op.
- SmallVector<Value> replacements;
+ /// The result generated by the loop nest in tiling, may hold partial results,
+ /// which need to be merged to match the computation of the untiled operation.
+ /// `mergeResult` contains the operations used to perform this merge from
+ /// partial results and the values that can be used as replacements of
+ /// the untiled operation.
+ MergeResult mergeResult;
/// Slices generated after tiling that can be used for fusing with the tiled
/// producer.
SmallVector<Operation *> generatedSlices;
@@ -300,20 +335,6 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
FailureOr<SmallVector<scf::ForOp>>
lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
-/// Transformation information returned after reduction tiling.
-struct SCFReductionTilingResult {
- /// The partial reduction tiled op generated.
- SmallVector<Operation *> parallelTiledOps;
- /// The final reduction operation merging all the partial reductions.
- SmallVector<Operation *> mergeOps;
- /// Initial values used for reduction.
- SmallVector<Value> initialValues;
- /// The loop operations that iterate over the tiles.
- SmallVector<LoopLikeOpInterface> loops;
- /// The replacements to use for the results of the tiled operation.
- SmallVector<Value> replacements;
-};
-
/// Method to tile a reduction and generate a parallel op within a serial loop.
/// Each of the partial reductions are calculated in parallel. Then after the
/// loop all the partial reduction are merged into a final reduction.
@@ -338,7 +359,7 @@ struct SCFReductionTilingResult {
/// %6 = linalg.generic %1 ["parallel", "reduction"]
/// : tensor<7x4xf32> -> tensor<7xf32>
/// ```
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSize);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8839faf4cafb2d..66a3947e0f91fc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2224,7 +2224,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultDefiniteFailure(target);
if (target->getNumResults())
- rewriter.replaceOp(target, maybeTilingResult->replacements);
+ rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
else
rewriter.eraseOp(target);
@@ -2631,17 +2631,18 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
+ FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
if (failed(result))
return emitDefaultSilenceableFailure(target);
+ rewriter.replaceOp(target, result->mergeResult.replacements);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
- for (auto parallelTiledOp : result->parallelTiledOps)
+ for (auto parallelTiledOp : result->tiledOps)
results.push_back(parallelTiledOp);
- for (auto mergeOp : result->mergeOps)
+ for (auto mergeOp : result->mergeResult.mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
@@ -3065,7 +3066,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure::definiteFailure();
- rewriter.replaceOp(op, maybeTilingResult->replacements);
+ rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
tiled.append(maybeTilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3304,7 +3305,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);
- rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
+ rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
tilingResult = *maybeTilingResult;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 6a4a6b43933806..8ece9fb259ddd6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -570,6 +570,146 @@ static LogicalResult generateLoopNest(
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
}
+static LogicalResult
+createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
+ ArrayRef<OpFoldResult> tileSizes,
+ SmallVector<Value> &initTensors,
+ const scf::SCFTilingOptions &options) {
+ Location loc = op->getLoc();
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return tensor::getOrCreateDestinations(rewriter, loc, op, initTensors);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ FailureOr<SmallVector<Value>> maybeInitTensors =
+ redOp.generateInitialTensorForPartialReduction(rewriter, loc, tileSizes,
+ reductionDims);
+ if (failed(maybeInitTensors)) {
+ return failure();
+ }
+ initTensors = maybeInitTensors.value();
+ return success();
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
+static FailureOr<TilingResult>
+getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
+ ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ const scf::SCFTilingOptions &options) {
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return op.getTiledImplementation(rewriter, offsets, sizes);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
+ offsets, sizes, reductionDims);
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
+static LogicalResult
+getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
+ TilingInterface op, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &resultOffset,
+ SmallVector<OpFoldResult> &resultSize,
+ const scf::SCFTilingOptions &options) {
+
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return op.getResultTilePosition(rewriter, index, offsets, sizes,
+ resultOffset, resultSize);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ resultOffset =
+ SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
+ for (size_t i = 0; i < offsets.size(); i++) {
+ resultSize.push_back(
+ tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
+ }
+ return success();
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+ }
+}
+
+static FailureOr<MergeResult>
+mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
+ ValueRange partialResults,
+ const scf::SCFTilingOptions &options) {
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ // No need to merge results for reduction tiling strategy.
+ return MergeResult{{}, partialResults};
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
+ reductionDims);
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
/// Append the specified additional `newInitOperands` operands to the
/// loops existing `init` operands (or similar), and replace `loopOp` with
/// the new loop that has the additional init operands. The loop body of
@@ -710,11 +850,11 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
});
}
-/// Method to add new init values to a loop nest. Updates `loops` in-place with
-/// new loops that use the `newInitValues`.
-/// The outer-loops are updated to yield the new result values of the inner
-/// loop. For the innermost loop, the call back `getNewYields` is invoked to get
-/// the additional values to yield form the innermost loop.
+/// Method to add new init values to a loop nest. Updates `loops` in-place
+/// with new loops that use the `newInitValues`. The outer-loops are updated
+/// to yield the new result values of the inner loop. For the innermost loop,
+/// the call back `getNewYields` is invoked to get the additional values to
+/// yield form the innermost loop.
static LogicalResult addInitOperandsToLoopNest(
RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
@@ -852,9 +992,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
auto clonedOp = cast<TilingInterface>(
cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
- // 5b. Early return cloned op if tiling is not happening. We can not return
- // the original op because it could lead to
- // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
+ // 5b. Early return cloned op if tiling is not happening. We can not
+ // return the original op because it could lead to `rewriter.replaceOp(op,
+ // op->getResults())` and users would get crash.
if (llvm::all_of(tileSizes, isZeroIndex)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
@@ -864,7 +1004,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
}
// 5c. Tile the cloned operation.
- tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
+ tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
+ offsets, sizes, options);
if (failed(tilingResult)) {
rewriter.eraseOp(clonedOp);
return op.emitOpError("faild to tile operation");
@@ -879,8 +1020,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
llvm::enumerate(tilingResult->tiledValues)) {
tiledResults.push_back(tiledValue);
SmallVector<OpFoldResult> resultOffset, resultSize;
- if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
- resultOffset, resultSize))) {
+ if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
+ sizes, resultOffset, resultSize,
+ options))) {
for (auto op : tilingResult->tiledOps) {
rewriter.eraseOp(op);
}
@@ -895,158 +1037,64 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
};
// 6. Find the destination tensors to use for the operation.
- SmallVector<Value> destinationTensors;
- if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
- destinationTensors))) {
- return rewriter.notifyMatchFailure(op,
- "unable to create destination tensors");
+ SmallVector<Value> initTensors;
+ if (failed(createInitialTensorsForTiling(rewriter, op, tileSizes, initTensors,
+ options))) {
+ return rewriter.notifyMatchFailure(
+ op, "unable to create initial tensors for tiling");
}
// 7. Generate the tiled loops nest using the callback defined above.
SmallVector<LoopLikeOpInterface> loops;
if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
- tileSizes, numThreads, destinationTensors,
+ tileSizes, numThreads, initTensors,
innerYieldTiledValuesFn, loops)))
return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) &&
"expected tiling result to be computed after loop generation");
- // If loops are empty, the tiled op is used as the replacement for the untiled
- // op.
+ SmallVector<Value> partialResults;
if (loops.empty()) {
- return scf::SCFTilingResult{tilingResult->tiledOps, loops,
- tilingResult->tiledValues,
- tilingResult->generatedSlices};
+ // If loops are empty, the tiled op is used as the replacement for the
+ // untiled op.
+ partialResults = tilingResult->tiledValues;
+ } else {
+ partialResults = llvm::map_to_vector(loops.front()->getResults(),
+ [](OpResult r) -> Value { return r; });
+ }
+
+ FailureOr<MergeResult> mergeResult =
+ mergeTilingResults(rewriter, op, partialResults, options);
+ if (failed(mergeResult)) {
+ return rewriter.notifyMatchFailure(
+ op, "Failed to merge partial results from tiling");
}
- SmallVector<Value> replacements = llvm::map_to_vector(
- loops.front()->getResults(), [](OpResult r) -> Value { return r; });
- return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+ return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
+ mergeResult.value(),
tilingResult->generatedSlices};
}
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
mlir::scf::tileReductionUsingScf(RewriterBase &b,
PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSizes) {
- Location loc = op.getLoc();
- // Ops implementing PartialReductionOpInterface are expected to implement
- // TilingInterface.
- auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
- SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
- auto tileSizesVector = llvm::to_vector(tileSizes);
- if (tileSizesVector.size() < iterationDomain.size()) {
- auto zero = b.getIndexAttr(0);
- tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
- zero);
- }
- SmallVector<utils::IteratorType> iterators =
- tilingInterfaceOp.getLoopIteratorTypes();
-
- SmallVector<int> reductionDims;
- for (auto [idx, iteratorType] :
- llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
- if (iteratorType == utils::IteratorType::reduction)
- reductionDims.push_back(idx);
- }
-
- // 2. create the inital tensor value.
- FailureOr<SmallVector<Value>> maybeInitTensors =
- op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
- reductionDims);
- if (failed(maybeInitTensors)) {
- return b.notifyMatchFailure(op, "Failed to create initial tensors.");
- }
- SmallVector<Value> &initTensors = maybeInitTensors.value();
-
- // 3. Define the callback to use for generating the inner most tile loop body.
- SmallVector<Operation *> parallelTiledOps;
- auto innerYieldTiledValuesFn =
- [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
- ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
- SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
- SmallVector<SmallVector<OpFoldResult>> &resultSizes)
- -> LogicalResult {
- SmallVector<OpFoldResult> offsets, sizes;
- {
- int materializedLoopNum = 0;
- for (auto [tileSize, loopRange] :
- llvm::zip_equal(tileSizesVector, iterationDomain)) {
- if (isConstantIntValue(tileSize, 0)) {
- offsets.push_back(loopRange.offset);
- sizes.push_back(loopRange.size);
- continue;
- }
- Value iv = ivs[materializedLoopNum++];
- offsets.push_back(iv);
- sizes.push_back(
- getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
- }
- }
-
- // 4a. Clone the operation.
- {
- auto clonedOp = cast<PartialReductionOpInterface>(
- cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
-
- // 4b. Tile the cloned operation.
- FailureOr<TilingResult> partialTilingResult =
- clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
- sizes, reductionDims);
- if (failed(partialTilingResult)) {
- return failure();
- }
- std::swap(parallelTiledOps, partialTilingResult->tiledOps);
- std::swap(tiledResult, partialTilingResult->tiledValues);
-
- // 4c. Delete the cloned operation.
- b.eraseOp(clonedOp);
- }
-
- // 4d. Compute the offsets and sizes needed to insert the result of the
- // tiled value back into destination before yielding the destination.
- for (auto result : tiledResult) {
- SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
- resultOffsets.emplace_back(std::move(outOffsets));
-
- SmallVector<OpFoldResult> outSizes;
- for (size_t i = 0; i < offsets.size(); i++) {
- outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
- }
- resultSizes.emplace_back(std::move(outSizes));
- }
- return success();
- };
-
- // 5. Generate the tiled implementation using the destination tensors.
- SmallVector<LoopLikeOpInterface> loops;
- scf::SCFTilingOptions options;
- options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
- if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
- /*numThreads=*/ArrayRef<OpFoldResult>{},
- initTensors, innerYieldTiledValuesFn, loops)))
- return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
-
- SmallVector<Value> replacements = llvm::map_to_vector(
- loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-
- // 5. Apply the merge reduction to combine all the partial values.
- b.setInsertionPointAfter(*loops.begin());
- FailureOr<MergeResult> mergeResult =
- op.mergeReductions(b, loc, replacements, reductionDims);
- if (failed(mergeResult)) {
- return failure();
- }
- b.replaceOp(op, mergeResult->replacements);
-
- SCFReductionTilingResult reductionTilingResult;
- std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
- std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
- std::swap(reductionTilingResult.initialValues, initTensors);
- std::swap(reductionTilingResult.loops, loops);
- std::swap(reductionTilingResult.replacements, mergeResult->replacements);
-
- return reductionTilingResult;
+ SCFTilingOptions options;
+ options.setLoopType(SCFTilingOptions::LoopType::ForOp);
+ options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction);
+ options.setTileSizes(tileSizes);
+
+ TilingInterface tilingInterfaceOp =
+ dyn_cast<TilingInterface>(op.getOperation());
+ if (!tilingInterfaceOp) {
+ return b.notifyMatchFailure(
+ op,
+ "Operation implementing PartialReductionOpInterface should implement "
+ "TilingInterface");
+ }
+
+ return tileUsingSCF(b, tilingInterfaceOp, options);
}
//===----------------------------------------------------------------------===//
@@ -1055,9 +1103,10 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
/// Return the untiled producer whose slice is used in a tiled consumer. The
/// method traverses the tile loop nest (`loops`) if needed, and returns the
-/// `iter_args` of the outer most that is encountered. Traversing the iter_args
-/// indicates that this is a destination operand of the consumer. If there was
-/// no loop traversal needed, the second value of the returned tuple is empty.
+/// `iter_args` of the outer most that is encountered. Traversing the
+/// iter_args indicates that this is a destination operand of the consumer. If
+/// there was no loop traversal needed, the second value of the returned tuple
+/// is empty.
static std::tuple<OpResult, std::optional<OpOperand *>>
getUntiledProducerFromSliceSource(OpOperand *source,
ArrayRef<LoopLikeOpInterface> loops) {
@@ -1115,8 +1164,8 @@ mlir::scf::tileAndFuseProducerOfSlice(
Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
rewriter, fusableProducerOp, clonedOpDestinationTensors);
// 2d. Update the source of the candidateSlice to be the cloned producer.
- // Easier to just clone the slice with different source since replacements
- // and DCE of cloned ops becomes easier
+ // Easier to just clone the slice with different source since
+ // replacements and DCE of cloned ops becomes easier
SmallVector<Value> candidateSliceOpOperands =
llvm::to_vector(candidateSliceOp->getOperands());
candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
@@ -1250,13 +1299,13 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
failed(tilableOp.getIterationDomainTileFromResultTile(
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
iterDomainOffset, iterDomainSizes))) {
- // In theory, it is unnecessary to raise an error here. Actually although
- // it fails to reconstruct the result tensor, it should not broke current
- // fusion anyway. The reason why we must return failure currently is that
- // the callback function `newYieldValuesFn` will be called after new init
- // operand(s) has already been appended. It will take more refactoring to
- // make sure the init operands are added consistently in the future. For
- // more details, please refer to:
+ // In theory, it is unnecessary to raise an error here. Actually
+ // although it fails to reconstruct the result tensor, it should not
+ // broke current fusion anyway. The reason why we must return failure
+ // currently is that the callback function `newYieldValuesFn` will be
+ // called after new init operand(s) has already been appended. It will
+ // take more refactoring to make sure the init operands are added
+ // consistently in the future. For more details, please refer to:
// https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
return failure();
}
@@ -1282,7 +1331,8 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
}
}
- // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+ // d. create `extract_slice` for `iter_args` for DPS operation if
+ // necessary
if (auto tiledDestStyleOp =
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
rewriter.setInsertionPoint(tiledDestStyleOp);
@@ -1334,9 +1384,10 @@ class SliceTrackingListener : public RewriterBase::Listener {
std::optional<FrozenRewritePatternSet> patterns);
SliceTrackingListener() = default;
- /// Adds the given list of operations to the worklist, and if present, applies
- /// the list of `patterns` to the newly added operations. This only processes
- /// the given operations and any newly inserted ones by the pattern set.
+ /// Adds the given list of operations to the worklist, and if present,
+ /// applies the list of `patterns` to the newly added operations. This only
+ /// processes the given operations and any newly inserted ones by the
+ /// pattern set.
LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
/// Add to the new operation worklist if it is an extract_slice.
@@ -1357,7 +1408,8 @@ class SliceTrackingListener : public RewriterBase::Listener {
std::deque<tensor::ExtractSliceOp> worklist;
private:
- /// Optional pattern set to apply when adding new operations to the worklist.
+ /// Optional pattern set to apply when adding new operations to the
+ /// worklist.
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
};
@@ -1390,8 +1442,9 @@ void SliceTrackingListener::notifyOperationInserted(
worklist.push_back(slice);
}
-// Scan the worklist for the given op and remove it if present. The expectation
-// is for the worklist to be small and for removal to be relatively rare.
+// Scan the worklist for the given op and remove it if present. The
+// expectation is for the worklist to be small and for removal to be
+// relatively rare.
void SliceTrackingListener::removeOp(Operation *op) {
if (!isa<tensor::ExtractSliceOp>(op))
return;
@@ -1445,17 +1498,18 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
auto &loops = tilingResult->loops;
if (loops.empty()) {
DenseMap<Value, Value> replacements;
- for (auto [origVal, replacement] :
- llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
+ for (auto [origVal, replacement] : llvm::zip_equal(
+ consumer->getResults(), tilingResult->mergeResult.replacements)) {
replacements[origVal] = replacement;
}
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}
- // To keep track of replacements for now just record the map from the original
- // untiled value to the result number of the for loop. Since the loop gets
- // potentially replaced during fusion, keeping the value directly wont work.
+ // To keep track of replacements for now just record the map from the
+ // original untiled value to the result number of the for loop. Since the
+ // loop gets potentially replaced during fusion, keeping the value directly
+ // wont work.
DenseMap<Value, size_t> origValToResultNumber;
for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
origValToResultNumber[result] = index;
@@ -1463,11 +1517,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
- // `tensor.extract_slice` operations with source being the operands of the
- // untiled operation. Create a worklist of these `tensor.extract_slice`
- // operations. If the producers of the source of the `tensor.extract_slice`
- // can be tiled such that the tiled value is generated in-place, that
- // effectively tiles + fuses the operations.
+ // `tensor.extract_slice` operations with source being the operands of
+ // the untiled operation. Create a worklist of these
+ // `tensor.extract_slice` operations. If the producers of the source of
+ // the `tensor.extract_slice` can be tiled such that the tiled value is
+ // generated in-place, that effectively tiles + fuses the operations.
struct WorklistItem {
tensor::ExtractSliceOp candidateSlice;
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
@@ -1511,9 +1565,10 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
if (worklistItem.controlFnResult.yieldProducerReplacement) {
- // Reconstruct and yield all opResult of fusableProducerOp by default. The
- // caller can specific which one to yield by designating optional argument
- // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+ // Reconstruct and yield all opResult of fusableProducerOp by default.
+ // The caller can specific which one to yield by designating optional
+ // argument named `yieldResultNumber` of
+ // `yieldReplacementForFusedProducer`.
Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
FailureOr<SmallVector<Operation *>> newSlices =
yieldReplacementForFusedProducer(rewriter,
@@ -1582,8 +1637,8 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
return success();
}
-/// An utility to get the first user of the given loopOp. If any of user stay in
-/// different block of loopOp, return failure.
+/// An utility to get the first user of the given loopOp. If any of user stay
+/// in different block of loopOp, return failure.
static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
if (!isa<LoopLikeOpInterface>(loopOp))
return failure();
@@ -1616,11 +1671,11 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
return firstUserOfLoop;
}
-/// This utility currently checks whether the first userOp of loop is NOT before
-/// the last defineOp of consumer operand. Because that we need to move the
-/// whole loop structure right before the `firstUserOfLoop`. This utility thus
-/// helps ensuring that no invalid IR is formed, i.e. no backward slice of
-/// consumerOp is dominated by the `firstUserOfLoop`. Saying that:
+/// This utility currently checks whether the first userOp of loop is NOT
+/// before the last defineOp of consumer operand. Because that we need to move
+/// the whole loop structure right before the `firstUserOfLoop`. This utility
+/// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
+/// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
///
/// ```
/// %0 = scf.for() {
@@ -1634,9 +1689,9 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
/// %3 = consumerOp(%2)
/// ```
///
-/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
-/// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
-/// use-def chain violation:
+/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
+/// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
+/// a.k.a. use-def chain violation:
///
/// ```
/// %0:2 = scf.for() {
@@ -1650,10 +1705,10 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
///
/// @param loopOp: loop operation
/// @param consumerOp: consumer operation
-/// @param reorderOperations: the flag controls whether to reorder the backward
-/// slice w.r.t. the defineOp of `consumerOp` operands.
-/// @return: computed backward slice of consumerOp, but excluding those already
-/// dominates `firstUserOfLoop`.
+/// @param reorderOperations: the flag controls whether to reorder the
+/// backward slice w.r.t. the defineOp of `consumerOp` operands.
+/// @return: computed backward slice of consumerOp, but excluding those
+/// already dominates `firstUserOfLoop`.
static FailureOr<llvm::SetVector<Operation *>>
checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
bool reorderOperations) {
@@ -1713,8 +1768,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
if (!isa<TilingInterface>(consumerOp) ||
!isa<DestinationStyleOpInterface>(consumerOp)) {
// TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now. Add
- // support for other op such as op has InferTypeOpInterface.
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
continue;
}
// Step 2. Check if user stay in the same block.
@@ -1729,7 +1784,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
checkAssumptionForLoop(loopOp, consumerOp, true);
if (failed(slice))
continue;
- // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
+ // Step 5. If backward sice is not empty, move them before
+ // firstUserOfLoop.
if (!slice->empty()) {
mlir::topologicalSort(*slice);
FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
@@ -1743,8 +1799,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
return failure();
}
-/// Find the perfectly nested loops outside of given loop(included) sorted from
-/// outer to inner.
+/// Find the perfectly nested loops outside of given loop(included) sorted
+/// from outer to inner.
///
/// E.g.
///
@@ -1997,10 +2053,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
}
// 10. Try to get iter domain position from input position. Use
- // clonedConsumerOp instead of tiledConsumerOp, because the iteration domain
- // may require index computation based on the result size. The sizes and
- // offsets should be the same either way, but using tiledConsumerOp could
- // lead to some chained unnecessary extra index computation.
+ // clonedConsumerOp instead of tiledConsumerOp, because the iteration
+ // domain may require index computation based on the result size. The
+ // sizes and offsets should be the same either way, but using
+ // tiledConsumerOp could lead to some chained unnecessary extra index
+ // computation.
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
@@ -2067,7 +2124,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
"unable to add new inits to nest loop");
}
- // 15. Replace the result of scf loop and consumer op with new loop's results.
+ // 15. Replace the result of scf loop and consumer op with new loop's
+ // results.
for (auto &&[oldResult, newResult] : llvm::zip(
consumerOp->getResults(),
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 5e903e378daf82..7380b766935ffe 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -250,7 +250,8 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
return failure();
// Perform the replacement of tiled and fused values.
- rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
+ rewriter.replaceOp(tilingInterfaceOp,
+ tiledResults->mergeResult.replacements);
// Report back the relevant handles to the transform op.
tiledOps.push_back(tiledResults->tiledOps.front());
More information about the Mlir-commits
mailing list