[Mlir-commits] [mlir] 32c6e08 - [mlir][linalg] Add attribute matcher to structured.match transform op

Matthias Springer llvmlistbot at llvm.org
Fri Jul 22 04:58:48 PDT 2022


Author: Matthias Springer
Date: 2022-07-22T13:55:12+02:00
New Revision: 32c6e0815aa0691cb2044702b61caa895f4b316c

URL: https://github.com/llvm/llvm-project/commit/32c6e0815aa0691cb2044702b61caa895f4b316c
DIFF: https://github.com/llvm/llvm-project/commit/32c6e0815aa0691cb2044702b61caa895f4b316c.diff

LOG: [mlir][linalg] Add attribute matcher to structured.match transform op

This is useful for building small test cases and will be utilized in a subsequent commit that adds a fusion example.

Differential Revision: https://reviews.llvm.org/D130344

Added: 
    mlir/test/Dialect/Linalg/transform-op-match.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f97061d516d5..0ad37ada457d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -193,10 +193,14 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
 
     The following constraints are supported:
       - interface: an optional MatchInterfaceEnum specifying an enum
-      representation for an interface to target.
-      - ops: an optional StrArrayAttr specifying the concrete name of an op. 
+        representation for an interface to target.
+      - ops: an optional StrArrayAttr specifying the concrete name of an op.
+        Multiple names can be specified. Matched ops must have one of specified
+        names.
+      - attribute: an optional Str specifying the name of an attribute that
+        matched ops must have.
       
-    Note: either `ops` or `interface` must be specified.
+    Note: Only ops that satisfy all specified constraints are matched.
 
     TODO: Extend with regions to allow a limited form of constraints.
 
@@ -214,12 +218,17 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
 
   let arguments = (ins PDL_Operation:$target,
                        OptionalAttr<StrArrayAttr>:$ops,
-                       OptionalAttr<MatchInterfaceEnum>:$interface);
+                       OptionalAttr<MatchInterfaceEnum>:$interface,
+                       OptionalAttr<StrAttr>:$attribute);
   // TODO: variadic results when needed.
   let results = (outs PDL_Operation:$results);
 
-  let hasCustomAssemblyFormat = 1;
-  let hasVerifier = 1;
+  let assemblyFormat = [{
+    (`ops` `{` $ops^ `}`)? 
+    (`interface` `{` $interface^ `}`)? 
+    (`attribute` `{` $attribute^ `}`)? 
+    `in` $target attr-dict
+  }];
 }
 
 def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a74f3d4e3d3c..7590633eedce 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -430,15 +430,6 @@ LogicalResult transform::InterchangeOp::verify() {
 // MatchOp
 //===---------------------------------------------------------------------===//
 
-LogicalResult transform::MatchOp::verify() {
-  bool opXorIface = getOps().hasValue() ^ getInterface().hasValue();
-  if (!opXorIface)
-    return this->emitOpError(
-        "requires a either a match_op or a match_interface attribute (but not "
-        "both)");
-  return success();
-}
-
 DiagnosedSilenceableFailure
 transform::MatchOp::apply(transform::TransformResults &results,
                           transform::TransformState &state) {
@@ -453,21 +444,28 @@ transform::MatchOp::apply(transform::TransformResults &results,
         this->emitOpError("requires exactly one target handle"));
 
   SmallVector<Operation *> res;
-
   auto matchFun = [&](Operation *op) {
-    if (strs.contains(op->getName().getStringRef()))
-      res.push_back(op);
+    if (getOps().hasValue() && !strs.contains(op->getName().getStringRef()))
+      return WalkResult::advance();
+
     // Interfaces cannot be matched by name, just by ID.
     // So we specifically encode the interfaces we care about for this op.
     if (getInterface().hasValue()) {
       auto iface = getInterface().getValue();
       if (iface == transform::MatchInterfaceEnum::LinalgOp &&
-          isa<linalg::LinalgOp>(op))
-        res.push_back(op);
+          !isa<linalg::LinalgOp>(op))
+        return WalkResult::advance();
       if (iface == transform::MatchInterfaceEnum::TilingInterface &&
           isa<TilingInterface>(op))
-        res.push_back(op);
+        return WalkResult::advance();
     }
+
+    if (getAttribute().hasValue() && !op->hasAttr(getAttribute().getValue()))
+      return WalkResult::advance();
+
+    // All constraints are satisfied.
+    res.push_back(op);
+    return WalkResult::advance();
   };
 
   payloadOps.front()->walk(matchFun);
@@ -475,65 +473,6 @@ transform::MatchOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure(success());
 }
 
