[Mlir-commits] [mlir] [MLIR] [Transforms] Let `transform.structured.convert_to_loops` return handles to loops (PR #83984)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 5 10:05:17 PST 2024


https://github.com/lhunloh updated https://github.com/llvm/llvm-project/pull/83984

>From 49b688c3858c7063a22c96eb47b86a91c0ef26a0 Mon Sep 17 00:00:00 2001
From: Lars <larshunloh at uni-muenster.de>
Date: Tue, 5 Mar 2024 11:02:20 +0100
Subject: [PATCH 1/3] Let `transform.structured.convert_to_loops` return
 handles to loops

This lets `transform.structured.convert_to_loops` return handles to the
generated loops, making this transformation more useful to use
for (transformation-)nesting purposes. This is modelled after SCFs `transform.loop.forall_to_for` which returns handles to loops.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  6 ++---
 .../TransformOps/LinalgTransformOps.cpp       |  3 +++
 .../lower-to-loops-using-interface.mlir       | 24 ++++++++++++-------
 3 files changed, 22 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 53ed31877c6f24..3d77e308f98f50 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1281,14 +1281,14 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
   let description = [{
     For operations that implement the `TilingInterface`, and implement
     the `generateScalarImplementation` method, lowers the operation to
-    loops. This operation does not return any handles.
+    loops. The return handles point to the generated loops.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$result);
 
   let assemblyFormat = [{
-    $target attr-dict `:` type($target)
+    $target attr-dict `:` functional-type(operands, results)
   }];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 0ac0a89dcc76ae..ef046e77b97375 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2122,6 +2122,9 @@ DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
   if (failed(loops))
     return emitDefaultDefiniteFailure(target);
   rewriter.eraseOp(target);
+  for (auto &loop : *loops) {
+    results.push_back(loop);
+  }
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 1b2c553b25ded0..e431e63134eefc 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %matmul : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %matmul
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -66,7 +67,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %generic : !transform.any_op
+    %0:2 = transform.structured.convert_to_loops %generic
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -111,7 +113,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %conv : !transform.any_op
+    %0:7 = transform.structured.convert_to_loops %conv
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -165,7 +168,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %pool : !transform.any_op
+    %0:6 = transform.structured.convert_to_loops %pool
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -216,7 +220,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %map = transform.structured.match ops{["linalg.map"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %map : !transform.any_op
+    %0 = transform.structured.convert_to_loops %map
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -248,7 +253,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %transpose : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %transpose
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -285,7 +291,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %reduce : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %reduce
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }
@@ -322,7 +329,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %broadcast : !transform.any_op
+    %0:3 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
   }
 }

>From ac4d27c781a44a8f8cbdf0daa488fe2d2e10ba84 Mon Sep 17 00:00:00 2001
From: Lars <larshunloh at uni-muenster.de>
Date: Tue, 5 Mar 2024 18:43:14 +0100
Subject: [PATCH 2/3] Let `convert_to_loops` return handle to loops

This lets `transform.structured.convert_to_loops` return a handle to all
generated loops, making this transformation more useful to use
for (transformation-)nesting purposes. This handle may be split via `transform.split_handle` to get handles to each individual loop.
---
 .../Linalg/TransformOps/LinalgTransformOps.td | 19 ++---
 .../TransformOps/LinalgTransformOps.cpp       | 36 ++++++---
 .../lower-to-loops-using-interface.mlir       | 79 +++++++++++++++----
 3 files changed, 96 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 3d77e308f98f50..63a80919a1f1c1 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1274,33 +1274,28 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertToLoopsOp
+//===----------------------------------------------------------------------===//
+
 def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
     For operations that implement the `TilingInterface`, and implement
     the `generateScalarImplementation` method, lowers the operation to
-    loops. The return handles point to the generated loops.
+    loops. The return handle points to all generated loops.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs Variadic<TransformHandleTypeInterface>:$result);
+  let results = (outs TransformHandleTypeInterface:$result);
 
   let assemblyFormat = [{
     $target attr-dict `:` functional-type(operands, results)
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::TilingInterface target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
-
 //===----------------------------------------------------------------------===//
 // DecomposeInterfaceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ef046e77b97375..ae28049f02e391 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2112,19 +2112,31 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
 // ConvertToLoopsOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
-    transform::TransformRewriter &rewriter, TilingInterface target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  rewriter.setInsertionPoint(target);
-  FailureOr<SmallVector<scf::ForOp>> loops =
-      scf::lowerToLoopsUsingSCFForOp(rewriter, target);
-  if (failed(loops))
-    return emitDefaultDefiniteFailure(target);
-  rewriter.eraseOp(target);
-  for (auto &loop : *loops) {
-    results.push_back(loop);
+DiagnosedSilenceableFailure
+transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
+                                   transform::TransformResults &results,
+                                   transform::TransformState &state) {
+  SmallVector<Operation *> loops;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    auto tilingOp = dyn_cast<TilingInterface>(*target);
+    if (!target) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError()
+          << "expected the payload to implement TilingInterface";
+      diag.attachNote(target->getLoc()) << "payload op";
+      return diag;
+    }
+    rewriter.setInsertionPoint(target);
+    FailureOr<SmallVector<scf::ForOp>> generatedLoops =
+        scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
+    if (failed(generatedLoops))
+      return emitDefaultDefiniteFailure(target);
+    for (scf::ForOp &loop : *generatedLoops) {
+      loops.push_back(loop.getOperation());
+    }
+    rewriter.eraseOp(target);
   }
+  results.set(cast<OpResult>(getResult()), loops);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index e431e63134eefc..8cbee3cbb758b2 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -11,8 +11,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %0:3 = transform.structured.convert_to_loops %matmul
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %0 = transform.structured.convert_to_loops %matmul
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -38,6 +38,57 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+  %arg2 : memref<?x?xf32>, %arg3 : memref<?xf32>, %arg4 : memref<?xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+  linalg.matvec ins(%arg0, %arg3 : memref<?x?xf32>, memref<?xf32>)
+      outs(%arg4 : memref<?xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %linalg_ops = transform.structured.match interface{TilingInterface} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %linalg_ops
+      : (!transform.any_op) -> (!transform.any_op)
+    %1:5 = transform.split_handle %0
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @gemm
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: memref<?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
+//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
+//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
+//   CHECK-DAG:         %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
+//   CHECK-DAG:         %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
+//   CHECK-DAG:         %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
+//       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
+//       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
+//       CHECK:         memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
+//       CHECK:   scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
+//       CHECK:     scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
+//   CHECK-DAG:         %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV3]], %[[IV4]]]
+//   CHECK-DAG:         %[[RHS:.+]] = memref.load %[[ARG3]][%[[IV4]]]
+//   CHECK-DAG:         %[[OUT:.+]] = memref.load %[[ARG4]][%[[IV3]]]
+//       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
+//       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
+//       CHECK:         memref.store %[[ADDF]], %[[ARG4]][%[[IV3]]]
+
+// -----
+
 func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
     %arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) {
   linalg.generic {
@@ -67,8 +118,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %0:2 = transform.structured.convert_to_loops %generic
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %0 = transform.structured.convert_to_loops %generic
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -113,8 +164,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %0:7 = transform.structured.convert_to_loops %conv
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %0 = transform.structured.convert_to_loops %conv
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -168,8 +219,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %0:6 = transform.structured.convert_to_loops %pool
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %0 = transform.structured.convert_to_loops %pool
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -253,8 +304,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %0:3 = transform.structured.convert_to_loops %transpose
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %0 = transform.structured.convert_to_loops %transpose
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -291,8 +342,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %0:3 = transform.structured.convert_to_loops %reduce
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %0 = transform.structured.convert_to_loops %reduce
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -329,8 +380,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    %0:3 = transform.structured.convert_to_loops %broadcast
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }

>From 4e6a64141359c491c918047816204b8852345996 Mon Sep 17 00:00:00 2001
From: Lars <larshunloh at uni-muenster.de>
Date: Tue, 5 Mar 2024 19:04:49 +0100
Subject: [PATCH 3/3] Adjusted ConvertToLoops Description

---
 .../mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td       | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 63a80919a1f1c1..bdeab55091b9f3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1286,6 +1286,7 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
     For operations that implement the `TilingInterface`, and implement
     the `generateScalarImplementation` method, lowers the operation to
     loops. The return handle points to all generated loops.
+    Fails if the payload ops cannot be lowered to loops.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);



More information about the Mlir-commits mailing list