[Mlir-commits] [mlir] ecd9dc0 - [mlir][Transform] Add a new navigation op to retrieve the producer of an operand

Nicolas Vasilache llvmlistbot at llvm.org
Mon Sep 19 04:21:16 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-19T04:16:15-07:00
New Revision: ecd9dc0499880f2a89e4e03e9ffd3b368fe7e7ff

URL: https://github.com/llvm/llvm-project/commit/ecd9dc0499880f2a89e4e03e9ffd3b368fe7e7ff
DIFF: https://github.com/llvm/llvm-project/commit/ecd9dc0499880f2a89e4e03e9ffd3b368fe7e7ff.diff

LOG: [mlir][Transform] Add a new navigation op to retrieve the producer of an operand

Given an opOperand uniquely determined by the operation `%op` and the operand number `num`,
the `transform.get_producer_of_operand %op[num]` returns the handle to the unique operation
that produced the SSA value used as opOperand.

The transform fails if the operand is a block argument.

Differential Revision: https://reviews.llvm.org/D134171

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 361973da1e44b..99408eb3a60d7 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -169,6 +169,25 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent
   let assemblyFormat = "$target attr-dict";
 }
 
+def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+  let summary = "Get handle to the producer of this operation's operand number";
+  let description = [{
+    The handle defined by this Transform op corresponds to operation that
+    produces the SSA value defined by the `target` and `operand_number`
+    arguments. If the origin of the SSA value is not an operations (i.e. it is
+    a block argument), the transform silently fails.
+    The return handle points to only the subset of successfully produced
+    computational operations, which can be empty.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                       I64Attr:$operand_number);
+  let results = (outs PDL_Operation:$parent);
+  let assemblyFormat = "$target `[` $operand_number `]` attr-dict";
+}
+
 def MergeHandlesOp : TransformDialectOp<"merge_handles",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index ecab661f92d21..b9f9d44f9b275 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -386,6 +386,36 @@ DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// GetProducerOfOperand
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetProducerOfOperand::apply(transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  int64_t operandNumber = getOperandNumber();
+  SmallVector<Operation *> producers;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    Operation *producer =
+        target->getNumOperands() <= operandNumber
+            ? nullptr
+            : target->getOperand(operandNumber).getDefiningOp();
+    if (!producer) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError()
+          << "could not find a producer for operand number: " << operandNumber
+          << " of " << *target;
+      diag.attachNote(target->getLoc()) << "target op";
+      results.set(getResult().cast<OpResult>(),
+                  SmallVector<mlir::Operation *>{});
+      return diag;
+    }
+    producers.push_back(producer);
+  }
+  results.set(getResult().cast<OpResult>(), producers);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // MergeHandlesOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 494c6160357f9..0a5688f2239ce 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -727,3 +727,36 @@ transform.with_pdl_patterns {
     transform.test_print_remark_at_operand %results, "transform applied"
   }
 }
+
+// -----
+
+func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
+  // expected-remark @below {{found muli}}
+  %0 = arith.muli %arg0, %arg1 : index  
+  arith.addi %0, %arg1 : index  
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %addi = transform.structured.match ops{["arith.addi"]} in %arg1
+  %muli = get_producer_of_operand %addi[0]
+  transform.test_print_remark_at_operand %muli, "found muli"
+}
+
+// -----
+
+func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
+  // expected-note @below {{target op}}
+  %0 = arith.muli %arg0, %arg1 : index  
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+  // expected-error @below {{could not find a producer for operand number: 0 of}}
+  %bbarg = get_producer_of_operand %muli[0]
+
+}
+


        


More information about the Mlir-commits mailing list