[Mlir-commits] [mlir] d906426 - [mlir] make transform.loop.outline also return the call handle
Alex Zinenko
llvmlistbot at llvm.org
Fri May 5 05:42:13 PDT 2023
Author: Alex Zinenko
Date: 2023-05-05T12:42:05Z
New Revision: d9064269442d8d6126021db9230bac97a2ea9e13
URL: https://github.com/llvm/llvm-project/commit/d9064269442d8d6126021db9230bac97a2ea9e13
DIFF: https://github.com/llvm/llvm-project/commit/d9064269442d8d6126021db9230bac97a2ea9e13.diff
LOG: [mlir] make transform.loop.outline also return the call handle
Outlining is particularly interesting when the outlined function is
replaced with something else, e.g., a microkernel. It is good to have a
handle to the call in this case.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D149849
Added:
Modified:
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/python/mlir/dialects/_loop_transform_ops_ext.py
mlir/test/Dialect/SCF/transform-ops-invalid.mlir
mlir/test/Dialect/SCF/transform-ops.mlir
mlir/test/python/dialects/transform_loop_ext.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 0399a5a9afa5e..fef9e4bd17214 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -46,16 +46,22 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Outlines a loop into a named function";
let description = [{
- Moves the loop into a separate function with the specified name and
- replaces the loop in the Payload IR with a call to that function. Takes
- care of forwarding values that are used in the loop as function arguments.
- If the operand is associated with more than one loop, each loop will be
- outlined into a separate function. The provided name is used as a _base_
- for forming actual function names following SymbolTable auto-renaming
- scheme to avoid duplicate symbols. Expects that all ops in the Payload IR
- have a SymbolTable ancestor (typically true because of the top-level
- module). Returns the handle to the list of outlined functions in the same
- order as the operand handle.
+ Moves the loop into a separate function with the specified name and replaces
+ the loop in the Payload IR with a call to that function. Takes care of
+ forwarding values that are used in the loop as function arguments. If the
+ operand is associated with more than one loop, each loop will be outlined
+ into a separate function. The provided name is used as a _base_ for forming
+ actual function names following `SymbolTable` auto-renaming scheme to avoid
+ duplicate symbols. Expects that all ops in the Payload IR have a
+ `SymbolTable` ancestor (typically true because of the top-level module).
+
+ #### Return Modes
+
+ Returns a handle to the list of outlined functions and a handle to the
+ corresponding function call operations in the same order as the operand
+ handle.
+
+ Produces a definite failure if outlining failed for any of the targets.
}];
// Note that despite the name of the transform operation and related utility
@@ -63,7 +69,8 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
// a loop.
let arguments = (ins TransformHandleTypeInterface:$target,
StrAttr:$func_name);
- let results = (outs TransformHandleTypeInterface:$transformed);
+ let results = (outs TransformHandleTypeInterface:$function,
+ TransformHandleTypeInterface:$call);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 0c7c04ddc63c2..18425dea7b19f 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -85,7 +85,8 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
DiagnosedSilenceableFailure
transform::LoopOutlineOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- SmallVector<Operation *> transformed;
+ SmallVector<Operation *> functions;
+ SmallVector<Operation *> calls;
DenseMap<Operation *, SymbolTable> symbolTables;
for (Operation *target : state.getPayloadOps(getTarget())) {
Location location = target->getLoc();
@@ -112,9 +113,11 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
symbolTable.insert(*outlined);
call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
}
- transformed.push_back(*outlined);
+ functions.push_back(*outlined);
+ calls.push_back(call);
}
- results.set(getTransformed().cast<OpResult>(), transformed);
+ results.set(getFunction().cast<OpResult>(), functions);
+ results.set(getCall().cast<OpResult>(), calls);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
index a275ea615378e..10079d32fd925 100644
--- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
@@ -39,7 +39,8 @@ class LoopOutlineOp:
def __init__(
self,
- result_type: Type,
+ function_type: Type,
+ call_type: Type,
target: Union[Operation, Value],
*,
func_name: Union[str, StringAttr],
@@ -47,7 +48,8 @@ def __init__(
loc=None,
):
super().__init__(
- result_type,
+ function_type,
+ call_type,
_get_op_result_or_value(target),
func_name=(func_name if isinstance(func_name, StringAttr) else
StringAttr.get(func_name)),
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index da040fff82733..2e15abdd260db 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -54,8 +54,8 @@ func.func @loop_outline_op_multi_region() {
}
transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{failed to outline}}
- transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
+ transform.loop.outline %0 {func_name = "foo"} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
}
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index d876d0f6be9c2..8fe9eddf6c193 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -75,11 +75,11 @@ func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) {
}
transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!pdl.operation) -> !pdl.operation
- %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
// CHECK: = transform.loop.outline %{{.*}}
- transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation
+ transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
}
// -----
diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py
index 02d35f628c175..067a8b60d4f89 100644
--- a/mlir/test/python/dialects/transform_loop_ext.py
+++ b/mlir/test/python/dialects/transform_loop_ext.py
@@ -33,7 +33,7 @@ def loopOutline():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], transform.OperationType.get("scf.for"))
with InsertionPoint(sequence.body):
- loop.LoopOutlineOp(pdl.OperationType.get(), sequence.bodyTarget, func_name="foo")
+ loop.LoopOutlineOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo")
transform.YieldOp()
# CHECK-LABEL: TEST: loopOutline
# CHECK: = transform.loop.outline %
More information about the Mlir-commits
mailing list