[Mlir-commits] [mlir] 2cc5f5d - [mlir][Linalg] Implement tileReductionUsingScf for multiple reductions
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 16 13:47:15 PDT 2023
Author: Groverkss
Date: 2023-08-17T02:17:03+05:30
New Revision: 2cc5f5d43c20428c9b9799176ae2d1da0bf9330d
URL: https://github.com/llvm/llvm-project/commit/2cc5f5d43c20428c9b9799176ae2d1da0bf9330d
DIFF: https://github.com/llvm/llvm-project/commit/2cc5f5d43c20428c9b9799176ae2d1da0bf9330d.diff
LOG: [mlir][Linalg] Implement tileReductionUsingScf for multiple reductions
This patch improves the reduction tiling for linalg to support multiple
reduction dimensions.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D158005
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index a1ab4d9fd03801..cedaa4344a2958 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -255,15 +255,11 @@ struct LinalgOpPartialReductionInterface
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
OpBuilder::InsertionGuard guard(b);
- assert(reductionDims.size() == 1 &&
- "only support single reduction right now.");
+
if (linalgOp.hasBufferSemantics())
return op->emitOpError("expected operation to have tensor semantics");
// Insert the new parallel dimension based on the index of the reduction
- // loop. This could be controlled by user for more flexibility.
- int64_t insertSplitDimension = reductionDims[0];
- assert(sizes.size() >= static_cast<size_t>(insertSplitDimension) &&
- "reduction dimension must be tiled");
+ // loops. This could be controlled by user for more flexibility.
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
@@ -276,18 +272,31 @@ struct LinalgOpPartialReductionInterface
return op->emitOpError(
"Failed to get an identity value for the reduction operation.");
- // Calculate the new shape, we insert the new dimension based on the index
- // of the reduction dimension.
- SmallVector<int64_t> newOutputShape;
ArrayRef<int64_t> oldShape =
linalgOp.getShape(linalgOp.getDpsInitOperand(0));
+
+ // Extend tile size vector to the rank of the output tensor.
+ SmallVector<Value> tileSizeVector =
+ getValueOrCreateConstantIndexOp(b, loc, sizes);
+ if (tileSizeVector.size() < oldShape.size()) {
+ auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero);
+ }
+
+ // Calculate the new shape, we insert the new dimensions based on the index
+ // of the reduction dimensions.
+ SmallVector<int64_t> newOutputShape;
SmallVector<Value> dynamicDims;
- for (int64_t idx : llvm::seq<int64_t>(0, oldShape.size() + 1)) {
- if (idx == insertSplitDimension) {
+ int64_t currReductionDims = 0;
+ DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
+ for (int64_t idx :
+ llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
+ if (reductionDimsSet.contains(idx)) {
dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
+ currReductionDims++;
continue;
}
- int64_t oldIdx = idx < insertSplitDimension ? idx : idx - 1;
+ int64_t oldIdx = idx - currReductionDims;
int64_t dim = oldShape[oldIdx];
newOutputShape.push_back(dim);
if (ShapedType::isDynamic(dim))
@@ -310,21 +319,20 @@ struct LinalgOpPartialReductionInterface
ArrayRef<int> reductionDims) const {
OpBuilder::InsertionGuard guard(b);
auto linalgOp = cast<LinalgOp>(op);
- assert(reductionDims.size() == 1 &&
- "only support single reduction right now.");
- int64_t insertSplitDimension = reductionDims[0];
AffineMap oldOutputMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
- SmallVector<AffineExpr> outputExpr;
- for (auto [idx, expr] : llvm::enumerate(oldOutputMap.getResults())) {
- if (static_cast<int64_t>(idx) == insertSplitDimension) {
- outputExpr.push_back(b.getAffineDimExpr(reductionDims[0]));
- }
- outputExpr.push_back(expr);
+ SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() +
+ reductionDims.size());
+
+ for (int idx : reductionDims)
+ outputExpr[idx] = b.getAffineDimExpr(idx);
+ int currExpr = 0;
+ for (int idx : llvm::seq<int>(0, outputExpr.size())) {
+ if (outputExpr[idx])
+ continue;
+ outputExpr[idx] = oldOutputMap.getResult(currExpr++);
}
- if (insertSplitDimension == oldOutputMap.getNumResults())
- outputExpr.push_back(b.getAffineDimExpr(reductionDims[0]));
// Step 1: Extract a slice of the input operands.
SmallVector<Value> valuesToTile = linalgOp.getDpsInputOperands();
@@ -338,11 +346,12 @@ struct LinalgOpPartialReductionInterface
Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
sizes, strides);
- // Step3. create a generic op where the reduction dimension is replaced by a
- // parallel dimension of the size of reduction.
+ // Step3. Create a generic op where the reduction dimensions are replaced
+ // by a parallel dimension of the size of reduction.
SmallVector<utils::IteratorType> newIteratorTypes =
linalgOp.getIteratorTypesArray();
- newIteratorTypes[reductionDims[0]] = utils::IteratorType::parallel;
+ for (int dim : reductionDims)
+ newIteratorTypes[dim] = utils::IteratorType::parallel;
SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
linalgOp.getContext());
@@ -359,24 +368,25 @@ struct LinalgOpPartialReductionInterface
ValueRange partialReduce,
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
- assert(reductionDims.size() == 1 &&
- "only support single reduction right now.");
- int64_t dimToMerge = reductionDims[0];
- // Then create a new reduction that only reduce the newly added dimension
+ DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
+
+ // Then create a new reduction that only reduce the newly added dimensions
// from the previous op.
int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
SmallVector<utils::IteratorType> reductionIteratorTypes;
SmallVector<AffineExpr> exprs;
+
for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
- if (dimToMerge == i) {
+ if (reductionDimsSet.contains(i)) {
reductionIteratorTypes.push_back(utils::IteratorType::reduction);
} else {
exprs.push_back(b.getAffineDimExpr(i));
reductionIteratorTypes.push_back(utils::IteratorType::parallel);
}
}
+
AffineMap outputMap =
AffineMap::get(intermRank, 0, exprs, op->getContext());
SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 3f8fe9a23f55e5..597676a017bf48 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -419,26 +419,18 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
op, "don't support ops with multiple results for now");
SmallVector<utils::IteratorType> iterators =
tilingInterfaceOp.getLoopIteratorTypes();
- int64_t numReductionDims = llvm::count(
- tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
- if (numReductionDims != 1)
- return b.notifyMatchFailure(
- op, "only support ops with one reduction dimension.");
- int reductionDim;
+
+ SmallVector<int> reductionDims;
for (auto [idx, iteratorType] :
llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
- if (iteratorType == utils::IteratorType::reduction) {
- reductionDim = idx;
- break;
- }
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
}
- if (static_cast<size_t>(reductionDim) >= tileSize.size())
- return b.notifyMatchFailure(op, "reduction dimension must be tiled");
// 1. create the inital tensor value.
FailureOr<Operation *> identityTensor =
op.generateInitialTensorForPartialReduction(b, loc, tileSize,
- reductionDim);
+ reductionDims);
if (failed(identityTensor))
return b.notifyMatchFailure(op,
"cannot create a tensor of identity value.");
@@ -450,7 +442,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// 3. Generate the tiled implementation within the inner most loop.
b.setInsertionPoint(loops.back().getBody()->getTerminator());
Operation *parallelOp = op.tileToPartialReduction(
- b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim);
+ b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
SmallVector<OpFoldResult> resultSizesList;
for (size_t i = 0; i < offsets.size(); i++)
@@ -472,7 +464,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// 4. Apply the merge reduction to combine all the partial values.
b.setInsertionPointAfter(*loops.begin());
- Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
+ Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
b.replaceOp(op, mergeOp->getResults());
SCFReductionTilingResult results;
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 8c67b80c63f463..70e535b74f055b 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -353,3 +353,35 @@ module {
%for_op, %fill_op, %split_linalg_op, %combining_linalg_op = transform.structured.tile_reduction_using_scf %0 by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
}
}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+ func.func @reduction_tile_multiple_reduction(%arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %1, %out : f32
+ linalg.yield %2 : f32
+ } -> tensor<4096xf32>
+ return %0 : tensor<4096xf32>
+ }
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %for_op, %fill_op, %split_linalg_op, %combining_linalg_op = transform.structured.tile_reduction_using_scf %0 by tile_sizes = [0, 2, 64] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ }
+}
+
+// CHECK: func @reduction_tile_multiple_reduction(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<4096x2x64xf32>) -> tensor<4096x2x64xf32>
+// CHECK: %[[L0:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<4096x2x64xf32>)
+// CHECK: %[[L1:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]) -> (tensor<4096x2x64xf32>)
+// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>)
+// CHECK: scf.yield %[[OUT]] : tensor<4096x2x64xf32>
+// CHECK: scf.yield %[[L1]] : tensor<4096x2x64xf32>
+// CHECK: %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
+// CHECK: return %[[OUT2]] : tensor<4096xf32>
More information about the Mlir-commits
mailing list