[Mlir-commits] [mlir] [mlir][PartialReductionTilingInterface] Add support for `ReductionTilingStrategy::PartialReductionOuterParallel` in `tileUsingSCF`. (PR #143988)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 12 15:56:06 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

Following up from https://github.com/llvm/llvm-project/pull/143467,
this PR adds support for
`ReductionTilingStrategy::PartialReductionOuterParallel` to
`tileUsingSCF`. The implementation of
`PartialReductionTilingInterface` for `Linalg` ops has been updated to
support this strategy as well. This makes the `tileUsingSCF` come on
par with `linalg::tileReductionUsingForall` which will be deprecated
subsequently.

Changes summary
- `PartialReductionTilingInterface` changes :
  - `tileToPartialReduction` method needed to get the induction
    variables of the generated tile loops. This was needed to keep the
    generated code similar to `linalg::tileReductionUsingForall`,
    specifically to create a simplified access for slicing the
    intermediate partial results tensor when tiled in `num_threads` mode.
  - `getPartialResultTilePosition` methods needs the induction
    varialbes for the generated tile loops for the same reason above,
    and also needs the `tilingStrategy` to be passed in to generate
    correct code.

The tests in `transform-tile-reduction.mlir` testing the
`linalg::tileReductionUsingForall` have been moved over to test
`scf::tileUsingSCF` with
`ReductionTilingStrategy::PartialReductionOuterParallel`
strategy. Some of the test that were doing further cyclic distribution
of the transformed code from tiling are removed. Those seem like two
separate transformation that were merged into one. Ideally that would
need to happen when resolving the `scf.forall` rather than during
tiling.

Please review only the top commit.  Depends on https://github.com/llvm/llvm-project/pull/143467

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@<!-- -->gmail.com>

---

Patch is 75.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143988.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+7-1) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+18-28) 
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+1-1) 
- (modified) mlir/include/mlir/Interfaces/TilingInterface.h (+21) 
- (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+9-5) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+50-13) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+8-7) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+191-99) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+116-148) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-7) 
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+1-1) 
- (modified) mlir/test/Dialect/Linalg/transform-tile-reduction.mlir (+174-118) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..d0591ae122fbb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1767,6 +1767,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
       - the result-combining op,
       - the parent `for` op.
 
+    The `reduction_dims` can be used to specify the subset of reduction dimensions
+    of the operation to tile. If left unspecified, all reduction dimensions are
+    tiled.
+
     #### Example:
 
     ```
@@ -1817,7 +1821,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
   let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_op,
                       TransformHandleTypeInterface:$combining_op,
@@ -1830,6 +1835,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
   let assemblyFormat = [{
     $target
+    (`reduction_dims` `=` $reduction_dims^)?
     `by` `tile_sizes` `=` $tile_sizes
     attr-dict
     `:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..01ad64b76b15e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify mapping of loops to devices. This is only respected when the loop
+  /// constructs support such a mapping (like `scf.forall`). Will be ignored
+  /// when using loop constructs that dont support such a mapping (like
+  /// `scf.for`)
+  SmallVector<Attribute> mappingVector = {};
+  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+    mappingVector = llvm::to_vector(mapping);
+    return *this;
+  }
+
+  //-------------------------------------------------------------------------//
+  // Options related reduction tiling
+  //-------------------------------------------------------------------------//
+
   /// Specify how reduction dimensions should be tiled.
-  ///
-  /// Tiling can be thought of as splitting a dimension into 2 and materializing
-  /// the outer dimension as a loop:
-  ///
-  /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
-  ///
-  /// For parallel dimensions, the split can only happen in one way, with both
-  /// dimensions being parallel. For reduction dimensions however, there is a
-  /// choice in how we split the reduction dimension. This enum exposes this
-  /// choice.
-  enum class ReductionTilingStrategy {
-    // [reduction] -> [reduction1, reduction2]
-    // -> loop[reduction1] { [reduction2] }
-    FullReduction,
-    // [reduction] -> [reduction1, parallel2]
-    // -> loop[reduction1] { [parallel2] }; merge[reduction1]
-    PartialReductionOuterReduction,
-    // [reduction] -> [parallel1, reduction2]
-    // -> loop[parallel1] { [reduction2] }; merge[parallel1]
-    PartialReductionOuterParallel
-  };
   ReductionTilingStrategy reductionStrategy =
       ReductionTilingStrategy::FullReduction;
   SCFTilingOptions &
@@ -115,13 +108,10 @@ struct SCFTilingOptions {
     return *this;
   }
 
-  /// Specify mapping of loops to devices. This is only respected when the loop
-  /// constructs support such a mapping (like `scf.forall`). Will be ignored
-  /// when using loop constructs that dont support such a mapping (like
-  /// `scf.for`)
-  SmallVector<Attribute> mappingVector = {};
-  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
-    mappingVector = llvm::to_vector(mapping);
+  /// Specify the reduction dimensions to be tiled.
+  SetVector<unsigned> reductionDims;
+  SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
+    reductionDims.insert(dims.begin(), dims.end());
     return *this;
   }
 };
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index b37fb55b67931..77c376fb9973a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -156,7 +156,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
 /// corresponding pair of arrays. This is the inverse function of
 /// `getMixedValues`.
 std::pair<SmallVector<int64_t>, SmallVector<Value>>
-decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
+decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues);
 
 /// Helper to sort `values` according to matching `keys`.
 SmallVector<Value>
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index b33aa1489c311..8693cbea7f0b0 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -36,6 +36,27 @@ struct TilingResult {
   SmallVector<Operation *> generatedSlices;
 };
 
+/// Tiling can be thought of as splitting a dimension into 2 and
+/// materializing the outer dimension as a loop:
+///
+/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+///
+/// For parallel dimensions, the split can only happen in one way, with both
+/// dimensions being parallel. For reduction dimensions however, there is a
+/// choice in how we split the reduction dimension. This enum exposes this
+/// choice.
+enum class ReductionTilingStrategy {
+  // [reduction] -> [reduction1, reduction2]
+  // -> loop[reduction1] { [reduction2] }
+  FullReduction,
+  // [reduction] -> [reduction1, parallel2]
+  // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+  PartialReductionOuterReduction,
+  // [reduction] -> [parallel1, reduction2]
+  // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+  PartialReductionOuterParallel
+};
+
 /// Container for the result of merge operation of tiling.
 /// - `mergeOps` contains operations created during the merge.
 /// - `replacements` contains the values that represents the result of the
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..2d50a454710c2 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -384,7 +384,7 @@ def PartialReductionOpInterface :
             "::mlir::OpBuilder &":$b,
             "Location":$loc,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
-            "::mlir::ArrayRef<int>":$reductionDim),
+            "const ::mlir::SetVector<unsigned> &":$reductionDim),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -402,10 +402,12 @@ def PartialReductionOpInterface :
         /*args=*/(ins
             "::mlir::OpBuilder &":$b,
             "Location ":$loc,
+            "::mlir::ReductionTilingStrategy":$tilingStrategy,
             "ValueRange":$init,
+            "ValueRange":$ivs,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
-            "::mlir::ArrayRef<int>":$reductionDims),
+            "const ::llvm::SetVector<unsigned> &":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -423,7 +425,7 @@ def PartialReductionOpInterface :
             "::mlir::OpBuilder &":$b,
             "Location ":$loc,
             "ValueRange":$partialReduce,
-            "::mlir::ArrayRef<int>":$reductionDim),
+            "const ::mlir::SetVector<unsigned> &":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -441,11 +443,13 @@ def PartialReductionOpInterface :
         /*args=*/(ins
             "::mlir::OpBuilder &":$b,
             "unsigned":$resultNumber,
+            "ValueRange":$ivs,
+            "ReductionTilingStrategy":$tilingStrategy,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+            "const ::mlir::SetVector<unsigned> &":$reductionDims,
             "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
-            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
-            "::mlir::ArrayRef<int>":$reductionDims),
+            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..1f298185750dc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2775,10 +2775,11 @@ void transform::TileReductionUsingForOp::build(
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   MLIRContext *ctx = builder.getContext();
   auto opTy = transform::AnyOpType::get(ctx);
-  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+  auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
   build(builder, result,
         /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
         /*target=*/target,
+        /*reduction_dims=*/nullptr,
         /*tile_sizes=*/staticTileSizesAttr);
 }
 
@@ -2794,12 +2795,30 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
         target->getLoc(),
         "Operation should implement PartialReductionOpInterface");
   }
-  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
-      rewriter, partialReductionOp,
-      getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
-  if (failed(result))
-    return emitDefaultSilenceableFailure(target);
+  SmallVector<unsigned> reductionDims =
+      extractFromIntegerArrayAttr<unsigned>(getReductionDims());
+  if (reductionDims.empty()) {
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+  }
+
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+  options.setReductionTilingStrategy(
+      ReductionTilingStrategy::PartialReductionOuterReduction);
+  options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
+  options.setReductionDims(reductionDims);
+  FailureOr<scf::SCFTilingResult> result =
+      scf::tileUsingSCF(rewriter, partialReductionOp, options);
+
+  if (failed(result)) {
+    return emitSilenceableFailure(getLoc(),
+                                  "failed to tile using partial reduction");
+  }
   rewriter.replaceOp(target, result->replacements);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
@@ -2845,23 +2864,41 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
   SmallVector<OpFoldResult> tileSizes =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
-  FailureOr<linalg::ForallReductionTilingResult> result =
-      linalg::tileReductionUsingForall(
-          rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
-          numThreads, tileSizes, getMapping());
+
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
+  options.setReductionTilingStrategy(
+      ReductionTilingStrategy::PartialReductionOuterParallel);
+  if (!getNumThreads().empty()) {
+    options.setNumThreads(numThreads);
+  } else {
+    options.setTileSizes(tileSizes);
+  }
+  if (auto mapping = getMapping()) {
+    options.setMapping(mapping.value().getValue());
+  }
+  SmallVector<unsigned> reductionDims;
+  for (auto [idx, iteratorType] :
+       llvm::enumerate(target.getIteratorTypesArray()))
+    if (iteratorType == utils::IteratorType::reduction)
+      reductionDims.push_back(idx);
+  options.setReductionDims(reductionDims);
+  FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
+      rewriter, cast<TilingInterface>(target.getOperation()), options);
 
   if (failed(result)) {
     auto diag = emitSilenceableError() << "could not tile reduction";
-    diag.attachNote(target.getLoc()) << "target operation";
     return diag;
   }
