[Mlir-commits] [mlir] [mlir][scf] Return `replacements` explicitly in `SCFTilingResult`. (PR #143217)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 8 14:14:17 PDT 2025
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/143217
>From 7fc45f4fa18900ec4302c39d3786a53203ae4dc4 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 5 Jun 2025 15:03:42 -0700
Subject: [PATCH] [mlir][scf] Return `replacements` explicitly in
`SCFTilingResult`.
In #120115 the replacements for the tiled operations were wrapped
within the `MergeResult` object. That is a bit of an obfuscation and
not immediately obvious where to get the replacements post
tiling. This changes the `SCFTilingResult` to have `replacements`
explicit (as it was before that change).
It also makes the `mergeOps` a separate field of `SCFTilingResult`,
which is empty when the reduction type is `FullReduction`.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
.../SCF/Transforms/TileUsingInterface.h | 16 ++---
.../mlir/Interfaces/TilingInterface.td | 3 +-
.../TransformOps/LinalgTransformOps.cpp | 10 ++--
.../SCF/Transforms/TileUsingInterface.cpp | 58 ++++++++++---------
.../TestTilingInterfaceTransformOps.cpp | 2 +-
5 files changed, 47 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 33a43ce2ee7bb..f686ae07b9a99 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -136,15 +136,17 @@ struct SCFTilingResult {
SmallVector<Value> initialValues;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
- /// The result generated by the loop nest in tiling, may hold partial results,
- /// which need to be merged to match the computation of the untiled operation.
- /// `mergeResult` contains the operations used to perform this merge from
- /// partial results and the values that can be used as replacements of
- /// the untiled operation.
- MergeResult mergeResult;
+ /// Values to use as replacements for the untiled op. Is the same size as the
+ /// number of results of the untiled op.
+ SmallVector<Value> replacements;
/// Slices generated after tiling that can be used for fusing with the tiled
/// producer.
SmallVector<Operation *> generatedSlices;
+ /// In cases where there as an additional merge step after tiling
+ /// return the merged ops after tiling. This list is empty when reduction
+ /// tiling strategy is
+ /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
+ SmallVector<Operation *> mergeOps;
};
/// Method to tile an op that implements the `TilingInterface` using
@@ -362,7 +364,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
/// ```
FailureOr<scf::SCFTilingResult>
tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
- ArrayRef<OpFoldResult> tileSize);
+ ArrayRef<OpFoldResult> tileSizes);
} // namespace scf
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 50b69b8f8d833..cdf3d01ce8a84 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -363,7 +363,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
];
}
-def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
+def PartialReductionOpInterface :
+ OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
let description = [{
Interface for allowing operations to expose information needed to
tile reductions using partial reduction followed by merge. This is
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1c3b621828315..b2c28f5eed33c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2381,7 +2381,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultDefiniteFailure(target);
if (target->getNumResults())
- rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
+ rewriter.replaceOp(target, maybeTilingResult->replacements);
else
rewriter.eraseOp(target);
@@ -2800,12 +2800,12 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
if (failed(result))
return emitDefaultSilenceableFailure(target);
- rewriter.replaceOp(target, result->mergeResult.replacements);
+ rewriter.replaceOp(target, result->replacements);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
for (auto parallelTiledOp : result->tiledOps)
results.push_back(parallelTiledOp);
- for (auto mergeOp : result->mergeResult.mergeOps)
+ for (auto mergeOp : result->mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
@@ -3229,7 +3229,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure::definiteFailure();
- rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
+ rewriter.replaceOp(op, maybeTilingResult->replacements);
tiled.append(maybeTilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3465,7 +3465,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);
- rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
+ rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
tilingResult = *maybeTilingResult;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 57ee0f52e7491..a0f9b599d1351 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1058,48 +1058,50 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
assert(succeeded(tilingResult) &&
"expected tiling result to be computed after loop generation");
- SmallVector<Value> partialResults;
if (loops.empty()) {
// If loops are empty, the tiled op is used as the replacement for the
// untiled op.
- partialResults = tilingResult->tiledValues;
- } else {
- partialResults = llvm::map_to_vector(loops.front()->getResults(),
+ return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
+ tilingResult->tiledValues,
+ tilingResult->generatedSlices};
+ }
+
+ auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
[](OpResult r) -> Value { return r; });
+
+ // For the full reduction case, there is nothing more to do.
+ if (options.reductionStrategy ==
+ scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) {
+ return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
+ loopResults, tilingResult->generatedSlices};
}
+ // The results of the loop needs to be merged.
FailureOr<MergeResult> mergeResult =
- mergeTilingResults(rewriter, op, partialResults, options);
+ mergeTilingResults(rewriter, op, loopResults, options);
if (failed(mergeResult)) {
return rewriter.notifyMatchFailure(
op, "Failed to merge partial results from tiling");
}
-
- return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
- mergeResult.value(),
- tilingResult->generatedSlices};
+ return scf::SCFTilingResult{tilingResult->tiledOps,
+ initTensors,
+ loops,
+ mergeResult->replacements,
+ tilingResult->generatedSlices,
+ mergeResult->mergeOps};
}
FailureOr<scf::SCFTilingResult>
mlir::scf::tileReductionUsingScf(RewriterBase &b,
PartialReductionOpInterface op,
- ArrayRef<OpFoldResult> tileSizes) {
- SCFTilingOptions options;
- options.setLoopType(SCFTilingOptions::LoopType::ForOp);
- options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
- PartialReductionOuterReduction);
- options.setTileSizes(tileSizes);
-
- TilingInterface tilingInterfaceOp =
- dyn_cast<TilingInterface>(op.getOperation());
- if (!tilingInterfaceOp) {
- return b.notifyMatchFailure(
- op,
- "Operation implementing PartialReductionOpInterface should implement "
- "TilingInterface");
- }
-
- return tileUsingSCF(b, tilingInterfaceOp, options);
+ ArrayRef<OpFoldResult> tileSize) {
+ scf::SCFTilingOptions options;
+ options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+ options.setReductionTilingStrategy(
+ scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction);
+ options.setTileSizes(tileSize);
+ return tileUsingSCF(b, op, options);
}
//===----------------------------------------------------------------------===//
@@ -1539,8 +1541,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
tiledAndFusedOps.insert_range(tilingResult->tiledOps);
DenseMap<Value, Value> replacements;
- for (auto [origVal, replacement] : llvm::zip_equal(
- consumer->getResults(), tilingResult->mergeResult.replacements)) {
+ for (auto [origVal, replacement] :
+ llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
replacements[origVal] = replacement;
}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 45d6ae3820159..de32e6a59b4c0 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -261,7 +261,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
// Perform the replacement of tiled and fused values.
rewriter.replaceOp(tilingInterfaceOp,
- tiledResults->mergeResult.replacements);
+ tiledResults->replacements);
// Report back the relevant handles to the transform op.
tiledOps.push_back(tiledResults->tiledOps.front());
More information about the Mlir-commits
mailing list