[Mlir-commits] [mlir] [mlir][scf] Return `replacements` explicitly in `SCFTilingResult`. (PR #143217)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jun 8 14:14:17 PDT 2025


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/143217

>From 7fc45f4fa18900ec4302c39d3786a53203ae4dc4 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 5 Jun 2025 15:03:42 -0700
Subject: [PATCH] [mlir][scf] Return `replacements` explicitly in
 `SCFTilingResult`.

In #120115 the replacements for the tiled operations were wrapped
within the `MergeResult` object. That is a bit of an obfuscation and
not immediately obvious where to get the replacements post
tiling. This changes the `SCFTilingResult` to have `replacements`
explicit (as it was before that change).

It also makes the `mergeOps` a separate field of `SCFTilingResult`,
which is empty when the reduction type is `FullReduction`.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../SCF/Transforms/TileUsingInterface.h       | 16 ++---
 .../mlir/Interfaces/TilingInterface.td        |  3 +-
 .../TransformOps/LinalgTransformOps.cpp       | 10 ++--
 .../SCF/Transforms/TileUsingInterface.cpp     | 58 ++++++++++---------
 .../TestTilingInterfaceTransformOps.cpp       |  2 +-
 5 files changed, 47 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 33a43ce2ee7bb..f686ae07b9a99 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -136,15 +136,17 @@ struct SCFTilingResult {
   SmallVector<Value> initialValues;
   /// The `scf.for` operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
-  /// The result generated by the loop nest in tiling, may hold partial results,
-  /// which need to be merged to match the computation of the untiled operation.
-  /// `mergeResult` contains the operations used to perform this merge from
-  /// partial results and the values that can be used as replacements of
-  /// the untiled operation.
-  MergeResult mergeResult;
+  /// Values to use as replacements for the untiled op. Is the same size as the
+  /// number of results of the untiled op.
+  SmallVector<Value> replacements;
   /// Slices generated after tiling that can be used for fusing with the tiled
   /// producer.
   SmallVector<Operation *> generatedSlices;
+  /// In cases where there as an additional merge step after tiling
+  /// return the merged ops after tiling. This list is empty when reduction
+  /// tiling strategy is
+  /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
+  SmallVector<Operation *> mergeOps;
 };
 
 /// Method to tile an op that implements the `TilingInterface` using
@@ -362,7 +364,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
 /// ```
 FailureOr<scf::SCFTilingResult>
 tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
-                      ArrayRef<OpFoldResult> tileSize);
+                      ArrayRef<OpFoldResult> tileSizes);
 
 } // namespace scf
 } // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 50b69b8f8d833..cdf3d01ce8a84 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -363,7 +363,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
   ];
 }
 
-def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
+def PartialReductionOpInterface :
+    OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
   let description = [{
     Interface for allowing operations to expose information needed to
     tile reductions using partial reduction followed by merge. This is
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1c3b621828315..b2c28f5eed33c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2381,7 +2381,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
     return emitDefaultDefiniteFailure(target);
 
   if (target->getNumResults())
-    rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
+    rewriter.replaceOp(target, maybeTilingResult->replacements);
   else
     rewriter.eraseOp(target);
 
@@ -2800,12 +2800,12 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
 
   if (failed(result))
     return emitDefaultSilenceableFailure(target);
-  rewriter.replaceOp(target, result->mergeResult.replacements);
+  rewriter.replaceOp(target, result->replacements);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
   for (auto parallelTiledOp : result->tiledOps)
     results.push_back(parallelTiledOp);
-  for (auto mergeOp : result->mergeResult.mergeOps)
+  for (auto mergeOp : result->mergeOps)
     results.push_back(mergeOp);
   results.push_back(result->loops.front());
   return DiagnosedSilenceableFailure::success();
@@ -3229,7 +3229,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
     if (failed(maybeTilingResult))
       return DiagnosedSilenceableFailure::definiteFailure();
 
-    rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
+    rewriter.replaceOp(op, maybeTilingResult->replacements);
 
     tiled.append(maybeTilingResult->tiledOps);
     for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3465,7 +3465,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
 
-  rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
+  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
 
   tilingResult = *maybeTilingResult;
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 57ee0f52e7491..a0f9b599d1351 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1058,48 +1058,50 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   assert(succeeded(tilingResult) &&
          "expected tiling result to be computed after loop generation");
 
-  SmallVector<Value> partialResults;
   if (loops.empty()) {
     // If loops are empty, the tiled op is used as the replacement for the
     // untiled op.
-    partialResults = tilingResult->tiledValues;
-  } else {
-    partialResults = llvm::map_to_vector(loops.front()->getResults(),
+    return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
+                                tilingResult->tiledValues,
+                                tilingResult->generatedSlices};
+  }
+
+  auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
                                          [](OpResult r) -> Value { return r; });
+
+  // For the full reduction case, there is nothing more to do.
+  if (options.reductionStrategy ==
+      scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) {
+    return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
+                                loopResults, tilingResult->generatedSlices};
   }
 
+  // The results of the loop needs to be merged.
   FailureOr<MergeResult> mergeResult =
-      mergeTilingResults(rewriter, op, partialResults, options);
+      mergeTilingResults(rewriter, op, loopResults, options);
   if (failed(mergeResult)) {
     return rewriter.notifyMatchFailure(
         op, "Failed to merge partial results from tiling");
   }
-
-  return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
-                              mergeResult.value(),
-                              tilingResult->generatedSlices};
+  return scf::SCFTilingResult{tilingResult->tiledOps,
+                              initTensors,
+                              loops,
+                              mergeResult->replacements,
+                              tilingResult->generatedSlices,
+                              mergeResult->mergeOps};
 }
 
 FailureOr<scf::SCFTilingResult>
 mlir::scf::tileReductionUsingScf(RewriterBase &b,
                                  PartialReductionOpInterface op,
-                                 ArrayRef<OpFoldResult> tileSizes) {
-  SCFTilingOptions options;
-  options.setLoopType(SCFTilingOptions::LoopType::ForOp);
-  options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
-                                         PartialReductionOuterReduction);
-  options.setTileSizes(tileSizes);
-
-  TilingInterface tilingInterfaceOp =
-      dyn_cast<TilingInterface>(op.getOperation());
-  if (!tilingInterfaceOp) {
-    return b.notifyMatchFailure(
-        op,
-        "Operation implementing PartialReductionOpInterface should implement "
-        "TilingInterface");
-  }
-
-  return tileUsingSCF(b, tilingInterfaceOp, options);
+                                 ArrayRef<OpFoldResult> tileSize) {
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+  options.setReductionTilingStrategy(
+      scf::SCFTilingOptions::ReductionTilingStrategy::
+          PartialReductionOuterReduction);
+  options.setTileSizes(tileSize);
+  return tileUsingSCF(b, op, options);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1539,8 +1541,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   tiledAndFusedOps.insert_range(tilingResult->tiledOps);
 
   DenseMap<Value, Value> replacements;
-  for (auto [origVal, replacement] : llvm::zip_equal(
-           consumer->getResults(), tilingResult->mergeResult.replacements)) {
+  for (auto [origVal, replacement] :
+       llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
     replacements[origVal] = replacement;
   }
 
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 45d6ae3820159..de32e6a59b4c0 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -261,7 +261,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
 
     // Perform the replacement of tiled and fused values.
     rewriter.replaceOp(tilingInterfaceOp,
-                       tiledResults->mergeResult.replacements);
+                       tiledResults->replacements);
 
     // Report back the relevant handles to the transform op.
     tiledOps.push_back(tiledResults->tiledOps.front());



More information about the Mlir-commits mailing list