-ParseResult transform::MatchOp::parse(OpAsmParser &parser,
-                                      OperationState &result) {
-  // Parse 'match_op' or 'interface' clause.
-  if (succeeded(parser.parseOptionalKeyword("ops"))) {
-    ArrayAttr opsAttr;
-    if (parser.parseLBrace() ||
-        parser.parseCustomAttributeWithFallback(
-            opsAttr, parser.getBuilder().getType<NoneType>(), "ops",
-            result.attributes) ||
-        parser.parseRBrace())
-      return failure();
-  } else if (succeeded(parser.parseOptionalKeyword("interface"))) {
-    if (parser.parseLBrace())
-      return failure();
-    StringRef attrStr;
-    auto loc = parser.getCurrentLocation();
-    if (parser.parseKeyword(&attrStr))
-      return failure();
-    auto interfaceEnum = transform::symbolizeMatchInterfaceEnum(attrStr);
-    if (!interfaceEnum)
-      return parser.emitError(loc, "invalid ")
-             << "match_interface attribute specification: \"" << attrStr << '"';
-    transform::MatchInterfaceEnumAttr match_interfaceAttr =
-        transform::MatchInterfaceEnumAttr::get(parser.getBuilder().getContext(),
-                                               interfaceEnum.value());
-    result.addAttribute("interface", match_interfaceAttr);
-    if (parser.parseRBrace())
-      return failure();
-  } else {
-    auto loc = parser.getCurrentLocation();
-    return parser.emitError(loc, "expected ops or interface");
-  }
-
-  OpAsmParser::UnresolvedOperand targetRawOperands[1];
-  ArrayRef<OpAsmParser::UnresolvedOperand> targetOperands(targetRawOperands);
-  if (parser.parseKeyword("in") || parser.parseOperand(targetRawOperands[0]) ||
-      parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-  Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
-  result.addTypes(pdlOpType);
-  if (parser.resolveOperands(targetOperands, pdlOpType, result.operands))
-    return failure();
-  return success();
-}
-
-void transform::MatchOp::print(OpAsmPrinter &p) {
-  if ((*this)->getAttr("ops")) {
-    p << " ops{";
-    p.printAttributeWithoutType(getOpsAttr());
-    p << "}";
-  }
-  if ((*this)->getAttr("interface")) {
-    p << " interface{" << stringifyMatchInterfaceEnum(*getInterface()) << "}";
-  }
-  p << " in " << getTarget();
-  p.printOptionalAttrDict((*this)->getAttrs(),
-                          /*elidedAttrs=*/{"ops", "interface"});
-}
-
 //===---------------------------------------------------------------------===//
 // MultiTileSizesOp
 //===---------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
new file mode 100644
index 000000000000..06af6c8dad96
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+func.func @bar() {
+  // expected-remark @below {{matched op name}}
+  // expected-remark @below {{matched attr name}}
+  %0 = arith.constant {my_attr} 0: i32
+  // expected-remark @below {{matched op name}}
+  %1 = arith.constant 1 : i32
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %match_name = transform.structured.match ops{["arith.constant"]} in %arg1
+    transform.test_print_remark_at_operand %match_name, "matched op name"
+    transform.test_consume_operand %match_name
+
+    %match_attr = transform.structured.match ops{["arith.constant"]} attribute{"my_attr"} in %arg1
+    transform.test_print_remark_at_operand %match_attr, "matched attr name"
+    transform.test_consume_operand %match_attr
+  }
+}


        


More information about the Mlir-commits mailing list