[Mlir-commits] [mlir] [mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (PR #120115)

Kunwar Grover llvmlistbot at llvm.org
Wed Dec 18 05:12:24 PST 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/120115

>From facdf23493e348a2c3a03083a24b4b7a6c072a9c Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 13 Dec 2024 16:58:10 +0000
Subject: [PATCH 1/3] [mlir][SCF] Merge tileUsingFor and tileReductionUsingFor
 implementation

---
 .../SCF/Transforms/TileUsingInterface.h       |  57 ++-
 .../TransformOps/LinalgTransformOps.cpp       |  13 +-
 .../SCF/Transforms/TileUsingInterface.cpp     | 458 ++++++++++--------
 .../TestTilingInterfaceTransformOps.cpp       |   3 +-
 4 files changed, 306 insertions(+), 225 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9f5f9f3fca97ad..d2cddfe00ac78e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,36 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify how reduction dimensions should be tiled.
+  ///
+  /// Tiling can be thought of as splitting a dimension into 2 and materializing
+  /// the outer dimension as a loop:
+  ///
+  /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+  ///
+  /// For parallel dimensions, the split can only happen in one way, with both
+  /// dimensions being parallel. For reduction dimensions however, there is a
+  /// choice in how we split the reduction dimension. This enum exposes this
+  /// choice.
+  enum class ReductionTilingStrategy {
+    // [reduction] -> [reduction1, reduction2]
+    // -> loop[reduction1] { [reduction2] }
+    FullReduction,
+    // [reduction] -> [reduction1, parallel2]
+    // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+    PartialReductionOuterReduction,
+    // [reduction] -> [parallel1, reduction2]
+    // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+    PartialReductionOuterParallel
+  };
+  ReductionTilingStrategy reductionStrategy =
+      ReductionTilingStrategy::FullReduction;
+  SCFTilingOptions &
+  setReductionTilingStrategy(ReductionTilingStrategy strategy) {
+    reductionStrategy = strategy;
+    return *this;
+  }
+
   /// Specify mapping of loops to devices. This is only respected when the loop
   /// constructs support such a mapping (like `scf.forall`). Will be ignored
   /// when using loop constructs that dont support such a mapping (like
@@ -102,11 +132,16 @@ struct SCFTilingResult {
   /// matter except the last op. The replacements are expected to be the results
   /// of the last op.
   SmallVector<Operation *> tiledOps;
+  /// The initial destination values passed to the tiled operations.
+  SmallVector<Value> initialValues;
   /// The `scf.for` operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
-  /// Values to use as replacements for the untiled op. Is the same size as the
-  /// number of results of the untiled op.
-  SmallVector<Value> replacements;
+  /// The result generated by the loop nest in tiling, may hold partial results,
+  /// which need to be merged to match the computation of the untiled operation.
+  /// `mergeResult` contains the operations used to perform this merge from
+  /// partial results and the values that can be used as replacements of
+  /// the untiled operation.
+  MergeResult mergeResult;
   /// Slices generated after tiling that can be used for fusing with the tiled
   /// producer.
   SmallVector<Operation *> generatedSlices;
@@ -300,20 +335,6 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
 FailureOr<SmallVector<scf::ForOp>>
 lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
 
-/// Transformation information returned after reduction tiling.
-struct SCFReductionTilingResult {
-  /// The partial reduction tiled op generated.
-  SmallVector<Operation *> parallelTiledOps;
-  /// The final reduction operation merging all the partial reductions.
-  SmallVector<Operation *> mergeOps;
-  /// Initial values used for reduction.
-  SmallVector<Value> initialValues;
-  /// The loop operations that iterate over the tiles.
-  SmallVector<LoopLikeOpInterface> loops;
-  /// The replacements to use for the results of the tiled operation.
-  SmallVector<Value> replacements;
-};
-
 /// Method to tile a reduction and generate a parallel op within a serial loop.
 /// Each of the partial reductions are calculated in parallel. Then after the
 /// loop all the partial reduction are merged into a final reduction.
@@ -338,7 +359,7 @@ struct SCFReductionTilingResult {
 /// %6 = linalg.generic %1 ["parallel", "reduction"]
 ///   : tensor<7x4xf32> -> tensor<7xf32>
 /// ```
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
 tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
                       ArrayRef<OpFoldResult> tileSize);
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8839faf4cafb2d..66a3947e0f91fc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2224,7 +2224,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
     return emitDefaultDefiniteFailure(target);
 
   if (target->getNumResults())
-    rewriter.replaceOp(target, maybeTilingResult->replacements);
+    rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
   else
     rewriter.eraseOp(target);
 
@@ -2631,17 +2631,18 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
+  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
       rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
   if (failed(result))
     return emitDefaultSilenceableFailure(target);
+  rewriter.replaceOp(target, result->mergeResult.replacements);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
-  for (auto parallelTiledOp : result->parallelTiledOps)
+  for (auto parallelTiledOp : result->tiledOps)
     results.push_back(parallelTiledOp);
-  for (auto mergeOp : result->mergeOps)
+  for (auto mergeOp : result->mergeResult.mergeOps)
     results.push_back(mergeOp);
   results.push_back(result->loops.front());
   return DiagnosedSilenceableFailure::success();
@@ -3065,7 +3066,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
     if (failed(maybeTilingResult))
       return DiagnosedSilenceableFailure::definiteFailure();
 
-    rewriter.replaceOp(op, maybeTilingResult->replacements);
+    rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
 
     tiled.append(maybeTilingResult->tiledOps);
     for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3304,7 +3305,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
 
-  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
+  rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
 
   tilingResult = *maybeTilingResult;
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 6a4a6b43933806..8ece9fb259ddd6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -570,6 +570,146 @@ static LogicalResult generateLoopNest(
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
+static LogicalResult
+createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
+                              ArrayRef<OpFoldResult> tileSizes,
+                              SmallVector<Value> &initTensors,
+                              const scf::SCFTilingOptions &options) {
+  Location loc = op->getLoc();
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return tensor::getOrCreateDestinations(rewriter, loc, op, initTensors);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    FailureOr<SmallVector<Value>> maybeInitTensors =
+        redOp.generateInitialTensorForPartialReduction(rewriter, loc, tileSizes,
+                                                       reductionDims);
+    if (failed(maybeInitTensors)) {
+      return failure();
+    }
+    initTensors = maybeInitTensors.value();
+    return success();
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
+static FailureOr<TilingResult>
+getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
+                       ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
+                       ArrayRef<OpFoldResult> sizes,
+                       const scf::SCFTilingOptions &options) {
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return op.getTiledImplementation(rewriter, offsets, sizes);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
+                                        offsets, sizes, reductionDims);
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
+static LogicalResult
+getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
+                      TilingInterface op, ArrayRef<OpFoldResult> offsets,
+                      ArrayRef<OpFoldResult> sizes,
+                      SmallVector<OpFoldResult> &resultOffset,
+                      SmallVector<OpFoldResult> &resultSize,
+                      const scf::SCFTilingOptions &options) {
+
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return op.getResultTilePosition(rewriter, index, offsets, sizes,
+                                    resultOffset, resultSize);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    resultOffset =
+        SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
+    for (size_t i = 0; i < offsets.size(); i++) {
+      resultSize.push_back(
+          tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
+    }
+    return success();
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+  }
+}
+
+static FailureOr<MergeResult>
+mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
+                   ValueRange partialResults,
+                   const scf::SCFTilingOptions &options) {
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    // No need to merge results for reduction tiling strategy.
+    return MergeResult{{}, partialResults};
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
+                                 reductionDims);
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
 /// Append the specified additional `newInitOperands` operands to the
 /// loops existing `init` operands (or similar), and replace `loopOp` with
 /// the new loop that has the additional init operands. The loop body of
@@ -710,11 +850,11 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
       });
 }
 
