[Mlir-commits] [mlir] 297ba16 - [mlir][linalg] Add tile_size option to `structured.tile_to_foreach_thread_op`
Christopher Bate
llvmlistbot at llvm.org
Thu Jul 21 09:36:27 PDT 2022
Author: Christopher Bate
Date: 2022-07-21T10:32:01-06:00
New Revision: 297ba167ded073a47dd9ea7e408aa95acdfcedf1
URL: https://github.com/llvm/llvm-project/commit/297ba167ded073a47dd9ea7e408aa95acdfcedf1
DIFF: https://github.com/llvm/llvm-project/commit/297ba167ded073a47dd9ea7e408aa95acdfcedf1.diff
LOG: [mlir][linalg] Add tile_size option to `structured.tile_to_foreach_thread_op`
This change modifies `structured.tile_to_foreach_thread_op` so that
it accepts either `tile_sizes` or `num_threads` parameters. If
`tile_sizes` are specified, then the number of threads required is
derived the tile sizes rather than the other way around. In both cases,
more aggressive folding of loop parameters is enabled during the
transformation, allowing for the potential elimination of `affine.min`
and `affine.max` operations in the static shape case when calculating
the final adjusted tile size.
Differential Revision: https://reviews.llvm.org/D130139
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 9f1cd681315e3..b8bcf136ee383 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -608,15 +608,17 @@ def TileToForeachThreadOp :
TransformEachOpTrait,
TransformOpInterface]> {
let description = [{
- Tile a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
- tiling by `num_threads`.
- If non-empty, the `thread_dim_mapping` is added as an attribute to the
+ Tile a TilingInterface op to a tiled `scf.foreach_thread`. Tiling is
+ applied by either specifying `num_threads` or `tile_size`. If `num_threads`
+ is specified, then the tile size for each dimension `i` is calculated
+ dynamically via `ceilDiv(dimSize[i], num_threads[i])`.
+ If non-empty, the `thread_dim_mapping` is added as an attribute to the
resulting `scf.foreach_thread`.
- Zero tile sizes indicate that the dimension is not tiled, and can be thought
- of as tiling by the full size of data.
- It is the user's responsibility to ensure that `num_threads` is a valid
- tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
- Linalg case).
+ Zero tile sizes indicate that the dimension is not tiled and can be
+ thought of as tiling by the full size of data.
+ It is the user's responsibility to ensure that `num_threads/tile_sizes` is
+ a valid tiling specification (i.e. that only tiles parallel dimensions,
+ e.g. in the Linalg case).
#### Return modes
@@ -627,24 +629,39 @@ def TileToForeachThreadOp :
successfully, the transform succeeds.
Otherwise the transform silently fails.
- The 2 returned handles point to only the subset of successfully produced
+ The two returned handles point to only the subset of successfully produced
tiled operations, which can all be empty.
- These 2 returned handles point to:
+ These two returned handles point to:
- the new scf.foreach_thread op,
- the tiled op that implements TilingInterface.
+
+ ### Example using `num_threads`
+
+ ```
+ %0 = pdl_match @match_matmul in %arg1
+ %3:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20]
+ ```
+
+ ### Example using `tile_sizes`
+
+ ```
+ %0 = pdl_match @match_matmul in %arg1
+ %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20, 0]
+ ```
}];
let arguments = (ins PDL_Operation:$target,
// TODO: dynamic number of threads.
- DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
+ OptionalAttr<DefaultValuedAttr<I64ArrayAttr, "{}">>:$num_threads,
+ OptionalAttr<DefaultValuedAttr<I64ArrayAttr, "{}">>:$tile_sizes,
OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
let results = (outs PDL_Operation:$foreach_thread_op,
PDL_Operation:$tiled_op);
let assemblyFormat = [{
- $target $num_threads (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)?
- attr-dict
+ $target (`num_threads` $num_threads^) : (`tile_sizes` $tile_sizes)?
+ (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict
}];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c81a7ee2ac326..569faf383a0ee 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -466,10 +466,17 @@ struct ForeachThreadTilingResult {
Operation *tiledOp;
};
FailureOr<ForeachThreadTilingResult>
-tileToForeachThreadOp(OpBuilder &builder, TilingInterface op,
+tileToForeachThreadOp(RewriterBase &builder, TilingInterface op,
ArrayRef<OpFoldResult> numThreads,
ArrayRef<int64_t> threadDimMapping = {});
+/// Same as `tileToForeachThreadOp`, but calculate the number of threads
+/// required using the given tileSizes.
+FailureOr<ForeachThreadTilingResult>
+tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
+ ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<int64_t> threadDimMapping = {});
+
/// All indices returned by IndexOp should be invariant with respect to tiling.
/// Therefore, if an operation is tiled, we have to transform the indices
/// accordingly, i.e. offset them by the values of the corresponding induction
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d2dc66d610789..345e27742fbbd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -721,8 +721,13 @@ static void materializeConstants(OpBuilder &b, Location loc,
actualValues.push_back(value);
continue;
}
- constants.push_back(dialect->materializeConstant(b, ofr.get<Attribute>(),
- b.getIndexType(), loc));
+ // Since we are directly specifying `index` as the result type, we need to
+ // ensure the provided attribute is also an index type. Otherwise, the
+ // AffineDialect materializer will create invalid `arith.constant`
+ // operations if the provided Attribute is any other kind of integer.
+ constants.push_back(dialect->materializeConstant(
+ b, b.getIndexAttr(ofr.get<Attribute>().cast<IntegerAttr>().getInt()),
+ b.getIndexType(), loc));
actualValues.push_back(constants.back()->getResult(0));
}
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index adab8da3d518b..070b1fc4eb821 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -909,12 +909,20 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
IRRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
auto maybeThreadDimMappingAttr = getThreadDimMapping();
- FailureOr<ForeachThreadTilingResult> tilingResult =
- linalg::tileToForeachThreadOp(
- rewriter, target, getAsOpFoldResult(getNumThreads()),
- maybeThreadDimMappingAttr
- ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
- : ArrayRef<int64_t>{});
+ auto dimMapping =
+ llvm::to_vector(maybeThreadDimMappingAttr
+ ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
+ : ArrayRef<int64_t>{});
+
+ FailureOr<ForeachThreadTilingResult> tilingResult = failure();
+ if (Optional<ArrayAttr> numThreads = getNumThreads())
+ tilingResult = linalg::tileToForeachThreadOp(
+ rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
+
+ if (Optional<ArrayAttr> tileSizes = getTileSizes())
+ tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
+ rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
+
if (failed(tilingResult))
return emitDefaultSilenceableFailure(target);
rewriter.replaceOp(target, tilingResult->tileOp->getResults());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a8112dbe50b84..6d97dfc6d84fb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -41,9 +41,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRAnalysis
MLIRArithmeticDialect
MLIRArithmeticTransforms
+ MLIRArithmeticUtils
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRComplexDialect
+ MLIRDialectUtils
MLIRFuncDialect
MLIRFuncToLLVM
MLIRFuncTransforms
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 0571ff5432afb..1dfaf69efa728 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -13,6 +13,7 @@
#include <utility>
#include "PassDetail.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
@@ -182,23 +183,43 @@ createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc,
}
/// Build an `affine_max` of all the `vals`.
-static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) {
+static OpFoldResult buildMax(OpBuilder &b, Location loc,
+ ArrayRef<OpFoldResult> vals) {
+ SmallVector<Value> args = getValueOrCreateConstantIndexOp(b, loc, vals);
return b.createOrFold<AffineMaxOp>(
loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
- vals);
+ args);
}
-/// Build an `affine_min` of all the `vals`.
-static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) {
- return b.createOrFold<AffineMinOp>(
- loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
- vals);
+/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
+/// than `iterationSize`.
+static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
+ OpFoldResult numThreads,
+ OpFoldResult iterationSize) {
+ Optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
+ Optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
+ Optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
+ if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
+ return false;
+ return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
}
-FailureOr<ForeachThreadTilingResult>
-linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op,
- ArrayRef<OpFoldResult> numThreads,
- ArrayRef<int64_t> threadDimMapping) {
+/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The
+/// tiling is specified by the number of tiles/threads `numThreads` and the
+/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
+/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
+/// numThreads[i])`. If non-empty, the `threadDimMapping` is added as an
+/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate
+/// that the dimension is not tiled, and can be thought of as tiling by the full
+/// size of data.
+/// It is the user's responsibility to ensure that `numThreads` is a valid
+/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
+/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
+/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
+static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
+ RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
+ Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
+ ArrayRef<int64_t> threadDimMapping, bool omitTileOffsetBoundsCheck) {
Location loc = op->getLoc();
OpBuilder::InsertionGuard g(b);
SmallVector<Range> loopRanges = op.getIterationDomain(b);
@@ -224,80 +245,128 @@ linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op,
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Operation *tiledOp = nullptr;
+
+ // Create the ForeachThreadOp. We don't use the lambda body-builder
+ // version because we require the use of RewriterBase in the body, so we
+ // manually move the insertion point to the body below.
scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
- loc, materializedNonZeroNumThreads, threadDimMapping,
- [&](OpBuilder &b, Location loc, ValueRange threadIds) {
- int64_t nLoops = loopRanges.size();
- SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
- tiledOffsets.reserve(nLoops);
- tiledSizes.reserve(nLoops);
- for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops;
- ++loopIdx) {
- bool overflow = loopIdx >= numThreads.size();
- bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
- // Degenerate case: take the whole domain.
- if (overflow || isZero) {
- tiledOffsets.push_back(loopRanges[loopIdx].offset);
- tiledSizes.push_back(loopRanges[loopIdx].size);
- continue;
- }
-
- // Tiled case: compute the offset and size.
- AffineExpr i, j, M, N, O;
- bindDims(b.getContext(), i, j);
- bindSymbols(b.getContext(), M, N, O);
- Value size = loopRanges[loopIdx].size;
- Value offset = loopRanges[loopIdx].offset;
- Value threadId = threadIds[threadIdIdx];
- // TODO: more aggressive foldings.
- // Symbolic fixed max size per thread.
- // TODO: floor + 0/1 depending on case for better load-balancing.
- Value maxSizePerThread = b.createOrFold<AffineApplyOp>(
- loc, M.ceilDiv(N),
- ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]});
- // Dynamic offset shifted by threadId * maxSizePerThread.
- Value offsetPerThread = b.createOrFold<AffineApplyOp>(
- loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread});
- // Dynamic upper-bound depending on the threadId.
- Value sizeMinusOffsetPerThread = b.createOrFold<AffineApplyOp>(
- loc, -i + M, ValueRange{offsetPerThread, size});
- Value tileSizePerThread = buildMin(
- b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread});
- tiledOffsets.push_back(offsetPerThread);
- // TODO: if tileSizePerThread <= 0 early exit.
- tiledSizes.push_back(
- buildMax(b, loc, ValueRange{zero, tileSizePerThread}));
- ++threadIdIdx;
- }
-
- SmallVector<Operation *> tiledOps =
- op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
- /*tileDestOperands=*/true);
- assert(tiledOps.size() == 1 && "expected a single produced tiled op");
- tiledOp = tiledOps.front();
-
- auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
- assert(tilingInterfaceOp &&
- "Tiled op does not implement TilingInterface");
-
- auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
-
- // Create terminator with parallel subset insert operations.
- auto performConcurrentlyOp = b.create<scf::PerformConcurrentlyOp>(loc);
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPointToStart(performConcurrentlyOp.getBody());
- for (auto it :
- llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
- destOperands)) {
- createMatchingParallelSubsetInsertOp(
- b, loc,
- cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
- std::get<1>(it), std::get<2>(it));
- }
- });
+ loc, op->getResultTypes(), ValueRange(materializedNonZeroNumThreads),
+ threadDimMapping);
+
+ // Fill out the ForeachThreadOp body.
+ b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ ValueRange threadIds = foreachThreadOp.getThreadIndices();
+ int64_t nLoops = loopRanges.size();
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ tiledOffsets.reserve(nLoops);
+ tiledSizes.reserve(nLoops);
+ for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
+ bool overflow = loopIdx >= numThreads.size();
+ bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+ // Degenerate case: take the whole domain.
+ if (overflow || isZero) {
+ tiledOffsets.push_back(loopRanges[loopIdx].offset);
+ tiledSizes.push_back(loopRanges[loopIdx].size);
+ continue;
+ }
+
+ // Tiled case: compute the offset and size.
+ AffineExpr i, j, M, N, O;
+ bindDims(b.getContext(), i, j);
+ bindSymbols(b.getContext(), M, N, O);
+ Value size = loopRanges[loopIdx].size;
+ Value offset = loopRanges[loopIdx].offset;
+ Value threadId = threadIds[threadIdIdx];
+ // Symbolic fixed max size per thread.
+ // TODO: floor + 0/1 depending on case for better load-balancing.
+ OpFoldResult tileSizePerThread =
+ nominalTileSizes.hasValue()
+ ? (*nominalTileSizes)[loopIdx]
+ : makeComposedFoldedAffineApply(
+ b, loc, M.ceilDiv(N),
+ ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
+
+ // Dynamic offset shifted by threadId * maxSizePerThread.
+ OpFoldResult offsetPerThread = makeComposedFoldedAffineApply(
+ b, loc, i + j * M, {offset, threadId, tileSizePerThread});
+ // Dynamic upper-bound depending on the threadId.
+ OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
+ b, loc, i + j * M - N,
+ {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
+ if (!isConstantIntValue(residualTileSize, 0)) {
+ OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
+ b, loc, -i + M, {offsetPerThread, size});
+ tileSizePerThread = makeComposedFoldedAffineMin(
+ b, loc, AffineMap::getMultiDimIdentityMap(2, b.getContext()),
+ ArrayRef<OpFoldResult>{sizeMinusOffsetPerThread, tileSizePerThread});
+ }
+
+ tiledOffsets.push_back(offsetPerThread);
+ // TODO: if tileSizePerThread <= 0 early exit.
+ if (!omitTileOffsetBoundsCheck &&
+ !canOmitTileOffsetInBoundsCheck(tileSizePerThread,
+ nonZeroNumThreads[threadIdIdx], size))
+ tileSizePerThread = buildMax(b, loc, {zero, tileSizePerThread});
+
+ tiledSizes.push_back(tileSizePerThread);
+ ++threadIdIdx;
+ }
+
+ SmallVector<Operation *> tiledOps =
+ op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
+ /*tileDestOperands=*/true);
+ assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+ tiledOp = tiledOps.front();
+
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
+ assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
+
+ auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
+
+ // Create terminator with parallel subset insert operations.
+ b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody());
+ for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
+ destOperands)) {
+ createMatchingParallelSubsetInsertOp(
+ b, loc, cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
+ std::get<1>(it), std::get<2>(it));
+ }
return ForeachThreadTilingResult{foreachThreadOp, tiledOp};
}
+FailureOr<ForeachThreadTilingResult>
+linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op,
+ ArrayRef<OpFoldResult> numThreads,
+ ArrayRef<int64_t> threadDimMapping) {
+ return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None,
+ threadDimMapping,
+ /*omitTileOffsetBoundsCheck=*/false);
+}
+
+FailureOr<ForeachThreadTilingResult>
+linalg::tileToForeachThreadOpUsingTileSizes(
+ RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<int64_t> threadDimMapping) {
+ SmallVector<Range> loopRanges = op.getIterationDomain(b);
+ unsigned nLoops = loopRanges.size();
+ SmallVector<OpFoldResult> numThreads;
+ numThreads.reserve(nLoops);
+ AffineExpr s0, s1;
+ bindSymbols(b.getContext(), s0, s1);
+ AffineExpr divExpr = s0.ceilDiv(s1);
+ for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
+ OpFoldResult numTiles = std::get<0>(it);
+ if (!isConstantIntValue(numTiles, 0))
+ numTiles = makeComposedFoldedAffineApply(
+ b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
+ numThreads.push_back(numTiles);
+ }
+ return tileToForeachThreadOpImpl(b, op, numThreads,
+ /*nominalTileSizes=*/tileSizes,
+ threadDimMapping,
+ /*omitTileOffsetBoundsCheck=*/true);
+}
+
// Insert a tile `source` into the destination tensor `dest`. The position at
// which the tile is inserted (as well as size of tile) is taken from a given
// ExtractSliceOp `sliceOp`.
diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index 2529f4a50556f..ab97df3611571 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s
// Offset per thread:
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
@@ -37,7 +37,143 @@ module {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
- %1:2 = transform.structured.tile_to_foreach_thread_op %0 [10, 20] (mapped to dims [1, 0])
+ %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims [1, 0])
}
}
}
+
+// -----
+
+// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 15)>
+
+// CHECK-LABEL: matmul_static(
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor
+// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor
+// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor
+func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
+ // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
+ // CHECK-DAG: %[[c21:.+]] = arith.constant 21 : index
+ // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]])
+ // CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV1]])
+ // CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
+ // CHECK-NOT: affine.min
+ // CHECK-NOT: affine.max
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+ // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+ // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+ // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
+ // CHECK: linalg.matmul
+ // CHECK: scf.foreach_thread.perform_concurrently
+ // CHECK-NEXT: tensor.parallel_insert_slice
+ %0 = linalg.matmul ins(%A, %B : tensor<100x200xf32>, tensor<200x300xf32>)
+ outs(%C : tensor<100x300xf32>) -> (tensor<100x300xf32>)
+ return %0 : tensor<100x300xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 21]
+ }
+}
+
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic(
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
+ // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+ // CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]]
+ // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+ // CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
+ // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+ // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
+ // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+ // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+ // CHECK tensor.extract_slice %[[A]]
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+ // CHECK tensor.extract_slice %[[B]]
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+ // CHECK tensor.extract_slice %[[C]]
+ // CHECK: linalg.matmul
+ // CHECK: scf.foreach_thread.perform_concurrently
+ // CHECK-NEXT: tensor.parallel_insert_slice
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20]
+ }
+}
+
+// -----
+
+// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 21)>
+
+// CHECK-LABEL: matmul_tile_size_static(
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor
+// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor
+// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor
+func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
+ // CHECK-DAG: %[[c10:.+]] = arith.constant 10 :
+ // CHECK-DAG: %[[c15:.+]] = arith.constant 15 :
+ // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]])
+ // CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
+ // CHECK-NOT: affine.max
+ // CHECK-NOT: affine.min
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+ // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+ // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+ // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
+ // CHECK: linalg.matmul
+ // CHECK: scf.foreach_thread.perform_concurrently
+ // CHECK-NEXT: tensor.parallel_insert_slice
+ %0 = linalg.matmul ins(%A, %B : tensor<100x200xf32>, tensor<200x300xf32>)
+ outs(%C : tensor<100x300xf32>) -> (tensor<100x300xf32>)
+ return %0 : tensor<100x300xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 21]
+ }
+}
More information about the Mlir-commits
mailing list