[Mlir-commits] [mlir] [mlir][PartialReductionTilingInterface] Add support for `ReductionTilingStrategy::PartialReductionOuterParallel` in `tileUsingSCF`. (PR #143988)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 12 15:55:34 PDT 2025


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/143988

Following up from https://github.com/llvm/llvm-project/pull/143467,
this PR adds support for
`ReductionTilingStrategy::PartialReductionOuterParallel` to
`tileUsingSCF`. The implementation of
`PartialReductionTilingInterface` for `Linalg` ops has been updated to
support this strategy as well. This makes the `tileUsingSCF` come on
par with `linalg::tileReductionUsingForall` which will be deprecated
subsequently.

Changes summary
- `PartialReductionTilingInterface` changes :
  - `tileToPartialReduction` method needed to get the induction
    variables of the generated tile loops. This was needed to keep the
    generated code similar to `linalg::tileReductionUsingForall`,
    specifically to create a simplified access for slicing the
    intermediate partial results tensor when tiled in `num_threads` mode.
  - `getPartialResultTilePosition` methods needs the induction
    varialbes for the generated tile loops for the same reason above,
    and also needs the `tilingStrategy` to be passed in to generate
    correct code.

The tests in `transform-tile-reduction.mlir` testing the
`linalg::tileReductionUsingForall` have been moved over to test
`scf::tileUsingSCF` with
`ReductionTilingStrategy::PartialReductionOuterParallel`
strategy. Some of the test that were doing further cyclic distribution
of the transformed code from tiling are removed. Those seem like two
separate transformation that were merged into one. Ideally that would
need to happen when resolving the `scf.forall` rather than during
tiling.

Please review only the top commit.  Depends on https://github.com/llvm/llvm-project/pull/143467

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>

>From c348b19112718f5ba4c75b2e8bcb8ca072cdf992 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Sun, 8 Jun 2025 16:20:06 -0700
Subject: [PATCH 1/2] [mlir][PartialReductionTilingInterface] Generalize
 implementation of `tileUsingSCF` for
 `ReductionTilingStrategy::PartialOuterReduction`.

This is a precursor to generalizing the `tileUsingSCF` to handle
`ReductionTilingStrategy::PartialOuterParallel` strategy. This change
itself is generalizing/refactoring the current implementation that
supports only `ReductionTilingStrategy::PartialOuterReduction`.

Changes in this PR
- Move the `ReductionTilingStrategy` enum out of
  `scf::SCFTilingOptions` and make them visible to `TilingInterface`.
- `PartialTilingInterface` changes
  - Pass the `tilingStrategy` used for partial reduction to
    `tileToPartialReduction`.
  - Pass the reduction dimension along as `const
    llvm::SetVector<unsigned> &`.
- Allow `scf::SCFTilingOptions` to set the reduction dimensions that
  are to be tiled.
- Change `structured.tiled_reduction_using_for` to allow specification
  of the reduction dimensions to be partially tiled.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../Linalg/TransformOps/LinalgTransformOps.td |   8 +-
 .../SCF/Transforms/TileUsingInterface.h       |  46 ++--
 .../include/mlir/Interfaces/TilingInterface.h |  21 ++
 .../mlir/Interfaces/TilingInterface.td        |  11 +-
 .../TransformOps/LinalgTransformOps.cpp       |  31 ++-
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp |  15 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 169 +++++++-------
 .../SCF/Transforms/TileUsingInterface.cpp     | 213 +++++++-----------
 .../Linalg/transform-tile-reduction.mlir      | 167 +++++++++++++-
 9 files changed, 430 insertions(+), 251 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..d0591ae122fbb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1767,6 +1767,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
       - the result-combining op,
       - the parent `for` op.
 
+    The `reduction_dims` can be used to specify the subset of reduction dimensions
+    of the operation to tile. If left unspecified, all reduction dimensions are
+    tiled.
+
     #### Example:
 
     ```
@@ -1817,7 +1821,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
   let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_op,
                       TransformHandleTypeInterface:$combining_op,
@@ -1830,6 +1835,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
   let assemblyFormat = [{
     $target
+    (`reduction_dims` `=` $reduction_dims^)?
     `by` `tile_sizes` `=` $tile_sizes
     attr-dict
     `:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..01ad64b76b15e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify mapping of loops to devices. This is only respected when the loop
+  /// constructs support such a mapping (like `scf.forall`). Will be ignored
+  /// when using loop constructs that dont support such a mapping (like
+  /// `scf.for`)
+  SmallVector<Attribute> mappingVector = {};
+  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+    mappingVector = llvm::to_vector(mapping);
+    return *this;
+  }
+
+  //-------------------------------------------------------------------------//
+  // Options related reduction tiling
+  //-------------------------------------------------------------------------//
+
   /// Specify how reduction dimensions should be tiled.
-  ///
-  /// Tiling can be thought of as splitting a dimension into 2 and materializing
-  /// the outer dimension as a loop:
-  ///
-  /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
-  ///
-  /// For parallel dimensions, the split can only happen in one way, with both
-  /// dimensions being parallel. For reduction dimensions however, there is a
-  /// choice in how we split the reduction dimension. This enum exposes this
-  /// choice.
-  enum class ReductionTilingStrategy {
-    // [reduction] -> [reduction1, reduction2]
-    // -> loop[reduction1] { [reduction2] }
-    FullReduction,
-    // [reduction] -> [reduction1, parallel2]
-    // -> loop[reduction1] { [parallel2] }; merge[reduction1]
-    PartialReductionOuterReduction,
-    // [reduction] -> [parallel1, reduction2]
-    // -> loop[parallel1] { [reduction2] }; merge[parallel1]
-    PartialReductionOuterParallel
-  };
   ReductionTilingStrategy reductionStrategy =
       ReductionTilingStrategy::FullReduction;
   SCFTilingOptions &
@@ -115,13 +108,10 @@ struct SCFTilingOptions {
     return *this;
   }
 
-  /// Specify mapping of loops to devices. This is only respected when the loop
-  /// constructs support such a mapping (like `scf.forall`). Will be ignored
-  /// when using loop constructs that dont support such a mapping (like
-  /// `scf.for`)
-  SmallVector<Attribute> mappingVector = {};
-  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
-    mappingVector = llvm::to_vector(mapping);
+  /// Specify the reduction dimensions to be tiled.
+  SetVector<unsigned> reductionDims;
+  SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
+    reductionDims.insert(dims.begin(), dims.end());
     return *this;
   }
 };
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index b33aa1489c311..8693cbea7f0b0 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -36,6 +36,27 @@ struct TilingResult {
   SmallVector<Operation *> generatedSlices;
 };
 
+/// Tiling can be thought of as splitting a dimension into 2 and
+/// materializing the outer dimension as a loop:
+///
+/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+///
+/// For parallel dimensions, the split can only happen in one way, with both
+/// dimensions being parallel. For reduction dimensions however, there is a
+/// choice in how we split the reduction dimension. This enum exposes this
+/// choice.
+enum class ReductionTilingStrategy {
+  // [reduction] -> [reduction1, reduction2]
+  // -> loop[reduction1] { [reduction2] }
+  FullReduction,
+  // [reduction] -> [reduction1, parallel2]
+  // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+  PartialReductionOuterReduction,
+  // [reduction] -> [parallel1, reduction2]
+  // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+  PartialReductionOuterParallel
+};
+
 /// Container for the result of merge operation of tiling.
 /// - `mergeOps` contains operations created during the merge.
 /// - `replacements` contains the values that represents the result of the
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..9358d8b82abce 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -384,7 +384,7 @@ def PartialReductionOpInterface :
             "::mlir::OpBuilder &":$b,
             "Location":$loc,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
-            "::mlir::ArrayRef<int>":$reductionDim),
+            "const ::mlir::SetVector<unsigned> &":$reductionDim),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -402,10 +402,11 @@ def PartialReductionOpInterface :
         /*args=*/(ins
             "::mlir::OpBuilder &":$b,
             "Location ":$loc,
+            "::mlir::ReductionTilingStrategy":$tilingStrategy,
             "ValueRange":$init,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
-            "::mlir::ArrayRef<int>":$reductionDims),
+            "const ::llvm::SetVector<unsigned> &":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -423,7 +424,7 @@ def PartialReductionOpInterface :
             "::mlir::OpBuilder &":$b,
             "Location ":$loc,
             "ValueRange":$partialReduce,
-            "::mlir::ArrayRef<int>":$reductionDim),
+            "const ::mlir::SetVector<unsigned> &":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -443,9 +444,9 @@ def PartialReductionOpInterface :
             "unsigned":$resultNumber,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+            "const ::mlir::SetVector<unsigned> &":$reductionDims,
             "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
