[Mlir-commits] [mlir] [mlir][PartialReductionTilingInterface] Generalize implementation of `tileUsingSCF` for `ReductionTilingStrategy::PartialOuterReduction`. (PR #143467)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 9 19:08:36 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

This is a precursor to generalizing the `tileUsingSCF` to handle
`ReductionTilingStrategy::PartialOuterParallel` strategy. This change
itself is generalizing/refactoring the current implementation that
supports only `ReductionTilingStrategy::PartialOuterReduction`.

Changes in this PR
- Move the `ReductionTilingStrategy` enum out of
  `scf::SCFTilingOptions` and make them visible to `TilingInterface`.
- `PartialTilingInterface` changes
  - Pass the `tilingStrategy` used for partial reduction to
    `tileToPartialReduction`.
  - Pass the reduction dimension along as `const
    llvm::SetVector<unsigned> &`.
- Allow `scf::SCFTilingOptions` to set the reduction dimensions that
  are to be tiled.
- Change `structured.tiled_reduction_using_for` to allow specification
  of the reduction dimensions to be partially tiled.

Depends on https://github.com/llvm/llvm-project/pull/143217. Please review only the top commit for this PR.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@<!-- -->gmail.com>

---

Patch is 58.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143467.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+7-1) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+27-35) 
- (modified) mlir/include/mlir/Interfaces/TilingInterface.h (+21) 
- (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+8-6) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+30-11) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+8-7) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+93-76) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+112-150) 
- (modified) mlir/test/Dialect/Linalg/transform-tile-reduction.mlir (+165-2) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+1-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..d0591ae122fbb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1767,6 +1767,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
       - the result-combining op,
       - the parent `for` op.
 
+    The `reduction_dims` can be used to specify the subset of reduction dimensions
+    of the operation to tile. If left unspecified, all reduction dimensions are
+    tiled.
+
     #### Example:
 
     ```
@@ -1817,7 +1821,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
   let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_op,
                       TransformHandleTypeInterface:$combining_op,
@@ -1830,6 +1835,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
   let assemblyFormat = [{
     $target
+    (`reduction_dims` `=` $reduction_dims^)?
     `by` `tile_sizes` `=` $tile_sizes
     attr-dict
     `:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 33a43ce2ee7bb..01ad64b76b15e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify mapping of loops to devices. This is only respected when the loop
+  /// constructs support such a mapping (like `scf.forall`). Will be ignored
+  /// when using loop constructs that dont support such a mapping (like
+  /// `scf.for`)
+  SmallVector<Attribute> mappingVector = {};
+  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+    mappingVector = llvm::to_vector(mapping);
+    return *this;
+  }
+
+  //-------------------------------------------------------------------------//
+  // Options related reduction tiling
+  //-------------------------------------------------------------------------//
+
   /// Specify how reduction dimensions should be tiled.
-  ///
-  /// Tiling can be thought of as splitting a dimension into 2 and materializing
-  /// the outer dimension as a loop:
-  ///
-  /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
-  ///
-  /// For parallel dimensions, the split can only happen in one way, with both
-  /// dimensions being parallel. For reduction dimensions however, there is a
-  /// choice in how we split the reduction dimension. This enum exposes this
-  /// choice.
-  enum class ReductionTilingStrategy {
-    // [reduction] -> [reduction1, reduction2]
-    // -> loop[reduction1] { [reduction2] }
-    FullReduction,
-    // [reduction] -> [reduction1, parallel2]
-    // -> loop[reduction1] { [parallel2] }; merge[reduction1]
-    PartialReductionOuterReduction,
-    // [reduction] -> [parallel1, reduction2]
-    // -> loop[parallel1] { [reduction2] }; merge[parallel1]
-    PartialReductionOuterParallel
-  };
   ReductionTilingStrategy reductionStrategy =
       ReductionTilingStrategy::FullReduction;
   SCFTilingOptions &
@@ -115,13 +108,10 @@ struct SCFTilingOptions {
     return *this;
   }
 
-  /// Specify mapping of loops to devices. This is only respected when the loop
-  /// constructs support such a mapping (like `scf.forall`). Will be ignored
-  /// when using loop constructs that dont support such a mapping (like
-  /// `scf.for`)
-  SmallVector<Attribute> mappingVector = {};
-  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
-    mappingVector = llvm::to_vector(mapping);
+  /// Specify the reduction dimensions to be tiled.
+  SetVector<unsigned> reductionDims;
+  SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
+    reductionDims.insert(dims.begin(), dims.end());
     return *this;
   }
 };
