[Mlir-commits] [mlir] 4b56345 - [mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (#120115)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 18 05:24:51 PST 2024


Author: Kunwar Grover
Date: 2024-12-18T13:24:47Z
New Revision: 4b56345895729fda3bc3c094bc3f237ba3a49686

URL: https://github.com/llvm/llvm-project/commit/4b56345895729fda3bc3c094bc3f237ba3a49686
DIFF: https://github.com/llvm/llvm-project/commit/4b56345895729fda3bc3c094bc3f237ba3a49686.diff

LOG: [mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (#120115)

This patch unifies the tiling implementation for tileUsingFor and
tileReductionUsingFor. This is done by passing an addition option to
SCFTilingOptions, allowing it to set how reduction dimensions should be
tiled. Currently, there are 3 different options for reduction tiling:
FullReduction (old tileUsingFor), PartialReductionOuterReduction (old
tileReductionUsingFor) and PartialReductionOuterParallel
(linalg::tileReductionUsingForall, this isn't implemented in this
patch).

The patch makes tileReductionUsingFor use the tileUsingFor
implementation with the new reduction tiling options.

There are no test changes because the implementation was doing almost
the exactly same thing. This was also tested in IREE (which uses both
these APIs heavily) and there were no test changes.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9f5f9f3fca97ad..d2cddfe00ac78e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,36 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify how reduction dimensions should be tiled.
+  ///
+  /// Tiling can be thought of as splitting a dimension into 2 and materializing
+  /// the outer dimension as a loop:
+  ///
+  /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+  ///
+  /// For parallel dimensions, the split can only happen in one way, with both
+  /// dimensions being parallel. For reduction dimensions however, there is a
+  /// choice in how we split the reduction dimension. This enum exposes this
+  /// choice.
+  enum class ReductionTilingStrategy {
+    // [reduction] -> [reduction1, reduction2]
+    // -> loop[reduction1] { [reduction2] }
+    FullReduction,
+    // [reduction] -> [reduction1, parallel2]
+    // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+    PartialReductionOuterReduction,
+    // [reduction] -> [parallel1, reduction2]
+    // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+    PartialReductionOuterParallel
+  };
+  ReductionTilingStrategy reductionStrategy =
+      ReductionTilingStrategy::FullReduction;
+  SCFTilingOptions &
+  setReductionTilingStrategy(ReductionTilingStrategy strategy) {
+    reductionStrategy = strategy;
+    return *this;
+  }
+
   /// Specify mapping of loops to devices. This is only respected when the loop
   /// constructs support such a mapping (like `scf.forall`). Will be ignored
   /// when using loop constructs that dont support such a mapping (like
@@ -102,11 +132,16 @@ struct SCFTilingResult {
   /// matter except the last op. The replacements are expected to be the results
   /// of the last op.
   SmallVector<Operation *> tiledOps;
+  /// The initial destination values passed to the tiled operations.
+  SmallVector<Value> initialValues;
   /// The `scf.for` operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
-  /// Values to use as replacements for the untiled op. Is the same size as the
-  /// number of results of the untiled op.
-  SmallVector<Value> replacements;
+  /// The result generated by the loop nest in tiling, may hold partial results,
+  /// which need to be merged to match the computation of the untiled operation.
+  /// `mergeResult` contains the operations used to perform this merge from
+  /// partial results and the values that can be used as replacements of
+  /// the untiled operation.
+  MergeResult mergeResult;
   /// Slices generated after tiling that can be used for fusing with the tiled
   /// producer.
   SmallVector<Operation *> generatedSlices;
@@ -300,20 +335,6 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
 FailureOr<SmallVector<scf::ForOp>>
 lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
 
-/// Transformation information returned after reduction tiling.
-struct SCFReductionTilingResult {
-  /// The partial reduction tiled op generated.
-  SmallVector<Operation *> parallelTiledOps;
-  /// The final reduction operation merging all the partial reductions.
-  SmallVector<Operation *> mergeOps;
-  /// Initial values used for reduction.
-  SmallVector<Value> initialValues;
-  /// The loop operations that iterate over the tiles.
-  SmallVector<LoopLikeOpInterface> loops;
-  /// The replacements to use for the results of the tiled operation.
-  SmallVector<Value> replacements;
-};
-
 /// Method to tile a reduction and generate a parallel op within a serial loop.
 /// Each of the partial reductions are calculated in parallel. Then after the
 /// loop all the partial reduction are merged into a final reduction.
@@ -338,7 +359,7 @@ struct SCFReductionTilingResult {
 /// %6 = linalg.generic %1 ["parallel", "reduction"]
 ///   : tensor<7x4xf32> -> tensor<7xf32>
 /// ```
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
 tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
                       ArrayRef<OpFoldResult> tileSize);
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8397652d1d8a8a..18fd24da395b76 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2223,7 +2223,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
     return emitDefaultDefiniteFailure(target);
 
   if (target->getNumResults())
-    rewriter.replaceOp(target, maybeTilingResult->replacements);
+    rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
   else
     rewriter.eraseOp(target);
 
@@ -2630,17 +2630,18 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
+  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
       rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
   if (failed(result))
     return emitDefaultSilenceableFailure(target);
+  rewriter.replaceOp(target, result->mergeResult.replacements);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
-  for (auto parallelTiledOp : result->parallelTiledOps)
+  for (auto parallelTiledOp : result->tiledOps)
     results.push_back(parallelTiledOp);
-  for (auto mergeOp : result->mergeOps)
+  for (auto mergeOp : result->mergeResult.mergeOps)
     results.push_back(mergeOp);
   results.push_back(result->loops.front());
   return DiagnosedSilenceableFailure::success();
@@ -3064,7 +3065,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
     if (failed(maybeTilingResult))
       return DiagnosedSilenceableFailure::definiteFailure();
 
-    rewriter.replaceOp(op, maybeTilingResult->replacements);
+    rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
 
     tiled.append(maybeTilingResult->tiledOps);
     for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3303,7 +3304,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
 
-  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
+  rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
 
   tilingResult = *maybeTilingResult;
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 6a4a6b43933806..ef5d4370e78102 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -570,6 +570,144 @@ static LogicalResult generateLoopNest(
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
+static FailureOr<SmallVector<Value>>
+createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
+                              ArrayRef<OpFoldResult> tileSizes,
+                              const scf::SCFTilingOptions &options) {
+  SmallVector<Value> initTensors;
+  Location loc = op->getLoc();
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
+      return failure();
+    return initTensors;
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only supported"
+              "for operations implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.generateInitialTensorForPartialReduction(
+        rewriter, loc, tileSizes, reductionDims);
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
+static FailureOr<TilingResult>
+getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
+                       ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
+                       ArrayRef<OpFoldResult> sizes,
+                       const scf::SCFTilingOptions &options) {
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return op.getTiledImplementation(rewriter, offsets, sizes);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
+                                        offsets, sizes, reductionDims);
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
+static LogicalResult
+getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
+                      TilingInterface op, ArrayRef<OpFoldResult> offsets,
+                      ArrayRef<OpFoldResult> sizes,
+                      SmallVector<OpFoldResult> &resultOffset,
+                      SmallVector<OpFoldResult> &resultSize,
+                      const scf::SCFTilingOptions &options) {
+
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return op.getResultTilePosition(rewriter, index, offsets, sizes,
+                                    resultOffset, resultSize);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    // TODO: This does not work for non identity accesses to the result tile.
+    // The proper fix is to add a getPartialResultTilePosition method to
+    // PartialReductionOpInterface.
+    resultOffset =
+        SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
+    for (size_t i = 0; i < offsets.size(); i++) {
+      resultSize.push_back(
+          tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
+    }
+    return success();
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+  }
+}
+
+static FailureOr<MergeResult>
+mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
+                   ValueRange partialResults,
+                   const scf::SCFTilingOptions &options) {
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    // No need to merge results for reduction tiling strategy.
+    return MergeResult{{}, partialResults};
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
+                                 reductionDims);
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
 /// Append the specified additional `newInitOperands` operands to the
 /// loops existing `init` operands (or similar), and replace `loopOp` with
 /// the new loop that has the additional init operands. The loop body of
@@ -710,11 +848,11 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
       });
 }
 
