[Mlir-commits] [mlir] 894fdbc - [mlir][transform] Add transform.select op

Matthias Springer llvmlistbot at llvm.org
Tue Jul 11 07:25:19 PDT 2023


Author: Matthias Springer
Date: 2023-07-11T16:16:56+02:00
New Revision: 894fdbc7194bea573cb0c4beb64b056e2b1b74d2

URL: https://github.com/llvm/llvm-project/commit/894fdbc7194bea573cb0c4beb64b056e2b1b74d2
DIFF: https://github.com/llvm/llvm-project/commit/894fdbc7194bea573cb0c4beb64b056e2b1b74d2.diff

LOG: [mlir][transform] Add transform.select op

This transform op can be used to select all payload ops with a given name from a handle.

Differential Revision: https://reviews.llvm.org/D154956

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 138469aec5eadf..6820f3c7412a27 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -854,6 +854,30 @@ def ReplicateOp : TransformDialectOp<"replicate",
                        "type($pattern) `,` type($handles)";
 }
 
+def SelectOp : TransformDialectOp<"select",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+  let summary = "Select payload ops by name";
+  let description = [{
+    The handle defined by this Transform op corresponds to all operations among
+    `target` that have the specified properties. Currently the following
+    properties are supported:
+
+    - `op_name`: The op must have the specified name.
+
+    The result payload ops are in the same relative order as the targeted ops.
+    This transform op reads the `target` handle and produces the `result`
+    handle. It reads the payload, but does not modify it.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       StrAttr:$op_name);
+  let results = (outs TransformHandleTypeInterface:$result);
+  let assemblyFormat = [{
+    $op_name `in` $target attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 def SequenceOp : TransformDialectOp<"sequence",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getSuccessorEntryOperands", "getSuccessorRegions",

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 546bc1220c9fe5..20d572d994c7bc 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1586,6 +1586,24 @@ LogicalResult transform::NamedSequenceOp::verify() {
   return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
 }
 
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::SelectOp::apply(transform::TransformRewriter &rewriter,
+                           transform::TransformResults &results,
+                           transform::TransformState &state) {
+  SmallVector<Operation *> result;
+  auto payloadOps = state.getPayloadOps(getTarget());
+  for (Operation *op : payloadOps) {
+    if (op->getName().getStringRef() == getOpName())
+      result.push_back(op);
+  }
+  results.set(cast<OpResult>(getResult()), result);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // SplitHandleOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 5d0921ce21c7d3..442ef625ffde9c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1945,3 +1945,32 @@ transform.sequence failures(propagate) {
   // expected-error @below{{failed to verify payload op}}
   transform.verify %0 : !transform.any_op
 }
+
+// -----
+
+func.func @select() {
+  // expected-remark @below{{found foo}}
+  "test.foo"() : () -> ()
+  // expected-remark @below{{found bar}}
+  "test.bar"() : () -> ()
+  // expected-remark @below{{found foo}}
+  "test.foo"() : () -> ()
+  func.return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // Match all ops inside the function (including the function itself).
+  %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %0 = transform.structured.match in %func_op : (!transform.any_op) -> !transform.any_op
+  // expected-remark @below{{5}}
+  test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+
+  // Select "test.foo".
+  %foo = transform.select "test.foo" in %0 : (!transform.any_op) -> !transform.any_op
+  test_print_remark_at_operand %foo, "found foo" : !transform.any_op
+
+  // Select "test.bar".
+  %bar = transform.select "test.bar" in %0 : (!transform.any_op) -> !transform.any_op
+  test_print_remark_at_operand %bar, "found bar" : !transform.any_op
+}


        


More information about the Mlir-commits mailing list