[Mlir-commits] [mlir] [mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface (PR #92624)
Kunwar Grover
llvmlistbot at llvm.org
Mon May 20 06:49:11 PDT 2024
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/92624
>From 8e890c8908b3267c41bb616de5749dca3fc43f17 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 20 May 2024 13:17:01 +0000
Subject: [PATCH] save
---
.../Linalg/TransformOps/LinalgTransformOps.td | 4 +-
.../Dialect/Linalg/Transforms/Transforms.h | 4 +-
.../SCF/Transforms/TileUsingInterface.h | 4 +-
.../mlir/Interfaces/TilingInterface.td | 5 +-
.../TransformOps/LinalgTransformOps.cpp | 6 +-
.../Transforms/MeshShardingInterfaceImpl.cpp | 13 +-
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 25 +--
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 170 +++++++++++-------
.../SCF/Transforms/TileUsingInterface.cpp | 51 +++---
.../Linalg/transform-tile-reduction.mlir | 50 +++++-
10 files changed, 211 insertions(+), 121 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5585ba27fdad8..93e2c2db729da 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1681,7 +1681,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
// TODO: support mixed static-dynamic (see TileUsingForallOp).
let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
- let results = (outs TransformHandleTypeInterface:$fill_op,
+ let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
TransformHandleTypeInterface:$split_linalg_op,
TransformHandleTypeInterface:$combining_linalg_op,
TransformHandleTypeInterface:$for_op);
@@ -1787,7 +1787,7 @@ def TileReductionUsingForallOp :
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
- let results = (outs TransformHandleTypeInterface:$fill_op,
+ let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
TransformHandleTypeInterface:$split_linalg_op,
TransformHandleTypeInterface:$combining_linalg_op,
TransformHandleTypeInterface:$forall_op);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f77c19ed0fcce..308ce92e35520 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,8 +876,8 @@ struct ForallReductionTilingResult {
Operation *parallelTiledOp;
/// The final reduction operation merging all the partial reductions.
Operation *mergeOp;
- /// The op initializing the tensor used for partial reductions.
- Operation *initialOp;
+ /// Initial values used for partial reductions.
+ SmallVector<Value> initialValues;
/// The `scf.forall` operation that iterate over the tiles.
scf::ForallOp loops;
};
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..6d567171e185a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -250,8 +250,8 @@ struct SCFReductionTilingResult {
Operation *parallelTiledOp;
/// The final reduction operation merging all the partial reductions.
Operation *mergeOp;
- /// Initial op
- Operation *initialOp;
+ /// Initial values used for reduction.
+ SmallVector<Value> initialValues;
/// The loop operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
};
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c2424..6fff7e0da538a 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -170,11 +170,12 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
operation reduction. The tensor shape is equal to operation result
shape with new dimension for each non zero tile size.
}],
- /*retType=*/"FailureOr<Operation*>",
+ /*retType=*/"FailureOr<Value>",
/*methodName=*/"generateInitialTensorForPartialReduction",
/*args=*/(ins
"OpBuilder &":$b,
- "Location ":$loc,
+ "Location":$loc,
+ "int64_t":$resultNumber,
"ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<int>":$reductionDim),
/*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..9b3121774ab3a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2523,7 +2523,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
if (failed(result))
return emitDefaultSilenceableFailure(target);
- results.push_back(result->initialOp);
+ for (Value initValue : result->initialValues)
+ results.push_back(initValue.getDefiningOp());
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
results.push_back(result->loops.front());
@@ -2574,7 +2575,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
diag.attachNote(target.getLoc()) << "target operation";
return diag;
}
- results.push_back(result->initialOp);
+ for (Value initValue : result->initialValues)
+ results.push_back(initValue.getDefiningOp());
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
results.push_back(result->loops);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 146e880765668..e0394f852fcc3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -155,12 +155,12 @@ static Value createDestinationPassingStyleInitOperand(
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
PartialReductionOpInterface partialReductionIface =
llvm::cast<PartialReductionOpInterface>(op.getOperation());
- FailureOr<Operation *> reductionNeutralTensorOp =
+ assert(op->getNumResults() == 1 && "Multiple results not supported.");
+ FailureOr<Value> reductionNeutralTensor =
partialReductionIface.generateInitialTensorForPartialReduction(
- builder, builder.getLoc(), shape, {});
- assert(succeeded(reductionNeutralTensorOp));
- builder.create<scf::YieldOp>(
- reductionNeutralTensorOp.value()->getResult(0));
+ builder, builder.getLoc(), 0, shape, {});
+ assert(succeeded(reductionNeutralTensor));
+ builder.create<scf::YieldOp>(reductionNeutralTensor.value());
}
return ifOp.getResult(0);
}
@@ -173,8 +173,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
ImplicitLocOpBuilder &builder) {
// TODO: add support for multiple destination passing style initial value
// operands.
- // PartialReductionOpInterface::generateInitialTensorForPartialReduction
- // needs to also support multiple DPS initial operands.
+ assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
Value spmdizedInitOperand =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index df4089d61bfd7..d51c4d0af0366 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -692,12 +692,17 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
op, "reduction dimension must be mapped to threads");
// 1. Create the inital tensor value.
- FailureOr<Operation *> identityTensor =
- op.generateInitialTensorForPartialReduction(b, loc, numThreads,
- reductionDim);
- if (failed(identityTensor))
- return b.notifyMatchFailure(op,
- "cannot create a tensor of identity value.");
+ SmallVector<Value> initTensors;
+ initTensors.reserve(op->getNumResults());
+ for (int idx : llvm::seq<int>(0, op->getNumResults())) {
+ FailureOr<Value> initValue = op.generateInitialTensorForPartialReduction(
+ b, loc, idx, numThreads, reductionDim);
+ if (failed(initValue))
+ return b.notifyMatchFailure(
+ op, "cannot create a tensor of identity value for result " +
+ std::to_string(idx));
+ initTensors.push_back(initValue.value());
+ }
// Gather destination tensors.
SmallVector<Value> dest;
@@ -715,8 +720,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 2. Create the ForallOp with an empty region.
scf::ForallOp forallOp = b.create<scf::ForallOp>(
- loc, getAsOpFoldResult(materializedNonZeroNumThreads),
- (*identityTensor)->getResults(), mapping);
+ loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
+ mapping);
// 3. Calculate the tile offsets and sizes for the subsequent loop that will
// be nested under `forallOp`.
@@ -726,7 +731,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
/*nominalTileSizes=*/std::nullopt, tiledOffsets,
tiledSizes);
- // 4. Clone the tileable op and update its destination operands to use the
+ // 4b. Clone the tileable op and update its destination operands to use the
// output bbArgs of the ForallOp.
SmallVector<Value> tilingResults;
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
@@ -838,7 +843,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 8. Return.
ForallReductionTilingResult results;
- results.initialOp = *identityTensor;
+ results.initialValues = initTensors;
results.loops = forallOp;
results.parallelTiledOp = tiledOp;
results.mergeOp = mergeOp;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5..528f120287181 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -250,9 +250,9 @@ template <typename LinalgOpTy>
struct LinalgOpPartialReductionInterface
: public PartialReductionOpInterface::ExternalModel<
LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
- FailureOr<Operation *> generateInitialTensorForPartialReduction(
- Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
- ArrayRef<int> reductionDims) const {
+ FailureOr<Value> generateInitialTensorForPartialReduction(
+ Operation *op, OpBuilder &b, Location loc, int64_t resultNumber,
+ ArrayRef<OpFoldResult> sizes, ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
OpBuilder::InsertionGuard guard(b);
@@ -262,7 +262,8 @@ struct LinalgOpPartialReductionInterface
// loops. This could be controlled by user for more flexibility.
SmallVector<Operation *, 4> combinerOps;
- if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
+ if (!matchReduction(linalgOp.getRegionOutputArgs(), resultNumber,
+ combinerOps) ||
combinerOps.size() != 1)
return op->emitOpError("Failed to anaysis the reduction operation.");
@@ -273,7 +274,7 @@ struct LinalgOpPartialReductionInterface
"Failed to get an identity value for the reduction operation.");
ArrayRef<int64_t> oldShape =
- linalgOp.getShape(linalgOp.getDpsInitOperand(0));
+ linalgOp.getShape(linalgOp.getDpsInitOperand(resultNumber));
// Calculate the new shape, we insert the new dimensions based on the index
// of the reduction dimensions.
@@ -293,15 +294,15 @@ struct LinalgOpPartialReductionInterface
newOutputShape.push_back(dim);
if (ShapedType::isDynamic(dim))
dynamicDims.push_back(b.create<tensor::DimOp>(
- loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx));
+ loc, linalgOp.getDpsInitOperand(resultNumber)->get(), oldIdx));
}
Value emptyTensor = b.create<tensor::EmptyOp>(
- loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(),
- dynamicDims);
+ loc, newOutputShape,
+ linalgOp.getRegionOutputArgs()[resultNumber].getType(), dynamicDims);
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
auto identityTensor =
b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
- return identityTensor.getOperation();
+ return identityTensor.getResult(0);
}
Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
@@ -312,44 +313,64 @@ struct LinalgOpPartialReductionInterface
OpBuilder::InsertionGuard guard(b);
auto linalgOp = cast<LinalgOp>(op);
- AffineMap oldOutputMap =
- linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
- SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() +
- reductionDims.size());
-
- for (int idx : reductionDims)
- outputExpr[idx] = b.getAffineDimExpr(idx);
- int currExpr = 0;
- for (int idx : llvm::seq<int>(0, outputExpr.size())) {
- if (outputExpr[idx])
- continue;
- outputExpr[idx] = oldOutputMap.getResult(currExpr++);
+ // Step 1. Extend init maps to have reduction dimension dims, since we
+ // are converting them to parallel dimensions.
+ SmallVector<AffineMap> newInitMaps;
+ newInitMaps.reserve(linalgOp.getNumDpsInits());
+ for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+ // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+ // this with a for range loop when we have it.
+ AffineMap newMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+ for (int redPos : reductionDims) {
+ newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
+ newMap.getNumResults());
+ }
+ newInitMaps.push_back(newMap);
}
- // Step 1: Extract a slice of the input operands.
- SmallVector<Value> valuesToTile = linalgOp.getDpsInputs();
- SmallVector<Value, 4> tiledOperands = makeTiledShapes(
- b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+ // Step 2a: Extract a slice of the input operands.
+ SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+ b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+
+ // Step 2b: Extract a slice of the init operands.
+ SmallVector<Value, 1> tiledInits;
+ for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {
+ int64_t initRank = valueMap.getNumResults();
+ SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0));
+ SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1));
+ SmallVector<OpFoldResult> initSizes;
+ for (AffineExpr dimExpr : valueMap.getResults()) {
+ auto dim = cast<AffineDimExpr>(dimExpr);
+ initSizes.push_back(sizes[dim.getPosition()]);
+ }
+ // TODO: Use SubsetExtractOpInterface here once available.
+ auto extractSlice = b.create<tensor::ExtractSliceOp>(
+ loc, valueToTile, initOffset, initSizes, initStride);
+ tiledInits.push_back(extractSlice);
+ }
- // Step 2: Extract the accumulator operands
- SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
- SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
- // TODO: use SubsetExtractOpInterface once it is available.
- Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
- sizes, strides);
+ // Update the indexing maps.
+ SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
+ // Change the init maps.
+ for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+ // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+ // this with a for range loop when we have it.
+ OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
+ int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
+ newMaps[mapIdx] = newInitMaps[idx];
+ }
- // Step3. Create a generic op where the reduction dimensions are replaced
- // by a parallel dimension of the size of reduction.
+ // Step 3. Change the reduction dim iterator types.
SmallVector<utils::IteratorType> newIteratorTypes =
linalgOp.getIteratorTypesArray();
for (int dim : reductionDims)
newIteratorTypes[dim] = utils::IteratorType::parallel;
- SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
- newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
- linalgOp.getContext());
+
+ // Step 4. Create the new generic op.
auto genericOp =
- b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands,
- ValueRange({out}), newMaps, newIteratorTypes);
+ b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
+ tiledInits, newMaps, newIteratorTypes);
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
@@ -361,40 +382,53 @@ struct LinalgOpPartialReductionInterface
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
- DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
-
- // Then create a new reduction that only reduce the newly added dimensions
- // from the previous op.
- int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
- AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
- SmallVector<utils::IteratorType> reductionIteratorTypes;
- SmallVector<AffineExpr> exprs;
-
- for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
- if (reductionDimsSet.contains(i)) {
- reductionIteratorTypes.push_back(utils::IteratorType::reduction);
- } else {
- exprs.push_back(b.getAffineDimExpr(i));
- reductionIteratorTypes.push_back(utils::IteratorType::parallel);
+ // Step 1. Recover the dims that actually need to be merged from the
+ // original operation. We can classify the original iterators as follows:
+ //
+ // parallel --> parallel
+ // reduction + not in reductionDims --> parallel (already reduced)
+ // reduction + in reductionDims --> reduction (will reduce now)
+ SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
+ utils::IteratorType::parallel);
+ for (int redIdx : reductionDims)
+ iterators[redIdx] = utils::IteratorType::reduction;
+
+ // Step 2. For each partial result, create a map to index it. This map
+ // is simply the indexing map for the original result with reductionDims
+ // appended (as produced in tileToPartialReduction).
+ int64_t numInits = linalgOp.getNumDpsInits();
+ SmallVector<AffineMap> indexingMaps(numInits * 2);
+ for (int idx : llvm::seq<int>(0, numInits)) {
+ AffineMap &inputMap = indexingMaps[idx];
+ AffineMap &outputMap = indexingMaps[numInits + idx];
+
+ outputMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+ inputMap = outputMap;
+ for (int redPos : reductionDims) {
+ inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos),
+ inputMap.getNumResults());
}
}
- AffineMap outputMap =
- AffineMap::get(intermRank, 0, exprs, op->getContext());
- SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
-
- SmallVector<Operation *, 4> combinerOps;
- matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps);
- Operation *reductionOp = combinerOps[0];
-
auto reduction = b.create<GenericOp>(
- loc, op->getResultTypes(), ValueRange({partialReduce[0]}),
- linalgOp.getDpsInits(), reductionMaps, reductionIteratorTypes,
- [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
- Operation *clonedReductionOp = b.clone(*reductionOp);
- clonedReductionOp->setOperand(0, inputs[0]);
- clonedReductionOp->setOperand(1, inputs[1]);
- b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+ loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(),
+ indexingMaps, iterators,
+ [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
+ int64_t numInits = linalgOp.getNumDpsInits();
+ SmallVector<Value> yieldedValues;
+ for (int idx : llvm::seq<int>(0, numInits)) {
+ // Get the combiner op.
+ SmallVector<Operation *, 4> combinerOps;
+ matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
+ Operation *clonedReductionOp = b.clone(*combinerOps[0]);
+ // Combine the input at idx and output at numInits + idx.
+ clonedReductionOp->setOperand(0, inputs[idx]);
+ clonedReductionOp->setOperand(1, inputs[numInits + idx]);
+ // Yield.
+ yieldedValues.push_back(clonedReductionOp->getResult(0));
+ }
+ b.create<linalg::YieldOp>(loc, yieldedValues);
});
return reduction.getOperation();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69d..32a683eff8cdb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -182,6 +182,9 @@ static LogicalResult generateLoopNestUsingForOp(
if (loops.empty())
return success();
+ assert(tiledResults.size() == destinationTensors.size() &&
+ "Number of results of body should be equal to number of iter args");
+
// 6. Yield all the results of the tiled operation.
SmallVector<Value> yieldedValues;
for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
@@ -694,9 +697,6 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
zero);
}
- if (op->getNumResults() != 1)
- return b.notifyMatchFailure(
- op, "don't support ops with multiple results for now");
SmallVector<utils::IteratorType> iterators =
tilingInterfaceOp.getLoopIteratorTypes();
@@ -708,12 +708,18 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
}
// 2. create the inital tensor value.
- FailureOr<Operation *> identityTensor =
- op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
- reductionDims);
- if (failed(identityTensor))
- return b.notifyMatchFailure(op,
- "cannot create a tensor of identity value.");
+ SmallVector<Value> initTensors;
+ initTensors.reserve(op->getNumResults());
+ for (int idx : llvm::seq<int>(0, op->getNumResults())) {
+ FailureOr<Value> initTensor = op.generateInitialTensorForPartialReduction(
+ b, loc, idx, tileSizesVector, reductionDims);
+ if (failed(initTensor)) {
+ return b.notifyMatchFailure(
+ op, "cannot create a tensor of identity value for result: " +
+ std::to_string(idx));
+ }
+ initTensors.push_back(initTensor.value());
+ }
// 3. Define the callback to use for generating the inner most tile loop body.
Operation *parallelOp = nullptr;
@@ -753,29 +759,26 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
// 4d. Compute the offsets and sizes needed to insert the result of the
// tiled value back into destination before yielding the destination.
- SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
- resultOffsets.emplace_back(std::move(outOffsets));
-
- SmallVector<OpFoldResult> outSizes;
- for (size_t i = 0; i < offsets.size(); i++) {
- outSizes.push_back(
- tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+ for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) {
+ SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+ resultOffsets.emplace_back(std::move(outOffsets));
+
+ SmallVector<OpFoldResult> outSizes;
+ for (size_t i = 0; i < offsets.size(); i++) {
+ outSizes.push_back(
+ tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i));
+ }
+ resultSizes.emplace_back(std::move(outSizes));
}
- resultSizes.emplace_back(std::move(outSizes));
return success();
};
// 5. Generate the tiled implementation using the destination tensors.
- SmallVector<Value> destinationTensors =
- llvm::map_to_vector(identityTensor.value()->getResults(),
- [](OpResult res) -> Value { return res; });
-
SmallVector<LoopLikeOpInterface> loops;
scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
- destinationTensors, innerYieldTiledValuesFn,
- loops)))
+ initTensors, innerYieldTiledValuesFn, loops)))
return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
SmallVector<Value> replacements = llvm::map_to_vector(
@@ -787,7 +790,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
b.replaceOp(op, mergeOp->getResults());
SCFReductionTilingResult results;
- results.initialOp = *identityTensor;
+ results.initialValues = initTensors;
results.loops = loops;
results.parallelTiledOp = parallelOp;
results.mergeOp = mergeOp;
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 0e1512717a22d..f3cf7c4dffa05 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -80,13 +80,14 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: func @reduction_tile_transpose
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
// CHECK: scf.for
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
-// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
// CHECK: }
@@ -403,3 +404,48 @@ module {
// CHECK: scf.yield %[[L1]] : tensor<4096x2x64xf32>
// CHECK: %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
// CHECK: return %[[OUT2]] : tensor<4096xf32>
+
+// -----
+
+func.func @reduction_tile_multiple_results(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>, %out2: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %red:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%out, %out2 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%arg7: f32, %arg9: f32, %arg9_1: f32):
+ %1 = arith.mulf %arg7, %arg7 : f32
+ %2 = arith.addf %1, %arg9 : f32
+ %3 = arith.maximumf %1, %arg9_1 : f32
+ linalg.yield %2, %3 : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+ return %red#0, %red#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %12, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+ by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: func @reduction_tile_multiple_results
+// CHECK-DAG: %[[SUM_ID:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[MAX_ID:.+]] = arith.constant 0xFF800000 : f32
+// CHECK-DAG: %[[SUM_INIT:.+]] = linalg.fill ins(%[[SUM_ID]] : f32) outs(%{{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
+// CHECK-DAG: %[[MAX_INIT:.+]] = linalg.fill ins(%[[MAX_ID]] : f32) outs(%{{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
+// CHECK: %[[OUT:.+]]:2 = scf.for
+// CHECK-SAME: iter_args(%[[SUM:.+]] = %[[SUM_INIT]], %[[MAX:.+]] = %[[MAX_INIT]])
+// CHECK: %[[UPDATED:.*]]:2 = linalg.generic
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: arith.maximumf
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[UPDATED]]#0 into %[[SUM]]
+// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[UPDATED]]#1 into %[[MAX]]
+// CHECK: scf.yield %[[INSERT1]], %[[INSERT1]]
+// CHECK: linalg.generic
+// CHECK: arith.addf
+// CHECK: arith.maximumf
More information about the Mlir-commits
mailing list