-/// Method to add new init values to a loop nest. Updates `loops` in-place with
-/// new loops that use the `newInitValues`.
-/// The outer-loops are updated to yield the new result values of the inner
-/// loop. For the innermost loop, the call back `getNewYields` is invoked to get
-/// the additional values to yield form the innermost loop.
+/// Method to add new init values to a loop nest. Updates `loops` in-place
+/// with new loops that use the `newInitValues`. The outer-loops are updated
+/// to yield the new result values of the inner loop. For the innermost loop,
+/// the call back `getNewYields` is invoked to get the additional values to
+/// yield form the innermost loop.
 static LogicalResult addInitOperandsToLoopNest(
     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
     ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
@@ -852,9 +990,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     auto clonedOp = cast<TilingInterface>(
         cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
 
-    // 5b. Early return cloned op if tiling is not happening. We can not return
-    // the original op because it could lead to
-    // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
+    // 5b. Early return cloned op if tiling is not happening. We can not
+    // return the original op because it could lead to `rewriter.replaceOp(op,
+    // op->getResults())` and users would get crash.
     if (llvm::all_of(tileSizes, isZeroIndex)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
@@ -864,7 +1002,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     }
 
     // 5c. Tile the cloned operation.
-    tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
+    tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
+                                          offsets, sizes, options);
     if (failed(tilingResult)) {
       rewriter.eraseOp(clonedOp);
       return op.emitOpError("faild to tile operation");
@@ -879,8 +1018,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
          llvm::enumerate(tilingResult->tiledValues)) {
       tiledResults.push_back(tiledValue);
       SmallVector<OpFoldResult> resultOffset, resultSize;
-      if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
-                                          resultOffset, resultSize))) {
+      if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
+                                       sizes, resultOffset, resultSize,
+                                       options))) {
         for (auto op : tilingResult->tiledOps) {
           rewriter.eraseOp(op);
         }
@@ -895,158 +1035,65 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   };
 
   // 6. Find the destination tensors to use for the operation.
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors))) {
-    return rewriter.notifyMatchFailure(op,
-                                       "unable to create destination tensors");
+  FailureOr<SmallVector<Value>> maybeInits =
+      createInitialTensorsForTiling(rewriter, op, tileSizes, options);
+  if (failed(maybeInits)) {
+    return rewriter.notifyMatchFailure(
+        op, "unable to create initial tensors for tiling");
   }
+  SmallVector<Value> &initTensors = maybeInits.value();
 
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;
   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
-                              tileSizes, numThreads, destinationTensors,
+                              tileSizes, numThreads, initTensors,
                               innerYieldTiledValuesFn, loops)))
     return op.emitOpError("failed to generate tiling loops");
   assert(succeeded(tilingResult) &&
          "expected tiling result to be computed after loop generation");
 
