[Mlir-commits] [mlir] [mlir][PartialReductionTilingInterface] Generalize implementation of `tileUsingSCF` for `ReductionTilingStrategy::PartialOuterReduction`. (PR #143467)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 9 19:08:35 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir
Author: None (MaheshRavishankar)
<details>
<summary>Changes</summary>
This is a precursor to generalizing the `tileUsingSCF` to handle
`ReductionTilingStrategy::PartialOuterParallel` strategy. This change
itself is generalizing/refactoring the current implementation that
supports only `ReductionTilingStrategy::PartialOuterReduction`.
Changes in this PR
- Move the `ReductionTilingStrategy` enum out of
`scf::SCFTilingOptions` and make them visible to `TilingInterface`.
- `PartialTilingInterface` changes
- Pass the `tilingStrategy` used for partial reduction to
`tileToPartialReduction`.
- Pass the reduction dimension along as `const
llvm::SetVector<unsigned> &`.
- Allow `scf::SCFTilingOptions` to set the reduction dimensions that
are to be tiled.
- Change `structured.tiled_reduction_using_for` to allow specification
of the reduction dimensions to be partially tiled.
Depends on https://github.com/llvm/llvm-project/pull/143217. Please review only the top commit for this PR.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@<!-- -->gmail.com>
---
Patch is 58.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143467.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+7-1)
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+27-35)
- (modified) mlir/include/mlir/Interfaces/TilingInterface.h (+21)
- (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+8-6)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+30-11)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+8-7)
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+93-76)
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+112-150)
- (modified) mlir/test/Dialect/Linalg/transform-tile-reduction.mlir (+165-2)
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+1-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..d0591ae122fbb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1767,6 +1767,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
- the result-combining op,
- the parent `for` op.
+ The `reduction_dims` can be used to specify the subset of reduction dimensions
+ of the operation to tile. If left unspecified, all reduction dimensions are
+ tiled.
+
#### Example:
```
@@ -1817,7 +1821,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
// TODO: support mixed static-dynamic (see TileUsingForallOp).
let arguments = (ins TransformHandleTypeInterface:$target,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
TransformHandleTypeInterface:$split_op,
TransformHandleTypeInterface:$combining_op,
@@ -1830,6 +1835,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
let assemblyFormat = [{
$target
+ (`reduction_dims` `=` $reduction_dims^)?
`by` `tile_sizes` `=` $tile_sizes
attr-dict
`:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 33a43ce2ee7bb..01ad64b76b15e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
return *this;
}
+ /// Specify mapping of loops to devices. This is only respected when the loop
+ /// constructs support such a mapping (like `scf.forall`). Will be ignored
+ /// when using loop constructs that dont support such a mapping (like
+ /// `scf.for`)
+ SmallVector<Attribute> mappingVector = {};
+ SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+ mappingVector = llvm::to_vector(mapping);
+ return *this;
+ }
+
+ //-------------------------------------------------------------------------//
+ // Options related reduction tiling
+ //-------------------------------------------------------------------------//
+
/// Specify how reduction dimensions should be tiled.
- ///
- /// Tiling can be thought of as splitting a dimension into 2 and materializing
- /// the outer dimension as a loop:
- ///
- /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
- ///
- /// For parallel dimensions, the split can only happen in one way, with both
- /// dimensions being parallel. For reduction dimensions however, there is a
- /// choice in how we split the reduction dimension. This enum exposes this
- /// choice.
- enum class ReductionTilingStrategy {
- // [reduction] -> [reduction1, reduction2]
- // -> loop[reduction1] { [reduction2] }
- FullReduction,
- // [reduction] -> [reduction1, parallel2]
- // -> loop[reduction1] { [parallel2] }; merge[reduction1]
- PartialReductionOuterReduction,
- // [reduction] -> [parallel1, reduction2]
- // -> loop[parallel1] { [reduction2] }; merge[parallel1]
- PartialReductionOuterParallel
- };
ReductionTilingStrategy reductionStrategy =
ReductionTilingStrategy::FullReduction;
SCFTilingOptions &
@@ -115,13 +108,10 @@ struct SCFTilingOptions {
return *this;
}
- /// Specify mapping of loops to devices. This is only respected when the loop
- /// constructs support such a mapping (like `scf.forall`). Will be ignored
- /// when using loop constructs that dont support such a mapping (like
- /// `scf.for`)
- SmallVector<Attribute> mappingVector = {};
- SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
- mappingVector = llvm::to_vector(mapping);
+ /// Specify the reduction dimensions to be tiled.
+ SetVector<unsigned> reductionDims;
+ SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
+ reductionDims.insert(dims.begin(), dims.end());
return *this;
}
};
@@ -136,15 +126,17 @@ struct SCFTilingResult {
SmallVector<Value> initialValues;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
- /// The result generated by the loop nest in tiling, may hold partial results,
- /// which need to be merged to match the computation of the untiled operation.
- /// `mergeResult` contains the operations used to perform this merge from
- /// partial results and the values that can be used as replacements of
- /// the untiled operation.
- MergeResult mergeResult;
+ /// Values to use as replacements for the untiled op. Is the same size as the
+ /// number of results of the untiled op.
+ SmallVector<Value> replacements;
/// Slices generated after tiling that can be used for fusing with the tiled
/// producer.
SmallVector<Operation *> generatedSlices;
+ /// In cases where there as an additional merge step after tiling
+ /// return the merged ops after tiling. This list is empty when reduction
+ /// tiling strategy is
+ /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
+ SmallVector<Operation *> mergeOps;
};
/// Method to tile an op that implements the `TilingInterface` using
@@ -362,7 +354,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
/// ```
FailureOr<scf::SCFTilingResult>
tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
- ArrayRef<OpFoldResult> tileSize);
+ ArrayRef<OpFoldResult> tileSizes);
} // namespace scf
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index b33aa1489c311..8693cbea7f0b0 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -36,6 +36,27 @@ struct TilingResult {
SmallVector<Operation *> generatedSlices;
};
+/// Tiling can be thought of as splitting a dimension into 2 and
+/// materializing the outer dimension as a loop:
+///
+/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+///
+/// For parallel dimensions, the split can only happen in one way, with both
+/// dimensions being parallel. For reduction dimensions however, there is a
+/// choice in how we split the reduction dimension. This enum exposes this
+/// choice.
+enum class ReductionTilingStrategy {
+ // [reduction] -> [reduction1, reduction2]
+ // -> loop[reduction1] { [reduction2] }
+ FullReduction,
+ // [reduction] -> [reduction1, parallel2]
+ // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+ PartialReductionOuterReduction,
+ // [reduction] -> [parallel1, reduction2]
+ // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+ PartialReductionOuterParallel
+};
+
/// Container for the result of merge operation of tiling.
/// - `mergeOps` contains operations created during the merge.
/// - `replacements` contains the values that represents the result of the
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 50b69b8f8d833..9358d8b82abce 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -363,7 +363,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
];
}
-def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
+def PartialReductionOpInterface :
+ OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
let description = [{
Interface for allowing operations to expose information needed to
tile reductions using partial reduction followed by merge. This is
@@ -383,7 +384,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
"::mlir::OpBuilder &":$b,
"Location":$loc,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
- "::mlir::ArrayRef<int>":$reductionDim),
+ "const ::mlir::SetVector<unsigned> &":$reductionDim),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -401,10 +402,11 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"Location ":$loc,
+ "::mlir::ReductionTilingStrategy":$tilingStrategy,
"ValueRange":$init,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
- "::mlir::ArrayRef<int>":$reductionDims),
+ "const ::llvm::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -422,7 +424,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
"::mlir::OpBuilder &":$b,
"Location ":$loc,
"ValueRange":$partialReduce,
- "::mlir::ArrayRef<int>":$reductionDim),
+ "const ::mlir::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -442,9 +444,9 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
"unsigned":$resultNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+ "const ::mlir::SetVector<unsigned> &":$reductionDims,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
- "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
- "::mlir::ArrayRef<int>":$reductionDims),
+ "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1c3b621828315..c003825264920 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2381,7 +2381,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultDefiniteFailure(target);
if (target->getNumResults())
- rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
+ rewriter.replaceOp(target, maybeTilingResult->replacements);
else
rewriter.eraseOp(target);
@@ -2775,10 +2775,11 @@ void transform::TileReductionUsingForOp::build(
// TODO: support mixed static-dynamic (see TileUsingForallOp).
MLIRContext *ctx = builder.getContext();
auto opTy = transform::AnyOpType::get(ctx);
- auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+ auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
build(builder, result,
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
/*target=*/target,
+ /*reduction_dims=*/nullptr,
/*tile_sizes=*/staticTileSizesAttr);
}
@@ -2794,18 +2795,36 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
target->getLoc(),
"Operation should implement PartialReductionOpInterface");
}
- FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
- rewriter, partialReductionOp,
- getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
- if (failed(result))
- return emitDefaultSilenceableFailure(target);
- rewriter.replaceOp(target, result->mergeResult.replacements);
+ SmallVector<unsigned> reductionDims =
+ extractFromIntegerArrayAttr<unsigned>(getReductionDims());
+ if (reductionDims.empty()) {
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ }
+
+ scf::SCFTilingOptions options;
+ options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+ options.setReductionTilingStrategy(
+ ReductionTilingStrategy::PartialReductionOuterReduction);
+ options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
+ options.setReductionDims(reductionDims);
+ FailureOr<scf::SCFTilingResult> result =
+ scf::tileUsingSCF(rewriter, partialReductionOp, options);
+
+ if (failed(result)) {
+ return emitSilenceableFailure(getLoc(),
+ "failed to tile using partial reduction");
+ }
+ rewriter.replaceOp(target, result->replacements);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
for (auto parallelTiledOp : result->tiledOps)
results.push_back(parallelTiledOp);
- for (auto mergeOp : result->mergeResult.mergeOps)
+ for (auto mergeOp : result->mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
@@ -3229,7 +3248,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure::definiteFailure();
- rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
+ rewriter.replaceOp(op, maybeTilingResult->replacements);
tiled.append(maybeTilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3465,7 +3484,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);
- rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
+ rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
tilingResult = *maybeTilingResult;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4162aa0b71e6d..8a5a2e54cdda2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
}
FailureOr<StaticContinuousTileSizeSpecification>
-mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
- unsigned dimension,
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
unsigned targetSize) {
assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
// Find the trip count of the iteration space dimension for which the tile
// sizes are computed.
- Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
- loopRanges[dimension].size);
+ Value loopRange =
+ getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
ContinuousTileSizeSpecification spec;
// Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
"many elements as number of threads");
- int reductionDim = static_cast<int>(redDims.front());
if (redDims.front() >= numThreads.size())
return b.notifyMatchFailure(
op, "reduction dimension must be mapped to threads");
// 1. Create the inital tensor value.
+ unsigned reductionDim = redDims.front();
+ SetVector<unsigned> reductionDims;
+ reductionDims.insert(reductionDim);
FailureOr<SmallVector<Value>> maybeInitTensors =
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
- reductionDim);
+ reductionDims);
if (failed(maybeInitTensors))
return b.notifyMatchFailure(
op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 7. Merge the partial reductions.
b.setInsertionPointAfter(forallOp);
FailureOr<MergeResult> mergeResult =
- op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
+ op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
if (failed(mergeResult)) {
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..f649bc49a8fbd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include <optional>
@@ -327,23 +328,48 @@ struct LinalgOpTilingInterface
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
//===----------------------------------------------------------------------===//
-/// Return an AffineMap for a partial result for the given result number,
-/// assuming the partial tiling strategy is outer-reduction loop +
-/// inner-parallel tile. The returned AffineMap can be used as the replacement
-/// AffineMap for the inner-parallel tile linalg op for the given result number.
-///
-/// The new AffineMap is the old AffineMap with reduction dimensions appended
-/// at end.
-static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
- ArrayRef<int> reductionDims,
- unsigned resultNumber) {
- AffineMap map =
- linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
- for (int redPos : reductionDims) {
- map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
- map.getNumResults());
+/// Return an AffineMaps to use for the `outs` operands of the linalg op
+/// generated for partial results. The new AffineMap is the AffineMap of the
+/// untiled op with reduction dimensions appended at end in order in which they
+/// were specified during tiling.
+static SmallVector<AffineMap>
+getPartialResultAffineMaps(LinalgOp linalgOp,
+ const SetVector<unsigned> &reductionDims) {
+ auto partialReductionMaps = llvm::map_to_vector(
+ linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
+ AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
+ for (auto redPos : reductionDims) {
+ map =
+ map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
+ map.getNumResults());
+ }
+ return map;
+ });
+ return partialReductionMaps;
+}
+
+/// Return the slice of the `initValue` to use as input to the partial reduction
+/// op generated.
+static Operation *getInitSliceForOuterReduction(
+ OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
+ AffineMap partialReductionMap) {
+ int64_t initRank = partialReductionMap.getNumResults();
+ SmallVector<OpFoldResult> initOffsets, initSizes;
+ SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1));
+ for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+ if (reductionDims.cont...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/143467
More information about the Mlir-commits
mailing list