[Mlir-commits] [mlir] d906426 - [mlir] make transform.loop.outline also return the call handle

Alex Zinenko llvmlistbot at llvm.org
Fri May 5 05:42:13 PDT 2023


Author: Alex Zinenko
Date: 2023-05-05T12:42:05Z
New Revision: d9064269442d8d6126021db9230bac97a2ea9e13

URL: https://github.com/llvm/llvm-project/commit/d9064269442d8d6126021db9230bac97a2ea9e13
DIFF: https://github.com/llvm/llvm-project/commit/d9064269442d8d6126021db9230bac97a2ea9e13.diff

LOG: [mlir] make transform.loop.outline also return the call handle

Outlining is particularly interesting when the outlined function is
replaced with something else, e.g., a microkernel. It is good to have a
handle to the call in this case.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D149849

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/python/mlir/dialects/_loop_transform_ops_ext.py
    mlir/test/Dialect/SCF/transform-ops-invalid.mlir
    mlir/test/Dialect/SCF/transform-ops.mlir
    mlir/test/python/dialects/transform_loop_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 0399a5a9afa5e..fef9e4bd17214 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -46,16 +46,22 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Outlines a loop into a named function";
   let description = [{
-     Moves the loop into a separate function with the specified name and
-     replaces the loop in the Payload IR with a call to that function. Takes
-     care of forwarding values that are used in the loop as function arguments.
-     If the operand is associated with more than one loop, each loop will be
-     outlined into a separate function. The provided name is used as a _base_
-     for forming actual function names following SymbolTable auto-renaming
-     scheme to avoid duplicate symbols. Expects that all ops in the Payload IR
-     have a SymbolTable ancestor (typically true because of the top-level
-     module). Returns the handle to the list of outlined functions in the same
-     order as the operand handle.
+    Moves the loop into a separate function with the specified name and replaces
+    the loop in the Payload IR with a call to that function. Takes care of
+    forwarding values that are used in the loop as function arguments. If the
+    operand is associated with more than one loop, each loop will be outlined
+    into a separate function. The provided name is used as a _base_ for forming
+    actual function names following `SymbolTable` auto-renaming scheme to avoid
+    duplicate symbols. Expects that all ops in the Payload IR have a
+    `SymbolTable` ancestor (typically true because of the top-level module).
+
+    #### Return Modes
+
+    Returns a handle to the list of outlined functions and a handle to the
+    corresponding function call operations in the same order as the operand
+    handle.
+
+    Produces a definite failure if outlining failed for any of the targets.
   }];
 
   // Note that despite the name of the transform operation and related utility
@@ -63,7 +69,8 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
   // a loop.
   let arguments = (ins TransformHandleTypeInterface:$target,
                    StrAttr:$func_name);
-  let results = (outs TransformHandleTypeInterface:$transformed);
+  let results = (outs TransformHandleTypeInterface:$function,
+                      TransformHandleTypeInterface:$call);
 
   let assemblyFormat =
     "$target attr-dict `:` functional-type(operands, results)";

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 0c7c04ddc63c2..18425dea7b19f 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -85,7 +85,8 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
 DiagnosedSilenceableFailure
 transform::LoopOutlineOp::apply(transform::TransformResults &results,
                                 transform::TransformState &state) {
-  SmallVector<Operation *> transformed;
+  SmallVector<Operation *> functions;
+  SmallVector<Operation *> calls;
   DenseMap<Operation *, SymbolTable> symbolTables;
   for (Operation *target : state.getPayloadOps(getTarget())) {
     Location location = target->getLoc();
@@ -112,9 +113,11 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
       symbolTable.insert(*outlined);
       call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
     }
-    transformed.push_back(*outlined);
+    functions.push_back(*outlined);
+    calls.push_back(call);
   }
-  results.set(getTransformed().cast<OpResult>(), transformed);
+  results.set(getFunction().cast<OpResult>(), functions);
+  results.set(getCall().cast<OpResult>(), calls);
   return DiagnosedSilenceableFailure::success();
 }
 

diff  --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
index a275ea615378e..10079d32fd925 100644
--- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
@@ -39,7 +39,8 @@ class LoopOutlineOp:
 
   def __init__(
       self,
-      result_type: Type,
+      function_type: Type,
+      call_type: Type,
       target: Union[Operation, Value],
       *,
       func_name: Union[str, StringAttr],
@@ -47,7 +48,8 @@ def __init__(
       loc=None,
   ):
     super().__init__(
-        result_type,
+        function_type,
+        call_type,
         _get_op_result_or_value(target),
         func_name=(func_name if isinstance(func_name, StringAttr) else
                    StringAttr.get(func_name)),

diff  --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index da040fff82733..2e15abdd260db 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -54,8 +54,8 @@ func.func @loop_outline_op_multi_region() {
 }
 
 transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   // expected-error @below {{failed to outline}}
-  transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
+  transform.loop.outline %0 {func_name = "foo"} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 }

diff  --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index d876d0f6be9c2..8fe9eddf6c193 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -75,11 +75,11 @@ func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) {
 }
 
 transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-  %1 = transform.loop.get_parent_for %0  : (!pdl.operation) -> !transform.op<"scf.for">
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.loop.get_parent_for %0  : (!transform.any_op) -> !transform.op<"scf.for">
   // CHECK: = transform.loop.outline %{{.*}}
-  transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation
+  transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
 }
 
 // -----

diff  --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py
index 02d35f628c175..067a8b60d4f89 100644
--- a/mlir/test/python/dialects/transform_loop_ext.py
+++ b/mlir/test/python/dialects/transform_loop_ext.py
@@ -33,7 +33,7 @@ def loopOutline():
   sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
                                   [], transform.OperationType.get("scf.for"))
   with InsertionPoint(sequence.body):
-    loop.LoopOutlineOp(pdl.OperationType.get(), sequence.bodyTarget, func_name="foo")
+    loop.LoopOutlineOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo")
     transform.YieldOp()
   # CHECK-LABEL: TEST: loopOutline
   # CHECK: = transform.loop.outline %


        


More information about the Mlir-commits mailing list