[Mlir-commits] [mlir] 32c6e08 - [mlir][linalg] Add attribute matcher to structured.match transform op
Matthias Springer
llvmlistbot at llvm.org
Fri Jul 22 04:58:48 PDT 2022
Author: Matthias Springer
Date: 2022-07-22T13:55:12+02:00
New Revision: 32c6e0815aa0691cb2044702b61caa895f4b316c
URL: https://github.com/llvm/llvm-project/commit/32c6e0815aa0691cb2044702b61caa895f4b316c
DIFF: https://github.com/llvm/llvm-project/commit/32c6e0815aa0691cb2044702b61caa895f4b316c.diff
LOG: [mlir][linalg] Add attribute matcher to structured.match transform op
This is useful for building small test cases and will be utilized in a subsequent commit that adds a fusion example.
Differential Revision: https://reviews.llvm.org/D130344
Added:
mlir/test/Dialect/Linalg/transform-op-match.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f97061d516d5..0ad37ada457d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -193,10 +193,14 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
The following constraints are supported:
- interface: an optional MatchInterfaceEnum specifying an enum
- representation for an interface to target.
- - ops: an optional StrArrayAttr specifying the concrete name of an op.
+ representation for an interface to target.
+ - ops: an optional StrArrayAttr specifying the concrete name of an op.
+ Multiple names can be specified. Matched ops must have one of specified
+ names.
+ - attribute: an optional Str specifying the name of an attribute that
+ matched ops must have.
- Note: either `ops` or `interface` must be specified.
+ Note: Only ops that satisfy all specified constraints are matched.
TODO: Extend with regions to allow a limited form of constraints.
@@ -214,12 +218,17 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
let arguments = (ins PDL_Operation:$target,
OptionalAttr<StrArrayAttr>:$ops,
- OptionalAttr<MatchInterfaceEnum>:$interface);
+ OptionalAttr<MatchInterfaceEnum>:$interface,
+ OptionalAttr<StrAttr>:$attribute);
// TODO: variadic results when needed.
let results = (outs PDL_Operation:$results);
- let hasCustomAssemblyFormat = 1;
- let hasVerifier = 1;
+ let assemblyFormat = [{
+ (`ops` `{` $ops^ `}`)?
+ (`interface` `{` $interface^ `}`)?
+ (`attribute` `{` $attribute^ `}`)?
+ `in` $target attr-dict
+ }];
}
def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a74f3d4e3d3c..7590633eedce 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -430,15 +430,6 @@ LogicalResult transform::InterchangeOp::verify() {
// MatchOp
//===---------------------------------------------------------------------===//
-LogicalResult transform::MatchOp::verify() {
- bool opXorIface = getOps().hasValue() ^ getInterface().hasValue();
- if (!opXorIface)
- return this->emitOpError(
- "requires a either a match_op or a match_interface attribute (but not "
- "both)");
- return success();
-}
-
DiagnosedSilenceableFailure
transform::MatchOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
@@ -453,21 +444,28 @@ transform::MatchOp::apply(transform::TransformResults &results,
this->emitOpError("requires exactly one target handle"));
SmallVector<Operation *> res;
-
auto matchFun = [&](Operation *op) {
- if (strs.contains(op->getName().getStringRef()))
- res.push_back(op);
+ if (getOps().hasValue() && !strs.contains(op->getName().getStringRef()))
+ return WalkResult::advance();
+
// Interfaces cannot be matched by name, just by ID.
// So we specifically encode the interfaces we care about for this op.
if (getInterface().hasValue()) {
auto iface = getInterface().getValue();
if (iface == transform::MatchInterfaceEnum::LinalgOp &&
- isa<linalg::LinalgOp>(op))
- res.push_back(op);
+ !isa<linalg::LinalgOp>(op))
+ return WalkResult::advance();
if (iface == transform::MatchInterfaceEnum::TilingInterface &&
isa<TilingInterface>(op))
- res.push_back(op);
+ return WalkResult::advance();
}
+
+ if (getAttribute().hasValue() && !op->hasAttr(getAttribute().getValue()))
+ return WalkResult::advance();
+
+ // All constraints are satisfied.
+ res.push_back(op);
+ return WalkResult::advance();
};
payloadOps.front()->walk(matchFun);
@@ -475,65 +473,6 @@ transform::MatchOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure(success());
}
-ParseResult transform::MatchOp::parse(OpAsmParser &parser,
- OperationState &result) {
- // Parse 'match_op' or 'interface' clause.
- if (succeeded(parser.parseOptionalKeyword("ops"))) {
- ArrayAttr opsAttr;
- if (parser.parseLBrace() ||
- parser.parseCustomAttributeWithFallback(
- opsAttr, parser.getBuilder().getType<NoneType>(), "ops",
- result.attributes) ||
- parser.parseRBrace())
- return failure();
- } else if (succeeded(parser.parseOptionalKeyword("interface"))) {
- if (parser.parseLBrace())
- return failure();
- StringRef attrStr;
- auto loc = parser.getCurrentLocation();
- if (parser.parseKeyword(&attrStr))
- return failure();
- auto interfaceEnum = transform::symbolizeMatchInterfaceEnum(attrStr);
- if (!interfaceEnum)
- return parser.emitError(loc, "invalid ")
- << "match_interface attribute specification: \"" << attrStr << '"';
- transform::MatchInterfaceEnumAttr match_interfaceAttr =
- transform::MatchInterfaceEnumAttr::get(parser.getBuilder().getContext(),
- interfaceEnum.value());
- result.addAttribute("interface", match_interfaceAttr);
- if (parser.parseRBrace())
- return failure();
- } else {
- auto loc = parser.getCurrentLocation();
- return parser.emitError(loc, "expected ops or interface");
- }
-
- OpAsmParser::UnresolvedOperand targetRawOperands[1];
- ArrayRef<OpAsmParser::UnresolvedOperand> targetOperands(targetRawOperands);
- if (parser.parseKeyword("in") || parser.parseOperand(targetRawOperands[0]) ||
- parser.parseOptionalAttrDict(result.attributes))
- return failure();
- Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
- result.addTypes(pdlOpType);
- if (parser.resolveOperands(targetOperands, pdlOpType, result.operands))
- return failure();
- return success();
-}
-
-void transform::MatchOp::print(OpAsmPrinter &p) {
- if ((*this)->getAttr("ops")) {
- p << " ops{";
- p.printAttributeWithoutType(getOpsAttr());
- p << "}";
- }
- if ((*this)->getAttr("interface")) {
- p << " interface{" << stringifyMatchInterfaceEnum(*getInterface()) << "}";
- }
- p << " in " << getTarget();
- p.printOptionalAttrDict((*this)->getAttrs(),
- /*elidedAttrs=*/{"ops", "interface"});
-}
-
//===---------------------------------------------------------------------===//
// MultiTileSizesOp
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
new file mode 100644
index 000000000000..06af6c8dad96
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+func.func @bar() {
+ // expected-remark @below {{matched op name}}
+ // expected-remark @below {{matched attr name}}
+ %0 = arith.constant {my_attr} 0: i32
+ // expected-remark @below {{matched op name}}
+ %1 = arith.constant 1 : i32
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %match_name = transform.structured.match ops{["arith.constant"]} in %arg1
+ transform.test_print_remark_at_operand %match_name, "matched op name"
+ transform.test_consume_operand %match_name
+
+ %match_attr = transform.structured.match ops{["arith.constant"]} attribute{"my_attr"} in %arg1
+ transform.test_print_remark_at_operand %match_attr, "matched attr name"
+ transform.test_consume_operand %match_attr
+ }
+}
More information about the Mlir-commits
mailing list