[Mlir-commits] [mlir] 894fdbc - [mlir][transform] Add transform.select op
Matthias Springer
llvmlistbot at llvm.org
Tue Jul 11 07:25:19 PDT 2023
Author: Matthias Springer
Date: 2023-07-11T16:16:56+02:00
New Revision: 894fdbc7194bea573cb0c4beb64b056e2b1b74d2
URL: https://github.com/llvm/llvm-project/commit/894fdbc7194bea573cb0c4beb64b056e2b1b74d2
DIFF: https://github.com/llvm/llvm-project/commit/894fdbc7194bea573cb0c4beb64b056e2b1b74d2.diff
LOG: [mlir][transform] Add transform.select op
This transform op can be used to select all payload ops with a given name from a handle.
Differential Revision: https://reviews.llvm.org/D154956
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 138469aec5eadf..6820f3c7412a27 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -854,6 +854,30 @@ def ReplicateOp : TransformDialectOp<"replicate",
"type($pattern) `,` type($handles)";
}
+def SelectOp : TransformDialectOp<"select",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+ let summary = "Select payload ops by name";
+ let description = [{
+ The handle defined by this Transform op corresponds to all operations among
+ `target` that have the specified properties. Currently the following
+ properties are supported:
+
+ - `op_name`: The op must have the specified name.
+
+ The result payload ops are in the same relative order as the targeted ops.
+ This transform op reads the `target` handle and produces the `result`
+ handle. It reads the payload, but does not modify it.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ StrAttr:$op_name);
+ let results = (outs TransformHandleTypeInterface:$result);
+ let assemblyFormat = [{
+ $op_name `in` $target attr-dict `:` functional-type(operands, results)
+ }];
+}
+
def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getSuccessorEntryOperands", "getSuccessorRegions",
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 546bc1220c9fe5..20d572d994c7bc 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1586,6 +1586,24 @@ LogicalResult transform::NamedSequenceOp::verify() {
return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
}
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::SelectOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> result;
+ auto payloadOps = state.getPayloadOps(getTarget());
+ for (Operation *op : payloadOps) {
+ if (op->getName().getStringRef() == getOpName())
+ result.push_back(op);
+ }
+ results.set(cast<OpResult>(getResult()), result);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// SplitHandleOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 5d0921ce21c7d3..442ef625ffde9c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1945,3 +1945,32 @@ transform.sequence failures(propagate) {
// expected-error @below{{failed to verify payload op}}
transform.verify %0 : !transform.any_op
}
+
+// -----
+
+func.func @select() {
+ // expected-remark @below{{found foo}}
+ "test.foo"() : () -> ()
+ // expected-remark @below{{found bar}}
+ "test.bar"() : () -> ()
+ // expected-remark @below{{found foo}}
+ "test.foo"() : () -> ()
+ func.return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // Match all ops inside the function (including the function itself).
+ %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match in %func_op : (!transform.any_op) -> !transform.any_op
+ // expected-remark @below{{5}}
+ test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+
+ // Select "test.foo".
+ %foo = transform.select "test.foo" in %0 : (!transform.any_op) -> !transform.any_op
+ test_print_remark_at_operand %foo, "found foo" : !transform.any_op
+
+ // Select "test.bar".
+ %bar = transform.select "test.bar" in %0 : (!transform.any_op) -> !transform.any_op
+ test_print_remark_at_operand %bar, "found bar" : !transform.any_op
+}
More information about the Mlir-commits
mailing list