[Mlir-commits] [mlir] 4cf936d - [mlir][transform] Add transform.get_defining_op op

Matthias Springer llvmlistbot at llvm.org
Tue Feb 21 00:03:33 PST 2023


Author: Matthias Springer
Date: 2023-02-21T09:03:25+01:00
New Revision: 4cf936d01f36adaa268fdc1411f27be197434cb2

URL: https://github.com/llvm/llvm-project/commit/4cf936d01f36adaa268fdc1411f27be197434cb2
DIFF: https://github.com/llvm/llvm-project/commit/4cf936d01f36adaa268fdc1411f27be197434cb2.diff

LOG: [mlir][transform] Add transform.get_defining_op op

This op is the inverse of `transform.get_result`.

Differential Revision: https://reviews.llvm.org/D144409

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 7a746174da2e8..886586513dc85 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -210,6 +210,23 @@ def GetConsumersOfResult : TransformDialectOp<"get_consumers_of_result",
                        "functional-type(operands, results)";
 }
 
+def GetDefiningOp : TransformDialectOp<"get_defining_op",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+  let summary = "Get handle to the defining op of a value";
+  let description = [{
+    The handle defined by this Transform op corresponds to the defining op of
+    the targeted value.
+
+    This transform fails silently if the targeted value is a block argument.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$result);
+  let assemblyFormat = "$target attr-dict `:` "
+                       "functional-type(operands, results)";
+}
+
 def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index a7b370cdd0fda..663220695a7ba 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -447,6 +447,27 @@ transform::GetConsumersOfResult::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// GetDefiningOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetDefiningOp::apply(transform::TransformResults &results,
+                              transform::TransformState &state) {
+  SmallVector<Operation *> definingOps;
+  for (Value v : state.getPayloadValues(getTarget())) {
+    if (v.isa<BlockArgument>()) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError() << "cannot get defining op of block argument";
+      diag.attachNote(v.getLoc()) << "target value";
+      return diag;
+    }
+    definingOps.push_back(v.getDefiningOp());
+  }
+  results.set(getResult().cast<OpResult>(), definingOps);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // GetProducerOfOperand
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index f6056864e0bac..3a7f42015f38b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1222,3 +1222,36 @@ transform.sequence failures(propagate) {
   %result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value
   transform.test_print_remark_at_operand_value %result, "addi result" : !transform.any_value
 }
+
+// -----
+
+func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
+  // expected-remark @below {{matched}}
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %result = transform.get_result %addi[0] : (!transform.any_op) -> !transform.any_value
+  %op = transform.get_defining_op %result : (!transform.any_value) -> !transform.any_op
+  transform.test_print_remark_at_operand %op, "matched" : !transform.any_op
+}
+
+// -----
+
+// expected-note @below {{target value}}
+func.func @get_result_of_op_bbarg(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %bbarg = test_produce_value_handle_to_argument_of_parent_block %addi, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-error @below {{cannot get defining op of block argument}}
+  %op = transform.get_defining_op %bbarg : (!transform.any_value) -> !transform.any_op
+  transform.test_print_remark_at_operand %op, "matched" : !transform.any_op
+}


        


More information about the Mlir-commits mailing list