-/// Method to add new init values to a loop nest. Updates `loops` in-place with
-/// new loops that use the `newInitValues`.
-/// The outer-loops are updated to yield the new result values of the inner
-/// loop. For the innermost loop, the call back `getNewYields` is invoked to get
-/// the additional values to yield form the innermost loop.
+/// Method to add new init values to a loop nest. Updates `loops` in-place
+/// with new loops that use the `newInitValues`. The outer-loops are updated
+/// to yield the new result values of the inner loop. For the innermost loop,
+/// the call back `getNewYields` is invoked to get the additional values to
+/// yield form the innermost loop.
 static LogicalResult addInitOperandsToLoopNest(
     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
     ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
@@ -852,9 +992,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     auto clonedOp = cast<TilingInterface>(
         cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
 
-    // 5b. Early return cloned op if tiling is not happening. We can not return
-    // the original op because it could lead to
-    // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
+    // 5b. Early return cloned op if tiling is not happening. We can not
+    // return the original op because it could lead to `rewriter.replaceOp(op,
+    // op->getResults())` and users would get crash.
     if (llvm::all_of(tileSizes, isZeroIndex)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
@@ -864,7 +1004,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     }
 
     // 5c. Tile the cloned operation.
-    tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
+    tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
+                                          offsets, sizes, options);
     if (failed(tilingResult)) {
       rewriter.eraseOp(clonedOp);
       return op.emitOpError("faild to tile operation");
@@ -879,8 +1020,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
          llvm::enumerate(tilingResult->tiledValues)) {
       tiledResults.push_back(tiledValue);
       SmallVector<OpFoldResult> resultOffset, resultSize;
-      if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
-                                          resultOffset, resultSize))) {
+      if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
+                                       sizes, resultOffset, resultSize,
+                                       options))) {
         for (auto op : tilingResult->tiledOps) {
           rewriter.eraseOp(op);
         }
@@ -895,158 +1037,64 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   };
 
   // 6. Find the destination tensors to use for the operation.
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors))) {
-    return rewriter.notifyMatchFailure(op,
-                                       "unable to create destination tensors");
+  SmallVector<Value> initTensors;
+  if (failed(createInitialTensorsForTiling(rewriter, op, tileSizes, initTensors,
+                                           options))) {
+    return rewriter.notifyMatchFailure(
+        op, "unable to create initial tensors for tiling");
   }
 
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;
   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
-                              tileSizes, numThreads, destinationTensors,
+                              tileSizes, numThreads, initTensors,
                               innerYieldTiledValuesFn, loops)))
     return op.emitOpError("failed to generate tiling loops");
   assert(succeeded(tilingResult) &&
          "expected tiling result to be computed after loop generation");
 
