[Mlir-commits] [mlir] [MLIR][Transform] Add attribute in MatchOp to filter by operand type (PR #67994)

Pablo Antonio Martinez llvmlistbot at llvm.org
Mon Dec 4 07:47:26 PST 2023


https://github.com/pabloantoniom updated https://github.com/llvm/llvm-project/pull/67994

>From c3b9f85d6d5436b9f2c6fb713b53eaf120d5e581 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Mon, 2 Oct 2023 16:13:41 +0100
Subject: [PATCH 1/6] [MLIR][Transform] Add attribute in MatchOp to filter by
 operand type

---
 .../Linalg/TransformOps/LinalgTransformOps.td   |  5 ++++-
 .../Linalg/TransformOps/LinalgTransformOps.cpp  |  9 +++++++++
 .../test/Dialect/Linalg/transform-op-match.mlir | 17 +++++++++++++++++
 3 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc03..71985925b8d94 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -535,6 +535,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
       - attribute: the matched op must have all specified attributes (with their
         specified values).
       - filter_result_type: the matched op must return exactly this one type.
+      - filter_operand_type: all the operands of the matched op must must be of this type.
 
     Note: Only ops that satisfy all specified constraints are matched.
 
@@ -556,7 +557,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
                        OptionalAttr<StrArrayAttr>:$ops,
                        OptionalAttr<MatchInterfaceEnum>:$interface,
                        OptionalAttr<DictionaryAttr>:$op_attrs,
-                       OptionalAttr<TypeAttr>:$filter_result_type);
+                       OptionalAttr<TypeAttr>:$filter_result_type,
+                       OptionalAttr<TypeAttr>:$filter_operand_type);
   // TODO: variadic results when needed.
   let results = (outs TransformHandleTypeInterface:$results);
 
@@ -570,6 +572,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
     (`interface` `{` $interface^ `}`)?
     (`attributes` $op_attrs^)?
     (`filter_result_type` `=` $filter_result_type^)?
+    (`filter_operand_type` `=` $filter_operand_type^)?
     `in` $target attr-dict
     `:` functional-type($target, results)
   }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9ce780d3d249c..260615ece7a99 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1180,6 +1180,15 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
         return;
     }
 
