[Mlir-commits] [mlir] [MLIR][Transform] Add attribute in MatchOp to filter by operand type (PR #67994)
Pablo Antonio Martinez
llvmlistbot at llvm.org
Fri Nov 10 03:44:13 PST 2023
https://github.com/pabloantoniom updated https://github.com/llvm/llvm-project/pull/67994
>From c3b9f85d6d5436b9f2c6fb713b53eaf120d5e581 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Mon, 2 Oct 2023 16:13:41 +0100
Subject: [PATCH 1/4] [MLIR][Transform] Add attribute in MatchOp to filter by
operand type
---
.../Linalg/TransformOps/LinalgTransformOps.td | 5 ++++-
.../Linalg/TransformOps/LinalgTransformOps.cpp | 9 +++++++++
.../test/Dialect/Linalg/transform-op-match.mlir | 17 +++++++++++++++++
3 files changed, 30 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..71985925b8d94c5 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -535,6 +535,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
- attribute: the matched op must have all specified attributes (with their
specified values).
- filter_result_type: the matched op must return exactly this one type.
+ - filter_operand_type: all the operands of the matched op must must be of this type.
Note: Only ops that satisfy all specified constraints are matched.
@@ -556,7 +557,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
OptionalAttr<StrArrayAttr>:$ops,
OptionalAttr<MatchInterfaceEnum>:$interface,
OptionalAttr<DictionaryAttr>:$op_attrs,
- OptionalAttr<TypeAttr>:$filter_result_type);
+ OptionalAttr<TypeAttr>:$filter_result_type,
+ OptionalAttr<TypeAttr>:$filter_operand_type);
// TODO: variadic results when needed.
let results = (outs TransformHandleTypeInterface:$results);
@@ -570,6 +572,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
(`interface` `{` $interface^ `}`)?
(`attributes` $op_attrs^)?
(`filter_result_type` `=` $filter_result_type^)?
+ (`filter_operand_type` `=` $filter_operand_type^)?
`in` $target attr-dict
`:` functional-type($target, results)
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9ce780d3d249cfb..260615ece7a9905 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1180,6 +1180,15 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
return;
}
+ if (getFilterOperandType().has_value()) {
+ Type t = getFilterOperandType().value();
+ for (auto type : op->getOperandTypes()) {
+ if (type != t) {
+ return;
+ }
+ }
+ }
+
// All constraints are satisfied.
res.push_back(op);
return;
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 9db63dc0696dab3..b98aaea9fc70f2f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -39,6 +39,23 @@ transform.sequence failures(propagate) {
// -----
+func.func @by_operand_type() {
+ %c0 = arith.constant 1.0: f32
+ // expected-remark @below {{matched op name}}
+ %res = arith.fptoui %c0 : f32 to i32
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %match_name = transform.structured.match
+ ops{["arith.fptoui"]} filter_operand_type = f32 in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
+ transform.test_consume_operand %match_name : !transform.any_op
+}
+
+// -----
+
func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
%c0 = arith.constant 0.0 : f32
// expected-remark @below {{tileable}}
>From e34f99517af901734380c8fba188d9feb750549f Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Tue, 7 Nov 2023 16:04:00 +0000
Subject: [PATCH 2/4] [MLIR][Transform] Add negative test case for
filter_operand_type in MatchOp
---
.../Dialect/Linalg/transform-op-match.mlir | 32 +++++++++++++++++++
1 file changed, 32 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index b98aaea9fc70f2f..5b8892ba3f7bae8 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -56,6 +56,38 @@ transform.sequence failures(propagate) {
// -----
+func.func @by_operand_type_negative_match() {
+ %c0 = arith.constant 1.0: f32
+ %res = arith.fptoui %c0 : f32 to i32
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %match_name = transform.structured.match
+ ops{["arith.fptoui"]} filter_operand_type = i32 in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
+ transform.test_consume_operand %match_name : !transform.any_op
+}
+
+// -----
+
+func.func @by_operand_type_negative_match() {
+ %c0 = arith.constant 1.0: f32
+ %res = arith.fptoui %c0 : f32 to i32
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %match_name = transform.structured.match
+ ops{["arith.fptoui"]} filter_operand_type = i32 in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
+ transform.test_consume_operand %match_name : !transform.any_op
+}
+
+// -----
+
func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
%c0 = arith.constant 0.0 : f32
// expected-remark @below {{tileable}}
>From 5ba5bcbd879f768f217fed0999be6938db386bc8 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 10 Nov 2023 11:27:33 +0000
Subject: [PATCH 3/4] [MLIR][Transform] Simplify test cases. Use llvm::all_of
to filter types
---
.../Linalg/TransformOps/LinalgTransformOps.td | 3 +-
.../TransformOps/LinalgTransformOps.cpp | 9 +++--
.../Dialect/Linalg/transform-op-match.mlir | 33 ++-----------------
3 files changed, 9 insertions(+), 36 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 71985925b8d94c5..6d60ab07d01132d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -535,7 +535,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
- attribute: the matched op must have all specified attributes (with their
specified values).
- filter_result_type: the matched op must return exactly this one type.
- - filter_operand_type: all the operands of the matched op must must be of this type.
+ - filter_operand_type: all the operands of the matched op must must be of
+ this type.
Note: Only ops that satisfy all specified constraints are matched.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 260615ece7a9905..64e991c8518282d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1182,11 +1182,10 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
if (getFilterOperandType().has_value()) {
Type t = getFilterOperandType().value();
- for (auto type : op->getOperandTypes()) {
- if (type != t) {
- return;
- }
- }
+ if (!llvm::all_of(op->getOperandTypes(), [&](Type operandType) {
+ return operandType == t;
+ }))
+ return;
}
// All constraints are satisfied.
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 5b8892ba3f7bae8..647afb6ffe11a0c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -52,38 +52,11 @@ transform.sequence failures(propagate) {
ops{["arith.fptoui"]} filter_operand_type = f32 in %arg1 : (!transform.any_op) -> !transform.any_op
transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
transform.test_consume_operand %match_name : !transform.any_op
-}
-
-// -----
-
-func.func @by_operand_type_negative_match() {
- %c0 = arith.constant 1.0: f32
- %res = arith.fptoui %c0 : f32 to i32
- return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !transform.any_op):
- %match_name = transform.structured.match
- ops{["arith.fptoui"]} filter_operand_type = i32 in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
- transform.test_consume_operand %match_name : !transform.any_op
-}
-
-// -----
-
-func.func @by_operand_type_negative_match() {
- %c0 = arith.constant 1.0: f32
- %res = arith.fptoui %c0 : f32 to i32
- return
-}
-transform.sequence failures(propagate) {
-^bb1(%arg1: !transform.any_op):
- %match_name = transform.structured.match
+ %no_match_name = transform.structured.match
ops{["arith.fptoui"]} filter_operand_type = i32 in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
- transform.test_consume_operand %match_name : !transform.any_op
+ transform.test_print_remark_at_operand %no_match_name, "should not match" : !transform.any_op
+ transform.test_consume_operand %no_match_name : !transform.any_op
}
// -----
>From c867db38194851c37f67f412a37ba6d61907f739 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 10 Nov 2023 11:42:38 +0000
Subject: [PATCH 4/4] [MLIR][Transform] Run clang-format
---
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 64e991c8518282d..68f2d1fb0ebcf48 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1182,9 +1182,8 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
if (getFilterOperandType().has_value()) {
Type t = getFilterOperandType().value();
- if (!llvm::all_of(op->getOperandTypes(), [&](Type operandType) {
- return operandType == t;
- }))
+ if (!llvm::all_of(op->getOperandTypes(),
+ [&](Type operandType) { return operandType == t; }))
return;
}
More information about the Mlir-commits
mailing list