[Mlir-commits] [mlir] 7181785 - [mlir][PartialReductionTilingInterface] Generalize implementation of `tileUsingSCF` for `ReductionTilingStrategy::PartialOuterReduction`. (#143467)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 23 11:23:50 PDT 2025


Author: MaheshRavishankar
Date: 2025-06-23T11:23:46-07:00
New Revision: 71817856f7f4c407d76a12fbbdde9ac3e89dd0a1

URL: https://github.com/llvm/llvm-project/commit/71817856f7f4c407d76a12fbbdde9ac3e89dd0a1
DIFF: https://github.com/llvm/llvm-project/commit/71817856f7f4c407d76a12fbbdde9ac3e89dd0a1.diff

LOG: [mlir][PartialReductionTilingInterface] Generalize implementation of `tileUsingSCF` for `ReductionTilingStrategy::PartialOuterReduction`. (#143467)

This is a precursor to generalizing the `tileUsingSCF` to handle
`ReductionTilingStrategy::PartialOuterParallel` strategy. This change
itself is generalizing/refactoring the current implementation that
supports only `ReductionTilingStrategy::PartialOuterReduction`.

Changes in this PR
- Move the `ReductionTilingStrategy` enum out of
  `scf::SCFTilingOptions` and make them visible to `TilingInterface`.
- `PartialTilingInterface` changes
  - Pass the `tilingStrategy` used for partial reduction to
    `tileToPartialReduction`.
  - Pass the reduction dimension along as `const
    llvm::SetVector<unsigned> &`.
- Allow `scf::SCFTilingOptions` to set the reduction dimensions that
  are to be tiled.
- Change `structured.tiled_reduction_using_for` to allow specification
  of the reduction dimensions to be partially tiled.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/include/mlir/Interfaces/TilingInterface.h
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c5650470fdc8d..38c8734c47381 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1859,6 +1859,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
       - the result-combining op,
       - the parent `for` op.
 
+    The `reduction_dims` can be used to specify the subset of reduction dimensions
+    of the operation to tile. If left unspecified, all reduction dimensions are
+    tiled.
+
     #### Example:
 
     ```
@@ -1909,7 +1913,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
   let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_op,
                       TransformHandleTypeInterface:$combining_op,
@@ -1922,6 +1927,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
   let assemblyFormat = [{
     $target
+    (`reduction_dims` `=` $reduction_dims^)?
     `by` `tile_sizes` `=` $tile_sizes
     attr-dict
     `:` functional-type(operands, results)

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..9feb04dbe03c1 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify mapping of loops to devices. This is only respected when the loop
+  /// constructs support such a mapping (like `scf.forall`). Will be ignored
+  /// when using loop constructs that dont support such a mapping (like
+  /// `scf.for`)
+  SmallVector<Attribute> mappingVector = {};
+  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+    mappingVector = llvm::to_vector(mapping);
+    return *this;
+  }
+
+  //-------------------------------------------------------------------------//
+  // Options related reduction tiling
+  //-------------------------------------------------------------------------//
+
   /// Specify how reduction dimensions should be tiled.
-  ///
-  /// Tiling can be thought of as splitting a dimension into 2 and materializing
-  /// the outer dimension as a loop:
-  ///
-  /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
-  ///
-  /// For parallel dimensions, the split can only happen in one way, with both
-  /// dimensions being parallel. For reduction dimensions however, there is a
-  /// choice in how we split the reduction dimension. This enum exposes this
-  /// choice.
-  enum class ReductionTilingStrategy {
-    // [reduction] -> [reduction1, reduction2]
-    // -> loop[reduction1] { [reduction2] }
-    FullReduction,
-    // [reduction] -> [reduction1, parallel2]
-    // -> loop[reduction1] { [parallel2] }; merge[reduction1]
-    PartialReductionOuterReduction,
-    // [reduction] -> [parallel1, reduction2]
-    // -> loop[parallel1] { [reduction2] }; merge[parallel1]
-    PartialReductionOuterParallel
-  };
   ReductionTilingStrategy reductionStrategy =
       ReductionTilingStrategy::FullReduction;
   SCFTilingOptions &
@@ -115,13 +108,13 @@ struct SCFTilingOptions {
     return *this;
   }
 
-  /// Specify mapping of loops to devices. This is only respected when the loop
-  /// constructs support such a mapping (like `scf.forall`). Will be ignored
-  /// when using loop constructs that dont support such a mapping (like
-  /// `scf.for`)
-  SmallVector<Attribute> mappingVector = {};
-  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
-    mappingVector = llvm::to_vector(mapping);
+  /// Specify the reduction dimensions to be tiled. Note that this needs to be
+  /// specified. If left unspecified, then none of the reduction dimensions are
+  /// tiled.
+  SetVector<unsigned> reductionDims;
+  SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
+    reductionDims.clear();
+    reductionDims.insert(dims.begin(), dims.end());
     return *this;
   }
 };

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index b33aa1489c311..8693cbea7f0b0 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -36,6 +36,27 @@ struct TilingResult {
   SmallVector<Operation *> generatedSlices;
 };
 
+/// Tiling can be thought of as splitting a dimension into 2 and
+/// materializing the outer dimension as a loop:
+///
+/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+///
+/// For parallel dimensions, the split can only happen in one way, with both
+/// dimensions being parallel. For reduction dimensions however, there is a
+/// choice in how we split the reduction dimension. This enum exposes this
+/// choice.
+enum class ReductionTilingStrategy {
+  // [reduction] -> [reduction1, reduction2]
+  // -> loop[reduction1] { [reduction2] }
+  FullReduction,
+  // [reduction] -> [reduction1, parallel2]
+  // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+  PartialReductionOuterReduction,
+  // [reduction] -> [parallel1, reduction2]
+  // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+  PartialReductionOuterParallel
+};
+
 /// Container for the result of merge operation of tiling.
 /// - `mergeOps` contains operations created during the merge.
 /// - `replacements` contains the values that represents the result of the

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..43a27e1cb6cdf 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -384,7 +384,7 @@ def PartialReductionOpInterface :
             "::mlir::OpBuilder &":$b,
             "Location":$loc,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
-            "::mlir::ArrayRef<int>":$reductionDim),
+            "const ::mlir::SetVector<unsigned> &":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -402,10 +402,11 @@ def PartialReductionOpInterface :
         /*args=*/(ins
             "::mlir::OpBuilder &":$b,
             "Location ":$loc,
+            "::mlir::ReductionTilingStrategy":$tilingStrategy,
             "ValueRange":$init,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
-            "::mlir::ArrayRef<int>":$reductionDims),
+            "const ::llvm::SetVector<unsigned> &":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -423,7 +424,7 @@ def PartialReductionOpInterface :
             "::mlir::OpBuilder &":$b,
             "Location ":$loc,
             "ValueRange":$partialReduce,
-            "::mlir::ArrayRef<int>":$reductionDim),
+            "const ::mlir::SetVector<unsigned> &":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -443,9 +444,9 @@ def PartialReductionOpInterface :
             "unsigned":$resultNumber,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+            "const ::mlir::SetVector<unsigned> &":$reductionDims,
             "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
-            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
-            "::mlir::ArrayRef<int>":$reductionDims),
+            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d9a0ba02f4fe4..f2b7b34256847 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2947,10 +2947,11 @@ void transform::TileReductionUsingForOp::build(
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   MLIRContext *ctx = builder.getContext();
   auto opTy = transform::AnyOpType::get(ctx);
-  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+  auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
   build(builder, result,
         /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
         /*target=*/target,
+        /*reduction_dims=*/nullptr,
         /*tile_sizes=*/staticTileSizesAttr);
 }
 
@@ -2966,12 +2967,30 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
         target->getLoc(),
         "Operation should implement PartialReductionOpInterface");
   }
-  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
-      rewriter, partialReductionOp,
-      getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
-  if (failed(result))
-    return emitDefaultSilenceableFailure(target);
+  SmallVector<unsigned> reductionDims =
+      extractFromIntegerArrayAttr<unsigned>(getReductionDims());
+  if (reductionDims.empty()) {
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+  }
+
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+  options.setReductionTilingStrategy(
+      ReductionTilingStrategy::PartialReductionOuterReduction);
+  options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
+  options.setReductionDims(reductionDims);
+  FailureOr<scf::SCFTilingResult> result =
+      scf::tileUsingSCF(rewriter, partialReductionOp, options);
+
+  if (failed(result)) {
+    return emitSilenceableFailure(getLoc(),
+                                  "failed to tile using partial reduction");
+  }
   rewriter.replaceOp(target, result->replacements);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4162aa0b71e6d..8a5a2e54cdda2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
 }
 
 FailureOr<StaticContinuousTileSizeSpecification>
-mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
-                                               unsigned dimension,
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
                                                unsigned targetSize) {
 
   assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
 
   // Find the trip count of the iteration space dimension for which the tile
   // sizes are computed.
-  Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
-                                                    loopRanges[dimension].size);
+  Value loopRange =
+      getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
   ContinuousTileSizeSpecification spec;
 
   // Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
     return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
                                     "many elements as number of threads");
-  int reductionDim = static_cast<int>(redDims.front());
 
   if (redDims.front() >= numThreads.size())
     return b.notifyMatchFailure(
         op, "reduction dimension must be mapped to threads");
 
   // 1. Create the inital tensor value.
+  unsigned reductionDim = redDims.front();
+  SetVector<unsigned> reductionDims;
+  reductionDims.insert(reductionDim);
   FailureOr<SmallVector<Value>> maybeInitTensors =
       op.generateInitialTensorForPartialReduction(b, loc, numThreads,
-                                                  reductionDim);
+                                                  reductionDims);
   if (failed(maybeInitTensors))
     return b.notifyMatchFailure(
         op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   // 7. Merge the partial reductions.
   b.setInsertionPointAfter(forallOp);
   FailureOr<MergeResult> mergeResult =
-      op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
+      op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
   if (failed(mergeResult)) {
     return failure();
   }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..f649bc49a8fbd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include <optional>
@@ -327,23 +328,48 @@ struct LinalgOpTilingInterface
 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
 
-/// Return an AffineMap for a partial result for the given result number,
-/// assuming the partial tiling strategy is outer-reduction loop +
-/// inner-parallel tile. The returned AffineMap can be used as the replacement
-/// AffineMap for the inner-parallel tile linalg op for the given result number.
-///
-/// The new AffineMap is the old AffineMap with reduction dimensions appended
-/// at end.
-static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
-                                           ArrayRef<int> reductionDims,
-                                           unsigned resultNumber) {
-  AffineMap map =
-      linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
-  for (int redPos : reductionDims) {
-    map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
-                           map.getNumResults());
+/// Return an AffineMaps to use for the `outs` operands of the linalg op
+/// generated for partial results. The new AffineMap is the AffineMap of the
+/// untiled op with reduction dimensions appended at end in order in which they
+/// were specified during tiling.
+static SmallVector<AffineMap>
+getPartialResultAffineMaps(LinalgOp linalgOp,
+                           const SetVector<unsigned> &reductionDims) {
+  auto partialReductionMaps = llvm::map_to_vector(
+      linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
+        AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
+        for (auto redPos : reductionDims) {
+          map =
+              map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
+                               map.getNumResults());
+        }
+        return map;
+      });
+  return partialReductionMaps;
+}
+
+/// Return the slice of the `initValue` to use as input to the partial reduction
+/// op generated.
+static Operation *getInitSliceForOuterReduction(
+    OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
+    AffineMap partialReductionMap) {
+  int64_t initRank = partialReductionMap.getNumResults();
+  SmallVector<OpFoldResult> initOffsets, initSizes;
+  SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1));
+  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+    unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    if (reductionDims.contains(dim)) {
+      initOffsets.push_back(b.getIndexAttr(0));
+    } else {
+      initOffsets.push_back(offsets[dim]);
+    }
+    initSizes.push_back(sizes[dim]);
   }
-  return map;
+  // TODO: Use SubsetExtractOpInterface here once available.
+  auto extractSlice = b.create<tensor::ExtractSliceOp>(
+      loc, initValue, initOffsets, initSizes, initStrides);
+  return extractSlice;
 }
 
 /// External model implementation of PartialReductionInterface for
@@ -354,13 +380,16 @@ struct LinalgOpPartialReductionInterface
           LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
   FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
       Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
-      ArrayRef<int> reductionDims) const {
+      const SetVector<unsigned> &reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
-    OpBuilder::InsertionGuard guard(b);
 
+    OpBuilder::InsertionGuard guard(b);
     if (linalgOp.hasPureBufferSemantics())
       return op->emitOpError("expected operation to have tensor semantics");
 
+    SmallVector<AffineMap> partialResultMaps =
+        getPartialResultAffineMaps(linalgOp, reductionDims);
+
     // LinalgOp implements TilingInterface.
     auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
     SmallVector<OpFoldResult> shape =
@@ -377,8 +406,8 @@ struct LinalgOpPartialReductionInterface
     }
 
     SmallVector<Value> inits;
-    for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
-         ++initIdx) {
+    for (auto [initIdx, result, partialMap] :
+         llvm::enumerate(linalgOp->getResults(), partialResultMaps)) {
       SmallVector<Operation *, 4> combinerOps;
       if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
                           combinerOps) ||
@@ -392,16 +421,13 @@ struct LinalgOpPartialReductionInterface
             "Failed to get an identity value for the reduction operation.");
 
       // Append the new partial result dimensions.
-      AffineMap partialMap =
-          getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
       SmallVector<OpFoldResult> partialResultShape;
       for (AffineExpr dimExpr : partialMap.getResults()) {
         auto dim = cast<AffineDimExpr>(dimExpr);
         partialResultShape.push_back(tiledShape[dim.getPosition()]);
       }
 
-      Type elType =
-          getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
+      Type elType = getElementTypeOrSelf(result.getType());
       Value emptyTensor =
           b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
       Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
@@ -415,23 +441,25 @@ struct LinalgOpPartialReductionInterface
 
   FailureOr<TilingResult>
   tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
+                         ReductionTilingStrategy tilingStrategy,
                          ValueRange init, ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes,
-                         ArrayRef<int> reductionDims) const {
+                         const SetVector<unsigned> &reductionDims) const {
+    if (tilingStrategy !=
+        ReductionTilingStrategy::PartialReductionOuterReduction) {
+      // TODO: Add support for `PartialReductionOuterParallel` strategy.
+      return op->emitOpError("unsupported partial reduction tiling with "
+                             "`PartialReductionOuterParallel` strategy");
+    }
     OpBuilder::InsertionGuard guard(b);
     auto linalgOp = cast<LinalgOp>(op);
 
+    SmallVector<AffineMap> partialReductionMaps =
+        getPartialResultAffineMaps(linalgOp, reductionDims);
+
     // Step 1. Extend init maps to have reduction dimension dims, since we
     // are converting them to parallel dimensions.
-    SmallVector<AffineMap> newInitMaps;
-    newInitMaps.reserve(linalgOp.getNumDpsInits());
-    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
-      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
-      // this with a for range loop when we have it.
-      AffineMap newMap =
-          getPartialResultAffineMap(linalgOp, reductionDims, idx);
-      newInitMaps.push_back(newMap);
-    }
+    SmallVector<AffineMap> newInitMaps = partialReductionMaps;
 
     // Step 2a: Extract a slice of the input operands.
     SmallVector<Value> tiledInputs = makeTiledShapes(
@@ -443,31 +471,21 @@ struct LinalgOpPartialReductionInterface
 
     // Step 2b: Extract a slice of the init operands.
     SmallVector<Value, 1> tiledInits;
-    for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {
-      int64_t initRank = valueMap.getNumResults();
-      SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0));
-      SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1));
-      SmallVector<OpFoldResult> initSizes;
-      for (AffineExpr dimExpr : valueMap.getResults()) {
-        auto dim = cast<AffineDimExpr>(dimExpr);
-        initSizes.push_back(sizes[dim.getPosition()]);
-      }
-      // TODO: Use SubsetExtractOpInterface here once available.
-      auto extractSlice = b.create<tensor::ExtractSliceOp>(
-          loc, valueToTile, initOffset, initSizes, initStride);
-      tiledInits.push_back(extractSlice);
-      generatedSlices.push_back(extractSlice);
+    for (auto [partialReductionMap, valueToTile] :
+         llvm::zip_equal(partialReductionMaps, init)) {
+      Operation *sliceOp =
+          getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes,
+                                        reductionDims, partialReductionMap);
+      tiledInits.push_back(sliceOp->getResult(0));
+      generatedSlices.push_back(sliceOp);
     }
 
     // Update the indexing maps.
     SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
-    // Change the init maps.
-    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
-      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
-      // this with a for range loop when we have it.
-      OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
-      int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
-      newMaps[mapIdx] = newInitMaps[idx];
+    for (auto [initOperand, newInitMap] :
+         llvm::zip_equal(linalgOp.getDpsInitsMutable(), newInitMaps)) {
+      int mapIdx = linalgOp.getIndexingMapIndex(&initOperand);
+      newMaps[mapIdx] = newInitMap;
     }
 
     // Step 3. Change the reduction dim iterator types.
@@ -477,9 +495,9 @@ struct LinalgOpPartialReductionInterface
       newIteratorTypes[dim] = utils::IteratorType::parallel;
 
     // Step 4. Create the new generic op.
-    auto genericOp =
-        b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
-                            tiledInits, newMaps, newIteratorTypes);
+    auto resultTypes = ValueRange(tiledInits).getTypes();
+    auto genericOp = b.create<GenericOp>(loc, resultTypes, tiledInputs,
+                                         tiledInits, newMaps, newIteratorTypes);
     IRMapping mapping;
     op->getRegion(0).cloneInto(&genericOp.getRegion(),
                                genericOp.getRegion().begin(), mapping);
@@ -490,23 +508,24 @@ struct LinalgOpPartialReductionInterface
         generatedSlices};
   }
 
-  FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
-                                         Location loc, ValueRange partialReduce,
-                                         ArrayRef<int> reductionDims) const {
+  FailureOr<MergeResult>
+  mergeReductions(Operation *op, OpBuilder &b, Location loc,
+                  ValueRange partialReduce,
+                  const SetVector<unsigned> &reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
+    SmallVector<AffineMap> partialReductionMaps =
+        getPartialResultAffineMaps(linalgOp, reductionDims);
 
     // Permute the reduction dims as permuted by the partial result map.
-
-    int64_t numInits = linalgOp.getNumDpsInits();
     SmallVector<Operation *> mergeOperations;
     SmallVector<Value> replacements;
-    for (int idx : llvm::seq(numInits)) {
+    for (auto [idx, init, partialResult, partialMap] : llvm::enumerate(
+             linalgOp.getDpsInits(), partialReduce, partialReductionMaps)) {
+      unsigned initIdx = idx;
       // linalg.reduce's iteration space is the tiled result's iteration space
       // (and not the tiled operation's iteration space). To account for this,
       // permute the reduction dimensions based on the partial result map of the
       // tiled result.
-      AffineMap partialMap =
-          getPartialResultAffineMap(linalgOp, reductionDims, idx);
       SmallVector<int64_t> partialReductionDims;
       for (auto [resultNum, dimExpr] :
            llvm::enumerate(partialMap.getResults())) {
@@ -516,15 +535,13 @@ struct LinalgOpPartialReductionInterface
         }
       }
 
-      Value partialResult = partialReduce[idx];
-      Value init = linalgOp.getDpsInits()[idx];
-
       auto reduction = b.create<linalg::ReduceOp>(
           loc, partialResult, init, partialReductionDims,
-          [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
+          [&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) {
             // Get the combiner op.
             SmallVector<Operation *, 4> combinerOps;
-            matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
+            matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
+                           combinerOps);
             Operation *clonedReductionOp = b.clone(*combinerOps[0]);
             // Combine the input at idx and output at numInits + idx.
             clonedReductionOp->setOperand(0, inputs[0]);
@@ -542,14 +559,14 @@ struct LinalgOpPartialReductionInterface
   LogicalResult getPartialResultTilePosition(
       Operation *op, OpBuilder &b, unsigned resultNumber,
       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      const SetVector<unsigned> &reductionDims,
       SmallVector<OpFoldResult> &resultOffsets,
-      SmallVector<OpFoldResult> &resultSizes,
-      ArrayRef<int> reductionDims) const {
+      SmallVector<OpFoldResult> &resultSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
+    SmallVector<AffineMap> partialReductionMaps =
+        getPartialResultAffineMaps(linalgOp, reductionDims);
 
-    AffineMap partialMap =
-        getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
-    for (AffineExpr dimExpr : partialMap.getResults()) {
+    for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) {
       unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
       resultSizes.push_back(sizes[dim]);
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 3f29dd3ac5e48..e7c076024e67b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -77,9 +77,8 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
 //===----------------------------------------------------------------------===//
 
 /// Verify the tile size options are set in a consistent manner.
-static LogicalResult
-verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
-                      const scf::SCFTilingOptions &options) {
+static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc,
+                                   const scf::SCFTilingOptions &options) {
   // Specifying number of threads is only supported on `scf.forall` op.
   if (options.numThreadsComputationFunction &&
       options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
@@ -156,7 +155,9 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
 }
 
 /// Checks if any of the tiled loops are not parallel.
-static void checkSafeToTileToForall(TilingInterface op,
+static LogicalResult checkTileSizes(TilingInterface op,
+                                    scf::SCFTilingOptions::LoopType loopType,
+                                    ReductionTilingStrategy reductionStrategy,
                                     ArrayRef<OpFoldResult> tileSizes,
                                     ArrayRef<OpFoldResult> numThreads) {
   auto iterators = op.getLoopIteratorTypes();
@@ -165,28 +166,46 @@ static void checkSafeToTileToForall(TilingInterface op,
   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
          "when specified, expected number of threads to use for each loop");
 
+  bool isParallelTiling = false, isReductionTiling = false;
   for (auto [index, iterator, tileSize] :
        llvm::enumerate(iterators, tileSizes)) {
-    // If num threads is specified, check that it is greater than one only for
-    // parallel dimensions.
-    if (!numThreads.empty()) {
-      if (std::optional<int64_t> constNumThreads =
-              getConstantIntValue(numThreads[index])) {
-        if (constNumThreads.value() > 1 &&
+    if (!isConstantIntValue(tileSize, 0)) {
+      isParallelTiling |= iterator == utils::IteratorType::parallel;
+      isReductionTiling |= iterator == utils::IteratorType::reduction;
+    }
+
+    if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
+        reductionStrategy == ReductionTilingStrategy::FullReduction) {
+      // If num threads is specified, check that it is greater than one only for
+      // parallel dimensions.
+      if (!numThreads.empty()) {
+        if (std::optional<int64_t> constNumThreads =
+                getConstantIntValue(numThreads[index])) {
+          if (constNumThreads.value() > 1 &&
+              iterator != utils::IteratorType::parallel) {
+            op.emitWarning() << "tiling is not thread safe at axis #" << index;
+          }
+        }
+        continue;
+      }
+
+      if (std::optional<int64_t> constTileSize =
+              getConstantIntValue(tileSize)) {
+        if (constTileSize.value() > 0 &&
             iterator != utils::IteratorType::parallel) {
           op.emitWarning() << "tiling is not thread safe at axis #" << index;
         }
       }
-      continue;
     }
+  }
 
-    if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
-      if (constTileSize.value() > 0 &&
-          iterator != utils::IteratorType::parallel) {
-        op.emitWarning() << "tiling is not thread safe at axis #" << index;
-      }
-    }
+  if (isParallelTiling && isReductionTiling &&
+      reductionStrategy != ReductionTilingStrategy::FullReduction) {
+    return op->emitOpError(
+        "combined parallel and reduction tiling is not supported with partial "
+        "reduction tiling strategies");
   }
+  return success();
 }
 
 /// Check if `stride` evenly divides the trip count `size - offset`.
@@ -575,35 +594,20 @@ createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
                               const scf::SCFTilingOptions &options) {
   SmallVector<Value> initTensors;
   Location loc = op->getLoc();
-  switch (options.reductionStrategy) {
-  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
       return failure();
     return initTensors;
-  case scf::SCFTilingOptions::ReductionTilingStrategy::
-      PartialReductionOuterReduction: {
-    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
-    if (!redOp) {
-      return rewriter.notifyMatchFailure(
-          op, "PartialReductionOuterReduction tiling strategy is only supported"
-              "for operations implementing PartialReductionOpInterface");
-    }
-    // Get reduction dimensions.
-    // TODO: PartialReductionOpInterface should really query TilingInterface
-    // itself and find reduction dimensions.
-    SmallVector<int> reductionDims;
-    for (auto [idx, iteratorType] :
-         llvm::enumerate(op.getLoopIteratorTypes())) {
-      if (iteratorType == utils::IteratorType::reduction)
-        reductionDims.push_back(idx);
-    }
-    return redOp.generateInitialTensorForPartialReduction(
-        rewriter, loc, tileSizes, reductionDims);
   }
-  default:
-    return rewriter.notifyMatchFailure(op,
-                                       "unhandled reduction tiling strategy");
+
+  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+  if (!redOp) {
+    return rewriter.notifyMatchFailure(
+        op, "PartialReductionOuterReduction tiling strategy is only supported"
+            "for operations implementing PartialReductionOpInterface");
   }
+  return redOp.generateInitialTensorForPartialReduction(
+      rewriter, loc, tileSizes, options.reductionDims);
 }
 
 static FailureOr<TilingResult>
@@ -611,34 +615,20 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
                        ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
                        ArrayRef<OpFoldResult> sizes,
                        const scf::SCFTilingOptions &options) {
-  switch (options.reductionStrategy) {
-  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     return op.getTiledImplementation(rewriter, offsets, sizes);
-  case scf::SCFTilingOptions::ReductionTilingStrategy::
-      PartialReductionOuterReduction: {
-    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
-    if (!redOp) {
-      return rewriter.notifyMatchFailure(
-          op, "PartialReductionOuterReduction tiling strategy is only "
-              "supported for operations "
-              "implementing PartialReductionOpInterface");
-    }
-    // Get reduction dimensions.
-    // TODO: PartialReductionOpInterface should really query TilingInterface
-    // itself and find reduction dimensions.
-    SmallVector<int> reductionDims;
-    for (auto [idx, iteratorType] :
-         llvm::enumerate(op.getLoopIteratorTypes())) {
-      if (iteratorType == utils::IteratorType::reduction)
-        reductionDims.push_back(idx);
-    }
-    return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
-                                        offsets, sizes, reductionDims);
   }
-  default:
-    return rewriter.notifyMatchFailure(op,
-                                       "unhandled reduction tiling strategy");
+
+  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+  if (!redOp) {
+    return rewriter.notifyMatchFailure(
+        op, "PartialReductionOuterReduction tiling strategy is only "
+            "supported for operations "
+            "implementing PartialReductionOpInterface");
   }
+  return redOp.tileToPartialReduction(rewriter, op.getLoc(),
+                                      options.reductionStrategy, regionIterArg,
+                                      offsets, sizes, options.reductionDims);
 }
 
 static LogicalResult
@@ -649,70 +639,37 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
                       SmallVector<OpFoldResult> &resultSize,
                       const scf::SCFTilingOptions &options) {
 
-  switch (options.reductionStrategy) {
-  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     return op.getResultTilePosition(rewriter, index, offsets, sizes,
                                     resultOffset, resultSize);
-  case scf::SCFTilingOptions::ReductionTilingStrategy::
-      PartialReductionOuterReduction: {
-    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
-    if (!redOp) {
-      return rewriter.notifyMatchFailure(
-          op, "PartialReductionOuterReduction tiling strategy is only supported"
-              "for operations implementing PartialReductionOpInterface");
-    }
-    // Get reduction dimensions.
-    // TODO: PartialReductionOpInterface should really query TilingInterface
-    // itself and find reduction dimensions.
-    SmallVector<int> reductionDims;
-    for (auto [idx, iteratorType] :
-         llvm::enumerate(op.getLoopIteratorTypes())) {
-      if (iteratorType == utils::IteratorType::reduction)
-        reductionDims.push_back(idx);
-    }
-    return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
-                                              resultOffset, resultSize,
-                                              reductionDims);
   }
-  default:
-    return rewriter.notifyMatchFailure(op,
-                                       "unhandled reduction tiling strategy");
+  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+  if (!redOp) {
+    return rewriter.notifyMatchFailure(
+        op, "PartialReductionOuterReduction tiling strategy is only supported"
+            "for operations implementing PartialReductionOpInterface");
   }
+  return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
+                                            options.reductionDims, resultOffset,
+                                            resultSize);
 }
 
 static FailureOr<MergeResult>
 mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
                    ValueRange partialResults,
                    const scf::SCFTilingOptions &options) {
-  switch (options.reductionStrategy) {
-  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
-    // No need to merge results for reduction tiling strategy.
-    return MergeResult{{}, partialResults};
-  case scf::SCFTilingOptions::ReductionTilingStrategy::
-      PartialReductionOuterReduction: {
-    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
-    if (!redOp) {
-      return rewriter.notifyMatchFailure(
-          op, "PartialReductionOuterReduction tiling strategy is only "
-              "supported for operations "
-              "implementing PartialReductionOpInterface");
-    }
-    // Get reduction dimensions.
-    // TODO: PartialReductionOpInterface should really query TilingInterface
-    // itself and find reduction dimensions.
-    SmallVector<int> reductionDims;
-    for (auto [idx, iteratorType] :
-         llvm::enumerate(op.getLoopIteratorTypes())) {
-      if (iteratorType == utils::IteratorType::reduction)
-        reductionDims.push_back(idx);
-    }
-    return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
-                                 reductionDims);
-  }
-  default:
-    return rewriter.notifyMatchFailure(op,
-                                       "unhandled reduction tiling strategy");
+  assert(options.reductionStrategy != ReductionTilingStrategy::FullReduction &&
+         "expected merge to be called for only partial reduction cases");
+
+  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+  if (!redOp) {
+    return rewriter.notifyMatchFailure(
+        op, "PartialReductionOuterReduction tiling strategy is only "
+            "supported for operations "
+            "implementing PartialReductionOpInterface");
   }
+  return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
+                               options.reductionDims);
 }
 
 /// Append the specified additional `newInitOperands` operands to the
@@ -932,7 +889,7 @@ static LogicalResult addInitOperandsToLoopNest(
 FailureOr<scf::SCFTilingResult>
 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
                         const scf::SCFTilingOptions &options) {
-  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
+  if (failed(verifyOptions(rewriter, op.getLoc(), options))) {
     return failure();
   }
 
@@ -949,8 +906,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
 
   // Check if it is safe to tile. This is hold over from previous iterations
   // of tile to for-all. Consider dropping it.
-  if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
-    checkSafeToTileToForall(op, tileSizes, numThreads);
+  if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy,
+                            tileSizes, numThreads))) {
+    return failure();
   }
 
   // 3. If there is an interchange specified, permute the iteration domain and
@@ -1073,8 +1031,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
                                          [](OpResult r) -> Value { return r; });
 
   // For the full reduction case, there is nothing more to do.
-  if (options.reductionStrategy ==
-      scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) {
+  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     return scf::SCFTilingResult{
         tilingResult->tiledOps,        initTensors, loops, loopResults,
         tilingResult->generatedSlices, {}};
@@ -1102,9 +1059,13 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   scf::SCFTilingOptions options;
   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
   options.setReductionTilingStrategy(
-      scf::SCFTilingOptions::ReductionTilingStrategy::
-          PartialReductionOuterReduction);
+      ReductionTilingStrategy::PartialReductionOuterReduction);
   options.setTileSizes(tileSize);
+  SmallVector<unsigned> reductionDims;
+  for (auto [index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
+    if (iteratorType == utils::IteratorType::reduction)
+      reductionDims.push_back(index);
+  options.setReductionDims(reductionDims);
   return tileUsingSCF(b, op, options);
 }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 9d34c80822d0e..009ab17786696 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -343,7 +343,6 @@ module attributes {transform.with_named_sequence} {
 module {
   func.func @fail_for_float_neutral(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
     // expected-error @below {{'linalg.generic' op Failed to get an identity value for the reduction operation.}}
-    // expected-note @below {{when applied to this op}}
     %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
     ^bb0(%in: f32, %out: f32):
       %1 = llvm.fmul %in, %in  : f32
@@ -355,7 +354,7 @@ module {
   module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
       %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-      // expected-error @below {{transform.structured.tile_reduction_using_for failed to apply}}
+      // expected-error @below {{failed to tile using partial reduction}}
       %fill_op, %split_linalg_op, %combining_linalg_op, %for_op = transform.structured.tile_reduction_using_for %0 by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
       transform.yield
     }
@@ -480,3 +479,167 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:   }
 //     CHECK:   linalg.reduce
 //     CHECK:   return
+
+// -----
+
+// Check that only one of the reduction dimension can be tiled (in this case outer).
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+  func.func @reduction_tile_single_of_multiple_reduction_outer(
+        %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+    %0 = linalg.generic {
+        indexing_maps = [#map, #map1, #map2],
+        iterator_types = ["parallel", "reduction", "reduction"]}
+        ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %1, %out : f32
+      linalg.yield %2 : f32
+    } -> tensor<4096xf32>
+    return %0 : tensor<4096xf32>
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
+          transform.structured.tile_reduction_using_for %0 reduction_dims = [1] by tile_sizes = [0, 2]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+//      CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//      CHECK: @reduction_tile_single_of_multiple_reduction_outer(
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C86:.+]] = arith.constant 86 : index
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<4096x2xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill
+// CHECK-SAME:       outs(%[[EMPTY]] :
+//      CHECK:   %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[C86]] step %[[C2]]
+// CHECK-SAME:       iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
+//      CHECK:     %[[PARTIAL_RESULT:.+]] = linalg.generic
+// CHECK-SAME:         indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
+// CHECK-SAME:         iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME:         outs(%[[ITER_ARG]] :
+//      CHECK:     scf.yield %[[PARTIAL_RESULT]]
+//      CHECK:   %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME:       ins(%[[RESULT]] :
+// CHECK-SAME:       outs(%[[INIT]] :
+// CHECK-SAME:       dimensions = [1]
+//      CHECK:   return %[[REDUCE]]
+
+// -----
+
+// Check that only one of the reduction dimension can be tiled (in this case inner).
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+  func.func @reduction_tile_single_of_multiple_reduction_inner(
+        %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+    %0 = linalg.generic {
+        indexing_maps = [#map, #map1, #map2],
+        iterator_types = ["parallel", "reduction", "reduction"]}
+        ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %1, %out : f32
+      linalg.yield %2 : f32
+    } -> tensor<4096xf32>
+    return %0 : tensor<4096xf32>
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
+          transform.structured.tile_reduction_using_for %0 reduction_dims = [2] by tile_sizes = [0, 0, 64]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+//      CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//      CHECK: @reduction_tile_single_of_multiple_reduction_inner(
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C64:.+]] = arith.constant 64 : index
+//  CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<4096x64xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill
+// CHECK-SAME:       outs(%[[EMPTY]] :
+//      CHECK:   %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C64]]
+// CHECK-SAME:       iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
+//      CHECK:     %[[PARTIAL_RESULT:.+]] = linalg.generic
+// CHECK-SAME:         indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
+// CHECK-SAME:         iterator_types = ["parallel", "reduction", "parallel"]
+// CHECK-SAME:         outs(%[[ITER_ARG]] :
+//      CHECK:     scf.yield %[[PARTIAL_RESULT]]
+//      CHECK:   %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME:       ins(%[[RESULT]] :
+// CHECK-SAME:       outs(%[[INIT]] :
+// CHECK-SAME:       dimensions = [1]
+//      CHECK:   return %[[REDUCE]]
+
+// -----
+
+// Check that both the reduction dimensions are tiled but the dimensions in the output are swapped.
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+  func.func @reduction_tile_single_of_multiple_reduction_reversed(
+        %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+    %0 = linalg.generic {
+        indexing_maps = [#map, #map1, #map2],
+        iterator_types = ["parallel", "reduction", "reduction"]}
+        ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %1, %out : f32
+      linalg.yield %2 : f32
+    } -> tensor<4096xf32>
+    return %0 : tensor<4096xf32>
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
+          transform.structured.tile_reduction_using_for %0 reduction_dims = [2, 1] by tile_sizes = [0, 2, 64]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+//      CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+//      CHECK: @reduction_tile_single_of_multiple_reduction_reversed(
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C64:.+]] = arith.constant 64 : index
+//  CHECK-DAG:   %[[C86:.+]] = arith.constant 86 : index
+//  CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<4096x64x2xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill
+// CHECK-SAME:       outs(%[[EMPTY]] :
+//      CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C86]] step %[[C2]]
+// CHECK-SAME:       iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
+//      CHECK:     %[[RESULT0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C64]]
+// CHECK-SAME:         iter_args(%[[ITER_ARG0:.+]] = %[[ITER_ARG]])
+//      CHECK:       %[[PARTIAL_RESULT:.+]] = linalg.generic
+// CHECK-SAME:           indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
+// CHECK-SAME:           iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:           outs(%[[ITER_ARG0]] :
+//      CHECK:       scf.yield %[[PARTIAL_RESULT]]
+//      CHECK      scf.yield %[[RESULT0]]
+//      CHECK:   %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME:       ins(%[[RESULT]] :
+// CHECK-SAME:       outs(%[[INIT]] :
+// CHECK-SAME:       dimensions = [1, 2]
+//      CHECK: return %[[REDUCE]]


        


More information about the Mlir-commits mailing list