[Mlir-commits] [mlir] [mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (PR #120115)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 16 09:22:18 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

This patch unifies the tiling implementation for tileUsingFor and tileReductionUsingFor. This is done by passing an addition option to SCFTilingOptions, allowing it to set how reduction dimensions should be tiled. Currently, there are 3 different options for reduction tiling: FullReduction (old tileUsingFor), PartialReductionOuterReduction (old tileReductionUsingFor) and PartialReductionOuterParallel (linalg::tileReductionUsingForall, this isn't implemented in this patch).

The patch makes tileReductionUsingFor use the tileUsingFor implementation with the new reduction tiling options.

There are no test changes because the implementation was doing almost the exactly same thing. This was also tested in IREE (which uses both these APIs heavily) and there were no test changes.

---

Patch is 39.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120115.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+39-18) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+7-6) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+258-200) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+2-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9f5f9f3fca97ad..d2cddfe00ac78e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,36 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify how reduction dimensions should be tiled.
+  ///
+  /// Tiling can be thought of as splitting a dimension into 2 and materializing
+  /// the outer dimension as a loop:
+  ///
+  /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+  ///
+  /// For parallel dimensions, the split can only happen in one way, with both
+  /// dimensions being parallel. For reduction dimensions however, there is a
+  /// choice in how we split the reduction dimension. This enum exposes this
+  /// choice.
+  enum class ReductionTilingStrategy {
+    // [reduction] -> [reduction1, reduction2]
+    // -> loop[reduction1] { [reduction2] }
+    FullReduction,
+    // [reduction] -> [reduction1, parallel2]
+    // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+    PartialReductionOuterReduction,
+    // [reduction] -> [parallel1, reduction2]
+    // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+    PartialReductionOuterParallel
+  };
+  ReductionTilingStrategy reductionStrategy =
+      ReductionTilingStrategy::FullReduction;
+  SCFTilingOptions &
+  setReductionTilingStrategy(ReductionTilingStrategy strategy) {
+    reductionStrategy = strategy;
+    return *this;
+  }
+
   /// Specify mapping of loops to devices. This is only respected when the loop
   /// constructs support such a mapping (like `scf.forall`). Will be ignored
   /// when using loop constructs that dont support such a mapping (like