+    if (getFilterOperandType().has_value()) {
+      Type t = getFilterOperandType().value();
+      for (auto type : op->getOperandTypes()) {
+        if (type != t) {
+          return;
+        }
+      }
+    }
+
     // All constraints are satisfied.
     res.push_back(op);
     return;
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 9db63dc0696da..b98aaea9fc70f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -39,6 +39,23 @@ transform.sequence failures(propagate) {
 
 // -----
 
+func.func @by_operand_type() {
+  %c0 = arith.constant 1.0: f32
+  // expected-remark @below {{matched op name}}
+  %res = arith.fptoui %c0 : f32 to i32
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %match_name = transform.structured.match
+    ops{["arith.fptoui"]} filter_operand_type = f32 in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
+  transform.test_consume_operand %match_name : !transform.any_op
+}
+
+// -----
+
 func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
   %c0 = arith.constant 0.0 : f32
   // expected-remark @below {{tileable}}

>From e34f99517af901734380c8fba188d9feb750549f Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Tue, 7 Nov 2023 16:04:00 +0000
Subject: [PATCH 2/6] [MLIR][Transform] Add negative test case for
 filter_operand_type in MatchOp

---
 .../Dialect/Linalg/transform-op-match.mlir    | 32 +++++++++++++++++++
 1 file changed, 32 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index b98aaea9fc70f..5b8892ba3f7ba 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -56,6 +56,38 @@ transform.sequence failures(propagate) {
 
 // -----
 
+func.func @by_operand_type_negative_match() {
+  %c0 = arith.constant 1.0: f32
+  %res = arith.fptoui %c0 : f32 to i32
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %match_name = transform.structured.match
+    ops{["arith.fptoui"]} filter_operand_type = i32 in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
+  transform.test_consume_operand %match_name : !transform.any_op
+}
+
+// -----
+
+func.func @by_operand_type_negative_match() {
+  %c0 = arith.constant 1.0: f32
+  %res = arith.fptoui %c0 : f32 to i32
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %match_name = transform.structured.match
+    ops{["arith.fptoui"]} filter_operand_type = i32 in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
+  transform.test_consume_operand %match_name : !transform.any_op
+}
+
+// -----
+
 func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
   %c0 = arith.constant 0.0 : f32
   // expected-remark @below {{tileable}}

>From 5ba5bcbd879f768f217fed0999be6938db386bc8 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 10 Nov 2023 11:27:33 +0000
Subject: [PATCH 3/6] [MLIR][Transform] Simplify test cases. Use llvm::all_of
 to filter types

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  3 +-
 .../TransformOps/LinalgTransformOps.cpp       |  9 +++--
 .../Dialect/Linalg/transform-op-match.mlir    | 33 ++-----------------
 3 files changed, 9 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 71985925b8d94..6d60ab07d0113 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -535,7 +535,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
       - attribute: the matched op must have all specified attributes (with their
         specified values).
       - filter_result_type: the matched op must return exactly this one type.
-      - filter_operand_type: all the operands of the matched op must must be of this type.
+      - filter_operand_type: all the operands of the matched op must must be of
+        this type.
 
     Note: Only ops that satisfy all specified constraints are matched.
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 260615ece7a99..64e991c851828 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1182,11 +1182,10 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
 
     if (getFilterOperandType().has_value()) {
       Type t = getFilterOperandType().value();
-      for (auto type : op->getOperandTypes()) {
-        if (type != t) {
-          return;
-        }
-      }
+      if (!llvm::all_of(op->getOperandTypes(), [&](Type operandType) {
+        return operandType == t;
+      }))
+        return;
     }
 
     // All constraints are satisfied.
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 5b8892ba3f7ba..647afb6ffe11a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -52,38 +52,11 @@ transform.sequence failures(propagate) {
     ops{["arith.fptoui"]} filter_operand_type = f32 in %arg1 : (!transform.any_op) -> !transform.any_op
   transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
   transform.test_consume_operand %match_name : !transform.any_op
-}
-
-// -----
-
-func.func @by_operand_type_negative_match() {
-  %c0 = arith.constant 1.0: f32
-  %res = arith.fptoui %c0 : f32 to i32
-  return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !transform.any_op):
-  %match_name = transform.structured.match
-    ops{["arith.fptoui"]} filter_operand_type = i32 in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
-  transform.test_consume_operand %match_name : !transform.any_op
-}
-
-// -----
-
-func.func @by_operand_type_negative_match() {
-  %c0 = arith.constant 1.0: f32
-  %res = arith.fptoui %c0 : f32 to i32
-  return
-}
 
-transform.sequence failures(propagate) {
-^bb1(%arg1: !transform.any_op):
-  %match_name = transform.structured.match
+  %no_match_name = transform.structured.match
     ops{["arith.fptoui"]} filter_operand_type = i32 in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
-  transform.test_consume_operand %match_name : !transform.any_op
+  transform.test_print_remark_at_operand %no_match_name, "should not match" : !transform.any_op
+  transform.test_consume_operand %no_match_name : !transform.any_op
 }
 
 // -----

>From c867db38194851c37f67f412a37ba6d61907f739 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 10 Nov 2023 11:42:38 +0000
Subject: [PATCH 4/6] [MLIR][Transform] Run clang-format

---
 mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 64e991c851828..68f2d1fb0ebcf 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1182,9 +1182,8 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
 
     if (getFilterOperandType().has_value()) {
       Type t = getFilterOperandType().value();
-      if (!llvm::all_of(op->getOperandTypes(), [&](Type operandType) {
-        return operandType == t;
-      }))
+      if (!llvm::all_of(op->getOperandTypes(),
+                        [&](Type operandType) { return operandType == t; }))
         return;
     }
 

>From 50932215cbabd5ffb63b0d4875e3a860bf96d892 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Thu, 16 Nov 2023 16:03:03 +0000
Subject: [PATCH 5/6] [MLIR][Transform] Convert filter_operand_type into a list
 of types

---
 .../Linalg/TransformOps/LinalgTransformOps.td | 11 +++++----
 .../TransformOps/LinalgTransformOps.cpp       | 24 +++++++++++++++----
 .../Dialect/Linalg/transform-op-match.mlir    |  8 +++++--
 3 files changed, 33 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6d60ab07d0113..b3e47b4e343c1 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -535,8 +535,11 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
       - attribute: the matched op must have all specified attributes (with their
         specified values).
       - filter_result_type: the matched op must return exactly this one type.
-      - filter_operand_type: all the operands of the matched op must must be of
-        this type.
+      - filter_operand_types: all the operands of the matched op must must be of
+        this type. If more than a type is specified, then the length of the list
+        must be equal to the number of operands in the matched op, and the match
+        will succeed only if the operand types match all the types in the list
+        in the order in which they are specified.
 
     Note: Only ops that satisfy all specified constraints are matched.
 
@@ -559,7 +562,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
                        OptionalAttr<MatchInterfaceEnum>:$interface,
                        OptionalAttr<DictionaryAttr>:$op_attrs,
                        OptionalAttr<TypeAttr>:$filter_result_type,
-                       OptionalAttr<TypeAttr>:$filter_operand_type);
+                       OptionalAttr<TypeArrayAttr>:$filter_operand_types);
   // TODO: variadic results when needed.
   let results = (outs TransformHandleTypeInterface:$results);
 
