[Mlir-commits] [mlir] 6370f75 - [mlir][Transform] Add support for dynamically unpacking tile_sizes / num_threads in tile_to_foreach_thread
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 14 04:40:09 PST 2022
Author: Nicolas Vasilache
Date: 2022-11-14T04:39:57-08:00
New Revision: 6370f75ad70c508b52a3d98250bad51e8d8138b6
URL: https://github.com/llvm/llvm-project/commit/6370f75ad70c508b52a3d98250bad51e8d8138b6
DIFF: https://github.com/llvm/llvm-project/commit/6370f75ad70c508b52a3d98250bad51e8d8138b6.diff
LOG: [mlir][Transform] Add support for dynamically unpacking tile_sizes / num_threads in tile_to_foreach_thread
This commit adds automatic unpacking of Value's of type pdl::OperationType to the underlying single-result OpResult.
This allows mixing single-value, attribute and multi-value pdl::Operation tile sizes and num threads to TileToForeachThreadOp.
Differential Revision: https://reviews.llvm.org/D137896
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b92ed19863069..4cfec70b0c07e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -796,6 +796,10 @@ def TileToForeachThreadOp :
If non-empty, the `mapping` is added as an attribute to the
resulting `scf.foreach_thread`.
+ Note: `tile_sizes` and `num_threads` are variadic. Each tile size/number of
+ threads can be an index attribute or a transform handle that is mapped to
+ exactly one payload op with exactly one index result.
+
#### Return modes
This operation ignores ops that do not implement the TilingInterface and
@@ -857,7 +861,7 @@ def TileToForeachThreadOp :
"ArrayRef<OpFoldResult>":$mixedNumThreads,
CArg<"::mlir::transform::NumThreadsSpec",
"::mlir::transform::NumThreadsSpec()">,
- CArg<"ArrayAttr", "{}">:$mapping)>,
+ CArg<"ArrayAttr", "{}">:$mapping)>
];
let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index cdd4e158f9b1c..960fabee05b59 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1327,9 +1327,12 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
ArrayRef<int64_t> staticTileSizes,
transform::TileSizesSpec,
ArrayAttr mapping) {
- return build(builder, result, target,
+ return build(builder, result,
+ /*target=*/target,
+ /*mixedTileSizes=*/
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
- TileSizesSpec(), mapping);
+ /*_=*/TileSizesSpec(),
+ /*mapping=*/mapping);
}
void transform::TileToForeachThreadOp::build(
@@ -1346,9 +1349,14 @@ void transform::TileToForeachThreadOp::build(
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
- build(builder, result, TypeRange{operationType, operationType}, target,
- /*numThreads=*/ValueRange{}, dynamicTileSizes,
- /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping);
+ build(builder, result,
+ /*resultTypes=*/TypeRange{operationType, operationType},
+ /*target=*/target,
+ /*numThreads=*/ValueRange{},
+ /*tileSizes=*/dynamicTileSizes,
+ /*staticNumThreads=*/builder.getI64ArrayAttr({}),
+ /*staticTileSizes=*/staticTileSizesAttr,
+ /*mapping=*/mapping);
}
void transform::TileToForeachThreadOp::build(OpBuilder &builder,
@@ -1376,11 +1384,48 @@ void transform::TileToForeachThreadOp::build(
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
- build(builder, result, TypeRange{operationType, operationType}, target,
- dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr,
- /*staticTileSizes=*/ArrayAttr(), mapping);
+ build(builder, result,
+ /*resultTypes=*/TypeRange{operationType, operationType},
+ /*target=*/target,
+ /*numThreads=*/dynamicNumThreads,
+ /*tileSizes=*/ValueRange{},
+ /*staticNumThreads=*/staticNumThreadsAttr,
+ /*staticTileSizes=*/builder.getI64ArrayAttr({}),
+ /*mapping=*/mapping);
}
+// Given a list of OpFoldResults that are either index attrs or op
+// handles, return a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op. (There
+// must be exactly one mapped payload op and it must have exactly one
+// index result.)
+static DiagnosedSilenceableFailure unpackPDLOperations(
+ transform::TransformState &state, TransformOpInterface transformOp,
+ SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
+ for (OpFoldResult ofr : ofrs) {
+ // Don't try to unpack non-PDL operation.
+ if (ofr.is<Attribute>() ||
+ !ofr.get<Value>().getType().isa<pdl::OperationType>()) {
+ result.push_back(ofr);
+ continue;
+ }
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
+ for (Operation *op : payloadOps) {
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+ result.push_back(op->getResult(0));
+ }
+ }
+
+ return DiagnosedSilenceableFailure(success());
+};
+
DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
TransformOpInterface transformOp, ArrayRef<Operation *> targets,
@@ -1390,56 +1435,18 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
if (targets.empty())
return DiagnosedSilenceableFailure(success());
- // Given a list of OpFoldResults that are either index attrs or op handles,
- // return a list of OpFoldResults where all op handles are replaced with the
- // first (and only) OpResult of that payload op. (There must be exactly one
- // mapped payload op and it must have exactly one index result.)
- auto getOpResultsOrIndexAttrs =
- [&](SmallVector<OpFoldResult> &result,
- ArrayRef<OpFoldResult> opHandlesOrIndexAttrs) {
- for (OpFoldResult ofr : opHandlesOrIndexAttrs) {
- if (ofr.is<Attribute>()) {
- result.push_back(ofr);
- continue;
- }
- ArrayRef<Operation *> dynamicNumThreads =
- state.getPayloadOps(ofr.get<Value>());
- if (dynamicNumThreads.size() != 1) {
- DiagnosedSilenceableFailure diag =
- transformOp.emitSilenceableError()
- << "handle must be mapped to exactly 1 payload op";
- diag.attachNote(ofr.get<Value>().getLoc())
- << "mapped to " << dynamicNumThreads.size() << " ops";
- return diag;
- }
- Operation *op = dynamicNumThreads[0];
- if (op->getNumResults() != 1 ||
- !op->getResult(0).getType().isIndex()) {
- DiagnosedSilenceableFailure diag =
- transformOp.emitSilenceableError()
- << "payload op must have exactly 1 index result";
- diag.attachNote(op->getLoc())
- << "has " << op->getNumResults() << " results";
- return diag;
- }
- result.push_back(op->getResult(0));
- }
-
- return DiagnosedSilenceableFailure(success());
- };
-
// getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
// Convert to OpFoldResults[index attributes or payload op].
SmallVector<OpFoldResult> numThreads;
DiagnosedSilenceableFailure status =
- getOpResultsOrIndexAttrs(numThreads, mixedNumThreads);
+ unpackPDLOperations(state, transformOp, numThreads, mixedNumThreads);
if (!status.succeeded())
return status;
// getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
// Convert to OpFoldResults[index attributes or payload op].
SmallVector<OpFoldResult> tileSizes;
- status = getOpResultsOrIndexAttrs(tileSizes, mixedTileSizes);
+ status = unpackPDLOperations(state, transformOp, tileSizes, mixedTileSizes);
if (!status.succeeded())
return status;
@@ -1488,8 +1495,11 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps,
tiledOps);
- if (!diag.succeeded())
+ if (!diag.succeeded()) {
+ transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
+ transformResults.set(getTiledOp().cast<OpResult>(), {});
return diag;
+ }
transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index f015cdff93fa4..4dda150ca5341 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -41,6 +41,48 @@ module {
// -----
+// In this test case, matmul dims and tile size are dynamic.
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[tile_size_1:.*]] = "test.dummy"()
+ // CHECK-DAG: %[[tile_size_2:.*]] = "test.dummy"()
+ // CHECK-DAG: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
+ // CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+ // CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size_1]]]
+ // CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map0]]()[%[[N]], %[[tile_size_2]]]
+ // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+ // CHECK: tensor.extract_slice %[[A]]
+ // CHECK: tensor.extract_slice %[[B]]
+ // CHECK: tensor.extract_slice %[[C_BLK]]
+ // CHECK: linalg.matmul
+ // CHECK: scf.foreach_thread.perform_concurrently
+ // CHECK-NEXT: tensor.parallel_insert_slice
+ %tile_size_1 = "test.dummy"() : () -> (index)
+ %tile_size_2 = "test.dummy"() : () -> (index)
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %sz = transform.structured.match ops{["test.dummy"]} in %arg1
+ %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz]
+}
+
+// -----
+
// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>
More information about the Mlir-commits
mailing list