-  // If loops are empty, the tiled op is used as the replacement for the untiled
-  // op.
+  SmallVector<Value> partialResults;
   if (loops.empty()) {
-    return scf::SCFTilingResult{tilingResult->tiledOps, loops,
-                                tilingResult->tiledValues,
-                                tilingResult->generatedSlices};
+    // If loops are empty, the tiled op is used as the replacement for the
+    // untiled op.
+    partialResults = tilingResult->tiledValues;
+  } else {
+    partialResults = llvm::map_to_vector(loops.front()->getResults(),
+                                         [](OpResult r) -> Value { return r; });
   }
 
-  SmallVector<Value> replacements = llvm::map_to_vector(
-      loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+  FailureOr<MergeResult> mergeResult =
+      mergeTilingResults(rewriter, op, partialResults, options);
+  if (failed(mergeResult)) {
+    return rewriter.notifyMatchFailure(
+        op, "Failed to merge partial results from tiling");
+  }
+
+  return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
+                              mergeResult.value(),
                               tilingResult->generatedSlices};
 }
 
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
 mlir::scf::tileReductionUsingScf(RewriterBase &b,
                                  PartialReductionOpInterface op,
                                  ArrayRef<OpFoldResult> tileSizes) {
-  Location loc = op.getLoc();
-  // Ops implementing PartialReductionOpInterface are expected to implement
-  // TilingInterface.
-  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
-  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
-  auto tileSizesVector = llvm::to_vector(tileSizes);
-  if (tileSizesVector.size() < iterationDomain.size()) {
-    auto zero = b.getIndexAttr(0);
-    tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
-                           zero);
-  }
-  SmallVector<utils::IteratorType> iterators =
-      tilingInterfaceOp.getLoopIteratorTypes();
-
-  SmallVector<int> reductionDims;
-  for (auto [idx, iteratorType] :
-       llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
-    if (iteratorType == utils::IteratorType::reduction)
-      reductionDims.push_back(idx);
-  }
-
-  // 2. create the inital tensor value.
-  FailureOr<SmallVector<Value>> maybeInitTensors =
-      op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
-                                                  reductionDims);
-  if (failed(maybeInitTensors)) {
-    return b.notifyMatchFailure(op, "Failed to create initial tensors.");
-  }
-  SmallVector<Value> &initTensors = maybeInitTensors.value();
-
-  // 3. Define the callback to use for generating the inner most tile loop body.
-  SmallVector<Operation *> parallelTiledOps;
-  auto innerYieldTiledValuesFn =
-      [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
-          ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
-          SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
-          SmallVector<SmallVector<OpFoldResult>> &resultSizes)
-      -> LogicalResult {
-    SmallVector<OpFoldResult> offsets, sizes;
-    {
-      int materializedLoopNum = 0;
-      for (auto [tileSize, loopRange] :
-           llvm::zip_equal(tileSizesVector, iterationDomain)) {
-        if (isConstantIntValue(tileSize, 0)) {
-          offsets.push_back(loopRange.offset);
-          sizes.push_back(loopRange.size);
-          continue;
-        }
-        Value iv = ivs[materializedLoopNum++];
-        offsets.push_back(iv);
-        sizes.push_back(
-            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
-      }
-    }
-
-    // 4a. Clone the operation.
-    {
-      auto clonedOp = cast<PartialReductionOpInterface>(
-          cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
-
-      // 4b. Tile the cloned operation.
-      FailureOr<TilingResult> partialTilingResult =
-          clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
-                                          sizes, reductionDims);
-      if (failed(partialTilingResult)) {
-        return failure();
-      }
-      std::swap(parallelTiledOps, partialTilingResult->tiledOps);
-      std::swap(tiledResult, partialTilingResult->tiledValues);
-
-      // 4c. Delete the cloned operation.
-      b.eraseOp(clonedOp);
-    }
-
-    // 4d. Compute the offsets and sizes needed to insert the result of the
-    // tiled value back into destination before yielding the destination.
-    for (auto result : tiledResult) {
-      SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-      resultOffsets.emplace_back(std::move(outOffsets));
-
-      SmallVector<OpFoldResult> outSizes;
-      for (size_t i = 0; i < offsets.size(); i++) {
-        outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
-      }
-      resultSizes.emplace_back(std::move(outSizes));
-    }
-    return success();
-  };
-
-  // 5. Generate the tiled implementation using the destination tensors.
-  SmallVector<LoopLikeOpInterface> loops;
-  scf::SCFTilingOptions options;
-  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
-  if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
-                              /*numThreads=*/ArrayRef<OpFoldResult>{},
-                              initTensors, innerYieldTiledValuesFn, loops)))
-    return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
-
-  SmallVector<Value> replacements = llvm::map_to_vector(
-      loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-
-  // 5. Apply the merge reduction to combine all the partial values.
-  b.setInsertionPointAfter(*loops.begin());
-  FailureOr<MergeResult> mergeResult =
-      op.mergeReductions(b, loc, replacements, reductionDims);
-  if (failed(mergeResult)) {
-    return failure();
-  }
-  b.replaceOp(op, mergeResult->replacements);
-
-  SCFReductionTilingResult reductionTilingResult;
-  std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
-  std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
-  std::swap(reductionTilingResult.initialValues, initTensors);
-  std::swap(reductionTilingResult.loops, loops);
-  std::swap(reductionTilingResult.replacements, mergeResult->replacements);
-
-  return reductionTilingResult;
+  SCFTilingOptions options;
+  options.setLoopType(SCFTilingOptions::LoopType::ForOp);
+  options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
+                                         PartialReductionOuterReduction);
+  options.setTileSizes(tileSizes);
+
+  TilingInterface tilingInterfaceOp =
+      dyn_cast<TilingInterface>(op.getOperation());
+  if (!tilingInterfaceOp) {
+    return b.notifyMatchFailure(
+        op,
+        "Operation implementing PartialReductionOpInterface should implement "
+        "TilingInterface");
+  }
+
+  return tileUsingSCF(b, tilingInterfaceOp, options);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1055,9 +1102,10 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 
 /// Return the untiled producer whose slice is used in a tiled consumer. The
 /// method traverses the tile loop nest (`loops`) if needed, and returns the