@@ -136,15 +126,17 @@ struct SCFTilingResult {
   SmallVector<Value> initialValues;
   /// The `scf.for` operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
-  /// The result generated by the loop nest in tiling, may hold partial results,
-  /// which need to be merged to match the computation of the untiled operation.
-  /// `mergeResult` contains the operations used to perform this merge from
-  /// partial results and the values that can be used as replacements of
-  /// the untiled operation.
-  MergeResult mergeResult;
+  /// Values to use as replacements for the untiled op. Is the same size as the
+  /// number of results of the untiled op.
+  SmallVector<Value> replacements;
   /// Slices generated after tiling that can be used for fusing with the tiled
   /// producer.
   SmallVector<Operation *> generatedSlices;
+  /// In cases where there as an additional merge step after tiling
+  /// return the merged ops after tiling. This list is empty when reduction
+  /// tiling strategy is
+  /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
+  SmallVector<Operation *> mergeOps;
 };
 
 /// Method to tile an op that implements the `TilingInterface` using
@@ -362,7 +354,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
 /// ```
 FailureOr<scf::SCFTilingResult>
 tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
-                      ArrayRef<OpFoldResult> tileSize);
+                      ArrayRef<OpFoldResult> tileSizes);
 
 } // namespace scf
 } // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index b33aa1489c311..8693cbea7f0b0 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -36,6 +36,27 @@ struct TilingResult {
   SmallVector<Operation *> generatedSlices;
 };
 
+/// Tiling can be thought of as splitting a dimension into 2 and
+/// materializing the outer dimension as a loop:
+///
+/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+///
+/// For parallel dimensions, the split can only happen in one way, with both
+/// dimensions being parallel. For reduction dimensions however, there is a
+/// choice in how we split the reduction dimension. This enum exposes this
+/// choice.
+enum class ReductionTilingStrategy {
+  // [reduction] -> [reduction1, reduction2]
+  // -> loop[reduction1] { [reduction2] }
+  FullReduction,
+  // [reduction] -> [reduction1, parallel2]
+  // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+  PartialReductionOuterReduction,
+  // [reduction] -> [parallel1, reduction2]
+  // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+  PartialReductionOuterParallel
+};
+
 /// Container for the result of merge operation of tiling.
 /// - `mergeOps` contains operations created during the merge.
 /// - `replacements` contains the values that represents the result of the
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 50b69b8f8d833..9358d8b82abce 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -363,7 +363,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
   ];
 }
 
-def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
+def PartialReductionOpInterface :
+    OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
   let description = [{
     Interface for allowing operations to expose information needed to
     tile reductions using partial reduction followed by merge. This is
@@ -383,7 +384,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
             "::mlir::OpBuilder &":$b,
             "Location":$loc,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
-            "::mlir::ArrayRef<int>":$reductionDim),
+            "const ::mlir::SetVector<unsigned> &":$reductionDim),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -401,10 +402,11 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
         /*args=*/(ins
             "::mlir::OpBuilder &":$b,
             "Location ":$loc,
+            "::mlir::ReductionTilingStrategy":$tilingStrategy,
             "ValueRange":$init,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
-            "::mlir::ArrayRef<int>":$reductionDims),
+            "const ::llvm::SetVector<unsigned> &":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -422,7 +424,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
             "::mlir::OpBuilder &":$b,
             "Location ":$loc,
             "ValueRange":$partialReduce,
-            "::mlir::ArrayRef<int>":$reductionDim),
+            "const ::mlir::SetVector<unsigned> &":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -442,9 +444,9 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
             "unsigned":$resultNumber,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
             "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+            "const ::mlir::SetVector<unsigned> &":$reductionDims,
             "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
-            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
-            "::mlir::ArrayRef<int>":$reductionDims),
+            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1c3b621828315..c003825264920 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2381,7 +2381,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
     return emitDefaultDefiniteFailure(target);
 
   if (target->getNumResults())
-    rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
+    rewriter.replaceOp(target, maybeTilingResult->replacements);
   else
     rewriter.eraseOp(target);
 
@@ -2775,10 +2775,11 @@ void transform::TileReductionUsingForOp::build(
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   MLIRContext *ctx = builder.getContext();
   auto opTy = transform::AnyOpType::get(ctx);
-  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+  auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
   build(builder, result,
         /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
         /*target=*/target,
+        /*reduction_dims=*/nullptr,
         /*tile_sizes=*/staticTileSizesAttr);
 }
 
@@ -2794,18 +2795,36 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
         target->getLoc(),
         "Operation should implement PartialReductionOpInterface");
   }