-            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
-            "::mlir::ArrayRef<int>":$reductionDims),
+            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..c003825264920 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2775,10 +2775,11 @@ void transform::TileReductionUsingForOp::build(
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   MLIRContext *ctx = builder.getContext();
   auto opTy = transform::AnyOpType::get(ctx);
-  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+  auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
   build(builder, result,
         /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
         /*target=*/target,
+        /*reduction_dims=*/nullptr,
         /*tile_sizes=*/staticTileSizesAttr);
 }
 
@@ -2794,12 +2795,30 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
         target->getLoc(),
         "Operation should implement PartialReductionOpInterface");
   }
-  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
-      rewriter, partialReductionOp,
-      getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
-  if (failed(result))
-    return emitDefaultSilenceableFailure(target);
+  SmallVector<unsigned> reductionDims =
+      extractFromIntegerArrayAttr<unsigned>(getReductionDims());
+  if (reductionDims.empty()) {
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+  }
+
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+  options.setReductionTilingStrategy(
+      ReductionTilingStrategy::PartialReductionOuterReduction);
+  options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
+  options.setReductionDims(reductionDims);
+  FailureOr<scf::SCFTilingResult> result =
+      scf::tileUsingSCF(rewriter, partialReductionOp, options);
+
+  if (failed(result)) {
+    return emitSilenceableFailure(getLoc(),
+                                  "failed to tile using partial reduction");
+  }
   rewriter.replaceOp(target, result->replacements);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4162aa0b71e6d..8a5a2e54cdda2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
 }
 
 FailureOr<StaticContinuousTileSizeSpecification>
-mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
-                                               unsigned dimension,
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
                                                unsigned targetSize) {
 
   assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
 
   // Find the trip count of the iteration space dimension for which the tile
   // sizes are computed.
-  Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
-                                                    loopRanges[dimension].size);
+  Value loopRange =
+      getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
   ContinuousTileSizeSpecification spec;
 
   // Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
     return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
                                     "many elements as number of threads");
-  int reductionDim = static_cast<int>(redDims.front());
 
   if (redDims.front() >= numThreads.size())
     return b.notifyMatchFailure(
         op, "reduction dimension must be mapped to threads");
 
   // 1. Create the inital tensor value.
+  unsigned reductionDim = redDims.front();
+  SetVector<unsigned> reductionDims;
+  reductionDims.insert(reductionDim);
   FailureOr<SmallVector<Value>> maybeInitTensors =
       op.generateInitialTensorForPartialReduction(b, loc, numThreads,
-                                                  reductionDim);
+                                                  reductionDims);
   if (failed(maybeInitTensors))
     return b.notifyMatchFailure(
         op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   // 7. Merge the partial reductions.
   b.setInsertionPointAfter(forallOp);
   FailureOr<MergeResult> mergeResult =
-      op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
+      op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
   if (failed(mergeResult)) {
     return failure();
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..f649bc49a8fbd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include <optional>
@@ -327,23 +328,48 @@ struct LinalgOpTilingInterface
 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
 
-/// Return an AffineMap for a partial result for the given result number,
-/// assuming the partial tiling strategy is outer-reduction loop +
-/// inner-parallel tile. The returned AffineMap can be used as the replacement
-/// AffineMap for the inner-parallel tile linalg op for the given result number.
-///
-/// The new AffineMap is the old AffineMap with reduction dimensions appended
-/// at end.
-static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
-                                           ArrayRef<int> reductionDims,
-                                           unsigned resultNumber) {
-  AffineMap map =
-      linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
-  for (int redPos : reductionDims) {
-    map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
-                           map.getNumResults());
+/// Return an AffineMaps to use for the `outs` operands of the linalg op
+/// generated for partial results. The new AffineMap is the AffineMap of the
+/// untiled op with reduction dimensions appended at end in order in which they
+/// were specified during tiling.
+static SmallVector<AffineMap>
+getPartialResultAffineMaps(LinalgOp linalgOp,
+                           const SetVector<unsigned> &reductionDims) {
+  auto partialReductionMaps = llvm::map_to_vector(
+      linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
+        AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
+        for (auto redPos : reductionDims) {
+          map =
+              map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
+                               map.getNumResults());
+        }
+        return map;
+      });
+  return partialReductionMaps;
+}
+
+/// Return the slice of the `initValue` to use as input to the partial reduction
+/// op generated.
+static Operation *getInitSliceForOuterReduction(
+    OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
+    AffineMap partialReductionMap) {
+  int64_t initRank = partialReductionMap.getNumResults();
+  SmallVector<OpFoldResult> initOffsets, initSizes;
+  SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1));
+  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+    unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    if (reductionDims.contains(dim)) {
+      initOffsets.push_back(b.getIndexAttr(0));
+    } else {
+      initOffsets.push_back(offsets[dim]);
+    }
+    initSizes.push_back(sizes[dim]);
   }
-  return map;
+  // TODO: Use SubsetExtractOpInterface here once available.
+  auto extractSlice = b.create<tensor::ExtractSliceOp>(
+      loc, initValue, initOffsets, initSizes, initStrides);
+  return extractSlice;
 }
 
 /// External model implementation of PartialReductionInterface for
@@ -354,13 +380,16 @@ struct LinalgOpPartialReductionInterface
           LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
   FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
       Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
-      ArrayRef<int> reductionDims) const {
+      const SetVector<unsigned> &reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
-    OpBuilder::InsertionGuard guard(b);
 
+    OpBuilder::InsertionGuard guard(b);
     if (linalgOp.hasPureBufferSemantics())
       return op->emitOpError("expected operation to have tensor semantics");
 
+    SmallVector<AffineMap> partialResultMaps =
+        getPartialResultAffineMaps(linalgOp, reductionDims);
+
     // LinalgOp implements TilingInterface.
     auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
     SmallVector<OpFoldResult> shape =
@@ -377,8 +406,8 @@ struct LinalgOpPartialReductionInterface
     }
 
     SmallVector<Value> inits;
-    for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
-         ++initIdx) {
+    for (auto [initIdx, result, partialMap] :
+         llvm::enumerate(linalgOp->getResults(), partialResultMaps)) {
       SmallVector<Operation *, 4> combinerOps;
       if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
                           combinerOps) ||
@@ -392,16 +421,13 @@ struct LinalgOpPartialReductionInterface
             "Failed to get an identity value for the reduction operation.");
 
       // Append the new partial result dimensions.
-      AffineMap partialMap =
-          getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
       SmallVector<OpFoldResult> partialResultShape;
       for (AffineExpr dimExpr : partialMap.getResults()) {
         auto dim = cast<AffineDimExpr>(dimExpr);
         partialResultShape.push_back(tiledShape[dim.getPosition()]);
       }
 
-      Type elType =
-          getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
+      Type elType = getElementTypeOrSelf(result.getType());
       Value emptyTensor =
           b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
       Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
@@ -415,23 +441,25 @@ struct LinalgOpPartialReductionInterface
 
   FailureOr<TilingResult>
   tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
+                         ReductionTilingStrategy tilingStrategy,
                          ValueRange init, ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes,
-                         ArrayRef<int> reductionDims) const {
+                         const SetVector<unsigned> &reductionDims) const {
+    if (tilingStrategy !=
+        ReductionTilingStrategy::PartialReductionOuterReduction) {
+      // TODO: Add support for `PartialReductionOuterParallel` strategy.
+      return op->emitOpError("unsupported partial reduction tiling with "
+                             "`PartialReductionOuterParallel` strategy");
+    }
     OpBuilder::InsertionGuard guard(b);
     auto linalgOp = cast<LinalgOp>(op);
 
+    SmallVector<AffineMap> partialReductionMaps =
+        getPartialResultAffineMaps(linalgOp, reductionDims);
+
     // Step 1. Extend init maps to have reduction dimension dims, since we
     // are converting them to parallel dimensions.
-    SmallVector<AffineMap> newInitMaps;
-    newInitMaps.reserve(linalgOp.getNumDpsInits());
-    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
-      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
-      // this with a for range loop when we have it.
-      AffineMap newMap =
-          getPartialResultAffineMap(linalgOp, reductionDims, idx);
-      newInitMaps.push_back(newMap);
-    }
+    SmallVector<AffineMap> newInitMaps = partialReductionMaps;
 
     // Step 2a: Extract a slice of the input operands.
     SmallVector<Value> tiledInputs = makeTiledShapes(
@@ -443,31 +471,21 @@ struct LinalgOpPartialReductionInterface
 
     // Step 2b: Extract a slice of the init operands.
     SmallVector<Value, 1> tiledInits;
-    for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {
-      int64_t initRank = valueMap.getNumResults();
-      SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0));
-      SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1));
-      SmallVector<OpFoldResult> initSizes;
-      for (AffineExpr dimExpr : valueMap.getResults()) {
-        auto dim = cast<AffineDimExpr>(dimExpr);
-        initSizes.push_back(sizes[dim.getPosition()]);
-      }
-      // TODO: Use SubsetExtractOpInterface here once available.
-      auto extractSlice = b.create<tensor::ExtractSliceOp>(
-          loc, valueToTile, initOffset, initSizes, initStride);
-      tiledInits.push_back(extractSlice);
-      generatedSlices.push_back(extractSlice);
+    for (auto [partialReductionMap, valueToTile] :
+         llvm::zip_equal(partialReductionMaps, init)) {
+      Operation *sliceOp =
+          getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes,
+                                        reductionDims, partialReductionMap);
+      tiledInits.push_back(sliceOp->getResult(0));
+      generatedSlices.push_back(sliceOp);
     }
 
     // Update the indexing maps.
     SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