@@ -102,11 +132,16 @@ struct SCFTilingResult {
   /// matter except the last op. The replacements are expected to be the results
   /// of the last op.
   SmallVector<Operation *> tiledOps;
+  /// The initial destination values passed to the tiled operations.
+  SmallVector<Value> initialValues;
   /// The `scf.for` operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
-  /// Values to use as replacements for the untiled op. Is the same size as the
-  /// number of results of the untiled op.
-  SmallVector<Value> replacements;
+  /// The result generated by the loop nest in tiling, may hold partial results,
+  /// which need to be merged to match the computation of the untiled operation.
+  /// `mergeResult` contains the operations used to perform this merge from
+  /// partial results and the values that can be used as replacements of
+  /// the untiled operation.
+  MergeResult mergeResult;
   /// Slices generated after tiling that can be used for fusing with the tiled
   /// producer.
   SmallVector<Operation *> generatedSlices;
@@ -300,20 +335,6 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
 FailureOr<SmallVector<scf::ForOp>>
 lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
 
-/// Transformation information returned after reduction tiling.
-struct SCFReductionTilingResult {
-  /// The partial reduction tiled op generated.
-  SmallVector<Operation *> parallelTiledOps;
-  /// The final reduction operation merging all the partial reductions.
-  SmallVector<Operation *> mergeOps;
-  /// Initial values used for reduction.
-  SmallVector<Value> initialValues;
-  /// The loop operations that iterate over the tiles.
-  SmallVector<LoopLikeOpInterface> loops;
-  /// The replacements to use for the results of the tiled operation.
-  SmallVector<Value> replacements;
-};
-
 /// Method to tile a reduction and generate a parallel op within a serial loop.
 /// Each of the partial reductions are calculated in parallel. Then after the
 /// loop all the partial reduction are merged into a final reduction.
@@ -338,7 +359,7 @@ struct SCFReductionTilingResult {
 /// %6 = linalg.generic %1 ["parallel", "reduction"]
 ///   : tensor<7x4xf32> -> tensor<7xf32>
 /// ```
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
 tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
                       ArrayRef<OpFoldResult> tileSize);
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8839faf4cafb2d..66a3947e0f91fc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2224,7 +2224,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
     return emitDefaultDefiniteFailure(target);
 
   if (target->getNumResults())
-    rewriter.replaceOp(target, maybeTilingResult->replacements);
+    rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
   else
     rewriter.eraseOp(target);
 
@@ -2631,17 +2631,18 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
+  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
       rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
   if (failed(result))
     return emitDefaultSilenceableFailure(target);
+  rewriter.replaceOp(target, result->mergeResult.replacements);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
-  for (auto parallelTiledOp : result->parallelTiledOps)
+  for (auto parallelTiledOp : result->tiledOps)
     results.push_back(parallelTiledOp);
-  for (auto mergeOp : result->mergeOps)
+  for (auto mergeOp : result->mergeResult.mergeOps)
     results.push_back(mergeOp);
   results.push_back(result->loops.front());
   return DiagnosedSilenceableFailure::success();
@@ -3065,7 +3066,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
     if (failed(maybeTilingResult))
       return DiagnosedSilenceableFailure::definiteFailure();
 
-    rewriter.replaceOp(op, maybeTilingResult->replacements);
+    rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
 
     tiled.append(maybeTilingResult->tiledOps);
     for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3304,7 +3305,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
 
-  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
+  rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
 
   tilingResult = *maybeTilingResult;
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 6a4a6b43933806..8ece9fb259ddd6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -570,6 +570,146 @@ static LogicalResult generateLoopNest(
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
+static LogicalResult
+createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
+                              ArrayRef<OpFoldResult> tileSizes,
+                              SmallVector<Value> &initTensors,
+                              const scf::SCFTilingOptions &options) {
+  Location loc = op->getLoc();
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return tensor::getOrCreateDestinations(rewriter, loc, op, initTensors);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    FailureOr<SmallVector<Value>> maybeInitTensors =
+        redOp.generateInitialTensorForPartialReduction(rewriter, loc, tileSizes,
+                                                       reductionDims);
+    if (failed(maybeInitTensors)) {
+      return failure();
+    }
+    initTensors = maybeInitTensors.value();
+    return success();
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
+static FailureOr<TilingResult>
+getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
+                       ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
+                       ArrayRef<OpFoldResult> sizes,
+                       const scf::SCFTilingOptions &options) {
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return op.getTiledImplementation(rewriter, offsets, sizes);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
+                                        offsets, sizes, reductionDims);
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
+static LogicalResult
+getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
+                      TilingInterface op, ArrayRef<OpFoldResult> offsets,
+                      ArrayRef<OpFoldResult> sizes,
+                      SmallVector<OpFoldResult> &resultOffset,
+                      SmallVector<OpFoldResult> &resultSize,
+                      const scf::SCFTilingOptions &options) {
+
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return op.getResultTilePosition(rewriter, index, offsets, sizes,
+                                    resultOffset, resultSize);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    resultOffset =
+        SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
+    for (size_t i = 0; i < offsets.size(); i++) {
+      resultSize.push_back(
+          tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
+    }
+    return success();
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+  }
+}
+
+static FailureOr<MergeResult>
+mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
+                   ValueRange partialResults,
+                   const scf::SCFTilingOptions &options) {
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    // No need to merge results for reduction tiling strategy.
+    return MergeResult{{}, partialResults};
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
+                                 reductionDims);
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
 /// Append the specified additional `newInitOperands` operands to the
 /// loops existing `init` operands (or similar), and replace `loopOp` with
 /// the new loop that has the additional init operands. The loop body of
@@ -710,11 +850,11 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
       });
 }
 
