[Mlir-commits] [mlir] [MLIR] [Transforms] Let `transform.structured.convert_to_loops` return handles to loops (PR #83984)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 5 01:51:26 PST 2024


https://github.com/lhunloh updated https://github.com/llvm/llvm-project/pull/83984

>From 69a7f4c96d939b1ead84ec9812726b3efea1d653 Mon Sep 17 00:00:00 2001
From: Lars <larshunloh at uni-muenster.de>
Date: Tue, 5 Mar 2024 10:43:01 +0100
Subject: [PATCH 1/2] Let `transform.structured.convert_to_loops` return
 handles to loops

This lets `transform.structured.convert_to_loops` return handles to the
generated loops, making this transformation more useful to use
for (transformation-)nesting purposes. This is modelled after SCFs `transform.loop.forall_to_for` which returns handles to loops.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  6 ++---
 .../TransformOps/LinalgTransformOps.cpp       |  3 +++
 .../lower-to-loops-using-interface.mlir       | 24 ++++++++++++-------
 3 files changed, 22 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 53ed31877c6f24..3d77e308f98f50 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1281,14 +1281,14 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
   let description = [{
     For operations that implement the `TilingInterface`, and implement
     the `generateScalarImplementation` method, lowers the operation to
-    loops. This operation does not return any handles.
+    loops. The return handles point to the generated loops.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$result);
 
   let assemblyFormat = [{
-    $target attr-dict `:` type($target)
+    $target attr-dict `:` functional-type(operands, results)
   }];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 0ac0a89dcc76ae..259d6e1b8f277e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2122,6 +2122,9 @@ DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
   if (failed(loops))
     return emitDefaultDefiniteFailure(target);
   rewriter.eraseOp(target);
+  for(auto &loop: *loops){
+    results.push_back(loop);
+  }
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 1b2c553b25ded0..e431e63134eefc 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %matmul : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %matmul
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -66,7 +67,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %generic : !transform.any_op
+    %0:2 = transform.structured.convert_to_loops %generic
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -111,7 +113,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %conv : !transform.any_op
+    %0:7 = transform.structured.convert_to_loops %conv
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -165,7 +168,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %pool : !transform.any_op
+    %0:6 = transform.structured.convert_to_loops %pool
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -216,7 +220,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %map = transform.structured.match ops{["linalg.map"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %map : !transform.any_op
+    %0 = transform.structured.convert_to_loops %map
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -248,7 +253,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %transpose : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %transpose
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -285,7 +291,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %reduce : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %reduce
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -322,7 +329,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %broadcast : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }

>From dd7a60ded72c8f3c55623c3cb0830f9dfd9c1a27 Mon Sep 17 00:00:00 2001
From: Lars <larshunloh at uni-muenster.de>
Date: Tue, 5 Mar 2024 10:50:54 +0100
Subject: [PATCH 2/2] Resolving clang-format issue for `ConvertToLoopsOp`

---
 mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 259d6e1b8f277e..ed6e9d36885769 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2122,7 +2122,7 @@ DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
   if (failed(loops))
     return emitDefaultDefiniteFailure(target);
   rewriter.eraseOp(target);
-  for(auto &loop: *loops){
+  for (auto &loop: *loops) {
     results.push_back(loop);
   }
   return DiagnosedSilenceableFailure::success();



More information about the Mlir-commits mailing list