-    // Change the init maps.
-    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
-      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
-      // this with a for range loop when we have it.
-      OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
-      int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
-      newMaps[mapIdx] = newInitMaps[idx];
+    for (auto [initOperand, newInitMap] :
+         llvm::zip_equal(linalgOp.getDpsInitsMutable(), newInitMaps)) {
+      int mapIdx = linalgOp.getIndexingMapIndex(&initOperand);
+      newMaps[mapIdx] = newInitMap;
     }
 
     // Step 3. Change the reduction dim iterator types.
@@ -477,9 +495,9 @@ struct LinalgOpPartialReductionInterface
       newIteratorTypes[dim] = utils::IteratorType::parallel;
 
     // Step 4. Create the new generic op.
-    auto genericOp =
-        b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
-                            tiledInits, newMaps, newIteratorTypes);
+    auto resultTypes = ValueRange(tiledInits).getTypes();
+    auto genericOp = b.create<GenericOp>(loc, resultTypes, tiledInputs,
+                                         tiledInits, newMaps, newIteratorTypes);
     IRMapping mapping;
     op->getRegion(0).cloneInto(&genericOp.getRegion(),
                                genericOp.getRegion().begin(), mapping);
@@ -490,23 +508,24 @@ struct LinalgOpPartialReductionInterface
         generatedSlices};
   }
 