-  // If loops are empty, the tiled op is used as the replacement for the untiled
-  // op.
+  SmallVector<Value> partialResults;
   if (loops.empty()) {
-    return scf::SCFTilingResult{tilingResult->tiledOps, loops,
-                                tilingResult->tiledValues,
-                                tilingResult->generatedSlices};
+    // If loops are empty, the tiled op is used as the replacement for the
+    // untiled op.
+    partialResults = tilingResult->tiledValues;
+  } else {
+    partialResults = llvm::map_to_vector(loops.front()->getResults(),
+                                         [](OpResult r) -> Value { return r; });
+  }
+
+  FailureOr<MergeResult> mergeResult =
+      mergeTilingResults(rewriter, op, partialResults, options);
+  if (failed(mergeResult)) {
+    return rewriter.notifyMatchFailure(
+        op, "Failed to merge partial results from tiling");
   }
 
-  SmallVector<Value> replacements = llvm::map_to_vector(
-      loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+  return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
+                              mergeResult.value(),
                               tilingResult->generatedSlices};
 }
 
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
 mlir::scf::tileReductionUsingScf(RewriterBase &b,
                                  PartialReductionOpInterface op,
                                  ArrayRef<OpFoldResult> tileSizes) {
-  Location loc = op.getLoc();
-  // Ops implementing PartialReductionOpInterface are expected to implement
-  // TilingInterface.
-  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
-  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
-  auto tileSizesVector = llvm::to_vector(tileSizes);
-  if (tileSizesVector.size() < iterationDomain.size()) {
-    auto zero = b.getIndexAttr(0);
-    tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
-                           zero);
-  }
-  SmallVector<utils::IteratorType> iterators =
-      tilingInterfaceOp.getLoopIteratorTypes();
-
-  SmallVector<int> reductionDims;
-  for (auto [idx, iteratorType] :
-       llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
-    if (iteratorType == utils::IteratorType::reduction)
-      reductionDims.push_back(idx);
-  }
-
-  // 2. create the inital tensor value.
-  FailureOr<SmallVector<Value>> maybeInitTensors =
-      op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
-                                                  reductionDims);
-  if (failed(maybeInitTensors)) {
-    return b.notifyMatchFailure(op, "Failed to create initial tensors.");
-  }
-  SmallVector<Value> &initTensors = maybeInitTensors.value();
-
-  // 3. Define the callback to use for generating the inner most tile loop body.
-  SmallVector<Operation *> parallelTiledOps;
-  auto innerYieldTiledValuesFn =
-      [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
-          ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
-          SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
-          SmallVector<SmallVector<OpFoldResult>> &resultSizes)
-      -> LogicalResult {
-    SmallVector<OpFoldResult> offsets, sizes;
-    {
-      int materializedLoopNum = 0;
-      for (auto [tileSize, loopRange] :
-           llvm::zip_equal(tileSizesVector, iterationDomain)) {
-        if (isConstantIntValue(tileSize, 0)) {
-          offsets.push_back(loopRange.offset);
-          sizes.push_back(loopRange.size);
-          continue;
-        }
-        Value iv = ivs[materializedLoopNum++];
-        offsets.push_back(iv);
-        sizes.push_back(
-            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
-      }
-    }
-
-    // 4a. Clone the operation.
-    {
-      auto clonedOp = cast<PartialReductionOpInterface>(
-          cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
-
-      // 4b. Tile the cloned operation.
-      FailureOr<TilingResult> partialTilingResult =
-          clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
-                                          sizes, reductionDims);
-      if (failed(partialTilingResult)) {
-        return failure();
-      }
-      std::swap(parallelTiledOps, partialTilingResult->tiledOps);
-      std::swap(tiledResult, partialTilingResult->tiledValues);
-
-      // 4c. Delete the cloned operation.
-      b.eraseOp(clonedOp);
-    }
-
-    // 4d. Compute the offsets and sizes needed to insert the result of the
-    // tiled value back into destination before yielding the destination.
-    for (auto result : tiledResult) {
-      SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-      resultOffsets.emplace_back(std::move(outOffsets));
-
-      SmallVector<OpFoldResult> outSizes;
-      for (size_t i = 0; i < offsets.size(); i++) {
-        outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
-      }
-      resultSizes.emplace_back(std::move(outSizes));
-    }
-    return success();
-  };
-
-  // 5. Generate the tiled implementation using the destination tensors.
-  SmallVector<LoopLikeOpInterface> loops;
-  scf::SCFTilingOptions options;
-  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
-  if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
-                              /*numThreads=*/ArrayRef<OpFoldResult>{},
-                              initTensors, innerYieldTiledValuesFn, loops)))
-    return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
-
-  SmallVector<Value> replacements = llvm::map_to_vector(
-      loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-
-  // 5. Apply the merge reduction to combine all the partial values.
-  b.setInsertionPointAfter(*loops.begin());
-  FailureOr<MergeResult> mergeResult =
-      op.mergeReductions(b, loc, replacements, reductionDims);
-  if (failed(mergeResult)) {
-    return failure();
-  }
-  b.replaceOp(op, mergeResult->replacements);
-
-  SCFReductionTilingResult reductionTilingResult;
-  std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
-  std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
-  std::swap(reductionTilingResult.initialValues, initTensors);
-  std::swap(reductionTilingResult.loops, loops);
-  std::swap(reductionTilingResult.replacements, mergeResult->replacements);
-
-  return reductionTilingResult;
+  SCFTilingOptions options;
+  options.setLoopType(SCFTilingOptions::LoopType::ForOp);
+  options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
+                                         PartialReductionOuterReduction);
+  options.setTileSizes(tileSizes);
+
+  TilingInterface tilingInterfaceOp =
+      dyn_cast<TilingInterface>(op.getOperation());
+  if (!tilingInterfaceOp) {
+    return b.notifyMatchFailure(
+        op,
+        "Operation implementing PartialReductionOpInterface should implement "
+        "TilingInterface");
+  }
+
+  return tileUsingSCF(b, tilingInterfaceOp, options);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1055,9 +1103,10 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 
 /// Return the untiled producer whose slice is used in a tiled consumer. The
 /// method traverses the tile loop nest (`loops`) if needed, and returns the
