[Mlir-commits] [mlir] 3ef062a - [mlir][transform] Add transform.get_result op
Matthias Springer
llvmlistbot at llvm.org
Wed Feb 15 05:21:20 PST 2023
Author: Matthias Springer
Date: 2023-02-15T14:16:50+01:00
New Revision: 3ef062a4bd86bd81110fd1822d1ead24afe53c42
URL: https://github.com/llvm/llvm-project/commit/3ef062a4bd86bd81110fd1822d1ead24afe53c42
DIFF: https://github.com/llvm/llvm-project/commit/3ef062a4bd86bd81110fd1822d1ead24afe53c42.diff
LOG: [mlir][transform] Add transform.get_result op
This transform op returns a value handle pointing to the specified OpResult of the targeted op.
Differential Revision: https://reviews.llvm.org/D144087
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index dd66e61880416..7a746174da2e8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -230,6 +230,25 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
"functional-type(operands, results)";
}
+def GetResultOp : TransformDialectOp<"get_result",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+ let summary = "Get handle to the a result of the targeted op";
+ let description = [{
+ The handle defined by this Transform op corresponds to the OpResult with
+ `result_number` that is defined by the given `target` operation.
+
+ This transform fails silently if the targeted operation does not have enough
+ results. It reads the target handle and produces the result handle.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ I64Attr:$result_number);
+ let results = (outs TransformValueHandleTypeInterface:$result);
+ let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
+ "functional-type(operands, results)";
+}
+
def MergeHandlesOp : TransformDialectOp<"merge_handles",
[DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 5995485f5a79a..ef5bfd2a85aeb 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -722,6 +722,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
if (opResult.getType().isa<TransformParamTypeInterface>())
results.setParams(opResult, {});
+ else if (opResult.getType().isa<TransformValueHandleTypeInterface>())
+ results.setValues(opResult, {});
else
results.set(opResult, {});
}
@@ -831,7 +833,7 @@ void transform::TransformResults::setParams(
void transform::TransformResults::setValues(OpResult handle,
ValueRange values) {
int64_t position = handle.getResultNumber();
- assert(position < static_cast<int64_t>(values.size()) &&
+ assert(position < static_cast<int64_t>(this->values.size()) &&
"setting values for a non-existent handle");
assert(this->values[position].data() == nullptr && "values already set");
assert(operations[position].data() == nullptr &&
@@ -861,8 +863,8 @@ transform::TransformResults::getParams(unsigned resultNumber) const {
ArrayRef<Value>
transform::TransformResults::getValues(unsigned resultNumber) const {
- assert(resultNumber < params.size() &&
- "querying params for a non-existent handle");
+ assert(resultNumber < values.size() &&
+ "querying values for a non-existent handle");
assert(values[resultNumber].data() != nullptr &&
"querying unset values (ops or params expected?)");
return values[resultNumber];
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 134911322a4b3..a7b370cdd0fda 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -475,6 +475,28 @@ transform::GetProducerOfOperand::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// GetResultOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetResultOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ int64_t resultNumber = getResultNumber();
+ SmallVector<Value> opResults;
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ if (resultNumber >= target->getNumResults()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "targeted op does not have enough results";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ opResults.push_back(target->getOpResult(resultNumber));
+ }
+ results.setValues(getResult().cast<OpResult>(), opResults);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// MergeHandlesOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 7e2804d1621b2..f6056864e0bac 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1190,3 +1190,35 @@ transform.sequence failures(propagate) {
// expected-error @below {{unexpectedly consumed a value that is not a handle as operand #0}}
test_consume_operand %0 : !transform.test_dialect_param
}
+
+// -----
+
+func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
+ // expected-remark @below {{addi result}}
+ // expected-note @below {{value handle points to an op result #0}}
+ %r = arith.addi %arg0, %arg1 : index
+ return %r : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %result = transform.get_result %addi[0] : (!transform.any_op) -> !transform.any_value
+ transform.test_print_remark_at_operand_value %result, "addi result" : !transform.any_value
+}
+
+// -----
+
+func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
+ // expected-note @below {{target op}}
+ %r = arith.addi %arg0, %arg1 : index
+ return %r : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{targeted op does not have enough results}}
+ %result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value
+ transform.test_print_remark_at_operand_value %result, "addi result" : !transform.any_value
+}
More information about the Mlir-commits
mailing list