[Mlir-commits] [mlir] 593c14d - [mlir][Linalg] Add return type filter to the transform dialect
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Sep 14 08:51:04 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-14T08:50:31-07:00
New Revision: 593c14d422e01cd7d6698321c62e3ac266e8cfb3
URL: https://github.com/llvm/llvm-project/commit/593c14d422e01cd7d6698321c62e3ac266e8cfb3
DIFF: https://github.com/llvm/llvm-project/commit/593c14d422e01cd7d6698321c62e3ac266e8cfb3.diff
LOG: [mlir][Linalg] Add return type filter to the transform dialect
This allows matching ops by additionally providing an idiomatic spec for a unique return type.
Differential Revision: https://reviews.llvm.org/D133862
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-match.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 571d01b83d15f..46d70a6561b0a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -199,6 +199,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
names.
- attribute: the matched op must have all specified attributes (with their
specified values).
+ - filter_result_type: the matched op must return exactly this one type.
Note: Only ops that satisfy all specified constraints are matched.
@@ -219,7 +220,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
let arguments = (ins PDL_Operation:$target,
OptionalAttr<StrArrayAttr>:$ops,
OptionalAttr<MatchInterfaceEnum>:$interface,
- OptionalAttr<DictionaryAttr>:$op_attrs);
+ OptionalAttr<DictionaryAttr>:$op_attrs,
+ OptionalAttr<TypeAttr>:$filter_result_type);
// TODO: variadic results when needed.
let results = (outs PDL_Operation:$results);
@@ -227,6 +229,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
(`ops` `{` $ops^ `}`)?
(`interface` `{` $interface^ `}`)?
(`attributes` $op_attrs^)?
+ (`filter_result_type` `=` $filter_result_type^)?
`in` $target attr-dict
}];
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f4241f44ea0f6..29b13e27de7ed 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -458,7 +458,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
SmallVector<Operation *> res;
auto matchFun = [&](Operation *op) {
if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
- return WalkResult::advance();
+ return;
// Interfaces cannot be matched by name, just by ID.
// So we specifically encode the interfaces we care about for this op.
@@ -466,10 +466,10 @@ transform::MatchOp::apply(transform::TransformResults &results,
auto iface = getInterface().value();
if (iface == transform::MatchInterfaceEnum::LinalgOp &&
!isa<linalg::LinalgOp>(op))
- return WalkResult::advance();
+ return;
if (iface == transform::MatchInterfaceEnum::TilingInterface &&
isa<TilingInterface>(op))
- return WalkResult::advance();
+ return;
}
// Check if all specified attributes match.
@@ -480,15 +480,21 @@ transform::MatchOp::apply(transform::TransformResults &results,
attr.getName() == getOpsAttrName())
continue;
if (!op->hasAttr(attr.getName()))
- return WalkResult::advance();
+ return;
if (op->getAttr(attr.getName()) != attr.getValue())
- return WalkResult::advance();
+ return;
}
}
+ if (getFilterResultType().has_value()) {
+ Type t = getFilterResultType().value();
+ if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
+ return;
+ }
+
// All constraints are satisfied.
res.push_back(op);
- return WalkResult::advance();
+ return;
};
payloadOps.front()->walk(matchFun);
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 2696cf5b16ef9..7d31fc4b42844 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -25,6 +25,26 @@ transform.with_pdl_patterns {
// -----
+func.func @by_type() {
+ %0 = arith.constant 0: i32
+ // expected-remark @below {{matched op name}}
+ %1 = arith.constant 1.0 : f32
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %match_name = transform.structured.match
+ ops{["arith.constant"]} filter_result_type = f32 in %arg1
+ transform.test_print_remark_at_operand %match_name, "matched op name"
+ transform.test_consume_operand %match_name
+ }
+}
+
+// -----
+
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
func.func @match_complex_attribute(%arg0: tensor<12x128x32xf32>)
More information about the Mlir-commits
mailing list