-/// Method to add new init values to a loop nest. Updates `loops` in-place with
-/// new loops that use the `newInitValues`.
-/// The outer-loops are updated to yield the new result values of the inner
-/// loop. For the innermost loop, the call back `getNewYields` is invoked to get
-/// the additional values to yield form the innermost loop.
+/// Method to add new init values to a loop nest. Updates `loops` in-place
+/// with new loops that use the `newInitValues`. The outer-loops are updated
+/// to yield the new result values of the inner loop. For the innermost loop,
+/// the call back `getNewYields` is invoked to get the additional values to
+/// yield form the innermost loop.
 static LogicalResult addInitOperandsToLoopNest(
     RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
     ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
@@ -852,9 +992,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     auto clonedOp = cast<TilingInterface>(
         cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
 
-    // 5b. Early return cloned op if tiling is not happening. We can not return
-    // the original op because it could lead to
-    // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
+    // 5b. Early return cloned op if tiling is not happening. We can not
+    // return the original op because it could lead to `rewriter.replaceOp(op,
+    // op->getResults())` and users would get crash.
     if (llvm::all_of(tileSizes, isZeroIndex)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
@@ -864,7 +1004,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     }
 
     // 5c. Tile the cloned operation.
-    tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
+    tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
+                                          offsets, sizes, options);
     if (failed(tilingResult)) {
       rewriter.eraseOp(clonedOp);
       return op.emitOpError("faild to tile operation");
@@ -879,8 +1020,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
          llvm::enumerate(tilingResult->tiledValues)) {
       tiledResults.push_back(tiledValue);
       SmallVector<OpFoldResult> resultOffset, resultSize;
-      if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
-                                          resultOffset, resultSize))) {
+      if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
+                                       sizes, resultOffset, resultSize,
+                                       options))) {
         for (auto op : tilingResult->tiledOps) {
           rewriter.eraseOp(op);
         }
@@ -895,158 +1037,64 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   };
 
   // 6. Find the destination tensors to use for the operation.
-  SmallVector<Value> destinationTensors;
-  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
-                                             destinationTensors))) {
-    return rewriter.notifyMatchFailure(op,
-                                       "unable to create destination tensors");
+  SmallVector<Value> initTensors;
+  if (failed(createInitialTensorsForTiling(rewriter, op, tileSizes, initTensors,
+                                           options))) {
+    return rewriter.notifyMatchFailure(
+        op, "unable to create initial tensors for tiling");
   }
 
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;
   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
-                              tileSizes, numThreads, destinationTensors,
+                              tileSizes, numThreads, initTensors,
                               innerYieldTiledValuesFn, loops)))
     return op.emitOpError("failed to generate tiling loops");
   assert(succeeded(tilingResult) &&
          "expected tiling result to be computed after loop generation");
 
-  // If loops are empty, the tiled op is used as the replacement for the untiled
-  // op.
+  SmallVector<Value> partialResults;
   if (loops.empty()) {
-    return scf::SCFTilingResult{tilingResult->tiledOps, loops,
-                                tilingResult->tiledValues,
-                                tilingResult->generatedSlices};
+    // If loops are empty, the tiled op is used as the replacement for the
+    // untiled op.
+    partialResults = tilingResult->tiledValues;
+  } else {
+    partialResults = llvm::map_to_vector(loops.front()->getResults(),
+                                         [](OpResult r) -> Value { return r; });
+  }
+
+  FailureOr<MergeResult> mergeResult =
+      mergeTilingResults(rewriter, op, partialResults, options);
+  if (failed(mergeResult)) {
+    return rewriter.notifyMatchFailure(
+        op, "Failed to merge partial results from tiling");
   }
 
-  SmallVector<Value> replacements = llvm::map_to_vector(
-      loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+  return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
+                              mergeResult.value(),
                               tilingResult->generatedSlices};
 }
 
-FailureOr<scf::SCFReductionTilingResult>
+FailureOr<scf::SCFTilingResult>
 mlir::scf::tileReductionUsingScf(RewriterBase &b,
                                  PartialReductionOpInterface op,
                                  ArrayRef<OpFoldResult> tileSizes) {
-  Location loc = op.getLoc();
-  // Ops implementing PartialReductionOpInterface are expected to implement
-  // TilingInterface.
-  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
-  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
-  auto tileSizesVector = llvm::to_vector(tileSizes);
-  if (tileSizesVector.size() < iterationDomain.size()) {
-    auto zero = b.getIndexAttr(0);
-    tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
-                           zero);
-  }
-  SmallVector<utils::IteratorType> iterators =
-      tilingInterfaceOp.getLoopIteratorTypes();
-
-  SmallVector<int> reductionDims;
-  for (auto [idx, iteratorType] :
-       llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
-    if (iteratorType == utils::IteratorType::reduction)
-      reductionDims.push_back(idx);
-  }
-
-  // 2. create the ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/120115


More information about the Mlir-commits mailing list