[Mlir-commits] [mlir] 7181785 - [mlir][PartialReductionTilingInterface] Generalize implementation of `tileUsingSCF` for `ReductionTilingStrategy::PartialOuterReduction`. (#143467)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 23 11:23:50 PDT 2025
Author: MaheshRavishankar
Date: 2025-06-23T11:23:46-07:00
New Revision: 71817856f7f4c407d76a12fbbdde9ac3e89dd0a1
URL: https://github.com/llvm/llvm-project/commit/71817856f7f4c407d76a12fbbdde9ac3e89dd0a1
DIFF: https://github.com/llvm/llvm-project/commit/71817856f7f4c407d76a12fbbdde9ac3e89dd0a1.diff
LOG: [mlir][PartialReductionTilingInterface] Generalize implementation of `tileUsingSCF` for `ReductionTilingStrategy::PartialOuterReduction`. (#143467)
This is a precursor to generalizing the `tileUsingSCF` to handle
`ReductionTilingStrategy::PartialOuterParallel` strategy. This change
itself is generalizing/refactoring the current implementation that
supports only `ReductionTilingStrategy::PartialOuterReduction`.
Changes in this PR
- Move the `ReductionTilingStrategy` enum out of
`scf::SCFTilingOptions` and make them visible to `TilingInterface`.
- `PartialTilingInterface` changes
- Pass the `tilingStrategy` used for partial reduction to
`tileToPartialReduction`.
- Pass the reduction dimension along as `const
llvm::SetVector<unsigned> &`.
- Allow `scf::SCFTilingOptions` to set the reduction dimensions that
are to be tiled.
- Change `structured.tiled_reduction_using_for` to allow specification
of the reduction dimensions to be partially tiled.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/include/mlir/Interfaces/TilingInterface.h
mlir/include/mlir/Interfaces/TilingInterface.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c5650470fdc8d..38c8734c47381 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1859,6 +1859,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
- the result-combining op,
- the parent `for` op.
+ The `reduction_dims` can be used to specify the subset of reduction dimensions
+ of the operation to tile. If left unspecified, all reduction dimensions are
+ tiled.
+
#### Example:
```
@@ -1909,7 +1913,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
// TODO: support mixed static-dynamic (see TileUsingForallOp).
let arguments = (ins TransformHandleTypeInterface:$target,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
TransformHandleTypeInterface:$split_op,
TransformHandleTypeInterface:$combining_op,
@@ -1922,6 +1927,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
let assemblyFormat = [{
$target
+ (`reduction_dims` `=` $reduction_dims^)?
`by` `tile_sizes` `=` $tile_sizes
attr-dict
`:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..9feb04dbe03c1 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
return *this;
}
+ /// Specify mapping of loops to devices. This is only respected when the loop
+ /// constructs support such a mapping (like `scf.forall`). Will be ignored
+ /// when using loop constructs that dont support such a mapping (like
+ /// `scf.for`)
+ SmallVector<Attribute> mappingVector = {};
+ SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+ mappingVector = llvm::to_vector(mapping);
+ return *this;
+ }
+
+ //-------------------------------------------------------------------------//
+ // Options related reduction tiling
+ //-------------------------------------------------------------------------//
+
/// Specify how reduction dimensions should be tiled.
- ///
- /// Tiling can be thought of as splitting a dimension into 2 and materializing
- /// the outer dimension as a loop:
- ///
- /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
- ///
- /// For parallel dimensions, the split can only happen in one way, with both
- /// dimensions being parallel. For reduction dimensions however, there is a
- /// choice in how we split the reduction dimension. This enum exposes this
- /// choice.
- enum class ReductionTilingStrategy {
- // [reduction] -> [reduction1, reduction2]
- // -> loop[reduction1] { [reduction2] }
- FullReduction,
- // [reduction] -> [reduction1, parallel2]
- // -> loop[reduction1] { [parallel2] }; merge[reduction1]
- PartialReductionOuterReduction,
- // [reduction] -> [parallel1, reduction2]
- // -> loop[parallel1] { [reduction2] }; merge[parallel1]
- PartialReductionOuterParallel
- };
ReductionTilingStrategy reductionStrategy =
ReductionTilingStrategy::FullReduction;
SCFTilingOptions &
@@ -115,13 +108,13 @@ struct SCFTilingOptions {
return *this;
}
- /// Specify mapping of loops to devices. This is only respected when the loop
- /// constructs support such a mapping (like `scf.forall`). Will be ignored
- /// when using loop constructs that dont support such a mapping (like
- /// `scf.for`)
- SmallVector<Attribute> mappingVector = {};
- SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
- mappingVector = llvm::to_vector(mapping);
+ /// Specify the reduction dimensions to be tiled. Note that this needs to be
+ /// specified. If left unspecified, then none of the reduction dimensions are
+ /// tiled.
+ SetVector<unsigned> reductionDims;
+ SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
+ reductionDims.clear();
+ reductionDims.insert(dims.begin(), dims.end());
return *this;
}
};
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index b33aa1489c311..8693cbea7f0b0 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -36,6 +36,27 @@ struct TilingResult {
SmallVector<Operation *> generatedSlices;
};
+/// Tiling can be thought of as splitting a dimension into 2 and
+/// materializing the outer dimension as a loop:
+///
+/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
+///
+/// For parallel dimensions, the split can only happen in one way, with both
+/// dimensions being parallel. For reduction dimensions however, there is a
+/// choice in how we split the reduction dimension. This enum exposes this
+/// choice.
+enum class ReductionTilingStrategy {
+ // [reduction] -> [reduction1, reduction2]
+ // -> loop[reduction1] { [reduction2] }
+ FullReduction,
+ // [reduction] -> [reduction1, parallel2]
+ // -> loop[reduction1] { [parallel2] }; merge[reduction1]
+ PartialReductionOuterReduction,
+ // [reduction] -> [parallel1, reduction2]
+ // -> loop[parallel1] { [reduction2] }; merge[parallel1]
+ PartialReductionOuterParallel
+};
+
/// Container for the result of merge operation of tiling.
/// - `mergeOps` contains operations created during the merge.
/// - `replacements` contains the values that represents the result of the
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..43a27e1cb6cdf 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -384,7 +384,7 @@ def PartialReductionOpInterface :
"::mlir::OpBuilder &":$b,
"Location":$loc,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
- "::mlir::ArrayRef<int>":$reductionDim),
+ "const ::mlir::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -402,10 +402,11 @@ def PartialReductionOpInterface :
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"Location ":$loc,
+ "::mlir::ReductionTilingStrategy":$tilingStrategy,
"ValueRange":$init,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
- "::mlir::ArrayRef<int>":$reductionDims),
+ "const ::llvm::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -423,7 +424,7 @@ def PartialReductionOpInterface :
"::mlir::OpBuilder &":$b,
"Location ":$loc,
"ValueRange":$partialReduce,
- "::mlir::ArrayRef<int>":$reductionDim),
+ "const ::mlir::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -443,9 +444,9 @@ def PartialReductionOpInterface :
"unsigned":$resultNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+ "const ::mlir::SetVector<unsigned> &":$reductionDims,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
- "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
- "::mlir::ArrayRef<int>":$reductionDims),
+ "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d9a0ba02f4fe4..f2b7b34256847 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2947,10 +2947,11 @@ void transform::TileReductionUsingForOp::build(
// TODO: support mixed static-dynamic (see TileUsingForallOp).
MLIRContext *ctx = builder.getContext();
auto opTy = transform::AnyOpType::get(ctx);
- auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+ auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
build(builder, result,
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
/*target=*/target,
+ /*reduction_dims=*/nullptr,
/*tile_sizes=*/staticTileSizesAttr);
}
@@ -2966,12 +2967,30 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
target->getLoc(),
"Operation should implement PartialReductionOpInterface");
}
- FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
- rewriter, partialReductionOp,
- getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
- if (failed(result))
- return emitDefaultSilenceableFailure(target);
+ SmallVector<unsigned> reductionDims =
+ extractFromIntegerArrayAttr<unsigned>(getReductionDims());
+ if (reductionDims.empty()) {
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ }
+
+ scf::SCFTilingOptions options;
+ options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
+ options.setReductionTilingStrategy(
+ ReductionTilingStrategy::PartialReductionOuterReduction);
+ options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
+ options.setReductionDims(reductionDims);
+ FailureOr<scf::SCFTilingResult> result =
+ scf::tileUsingSCF(rewriter, partialReductionOp, options);
+
+ if (failed(result)) {
+ return emitSilenceableFailure(getLoc(),
+ "failed to tile using partial reduction");
+ }
rewriter.replaceOp(target, result->replacements);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4162aa0b71e6d..8a5a2e54cdda2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
}
FailureOr<StaticContinuousTileSizeSpecification>
-mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
- unsigned dimension,
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
unsigned targetSize) {
assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
// Find the trip count of the iteration space dimension for which the tile
// sizes are computed.
- Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
- loopRanges[dimension].size);
+ Value loopRange =
+ getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
ContinuousTileSizeSpecification spec;
// Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
"many elements as number of threads");
- int reductionDim = static_cast<int>(redDims.front());
if (redDims.front() >= numThreads.size())
return b.notifyMatchFailure(
op, "reduction dimension must be mapped to threads");
// 1. Create the inital tensor value.
+ unsigned reductionDim = redDims.front();
+ SetVector<unsigned> reductionDims;
+ reductionDims.insert(reductionDim);
FailureOr<SmallVector<Value>> maybeInitTensors =
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
- reductionDim);
+ reductionDims);
if (failed(maybeInitTensors))
return b.notifyMatchFailure(
op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 7. Merge the partial reductions.
b.setInsertionPointAfter(forallOp);
FailureOr<MergeResult> mergeResult =
- op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
+ op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
if (failed(mergeResult)) {
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..f649bc49a8fbd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include <optional>
@@ -327,23 +328,48 @@ struct LinalgOpTilingInterface
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
//===----------------------------------------------------------------------===//
-/// Return an AffineMap for a partial result for the given result number,
-/// assuming the partial tiling strategy is outer-reduction loop +
-/// inner-parallel tile. The returned AffineMap can be used as the replacement
-/// AffineMap for the inner-parallel tile linalg op for the given result number.
-///
-/// The new AffineMap is the old AffineMap with reduction dimensions appended
-/// at end.
-static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
- ArrayRef<int> reductionDims,
- unsigned resultNumber) {
- AffineMap map =
- linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
- for (int redPos : reductionDims) {
- map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
- map.getNumResults());
+/// Return an AffineMaps to use for the `outs` operands of the linalg op
+/// generated for partial results. The new AffineMap is the AffineMap of the
+/// untiled op with reduction dimensions appended at end in order in which they
+/// were specified during tiling.
+static SmallVector<AffineMap>
+getPartialResultAffineMaps(LinalgOp linalgOp,
+ const SetVector<unsigned> &reductionDims) {
+ auto partialReductionMaps = llvm::map_to_vector(
+ linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
+ AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
+ for (auto redPos : reductionDims) {
+ map =
+ map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
+ map.getNumResults());
+ }
+ return map;
+ });
+ return partialReductionMaps;
+}
+
+/// Return the slice of the `initValue` to use as input to the partial reduction
+/// op generated.
+static Operation *getInitSliceForOuterReduction(
+ OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
+ AffineMap partialReductionMap) {
+ int64_t initRank = partialReductionMap.getNumResults();
+ SmallVector<OpFoldResult> initOffsets, initSizes;
+ SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1));
+ for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+ if (reductionDims.contains(dim)) {
+ initOffsets.push_back(b.getIndexAttr(0));
+ } else {
+ initOffsets.push_back(offsets[dim]);
+ }
+ initSizes.push_back(sizes[dim]);
}
- return map;
+ // TODO: Use SubsetExtractOpInterface here once available.
+ auto extractSlice = b.create<tensor::ExtractSliceOp>(
+ loc, initValue, initOffsets, initSizes, initStrides);
+ return extractSlice;
}
/// External model implementation of PartialReductionInterface for
@@ -354,13 +380,16 @@ struct LinalgOpPartialReductionInterface
LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
- ArrayRef<int> reductionDims) const {
+ const SetVector<unsigned> &reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
- OpBuilder::InsertionGuard guard(b);
+ OpBuilder::InsertionGuard guard(b);
if (linalgOp.hasPureBufferSemantics())
return op->emitOpError("expected operation to have tensor semantics");
+ SmallVector<AffineMap> partialResultMaps =
+ getPartialResultAffineMaps(linalgOp, reductionDims);
+
// LinalgOp implements TilingInterface.
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
SmallVector<OpFoldResult> shape =
@@ -377,8 +406,8 @@ struct LinalgOpPartialReductionInterface
}
SmallVector<Value> inits;
- for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
- ++initIdx) {
+ for (auto [initIdx, result, partialMap] :
+ llvm::enumerate(linalgOp->getResults(), partialResultMaps)) {
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
combinerOps) ||
@@ -392,16 +421,13 @@ struct LinalgOpPartialReductionInterface
"Failed to get an identity value for the reduction operation.");
// Append the new partial result dimensions.
- AffineMap partialMap =
- getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
SmallVector<OpFoldResult> partialResultShape;
for (AffineExpr dimExpr : partialMap.getResults()) {
auto dim = cast<AffineDimExpr>(dimExpr);
partialResultShape.push_back(tiledShape[dim.getPosition()]);
}
- Type elType =
- getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
+ Type elType = getElementTypeOrSelf(result.getType());
Value emptyTensor =
b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
@@ -415,23 +441,25 @@ struct LinalgOpPartialReductionInterface
FailureOr<TilingResult>
tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
+ ReductionTilingStrategy tilingStrategy,
ValueRange init, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- ArrayRef<int> reductionDims) const {
+ const SetVector<unsigned> &reductionDims) const {
+ if (tilingStrategy !=
+ ReductionTilingStrategy::PartialReductionOuterReduction) {
+ // TODO: Add support for `PartialReductionOuterParallel` strategy.
+ return op->emitOpError("unsupported partial reduction tiling with "
+ "`PartialReductionOuterParallel` strategy");
+ }
OpBuilder::InsertionGuard guard(b);
auto linalgOp = cast<LinalgOp>(op);
+ SmallVector<AffineMap> partialReductionMaps =
+ getPartialResultAffineMaps(linalgOp, reductionDims);
+
// Step 1. Extend init maps to have reduction dimension dims, since we
// are converting them to parallel dimensions.
- SmallVector<AffineMap> newInitMaps;
- newInitMaps.reserve(linalgOp.getNumDpsInits());
- for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
- // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
- // this with a for range loop when we have it.
- AffineMap newMap =
- getPartialResultAffineMap(linalgOp, reductionDims, idx);
- newInitMaps.push_back(newMap);
- }
+ SmallVector<AffineMap> newInitMaps = partialReductionMaps;
// Step 2a: Extract a slice of the input operands.
SmallVector<Value> tiledInputs = makeTiledShapes(
@@ -443,31 +471,21 @@ struct LinalgOpPartialReductionInterface
// Step 2b: Extract a slice of the init operands.
SmallVector<Value, 1> tiledInits;
- for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {
- int64_t initRank = valueMap.getNumResults();
- SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0));
- SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1));
- SmallVector<OpFoldResult> initSizes;
- for (AffineExpr dimExpr : valueMap.getResults()) {
- auto dim = cast<AffineDimExpr>(dimExpr);
- initSizes.push_back(sizes[dim.getPosition()]);
- }
- // TODO: Use SubsetExtractOpInterface here once available.
- auto extractSlice = b.create<tensor::ExtractSliceOp>(
- loc, valueToTile, initOffset, initSizes, initStride);
- tiledInits.push_back(extractSlice);
- generatedSlices.push_back(extractSlice);
+ for (auto [partialReductionMap, valueToTile] :
+ llvm::zip_equal(partialReductionMaps, init)) {
+ Operation *sliceOp =
+ getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes,
+ reductionDims, partialReductionMap);
+ tiledInits.push_back(sliceOp->getResult(0));
+ generatedSlices.push_back(sliceOp);
}
// Update the indexing maps.
SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
- // Change the init maps.
- for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
- // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
- // this with a for range loop when we have it.
- OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
- int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
- newMaps[mapIdx] = newInitMaps[idx];
+ for (auto [initOperand, newInitMap] :
+ llvm::zip_equal(linalgOp.getDpsInitsMutable(), newInitMaps)) {
+ int mapIdx = linalgOp.getIndexingMapIndex(&initOperand);
+ newMaps[mapIdx] = newInitMap;
}
// Step 3. Change the reduction dim iterator types.
@@ -477,9 +495,9 @@ struct LinalgOpPartialReductionInterface
newIteratorTypes[dim] = utils::IteratorType::parallel;
// Step 4. Create the new generic op.
- auto genericOp =
- b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
- tiledInits, newMaps, newIteratorTypes);
+ auto resultTypes = ValueRange(tiledInits).getTypes();
+ auto genericOp = b.create<GenericOp>(loc, resultTypes, tiledInputs,
+ tiledInits, newMaps, newIteratorTypes);
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
@@ -490,23 +508,24 @@ struct LinalgOpPartialReductionInterface
generatedSlices};
}
- FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
- Location loc, ValueRange partialReduce,
- ArrayRef<int> reductionDims) const {
+ FailureOr<MergeResult>
+ mergeReductions(Operation *op, OpBuilder &b, Location loc,
+ ValueRange partialReduce,
+ const SetVector<unsigned> &reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
+ SmallVector<AffineMap> partialReductionMaps =
+ getPartialResultAffineMaps(linalgOp, reductionDims);
// Permute the reduction dims as permuted by the partial result map.
-
- int64_t numInits = linalgOp.getNumDpsInits();
SmallVector<Operation *> mergeOperations;
SmallVector<Value> replacements;
- for (int idx : llvm::seq(numInits)) {
+ for (auto [idx, init, partialResult, partialMap] : llvm::enumerate(
+ linalgOp.getDpsInits(), partialReduce, partialReductionMaps)) {
+ unsigned initIdx = idx;
// linalg.reduce's iteration space is the tiled result's iteration space
// (and not the tiled operation's iteration space). To account for this,
// permute the reduction dimensions based on the partial result map of the
// tiled result.
- AffineMap partialMap =
- getPartialResultAffineMap(linalgOp, reductionDims, idx);
SmallVector<int64_t> partialReductionDims;
for (auto [resultNum, dimExpr] :
llvm::enumerate(partialMap.getResults())) {
@@ -516,15 +535,13 @@ struct LinalgOpPartialReductionInterface
}
}
- Value partialResult = partialReduce[idx];
- Value init = linalgOp.getDpsInits()[idx];
-
auto reduction = b.create<linalg::ReduceOp>(
loc, partialResult, init, partialReductionDims,
- [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
+ [&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) {
// Get the combiner op.
SmallVector<Operation *, 4> combinerOps;
- matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
+ matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
+ combinerOps);
Operation *clonedReductionOp = b.clone(*combinerOps[0]);
// Combine the input at idx and output at numInits + idx.
clonedReductionOp->setOperand(0, inputs[0]);
@@ -542,14 +559,14 @@ struct LinalgOpPartialReductionInterface
LogicalResult getPartialResultTilePosition(
Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ const SetVector<unsigned> &reductionDims,
SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes,
- ArrayRef<int> reductionDims) const {
+ SmallVector<OpFoldResult> &resultSizes) const {
auto linalgOp = cast<LinalgOp>(op);
+ SmallVector<AffineMap> partialReductionMaps =
+ getPartialResultAffineMaps(linalgOp, reductionDims);
- AffineMap partialMap =
- getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
- for (AffineExpr dimExpr : partialMap.getResults()) {
+ for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) {
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
resultSizes.push_back(sizes[dim]);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 3f29dd3ac5e48..e7c076024e67b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -77,9 +77,8 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
//===----------------------------------------------------------------------===//
/// Verify the tile size options are set in a consistent manner.
-static LogicalResult
-verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
- const scf::SCFTilingOptions &options) {
+static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc,
+ const scf::SCFTilingOptions &options) {
// Specifying number of threads is only supported on `scf.forall` op.
if (options.numThreadsComputationFunction &&
options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
@@ -156,7 +155,9 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
}
/// Checks if any of the tiled loops are not parallel.
-static void checkSafeToTileToForall(TilingInterface op,
+static LogicalResult checkTileSizes(TilingInterface op,
+ scf::SCFTilingOptions::LoopType loopType,
+ ReductionTilingStrategy reductionStrategy,
ArrayRef<OpFoldResult> tileSizes,
ArrayRef<OpFoldResult> numThreads) {
auto iterators = op.getLoopIteratorTypes();
@@ -165,28 +166,46 @@ static void checkSafeToTileToForall(TilingInterface op,
assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
"when specified, expected number of threads to use for each loop");
+ bool isParallelTiling = false, isReductionTiling = false;
for (auto [index, iterator, tileSize] :
llvm::enumerate(iterators, tileSizes)) {
- // If num threads is specified, check that it is greater than one only for
- // parallel dimensions.
- if (!numThreads.empty()) {
- if (std::optional<int64_t> constNumThreads =
- getConstantIntValue(numThreads[index])) {
- if (constNumThreads.value() > 1 &&
+ if (!isConstantIntValue(tileSize, 0)) {
+ isParallelTiling |= iterator == utils::IteratorType::parallel;
+ isReductionTiling |= iterator == utils::IteratorType::reduction;
+ }
+
+ if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
+ reductionStrategy == ReductionTilingStrategy::FullReduction) {
+ // If num threads is specified, check that it is greater than one only for
+ // parallel dimensions.
+ if (!numThreads.empty()) {
+ if (std::optional<int64_t> constNumThreads =
+ getConstantIntValue(numThreads[index])) {
+ if (constNumThreads.value() > 1 &&
+ iterator != utils::IteratorType::parallel) {
+ op.emitWarning() << "tiling is not thread safe at axis #" << index;
+ }
+ }
+ continue;
+ }
+
+ if (std::optional<int64_t> constTileSize =
+ getConstantIntValue(tileSize)) {
+ if (constTileSize.value() > 0 &&
iterator != utils::IteratorType::parallel) {
op.emitWarning() << "tiling is not thread safe at axis #" << index;
}
}
- continue;
}
+ }
- if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
- if (constTileSize.value() > 0 &&
- iterator != utils::IteratorType::parallel) {
- op.emitWarning() << "tiling is not thread safe at axis #" << index;
- }
- }
+ if (isParallelTiling && isReductionTiling &&
+ reductionStrategy != ReductionTilingStrategy::FullReduction) {
+ return op->emitOpError(
+ "combined parallel and reduction tiling is not supported with partial "
+ "reduction tiling strategies");
}
+ return success();
}
/// Check if `stride` evenly divides the trip count `size - offset`.
@@ -575,35 +594,20 @@ createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
const scf::SCFTilingOptions &options) {
SmallVector<Value> initTensors;
Location loc = op->getLoc();
- switch (options.reductionStrategy) {
- case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
return failure();
return initTensors;
- case scf::SCFTilingOptions::ReductionTilingStrategy::
- PartialReductionOuterReduction: {
- auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
- if (!redOp) {
- return rewriter.notifyMatchFailure(
- op, "PartialReductionOuterReduction tiling strategy is only supported"
- "for operations implementing PartialReductionOpInterface");
- }
- // Get reduction dimensions.
- // TODO: PartialReductionOpInterface should really query TilingInterface
- // itself and find reduction dimensions.
- SmallVector<int> reductionDims;
- for (auto [idx, iteratorType] :
- llvm::enumerate(op.getLoopIteratorTypes())) {
- if (iteratorType == utils::IteratorType::reduction)
- reductionDims.push_back(idx);
- }
- return redOp.generateInitialTensorForPartialReduction(
- rewriter, loc, tileSizes, reductionDims);
}
- default:
- return rewriter.notifyMatchFailure(op,
- "unhandled reduction tiling strategy");
+
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only supported"
+ "for operations implementing PartialReductionOpInterface");
}
+ return redOp.generateInitialTensorForPartialReduction(
+ rewriter, loc, tileSizes, options.reductionDims);
}
static FailureOr<TilingResult>
@@ -611,34 +615,20 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
const scf::SCFTilingOptions &options) {
- switch (options.reductionStrategy) {
- case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
return op.getTiledImplementation(rewriter, offsets, sizes);
- case scf::SCFTilingOptions::ReductionTilingStrategy::
- PartialReductionOuterReduction: {
- auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
- if (!redOp) {
- return rewriter.notifyMatchFailure(
- op, "PartialReductionOuterReduction tiling strategy is only "
- "supported for operations "
- "implementing PartialReductionOpInterface");
- }
- // Get reduction dimensions.
- // TODO: PartialReductionOpInterface should really query TilingInterface
- // itself and find reduction dimensions.
- SmallVector<int> reductionDims;
- for (auto [idx, iteratorType] :
- llvm::enumerate(op.getLoopIteratorTypes())) {
- if (iteratorType == utils::IteratorType::reduction)
- reductionDims.push_back(idx);
- }
- return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
- offsets, sizes, reductionDims);
}
- default:
- return rewriter.notifyMatchFailure(op,
- "unhandled reduction tiling strategy");
+
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
}
+ return redOp.tileToPartialReduction(rewriter, op.getLoc(),
+ options.reductionStrategy, regionIterArg,
+ offsets, sizes, options.reductionDims);
}
static LogicalResult
@@ -649,70 +639,37 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
SmallVector<OpFoldResult> &resultSize,
const scf::SCFTilingOptions &options) {
- switch (options.reductionStrategy) {
- case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
return op.getResultTilePosition(rewriter, index, offsets, sizes,
resultOffset, resultSize);
- case scf::SCFTilingOptions::ReductionTilingStrategy::
- PartialReductionOuterReduction: {
- auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
- if (!redOp) {
- return rewriter.notifyMatchFailure(
- op, "PartialReductionOuterReduction tiling strategy is only supported"
- "for operations implementing PartialReductionOpInterface");
- }
- // Get reduction dimensions.
- // TODO: PartialReductionOpInterface should really query TilingInterface
- // itself and find reduction dimensions.
- SmallVector<int> reductionDims;
- for (auto [idx, iteratorType] :
- llvm::enumerate(op.getLoopIteratorTypes())) {
- if (iteratorType == utils::IteratorType::reduction)
- reductionDims.push_back(idx);
- }
- return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
- resultOffset, resultSize,
- reductionDims);
}
- default:
- return rewriter.notifyMatchFailure(op,
- "unhandled reduction tiling strategy");
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only supported"
+ "for operations implementing PartialReductionOpInterface");
}
+ return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
+ options.reductionDims, resultOffset,
+ resultSize);
}
static FailureOr<MergeResult>
mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
ValueRange partialResults,
const scf::SCFTilingOptions &options) {
- switch (options.reductionStrategy) {
- case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
- // No need to merge results for reduction tiling strategy.
- return MergeResult{{}, partialResults};
- case scf::SCFTilingOptions::ReductionTilingStrategy::
- PartialReductionOuterReduction: {
- auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
- if (!redOp) {
- return rewriter.notifyMatchFailure(
- op, "PartialReductionOuterReduction tiling strategy is only "
- "supported for operations "
- "implementing PartialReductionOpInterface");
- }
- // Get reduction dimensions.
- // TODO: PartialReductionOpInterface should really query TilingInterface
- // itself and find reduction dimensions.
- SmallVector<int> reductionDims;
- for (auto [idx, iteratorType] :
- llvm::enumerate(op.getLoopIteratorTypes())) {
- if (iteratorType == utils::IteratorType::reduction)
- reductionDims.push_back(idx);
- }
- return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
- reductionDims);
- }
- default:
- return rewriter.notifyMatchFailure(op,
- "unhandled reduction tiling strategy");
+ assert(options.reductionStrategy != ReductionTilingStrategy::FullReduction &&
+ "expected merge to be called for only partial reduction cases");
+
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
}
+ return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
+ options.reductionDims);
}
/// Append the specified additional `newInitOperands` operands to the
@@ -932,7 +889,7 @@ static LogicalResult addInitOperandsToLoopNest(
FailureOr<scf::SCFTilingResult>
mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
const scf::SCFTilingOptions &options) {
- if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
+ if (failed(verifyOptions(rewriter, op.getLoc(), options))) {
return failure();
}
@@ -949,8 +906,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// Check if it is safe to tile. This is hold over from previous iterations
// of tile to for-all. Consider dropping it.
- if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
- checkSafeToTileToForall(op, tileSizes, numThreads);
+ if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy,
+ tileSizes, numThreads))) {
+ return failure();
}
// 3. If there is an interchange specified, permute the iteration domain and
@@ -1073,8 +1031,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
[](OpResult r) -> Value { return r; });
// For the full reduction case, there is nothing more to do.
- if (options.reductionStrategy ==
- scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) {
+ if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
return scf::SCFTilingResult{
tilingResult->tiledOps, initTensors, loops, loopResults,
tilingResult->generatedSlices, {}};
@@ -1102,9 +1059,13 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
options.setReductionTilingStrategy(
- scf::SCFTilingOptions::ReductionTilingStrategy::
- PartialReductionOuterReduction);
+ ReductionTilingStrategy::PartialReductionOuterReduction);
options.setTileSizes(tileSize);
+ SmallVector<unsigned> reductionDims;
+ for (auto [index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(index);
+ options.setReductionDims(reductionDims);
return tileUsingSCF(b, op, options);
}
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 9d34c80822d0e..009ab17786696 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -343,7 +343,6 @@ module attributes {transform.with_named_sequence} {
module {
func.func @fail_for_float_neutral(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// expected-error @below {{'linalg.generic' op Failed to get an identity value for the reduction operation.}}
- // expected-note @below {{when applied to this op}}
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
%1 = llvm.fmul %in, %in : f32
@@ -355,7 +354,7 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{transform.structured.tile_reduction_using_for failed to apply}}
+ // expected-error @below {{failed to tile using partial reduction}}
%fill_op, %split_linalg_op, %combining_linalg_op, %for_op = transform.structured.tile_reduction_using_for %0 by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -480,3 +479,167 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: linalg.reduce
// CHECK: return
+
+// -----
+
+// Check that only one of the reduction dimension can be tiled (in this case outer).
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+ func.func @reduction_tile_single_of_multiple_reduction_outer(
+ %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "reduction"]}
+ ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %1, %out : f32
+ linalg.yield %2 : f32
+ } -> tensor<4096xf32>
+ return %0 : tensor<4096xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
+ transform.structured.tile_reduction_using_for %0 reduction_dims = [1] by tile_sizes = [0, 2]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+// CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK: @reduction_tile_single_of_multiple_reduction_outer(
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C86:.+]] = arith.constant 86 : index
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<4096x2xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[C86]] step %[[C2]]
+// CHECK-SAME: iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
+// CHECK: %[[PARTIAL_RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: outs(%[[ITER_ARG]] :
+// CHECK: scf.yield %[[PARTIAL_RESULT]]
+// CHECK: %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME: ins(%[[RESULT]] :
+// CHECK-SAME: outs(%[[INIT]] :
+// CHECK-SAME: dimensions = [1]
+// CHECK: return %[[REDUCE]]
+
+// -----
+
+// Check that only one of the reduction dimension can be tiled (in this case inner).
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+ func.func @reduction_tile_single_of_multiple_reduction_inner(
+ %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "reduction"]}
+ ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %1, %out : f32
+ linalg.yield %2 : f32
+ } -> tensor<4096xf32>
+ return %0 : tensor<4096xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
+ transform.structured.tile_reduction_using_for %0 reduction_dims = [2] by tile_sizes = [0, 0, 64]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+// CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: @reduction_tile_single_of_multiple_reduction_inner(
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
+// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<4096x64xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C64]]
+// CHECK-SAME: iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
+// CHECK: %[[PARTIAL_RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"]
+// CHECK-SAME: outs(%[[ITER_ARG]] :
+// CHECK: scf.yield %[[PARTIAL_RESULT]]
+// CHECK: %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME: ins(%[[RESULT]] :
+// CHECK-SAME: outs(%[[INIT]] :
+// CHECK-SAME: dimensions = [1]
+// CHECK: return %[[REDUCE]]
+
+// -----
+
+// Check that both the reduction dimensions are tiled but the dimensions in the output are swapped.
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+ func.func @reduction_tile_single_of_multiple_reduction_reversed(
+ %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "reduction"]}
+ ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %1, %out : f32
+ linalg.yield %2 : f32
+ } -> tensor<4096xf32>
+ return %0 : tensor<4096xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
+ transform.structured.tile_reduction_using_for %0 reduction_dims = [2, 1] by tile_sizes = [0, 2, 64]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+// CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+// CHECK: @reduction_tile_single_of_multiple_reduction_reversed(
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
+// CHECK-DAG: %[[C86:.+]] = arith.constant 86 : index
+// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<4096x64x2xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C86]] step %[[C2]]
+// CHECK-SAME: iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
+// CHECK: %[[RESULT0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C64]]
+// CHECK-SAME: iter_args(%[[ITER_ARG0:.+]] = %[[ITER_ARG]])
+// CHECK: %[[PARTIAL_RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: outs(%[[ITER_ARG0]] :
+// CHECK: scf.yield %[[PARTIAL_RESULT]]
+// CHECK scf.yield %[[RESULT0]]
+// CHECK: %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME: ins(%[[RESULT]] :
+// CHECK-SAME: outs(%[[INIT]] :
+// CHECK-SAME: dimensions = [1, 2]
+// CHECK: return %[[REDUCE]]
More information about the Mlir-commits
mailing list