[Mlir-commits] [mlir] 3ef062a - [mlir][transform] Add transform.get_result op

Matthias Springer llvmlistbot at llvm.org
Wed Feb 15 05:21:20 PST 2023


Author: Matthias Springer
Date: 2023-02-15T14:16:50+01:00
New Revision: 3ef062a4bd86bd81110fd1822d1ead24afe53c42

URL: https://github.com/llvm/llvm-project/commit/3ef062a4bd86bd81110fd1822d1ead24afe53c42
DIFF: https://github.com/llvm/llvm-project/commit/3ef062a4bd86bd81110fd1822d1ead24afe53c42.diff

LOG: [mlir][transform] Add transform.get_result op

This transform op returns a value handle pointing to the specified OpResult of the targeted op.

Differential Revision: https://reviews.llvm.org/D144087

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index dd66e61880416..7a746174da2e8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -230,6 +230,25 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
                        "functional-type(operands, results)";
 }
 
+def GetResultOp : TransformDialectOp<"get_result",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+  let summary = "Get handle to the a result of the targeted op";
+  let description = [{
+    The handle defined by this Transform op corresponds to the OpResult with
+    `result_number` that is defined by the given `target` operation.
+    
+    This transform fails silently if the targeted operation does not have enough
+    results. It reads the target handle and produces the result handle.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$result_number);
+  let results = (outs TransformValueHandleTypeInterface:$result);
+  let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
+                       "functional-type(operands, results)";
+}
+
 def MergeHandlesOp : TransformDialectOp<"merge_handles",
     [DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 5995485f5a79a..ef5bfd2a85aeb 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -722,6 +722,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
 
       if (opResult.getType().isa<TransformParamTypeInterface>())
         results.setParams(opResult, {});
+      else if (opResult.getType().isa<TransformValueHandleTypeInterface>())
+        results.setValues(opResult, {});
       else
         results.set(opResult, {});
     }
@@ -831,7 +833,7 @@ void transform::TransformResults::setParams(
 void transform::TransformResults::setValues(OpResult handle,
                                             ValueRange values) {
   int64_t position = handle.getResultNumber();
-  assert(position < static_cast<int64_t>(values.size()) &&
+  assert(position < static_cast<int64_t>(this->values.size()) &&
          "setting values for a non-existent handle");
   assert(this->values[position].data() == nullptr && "values already set");
   assert(operations[position].data() == nullptr &&
@@ -861,8 +863,8 @@ transform::TransformResults::getParams(unsigned resultNumber) const {
 
 ArrayRef<Value>
 transform::TransformResults::getValues(unsigned resultNumber) const {
-  assert(resultNumber < params.size() &&
-         "querying params for a non-existent handle");
+  assert(resultNumber < values.size() &&
+         "querying values for a non-existent handle");
   assert(values[resultNumber].data() != nullptr &&
          "querying unset values (ops or params expected?)");
   return values[resultNumber];

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 134911322a4b3..a7b370cdd0fda 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -475,6 +475,28 @@ transform::GetProducerOfOperand::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// GetResultOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetResultOp::apply(transform::TransformResults &results,
+                              transform::TransformState &state) {
+  int64_t resultNumber = getResultNumber();
+  SmallVector<Value> opResults;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    if (resultNumber >= target->getNumResults()) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError() << "targeted op does not have enough results";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+    opResults.push_back(target->getOpResult(resultNumber));
+  }
+  results.setValues(getResult().cast<OpResult>(), opResults);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // MergeHandlesOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 7e2804d1621b2..f6056864e0bac 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1190,3 +1190,35 @@ transform.sequence failures(propagate) {
   // expected-error @below {{unexpectedly consumed a value that is not a handle as operand #0}}
   test_consume_operand %0 : !transform.test_dialect_param
 }
+
+// -----
+
+func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
+  // expected-remark @below {{addi result}}
+  // expected-note @below {{value handle points to an op result #0}}
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %result = transform.get_result %addi[0] : (!transform.any_op) -> !transform.any_value
+  transform.test_print_remark_at_operand_value %result, "addi result" : !transform.any_value
+}
+
+// -----
+
+func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
+  // expected-note @below {{target op}}
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{targeted op does not have enough results}}
+  %result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value
+  transform.test_print_remark_at_operand_value %result, "addi result" : !transform.any_value
+}


        


More information about the Mlir-commits mailing list