[llvm-branch-commits] [mlir] release/18.x: [MLIR] [Transforms] Let `transform.structured.convert_to_loops` return handles to loops (#83984) (PR #85942)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Mar 20 08:02:04 PDT 2024
https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/85942
Backport 0597644a6466ae9148b0b41cb8f95d5022e045c2 47bc565ca7990a2de20af4030baf08ac62739aca
Requested by: @lhunloh
>From d5933a73516f3bdfc37216d52278e0ca3d42859d Mon Sep 17 00:00:00 2001
From: Congcong Cai <congcongcai0907 at 163.com>
Date: Tue, 5 Mar 2024 03:58:12 +0800
Subject: [PATCH 1/2] [mlir][transform] replace original op to loop ops
(#83537)
(cherry picked from commit 0597644a6466ae9148b0b41cb8f95d5022e045c2)
---
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 1 +
.../TilingInterface/lower-to-loops-using-interface.mlir | 1 +
2 files changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 140bdd1f2db361..be875297fc93ca 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2092,6 +2092,7 @@ DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
scf::lowerToLoopsUsingSCFForOp(rewriter, target);
if (failed(loops))
return emitDefaultDefiniteFailure(target);
+ rewriter.eraseOp(target);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 7969de0d456bb6..1b2c553b25ded0 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -33,6 +33,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
// CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK-NOT: linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
// -----
>From 6db21232a2f1d47c61fb3bf5985487ee695491f3 Mon Sep 17 00:00:00 2001
From: lhunloh <8047408+lhunloh at users.noreply.github.com>
Date: Wed, 6 Mar 2024 22:07:30 +0000
Subject: [PATCH 2/2] [MLIR] [Transforms] Let
`transform.structured.convert_to_loops` return handles to loops (#83984)
This lets `transform.structured.convert_to_loops` return handles to the
generated loops, making this transformation more useful to use for
(transformation-)nesting purposes. This is modelled after SCFs
`transform.loop.forall_to_for` which returns handles to loops.
Introduced in commit aa2a96a24ae3a8cc04635ab6ede474c5f2665053, with a
note that they might move out of the `Linalg`-Dialect, but no reason
given for the non-return of handles. As far as I can see, this transform
always returns loops.
(cherry picked from commit 47bc565ca7990a2de20af4030baf08ac62739aca)
---
.../Linalg/TransformOps/LinalgTransformOps.td | 22 +++---
.../TransformOps/LinalgTransformOps.cpp | 35 ++++++---
.../lower-to-loops-using-interface.mlir | 75 +++++++++++++++++--
3 files changed, 101 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b139f1ef58b3a9..da7183dae75ffc 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1274,33 +1274,29 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
}];
}
+//===----------------------------------------------------------------------===//
+// ConvertToLoopsOp
+//===----------------------------------------------------------------------===//
+
def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
For operations that implement the `TilingInterface`, and implement
the `generateScalarImplementation` method, lowers the operation to
- loops. This operation does not return any handles.
+ loops. The return handle points to all generated loops.
+ Fails if the payload ops cannot be lowered to loops.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs);
+ let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = [{
- $target attr-dict `:` type($target)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::TilingInterface target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
+ $target attr-dict `:` functional-type(operands, results)
}];
}
-
//===----------------------------------------------------------------------===//
// DecomposeInterfaceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index be875297fc93ca..905875ae43ce8a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2083,16 +2083,31 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
// ConvertToLoopsOp
//===----------------------------------------------------------------------===//
-DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
- transform::TransformRewriter &rewriter, TilingInterface target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- rewriter.setInsertionPoint(target);
- FailureOr<SmallVector<scf::ForOp>> loops =
- scf::lowerToLoopsUsingSCFForOp(rewriter, target);
- if (failed(loops))
- return emitDefaultDefiniteFailure(target);
- rewriter.eraseOp(target);
+DiagnosedSilenceableFailure
+transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> loops;
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ auto tilingOp = dyn_cast<TilingInterface>(*target);
+ if (!target) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "expected the payload to implement TilingInterface";
+ diag.attachNote(target->getLoc()) << "payload op";
+ return diag;
+ }
+ rewriter.setInsertionPoint(target);
+ FailureOr<SmallVector<scf::ForOp>> generatedLoops =
+ scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
+ if (failed(generatedLoops))
+ return emitDefaultDefiniteFailure(target);
+ for (scf::ForOp &loop : *generatedLoops) {
+ loops.push_back(loop.getOperation());
+ }
+ rewriter.eraseOp(target);
+ }
+ results.set(cast<OpResult>(getResult()), loops);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 1b2c553b25ded0..8cbee3cbb758b2 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %matmul : !transform.any_op
+ %0 = transform.structured.convert_to_loops %matmul
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -37,6 +38,57 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+ %arg2 : memref<?x?xf32>, %arg3 : memref<?xf32>, %arg4 : memref<?xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%arg2 : memref<?x?xf32>)
+ linalg.matvec ins(%arg0, %arg3 : memref<?x?xf32>, memref<?xf32>)
+ outs(%arg4 : memref<?xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %linalg_ops = transform.structured.match interface{TilingInterface} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %linalg_ops
+ : (!transform.any_op) -> (!transform.any_op)
+ %1:5 = transform.split_handle %0
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @gemm
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
+// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
+// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
+// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
+// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
+// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
+// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
+// CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
+// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
+// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV3]], %[[IV4]]]
+// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG3]][%[[IV4]]]
+// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG4]][%[[IV3]]]
+// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
+// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
+// CHECK: memref.store %[[ADDF]], %[[ARG4]][%[[IV3]]]
+
+// -----
+
func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
%arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) {
linalg.generic {
@@ -66,7 +118,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %generic : !transform.any_op
+ %0 = transform.structured.convert_to_loops %generic
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -111,7 +164,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %conv : !transform.any_op
+ %0 = transform.structured.convert_to_loops %conv
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -165,7 +219,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %pool : !transform.any_op
+ %0 = transform.structured.convert_to_loops %pool
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -216,7 +271,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%map = transform.structured.match ops{["linalg.map"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %map : !transform.any_op
+ %0 = transform.structured.convert_to_loops %map
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -248,7 +304,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %transpose : !transform.any_op
+ %0 = transform.structured.convert_to_loops %transpose
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -285,7 +342,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %reduce : !transform.any_op
+ %0 = transform.structured.convert_to_loops %reduce
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -322,7 +380,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %broadcast : !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
More information about the llvm-branch-commits
mailing list