[Mlir-commits] [mlir] 593c14d - [mlir][Linalg] Add return type filter to the transform dialect

Nicolas Vasilache llvmlistbot at llvm.org
Wed Sep 14 08:51:04 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-14T08:50:31-07:00
New Revision: 593c14d422e01cd7d6698321c62e3ac266e8cfb3

URL: https://github.com/llvm/llvm-project/commit/593c14d422e01cd7d6698321c62e3ac266e8cfb3
DIFF: https://github.com/llvm/llvm-project/commit/593c14d422e01cd7d6698321c62e3ac266e8cfb3.diff

LOG: [mlir][Linalg] Add return type filter to the transform dialect

This allows matching ops by additionally providing an idiomatic spec for a unique return type.

Differential Revision: https://reviews.llvm.org/D133862

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-match.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 571d01b83d15f..46d70a6561b0a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -199,6 +199,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
         names.
       - attribute: the matched op must have all specified attributes (with their
         specified values).
+      - filter_result_type: the matched op must return exactly this one type.
       
     Note: Only ops that satisfy all specified constraints are matched.
 
@@ -219,7 +220,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
   let arguments = (ins PDL_Operation:$target,
                        OptionalAttr<StrArrayAttr>:$ops,
                        OptionalAttr<MatchInterfaceEnum>:$interface,
-                       OptionalAttr<DictionaryAttr>:$op_attrs);
+                       OptionalAttr<DictionaryAttr>:$op_attrs,
+                       OptionalAttr<TypeAttr>:$filter_result_type);
   // TODO: variadic results when needed.
   let results = (outs PDL_Operation:$results);
 
@@ -227,6 +229,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
     (`ops` `{` $ops^ `}`)?
     (`interface` `{` $interface^ `}`)?
     (`attributes` $op_attrs^)?
+    (`filter_result_type` `=` $filter_result_type^)?
     `in` $target attr-dict
   }];
 }

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f4241f44ea0f6..29b13e27de7ed 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -458,7 +458,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
   SmallVector<Operation *> res;
   auto matchFun = [&](Operation *op) {
     if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
-      return WalkResult::advance();
+      return;
 
     // Interfaces cannot be matched by name, just by ID.
     // So we specifically encode the interfaces we care about for this op.
@@ -466,10 +466,10 @@ transform::MatchOp::apply(transform::TransformResults &results,
       auto iface = getInterface().value();
       if (iface == transform::MatchInterfaceEnum::LinalgOp &&
           !isa<linalg::LinalgOp>(op))
-        return WalkResult::advance();
+        return;
       if (iface == transform::MatchInterfaceEnum::TilingInterface &&
           isa<TilingInterface>(op))
-        return WalkResult::advance();
+        return;
     }
 
     // Check if all specified attributes match.
@@ -480,15 +480,21 @@ transform::MatchOp::apply(transform::TransformResults &results,
             attr.getName() == getOpsAttrName())
           continue;
         if (!op->hasAttr(attr.getName()))
-          return WalkResult::advance();
+          return;
         if (op->getAttr(attr.getName()) != attr.getValue())
-          return WalkResult::advance();
+          return;
       }
     }
 
+    if (getFilterResultType().has_value()) {
+      Type t = getFilterResultType().value();
+      if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
+        return;
+    }
+
     // All constraints are satisfied.
     res.push_back(op);
-    return WalkResult::advance();
+    return;
   };
 
   payloadOps.front()->walk(matchFun);

diff  --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 2696cf5b16ef9..7d31fc4b42844 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -25,6 +25,26 @@ transform.with_pdl_patterns {
 
 // -----
 
+func.func @by_type() {
+  %0 = arith.constant 0: i32
+  // expected-remark @below {{matched op name}}
+  %1 = arith.constant 1.0 : f32
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %match_name = transform.structured.match
+      ops{["arith.constant"]} filter_result_type = f32 in %arg1
+    transform.test_print_remark_at_operand %match_name, "matched op name"
+    transform.test_consume_operand %match_name
+  }
+}
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 func.func @match_complex_attribute(%arg0: tensor<12x128x32xf32>)


        


More information about the Mlir-commits mailing list