[Mlir-commits] [mlir] [MLIR][Transform] Add attribute in MatchOp to filter by operand type (PR #67994)

Pablo Antonio Martinez llvmlistbot at llvm.org
Mon Oct 2 08:15:07 PDT 2023


https://github.com/pabloantoniom updated https://github.com/llvm/llvm-project/pull/67994

>From c3b9f85d6d5436b9f2c6fb713b53eaf120d5e581 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Mon, 2 Oct 2023 16:13:41 +0100
Subject: [PATCH] [MLIR][Transform] Add attribute in MatchOp to filter by
 operand type

---
 .../Linalg/TransformOps/LinalgTransformOps.td   |  5 ++++-
 .../Linalg/TransformOps/LinalgTransformOps.cpp  |  9 +++++++++
 .../test/Dialect/Linalg/transform-op-match.mlir | 17 +++++++++++++++++
 3 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..71985925b8d94c5 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -535,6 +535,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
       - attribute: the matched op must have all specified attributes (with their
         specified values).
       - filter_result_type: the matched op must return exactly this one type.
+      - filter_operand_type: all the operands of the matched op must must be of this type.
 
     Note: Only ops that satisfy all specified constraints are matched.
 
@@ -556,7 +557,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
                        OptionalAttr<StrArrayAttr>:$ops,
                        OptionalAttr<MatchInterfaceEnum>:$interface,
                        OptionalAttr<DictionaryAttr>:$op_attrs,
-                       OptionalAttr<TypeAttr>:$filter_result_type);
+                       OptionalAttr<TypeAttr>:$filter_result_type,
+                       OptionalAttr<TypeAttr>:$filter_operand_type);
   // TODO: variadic results when needed.
   let results = (outs TransformHandleTypeInterface:$results);
 
@@ -570,6 +572,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
     (`interface` `{` $interface^ `}`)?
     (`attributes` $op_attrs^)?
     (`filter_result_type` `=` $filter_result_type^)?
+    (`filter_operand_type` `=` $filter_operand_type^)?
     `in` $target attr-dict
     `:` functional-type($target, results)
   }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9ce780d3d249cfb..260615ece7a9905 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1180,6 +1180,15 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
         return;
     }
 
+    if (getFilterOperandType().has_value()) {
+      Type t = getFilterOperandType().value();
+      for (auto type : op->getOperandTypes()) {
+        if (type != t) {
+          return;
+        }
+      }
+    }
+
     // All constraints are satisfied.
     res.push_back(op);
     return;
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 9db63dc0696dab3..b98aaea9fc70f2f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -39,6 +39,23 @@ transform.sequence failures(propagate) {
 
 // -----
 
+func.func @by_operand_type() {
+  %c0 = arith.constant 1.0: f32
+  // expected-remark @below {{matched op name}}
+  %res = arith.fptoui %c0 : f32 to i32
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %match_name = transform.structured.match
+    ops{["arith.fptoui"]} filter_operand_type = f32 in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %match_name, "matched op name" : !transform.any_op
+  transform.test_consume_operand %match_name : !transform.any_op
+}
+
+// -----
+
 func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
   %c0 = arith.constant 0.0 : f32
   // expected-remark @below {{tileable}}



More information about the Mlir-commits mailing list