-  FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
-                                         Location loc, ValueRange partialReduce,
-                                         ArrayRef<int> reductionDims) const {
+  FailureOr<MergeResult>
+  mergeReductions(Operation *op, OpBuilder &b, Location loc,
+                  ValueRange partialReduce,
+                  const SetVector<unsigned> &reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
+    SmallVector<AffineMap> partialReductionMaps =
+        getPartialResultAffineMaps(linalgOp, reductionDims);
 
     // Permute the reduction dims as permuted by the partial result map.
-
-    int64_t numInits = linalgOp.getNumDpsInits();
     SmallVector<Operation *> mergeOperations;
     SmallVector<Value> replacements;
-    for (int idx : llvm::seq(numInits)) {
+    for (auto [idx, init, partialResult, partialMap] : llvm::enumerate(
+             linalgOp.getDpsInits(), partialReduce, partialReductionMaps)) {
+      unsigned initIdx = idx;
       // linalg.reduce's iteration space is the tiled result's iteration space
       // (and not the tiled operation's iteration space). To account for this,
       // permute the reduction dimensions based on the partial result map of the
       // tiled result.
-      AffineMap partialMap =
-          getPartialResultAffineMap(linalgOp, reductionDims, idx);
       SmallVector<int64_t> partialReductionDims;
       for (auto [resultNum, dimExpr] :
            llvm::enumerate(partialMap.getResults())) {
@@ -516,15 +535,13 @@ struct LinalgOpPartialReductionInterface
         }
       }
 
-      Value partialResult = partialReduce[idx];
-      Value init = linalgOp.getDpsInits()[idx];
-
       auto reduction = b.create<linalg::ReduceOp>(
           loc, partialResult, init, partialReductionDims,
-          [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
+          [&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) {
             // Get the combiner op.
             SmallVector<Operation *, 4> combinerOps;
-            matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
+            matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
+                           combinerOps);
             Operation *clonedReductionOp = b.clone(*combinerOps[0]);
             // Combine the input at idx and output at numInits + idx.
             clonedReductionOp->setOperand(0, inputs[0]);
@@ -542,14 +559,14 @@ struct LinalgOpPartialReductionInterface
   LogicalResult getPartialResultTilePosition(
       Operation *op, OpBuilder &b, unsigned resultNumber,
       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      const SetVector<unsigned> &reductionDims,
       SmallVector<OpFoldResult> &resultOffsets,
-      SmallVector<OpFoldResult> &resultSizes,
-      ArrayRef<int> reductionDims) const {
+      SmallVector<OpFoldResult> &resultSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
+    SmallVector<AffineMap> partialReductionMaps =
+        getPartialResultAffineMaps(linalgOp, reductionDims);
 
-    AffineMap partialMap =
-        getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
-    for (AffineExpr dimExpr : partialMap.getResults()) {
+    for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) {
       unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
       resultSizes.push_back(sizes[dim]);
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 3f29dd3ac5e48..e7c076024e67b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -77,9 +77,8 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
 //===----------------------------------------------------------------------===//
 
 /// Verify the tile size options are set in a consistent manner.
-static LogicalResult
-verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
-                      const scf::SCFTilingOptions &options) {
+static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc,
+                                   const scf::SCFTilingOptions &options) {
   // Specifying number of threads is only supported on `scf.forall` op.
   if (options.numThreadsComputationFunction &&
       options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
@@ -156,7 +155,9 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
 }
 
 /// Checks if any of the tiled loops are not parallel.
-static void checkSafeToTileToForall(TilingInterface op,
+static LogicalResult checkTileSizes(TilingInterface op,
+                                    scf::SCFTilingOptions::LoopType loopType,
+                                    ReductionTilingStrategy reductionStrategy,
                                     ArrayRef<OpFoldResult> tileSizes,
                                     ArrayRef<OpFoldResult> numThreads) {
   auto iterators = op.getLoopIteratorTypes();
@@ -165,28 +166,46 @@ static void checkSafeToTileToForall(TilingInterface op,
   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
          "when specified, expected number of threads to use for each loop");
 
+  bool isParallelTiling = false, isReductionTiling = false;
   for (auto [index, iterator, tileSize] :
        llvm::enumerate(iterators, tileSizes)) {
-    // If num threads is specified, check that it is greater than one only for
-    // parallel dimensions.
-    if (!numThreads.empty()) {
-      if (std::optional<int64_t> constNumThreads =
-              getConstantIntValue(numThreads[index])) {
-        if (constNumThreads.value() > 1 &&
+    if (!isConstantIntValue(tileSize, 0)) {
+      isParallelTiling |= iterator == utils::IteratorType::parallel;
+      isReductionTiling |= iterator == utils::IteratorType::reduction;
+    }
+
+    if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
+        reductionStrategy == ReductionTilingStrategy::FullReduction) {
+      // If num threads is specified, check that it is greater than one only for
+      // parallel dimensions.
+      if (!numThreads.empty()) {
+        if (std::optional<int64_t> constNumThreads =
+                getConstantIntValue(numThreads[index])) {
+          if (constNumThreads.value() > 1 &&
+              iterator != utils::IteratorType::parallel) {
+            op.emitWarning() << "tiling is not thread safe at axis #" << index;
+          }
+        }
+        continue;
+      }
+
+      if (std::optional<int64_t> constTileSize =
+              getConstantIntValue(tileSize)) {
+        if (constTileSize.value() > 0 &&
             iterator != utils::IteratorType::parallel) {
           op.emitWarning() << "tiling is not thread safe at axis #" << index;
         }
       }
-      continue;
     }
+  }
 
-    if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
-      if (constTileSize.value() > 0 &&
-          iterator != utils::IteratorType::parallel) {
-        op.emitWarning() << "tiling is not thread safe at axis #" << index;
-      }
-    }
+  if (isParallelTiling && isReductionTiling &&
+      reductionStrategy != ReductionTilingStrategy::FullReduction) {
+    return op->emitOpError(
+        "combined parallel and reduction tiling is not supported with partial "
+        "reduction tiling strategies");
   }
+  return success();
 }
 
 /// Check if `stride` evenly divides the trip count `size - offset`.
@@ -575,35 +594,20 @@ createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
                               const scf::SCFTilingOptions &options) {
   SmallVector<Value> initTensors;
   Location loc = op->getLoc();
-  switch (options.reductionStrategy) {
-  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
       return failure();
     return initTensors;
-  case scf::SCFTilingOptions::ReductionTilingStrategy::
-      PartialReductionOuterReduction: {
-    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
-    if (!redOp) {
-      return rewriter.notifyMatchFailure(
-          op, "PartialReductionOuterReduction tiling strategy is only supported"
-              "for operations implementing PartialReductionOpInterface");
-    }
-    // Get reduction dimensions.
-    // TODO: PartialReductionOpInterface should really query TilingInterface
-    // itself and find reduction dimensions.
-    SmallVector<int> reductionDims;
-    for (auto [idx, iteratorType] :
-         llvm::enumerate(op.getLoopIteratorTypes())) {
-      if (iteratorType == utils::IteratorType::reduction)
-        reductionDims.push_back(idx);
-    }
-    return redOp.generateInitialTensorForPartialReduction(
-        rewriter, loc, tileSizes, reductionDims);
   }
-  default:
-    return rewriter.notifyMatchFailure(op,
-                                       "unhandled reduction tiling strategy");
+
+  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+  if (!redOp) {
+    return rewriter.notifyMatchFailure(
+        op, "PartialReductionOuterReduction tiling strategy is only supported"
+            "for operations implementing PartialReductionOpInterface");
   }
+  return redOp.generateInitialTensorForPartialReduction(
+      rewriter, loc, tileSizes, options.reductionDims);
 }
 
 static FailureOr<TilingResult>
@@ -611,34 +615,20 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
                        ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
                        ArrayRef<OpFoldResult> sizes,
                        const scf::SCFTilingOptions &options) {
-  switch (options.reductionStrategy) {
-  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     return op.getTiledImplementation(rewriter, offsets, sizes);
-  case scf::SCFTilingOptions::ReductionTilingStrategy::
-      PartialReductionOuterReduction: {
-    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
-    if (!redOp) {
-      return rewriter.notifyMatchFailure(
-          op, "PartialReductionOuterReduction tiling strategy is only "
-              "supported for operations "
-              "implementing PartialReductionOpInterface");
-    }
-    // Get reduction dimensions.
-    // TODO: PartialReductionOpInterface should really query TilingInterface
-    // itself and find reduction dimensions.
-    SmallVector<int> reductionDims;
-    for (auto [idx, iteratorType] :
-         llvm::enumerate(op.getLoopIteratorTypes())) {
-      if (iteratorType == utils::IteratorType::reduction)
-        reductionDims.push_back(idx);
-    }
-    return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
-                                        offsets, sizes, reductionDims);
   }
-  default:
-    return rewriter.notifyMatchFailure(op,
-                                       "unhandled reduction tiling strategy");
+
+  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+  if (!redOp) {
+    return rewriter.notifyMatchFailure(
+        op, "PartialReductionOuterReduction tiling strategy is only "
+            "supported for operations "
+            "implementing PartialReductionOpInterface");
   }
+  return redOp.tileToPartialReduction(rewriter, op.getLoc(),
+                                      options.reductionStrategy, regionIterArg,
+                                      offsets, sizes, options.reductionDims);
 }
 
 static LogicalResult
@@ -649,70 +639,37 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
                       SmallVector<OpFoldResult> &resultSize,
                       const scf::SCFTilingOptions &options) {
 
-  switch (options.reductionStrategy) {
-  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     return op.getResultTilePosition(rewriter, index, offsets, sizes,
                                     resultOffset, resultSize);
-  case scf::SCFTilingOptions::ReductionTilingStrategy::
-      PartialReductionOuterReduction: {
-    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
-    if (!redOp) {
-      return rewriter.notifyMatchFailure(
-          op, "PartialReductionOuterReduction tiling strategy is only supported"
-              "for operations implementing PartialReductionOpInterface");
-    }
-    // Get reduction dimensions.
-    // TODO: PartialReductionOpInterface should really query TilingInterface
-    // itself and find reduction dimensions.
-    SmallVector<int> reductionDims;
-    for (auto [idx, iteratorType] :
-         llvm::enumerate(op.getLoopIteratorTypes())) {
-      if (iteratorType == utils::IteratorType::reduction)
-        reductionDims.push_back(idx);
-    }
-    return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
-                                              resultOffset, resultSize,
-                                              reductionDims);
   }
-  default:
-    return rewriter.notifyMatchFailure(op,
-                                       "unhandled reduction tiling strategy");
+  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+  if (!redOp) {
+    return rewriter.notifyMatchFailure(
+        op, "PartialReductionOuterReduction tiling strategy is only supported"
+            "for operations implementing PartialReductionOpInterface");
   }
+  return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
+                                            options.reductionDims, resultOffset,
+                                            resultSize);
 }
 
 static FailureOr<MergeResult>
 mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
                    ValueRange partialResults,
                    const scf::SCFTilingOptions &options) {
-  switch (options.reductionStrategy) {
-  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
-    // No need to merge results for reduction tiling strategy.
-    return MergeResult{{}, partialResults};
-  case scf::SCFTilingOptions::ReductionTilingStrategy::
-      PartialReductionOuterReduction: {
-    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
-    if (!redOp) {
-      return rewriter.notifyMatchFailure(
-          op, "PartialReductionOuterReduction tiling strategy is only "
-              "supported for operations "
-              "implementing PartialReductionOpInterface");
-    }
-    // Get reduction dimensions.
-    // TODO: PartialReductionOpInterface should really query TilingInterface
-    // itself and find reduction dimensions.
-    SmallVector<int> reductionDims;
-    for (auto [idx, iteratorType] :
-         llvm::enumerate(op.getLoopIteratorTypes())) {
-      if (iteratorType == utils::IteratorType::reduction)
-        reductionDims.push_back(idx);
-    }
-    return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
-                                 reductionDims);
-  }
-  default:
-    return rewriter.notifyMatchFailure(op,
-                                       "unhandled reduction tiling strategy");
+  assert(options.reductionStrategy != ReductionTilingStrategy::FullReduction &&
+         "expected merge to be called for only partial reduction cases");
+
+  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+  if (!redOp) {
+    return rewriter.notifyMatchFailure(
+        op, "PartialReductionOuterReduction tiling strategy is only "
+            "supported for operations "
+            "implementing PartialReductionOpInterface");
   }
+  return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
+                               options.reductionDims);
 }
 
 /// Append the specified additional `newInitOperands` operands to the
@@ -932,7 +889,7 @@ static LogicalResult addInitOperandsToLoopNest(
 FailureOr<scf::SCFTilingResult>
 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
                         const scf::SCFTilingOptions &options) {
-  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
+  if (failed(verifyOptions(rewriter, op.getLoc(), options))) {
     return failure();
   }
 
@@ -949,8 +906,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
 
   // Check if it is safe to tile. This is hold over from previous iterations
   // of tile to for-all. Consider dropping it.
-  if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
-    checkSafeToTileToForall(op, tileSizes, numThreads);
+  if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy,
+                            tileSizes, numThreads))) {
+    return failure();
   }
 
   // 3. If there is an interchange specified, permute the iteration domain and
@@ -1073,8 +1031,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
                                          [](OpResult r) -> Value { return r; });
 
   // For the full reduction case, there is nothing more to do.
-  if (options.reductionStrategy ==
-      scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) {
+  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     return scf::SCFTilingResult{
         tilingResult->tiledOps,        initTensors, loops, loopResults,
         tilingResult->generatedSlices, {}};
@@ -1102,9 +1059,13 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   scf::SCFTilingOptions options;
   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
   options.setReductionTilingStrategy(
-      scf::SCFTilingOptions::ReductionTilingStrategy::
-          PartialReductionOuterReduction);
+      ReductionTilingStrategy::PartialReductionOuterReduction);
   options.setTileSizes(tileSize);
+  SmallVector<unsigned> reductionDims;
+  for (auto [index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
+    if (iteratorType == utils::IteratorType::reduction)
+      reductionDims.push_back(index);
+  options.setReductionDims(reductionDims);
   return tileUsingSCF(b, op, options);
 }
 
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 9d34c80822d0e..009ab17786696 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -343,7 +343,6 @@ module attributes {transform.with_named_sequence} {
 module {
   func.func @fail_for_float_neutral(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
     // expected-error @below {{'linalg.generic' op Failed to get an identity value for the reduction operation.}}
-    // expected-note @below {{when applied to this op}}
     %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
     ^bb0(%in: f32, %out: f32):
       %1 = llvm.fmul %in, %in  : f32
@@ -355,7 +354,7 @@ module {
   module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
       %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-      // expected-error @below {{transform.structured.tile_reduction_using_for failed to apply}}
+      // expected-error @below {{failed to tile using partial reduction}}
       %fill_op, %split_linalg_op, %combining_linalg_op, %for_op = transform.structured.tile_reduction_using_for %0 by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
       transform.yield
     }
@@ -480,3 +479,167 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:   }
 //     CHECK:   linalg.reduce
 //     CHECK:   return
+
+// -----
+
+// Check that only one of the reduction dimension can be tiled (in this case outer).
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+  func.func @reduction_tile_single_of_multiple_reduction_outer(
+        %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+    %0 = linalg.generic {
+        indexing_maps = [#map, #map1, #map2],
+        iterator_types = ["parallel", "reduction", "reduction"]}
+        ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %1, %out : f32
+      linalg.yield %2 : f32
+    } -> tensor<4096xf32>
+    return %0 : tensor<4096xf32>
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
+          transform.structured.tile_reduction_using_for %0 reduction_dims = [1] by tile_sizes = [0, 2]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+//      CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//      CHECK: @reduction_tile_single_of_multiple_reduction_outer(
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C86:.+]] = arith.constant 86 : index
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<4096x2xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill
+// CHECK-SAME:       outs(%[[EMPTY]] :
+//      CHECK:   %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[C86]] step %[[C2]]
+// CHECK-SAME:       iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
+//      CHECK:     %[[PARTIAL_RESULT:.+]] = linalg.generic
+// CHECK-SAME:         indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
+// CHECK-SAME:         iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME:         outs(%[[ITER_ARG]] :
+//      CHECK:     scf.yield %[[PARTIAL_RESULT]]
+//      CHECK:   %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME:       ins(%[[RESULT]] :
+// CHECK-SAME:       outs(%[[INIT]] :
+// CHECK-SAME:       dimensions = [1]
+//      CHECK:   return %[[REDUCE]]
+
+// -----
+
+// Check that only one of the reduction dimension can be tiled (in this case inner).
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+  func.func @reduction_tile_single_of_multiple_reduction_inner(
+        %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+    %0 = linalg.generic {
+        indexing_maps = [#map, #map1, #map2],
+        iterator_types = ["parallel", "reduction", "reduction"]}
+        ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %1, %out : f32
+      linalg.yield %2 : f32
+    } -> tensor<4096xf32>
+    return %0 : tensor<4096xf32>
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
+          transform.structured.tile_reduction_using_for %0 reduction_dims = [2] by tile_sizes = [0, 0, 64]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+//      CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//      CHECK: @reduction_tile_single_of_multiple_reduction_inner(
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C64:.+]] = arith.constant 64 : index
+//  CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<4096x64xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill
+// CHECK-SAME:       outs(%[[EMPTY]] :
+//      CHECK:   %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C64]]
+// CHECK-SAME:       iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
+//      CHECK:     %[[PARTIAL_RESULT:.+]] = linalg.generic
+// CHECK-SAME:         indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
+// CHECK-SAME:         iterator_types = ["parallel", "reduction", "parallel"]
+// CHECK-SAME:         outs(%[[ITER_ARG]] :
+//      CHECK:     scf.yield %[[PARTIAL_RESULT]]
+//      CHECK:   %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME:       ins(%[[RESULT]] :
+// CHECK-SAME:       outs(%[[INIT]] :
+// CHECK-SAME:       dimensions = [1]
+//      CHECK:   return %[[REDUCE]]
+
+// -----
+
+// Check that both the reduction dimensions are tiled but the dimensions in the output are swapped.
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+  func.func @reduction_tile_single_of_multiple_reduction_reversed(
+        %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+    %0 = linalg.generic {
+        indexing_maps = [#map, #map1, #map2],
+        iterator_types = ["parallel", "reduction", "reduction"]}
+        ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %1, %out : f32
+      linalg.yield %2 : f32
+    } -> tensor<4096xf32>
+    return %0 : tensor<4096xf32>
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
+          transform.structured.tile_reduction_using_for %0 reduction_dims = [2, 1] by tile_sizes = [0, 2, 64]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+//      CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+//      CHECK: @reduction_tile_single_of_multiple_reduction_reversed(
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C64:.+]] = arith.constant 64 : index
+//  CHECK-DAG:   %[[C86:.+]] = arith.constant 86 : index
+//  CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//  CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<4096x64x2xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill
+// CHECK-SAME:       outs(%[[EMPTY]] :
+//      CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C86]] step %[[C2]]
+// CHECK-SAME:       iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
+//      CHECK:     %[[RESULT0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C64]]
+// CHECK-SAME:         iter_args(%[[ITER_ARG0:.+]] = %[[ITER_ARG]])
+//      CHECK:       %[[PARTIAL_RESULT:.+]] = linalg.generic
+// CHECK-SAME:           indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
+// CHECK-SAME:           iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:           outs(%[[ITER_ARG0]] :
+//      CHECK:       scf.yield %[[PARTIAL_RESULT]]
+//      CHECK      scf.yield %[[RESULT0]]
+//      CHECK:   %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME:       ins(%[[RESULT]] :
+// CHECK-SAME:       outs(%[[INIT]] :
+// CHECK-SAME:       dimensions = [1, 2]
+//      CHECK: return %[[REDUCE]]

>From b03fcbe0bf7478aa9854554e1e2a42e5cce6a8a6 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Sun, 8 Jun 2025 16:22:00 -0700
Subject: [PATCH 2/2] [mlir][PartialReductionTilingInterface] Add support for
 `ReductionTilingStrategy::PartialReductionOuterParallel` in `tileUsingSCF`.

Following up from https://github.com/llvm/llvm-project/pull/143467,
this PR adds support for
`ReductionTilingStrategy::PartialReductionOuterParallel` to
`tileUsingSCF`. The implementation of
`PartialReductionTilingInterface` for `Linalg` ops has been updated to
support this strategy as well. This makes the `tileUsingSCF` come on
par with `linalg::tileReductionUsingForall` which will be deprecated
subsequently.

Changes summary
- `PartialReductionTilingInterface` changes :
  - `tileToPartialReduction` method needed to get the induction
    variables of the generated tile loops. This was needed to keep the
    generated code similar to `linalg::tileReductionUsingForall`,
    specifically to create a simplified access for slicing the
    intermediate partial results tensor when tiled in `num_threads` mode.
  - `getPartialResultTilePosition` methods needs the induction
    varialbes for the generated tile loops for the same reason above,
    and also needs the `tilingStrategy` to be passed in to generate
    correct code.

The tests in `transform-tile-reduction.mlir` testing the
`linalg::tileReductionUsingForall` have been moved over to test
`scf::tileUsingSCF` with
`ReductionTilingStrategy::PartialReductionOuterParallel`
strategy. Some of the test that were doing further cyclic distribution
of the transformed code from tiling are removed. Those seem like two
separate transformation that were merged into one. Ideally that would
need to happen when resolving the `scf.forall` rather than during
tiling.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../mlir/Dialect/Utils/StaticValueUtils.h     |   2 +-
 .../mlir/Interfaces/TilingInterface.td        |   3 +
 .../TransformOps/LinalgTransformOps.cpp       |  32 +++-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 179 +++++++++++++-----
 .../SCF/Transforms/TileUsingInterface.cpp     |  81 ++++----
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  14 +-
 mlir/lib/Dialect/Utils/StaticValueUtils.cpp   |   2 +-
 .../Linalg/transform-tile-reduction.mlir      | 125 +-----------
 8 files changed, 217 insertions(+), 221 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index b37fb55b67931..77c376fb9973a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -156,7 +156,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
 /// corresponding pair of arrays. This is the inverse function of
 /// `getMixedValues`.
 std::pair<SmallVector<int64_t>, SmallVector<Value>>
-decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
+decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues);
 
 /// Helper to sort `values` according to matching `keys`.
 SmallVector<Value>
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 9358d8b82abce..2d50a454710c2 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -404,6 +404,7 @@ def PartialReductionOpInterface :
             "Location ":$loc,
             "::mlir::ReductionTilingStrategy":$tilingStrategy,
             "ValueRange":$init,