-/// `iter_args` of the outer most that is encountered. Traversing the iter_args
-/// indicates that this is a destination operand of the consumer. If there was
-/// no loop traversal needed, the second value of the returned tuple is empty.
+/// `iter_args` of the outer most that is encountered. Traversing the
+/// iter_args indicates that this is a destination operand of the consumer. If
+/// there was no loop traversal needed, the second value of the returned tuple
+/// is empty.
 static std::tuple<OpResult, std::optional<OpOperand *>>
 getUntiledProducerFromSliceSource(OpOperand *source,
                                   ArrayRef<LoopLikeOpInterface> loops) {
@@ -1115,8 +1164,8 @@ mlir::scf::tileAndFuseProducerOfSlice(
   Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
       rewriter, fusableProducerOp, clonedOpDestinationTensors);
   // 2d. Update the source of the candidateSlice to be the cloned producer.
-  //     Easier to just clone the slice with different source since replacements
-  //     and DCE of cloned ops becomes easier
+  //     Easier to just clone the slice with different source since
+  //     replacements and DCE of cloned ops becomes easier
   SmallVector<Value> candidateSliceOpOperands =
       llvm::to_vector(candidateSliceOp->getOperands());
   candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
@@ -1250,13 +1299,13 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
         failed(tilableOp.getIterationDomainTileFromResultTile(
             rewriter, sliceResultNumber, sliceOffset, sliceSizes,
             iterDomainOffset, iterDomainSizes))) {
-      // In theory, it is unnecessary to raise an error here. Actually although
-      // it fails to reconstruct the result tensor, it should not broke current
-      // fusion anyway. The reason why we must return failure currently is that
-      // the callback function `newYieldValuesFn` will be called after new init
-      // operand(s) has already been appended. It will take more refactoring to
-      // make sure the init operands are added consistently in the future. For
-      // more details, please refer to:
+      // In theory, it is unnecessary to raise an error here. Actually
+      // although it fails to reconstruct the result tensor, it should not
+      // broke current fusion anyway. The reason why we must return failure
+      // currently is that the callback function `newYieldValuesFn` will be
+      // called after new init operand(s) has already been appended. It will
+      // take more refactoring to make sure the init operands are added
+      // consistently in the future. For more details, please refer to:
       // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
       return failure();
     }
