[Mlir-commits] [mlir] Add support of param type for transform.structured.tile_using_forall (PR #72097)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 16 07:19:50 PST 2024
https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/72097
>From 27a4e49152933d3771a9005666dd4deef295ffd3 Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Mon, 13 Nov 2023 02:19:39 -0800
Subject: [PATCH] Add support of param type for
transform.structured.tile_using_forall
---
.../Linalg/TransformOps/LinalgTransformOps.td | 10 +-
.../TransformOps/LinalgTransformOps.cpp | 47 ++++--
mlir/test/Dialect/Linalg/tile-to-forall.mlir | 146 +++++++++++++++++-
.../TestTransformDialectExtension.cpp | 7 +
.../TestTransformDialectExtension.td | 19 +++
5 files changed, 214 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 7d10ba0ae829e5..7b981c4106e69b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -23,7 +23,7 @@ include "mlir/IR/RegionKindInterface.td"
// value in the payload IR.
def TransformParamTypeOrAnyHandle : Type<
Or<[TransformHandleTypeInterface.predicate,
- Transform_ParamType.predicate]>,
+ TransformParamTypeInterface.predicate]>,
"transform 'param' type or any handle type">;
//===----------------------------------------------------------------------===//
@@ -1965,10 +1965,10 @@ def TileUsingForallOp :
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- Variadic<TransformHandleTypeInterface>:$num_threads,
- Variadic<TransformHandleTypeInterface>:$tile_sizes,
- Optional<TransformHandleTypeInterface>:$packed_num_threads,
- Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
+ Variadic<TransformParamTypeOrAnyHandle>:$num_threads,
+ Variadic<TransformParamTypeOrAnyHandle>:$tile_sizes,
+ Optional<TransformParamTypeOrAnyHandle>:$packed_num_threads,
+ Optional<TransformParamTypeOrAnyHandle>:$packed_tile_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 139566d350fe83..25a871e346b2e1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -86,8 +86,9 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
return cast<LinalgOp>(result->getOperation());
}
-/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
-/// to exactly one op with one index result, return that value.
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, return that value.
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
transform::TransformState &state, TransformOpInterface transformOp,
SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
@@ -98,12 +99,23 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
result.push_back(ofr);
continue;
}
- auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+
+ Value transformValue = ofr.get<Value>();
+ if (isa<ParamType>(transformValue.getType())) {
+ ArrayRef<Attribute> params = state.getParams(transformValue);
+ if (params.size() != 1)
+ return transformOp.emitDefiniteFailure()
+ << "requires exactly one parameter associated";
+ result.push_back(params[0]);
+ continue;
+ }
+
+ auto payloadOps = state.getPayloadOps(transformValue);
if (!llvm::hasSingleElement(payloadOps)) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
<< "handle must be mapped to exactly one payload op";
- diag.attachNote(ofr.get<Value>().getLoc())
+ diag.attachNote(transformValue.getLoc())
<< "mapped to " << llvm::range_size(payloadOps) << " payload ops";
return diag;
}
@@ -123,14 +135,31 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
return DiagnosedSilenceableFailure::success();
}
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
+// Given a list of params that are index attrs or a list of OpFoldResults
+// that are either index attrs or op handles, return a list of OpFoldResults
+// of index attrs or a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op.
+// (There must be exactly one parameter associated with the AnyParamType or
+// one mapped payload op which must have exactly one index result.)
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
transform::TransformState &state, TransformOpInterface transformOp,
SmallVector<OpFoldResult> &result, Value packedHandle) {
+ if (isa<AnyParamType>(packedHandle.getType())) {
+ ArrayRef<Attribute> params = state.getParams(packedHandle);
+ if (params.size() != 1)
+ return transformOp.emitDefiniteFailure()
+ << "requires exactly one parameter associated";
+ ArrayAttr paramsArray = dyn_cast<ArrayAttr>(params[0]);
+ if (!paramsArray)
+ return transformOp.emitDefiniteFailure() << "expected ArrayAttr";
+ for (Attribute param : paramsArray.getValue()) {
+ if (!isa<IntegerAttr>(param))
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ result.push_back(param);
+ }
+ return DiagnosedSilenceableFailure::success();
+ }
+
for (Operation *op : state.getPayloadOps(packedHandle)) {
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
DiagnosedSilenceableFailure diag =
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 2192d160b1150f..792bff8cf0c694 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
// Offset per thread:
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
@@ -451,3 +451,147 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic(
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
+ // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+ // CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
+ // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+ // CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+ // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+ // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+ // CHECK: tensor.extract_slice %[[A]]
+ // CHECK: tensor.extract_slice %[[B]]
+ // CHECK: tensor.extract_slice %[[C_BLK]]
+ // CHECK: linalg.matmul
+ // CHECK: scf.forall.in_parallel
+ // CHECK-NEXT: tensor.parallel_insert_slice
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %sz = transform.param.constant 10 : i64 -> !transform.param<i64>
+ %1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param<i64>, 20]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic(
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
+ // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+ // CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
+ // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+ // CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+ // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+ // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+ // CHECK: tensor.extract_slice %[[A]]
+ // CHECK: tensor.extract_slice %[[B]]
+ // CHECK: tensor.extract_slice %[[C_BLK]]
+ // CHECK: linalg.matmul
+ // CHECK: scf.forall.in_parallel
+ // CHECK-NEXT: tensor.parallel_insert_slice
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %sz = transform.param.constant [10 : i64, 20 : i64] -> !transform.any_param
+ %1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %sz = transform.param.constant "[10 : i64, 20 : i64]" -> !transform.any_param
+ // expected-error @below {{expected ArrayAttr}}
+ %1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %sz = transform.param.constant ["10", "20"] -> !transform.any_param
+ // expected-error @below {{expected IntegerAttr}}
+ %1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %sz = transform.test_produce_empty_param %0 : (!transform.any_op) -> !transform.any_param
+ // expected-error @below {{requires exactly one parameter associated}}
+ %1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 50caf8f9cfc709..4ef5b27100c404 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -803,6 +803,13 @@ void mlir::test::TestProduceInvalidIR::getEffects(
transform::modifiesPayload(effects);
}
+DiagnosedSilenceableFailure mlir::test::TestProduceEmptyParamOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ return DiagnosedSilenceableFailure::success();
+}
+
namespace {
/// Test conversion pattern that replaces ops with the "replace_with_new_op"
/// attribute with "test.new_op".
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 54036f7929d1b8..49832fa63e2293 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -548,4 +548,23 @@ def TestProduceInvalidIR
}];
}
+def TestProduceEmptyParamOp :
+ Op<Transform_Dialect, "test_produce_empty_param",
+ [MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait,
+ ParamProducerTransformOpTrait]> {
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformParamTypeInterface:$out);
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+ let cppNamespace = "mlir::test";
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list