+            "ValueRange":$ivs,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
             "const ::llvm::SetVector<unsigned> &":$reductionDims),
@@ -442,6 +443,8 @@ def PartialReductionOpInterface :
         /*args=*/(ins
             "::mlir::OpBuilder &":$b,
             "unsigned":$resultNumber,
+            "ValueRange":$ivs,
+            "ReductionTilingStrategy":$tilingStrategy,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
             "const ::mlir::SetVector<unsigned> &":$reductionDims,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c003825264920..1f298185750dc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2864,23 +2864,41 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
   SmallVector<OpFoldResult> tileSizes =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
-  FailureOr<linalg::ForallReductionTilingResult> result =
-      linalg::tileReductionUsingForall(
-          rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
-          numThreads, tileSizes, getMapping());
+
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
+  options.setReductionTilingStrategy(
+      ReductionTilingStrategy::PartialReductionOuterParallel);
+  if (!getNumThreads().empty()) {
+    options.setNumThreads(numThreads);
+  } else {
+    options.setTileSizes(tileSizes);
+  }
+  if (auto mapping = getMapping()) {
+    options.setMapping(mapping.value().getValue());
+  }
+  SmallVector<unsigned> reductionDims;
+  for (auto [idx, iteratorType] :
+       llvm::enumerate(target.getIteratorTypesArray()))
+    if (iteratorType == utils::IteratorType::reduction)
+      reductionDims.push_back(idx);
+  options.setReductionDims(reductionDims);
+  FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
+      rewriter, cast<TilingInterface>(target.getOperation()), options);
 
   if (failed(result)) {
     auto diag = emitSilenceableError() << "could not tile reduction";
-    diag.attachNote(target.getLoc()) << "target operation";
     return diag;
   }