-/// `iter_args` of the outer most that is encountered. Traversing the iter_args
-/// indicates that this is a destination operand of the consumer. If there was
-/// no loop traversal needed, the second value of the returned tuple is empty.
+/// `iter_args` of the outer most that is encountered. Traversing the
+/// iter_args indicates that this is a destination operand of the consumer. If
+/// there was no loop traversal needed, the second value of the returned tuple
+/// is empty.
 static std::tuple<OpResult, std::optional<OpOperand *>>
 getUntiledProducerFromSliceSource(OpOperand *source,
                                   ArrayRef<LoopLikeOpInterface> loops) {
@@ -1115,8 +1163,8 @@ mlir::scf::tileAndFuseProducerOfSlice(
   Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
       rewriter, fusableProducerOp, clonedOpDestinationTensors);
   // 2d. Update the source of the candidateSlice to be the cloned producer.
-  //     Easier to just clone the slice with 
diff erent source since replacements
-  //     and DCE of cloned ops becomes easier
+  //     Easier to just clone the slice with 
diff erent source since
+  //     replacements and DCE of cloned ops becomes easier
   SmallVector<Value> candidateSliceOpOperands =
       llvm::to_vector(candidateSliceOp->getOperands());
   candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
@@ -1250,13 +1298,13 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
         failed(tilableOp.getIterationDomainTileFromResultTile(
             rewriter, sliceResultNumber, sliceOffset, sliceSizes,
             iterDomainOffset, iterDomainSizes))) {
-      // In theory, it is unnecessary to raise an error here. Actually although
-      // it fails to reconstruct the result tensor, it should not broke current
-      // fusion anyway. The reason why we must return failure currently is that
-      // the callback function `newYieldValuesFn` will be called after new init
-      // operand(s) has already been appended. It will take more refactoring to
-      // make sure the init operands are added consistently in the future. For
-      // more details, please refer to:
+      // In theory, it is unnecessary to raise an error here. Actually
+      // although it fails to reconstruct the result tensor, it should not
+      // broke current fusion anyway. The reason why we must return failure
+      // currently is that the callback function `newYieldValuesFn` will be
+      // called after new init operand(s) has already been appended. It will
+      // take more refactoring to make sure the init operands are added
+      // consistently in the future. For more details, please refer to:
       // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
       return failure();
     }
