[Mlir-commits] [mlir] [MLIR] [Transforms] Let lowerToLoopsUsingSCFForOp delete target op, fixes #83252 (PR #83256)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 28 04:56:54 PST 2024
https://github.com/lhunloh updated https://github.com/llvm/llvm-project/pull/83256
>From 2f1ef500a3295a3763174b6f1b4718a5a16a38f6 Mon Sep 17 00:00:00 2001
From: Lars <larshunloh at uni-muenster.de>
Date: Wed, 28 Feb 2024 13:15:14 +0100
Subject: [PATCH 1/2] Let lowerToLoopsUsingSCFForOp delete target op.
The function mlir::scf::lowerToLoopsUsingSCFForOp was not deleting its
(structured) target op, resulting in IR with the expected loop nest in
front of the still remaining (structured) op, e.g. a linalg.matmul.
---
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69df..e1e9be858b251e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1133,5 +1133,6 @@ mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
return failure();
}
+ rewriter.eraseOp(op);
return loops;
}
>From 768c3dd6aa457e88aa319b2aad91d28575a66789 Mon Sep 17 00:00:00 2001
From: Lars <larshunloh at uni-muenster.de>
Date: Wed, 28 Feb 2024 13:54:20 +0100
Subject: [PATCH 2/2] Let transform.structured.convert_to_loops return handles
to loopnest
This lets transform.structured.convert_to_loops return handles to the generated loops (much like SCFs loop.forall_to_for), making this transformation more useful to use for (transformation-)nesting purposes.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 6 ++---
.../TransformOps/LinalgTransformOps.cpp | 3 +++
.../lower-to-loops-using-interface.mlir | 22 ++++++++++++-------
3 files changed, 20 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 309573a562872f..3f5f7196bec565 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1281,14 +1281,14 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
let description = [{
For operations that implement the `TilingInterface`, and implement
the `generateScalarImplementation` method, lowers the operation to
- loops. This operation does not return any handles.
+ loops. The return handles point to the generated loops.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$result);
let assemblyFormat = [{
- $target attr-dict `:` type($target)
+ $target attr-dict `:` functional-type(operands, results)
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 299965bcfc3ab3..2a662bbacbd2c6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2121,6 +2121,9 @@ DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
scf::lowerToLoopsUsingSCFForOp(rewriter, target);
if (failed(loops))
return emitDefaultDefiniteFailure(target);
+ for(auto &loop: *loops) {
+ results.push_back(loop);
+ }
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 7969de0d456bb6..274da65d4f1a82 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %matmul : !transform.any_op
+ %0:3 = transform.structured.convert_to_loops %matmul
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -65,7 +66,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %generic : !transform.any_op
+ %0:2 = transform.structured.convert_to_loops %generic : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -110,7 +111,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %conv : !transform.any_op
+ %0:7 = transform.structured.convert_to_loops %conv
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -164,7 +166,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %pool : !transform.any_op
+ %0:6 = transform.structured.convert_to_loops %pool
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -215,7 +218,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%map = transform.structured.match ops{["linalg.map"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %map : !transform.any_op
+ %0 = transform.structured.convert_to_loops %map : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
@@ -247,7 +250,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %transpose : !transform.any_op
+ %0:3 = transform.structured.convert_to_loops %transpose
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -284,7 +288,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %reduce : !transform.any_op
+ %0:3 = transform.structured.convert_to_loops %reduce
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -321,7 +326,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %broadcast : !transform.any_op
+ %0:3 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
More information about the Mlir-commits
mailing list