+  rewriter.replaceOp(target, result->replacements);
+
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
-  for (auto parallelTiledOp : result->parallelTiledOps)
+  for (auto parallelTiledOp : result->tiledOps)
     results.push_back(parallelTiledOp);
   for (auto mergeOp : result->mergeOps)
     results.push_back(mergeOp);
-  results.push_back(result->loops);
+  results.push_back(result->loops.front());
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index f649bc49a8fbd..2dfe4448019b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -328,6 +328,17 @@ struct LinalgOpTilingInterface
 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
 
+/// In a given set vector, get the position of a particular element.
+std::optional<int> getPositionIn(const llvm::SetVector<unsigned> &reductionDims,
+                                 unsigned value) {
+  for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
+    if (reductionDim == value) {
+      return index;
+    }
+  }
+  return std::nullopt;
+}
+
 /// Return an AffineMaps to use for the `outs` operands of the linalg op
 /// generated for partial results. The new AffineMap is the AffineMap of the
 /// untiled op with reduction dimensions appended at end in order in which they
@@ -348,28 +359,79 @@ getPartialResultAffineMaps(LinalgOp linalgOp,
   return partialReductionMaps;
 }
 
-/// Return the slice of the `initValue` to use as input to the partial reduction
-/// op generated.
-static Operation *getInitSliceForOuterReduction(
-    OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
+struct InitSliceInfo {
+  SmallVector<int64_t> resultShape;
+  SmallVector<OpFoldResult> offsets;
+  SmallVector<OpFoldResult> sizes;
+  SmallVector<OpFoldResult> strides;
+};
+
+/// Return the result type, offsets, sizes and strides of the slice of the
+/// `initValue` to use as input to the partial reduction op generated with
+/// outer reduction strategy.
+static InitSliceInfo getInitSliceInfoForOuterReduction(
+    MLIRContext *context, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
     AffineMap partialReductionMap) {
   int64_t initRank = partialReductionMap.getNumResults();
   SmallVector<OpFoldResult> initOffsets, initSizes;
-  SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1));
+  Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
+  Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+  SmallVector<OpFoldResult> initStrides(initRank, one);
   for (AffineExpr dimExpr : partialReductionMap.getResults()) {
     unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
     if (reductionDims.contains(dim)) {
-      initOffsets.push_back(b.getIndexAttr(0));
+      initOffsets.push_back(zero);
     } else {
       initOffsets.push_back(offsets[dim]);
     }
     initSizes.push_back(sizes[dim]);
   }
-  // TODO: Use SubsetExtractOpInterface here once available.
-  auto extractSlice = b.create<tensor::ExtractSliceOp>(
-      loc, initValue, initOffsets, initSizes, initStrides);
-  return extractSlice;
+  SmallVector<int64_t> resultShape;
+  std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
+  return {resultShape, initOffsets, initSizes, initStrides};
+}
+
+/// Return the result type, offsets, sizes and strides of the slice of the
+/// `initValue` to use as input to the partial reduction op generated with
+/// outer parallel strategy.
+static InitSliceInfo getInitSliceInfoForOuterParallel(
+    MLIRContext *context, ValueRange ivs, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
+    AffineMap partialReductionMap) {
+  int64_t initRank = partialReductionMap.getNumResults();
+  SmallVector<OpFoldResult> initOffsets, initSizes;
+  Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+  SmallVector<OpFoldResult> initStrides(initRank, one);
+  SmallVector<OpFoldResult> resultShape;
+  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+    unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    if (std::optional<int> dimPos = getPositionIn(reductionDims, dim)) {
+      initOffsets.push_back(ivs[dimPos.value()]);
+      initSizes.push_back(one);
+    } else {
+      initOffsets.push_back(offsets[dim]);
+      initSizes.push_back(sizes[dim]);
+      resultShape.push_back(sizes[dim]);
+    }
+  }
+  SmallVector<int64_t> staticShapes;
+  std::tie(staticShapes, std::ignore) = decomposeMixedValues(resultShape);
+  return {staticShapes, initOffsets, initSizes, initStrides};
+}
+
+static InitSliceInfo getInitSliceInfo(
+    MLIRContext *context, ReductionTilingStrategy strategy, ValueRange ivs,
+    ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+    const SetVector<unsigned> &reductionDims, AffineMap partialReductionMap) {
+  if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
+    return getInitSliceInfoForOuterReduction(
+        context, offsets, sizes, reductionDims, partialReductionMap);
+  }
+  assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
+         "unexpected ReductionTilingStrategy");
+  return getInitSliceInfoForOuterParallel(context, ivs, offsets, sizes,
+                                          reductionDims, partialReductionMap);
 }
 
 /// External model implementation of PartialReductionInterface for
@@ -439,18 +501,11 @@ struct LinalgOpPartialReductionInterface
     return inits;
   }
 