+  rewriter.replaceOp(target, result->replacements);
+
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
-  for (auto parallelTiledOp : result->parallelTiledOps)
+  for (auto parallelTiledOp : result->tiledOps)
     results.push_back(parallelTiledOp);
   for (auto mergeOp : result->mergeOps)
     results.push_back(mergeOp);
-  results.push_back(result->loops);
+  results.push_back(result->loops.front());
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4162aa0b71e6d..8a5a2e54cdda2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
 }
 
 FailureOr<StaticContinuousTileSizeSpecification>
-mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
-                                               unsigned dimension,
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
                                                unsigned targetSize) {
 
   assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
 
   // Find the trip count of the iteration space dimension for which the tile
   // sizes are computed.
-  Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
-                                                    loopRanges[dimension].size);
+  Value loopRange =
+      getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
   ContinuousTileSizeSpecification spec;
 
   // Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
     return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
                                     "many elements as number of threads");
-  int reductionDim = static_cast<int>(redDims.front());
 
   if (redDims.front() >= numThreads.size())
     return b.notifyMatchFailure(
         op, "reduction dimension must be mapped to threads");
 
   // 1. Create the inital tensor value.
+  unsigned reductionDim = redDims.front();
+  SetVector<unsigned> reductionDims;
+  reductionDims.insert(reductionDim);
   FailureOr<SmallVector<Value>> maybeInitTensors =
       op.generateInitialTensorForPartialReduction(b, loc, numThreads,
-                                                  reductionDim);
+                                                  reductionDims);
   if (failed(maybeInitTensors))
     return b.notifyMatchFailure(
         op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   // 7. Merge the partial reductions.
   b.setInsertionPointAfter(forallOp);
   FailureOr<MergeResult> mergeResult =
-      op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
+      op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
   if (failed(mergeResult)) {
     return failure();
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..2dfe4448019b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include <optional>
@@ -327,23 +328,110 @@ struct LinalgOpTilingInterface
 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
 
-/// Return an AffineMap for a partial result for the given result number,
-/// assuming the partial tiling strategy is outer-reduction loop +
-/// inner-parallel tile. The returned AffineMap can be used as the replacement
-/// AffineMap for the inner-parallel tile linalg op for the given result number.
-///
-/// The new AffineMap is the old AffineMap with reduction dimensions appended
-/// at end.
-static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
-                                           ArrayRef<int> reductionDims,
-                                           unsigned resultNumber) {
-  AffineMap map =
-      linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
-  for (int redPos : reductionDims) {
-    map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
-                           map.getNumResults());
+/// In a given set vector, get the position of a particular element.
+std::optional<int> getPositionIn(const llvm::SetVector<unsigned> &reductionDims,
+                                 unsigned value) {
+  for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
+    if (reductionDim == value) {
+      return index;
+    }
+  }
+  return std::nullopt;
+}
+
+/// Return an AffineMaps to use for the `outs` operands of the linalg op
+/// generated for partial results. The new AffineMap is the AffineMap of the
+/// untiled op with reduction dimensions appended at end in order in which they
+/// were specified during tiling.
+static SmallVector<AffineMap>
+getPartialResultAffineMaps(LinalgOp linalgOp,
+                           const SetVector<unsigned> &reductionDims) {
+  auto partialReductionMaps = llvm::map_to_vector(
+      linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
+        AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
+        for (auto redPos : reductionDims) {
+          map =
+              map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
+                               map.getNumResults());
+        }
+        return map;
+      });
+  return partialReductionMaps;
+}
+
+struct InitSliceInfo {
+  SmallVector<int64_t> resultShape;
+  SmallVector<OpFoldResult> offsets;
+  SmallVector<OpFoldResult> sizes;
+  SmallVector<OpFoldResult> strides;
+};
+
+/// Return the result type, offsets, sizes and strides of the slice of the
+/// `initValue` to use as input to the partial reduction op generated with
+/// outer reduction strategy.
+static InitSliceInfo getInitSliceInfoForOuterReduction(
+    MLIRContext *context, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
+    AffineMap partialReductionMap) {
+  int64_t initRank = partialReductionMap.getNumResults();
+  SmallVector<OpFoldResult> initOffsets, initSizes;
+  Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
+  Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+  SmallVector<OpFoldResult> initStrides(initRank, one);
+  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+    unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    if (reductionDims.contains(dim)) {
+      initOffsets.push_back(zero);
+    } else {
+      initOffsets.push...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/143988


More information about the Mlir-commits mailing list