[Mlir-commits] [mlir] 99833cd - [mlir][linalg] Add reduction tiling using scf.foreachthread
Thomas Raoux
llvmlistbot at llvm.org
Mon Nov 14 10:06:06 PST 2022
Author: Thomas Raoux
Date: 2022-11-14T18:05:40Z
New Revision: 99833cd8188c1ddbf8c76f4683813060553af79a
URL: https://github.com/llvm/llvm-project/commit/99833cd8188c1ddbf8c76f4683813060553af79a
DIFF: https://github.com/llvm/llvm-project/commit/99833cd8188c1ddbf8c76f4683813060553af79a.diff
LOG: [mlir][linalg] Add reduction tiling using scf.foreachthread
This adds a transformation to tile reduction operations to partial
reduction using scf.foreachthread. This uses
PartialReductionOpInterface to create a merge operation of the partial
tiles.
Differential Revision: https://reviews.llvm.org/D137912
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4cfec70b0c07e..d6fae79ecc9b0 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -626,7 +626,6 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
}];
}
-
def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
@@ -714,6 +713,89 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
}];
}
+def TileReductionUsingForeachThreadOp :
+ Op<Transform_Dialect, "structured.tile_reduction_using_foreach_thread",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Tile a PartialReductionOpInterface op to a tiled `scf.foreach_thread` doing
+ partial reduction.
+
+ This transformation tiles the `target` along the reduction dimensions. It
+ creates a tensor initialized with the identity value. Then it creates a
+ `scf.foreach_thread` loops with the number threads given by `num_threads`.
+ The op is tiled op with a size equal to `floordiv(size, num_threads)`.
+ All the partial reduction value is are parallel inserted to create a new
+ tensor. After the loop a merge operation is created to do a final reduction
+ with the partial reductions tensor.
+
+ #### Return modes
+
+ This 3 returned handles point to:
+ - the fill op used to initialize the neutral element,
+ - the parallel tiled op and
+ - the result-combining op.
+
+ #### Example:
+
+ ```
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%out : tensor<?xf32>) {
+ ^bb0(%arg7: f32, %arg9: f32):
+ %1 = arith.addf %arg7, %arg9 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?xf32>
+ return %red : tensor<?xf32>
+ ```
+
+ is transformed into:
+
+ ```
+ %0 = tensor.empty(%dim_1) : tensor<?x5xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x5xf32>) -> tensor<?x5xf32>
+ %2 = scf.foreach_thread (%arg2) in (%c5) shared_outs(%arg3 = %1) -> (tensor<?x5xf32>) {
+ %4 = affine.min #map(%arg2)[%dim_0]
+ %5 = affine.max #map1(%4)
+ %extracted_slice = tensor.extract_slice %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
+ %6 = affine.apply #map2(%arg2)[%dim_0]
+ %extracted_slice_2 = tensor.extract_slice %arg0[0, %6] [%dim, %5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %extracted_slice_3 = tensor.extract_slice %extracted_slice[0] [%dim] [1] : tensor<?xf32> to tensor<?xf32>
+ %7 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2 : tensor<?x?xf32>) outs(%extracted_slice_3 : tensor<?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %9 = arith.addf %in, %out : f32
+ linalg.yield %9 : f32
+ } -> tensor<?xf32>
+ scf.foreach_thread.perform_concurrently {
+ tensor.parallel_insert_slice %7 into %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
+ }
+ } {thread_dim_mapping = []}
+ %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<?x5xf32>) outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %4 = arith.addf %in, %out : f32
+ linalg.yield %4 : f32
+ } -> tensor<?xf32>
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads);
+ let results = (outs PDL_Operation:$fill_op,
+ PDL_Operation:$split_linalg_op,
+ PDL_Operation:$combining_linalg_op);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::linalg::LinalgOp target,
+ ::llvm::SmallVectorImpl<::mlir::Operation *> &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 758386fbc5bc1..dcdc532312460 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -445,6 +445,47 @@ tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
ArrayRef<OpFoldResult> tileSizes,
Optional<ArrayAttr> mapping);
+/// Transformation information returned after reduction tiling.
+struct ForeachThreadReductionTilingResult {
+ /// The partial reduction tiled op generated.
+ Operation *parallelTiledOp;
+ /// The final reduction operation merging all the partial reductions.
+ Operation *mergeOp;
+ /// The op initializing the tensor used for partial reductions.
+ Operation *initialOp;
+ /// The `scf.foreach_thread` operation that iterate over the tiles.
+ scf::ForeachThreadOp loops;
+};
+
+/// Method to tile a reduction to parallel iterations computing partial
+/// reductions. After the loop all the partial reduction are merged into a final
+/// reduction. For example for the following sequence
+///
+/// ```mlir
+/// %0 = linalg.generic %in ["parallel", "reduction"]
+/// : tensor<7x9xf32> -> tensor<7xf32>
+/// ```
+///
+/// into:
+///
+/// ```mlir
+/// %0 = linalg.fill ... : tensor<7x4xf32>
+/// %1 = scf.foreach_thread (%iv) in (%c4) shared_outs(%arg0 = %0)
+/// -> (tensor<7x4xf32>) {
+/// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32>
+/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
+/// %4 = linalg.generic %2, %3 ["parallel", "reduction"]
+/// : tensor<7x?xf32> -> tensor<7xf32>
+/// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32>
+/// }
+/// %6 = linalg.generic %1 ["parallel", "reduction"]
+/// : tensor<7x4xf32> -> tensor<7xf32>
+/// ```
+FailureOr<ForeachThreadReductionTilingResult>
+tileReductionUsingForeachThread(RewriterBase &b, PartialReductionOpInterface op,
+ ArrayRef<OpFoldResult> numThreads,
+ Optional<ArrayAttr> mapping);
+
/// All indices returned by IndexOp should be invariant with respect to
/// tiling. Therefore, if an operation is tiled, we have to transform the
/// indices accordingly, i.e. offset them by the values of the corresponding
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 960fabee05b59..ff8549253cff5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1165,6 +1165,39 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
return DiagnosedSilenceableFailure(success());
}
+//===----------------------------------------------------------------------===//
+// TileReductionUsingForeachThreadOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::TileReductionUsingForeachThreadOp::applyToOne(
+ linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
+ transform::TransformState &state) {
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(target);
+ SmallVector<int64_t> numThreads = extractFromI64ArrayAttr(getNumThreads());
+ SmallVector<OpFoldResult> numThreadResults;
+ for (int64_t num : numThreads) {
+ numThreadResults.push_back(rewriter.getIndexAttr(num));
+ }
+
+ FailureOr<linalg::ForeachThreadReductionTilingResult> result =
+ linalg::tileReductionUsingForeachThread(
+ rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
+ numThreadResults, /*mapping=*/llvm::None);
+
+ if (failed(result)) {
+ results.assign(3, nullptr);
+ Diagnostic diag(target->getLoc(), DiagnosticSeverity::Remark);
+ diag << "could not tile reduction in target.";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ }
+ results.push_back(result->initialOp);
+ results.push_back(result->parallelTiledOp);
+ results.push_back(result->mergeOp);
+ return DiagnosedSilenceableFailure(success());
+}
+
//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 284109e6bf3f2..a63ddc496117c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -211,58 +211,21 @@ static OpFoldResult buildMin(OpBuilder &b, Location loc,
vals);
}
-/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The
-/// tiling is specified by the number of tiles/threads `numThreads` and the
-/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
-/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
-/// numThreads[i])`. If non-empty, the `mapping` is added as an
-/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate
-/// that the dimension is not tiled, and can be thought of as tiling by the full
-/// size of data.
-/// It is the user's responsibility to ensure that `numThreads` is a valid
-/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
-/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
-static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
- RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
+/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
+/// number of threads.
+static void calculateTileOffsetsAndSizes(
+ RewriterBase &b, Location loc, scf::ForeachThreadOp foreachThreadOp,
+ ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
+ bool omitTileOffsetBoundsCheck,
Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
- Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
- Location loc = op->getLoc();
- OpBuilder::InsertionGuard g(b);
- SmallVector<Range> loopRanges = op.getIterationDomain(b);
- if (loopRanges.empty())
- return op->emitOpError("expected non-empty loop ranges");
- auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
- if (llvm::any_of(loopRanges, hasStrideOne))
- return op->emitOpError("only stride-1 supported atm");
-
- // Gather destination tensors.
- SmallVector<Value> dest;
- if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
- return op->emitOpError("failed to get destination tensors");
-
+ SmallVector<OpFoldResult> &tiledOffsets,
+ SmallVector<OpFoldResult> &tiledSizes) {
+ ValueRange threadIds = foreachThreadOp.getThreadIndices();
SmallVector<OpFoldResult> nonZeroNumThreads =
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 0);
}));
- SmallVector<Value> materializedNonZeroNumThreads =
- llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
- return getValueOrCreateConstantIndexOp(b, loc, ofr);
- }));
-
- Operation *tiledOp = nullptr;
-
- // Create the ForeachThreadOp. We don't use the lambda body-builder
- // version because we require the use of RewriterBase in the body, so we
- // manually move the insertion point to the body below.
- scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
- loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
-
- // Fill out the ForeachThreadOp body.
- b.setInsertionPointToStart(foreachThreadOp.getBody(0));
- ValueRange threadIds = foreachThreadOp.getThreadIndices();
int64_t nLoops = loopRanges.size();
- SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
tiledOffsets.reserve(nLoops);
tiledSizes.reserve(nLoops);
for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
@@ -316,6 +279,61 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
tiledSizes.push_back(tileSizePerThread);
++threadIdIdx;
}
+}
+
+/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The
+/// tiling is specified by the number of tiles/threads `numThreads` and the
+/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
+/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
+/// numThreads[i])`. If non-empty, the `mapping` is added as an
+/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate
+/// that the dimension is not tiled, and can be thought of as tiling by the full
+/// size of data.
+/// It is the user's responsibility to ensure that `numThreads` is a valid
+/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
+/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
+/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
+static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
+ RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
+ Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
+ Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard g(b);
+ SmallVector<Range> loopRanges = op.getIterationDomain(b);
+ if (loopRanges.empty())
+ return op->emitOpError("expected non-empty loop ranges");
+ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+ if (llvm::any_of(loopRanges, hasStrideOne))
+ return op->emitOpError("only stride-1 supported atm");
+
+ // Gather destination tensors.
+ SmallVector<Value> dest;
+ if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
+ return op->emitOpError("failed to get destination tensors");
+
+ SmallVector<OpFoldResult> nonZeroNumThreads =
+ llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 0);
+ }));
+ SmallVector<Value> materializedNonZeroNumThreads =
+ llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
+ return getValueOrCreateConstantIndexOp(b, loc, ofr);
+ }));
+
+ Operation *tiledOp = nullptr;
+
+ // Create the ForeachThreadOp. We don't use the lambda body-builder
+ // version because we require the use of RewriterBase in the body, so we
+ // manually move the insertion point to the body below.
+ scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
+ loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
+
+ // Fill out the ForeachThreadOp body.
+ b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ calculateTileOffsetsAndSizes(b, loc, foreachThreadOp, numThreads, loopRanges,
+ omitTileOffsetBoundsCheck, nominalTileSizes,
+ tiledOffsets, tiledSizes);
// Clone the tileable op and update its destination operands to use the output
// bbArgs of the ForeachThreadOp.
@@ -392,6 +410,140 @@ linalg::tileToForeachThreadOpUsingTileSizes(RewriterBase &b, TilingInterface op,
/*omitTileOffsetBoundsCheck=*/true);
}
+FailureOr<linalg::ForeachThreadReductionTilingResult>
+linalg::tileReductionUsingForeachThread(RewriterBase &b,
+ PartialReductionOpInterface op,
+ ArrayRef<OpFoldResult> numThreads,
+ Optional<ArrayAttr> mapping) {
+ Location loc = op.getLoc();
+ OpBuilder::InsertionGuard g(b);
+ // Ops implementing PartialReductionOpInterface are expected to implement
+ // TilingInterface.
+ auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
+ SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
+ if (op->getNumResults() != 1)
+ return b.notifyMatchFailure(
+ op, "don't support ops with multiple results for now");
+ SmallVector<utils::IteratorType> iterators =
+ tilingInterfaceOp.getLoopIteratorTypes();
+ SmallVector<unsigned> redDims;
+ cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
+ if (redDims.size() != 1)
+ return b.notifyMatchFailure(
+ op, "only support ops with one reduction dimension.");
+ int reductionDim = static_cast<int>(redDims.front());
+ // 1. create the inital tensor value.
+ FailureOr<Operation *> identityTensor =
+ op.generateInitialTensorForPartialReduction(b, loc, numThreads,
+ reductionDim);
+ if (failed(identityTensor))
+ return b.notifyMatchFailure(op,
+ "cannot create a tensor of identity value.");
+
+ // Gather destination tensors.
+ SmallVector<Value> dest;
+ if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
+ return b.notifyMatchFailure(op, "failed to get destination tensors");
+
+ Operation *tiledOp = nullptr;
+
+ SmallVector<OpFoldResult> nonZeroNumThreads =
+ llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 0);
+ }));
+ SmallVector<Value> materializedNonZeroNumThreads =
+ llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
+ return getValueOrCreateConstantIndexOp(b, loc, ofr);
+ }));
+
+ // 2. Create the ForeachThreadOp with an empty region.
+ scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
+ loc, identityTensor.value()->getResults(),
+ ValueRange(materializedNonZeroNumThreads), mapping);
+
+ // 3. calculate the tile offsets and sizes.
+ b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ calculateTileOffsetsAndSizes(
+ b, loc, foreachThreadOp, numThreads, iterationDomain,
+ /*omitTileOffsetBoundsCheck =*/false,
+ /*nominalTileSizes=*/llvm::None, tiledOffsets, tiledSizes);
+
+ // 4. Clone the tileable op and update its destination operands to use the
+ // output bbArgs of the ForeachThreadOp.
+ ArrayRef<BlockArgument> destBbArgs =
+ foreachThreadOp.getOutputBlockArguments();
+ Operation *clonedOp = b.clone(*op.getOperation());
+ auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
+ for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
+ auto *it = llvm::find(dest, outOperand->get());
+ assert(it != dest.end() && "dest operand not found in dest");
+ unsigned destNum = std::distance(dest.begin(), it);
+ SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
+ SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes = tiledSizes;
+ sizes[reductionDim] = b.getIndexAttr(1);
+ outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
+ // TODO: use SubsetExtractOpInterface once it is available.
+ Value patial = b.create<tensor::ExtractSliceOp>(
+ loc, outOperand->get().getType().cast<RankedTensorType>(),
+ destBbArgs[destNum], outOffsets, sizes, strides);
+ outOperand->set(patial);
+ }
+
+ // 5. Tile the cloned op and delete the clone.
+ SmallVector<Operation *> tiledOps =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
+ tiledSizes);
+ b.eraseOp(clonedOp);
+ assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+ tiledOp = tiledOps.front();
+
+ // 6. Insert the partial reductions back into a new tensor.
+ auto tiledInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
+ assert(tiledInterfaceOp && "Tiled op does not implement TilingInterface");
+ OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
+ for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
+ tiledInterfaceOp->getResults(), destBbArgs)) {
+ b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(tilingInterfaceOp.getResultTilePosition(
+ b, std::get<0>(it), tiledOffsets, tiledSizes, resultOffsets,
+ resultSizes)))
+ return op->emitOpError("output offsets couldn't be calculated");
+ SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
+ int64_t offIdx = 0;
+ int64_t sizeIdx = 0;
+ for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
+ if (i == reductionDim) {
+ resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front());
+ resultSizesRank.push_back(b.getIndexAttr(1));
+ continue;
+ }
+ resultOffsetsRank.push_back(resultOffsets[offIdx++]);
+ resultSizesRank.push_back(resultSizes[sizeIdx++]);
+ }
+
+ SmallVector<OpFoldResult> strides(resultSizesRank.size(),
+ b.getIndexAttr(1));
+ b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
+ b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
+ std::get<2>(it), resultOffsetsRank,
+ resultSizesRank, strides);
+ }
+ // 7. Merge the partial reductions.
+ b.setInsertionPointAfter(foreachThreadOp);
+ Operation *mergeOp =
+ op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
+ b.replaceOp(op, mergeOp->getResults());
+ ForeachThreadReductionTilingResult results;
+ results.initialOp = identityTensor.value();
+ results.loops = foreachThreadOp;
+ results.parallelTiledOp = tiledOp;
+ results.mergeOp = mergeOp;
+ return results;
+}
+
// Insert a tile `source` into the destination tensor `dest`. The position at
// which the tile is inserted (as well as size of tile) is taken from a given
// ExtractSliceOp `sliceOp`.
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index dad2f8476d1ff..131c7a6193bde 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -canonicalize -cse | FileCheck %s
func.func @reduction_tile(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -86,3 +86,114 @@ transform.sequence failures(propagate) {
// CHECK: }
// CHECK: linalg.generic
// CHECK: return
+
+// -----
+
+func.func @reduction_tile_parallel(
+ %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%out : tensor<?xf32>) {
+ ^bb0(%arg7: f32, %arg9: f32):
+ %1 = arith.mulf %arg7, %arg7 : f32
+ %2 = arith.addf %1, %arg9 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ return %red : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5] }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: func @reduction_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
+// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
+// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
+// CHECK: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
+// CHECK: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
+// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
+// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
+// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<?xf32>
+// CHECK: scf.foreach_thread.perform_concurrently {
+// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<?xf32>
+// CHECK: return %[[R]] : tensor<?xf32>
+
+// -----
+
+func.func @matmul_tile_parallel(
+ %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %matmul = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%out: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %matmul : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 0, 5] }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D3:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
+// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
+// CHECK: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
+// CHECK: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
+// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
+// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
+// CHECK: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0, 0] [%[[D0]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: scf.foreach_thread.perform_concurrently {
+// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) {
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<?x?xf32>
+// CHECK: return %[[R]] : tensor<?x?xf32>
More information about the Mlir-commits
mailing list