-  FailureOr<TilingResult>
-  tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
-                         ReductionTilingStrategy tilingStrategy,
-                         ValueRange init, ArrayRef<OpFoldResult> offsets,
-                         ArrayRef<OpFoldResult> sizes,
-                         const SetVector<unsigned> &reductionDims) const {
-    if (tilingStrategy !=
-        ReductionTilingStrategy::PartialReductionOuterReduction) {
-      // TODO: Add support for `PartialReductionOuterParallel` strategy.
-      return op->emitOpError("unsupported partial reduction tiling with "
-                             "`PartialReductionOuterParallel` strategy");
-    }
+  FailureOr<TilingResult> tileToPartialReduction(
+      Operation *op, OpBuilder &b, Location loc,
+      ReductionTilingStrategy tilingStrategy, ValueRange init, ValueRange ivs,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      const SetVector<unsigned> &reductionDims) const {
     OpBuilder::InsertionGuard guard(b);
     auto linalgOp = cast<LinalgOp>(op);
 
@@ -459,7 +514,16 @@ struct LinalgOpPartialReductionInterface
 
     // Step 1. Extend init maps to have reduction dimension dims, since we
     // are converting them to parallel dimensions.
-    SmallVector<AffineMap> newInitMaps = partialReductionMaps;
+    SmallVector<AffineMap> newInitMaps;
+    if (tilingStrategy ==
+        ReductionTilingStrategy::PartialReductionOuterReduction) {
+      newInitMaps = llvm::to_vector(partialReductionMaps);
+    } else {
+      newInitMaps = llvm::map_to_vector(
+          linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
+            return linalgOp.getMatchingIndexingMap(&opOperand);
+          });
+    }
 
     // Step 2a: Extract a slice of the input operands.
     SmallVector<Value> tiledInputs = makeTiledShapes(
@@ -473,10 +537,17 @@ struct LinalgOpPartialReductionInterface
     SmallVector<Value, 1> tiledInits;
     for (auto [partialReductionMap, valueToTile] :
          llvm::zip_equal(partialReductionMaps, init)) {
-      Operation *sliceOp =
-          getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes,
-                                        reductionDims, partialReductionMap);
-      tiledInits.push_back(sliceOp->getResult(0));
+      InitSliceInfo sliceInfo =
+          getInitSliceInfo(b.getContext(), tilingStrategy, ivs, offsets, sizes,
+                           reductionDims, partialReductionMap);
+      auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
+      RankedTensorType sliceResultType = RankedTensorType::get(
+          sliceInfo.resultShape, valueToTileType.getElementType(),
+          valueToTileType.getEncoding());
+      auto sliceOp = b.create<tensor::ExtractSliceOp>(
+          loc, sliceResultType, valueToTile, sliceInfo.offsets, sliceInfo.sizes,
+          sliceInfo.strides);
+      tiledInits.push_back(sliceOp.getResult());
       generatedSlices.push_back(sliceOp);
     }
 
@@ -491,19 +562,31 @@ struct LinalgOpPartialReductionInterface
     // Step 3. Change the reduction dim iterator types.
     SmallVector<utils::IteratorType> newIteratorTypes =
         linalgOp.getIteratorTypesArray();
-    for (int dim : reductionDims)
-      newIteratorTypes[dim] = utils::IteratorType::parallel;
+    if (tilingStrategy ==
+        ReductionTilingStrategy::PartialReductionOuterReduction) {
+      for (int dim : reductionDims)
+        newIteratorTypes[dim] = utils::IteratorType::parallel;
+    }
 
     // Step 4. Create the new generic op.
+    Operation *partialReductionOp;
     auto resultTypes = ValueRange(tiledInits).getTypes();
-    auto genericOp = b.create<GenericOp>(loc, resultTypes, tiledInputs,
-                                         tiledInits, newMaps, newIteratorTypes);
-    IRMapping mapping;
-    op->getRegion(0).cloneInto(&genericOp.getRegion(),
-                               genericOp.getRegion().begin(), mapping);
+    if (tilingStrategy ==
+        ReductionTilingStrategy::PartialReductionOuterReduction) {
+      auto genericOp = b.create<GenericOp>(
+          loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes);
+      IRMapping mapping;
+      op->getRegion(0).cloneInto(&genericOp.getRegion(),
+                                 genericOp.getRegion().begin(), mapping);
+      partialReductionOp = genericOp.getOperation();
+    } else {
+      SmallVector<Value> operands = std::move(tiledInputs);
+      llvm::append_range(operands, tiledInits);
+      partialReductionOp = mlir::clone(b, op, resultTypes, operands);
+    }
     return TilingResult{
-        {genericOp.getOperation()},
-        llvm::map_to_vector(genericOp->getResults(),
+        {partialReductionOp},
+        llvm::map_to_vector(partialReductionOp->getResults(),
                             [](OpResult r) -> Value { return r; }),
         generatedSlices};
   }
@@ -557,27 +640,19 @@ struct LinalgOpPartialReductionInterface
   }
 
   LogicalResult getPartialResultTilePosition(
-      Operation *op, OpBuilder &b, unsigned resultNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
-      const SetVector<unsigned> &reductionDims,
+      Operation *op, OpBuilder &b, unsigned resultNumber, ValueRange ivs,
+      ReductionTilingStrategy tilingStrategy, ArrayRef<OpFoldResult> offsets,
+      ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
       SmallVector<OpFoldResult> &resultOffsets,
       SmallVector<OpFoldResult> &resultSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
     SmallVector<AffineMap> partialReductionMaps =
         getPartialResultAffineMaps(linalgOp, reductionDims);
-
-    for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) {
-      unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
-      resultSizes.push_back(sizes[dim]);
-
-      if (llvm::is_contained(reductionDims, dim)) {
-        // Reduction dims are reduced, and are always outputed in the same
-        // place. So use offset 0 for them.
-        resultOffsets.push_back(b.getIndexAttr(0));
-      } else {
-        resultOffsets.push_back(offsets[dim]);
-      }
-    }
+    InitSliceInfo sliceInfo =
+        getInitSliceInfo(b.getContext(), tilingStrategy, ivs, offsets, sizes,
+                         reductionDims, partialReductionMaps[resultNumber]);
+    std::swap(resultOffsets, sliceInfo.offsets);
+    std::swap(resultSizes, sliceInfo.sizes);
 
     return success();
   }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e7c076024e67b..54128fe72d6c4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -166,12 +166,11 @@ static LogicalResult checkTileSizes(TilingInterface op,
   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
          "when specified, expected number of threads to use for each loop");
 
-  bool isParallelTiling = false, isReductionTiling = false;
+  bool isParallelTiling = false;
   for (auto [index, iterator, tileSize] :
        llvm::enumerate(iterators, tileSizes)) {
     if (!isConstantIntValue(tileSize, 0)) {
       isParallelTiling |= iterator == utils::IteratorType::parallel;
-      isReductionTiling |= iterator == utils::IteratorType::reduction;
     }
 
     if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
@@ -199,11 +198,11 @@ static LogicalResult checkTileSizes(TilingInterface op,
     }
   }
 
-  if (isParallelTiling && isReductionTiling &&
-      reductionStrategy != ReductionTilingStrategy::FullReduction) {
-    return op->emitOpError(
-        "combined parallel and reduction tiling is not supported with partial "
-        "reduction tiling strategies");
+  if (reductionStrategy != ReductionTilingStrategy::FullReduction) {
+    if (isParallelTiling) {
+      return op->emitOpError("tiling parallel dimensions is not supported with "
+                             "partial reduction tiling strategies");
+    }
   }
   return success();
 }
@@ -264,10 +263,12 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
 /// `offset`s and `size`s of the tile of the iteration space that the
 /// innermost loop body of the generated tiled loops corresponds to.
 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
-getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
+getTileOffsetAndSizes(RewriterBase &rewriter, Location loc,
+                      ReductionTilingStrategy strategy, ValueRange ivs,
                       ArrayRef<Range> iterationDomain,
                       ArrayRef<OpFoldResult> tileSizes,
-                      ArrayRef<OpFoldResult> numThreads) {
+                      ArrayRef<OpFoldResult> numThreads,
+                      const llvm::SetVector<unsigned> &reductionDims) {
   SmallVector<OpFoldResult> offsets, sizes;
   int materializedLoopNum = 0;
 
@@ -279,8 +280,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
     offsetExpr = d0 + d1 * s0;
     residualTileSizeExpr = s1 - (d0 + d1 * s0);
 
-    for (auto [nt, tileSize, loopRange] :
-         llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
+    for (auto [index, nt, tileSize, loopRange] :
+         llvm::enumerate(numThreads, tileSizes, iterationDomain)) {
 
       // Non-tiled cases, set the offset and size to the
       // `loopRange.offset/size`.
@@ -590,6 +591,7 @@ static LogicalResult generateLoopNest(
 
 static FailureOr<SmallVector<Value>>
 createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
+                              ArrayRef<OpFoldResult> numThreads,
                               ArrayRef<OpFoldResult> tileSizes,
                               const scf::SCFTilingOptions &options) {
   SmallVector<Value> initTensors;
@@ -606,15 +608,20 @@ createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
         op, "PartialReductionOuterReduction tiling strategy is only supported"
             "for operations implementing PartialReductionOpInterface");
   }
-  return redOp.generateInitialTensorForPartialReduction(
-      rewriter, loc, tileSizes, options.reductionDims);
+  ArrayRef<OpFoldResult> sizes;
+  if (options.numThreadsComputationFunction) {
+    sizes = numThreads;
+  } else {
+    sizes = tileSizes;
+  }
+  return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
+                                                        options.reductionDims);
 }
 
-static FailureOr<TilingResult>
-getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
-                       ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
-                       ArrayRef<OpFoldResult> sizes,
-                       const scf::SCFTilingOptions &options) {
+static FailureOr<TilingResult> getTiledImplementation(
+    RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg,
+    ValueRange ivs, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, const scf::SCFTilingOptions &options) {
   if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     return op.getTiledImplementation(rewriter, offsets, sizes);
   }
@@ -626,18 +633,17 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
             "supported for operations "
             "implementing PartialReductionOpInterface");
   }
-  return redOp.tileToPartialReduction(rewriter, op.getLoc(),
-                                      options.reductionStrategy, regionIterArg,
-                                      offsets, sizes, options.reductionDims);
+  return redOp.tileToPartialReduction(
+      rewriter, op.getLoc(), options.reductionStrategy, regionIterArg, ivs,
+      offsets, sizes, options.reductionDims);
 }
 