@@ -1282,7 +1331,8 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
       }
     }
 
-    // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+    // d. create `extract_slice` for `iter_args` for DPS operation if
+    // necessary
     if (auto tiledDestStyleOp =
             dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
       rewriter.setInsertionPoint(tiledDestStyleOp);
@@ -1334,9 +1384,10 @@ class SliceTrackingListener : public RewriterBase::Listener {
       std::optional<FrozenRewritePatternSet> patterns);
   SliceTrackingListener() = default;
 
-  /// Adds the given list of operations to the worklist, and if present, applies
-  /// the list of `patterns` to the newly added operations. This only processes
-  /// the given operations and any newly inserted ones by the pattern set.
+  /// Adds the given list of operations to the worklist, and if present,
+  /// applies the list of `patterns` to the newly added operations. This only
+  /// processes the given operations and any newly inserted ones by the
+  /// pattern set.
   LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
 
   /// Add to the new operation worklist if it is an extract_slice.
@@ -1357,7 +1408,8 @@ class SliceTrackingListener : public RewriterBase::Listener {
   std::deque<tensor::ExtractSliceOp> worklist;
 
 private:
-  /// Optional pattern set to apply when adding new operations to the worklist.
+  /// Optional pattern set to apply when adding new operations to the
+  /// worklist.
   std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
 };
 
@@ -1390,8 +1442,9 @@ void SliceTrackingListener::notifyOperationInserted(
   worklist.push_back(slice);
 }
 
-// Scan the worklist for the given op and remove it if present. The expectation
-// is for the worklist to be small and for removal to be relatively rare.
+// Scan the worklist for the given op and remove it if present. The
+// expectation is for the worklist to be small and for removal to be
+// relatively rare.
 void SliceTrackingListener::removeOp(Operation *op) {
   if (!isa<tensor::ExtractSliceOp>(op))
     return;
@@ -1445,17 +1498,18 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   auto &loops = tilingResult->loops;
   if (loops.empty()) {
     DenseMap<Value, Value> replacements;
-    for (auto [origVal, replacement] :
-         llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
+    for (auto [origVal, replacement] : llvm::zip_equal(
+             consumer->getResults(), tilingResult->mergeResult.replacements)) {
       replacements[origVal] = replacement;
     }
     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
                                      replacements};
   }
 
-  // To keep track of replacements for now just record the map from the original
-  // untiled value to the result number of the for loop. Since the loop gets
-  // potentially replaced during fusion, keeping the value directly wont work.
+  // To keep track of replacements for now just record the map from the
+  // original untiled value to the result number of the for loop. Since the
+  // loop gets potentially replaced during fusion, keeping the value directly
+  // wont work.
   DenseMap<Value, size_t> origValToResultNumber;
   for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
     origValToResultNumber[result] = index;
@@ -1463,11 +1517,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
 
   // 2. Typically, the operands of the tiled operation are slices of the
   //    operands of the untiled operation. These are expressed in IR using
-  //    `tensor.extract_slice` operations with source being the operands of the
-  //    untiled operation. Create a worklist of these `tensor.extract_slice`
-  //    operations. If the producers of the source of the `tensor.extract_slice`
-  //    can be tiled such that the tiled value is generated in-place, that
-  //    effectively tiles + fuses the operations.
+  //    `tensor.extract_slice` operations with source being the operands of
+  //    the untiled operation. Create a worklist of these
+  //    `tensor.extract_slice` operations. If the producers of the source of
+  //    the `tensor.extract_slice` can be tiled such that the tiled value is
+  //    generated in-place, that effectively tiles + fuses the operations.
   struct WorklistItem {
     tensor::ExtractSliceOp candidateSlice;
     SCFTileAndFuseOptions::ControlFnResult controlFnResult;
@@ -1511,9 +1565,10 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
     SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
 
     if (worklistItem.controlFnResult.yieldProducerReplacement) {
-      // Reconstruct and yield all opResult of fusableProducerOp by default. The
-      // caller can specific which one to yield by designating optional argument
-      // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+      // Reconstruct and yield all opResult of fusableProducerOp by default.
+      // The caller can specific which one to yield by designating optional
+      // argument named `yieldResultNumber` of
+      // `yieldReplacementForFusedProducer`.
       Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
       FailureOr<SmallVector<Operation *>> newSlices =
           yieldReplacementForFusedProducer(rewriter,
@@ -1582,8 +1637,8 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
   return success();
 }
 
-/// An utility to get the first user of the given loopOp. If any of user stay in
-/// different block of loopOp, return failure.
+/// An utility to get the first user of the given loopOp. If any of user stay
+/// in different block of loopOp, return failure.
 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
   if (!isa<LoopLikeOpInterface>(loopOp))
     return failure();
@@ -1616,11 +1671,11 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
   return firstUserOfLoop;
 }
 
-/// This utility currently checks whether the first userOp of loop is NOT before
-/// the last defineOp of consumer operand. Because that we need to move the
-/// whole loop structure right before the `firstUserOfLoop`. This utility thus
-/// helps ensuring that no invalid IR is formed, i.e. no backward slice of
-/// consumerOp is dominated by the `firstUserOfLoop`. Saying that:
+/// This utility currently checks whether the first userOp of loop is NOT
+/// before the last defineOp of consumer operand. Because that we need to move
+/// the whole loop structure right before the `firstUserOfLoop`. This utility
+/// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
+/// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
 ///
 /// ```
 /// %0 = scf.for() {
@@ -1634,9 +1689,9 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
 /// %3 = consumerOp(%2)
 /// ```
 ///
-/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
-/// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
-/// use-def chain violation:
+/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
+/// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
+/// a.k.a. use-def chain violation:
 ///
 /// ```
 /// %0:2 = scf.for() {
@@ -1650,10 +1705,10 @@ static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
 ///
 /// @param loopOp: loop operation
 /// @param consumerOp: consumer operation
-/// @param reorderOperations: the flag controls whether to reorder the backward
-/// slice w.r.t. the defineOp of `consumerOp` operands.
-/// @return: computed backward slice of consumerOp, but excluding those already
-/// dominates `firstUserOfLoop`.
+/// @param reorderOperations: the flag controls whether to reorder the
+/// backward slice w.r.t. the defineOp of `consumerOp` operands.
+/// @return: computed backward slice of consumerOp, but excluding those
+/// already dominates `firstUserOfLoop`.
 static FailureOr<llvm::SetVector<Operation *>>
 checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
                        bool reorderOperations) {
@@ -1713,8 +1768,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
     if (!isa<TilingInterface>(consumerOp) ||
         !isa<DestinationStyleOpInterface>(consumerOp)) {
       // TODO: We have to init result of consumer before scf.for, use
-      // DestinationStyleOpInterface to get result shape from init for now. Add
-      // support for other op such as op has InferTypeOpInterface.
+      // DestinationStyleOpInterface to get result shape from init for now.
+      // Add support for other op such as op has InferTypeOpInterface.
       continue;
     }
     // Step 2. Check if user stay in the same block.
@@ -1729,7 +1784,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
         checkAssumptionForLoop(loopOp, consumerOp, true);
     if (failed(slice))
       continue;
-    // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
+    // Step 5. If backward sice is not empty, move them before
+    // firstUserOfLoop.
     if (!slice->empty()) {
       mlir::topologicalSort(*slice);
       FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
@@ -1743,8 +1799,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
   return failure();
 }
 
-/// Find the perfectly nested loops outside of given loop(included) sorted from
-/// outer to inner.
+/// Find the perfectly nested loops outside of given loop(included) sorted
+/// from outer to inner.
 ///
 /// E.g.
 ///
@@ -1997,10 +2053,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
     }
 
     // 10. Try to get iter domain position from input position. Use
-    // clonedConsumerOp instead of tiledConsumerOp, because the iteration domain
-    // may require index computation based on the result size. The sizes and
-    // offsets should be the same either way, but using tiledConsumerOp could
-    // lead to some chained unnecessary extra index computation.
+    // clonedConsumerOp instead of tiledConsumerOp, because the iteration
+    // domain may require index computation based on the result size. The
+    // sizes and offsets should be the same either way, but using
+    // tiledConsumerOp could lead to some chained unnecessary extra index
+    // computation.
     SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
     if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
             rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
@@ -2067,7 +2124,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
                                        "unable to add new inits to nest loop");
   }
 
