[Mlir-commits] [mlir] 4adf89f - [mlir][Transform] Add a transform.get_consumers_of_result navigation op
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jan 17 07:01:53 PST 2023
Author: Nicolas Vasilache
Date: 2023-01-17T06:58:38-08:00
New Revision: 4adf89fc65f6ffda8add795d002334b6bbea55f5
URL: https://github.com/llvm/llvm-project/commit/4adf89fc65f6ffda8add795d002334b6bbea55f5
DIFF: https://github.com/llvm/llvm-project/commit/4adf89fc65f6ffda8add795d002334b6bbea55f5.diff
LOG: [mlir][Transform] Add a transform.get_consumers_of_result navigation op
Differential Revision: https://reviews.llvm.org/D141930
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 72a819cffb9c9..d72d38a365b7b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -189,6 +189,27 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent
"$target attr-dict `:` functional-type(operands, results)";
}
+def GetConsumersOfResult : TransformDialectOp<"get_consumers_of_result",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+ let summary = "Get handle to the consumers of this operation's result number";
+ let description = [{
+ The handle defined by this Transform op corresponds to all operations that
+ consume the SSA value defined by the `target` and `result_number`
+ arguments.
+ This operation applies to a single payload operation, otherwise it
+ definitely fails.
+ The return handle points to the consuming operations operations, which can
+ be empty.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ I64Attr:$result_number);
+ let results = (outs TransformHandleTypeInterface:$consumers);
+ let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
+ "functional-type(operands, results)";
+}
+
def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 0b3391ea448e2..f711419a11a2f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -399,6 +399,31 @@ DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// GetConsumersOfResult
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetConsumersOfResult::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ int64_t resultNumber = getResultNumber();
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+ if (payloadOps.empty()) {
+ results.set(getResult().cast<OpResult>(), {});
+ return DiagnosedSilenceableFailure::success();
+ }
+ if (payloadOps.size() != 1)
+ return emitDefiniteFailure()
+ << "handle must be mapped to exactly one payload op";
+
+ Operation *target = payloadOps.front();
+ if (target->getNumResults() <= resultNumber)
+ return emitDefiniteFailure() << "result number overflow";
+ results.set(getResult().cast<OpResult>(),
+ llvm::to_vector(target->getResult(resultNumber).getUsers()));
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// GetProducerOfOperand
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index da48fe234633c..30d1155655960 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -775,6 +775,53 @@ transform.sequence failures(propagate) {
// -----
+func.func @get_consumer(%arg0: index, %arg1: index) {
+ %0 = arith.muli %arg0, %arg1 : index
+ // expected-remark @below {{found addi}}
+ arith.addi %0, %arg1 : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+ %addi = get_consumers_of_result %muli[0] : (!pdl.operation) -> !pdl.operation
+ transform.test_print_remark_at_operand %addi, "found addi" : !pdl.operation
+}
+
+// -----
+
+func.func @get_consumer_fail_1(%arg0: index, %arg1: index) {
+ %0 = arith.muli %arg0, %arg1 : index
+ %1 = arith.muli %arg0, %arg1 : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+ // expected-error @below {{handle must be mapped to exactly one payload op}}
+ %bbarg = get_consumers_of_result %muli[0] : (!pdl.operation) -> !pdl.operation
+
+}
+
+// -----
+
+func.func @get_consumer_fail_2(%arg0: index, %arg1: index) {
+ %0 = arith.muli %arg0, %arg1 : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+ // expected-error @below {{result number overflow}}
+ %bbarg = get_consumers_of_result %muli[1] : (!pdl.operation) -> !pdl.operation
+
+}
+
+// -----
+
func.func @split_handles(%a: index, %b: index, %c: index) {
%0 = arith.muli %a, %b : index
%1 = arith.muli %a, %c : index
More information about the Mlir-commits
mailing list