@@ -573,7 +576,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
     (`interface` `{` $interface^ `}`)?
     (`attributes` $op_attrs^)?
     (`filter_result_type` `=` $filter_result_type^)?
-    (`filter_operand_type` `=` $filter_operand_type^)?
+    (`filter_operand_types` `=` $filter_operand_types^)?
     `in` $target attr-dict
     `:` functional-type($target, results)
   }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 68f2d1fb0ebcf..702b4544a18e4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1141,6 +1141,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
   }
 
   SmallVector<Operation *> res;
+  bool wrong_operand_filter = false;
   auto matchFun = [&](Operation *op) {
     if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
       return;
@@ -1180,11 +1181,23 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
         return;
     }
 
-    if (getFilterOperandType().has_value()) {
-      Type t = getFilterOperandType().value();
-      if (!llvm::all_of(op->getOperandTypes(),
-                        [&](Type operandType) { return operandType == t; }))
+    if (getFilterOperandTypes().has_value()) {
+      mlir::ArrayAttr types = getFilterOperandTypes().value();
+      auto operandTypes = op->getOperandTypes();
+      if (types.size() != operandTypes.size()) {
+        wrong_operand_filter = true;
         return;
+      }
+
+      for (auto const &it :
+           llvm::zip(getFilterOperandTypes().value(), operandTypes)) {
+        auto attr = dyn_cast<mlir::TypeAttr>(std::get<0>(it));
+        Type type = attr.getValue().cast<::mlir::Type>();
+        Type t = getElementTypeOrSelf(std::get<1>(it));
+
+        if (type != t)
+          return;
+      }
     }
 
     // All constraints are satisfied.
@@ -1193,6 +1206,9 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
   };
 
   (*payloadOps.begin())->walk(matchFun);
