[Mlir-commits] [mlir] 06ca5c8 - [mlir][Linalg] Apply fixes to TileReductionUsingForeachThreadOp
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Dec 9 07:51:18 PST 2022
Author: Nicolas Vasilache
Date: 2022-12-09T07:51:12-08:00
New Revision: 06ca5c81a4d88d9c33018d5a33e38c449109e5d6
URL: https://github.com/llvm/llvm-project/commit/06ca5c81a4d88d9c33018d5a33e38c449109e5d6
DIFF: https://github.com/llvm/llvm-project/commit/06ca5c81a4d88d9c33018d5a33e38c449109e5d6.diff
LOG: [mlir][Linalg] Apply fixes to TileReductionUsingForeachThreadOp
In the process, numerous insertion point issues were found and fixed.
RAII on insertion points is now used more dilligently.
Differential Revision: https://reviews.llvm.org/D139714
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f7b0c03ca2f07..f2b3fb795723c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -796,7 +796,7 @@ def TileReductionUsingForeachThreadOp :
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %7 into %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
}
- } {thread_dim_mapping = []}
+ } {mapping = []}
%3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<?x5xf32>) outs(%arg1 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
%4 = arith.addf %in, %out : f32
@@ -807,7 +807,8 @@ def TileReductionUsingForeachThreadOp :
let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+ OptionalAttr<DeviceMappingArrayAttr>:$mapping);
let results = (outs PDL_Operation:$fill_op,
PDL_Operation:$split_linalg_op,
PDL_Operation:$combining_linalg_op);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9e94f101349a2..8fdd6cb2b6dee 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1222,7 +1222,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
FailureOr<linalg::ForeachThreadReductionTilingResult> result =
linalg::tileReductionUsingForeachThread(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
- numThreads, tileSizes, /*mapping=*/std::nullopt);
+ numThreads, tileSizes, getMapping());
if (failed(result)) {
results.assign(3, nullptr);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 8c34c42ea3ff9..f5cbd81762a8b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -25,8 +25,11 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include <utility>
@@ -221,6 +224,9 @@ static void calculateTileOffsetsAndSizes(
Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
SmallVector<OpFoldResult> &tiledOffsets,
SmallVector<OpFoldResult> &tiledSizes) {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+
ValueRange threadIds = foreachThreadOp.getThreadIndices();
SmallVector<OpFoldResult> nonZeroNumThreads =
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
@@ -300,6 +306,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
Location loc = op->getLoc();
OpBuilder::InsertionGuard g(b);
+
SmallVector<Range> loopRanges = op.getIterationDomain(b);
if (loopRanges.empty())
return op->emitOpError("expected non-empty loop ranges");
@@ -323,54 +330,64 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
Operation *tiledOp = nullptr;
- // Create the ForeachThreadOp. We don't use the lambda body-builder
+ // 1. Create the ForeachThreadOp. We don't use the lambda body-builder
// version because we require the use of RewriterBase in the body, so we
// manually move the insertion point to the body below.
scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
- // Fill out the ForeachThreadOp body.
- b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ // 2. Fill out the ForeachThreadOp body.
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
calculateTileOffsetsAndSizes(b, loc, foreachThreadOp, numThreads, loopRanges,
omitTileOffsetBoundsCheck, nominalTileSizes,
tiledOffsets, tiledSizes);
- // Clone the tileable op and update its destination operands to use the output
- // bbArgs of the ForeachThreadOp.
+ // 3. Clone the tileable op and update its destination operands to use the
+ // output bbArgs of the ForeachThreadOp.
ArrayRef<BlockArgument> destBbArgs =
foreachThreadOp.getOutputBlockArguments();
- Operation *clonedOp = b.clone(*op.getOperation());
- auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
- if (destinationStyleOp) {
- for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
- auto *it = llvm::find(dest, outOperand->get());
- assert(it != dest.end() && "dest operand not found in dest");
- unsigned destNum = std::distance(dest.begin(), it);
- outOperand->set(destBbArgs[destNum]);
+ {
+ // 3.a. RAII guard, inserting within foreachThreadOp, before terminator.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(foreachThreadOp.getTerminator());
+ Operation *clonedOp = b.clone(*op.getOperation());
+ auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
+ if (destinationStyleOp) {
+ for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
+ auto *it = llvm::find(dest, outOperand->get());
+ assert(it != dest.end() && "dest operand not found in dest");
+ unsigned destNum = std::distance(dest.begin(), it);
+ outOperand->set(destBbArgs[destNum]);
+ }
}
- }
- // Tile the cloned op and delete the clone.
- SmallVector<Operation *> tiledOps =
- cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
- tiledSizes);
- b.eraseOp(clonedOp);
- assert(tiledOps.size() == 1 && "expected a single produced tiled op");
- tiledOp = tiledOps.front();
+ // 4. Tile the cloned op and delete the clone.
+ SmallVector<Operation *> tiledOps =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
+ tiledSizes);
+ b.eraseOp(clonedOp);
+ assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+ tiledOp = tiledOps.front();
+ }
+ // 5. Parallel insert back into the result tensor.
auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
- OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
tilingInterfaceOp->getResults(), destBbArgs)) {
- b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+ // 5.a. Partial subset information is inserted just before the terminator.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(foreachThreadOp.getTerminator());
+
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
tiledSizes, resultOffsets,
resultSizes)))
return op->emitOpError("output offsets couldn't be calculated");
SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
+
+ // 5.b. Parallel insertions are inserted at the end of the combining
+ // terminator.
b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
std::get<2>(it), resultOffsets,
@@ -415,6 +432,8 @@ template <typename LoopTy>
static FailureOr<TiledLinalgOp>
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
const LinalgTilingOptions &options) {
+ OpBuilder::InsertionGuard g(b);
+
auto nLoops = op.getNumLoops();
// Initial tile sizes may be too big, only take the first nLoops.
tileSizes = tileSizes.take_front(nLoops);
@@ -570,17 +589,35 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
Optional<ArrayAttr> mapping) {
Location loc = op.getLoc();
OpBuilder::InsertionGuard g(b);
+
// Ops implementing PartialReductionOpInterface are expected to implement
// TilingInterface.
+ // TODO: proper core mechanism to tie interfaces together.
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
+
+ // Ops implementing PartialReductionOpInterface are not necessarily expected
+ // to implement TilingInterface.. This cast is unsafe atm.
+ // TODO: proper core mechanism to tie interfaces together.
+ // TODO: this function requires a pair of interfaces ..
+ auto destinationStyleOp =
+ dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!destinationStyleOp)
+ return b.notifyMatchFailure(op, "not a destination style op");
+
+ // Actually this only work for Linalg ops atm.
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
+ if (!linalgOp)
+ return b.notifyMatchFailure(op, "not a linalg op");
+
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
if (op->getNumResults() != 1)
return b.notifyMatchFailure(
op, "don't support ops with multiple results for now");
+
SmallVector<utils::IteratorType> iterators =
tilingInterfaceOp.getLoopIteratorTypes();
SmallVector<unsigned> redDims;
- cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
+ linalgOp.getReductionDims(redDims);
if (redDims.size() != 1)
return b.notifyMatchFailure(
op, "only support ops with one reduction dimension.");
@@ -588,7 +625,8 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
"many elements as number of threads");
int reductionDim = static_cast<int>(redDims.front());
- // 1. create the inital tensor value.
+
+ // 1. Create the inital tensor value.
FailureOr<Operation *> identityTensor =
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
reductionDim);
@@ -615,8 +653,8 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
loc, identityTensor.value()->getResults(),
ValueRange(materializedNonZeroNumThreads), mapping);
- // 3. calculate the tile offsets and sizes.
- b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ // 3. Calculate the tile offsets and sizes for the subsequent loop that will
+ // be nested under `foreachThreadOp`.
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
calculateTileOffsetsAndSizes(
b, loc, foreachThreadOp, numThreads, iterationDomain,
@@ -625,54 +663,77 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
// 4. Clone the tileable op and update its destination operands to use the
// output bbArgs of the ForeachThreadOp.
+ ValueRange tilingResults;
ArrayRef<BlockArgument> destBbArgs =
foreachThreadOp.getOutputBlockArguments();
- Operation *clonedOp = b.clone(*op.getOperation());
- b.setInsertionPointToStart(foreachThreadOp.getBody(0));
- auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
- for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
- auto *it = llvm::find(dest, initOperand->get());
- assert(it != dest.end() && "dest operand not found in dest");
- unsigned destNum = std::distance(dest.begin(), it);
- SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
- SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
- SmallVector<OpFoldResult> sizes = tiledSizes;
- sizes[reductionDim] = b.getIndexAttr(1);
- outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
- // TODO: use SubsetExtractOpInterface once it is available.
- Value patial = b.create<tensor::ExtractSliceOp>(
- loc, initOperand->get().getType().cast<RankedTensorType>(),
- destBbArgs[destNum], outOffsets, sizes, strides);
- initOperand->set(patial);
- }
- b.setInsertionPoint(clonedOp);
+ {
+ // 4.a. RAII guard, inserting within foreachThreadOp, before terminator.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(foreachThreadOp.getTerminator());
+
+ SmallVector<Value> tiledDpsInitOperands;
+ for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
+ auto *it = llvm::find(dest, initOperand->get());
+ assert(it != dest.end() && "dest operand not found in dest");
+ unsigned destNum = std::distance(dest.begin(), it);
+ SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
+ SmallVector<OpFoldResult> outOffsets(numThreads.size(),
+ b.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes = tiledSizes;
+ sizes[reductionDim] = b.getIndexAttr(1);
+ outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
+ // TODO: use SubsetExtractOpInterface once it is available.
+ tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
+ loc, initOperand->get().getType().cast<RankedTensorType>(),
+ destBbArgs[destNum], outOffsets, sizes, strides));
+ }
- // 5. Tile the cloned op and delete the clone.
- if (tileSizes.empty()) {
- SmallVector<Operation *> tiledOps =
- cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
- tiledSizes);
- assert(tiledOps.size() == 1 && "expected a single produced tiled op");
- tiledOp = tiledOps.front();
- } else {
- LinalgTilingOptions options;
- auto tiled = tileLinalgOpImpl<scf::ForOp>(b, cast<LinalgOp>(clonedOp),
- tileSizes, options);
- SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
- mapLoopToProcessorIds(cast<scf::ForOp>(tiled->loops.back()), ids,
- materializedNonZeroNumThreads);
- assert(tiled->loops.size() == 1 && "expected a single produced loop");
- tiledOp = tiled->loops.front();
+ // 4.b. Clone the op and update init operands.
+ // We cannot use a BlockAndValueMapping here because it can replace
+ //
diff erent OpOperands with the same value.
+ Operation *clonedOp = b.clone(*op.getOperation());
+ b.updateRootInPlace(clonedOp, [&]() {
+ for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
+ cast<DestinationStyleOpInterface>(clonedOp).getDpsInitOperands(),
+ tiledDpsInitOperands)) {
+ initOperandPtr->set(tiledInitValue);
+ }
+ });
+
+ // 5. Tile the cloned op and delete the clone.
+ if (tileSizes.empty()) {
+ SmallVector<Operation *> tiledOps =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(
+ b, tiledOffsets, tiledSizes);
+ assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+ tiledOp = tiledOps.front();
+ tilingResults = tiledOp->getResults();
+ } else {
+ LinalgTilingOptions options;
+ FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
+ b, cast<LinalgOp>(clonedOp), tileSizes, options);
+ if (failed(maybeTiled))
+ return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
+
+ SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
+ mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
+ materializedNonZeroNumThreads);
+ assert(maybeTiled->loops.size() == 1 &&
+ "expected a single produced loop");
+ tiledOp = maybeTiled->op;
+ tilingResults = maybeTiled->loops.front()->getResults();
+ }
+
+ b.eraseOp(clonedOp);
}
- b.eraseOp(clonedOp);
// 6. Insert the partial reductions back into a new tensor.
- b.setInsertionPointAfter(tiledOp);
- OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
- for (auto [index, result, bbArg] :
- llvm::zip(llvm::seq<unsigned>(0, dest.size()), tiledOp->getResults(),
- destBbArgs)) {
- b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+ for (auto [index, result, bbArg] : llvm::zip(
+ llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
+ // 6.a. Partial subset information is inserted just before the terminator.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(foreachThreadOp.getTerminator());
+
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(tilingInterfaceOp.getResultTilePosition(
b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
@@ -689,18 +750,23 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
resultOffsetsRank.push_back(resultOffsets[offIdx++]);
resultSizesRank.push_back(resultSizes[sizeIdx++]);
}
-
SmallVector<OpFoldResult> strides(resultSizesRank.size(),
b.getIndexAttr(1));
+
+ // 6.b. Parallel insertions are inserted at the end of the combining
+ // terminator.
b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
b.create<tensor::ParallelInsertSliceOp>(
loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
}
+
// 7. Merge the partial reductions.
b.setInsertionPointAfter(foreachThreadOp);
Operation *mergeOp =
op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
b.replaceOp(op, mergeOp->getResults());
+
+ // 8. Return.
ForeachThreadReductionTilingResult results;
results.initialOp = identityTensor.value();
results.loops = foreachThreadOp;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index f7f062156c05b..76f6485c9ff9b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -874,19 +874,19 @@ void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
DiagnosedSilenceableFailure
transform::PrintOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- llvm::errs() << "[[[ IR printer: ";
+ llvm::outs() << "[[[ IR printer: ";
if (getName().has_value())
- llvm::errs() << *getName() << " ";
+ llvm::outs() << *getName() << " ";
if (!getTarget()) {
- llvm::errs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
+ llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
return DiagnosedSilenceableFailure::success();
}
- llvm::errs() << "]]]\n";
+ llvm::outs() << "]]]\n";
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
for (Operation *target : targets)
- llvm::errs() << *target << "\n";
+ llvm::outs() << *target << "\n";
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index ad2dc0a4124d8..cd0d6d71113cc 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -218,7 +218,8 @@ func.func @reduction_tile_parallel_cyclic_dist(
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5], tile_sizes = [0, 3] }
+ %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0
+ { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)>
@@ -262,3 +263,39 @@ transform.sequence failures(propagate) {
// CHECK: linalg.yield
// CHECK: } -> tensor<?xf32>
// CHECK: return %[[R]] : tensor<?xf32>
+
+// -----
+
+func.func @reduction_tile_parallel_cyclic_dist(
+ %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%out : tensor<?xf32>) {
+ ^bb0(%arg7: f32, %arg9: f32):
+ %1 = arith.mulf %arg7, %arg7 : f32
+ %2 = arith.addf %1, %arg9 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ return %red : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0
+ { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
+
+ // CHECK: expecting fill
+ // CHECK-NEXT: linalg.fill
+ transform.print %1 {name = "expecting fill"} : !pdl.operation
+ // CHECK: expecting parallel reduction
+ // CHECK-NEXT: linalg.generic
+ // CHECK: iterator_types = ["parallel", "reduction"]
+ transform.print %2 {name = "expecting parallel reduction"} : !pdl.operation
+ // CHECK: expecting parallel reduction
+ // CHECK-NEXT: linalg.generic
+ // CHECK: iterator_types = ["parallel", "reduction"]
+ transform.print %3 {name = "expecting parallel reduction"} : !pdl.operation
+}
More information about the Mlir-commits
mailing list