[Mlir-commits] [mlir] b396e54 - Reland "[MLIR][Transform] Add attribute in MatchOp to filter by operand type (#67994)"

Pablo Antonio Martinez llvmlistbot at llvm.org
Thu Dec 7 03:58:08 PST 2023


Author: Pablo Antonio Martinez
Date: 2023-12-07T11:57:02Z
New Revision: b396e5429c9d5d18517a67e5c086f1013f47944f

URL: https://github.com/llvm/llvm-project/commit/b396e5429c9d5d18517a67e5c086f1013f47944f
DIFF: https://github.com/llvm/llvm-project/commit/b396e5429c9d5d18517a67e5c086f1013f47944f.diff

LOG: Reland "[MLIR][Transform] Add attribute in MatchOp to filter by operand type (#67994)"

Test was failing due to a different transform sequence declaration (transform sequence were used, while now it should be named transform sequence). Test is now fixed.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-match.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index de65f3176c46a..77ed9db5e71bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -574,6 +574,11 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
       - attribute: the matched op must have all specified attributes (with their
         specified values).
       - filter_result_type: the matched op must return exactly this one type.
+      - filter_operand_types: all the operands of the matched op must must be of
+        this type. If more than a type is specified, then the length of the list
+        must be equal to the number of operands in the matched op, and the match
+        will succeed only if the operand types match all the types in the list
+        in the order in which they are specified.
 
     Note: Only ops that satisfy all specified constraints are matched.
 
@@ -595,7 +600,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
                        OptionalAttr<StrArrayAttr>:$ops,
                        OptionalAttr<MatchInterfaceEnum>:$interface,
                        OptionalAttr<DictionaryAttr>:$op_attrs,
-                       OptionalAttr<TypeAttr>:$filter_result_type);
+                       OptionalAttr<TypeAttr>:$filter_result_type,
+                       OptionalAttr<TypeArrayAttr>:$filter_operand_types);
   // TODO: variadic results when needed.
   let results = (outs TransformHandleTypeInterface:$results);
 
@@ -609,6 +615,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
     (`interface` `{` $interface^ `}`)?
     (`attributes` $op_attrs^)?
     (`filter_result_type` `=` $filter_result_type^)?
+    (`filter_operand_types` `=` $filter_operand_types^)?
     `in` $target attr-dict
     `:` functional-type($target, results)
   }];

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e3713457e8412..54055aefbc512 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1171,6 +1171,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
   }
 
   SmallVector<Operation *> res;
+  bool incorrectNumOperandTypes = false;
   auto matchFun = [&](Operation *op) {
     if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
       return;
@@ -1210,12 +1211,47 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
         return;
     }
 
+    if (getFilterOperandTypes().has_value()) {
+      mlir::ArrayAttr types = getFilterOperandTypes().value();
+      auto operandTypes = op->getOperandTypes();
+
+      if (types.size() == 1) {
+        // All the operands must must be equal to the specified type
+        auto typeattr =
+            dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
+        Type t = typeattr.getValue().cast<::mlir::Type>();
+        if (!llvm::all_of(op->getOperandTypes(),
+                          [&](Type operandType) { return operandType == t; }))
+          return;
+      } else {
+        // The operand types must match all the types in the list (in the same
+        // order in with they are specified)
+        if (types.size() != operandTypes.size()) {
+          incorrectNumOperandTypes = true;
+          return;
+        }
+
+        for (auto [attr, operandType] :
+             llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
+          auto typeattr = cast<mlir::TypeAttr>(attr);
+          Type type = typeattr.getValue().cast<::mlir::Type>();
+
+          if (type != operandType)
+            return;
+        }
+      }
+    }
+
     // All constraints are satisfied.
     res.push_back(op);
     return;
   };
 
   (*payloadOps.begin())->walk(matchFun);
+  if (incorrectNumOperandTypes)
+    return emitDefiniteFailure("If filter_operand_types contains more than a "
+                               "type, then it must contain as much types as "
+                               "the number of operands in the target ops");
   results.set(cast<OpResult>(getResult()), res);
   return DiagnosedSilenceableFailure::success();
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 7d48b1f403b3b..fed3c007d9b6d 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -43,6 +43,44 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @by_operand_type() {
+  %c2 = arith.constant 2.0: f32
+  %v = arith.constant 8: i32
+  %r1 = math.fpowi %c2, %v : f32, i32
+  // expected-remark @below {{matched op name}}
+  %r2 = arith.addf %c2, %c2 : f32
+  // expected-remark @below {{matched op name}}
+  %r3 = arith.fptoui %r2 : f32 to i32
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %match_name1 = transform.structured.match
+      ops{["arith.fptoui"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.test_print_remark_at_operand %match_name1, "matched op name" : !transform.any_op
+    transform.test_consume_operand %match_name1 : !transform.any_op
+
+    %match_name2 = transform.structured.match
+      ops{["arith.addf"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.test_print_remark_at_operand %match_name2, "matched op name" : !transform.any_op
+    transform.test_consume_operand %match_name2 : !transform.any_op
+
+    %no_match_name1 = transform.structured.match
+      ops{["arith.fptoui"]} filter_operand_types = [i32] in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.test_print_remark_at_operand %no_match_name1, "should not match" : !transform.any_op
+    transform.test_consume_operand %no_match_name1 : !transform.any_op
+
+    %no_match_name2 = transform.structured.match
+      ops{["math.fpowi"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.test_print_remark_at_operand %no_match_name2, "should not match" : !transform.any_op
+    transform.test_consume_operand %no_match_name2 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
   %c0 = arith.constant 0.0 : f32
   // expected-remark @below {{tileable}}


        


More information about the Mlir-commits mailing list