[Mlir-commits] [mlir] f7fda6b - [mlir][linalg] Add extra parameter to tiling reduction to foreach_thread
Thomas Raoux
llvmlistbot at llvm.org
Wed Dec 7 10:37:23 PST 2022
Author: Thomas Raoux
Date: 2022-12-07T18:37:05Z
New Revision: f7fda6ba4a7bca12bc7ef62d27895aa24482eda5
URL: https://github.com/llvm/llvm-project/commit/f7fda6ba4a7bca12bc7ef62d27895aa24482eda5
DIFF: https://github.com/llvm/llvm-project/commit/f7fda6ba4a7bca12bc7ef62d27895aa24482eda5.diff
LOG: [mlir][linalg] Add extra parameter to tiling reduction to foreach_thread
This adds a tile_size parameter, when it is used the tiles are
cyclically distributed onto the threads of the scf.foreach_thread op.
Differential Revision: https://reviews.llvm.org/D139474
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 3ea0a66625776..f7b0c03ca2f07 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -751,6 +751,8 @@ def TileReductionUsingForeachThreadOp :
All the partial reduction value is are parallel inserted to create a new
tensor. After the loop a merge operation is created to do a final reduction
with the partial reductions tensor.
+ If an extra `tile_sizes` parameter is passed the tiles are cyclically
+ distributed on the threads of the `scf.foreach_threads` loop.
#### Return modes
@@ -804,7 +806,8 @@ def TileReductionUsingForeachThreadOp :
}];
let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads);
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
let results = (outs PDL_Operation:$fill_op,
PDL_Operation:$split_linalg_op,
PDL_Operation:$combining_linalg_op);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a58c9dc23c1fc..d7603d2c3dd1b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -496,7 +496,8 @@ struct ForeachThreadReductionTilingResult {
FailureOr<ForeachThreadReductionTilingResult>
tileReductionUsingForeachThread(RewriterBase &b, PartialReductionOpInterface op,
ArrayRef<OpFoldResult> numThreads,
- Optional<ArrayAttr> mapping);
+ ArrayRef<OpFoldResult> tileSizes = {},
+ Optional<ArrayAttr> mapping = llvm::None);
/// All indices returned by IndexOp should be invariant with respect to
/// tiling. Therefore, if an operation is tiled, we have to transform the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b46349874bd4f..9e94f101349a2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1217,16 +1217,12 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
transform::TransformState &state) {
TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
- SmallVector<int64_t> numThreads = extractFromI64ArrayAttr(getNumThreads());
- SmallVector<OpFoldResult> numThreadResults;
- for (int64_t num : numThreads) {
- numThreadResults.push_back(rewriter.getIndexAttr(num));
- }
-
+ SmallVector<OpFoldResult> numThreads = getAsOpFoldResult(getNumThreads());
+ SmallVector<OpFoldResult> tileSizes = getAsOpFoldResult(getTileSizes());
FailureOr<linalg::ForeachThreadReductionTilingResult> result =
linalg::tileReductionUsingForeachThread(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
- numThreadResults, /*mapping=*/std::nullopt);
+ numThreads, tileSizes, /*mapping=*/std::nullopt);
if (failed(result)) {
results.assign(3, nullptr);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index cde33ce740e7d..8c34c42ea3ff9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -410,152 +411,6 @@ linalg::tileToForeachThreadOpUsingTileSizes(RewriterBase &b, TilingInterface op,
/*omitTileOffsetBoundsCheck=*/true);
}
-FailureOr<linalg::ForeachThreadReductionTilingResult>
-linalg::tileReductionUsingForeachThread(RewriterBase &b,
- PartialReductionOpInterface op,
- ArrayRef<OpFoldResult> numThreads,
- Optional<ArrayAttr> mapping) {
- Location loc = op.getLoc();
- OpBuilder::InsertionGuard g(b);
- // Ops implementing PartialReductionOpInterface are expected to implement
- // TilingInterface.
- auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
- SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
- if (op->getNumResults() != 1)
- return b.notifyMatchFailure(
- op, "don't support ops with multiple results for now");
- SmallVector<utils::IteratorType> iterators =
- tilingInterfaceOp.getLoopIteratorTypes();
- SmallVector<unsigned> redDims;
- cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
- if (redDims.size() != 1)
- return b.notifyMatchFailure(
- op, "only support ops with one reduction dimension.");
- int reductionDim = static_cast<int>(redDims.front());
- // 1. create the inital tensor value.
- FailureOr<Operation *> identityTensor =
- op.generateInitialTensorForPartialReduction(b, loc, numThreads,
- reductionDim);
- if (failed(identityTensor))
- return b.notifyMatchFailure(op,
- "cannot create a tensor of identity value.");
-
- // Gather destination tensors.
- SmallVector<Value> dest;
- if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
- return b.notifyMatchFailure(op, "failed to get destination tensors");
-
- Operation *tiledOp = nullptr;
-
- SmallVector<OpFoldResult> nonZeroNumThreads =
- llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 0);
- }));
- SmallVector<Value> materializedNonZeroNumThreads =
- llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
- return getValueOrCreateConstantIndexOp(b, loc, ofr);
- }));
-
- // 2. Create the ForeachThreadOp with an empty region.
- scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
- loc, identityTensor.value()->getResults(),
- ValueRange(materializedNonZeroNumThreads), mapping);
-
- // 3. calculate the tile offsets and sizes.
- b.setInsertionPointToStart(foreachThreadOp.getBody(0));
- SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
- calculateTileOffsetsAndSizes(
- b, loc, foreachThreadOp, numThreads, iterationDomain,
- /*omitTileOffsetBoundsCheck =*/false,
- /*nominalTileSizes=*/std::nullopt, tiledOffsets, tiledSizes);
-
- // 4. Clone the tileable op and update its destination operands to use the
- // output bbArgs of the ForeachThreadOp.
- ArrayRef<BlockArgument> destBbArgs =
- foreachThreadOp.getOutputBlockArguments();
- Operation *clonedOp = b.clone(*op.getOperation());
- auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
- for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
- auto *it = llvm::find(dest, outOperand->get());
- assert(it != dest.end() && "dest operand not found in dest");
- unsigned destNum = std::distance(dest.begin(), it);
- SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
- SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
- SmallVector<OpFoldResult> sizes = tiledSizes;
- sizes[reductionDim] = b.getIndexAttr(1);
- outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
- // TODO: use SubsetExtractOpInterface once it is available.
- Value patial = b.create<tensor::ExtractSliceOp>(
- loc, outOperand->get().getType().cast<RankedTensorType>(),
- destBbArgs[destNum], outOffsets, sizes, strides);
- outOperand->set(patial);
- }
-
- // 5. Tile the cloned op and delete the clone.
- SmallVector<Operation *> tiledOps =
- cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
- tiledSizes);
- b.eraseOp(clonedOp);
- assert(tiledOps.size() == 1 && "expected a single produced tiled op");
- tiledOp = tiledOps.front();
-
- // 6. Insert the partial reductions back into a new tensor.
- auto tiledInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
- assert(tiledInterfaceOp && "Tiled op does not implement TilingInterface");
- OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
- for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
- tiledInterfaceOp->getResults(), destBbArgs)) {
- b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
- SmallVector<OpFoldResult> resultOffsets, resultSizes;
- if (failed(tilingInterfaceOp.getResultTilePosition(
- b, std::get<0>(it), tiledOffsets, tiledSizes, resultOffsets,
- resultSizes)))
- return op->emitOpError("output offsets couldn't be calculated");
- SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
- int64_t offIdx = 0;
- int64_t sizeIdx = 0;
- for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
- if (i == reductionDim) {
- resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front());
- resultSizesRank.push_back(b.getIndexAttr(1));
- continue;
- }
- resultOffsetsRank.push_back(resultOffsets[offIdx++]);
- resultSizesRank.push_back(resultSizes[sizeIdx++]);
- }
-
- SmallVector<OpFoldResult> strides(resultSizesRank.size(),
- b.getIndexAttr(1));
- b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
- b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
- std::get<2>(it), resultOffsetsRank,
- resultSizesRank, strides);
- }
- // 7. Merge the partial reductions.
- b.setInsertionPointAfter(foreachThreadOp);
- Operation *mergeOp =
- op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
- b.replaceOp(op, mergeOp->getResults());
- ForeachThreadReductionTilingResult results;
- results.initialOp = identityTensor.value();
- results.loops = foreachThreadOp;
- results.parallelTiledOp = tiledOp;
- results.mergeOp = mergeOp;
- return results;
-}
-
-// Insert a tile `source` into the destination tensor `dest`. The position at
-// which the tile is inserted (as well as size of tile) is taken from a given
-// ExtractSliceOp `sliceOp`.
-static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
- tensor::ExtractSliceOp sliceOp, Value source,
- Value dest) {
- return b.create<tensor::InsertSliceOp>(
- loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(),
- sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
- sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
-}
-
template <typename LoopTy>
static FailureOr<TiledLinalgOp>
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
@@ -707,6 +562,165 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
}
+FailureOr<linalg::ForeachThreadReductionTilingResult>
+linalg::tileReductionUsingForeachThread(RewriterBase &b,
+ PartialReductionOpInterface op,
+ ArrayRef<OpFoldResult> numThreads,
+ ArrayRef<OpFoldResult> tileSizes,
+ Optional<ArrayAttr> mapping) {
+ Location loc = op.getLoc();
+ OpBuilder::InsertionGuard g(b);
+ // Ops implementing PartialReductionOpInterface are expected to implement
+ // TilingInterface.
+ auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
+ SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
+ if (op->getNumResults() != 1)
+ return b.notifyMatchFailure(
+ op, "don't support ops with multiple results for now");
+ SmallVector<utils::IteratorType> iterators =
+ tilingInterfaceOp.getLoopIteratorTypes();
+ SmallVector<unsigned> redDims;
+ cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
+ if (redDims.size() != 1)
+ return b.notifyMatchFailure(
+ op, "only support ops with one reduction dimension.");
+ if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
+ return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
+ "many elements as number of threads");
+ int reductionDim = static_cast<int>(redDims.front());
+ // 1. create the inital tensor value.
+ FailureOr<Operation *> identityTensor =
+ op.generateInitialTensorForPartialReduction(b, loc, numThreads,
+ reductionDim);
+ if (failed(identityTensor))
+ return b.notifyMatchFailure(op,
+ "cannot create a tensor of identity value.");
+
+ // Gather destination tensors.
+ SmallVector<Value> dest;
+ if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
+ return b.notifyMatchFailure(op, "failed to get destination tensors");
+
+ Operation *tiledOp = nullptr;
+
+ SmallVector<OpFoldResult> nonZeroNumThreads =
+ llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 0);
+ }));
+ SmallVector<Value> materializedNonZeroNumThreads =
+ getAsValues(b, loc, nonZeroNumThreads);
+
+ // 2. Create the ForeachThreadOp with an empty region.
+ scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
+ loc, identityTensor.value()->getResults(),
+ ValueRange(materializedNonZeroNumThreads), mapping);
+
+ // 3. calculate the tile offsets and sizes.
+ b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ calculateTileOffsetsAndSizes(
+ b, loc, foreachThreadOp, numThreads, iterationDomain,
+ /*omitTileOffsetBoundsCheck =*/false,
+ /*nominalTileSizes=*/std::nullopt, tiledOffsets, tiledSizes);
+
+ // 4. Clone the tileable op and update its destination operands to use the
+ // output bbArgs of the ForeachThreadOp.
+ ArrayRef<BlockArgument> destBbArgs =
+ foreachThreadOp.getOutputBlockArguments();
+ Operation *clonedOp = b.clone(*op.getOperation());
+ b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
+ for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
+ auto *it = llvm::find(dest, initOperand->get());
+ assert(it != dest.end() && "dest operand not found in dest");
+ unsigned destNum = std::distance(dest.begin(), it);
+ SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
+ SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes = tiledSizes;
+ sizes[reductionDim] = b.getIndexAttr(1);
+ outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
+ // TODO: use SubsetExtractOpInterface once it is available.
+ Value patial = b.create<tensor::ExtractSliceOp>(
+ loc, initOperand->get().getType().cast<RankedTensorType>(),
+ destBbArgs[destNum], outOffsets, sizes, strides);
+ initOperand->set(patial);
+ }
+ b.setInsertionPoint(clonedOp);
+
+ // 5. Tile the cloned op and delete the clone.
+ if (tileSizes.empty()) {
+ SmallVector<Operation *> tiledOps =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
+ tiledSizes);
+ assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+ tiledOp = tiledOps.front();
+ } else {
+ LinalgTilingOptions options;
+ auto tiled = tileLinalgOpImpl<scf::ForOp>(b, cast<LinalgOp>(clonedOp),
+ tileSizes, options);
+ SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
+ mapLoopToProcessorIds(cast<scf::ForOp>(tiled->loops.back()), ids,
+ materializedNonZeroNumThreads);
+ assert(tiled->loops.size() == 1 && "expected a single produced loop");
+ tiledOp = tiled->loops.front();
+ }
+ b.eraseOp(clonedOp);
+
+ // 6. Insert the partial reductions back into a new tensor.
+ b.setInsertionPointAfter(tiledOp);
+ OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
+ for (auto [index, result, bbArg] :
+ llvm::zip(llvm::seq<unsigned>(0, dest.size()), tiledOp->getResults(),
+ destBbArgs)) {
+ b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(tilingInterfaceOp.getResultTilePosition(
+ b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
+ return op->emitOpError("output offsets couldn't be calculated");
+ SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
+ int64_t offIdx = 0;
+ int64_t sizeIdx = 0;
+ for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
+ if (i == reductionDim) {
+ resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front());
+ resultSizesRank.push_back(b.getIndexAttr(1));
+ continue;
+ }
+ resultOffsetsRank.push_back(resultOffsets[offIdx++]);
+ resultSizesRank.push_back(resultSizes[sizeIdx++]);
+ }
+
+ SmallVector<OpFoldResult> strides(resultSizesRank.size(),
+ b.getIndexAttr(1));
+ b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
+ b.create<tensor::ParallelInsertSliceOp>(
+ loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
+ }
+ // 7. Merge the partial reductions.
+ b.setInsertionPointAfter(foreachThreadOp);
+ Operation *mergeOp =
+ op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
+ b.replaceOp(op, mergeOp->getResults());
+ ForeachThreadReductionTilingResult results;
+ results.initialOp = identityTensor.value();
+ results.loops = foreachThreadOp;
+ results.parallelTiledOp = tiledOp;
+ results.mergeOp = mergeOp;
+ return results;
+}
+
+// Insert a tile `source` into the destination tensor `dest`. The position at
+// which the tile is inserted (as well as size of tile) is taken from a given
+// ExtractSliceOp `sliceOp`.
+static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
+ tensor::ExtractSliceOp sliceOp, Value source,
+ Value dest) {
+ return b.create<tensor::InsertSliceOp>(
+ loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(),
+ sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
+ sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
+}
+
template <typename LoopTy>
FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 82925906e78c0..ad2dc0a4124d8 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -126,9 +126,9 @@ transform.sequence failures(propagate) {
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
-// CHECK: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
-// CHECK: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
-// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
+// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
+// CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
+// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor<?xf32> to tensor<?xf32>
@@ -180,9 +180,9 @@ transform.sequence failures(propagate) {
// CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
-// CHECK: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
-// CHECK: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
-// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
+// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
+// CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
+// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
// CHECK: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
@@ -197,3 +197,68 @@ transform.sequence failures(propagate) {
// CHECK: linalg.yield
// CHECK: } -> tensor<?x?xf32>
// CHECK: return %[[R]] : tensor<?x?xf32>
+
+// -----
+
+func.func @reduction_tile_parallel_cyclic_dist(
+ %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%out : tensor<?xf32>) {
+ ^bb0(%arg7: f32, %arg9: f32):
+ %1 = arith.mulf %arg7, %arg7 : f32
+ %2 = arith.addf %1, %arg9 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ return %red : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5], tile_sizes = [0, 3] }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0)>
+
+// CHECK: func @reduction_tile_parallel_cyclic_dist(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
+// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
+// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
+// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
+// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
+// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
+// CHECK: %[[CARRY:.+]] = scf.for %[[IV1:.+]] = %[[LB]] to %[[D1]] step %[[C15]] iter_args(%[[ACC:.+]] = %[[ET]]) -> (tensor<?xf32>) {
+// CHECK: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[D1]]]
+// CHECK: %[[D3:.+]] = tensor.dim %[[ACC]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [%[[D0]], %[[TS0]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<?xf32>
+// CHECK: %[[INS:.+]] = tensor.insert_slice %[[PARTIAL]] into %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK: scf.yield %[[INS]] : tensor<?xf32>
+// CHECK: }
+// CHECK: scf.foreach_thread.perform_concurrently {
+// CHECK: tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<?xf32>
+// CHECK: return %[[R]] : tensor<?xf32>
More information about the Mlir-commits
mailing list