[Mlir-commits] [mlir] [mlir][scf] Add getPartialResultTilePosition to PartialReductionOpInterface (PR #120465)
Kunwar Grover
llvmlistbot at llvm.org
Fri Dec 27 08:38:24 PST 2024
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/120465
>From ab58988015d599ddf9390b6ab49a3d4827c3755b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 18 Dec 2024 15:45:52 +0000
Subject: [PATCH 1/3] Fix invalid use of PartialReductionOpInterface in
MeshShardingInteraceImpl
---
.../Transforms/MeshShardingInterfaceImpl.cpp | 34 ++++++++++++-------
1 file changed, 21 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 5bf2f91c2c7bc8..92cfba2549a3f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
SymbolTableCollection &symbolTable) {
- for (const MeshSharding& sharding : operandShardings) {
+ for (const MeshSharding &sharding : operandShardings) {
if (sharding) {
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
}
}
- for (const MeshSharding& sharding : resultShardings) {
+ for (const MeshSharding &sharding : resultShardings) {
if (sharding) {
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
}
@@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
// the original operand.
// The other processes would use the reduction operation neutral tensor.
static Value createDestinationPassingStyleInitOperand(
- LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
- MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+ LinalgOp op, int operandNumber, Value spmdizedOperand,
+ ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
+ ImplicitLocOpBuilder &builder) {
Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
meshOp.getSymName(), reductionMeshAxes, builder);
Value zero = builder.create<arith::ConstantIndexOp>(0);
@@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
SmallVector<OpFoldResult> shape =
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
- PartialReductionOpInterface partialReductionIface =
- llvm::cast<PartialReductionOpInterface>(op.getOperation());
- assert(op->getNumResults() == 1 && "Multiple results not supported.");
- FailureOr<SmallVector<Value>> reductionNeutralTensor =
- partialReductionIface.generateInitialTensorForPartialReduction(
- builder, builder.getLoc(), shape, {});
- assert(succeeded(reductionNeutralTensor));
- builder.create<scf::YieldOp>(reductionNeutralTensor.value());
+
+ SmallVector<Operation *> combinerOps;
+ matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
+ assert(combinerOps.size() == 1);
+ std::optional<TypedAttr> neutralEl =
+ arith::getNeutralElement(combinerOps[0]);
+
+ Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
+ neutralEl.value().getType());
+ Value constant =
+ builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
+ Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
+ .getResult(0);
+
+ builder.create<scf::YieldOp>(fill);
}
return ifOp.getResult(0);
}
@@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
Value spmdizedInitOperand =
spmdizationMap.lookup(op->getOperands()[operandIdx]);
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
- op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+ op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
return newOperands;
}
>From e5a56f5fd975a54cefe1c950b2e547a53a279759 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 18 Dec 2024 18:22:01 +0000
Subject: [PATCH 2/3] [mlir][scf] Add getPartialResultTilePosition to
PartialReductionOpInterface
---
.../mlir/Interfaces/TilingInterface.td | 22 +++
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 157 ++++++++++++------
.../SCF/Transforms/TileUsingInterface.cpp | 28 ++--
.../Linalg/transform-tile-reduction.mlir | 67 ++++++--
4 files changed, 196 insertions(+), 78 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index b75fc5e806afbe..50b69b8f8d833e 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
/*defaultImplementation=*/[{
return failure();
}]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return the position of the partial result tile computed by
+ the tiled operation. This is same as
+ TilingInterface:::getResultTilePosition, but determines the result
+ tile position for partial reduction.
+ }],
+ /*retType=*/"::llvm::LogicalResult",
+ /*methodName=*/"getPartialResultTilePosition",
+ /*args=*/(ins
+ "::mlir::OpBuilder &":$b,
+ "unsigned":$resultNumber,
+ "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
+ "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+ "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
+ "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
+ "::mlir::ArrayRef<int>":$reductionDims),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
>
];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index f86715a94b268a..098016cd0fd226 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -324,7 +324,20 @@ struct LinalgOpTilingInterface
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
//===----------------------------------------------------------------------===//
-/// External model implementation of PartialReductionInterface for LinalgOps.
+static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
+ ArrayRef<int> reductionDims,
+ unsigned resultNumber) {
+ AffineMap map =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
+ for (int redPos : reductionDims) {
+ map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
+ map.getNumResults());
+ }
+ return map;
+}
+
+/// External model implementation of PartialReductionInterface for
+/// LinalgOps.
template <typename LinalgOpTy>
struct LinalgOpPartialReductionInterface
: public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +351,24 @@ struct LinalgOpPartialReductionInterface
if (linalgOp.hasPureBufferSemantics())
return op->emitOpError("expected operation to have tensor semantics");
+ // LinalgOp implements TilingInterface.
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+ SmallVector<OpFoldResult> shape =
+ llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
+ [](Range x) { return x.size; });
+
+ SmallVector<OpFoldResult> tiledShape;
+ for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
+ if (isZeroIndex(tileSize)) {
+ tiledShape.push_back(dimSize);
+ } else {
+ tiledShape.push_back(tileSize);
+ }
+ }
+
SmallVector<Value> inits;
for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
++initIdx) {
- // Insert the new parallel dimension based on the index of the reduction
- // loops. This could be controlled by user for more flexibility.
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
combinerOps) ||
@@ -355,33 +381,19 @@ struct LinalgOpPartialReductionInterface
return op->emitOpError(
"Failed to get an identity value for the reduction operation.");
- ArrayRef<int64_t> oldShape =
- linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx));
-
- // Calculate the new shape, we insert the new dimensions based on the
- // index of the reduction dimensions.
- SmallVector<int64_t> newOutputShape;
- SmallVector<Value> dynamicDims;
- int64_t currReductionDims = 0;
- DenseSet<int> reductionDimsSet(reductionDims.begin(),
- reductionDims.end());
- for (int64_t idx :
- llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
- if (reductionDimsSet.contains(idx)) {
- dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
- currReductionDims++;
- continue;
- }
- int64_t oldIdx = idx - currReductionDims;
- int64_t dim = oldShape[oldIdx];
- newOutputShape.push_back(dim);
- if (ShapedType::isDynamic(dim))
- dynamicDims.push_back(b.create<tensor::DimOp>(
- loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx));
+ // Append the new partial result dimensions.
+ AffineMap partialMap =
+ getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
+ SmallVector<OpFoldResult> partialResultShape;
+ for (AffineExpr dimExpr : partialMap.getResults()) {
+ auto dim = cast<AffineDimExpr>(dimExpr);
+ partialResultShape.push_back(tiledShape[dim.getPosition()]);
}
- Value emptyTensor = b.create<tensor::EmptyOp>(
- loc, newOutputShape,
- linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims);
+
+ Type elType =
+ getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
+ Value emptyTensor =
+ b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
auto identityTensor =
b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +419,7 @@ struct LinalgOpPartialReductionInterface
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
// this with a for range loop when we have it.
AffineMap newMap =
- linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
- for (int redPos : reductionDims) {
- newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
- newMap.getNumResults());
- }
+ getPartialResultAffineMap(linalgOp, reductionDims, idx);
newInitMaps.push_back(newMap);
}
@@ -476,29 +484,74 @@ struct LinalgOpPartialReductionInterface
Location loc, ValueRange partialReduce,
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
- SmallVector<int64_t> reductionDimsInt64(reductionDims);
- auto reduction = b.create<linalg::ReduceOp>(
- loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
- [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
- int64_t numInits = linalgOp.getNumDpsInits();
- SmallVector<Value> yieldedValues;
- for (int idx : llvm::seq<int>(0, numInits)) {
+
+ // Permute the reduction dims as permuted by the partial result map.
+
+ int64_t numInits = linalgOp.getNumDpsInits();
+ SmallVector<Operation *> mergeOperations;
+ SmallVector<Value> replacements;
+ for (int idx : llvm::seq(numInits)) {
+ // linalg.reduce's iteration space is the result's iteration space (and
+ // not the operations iteration space). To account for this, permute the
+ // reduction dimensions based on the partial result map.
+ AffineMap partialMap =
+ getPartialResultAffineMap(linalgOp, reductionDims, idx);
+ SmallVector<int64_t> partialReductionDims;
+ for (auto [resultNum, dimExpr] :
+ llvm::enumerate(partialMap.getResults())) {
+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+ if (llvm::find(reductionDims, dim) != reductionDims.end()) {
+ partialReductionDims.push_back(resultNum);
+ }
+ }
+
+ Value partialResult = partialReduce[idx];
+ Value init = linalgOp.getDpsInits()[idx];
+
+ auto reduction = b.create<linalg::ReduceOp>(
+ loc, partialResult, init, partialReductionDims,
+ [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
// Get the combiner op.
SmallVector<Operation *, 4> combinerOps;
matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
Operation *clonedReductionOp = b.clone(*combinerOps[0]);
// Combine the input at idx and output at numInits + idx.
- clonedReductionOp->setOperand(0, inputs[idx]);
- clonedReductionOp->setOperand(1, inputs[numInits + idx]);
- // Yield.
- yieldedValues.push_back(clonedReductionOp->getResult(0));
- }
- b.create<linalg::YieldOp>(loc, yieldedValues);
- });
- return MergeResult{
- {reduction.getOperation()},
- llvm::map_to_vector(reduction->getResults(),
- [](OpResult r) -> Value { return r; })};
+ clonedReductionOp->setOperand(0, inputs[0]);
+ clonedReductionOp->setOperand(1, inputs[1]);
+ b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+ });
+
+ mergeOperations.push_back(reduction);
+ replacements.push_back(reduction->getResult(0));
+ }
+
+ return MergeResult{mergeOperations, replacements};
+ }
+
+ LogicalResult getPartialResultTilePosition(
+ Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes,
+ ArrayRef<int> reductionDims) const {
+ auto linalgOp = cast<LinalgOp>(op);
+
+ AffineMap partialMap =
+ getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
+ for (AffineExpr dimExpr : partialMap.getResults()) {
+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+ resultSizes.push_back(sizes[dim]);
+
+ if (llvm::find(reductionDims, dim) != reductionDims.end()) {
+ // Reduction dims are reduced, and are always outputed in the same
+ // place. So use offset 0 for them.
+ resultOffsets.push_back(b.getIndexAttr(0));
+ } else {
+ resultOffsets.push_back(offsets[dim]);
+ }
+ }
+
+ return success();
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2277989bf8411b..b548f8ce8b560b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -657,21 +657,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
resultOffset, resultSize);
case scf::SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction: {
- // TODO: This does not work for non identity accesses to the result tile.
- // The proper fix is to add a getPartialResultTilePosition method to
- // PartialReductionOpInterface.
- resultOffset =
- SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
- for (size_t i = 0; i < offsets.size(); i++) {
- resultSize.push_back(
- tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only supported"
+ "for operations implementing PartialReductionOpInterface");
}
- return success();
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
+ resultOffset, resultSize,
+ reductionDims);
+ }
default:
return rewriter.notifyMatchFailure(op,
"unhandled reduction tiling strategy");
}
- }
}
static FailureOr<MergeResult>
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index cce4b4efa61c8b..9d34c80822d0e1 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -32,8 +32,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
-// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
// CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
// CHECK: %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
@@ -81,13 +80,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK: func @reduction_tile_transpose
-// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
-// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
+// CHECK: tensor.empty(%{{.*}}) : tensor<?x5xf32>
+// CHECK: linalg.fill {{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
// CHECK: scf.for
-// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
+// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
-// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
-// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
+// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
+// CHECK: scf.yield {{.*}} : tensor<?x5xf32>
// CHECK: }
// CHECK: linalg.reduce
// CHECK: return
@@ -129,8 +128,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
-// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
@@ -183,9 +181,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG: %[[D3:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK-DAG: %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
+// CHECK: %[[E:.*]] = tensor.empty(%[[D0]], %[[D2]]) : tensor<?x?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
@@ -243,8 +239,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
-// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
@@ -422,8 +417,8 @@ func.func @reduction_tile_multiple_results(%arg0: tensor<?x?xf32>, %out: tensor<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %12, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
- by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %1, %12, %2, %3, %4, %loop = transform.structured.tile_reduction_using_for %0
+ by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -444,4 +439,44 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.yield %[[INSERT1]], %[[INSERT1]]
// CHECK: linalg.reduce
// CHECK: arith.addf
+// CHECK: linalg.reduce
// CHECK: arith.maximumf
+
+// -----
+
+func.func @reduction_tile_multi_dim_transpose(%arg0: tensor<?x?x?xf32>, %out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d0)>],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%arg0 : tensor<?x?x?xf32>)
+ outs(%out : tensor<?x?xf32>) {
+ ^bb0(%arg7: f32, %arg9: f32):
+ %42 = arith.addf %arg7, %arg9 : f32
+ linalg.yield %42 : f32
+ } -> tensor<?x?xf32>
+ return %red : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+ by tile_sizes = [0, 5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+// CHECK: func @reduction_tile_multi_dim_transpose
+// CHECK: tensor.empty(%{{.*}}) : tensor<?x?x5xf32>
+// CHECK: linalg.fill {{.*}} : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
+// CHECK: scf.for
+// CHECK: %[[K:.*]] = affine.min
+// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0, 0] [%[[D2:.*]], %[[D0:.*]], %[[K]]] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?x?xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?x?xf32>) outs(%[[EXT]] : tensor<?x?x?xf32>)
+// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0, 0] [%[[D2]], %[[D0]], %[[K]]] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x5xf32>
+// CHECK: scf.yield {{.*}} : tensor<?x?x5xf32>
+// CHECK: }
+// CHECK: linalg.reduce
+// CHECK: return
>From 93d07b9d3c89fe5f264bdab06611f83aa51089ac Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 27 Dec 2024 16:38:03 +0000
Subject: [PATCH 3/3] Address comments
---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 14 +++++++++++---
1 file changed, 11 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 098016cd0fd226..b7764da26a7f47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -324,6 +324,13 @@ struct LinalgOpTilingInterface
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
//===----------------------------------------------------------------------===//
+/// Return an AffineMap for a partial result for the given result number,
+/// assuming the partial tiling strategy is outer-reduction loop +
+/// inner-parallel tile. The returned AffineMap can be used as the replacement
+/// AffineMap for the inner-parallel tile linalg op for the given result number.
+///
+/// The new AffineMap is the old AffineMap with reduction dimensions appended
+/// at end.
static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
ArrayRef<int> reductionDims,
unsigned resultNumber) {
@@ -491,9 +498,10 @@ struct LinalgOpPartialReductionInterface
SmallVector<Operation *> mergeOperations;
SmallVector<Value> replacements;
for (int idx : llvm::seq(numInits)) {
- // linalg.reduce's iteration space is the result's iteration space (and
- // not the operations iteration space). To account for this, permute the
- // reduction dimensions based on the partial result map.
+ // linalg.reduce's iteration space is the tiled result's iteration space
+ // (and not the tiled operation's iteration space). To account for this,
+ // permute the reduction dimensions based on the partial result map of the
+ // tiled result.
AffineMap partialMap =
getPartialResultAffineMap(linalgOp, reductionDims, idx);
SmallVector<int64_t> partialReductionDims;
More information about the Mlir-commits
mailing list