@@ -1282,7 +1330,8 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
       }
     }
 
-    // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+    // d. create `extract_slice` for `iter_args` for DPS operation if
+    // necessary
     if (auto tiledDestStyleOp =
             dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
       rewriter.setInsertionPoint(tiledDestStyleOp);
@@ -1334,9 +1383,10 @@ class SliceTrackingListener : public RewriterBase::Listener {
       std::optional<FrozenRewritePatternSet> patterns);
   SliceTrackingListener() = default;
 
-  /// Adds the given list of operations to the worklist, and if present, applies
-  /// the list of `patterns` to the newly added operations. This only processes
-  /// the given operations and any newly inserted ones by the pattern set.
+  /// Adds the given list of operations to the worklist, and if present,
+  /// applies the list of `patterns` to the newly added operations. This only
+  /// processes the given operations and any newly inserted ones by the
+  /// pattern set.
   LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
 
   /// Add to the new operation worklist if it is an extract_slice.
@@ -1357,7 +1407,8 @@ class SliceTrackingListener : public RewriterBase::Listener {
   std::deque<tensor::ExtractSliceOp> worklist;
 
 private:
-  /// Optional pattern set to apply when adding new operations to the worklist.
+  /// Optional pattern set to apply when adding new operations to the
+  /// worklist.
   std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
 };
 
@@ -1390,8 +1441,9 @@ void SliceTrackingListener::notifyOperationInserted(
   worklist.push_back(slice);
 }
 
-// Scan the worklist for the given op and remove it if present. The expectation
-// is for the worklist to be small and for removal to be relatively rare.
+// Scan the worklist for the given op and remove it if present. The
+// expectation is for the worklist to be small and for removal to be
+// relatively rare.
 void SliceTrackingListener::removeOp(Operation *op) {
   if (!isa<tensor::ExtractSliceOp>(op))
     return;
@@ -1445,17 +1497,18 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   auto &loops = tilingResult->loops;
   if (loops.empty()) {
     DenseMap<Value, Value> replacements;
-    for (auto [origVal, replacement] :
-         llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
+    for (auto [origVal, replacement] : llvm::zip_equal(
+             consumer->getResults(), tilingResult->mergeResult.replacements)) {
       replacements[origVal] = replacement;
     }
     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
                                      replacements};
   }
 
-  // To keep track of replacements for now just record the map from the original
-  // untiled value to the result number of the for loop. Since the loop gets
-  // potentially replaced during fusion, keeping the value directly wont work.
+  // To keep track of replacements for now just record the map from the
+  // original untiled value to the result number of the for loop. Since the
+  // loop gets potentially replaced during fusion, keeping the value directly
+  // wont work.
   DenseMap<Value, size_t> origValToResultNumber;
   for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
     origValToResultNumber[result] = index;
@@ -1463,11 +1516,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
 
   // 2. Typically, the operands of the tiled operation are slices of the
   //    operands of the untiled operation. These are expressed in IR using
-  //    `tensor.extract_slice` operations with source being the operands of the
-  //    untiled operation. Create a worklist of these `tensor.extract_slice`
-  //    operations. If the producers of the source of the `tensor.extract_slice`
-  //    can be tiled such that the tiled value is generated in-place, that
-  //    effectively tiles + fuses the operations.
+  //    `tensor.extract_slice` operations with source being the operands of
+  //    the untiled operation. Create a worklist of these
+  //    `tensor.extract_slice` operations. If the producers of the source of
+  //    the `tensor.extract_slice` can be tiled such that the tiled value is
+  //    generated in-place, that effectively tiles + fuses the operations.
   struct WorklistItem {
     tensor::ExtractSliceOp candidateSlice;
     SCFTileAndFuseOptions::ControlFnResult controlFnResult;
@@ -1511,9 +1564,10 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
     SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
 
     if (worklistItem.controlFnResult.yieldProducerReplacement) {
-      // Reconstruct and yield all opResult of fusableProducerOp by default. The
-      // caller can specific which one to yield by designating optional argument
-      // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+      // Reconstruct and yield all opResult of fusableProducerOp by default.
+      // The caller can specific which one to yield by designating optional
+      // argument named `yieldResultNumber` of
+      // `yieldReplacementForFusedProducer`.
       Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
       FailureOr<SmallVector<Operation *>> newSlices =
           yieldReplacementForFusedProducer(rewriter,
@@ -1582,8 +1636,8 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
   return success();
 }
 
-/// An utility to get the first user of the given loopOp. If any of user stay in
-/// 
diff erent block of loopOp, return failure.
+/// An utility to get the first user of the given loopOp. If any of user stay
+/// in 
diff erent block of loopOp, return failure.
 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
   if (!isa<LoopLikeOpInterface>(loopOp))
     return failure();
@@ -1616,11 +1670,11 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
   return firstUserOfLoop;
 }
 
-/// This utility currently checks whether the first userOp of loop is NOT before
-/// the last defineOp of consumer operand. Because that we need to move the
-/// whole loop structure right before the `firstUserOfLoop`. This utility thus
-/// helps ensuring that no invalid IR is formed, i.e. no backward slice of
-/// consumerOp is dominated by the `firstUserOfLoop`. Saying that:
+/// This utility currently checks whether the first userOp of loop is NOT
+/// before the last defineOp of consumer operand. Because that we need to move
+/// the whole loop structure right before the `firstUserOfLoop`. This utility
+/// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
+/// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
 ///
 /// ```
 /// %0 = scf.for() {
@@ -1634,9 +1688,9 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
 /// %3 = consumerOp(%2)
 /// ```
 ///
-/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
-/// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
-/// use-def chain violation:
+/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
+/// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
+/// a.k.a. use-def chain violation:
 ///
 /// ```
 /// %0:2 = scf.for() {
@@ -1650,10 +1704,10 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
 ///
 /// @param loopOp: loop operation
 /// @param consumerOp: consumer operation
-/// @param reorderOperations: the flag controls whether to reorder the backward
-/// slice w.r.t. the defineOp of `consumerOp` operands.
-/// @return: computed backward slice of consumerOp, but excluding those already
-/// dominates `firstUserOfLoop`.
+/// @param reorderOperations: the flag controls whether to reorder the
+/// backward slice w.r.t. the defineOp of `consumerOp` operands.
+/// @return: computed backward slice of consumerOp, but excluding those
+/// already dominates `firstUserOfLoop`.
 static FailureOr<llvm::SetVector<Operation *>>
 checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
                        bool reorderOperations) {
@@ -1713,8 +1767,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
     if (!isa<TilingInterface>(consumerOp) ||
         !isa<DestinationStyleOpInterface>(consumerOp)) {
       // TODO: We have to init result of consumer before scf.for, use
-      // DestinationStyleOpInterface to get result shape from init for now. Add
-      // support for other op such as op has InferTypeOpInterface.
+      // DestinationStyleOpInterface to get result shape from init for now.
+      // Add support for other op such as op has InferTypeOpInterface.
       continue;
     }
     // Step 2. Check if user stay in the same block.
@@ -1729,7 +1783,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
         checkAssumptionForLoop(loopOp, consumerOp, true);
     if (failed(slice))
       continue;
-    // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
+    // Step 5. If backward sice is not empty, move them before
+    // firstUserOfLoop.
     if (!slice->empty()) {
       mlir::topologicalSort(*slice);
       FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
@@ -1743,8 +1798,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
   return failure();
 }
 
-/// Find the perfectly nested loops outside of given loop(included) sorted from
-/// outer to inner.
+/// Find the perfectly nested loops outside of given loop(included) sorted
+/// from outer to inner.
 ///
 /// E.g.
 ///
@@ -1997,10 +2052,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
     }
 
     // 10. Try to get iter domain position from input position. Use
-    // clonedConsumerOp instead of tiledConsumerOp, because the iteration domain
-    // may require index computation based on the result size. The sizes and
-    // offsets should be the same either way, but using tiledConsumerOp could
-    // lead to some chained unnecessary extra index computation.
+    // clonedConsumerOp instead of tiledConsumerOp, because the iteration
+    // domain may require index computation based on the result size. The
+    // sizes and offsets should be the same either way, but using
+    // tiledConsumerOp could lead to some chained unnecessary extra index
+    // computation.
     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
     if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
             rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
@@ -2067,7 +2123,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
                                        "unable to add new inits to nest loop");
   }
 
-  // 15. Replace the result of scf loop and consumer op with new loop's results.
+  // 15. Replace the result of scf loop and consumer op with new loop's
+  // results.
 
   for (auto &&[oldResult, newResult] : llvm::zip(
            consumerOp->getResults(),

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 5e903e378daf82..7380b766935ffe 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -250,7 +250,8 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
       return failure();
 
     // Perform the replacement of tiled and fused values.
-    rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
+    rewriter.replaceOp(tilingInterfaceOp,
+                       tiledResults->mergeResult.replacements);
 
     // Report back the relevant handles to the transform op.
     tiledOps.push_back(tiledResults->tiledOps.front());


        


More information about the Mlir-commits mailing list