-static LogicalResult
-getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
-                      TilingInterface op, ArrayRef<OpFoldResult> offsets,
-                      ArrayRef<OpFoldResult> sizes,
-                      SmallVector<OpFoldResult> &resultOffset,
-                      SmallVector<OpFoldResult> &resultSize,
-                      const scf::SCFTilingOptions &options) {
+static LogicalResult getResultTilePosition(
+    RewriterBase &rewriter, int64_t index, Value tiledResult,
+    TilingInterface op, ValueRange ivs, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffset,
+    SmallVector<OpFoldResult> &resultSize,
+    const scf::SCFTilingOptions &options) {
 
   if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
     return op.getResultTilePosition(rewriter, index, offsets, sizes,
@@ -649,9 +655,9 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
         op, "PartialReductionOuterReduction tiling strategy is only supported"
             "for operations implementing PartialReductionOpInterface");
   }
-  return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
-                                            options.reductionDims, resultOffset,
-                                            resultSize);
+  return redOp.getPartialResultTilePosition(
+      rewriter, index, ivs, options.reductionStrategy, offsets, sizes,
+      options.reductionDims, resultOffset, resultSize);
 }
 
 static FailureOr<MergeResult>
@@ -938,7 +944,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     // 4a. Compute the `offsets` and `sizes` to use for tiling.
     SmallVector<OpFoldResult> offsets, sizes;
     std::tie(offsets, sizes) = getTileOffsetAndSizes(
-        rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
+        rewriter, loc, options.reductionStrategy, ivs, iterationDomain,
+        tileSizes, numThreads, options.reductionDims);
 
     // 4b. If interchange was provided, apply inverse of the interchange
     //     to get back the offsets/sizes in the order to be specified.
@@ -967,7 +974,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
 
     // 5c. Tile the cloned operation.
     tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
-                                          offsets, sizes, options);
+                                          ivs, offsets, sizes, options);
     if (failed(tilingResult)) {
       rewriter.eraseOp(clonedOp);
       return op.emitOpError("faild to tile operation");
@@ -982,8 +989,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
          llvm::enumerate(tilingResult->tiledValues)) {
       tiledResults.push_back(tiledValue);
       SmallVector<OpFoldResult> resultOffset, resultSize;
-      if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
-                                       sizes, resultOffset, resultSize,
+      if (failed(getResultTilePosition(rewriter, index, tiledValue, op, ivs,
+                                       offsets, sizes, resultOffset, resultSize,
                                        options))) {
         for (auto op : tilingResult->tiledOps) {
           rewriter.eraseOp(op);
@@ -999,8 +1006,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   };
 
   // 6. Find the destination tensors to use for the operation.
-  FailureOr<SmallVector<Value>> maybeInits =
-      createInitialTensorsForTiling(rewriter, op, tileSizes, options);
+  FailureOr<SmallVector<Value>> maybeInits = createInitialTensorsForTiling(
+      rewriter, op, numThreads, tileSizes, options);
   if (failed(maybeInits)) {
     return rewriter.notifyMatchFailure(
         op, "unable to create initial tensors for tiling");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 04242cad9ecb6..72144ec71c5d2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2315,13 +2315,13 @@ RankedTensorType ExtractSliceOp::inferResultType(
 RankedTensorType ExtractSliceOp::inferResultType(
     RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
-  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
-  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
-  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
-  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
-  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
-                                         staticSizes, staticStrides);
+  SmallVector<int64_t> staticSizes;
+  std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
+  assert(static_cast<int64_t>(staticSizes.size()) ==
+             sourceTensorType.getRank() &&
+         "unexpected staticSizes not equal to rank of source");
+  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
+                               sourceTensorType.getEncoding());
 }
 
 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 29f7bd6857c27..6616a2f164802 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -209,7 +209,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
 /// Decompose a vector of mixed static or dynamic values into the corresponding
 /// pair of arrays. This is the inverse function of `getMixedValues`.
 std::pair<SmallVector<int64_t>, SmallVector<Value>>
-decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
+decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues) {
   SmallVector<int64_t> staticValues;
   SmallVector<Value> dynamicValues;
   for (const auto &it : mixedValues) {
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 009ab17786696..aa580b1acce21 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -134,10 +134,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
 // CHECK-DAG:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
 // CHECK-DAG:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
-//     CHECK:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
-//     CHECK:     %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-//     CHECK:     %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor<?xf32> to tensor<?xf32>
-//     CHECK:     %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
+// CHECK-DAG:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
+// CHECK-DAG:     %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[ET]] : tensor<?xf32>) {
 //     CHECK:       arith.mulf
 //     CHECK:       arith.addf
 //     CHECK:       linalg.yield
@@ -187,11 +186,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
 // CHECK-DAG:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
 // CHECK-DAG:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
-//     CHECK:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
-//     CHECK:     %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-//     CHECK:     %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-//     CHECK:     %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0, 0] [%[[D0]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-//     CHECK:     %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-DAG:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
+// CHECK-DAG:     %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK-DAG:     %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ET]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 //     CHECK:     scf.forall.in_parallel {
 //     CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
 //     CHECK:     }
@@ -204,113 +202,9 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @reduction_tile_parallel_cyclic_dist(
-  %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
-  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                                          affine_map<(d0, d1) -> (d0)>],
-   iterator_types = ["parallel", "reduction"]}
-   ins(%arg0 : tensor<?x?xf32>)
-   outs(%out : tensor<?xf32>) {
-    ^bb0(%arg7: f32, %arg9: f32):
-      %1 = arith.mulf %arg7, %arg7 : f32
-      %2 = arith.addf %1, %arg9 : f32
-      linalg.yield %2 : f32
-    } -> tensor<?xf32>
-  return %red : tensor<?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
-      by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-      transform.yield
-  }
-}
-
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0)>
-
-//     CHECK: func @reduction_tile_parallel_cyclic_dist(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
-// CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
-// CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
-//     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
-//     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
-//     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
-//     CHECK:     %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-//     CHECK:     %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
-//     CHECK:     %[[CARRY:.+]] = scf.for %[[IV1:.+]] = %[[LB]] to %[[D1]] step %[[C15]] iter_args(%[[ACC:.+]] = %[[ET]]) -> (tensor<?xf32>) {
-//     CHECK:       %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[D1]]]
-//     CHECK:       %[[D3:.+]] = tensor.dim %[[ACC]], %[[C0]] : tensor<?xf32>
-//     CHECK:       %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [%[[D0]], %[[TS0]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-//     CHECK:       %[[TEMPEXT:.+]] = tensor.extract_slice %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> to tensor<?xf32>
-//     CHECK:       %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
-//     CHECK:         arith.mulf
-//     CHECK:         arith.addf
-//     CHECK:         linalg.yield
-//     CHECK:       } -> tensor<?xf32>
-//     CHECK:       %[[INS:.+]] = tensor.insert_slice %[[PARTIAL]] into %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> into tensor<?xf32>
-//     CHECK:       scf.yield %[[INS]] : tensor<?xf32>
-//     CHECK:     }
-//     CHECK:     scf.forall.in_parallel {
-//     CHECK:       tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
-//     CHECK:     }
-//     CHECK:   }
-//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
-//     CHECK:     arith.addf
-//     CHECK:     linalg.yield
-//     CHECK:   }
-//     CHECK:   return %[[R]] : tensor<?xf32>
-
-// -----
-
-func.func @reduction_tile_parallel_cyclic_dist(
-  %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
-  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                                          affine_map<(d0, d1) -> (d0)>],
-   iterator_types = ["parallel", "reduction"]}
-   ins(%arg0 : tensor<?x?xf32>)
-   outs(%out : tensor<?xf32>) {
-    ^bb0(%arg7: f32, %arg9: f32):
-      %1 = arith.mulf %arg7, %arg7 : f32
-      %2 = arith.addf %1, %arg9 : f32
-      linalg.yield %2 : f32
-    } -> tensor<?xf32>
-  return %red : tensor<?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
-      by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
-    //      CHECK:     expecting fill
-    // CHECK-NEXT:     linalg.fill
-    transform.print %1 {name = "expecting fill"} : !transform.any_op
-    //      CHECK:     expecting parallel reduction
-    // CHECK-NEXT:     linalg.generic
-    //      CHECK:     iterator_types = ["parallel", "reduction"]
-    transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op
-    //      CHECK:     expecting parallel reduction
-    // CHECK-NEXT:     linalg.reduce
-    //      CHECK:     iterator_types = ["parallel", "reduction"]
-    transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
 func.func @reduction_untiled_forall(
   %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
-  // expected-note @below {{target operation}}
+  // expected-error @below {{tiling parallel dimensions is not supported with partial reduction tiling strategies}}
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                           affine_map<(d0, d1) -> (d0)>],
    iterator_types = ["parallel", "reduction"]}
@@ -330,8 +224,7 @@ module attributes {transform.with_named_sequence} {
     // expected-error @below {{could not tile reduction}}
     %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
       by num_threads = [5], tile_sizes = [3], mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
-      transform.yield
+    transform.yield
   }
 }
 



More information about the Mlir-commits mailing list