[Mlir-commits] [mlir] Let `transform.structured.convert_to_loops` return handles to loops (PR #83984)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 5 01:47:49 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: None (lhunloh)

<details>
<summary>Changes</summary>

This lets `transform.structured.convert_to_loops` return handles to the generated loops, making this transformation more useful to use for (transformation-)nesting purposes. This is modelled after SCFs `transform.loop.forall_to_for` which returns handles to loops.

Introduced in commit #aa2a96a, with a note that they might move out of the `Linalg`-Dialect, but no reason given for the non-return of handles. As far as I can see, this transform always returns loops.

---
Full diff: https://github.com/llvm/llvm-project/pull/83984.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+3-3) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3) 
- (modified) mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir (+16-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 53ed31877c6f24..3d77e308f98f50 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1281,14 +1281,14 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
   let description = [{
     For operations that implement the `TilingInterface`, and implement
     the `generateScalarImplementation` method, lowers the operation to
-    loops. This operation does not return any handles.
+    loops. The return handles point to the generated loops.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$result);
 
   let assemblyFormat = [{
-    $target attr-dict `:` type($target)
+    $target attr-dict `:` functional-type(operands, results)
   }];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 0ac0a89dcc76ae..259d6e1b8f277e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2122,6 +2122,9 @@ DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
   if (failed(loops))
     return emitDefaultDefiniteFailure(target);
   rewriter.eraseOp(target);
+  for(auto &loop: *loops){
+    results.push_back(loop);
+  }
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 1b2c553b25ded0..e431e63134eefc 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %matmul : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %matmul
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -66,7 +67,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %generic : !transform.any_op
+    %0:2 = transform.structured.convert_to_loops %generic
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -111,7 +113,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %conv : !transform.any_op
+    %0:7 = transform.structured.convert_to_loops %conv
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -165,7 +168,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %pool : !transform.any_op
+    %0:6 = transform.structured.convert_to_loops %pool
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -216,7 +220,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %map = transform.structured.match ops{["linalg.map"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %map : !transform.any_op
+    %0 = transform.structured.convert_to_loops %map
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -248,7 +253,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %transpose : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %transpose
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -285,7 +291,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %reduce : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %reduce
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -322,7 +329,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %broadcast : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/83984


More information about the Mlir-commits mailing list