-  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
-      rewriter, partialReductionOp,
-      getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
-  if (failed(result))
-    return emitDefaultSilenceableFailure(target);
-  rewriter.replaceOp(target, result->mergeResult.replacements);
+  SmallVector<unsigned> reductionDims =
+      extractFromIntegerArrayAttr<unsigned>(getReductionDims());
+  if (reductionDims.empty()) {
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+  }
+
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+  options.setReductionTilingStrategy(
+      ReductionTilingStrategy::PartialReductionOuterReduction);
+  options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
+  options.setReductionDims(reductionDims);
+  FailureOr<scf::SCFTilingResult> result =
+      scf::tileUsingSCF(rewriter, partialReductionOp, options);
+
+  if (failed(result)) {
+    return emitSilenceableFailure(getLoc(),
+                                  "failed to tile using partial reduction");
+  }
+  rewriter.replaceOp(target, result->replacements);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
   for (auto parallelTiledOp : result->tiledOps)
     results.push_back(parallelTiledOp);
-  for (auto mergeOp : result->mergeResult.mergeOps)
+  for (auto mergeOp : result->mergeOps)
     results.push_back(mergeOp);
   results.push_back(result->loops.front());
   return DiagnosedSilenceableFailure::success();
@@ -3229,7 +3248,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
     if (failed(maybeTilingResult))
       return DiagnosedSilenceableFailure::definiteFailure();
 
-    rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
+    rewriter.replaceOp(op, maybeTilingResult->replacements);
 
     tiled.append(maybeTilingResult->tiledOps);
     for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3465,7 +3484,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
 
-  rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
+  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
 
   tilingResult = *maybeTilingResult;
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4162aa0b71e6d..8a5a2e54cdda2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
 }
 
 FailureOr<StaticContinuousTileSizeSpecification>
-mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
-                                               unsigned dimension,
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
                                                unsigned targetSize) {
 
   assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
 
   // Find the trip count of the iteration space dimension for which the tile
   // sizes are computed.
-  Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
-                                                    loopRanges[dimension].size);
+  Value loopRange =
+      getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
   ContinuousTileSizeSpecification spec;
 
   // Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
     return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
                                     "many elements as number of threads");
-  int reductionDim = static_cast<int>(redDims.front());
 
   if (redDims.front() >= numThreads.size())
     return b.notifyMatchFailure(
         op, "reduction dimension must be mapped to threads");
 
   // 1. Create the inital tensor value.
+  unsigned reductionDim = redDims.front();
+  SetVector<unsigned> reductionDims;
+  reductionDims.insert(reductionDim);
   FailureOr<SmallVector<Value>> maybeInitTensors =
       op.generateInitialTensorForPartialReduction(b, loc, numThreads,
-                                                  reductionDim);
+                                                  reductionDims);
   if (failed(maybeInitTensors))
     return b.notifyMatchFailure(
         op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   // 7. Merge the partial reductions.
   b.setInsertionPointAfter(forallOp);
   FailureOr<MergeResult> mergeResult =
-      op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
+      op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
   if (failed(mergeResult)) {
     return failure();
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..f649bc49a8fbd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include <optional>
@@ -327,23 +328,48 @@ struct LinalgOpTilingInterface
 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
 
-/// Return an AffineMap for a partial result for the given result number,
-/// assuming the partial tiling strategy is outer-reduction loop +
-/// inner-parallel tile. The returned AffineMap can be used as the replacement
-/// AffineMap for the inner-parallel tile linalg op for the given result number.
-///
-/// The new AffineMap is the old AffineMap with reduction dimensions appended
-/// at end.
-static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
-                                           ArrayRef<int> reductionDims,
-                                           unsigned resultNumber) {
-  AffineMap map =
-      linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
-  for (int redPos : reductionDims) {
-    map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
-                           map.getNumResults());
+/// Return an AffineMaps to use for the `outs` operands of the linalg op
+/// generated for partial results. The new AffineMap is the AffineMap of the
+/// untiled op with reduction dimensions appended at end in order in which they
+/// were specified during tiling.
+static SmallVector<AffineMap>
+getPartialResultAffineMaps(LinalgOp linalgOp,
+                           const SetVector<unsigned> &reductionDims) {
+  auto partialReductionMaps = llvm::map_to_vector(
+      linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
+        AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
+        for (auto redPos : reductionDims) {
+          map =
+              map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
+                               map.getNumResults());
+        }
+        return map;
+      });
+  return partialReductionMaps;
+}
+
+/// Return the slice of the `initValue` to use as input to the partial reduction
+/// op generated.
+static Operation *getInitSliceForOuterReduction(
+    OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
+    AffineMap partialReductionMap) {
+  int64_t initRank = partialReductionMap.getNumResults();
+  SmallVector<OpFoldResult> initOffsets, initSizes;
+  SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1));
+  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+    unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    if (reductionDims.cont...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/143467


More information about the Mlir-commits mailing list