-  // 15. Replace the result of scf loop and consumer op with new loop's results.
+  // 15. Replace the result of scf loop and consumer op with new loop's
+  // results.
 
   for (auto &&[oldResult, newResult] : llvm::zip(
            consumerOp->getResults(),
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 5e903e378daf82..7380b766935ffe 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -250,7 +250,8 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
       return failure();
 
     // Perform the replacement of tiled and fused values.
-    rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
+    rewriter.replaceOp(tilingInterfaceOp,
+                       tiledResults->mergeResult.replacements);
 
     // Report back the relevant handles to the transform op.
     tiledOps.push_back(tiledResults->tiledOps.front());

>From 39650a3ec37f59400d2f2d6ef095b2438783fddb Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 17 Dec 2024 14:57:33 +0000
Subject: [PATCH 2/3] Address comments

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 30 ++++++++-----------
 1 file changed, 13 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 8ece9fb259ddd6..ceb5fe046af791 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -570,23 +570,24 @@ static LogicalResult generateLoopNest(
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
-static LogicalResult
+static FailureOr<SmallVector<Value>>
 createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
                               ArrayRef<OpFoldResult> tileSizes,
-                              SmallVector<Value> &initTensors,
                               const scf::SCFTilingOptions &options) {
+  SmallVector<Value> initTensors;
   Location loc = op->getLoc();
   switch (options.reductionStrategy) {
   case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
-    return tensor::getOrCreateDestinations(rewriter, loc, op, initTensors);
+    if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
+      return failure();
+    return initTensors;
   case scf::SCFTilingOptions::ReductionTilingStrategy::
       PartialReductionOuterReduction: {
     auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
     if (!redOp) {
       return rewriter.notifyMatchFailure(
-          op, "PartialReductionOuterReduction tiling strategy is only "
-              "supported for operations "
-              "implementing PartialReductionOpInterface");
+          op, "PartialReductionOuterReduction tiling strategy is only supported"
+              "for operations implementing PartialReductionOpInterface");
     }
     // Get reduction dimensions.
     // TODO: PartialReductionOpInterface should really query TilingInterface
@@ -597,14 +598,8 @@ createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
       if (iteratorType == utils::IteratorType::reduction)
         reductionDims.push_back(idx);
     }
-    FailureOr<SmallVector<Value>> maybeInitTensors =
-        redOp.generateInitialTensorForPartialReduction(rewriter, loc, tileSizes,
-                                                       reductionDims);
-    if (failed(maybeInitTensors)) {
-      return failure();
-    }
-    initTensors = maybeInitTensors.value();
-    return success();
+    return redOp.generateInitialTensorForPartialReduction(
+        rewriter, loc, tileSizes, reductionDims);
   }
   default:
     return rewriter.notifyMatchFailure(op,
@@ -1037,12 +1032,13 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   };
 
   // 6. Find the destination tensors to use for the operation.
-  SmallVector<Value> initTensors;
-  if (failed(createInitialTensorsForTiling(rewriter, op, tileSizes, initTensors,
-                                           options))) {
+  FailureOr<SmallVector<Value>> maybeInits =
+      createInitialTensorsForTiling(rewriter, op, tileSizes, options);
+  if (failed(maybeInits)) {
     return rewriter.notifyMatchFailure(
         op, "unable to create initial tensors for tiling");
   }
+  SmallVector<Value> &initTensors = maybeInits.value();
 
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;

>From 5d7aade7f19e5dae35ed150545472e78d483fe98 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 18 Dec 2024 13:12:02 +0000
Subject: [PATCH 3/3] Add TODO

---
 mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index ceb5fe046af791..ef5d4370e78102 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -656,6 +656,9 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
                                     resultOffset, resultSize);
   case scf::SCFTilingOptions::ReductionTilingStrategy::
       PartialReductionOuterReduction: {
+    // TODO: This does not work for non identity accesses to the result tile.
+    // The proper fix is to add a getPartialResultTilePosition method to
+    // PartialReductionOpInterface.
     resultOffset =
         SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
     for (size_t i = 0; i < offsets.size(); i++) {



More information about the Mlir-commits mailing list