[Mlir-commits] [mlir] 416ba22 - [mlir][linalg][transform] Support dynamic tile sizes in TileToForeachThreadOp
Matthias Springer
llvmlistbot at llvm.org
Mon Aug 22 07:52:28 PDT 2022
Author: Matthias Springer
Date: 2022-08-22T16:48:45+02:00
New Revision: 416ba2256d2310abbfc185d1bce53715cf29092f
URL: https://github.com/llvm/llvm-project/commit/416ba2256d2310abbfc185d1bce53715cf29092f
DIFF: https://github.com/llvm/llvm-project/commit/416ba2256d2310abbfc185d1bce53715cf29092f.diff
LOG: [mlir][linalg][transform] Support dynamic tile sizes in TileToForeachThreadOp
TileToForeachThreadOp now accepts mixed SSA value operands / index attributes for tile_sizes and num_threads. (Reusing OperandsOrIntegersSizesList.) In case of an operand, a PDL_Operation must be specified that is mapped to a payload op that returns the tile size or number of threads.
Differential Revision: https://reviews.llvm.org/D131949
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 11c9b898ef58d..571d01b83d15f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -660,25 +660,31 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
def TileToForeachThreadOp :
Op<Transform_Dialect, "structured.tile_to_foreach_thread_op",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformEachOpTrait,
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface]> {
let description = [{
- Tile a TilingInterface op to a tiled `scf.foreach_thread`. Tiling is
- applied by either specifying `num_threads` or `tile_size`. If `num_threads`
- is specified, then the tile size for each dimension `i` is calculated
- dynamically via `ceilDiv(dimSize[i], num_threads[i])`.
- If non-empty, the `thread_dim_mapping` is added as an attribute to the
- resulting `scf.foreach_thread`.
- Zero tile sizes indicate that the dimension is not tiled and can be
+ Tile a TilingInterface op to a tiled `scf.foreach_thread`.
+
+ Tiling is applied by either specifying `num_threads` or `tile_size`. If
+ `num_threads` is specified, then the tile size for each dimension `i` is
+ calculated dynamically via `ceilDiv(dimSize[i], num_threads[i])`.
+ `num_threads` and `tile_size` can be either static index attributes or SSA
+ values of PDL operation handle type (or a mix thereof). Operation handles
+ must be mapped to exactly one op that has exactly one result of index type.
+
+ Static zero tile sizes indicate that the dimension is not tiled and can be
thought of as tiling by the full size of data.
+
It is the user's responsibility to ensure that `num_threads/tile_sizes` is
a valid tiling specification (i.e. that only tiles parallel dimensions,
e.g. in the Linalg case).
-
+
+ If non-empty, the `thread_dim_mapping` is added as an attribute to the
+ resulting `scf.foreach_thread`.
+
#### Return modes
-
+
This operation ignores ops that do not implement the TilingInterface and
drops them in the return.
@@ -696,36 +702,46 @@ def TileToForeachThreadOp :
### Example using `num_threads`
```
- %0 = pdl_match @match_matmul in %arg1
+ %0 = pdl_match @match_matmul in %arg1
%3:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20]
```
### Example using `tile_sizes`
-
+
```
- %0 = pdl_match @match_matmul in %arg1
- %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20, 0]
+ %0 = pdl_match @match_matmul in %arg1
+ %sz = pdl_match @match_size_op in %arg1
+ %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [0, %sz, 20]
```
}];
let arguments = (ins PDL_Operation:$target,
- // TODO: dynamic number of threads.
- OptionalAttr<DefaultValuedAttr<I64ArrayAttr, "{}">>:$num_threads,
- OptionalAttr<DefaultValuedAttr<I64ArrayAttr, "{}">>:$tile_sizes,
+ Variadic<PDL_Operation>:$num_threads,
+ Variadic<PDL_Operation>:$tile_sizes,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$static_num_threads,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$static_tile_sizes,
OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
let results = (outs PDL_Operation:$foreach_thread_op,
PDL_Operation:$tiled_op);
-
let assemblyFormat = [{
- $target (`num_threads` $num_threads^) : (`tile_sizes` $tile_sizes)?
- (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict
+ $target oilist(
+ `num_threads` custom<DynamicIndexList>($num_threads,
+ $static_num_threads,
+ "ShapedType::kDynamicSize") |
+ `tile_sizes` custom<DynamicIndexList>($tile_sizes,
+ $static_tile_sizes,
+ "ShapedType::kDynamicSize"))
+ (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict
}];
+ let hasVerifier = 1;
let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::TilingInterface target,
- ::llvm::SmallVectorImpl<::mlir::Operation *> &results,
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformResults &transformResults,
::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedNumThreads();
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedTileSizes();
}];
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index be4f202cc0b39..819f34596f6a0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1020,33 +1020,137 @@ void transform::TileOp::getEffects(
// TileToForeachThreadOp
//===----------------------------------------------------------------------===//
-DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
- TilingInterface target, SmallVectorImpl<Operation *> &results,
+DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
+ transform::TransformResults &transformResults,
transform::TransformState &state) {
IRRewriter rewriter(getContext());
- rewriter.setInsertionPoint(target);
- auto maybeThreadDimMappingAttr = getThreadDimMapping();
- auto dimMapping =
- llvm::to_vector(maybeThreadDimMappingAttr
- ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
- : ArrayRef<int64_t>{});
-
- FailureOr<ForeachThreadTilingResult> tilingResult = failure();
- if (Optional<ArrayAttr> numThreads = getNumThreads())
- tilingResult = linalg::tileToForeachThreadOp(
- rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
-
- if (Optional<ArrayAttr> tileSizes = getTileSizes())
- tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
- rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
-
- if (failed(tilingResult))
- return emitDefaultSilenceableFailure(target);
- rewriter.replaceOp(target, tilingResult->tileOp->getResults());
- results.assign({tilingResult->tileOp, tilingResult->tiledOp});
+ ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+
+ // If there the target payload ops are empty, there is nothing to do.
+ if (targets.empty()) {
+ transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
+ transformResults.set(getTiledOp().cast<OpResult>(), {});
+ return DiagnosedSilenceableFailure(success());
+ }
+
+ // Result payload ops.
+ SmallVector<Operation *> tileOps;
+ SmallVector<Operation *> tiledOps;
+
+ // Given a list of OpFoldResults that are either index attrs or op handles,
+ // return a list of OpFoldResults where all op handles are replaced with the
+ // first (and only) OpResult of that payload op. (There must be exactly one
+ // mapped payload op and it must have exactly one index result.)
+ auto getOpResultsOrIndexAttrs =
+ [&](SmallVector<OpFoldResult> &result,
+ ArrayRef<OpFoldResult> opHandlesOrIndexAttrs) {
+ for (OpFoldResult ofr : opHandlesOrIndexAttrs) {
+ if (ofr.is<Attribute>()) {
+ result.push_back(ofr);
+ continue;
+ }
+ ArrayRef<Operation *> dynamicNumThreads =
+ state.getPayloadOps(ofr.get<Value>());
+ if (dynamicNumThreads.size() != 1) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "handle must be mapped to exactly 1 payload op";
+ diag.attachNote(ofr.get<Value>().getLoc())
+ << "mapped to " << dynamicNumThreads.size() << " ops";
+ return diag;
+ }
+ Operation *op = dynamicNumThreads[0];
+ if (op->getNumResults() != 1 ||
+ !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+ result.push_back(op->getResult(0));
+ }
+
+ return DiagnosedSilenceableFailure(success());
+ };
+
+ // getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
+ // Convert to OpFoldResults[index attributes or payload op].
+ SmallVector<OpFoldResult> numThreads;
+ DiagnosedSilenceableFailure status =
+ getOpResultsOrIndexAttrs(numThreads, getMixedNumThreads());
+ if (!status.succeeded())
+ return status;
+
+ // getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
+ // Convert to OpFoldResults[index attributes or payload op].
+ SmallVector<OpFoldResult> tileSizes;
+ status = getOpResultsOrIndexAttrs(tileSizes, getMixedTileSizes());
+ if (!status.succeeded())
+ return status;
+
+ // Transform all targets one by one.
+ for (Operation *target : targets) {
+ auto tilableOp = dyn_cast<TilingInterface>(target);
+ if (!tilableOp) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "only TilingInterface ops are supported";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ rewriter.setInsertionPoint(tilableOp);
+ auto maybeThreadDimMappingAttr = getThreadDimMapping();
+ auto dimMapping = llvm::to_vector(
+ maybeThreadDimMappingAttr
+ ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
+ : ArrayRef<int64_t>{});
+
+ FailureOr<ForeachThreadTilingResult> tilingResult = failure();
+ if (!getMixedNumThreads().empty()) {
+ tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
+ numThreads, dimMapping);
+ } else {
+ tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
+ rewriter, tilableOp, tileSizes, dimMapping);
+ }
+
+ if (failed(tilingResult))
+ return emitDefaultSilenceableFailure(tilableOp);
+ rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
+
+ tileOps.push_back(tilingResult->tileOp);
+ tiledOps.push_back(tilingResult->tiledOp);
+ }
+
+ transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
+ transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
+
return DiagnosedSilenceableFailure(success());
}
+void transform::TileToForeachThreadOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ onlyReadsHandle(getTileSizes(), effects);
+ onlyReadsHandle(getNumThreads(), effects);
+ producesHandle(getResults(), effects);
+}
+
+SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
+ return getMixedSizes(getStaticNumThreads(), getNumThreads());
+}
+
+SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
+ return getMixedSizes(getStaticTileSizes(), getTileSizes());
+}
+
+LogicalResult TileToForeachThreadOp::verify() {
+ if (getMixedNumThreads().empty() == getMixedTileSizes().empty())
+ return emitOpError("either num_threads or tile_sizes must be specified");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// VectorizeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index dd062908c0bf1..d519adb4dfd2c 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -203,3 +203,49 @@ module {
// CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]])
// CHECK: scf.foreach_thread.perform_concurrently {
// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32>
+
+// -----
+
+// In this test case, matmul dims and tile size are dynamic.
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[tile_size:.*]] = "test.dummy"()
+ // CHECK-DAG: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
+ // CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+ // CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size]]]
+ // CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]]
+ // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
+ // CHECK tensor.extract_slice %[[A]]
+ // CHECK tensor.extract_slice %[[B]]
+ // CHECK tensor.extract_slice %[[C]]
+ // CHECK: linalg.matmul
+ // CHECK: scf.foreach_thread.perform_concurrently
+ // CHECK-NEXT: tensor.parallel_insert_slice
+ %tile_size = "test.dummy"() : () -> (index)
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %sz = transform.structured.match ops{["test.dummy"]} in %arg1
+ %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz, 20]
+ }
+}
More information about the Mlir-commits
mailing list