+  if (wrong_operand_filter)
+    return emitDefiniteFailure("filter_operand_types length must be equal to "
+                               "the number of operands in the target ops");
   results.set(cast<OpResult>(getResult()), res);
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 647afb6ffe11a..28c53bc6f35a2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -49,14 +49,18 @@ func.func @by_operand_type() {
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %match_name = transform.structured.match
-    ops{["arith.fptoui"]} filter_operand_type = f32 in %arg1 : (!transform.any_op) -> !transform.any_op
+    ops{["arith.fptoui"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
   transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
   transform.test_consume_operand %match_name : !transform.any_op
 
   %no_match_name = transform.structured.match
-    ops{["arith.fptoui"]} filter_operand_type = i32 in %arg1 : (!transform.any_op) -> !transform.any_op
+    ops{["arith.fptoui"]} filter_operand_types = [i32] in %arg1 : (!transform.any_op) -> !transform.any_op
   transform.test_print_remark_at_operand %no_match_name, "should not match" : !transform.any_op
   transform.test_consume_operand %no_match_name : !transform.any_op
+
+  // expected-error @+1 {{filter_operand_types length must be equal to the number of operands in the target ops}}
+  %failure_match = transform.structured.match
+    ops{["arith.fptoui"]} filter_operand_types = [i32, i32] in %arg1 : (!transform.any_op) -> !transform.any_op
 }
 
 // -----

>From e2545a9da9f32f16a468ca54d7d238dd552ecef3 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Mon, 4 Dec 2023 15:46:14 +0000
Subject: [PATCH 6/6] [MLIR][Transform] Fix the behaviour of
 filter_operand_type with more than one type. Small fixes. Add more testcases

---
 .../TransformOps/LinalgTransformOps.cpp       | 41 ++++++++++++-------
 .../Dialect/Linalg/transform-op-match.mlir    | 32 +++++++++++----
 2 files changed, 50 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 702b4544a18e4..421ff81680a48 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1141,7 +1141,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
   }
 
   SmallVector<Operation *> res;
-  bool wrong_operand_filter = false;
+  bool incorrectNumOperandTypes = false;
   auto matchFun = [&](Operation *op) {
     if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
       return;
@@ -1184,19 +1184,31 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
     if (getFilterOperandTypes().has_value()) {
       mlir::ArrayAttr types = getFilterOperandTypes().value();
       auto operandTypes = op->getOperandTypes();
-      if (types.size() != operandTypes.size()) {
-        wrong_operand_filter = true;
-        return;
-      }
 
-      for (auto const &it :
-           llvm::zip(getFilterOperandTypes().value(), operandTypes)) {
-        auto attr = dyn_cast<mlir::TypeAttr>(std::get<0>(it));
-        Type type = attr.getValue().cast<::mlir::Type>();
-        Type t = getElementTypeOrSelf(std::get<1>(it));
-
-        if (type != t)
+      if (types.size() == 1) {
+        // All the operands must must be equal to the specified type
+        auto typeattr =
+            dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
+        Type t = typeattr.getValue().cast<::mlir::Type>();
+        if (!llvm::all_of(op->getOperandTypes(),
+                          [&](Type operandType) { return operandType == t; }))
+          return;
+      } else {
+        // The operand types must match all the types in the list (in the same
+        // order in with they are specified)
+        if (types.size() != operandTypes.size()) {
+          incorrectNumOperandTypes = true;
           return;
+        }
+
+        for (auto [attr, operandType] :
+             llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
+          auto typeattr = dyn_cast<mlir::TypeAttr>(attr);
+          Type type = typeattr.getValue().cast<::mlir::Type>();
+
+          if (type != operandType)
+            return;
+        }
       }
     }
 
@@ -1206,8 +1218,9 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
   };
 
   (*payloadOps.begin())->walk(matchFun);
-  if (wrong_operand_filter)
-    return emitDefiniteFailure("filter_operand_types length must be equal to "
+  if (incorrectNumOperandTypes)
+    return emitDefiniteFailure("If filter_operand_types contains more than a "
+                               "type, then it must contain as much types as "
                                "the number of operands in the target ops");
   results.set(cast<OpResult>(getResult()), res);
   return DiagnosedSilenceableFailure::success();
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 28c53bc6f35a2..85e3d2bfec17e 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -40,25 +40,39 @@ transform.sequence failures(propagate) {
 // -----
 
 func.func @by_operand_type() {
-  %c0 = arith.constant 1.0: f32
+  %c2 = arith.constant 2.0: f32
+  %v = arith.constant 8: i32
+  %r1 = math.fpowi %c2, %v : f32, i32
   // expected-remark @below {{matched op name}}
-  %res = arith.fptoui %c0 : f32 to i32
+  %r2 = arith.addf %c2, %c2 : f32
+  // expected-remark @below {{matched op name}}
+  %r3 = arith.fptoui %r2 : f32 to i32
   return
 }
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
-  %match_name = transform.structured.match
+  %match_name1 = transform.structured.match
     ops{["arith.fptoui"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
-  transform.test_consume_operand %match_name : !transform.any_op
+  transform.test_print_remark_at_operand %match_name1, "matched op name" : !transform.any_op
+  transform.test_consume_operand %match_name1 : !transform.any_op
+
+  %match_name2 = transform.structured.match
+    ops{["arith.addf"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %match_name2, "matched op name" : !transform.any_op
+  transform.test_consume_operand %match_name2 : !transform.any_op
 
-  %no_match_name = transform.structured.match
+  %no_match_name1 = transform.structured.match
     ops{["arith.fptoui"]} filter_operand_types = [i32] in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.test_print_remark_at_operand %no_match_name, "should not match" : !transform.any_op
-  transform.test_consume_operand %no_match_name : !transform.any_op
+  transform.test_print_remark_at_operand %no_match_name1, "should not match" : !transform.any_op
+  transform.test_consume_operand %no_match_name1 : !transform.any_op
+
+  %no_match_name2 = transform.structured.match
+    ops{["math.fpowi"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %no_match_name2, "should not match" : !transform.any_op
+  transform.test_consume_operand %no_match_name2 : !transform.any_op
 
-  // expected-error @+1 {{filter_operand_types length must be equal to the number of operands in the target ops}}
+  // expected-error @+1 {{If filter_operand_types contains more than a type, then it must contain as much types as the number of operands in the target ops}}
   %failure_match = transform.structured.match
     ops{["arith.fptoui"]} filter_operand_types = [i32, i32] in %arg1 : (!transform.any_op) -> !transform.any_op
 }



More information about the Mlir-commits mailing list