[Mlir-commits] [mlir] [MLIR] Add allow Insert/extract slice option to pack/unpack op (PR #117340)
Zhuoran Yin
llvmlistbot at llvm.org
Mon Dec 9 14:06:15 PST 2024
https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/117340
>From 012f6d46ff6c187470d6ca102be513e7a5a78a21 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 22 Nov 2024 15:52:12 +0000
Subject: [PATCH 1/7] [NFC] Add allowInsertSliceLowering to packOp and
allowExtractSliceLowering to UnPackOp
---
.../Linalg/TransformOps/LinalgTransformOps.td | 6 ++++--
.../mlir/Dialect/Linalg/Transforms/Transforms.h | 8 +++++---
.../Linalg/TransformOps/LinalgTransformOps.cpp | 8 ++++++--
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 12 +++++++-----
4 files changed, 22 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e3084530bd11b5..ea96da77b6c331 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -548,7 +548,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
Return handles to the newly produced pad, expand_shape and transpose ops.
}];
- let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
+ let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
+ DefaultValuedAttr<BoolAttr, "true">:$allowInsertSliceLowering);
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
@@ -588,7 +589,8 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
}];
- let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target);
+ let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
+ DefaultValuedAttr<BoolAttr, "true">:$allowExtractSliceLowering);
let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 51967f83fee377..fd27e7929764d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1121,7 +1121,8 @@ struct LowerPackResult {
/// Rewrite pack as pad + reshape + transpose.
FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
- tensor::PackOp packOp);
+ tensor::PackOp packOp,
+ bool allowInsertSliceLowering = true);
struct LowerUnPackOpResult {
tensor::EmptyOp emptyOp;
@@ -1131,8 +1132,9 @@ struct LowerUnPackOpResult {
};
/// Rewrite pack as empty + transpose + reshape + extract_slice.
-FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
- tensor::UnPackOp unPackOp);
+FailureOr<LowerUnPackOpResult>
+lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
+ bool allowExtractSliceLowering = true);
/// Struct to hold the result of a `pack` call.
struct PackResult {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ada80deacfdbfe..5117a5c58c381d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1171,7 +1171,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
+ bool allowInsertSliceLowering = getAllowInsertSliceLowering();
+ FailureOr<LowerPackResult> res =
+ lowerPack(rewriter, target, allowInsertSliceLowering);
if (failed(res)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "cannot lower to pad + expand + transpose";
@@ -1191,7 +1193,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
+ bool allowExtractSliceLowering = getAllowExtractSliceLowering();
+ FailureOr<LowerUnPackOpResult> res =
+ lowerUnPack(rewriter, target, allowExtractSliceLowering);
if (failed(res)) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index d92543d7264625..0717dad4c2852f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -217,7 +217,8 @@ struct PackedOperandsDimList {
} // namespace
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
- tensor::PackOp packOp) {
+ tensor::PackOp packOp,
+ bool allowInsertSliceLowering) {
// 1. Filter out NYI cases.
auto packedTensorType =
cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -295,7 +296,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
- if (packOp.isLikePad()) {
+ if (allowInsertSliceLowering && packOp.isLikePad()) {
// Pack ops which operate as simple pads may not produce legal
// tensor.insert_slice operations when the packed type does not rank reduce
// to the padded type.
@@ -351,8 +352,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
return LowerPackResult{padOp, reshapeOp, transposeOp};
}
-FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
- tensor::UnPackOp unPackOp) {
+FailureOr<LowerUnPackOpResult>
+linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
+ bool allowExtractSliceLowering) {
Location loc = unPackOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
@@ -362,7 +364,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
- if (unPackOp.isLikeUnPad()) {
+ if (allowExtractSliceLowering && unPackOp.isLikeUnPad()) {
// This unpack is just a plain unpad.
// Just extract the slice from the higher ranked tensor.
ArrayRef<int64_t> destShape = destTensorType.getShape();
>From 46b72028918f13f8faf7ee474d6da14f15a246ef Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 22 Nov 2024 17:51:40 +0000
Subject: [PATCH 2/7] This commit add test cases to allowInsertSliceLowering
and allowExtractSliceLowering
---
.../Dialect/Linalg/transform-lower-pack.mlir | 56 +++++++++++++++++++
1 file changed, 56 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 7aadf190695630..2e6a5ea97aaa33 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
// -----
+// This is same as pack_as_pad but since we explicitly added {allowInsertSliceLowering = false}, it should not
+// be lowered to insert_slice.
+// CHECK-LABEL: func.func @pack_disallowed_as_pad(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
+func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+ // tensor.pack is lowered to tensor.pad + tensor.expand_shape + tensor.insert_slice
+ // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
+ // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+ // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
+ %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+ : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+ return %pack : tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack {allowInsertSliceLowering = false}: (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+ transform.yield
+ }
+}
+
+// -----
+
// Check that we don't lower the following pack as a pad.
// Although all the outer most dimensions in the resulting shape are 1s,
// some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,34 @@ module attributes {transform.with_named_sequence} {
// -----
+// This is same as upack_as_pad but since we explicitly added {allowExtractSlicelowering = false}, it should not
+// be lowered to extract_slice.
+// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
+func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+
+ // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+ %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+ : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
+ return %pack : tensor<129x47x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack {allowExtractSliceLowering = false}: (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
%dest: tensor<200x4x16x100x16x32xi32>)
>From 0fa54017fd955f7637f9c8289896b4691518537f Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 25 Nov 2024 15:57:32 +0000
Subject: [PATCH 3/7] Address review requests
- Renamed allowInsertSliceLowering to lowerPadLikeWithInsertSlice
- Renamed allowExtractSliceLowering to lowerUnpadLikeWithExtractSlice
- Removed the redundant unit test since this is NFC change
This reverts commit 46b72028918f13f8faf7ee474d6da14f15a246ef.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 4 +-
.../Dialect/Linalg/Transforms/Transforms.h | 4 +-
.../TransformOps/LinalgTransformOps.cpp | 8 +--
.../Dialect/Linalg/Transforms/Transforms.cpp | 8 +--
.../Dialect/Linalg/transform-lower-pack.mlir | 56 -------------------
5 files changed, 12 insertions(+), 68 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ea96da77b6c331..675a766ec98b3c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -549,7 +549,7 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
}];
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
- DefaultValuedAttr<BoolAttr, "true">:$allowInsertSliceLowering);
+ DefaultValuedAttr<BoolAttr, "true">:$lowerPadLikeWithInsertSlice);
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
@@ -590,7 +590,7 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
}];
let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
- DefaultValuedAttr<BoolAttr, "true">:$allowExtractSliceLowering);
+ DefaultValuedAttr<BoolAttr, "true">:$lowerUnpadLikeWithExtractSlice);
let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fd27e7929764d3..82558de0fbfe67 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1122,7 +1122,7 @@ struct LowerPackResult {
/// Rewrite pack as pad + reshape + transpose.
FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp,
- bool allowInsertSliceLowering = true);
+ bool lowerPadLikeWithInsertSlice = true);
struct LowerUnPackOpResult {
tensor::EmptyOp emptyOp;
@@ -1134,7 +1134,7 @@ struct LowerUnPackOpResult {
/// Rewrite pack as empty + transpose + reshape + extract_slice.
FailureOr<LowerUnPackOpResult>
lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
- bool allowExtractSliceLowering = true);
+ bool lowerUnpadLikeWithExtractSlice = true);
/// Struct to hold the result of a `pack` call.
struct PackResult {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5117a5c58c381d..06f58d4943394f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1171,9 +1171,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- bool allowInsertSliceLowering = getAllowInsertSliceLowering();
+ bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
FailureOr<LowerPackResult> res =
- lowerPack(rewriter, target, allowInsertSliceLowering);
+ lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
if (failed(res)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "cannot lower to pad + expand + transpose";
@@ -1193,9 +1193,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- bool allowExtractSliceLowering = getAllowExtractSliceLowering();
+ bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
FailureOr<LowerUnPackOpResult> res =
- lowerUnPack(rewriter, target, allowExtractSliceLowering);
+ lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
if (failed(res)) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0717dad4c2852f..f597faa16cf60f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -218,7 +218,7 @@ struct PackedOperandsDimList {
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp,
- bool allowInsertSliceLowering) {
+ bool lowerPadLikeWithInsertSlice) {
// 1. Filter out NYI cases.
auto packedTensorType =
cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -296,7 +296,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
- if (allowInsertSliceLowering && packOp.isLikePad()) {
+ if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
// Pack ops which operate as simple pads may not produce legal
// tensor.insert_slice operations when the packed type does not rank reduce
// to the padded type.
@@ -354,7 +354,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
FailureOr<LowerUnPackOpResult>
linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
- bool allowExtractSliceLowering) {
+ bool lowerUnpadLikeWithExtractSlice) {
Location loc = unPackOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
@@ -364,7 +364,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
- if (allowExtractSliceLowering && unPackOp.isLikeUnPad()) {
+ if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
// This unpack is just a plain unpad.
// Just extract the slice from the higher ranked tensor.
ArrayRef<int64_t> destShape = destTensorType.getShape();
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 2e6a5ea97aaa33..7aadf190695630 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -96,34 +96,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// This is same as pack_as_pad but since we explicitly added {allowInsertSliceLowering = false}, it should not
-// be lowered to insert_slice.
-// CHECK-LABEL: func.func @pack_disallowed_as_pad(
-// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
-// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
-func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
- %cst_0 = arith.constant 0.0 : f32
- // tensor.pack is lowered to tensor.pad + tensor.expand_shape + tensor.insert_slice
- // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
- // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
- // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
- %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
- : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
- return %pack : tensor<1x1x1x1x136x64x16x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
- : (!transform.any_op) -> !transform.op<"tensor.pack">
- transform.structured.lower_pack %pack {allowInsertSliceLowering = false}: (!transform.op<"tensor.pack">)
- -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
- transform.yield
- }
-}
-
-// -----
-
// Check that we don't lower the following pack as a pad.
// Although all the outer most dimensions in the resulting shape are 1s,
// some of the original dimensions are not part of the inner_dims_pos, hence
@@ -261,34 +233,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// This is same as upack_as_pad but since we explicitly added {allowExtractSlicelowering = false}, it should not
-// be lowered to extract_slice.
-// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
-func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
- %cst_0 = arith.constant 0.0 : f32
-
- // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
- // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
- %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
- : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
- return %pack : tensor<129x47x16x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
- : (!transform.any_op) -> !transform.op<"tensor.unpack">
- transform.structured.lower_unpack %unpack {allowExtractSliceLowering = false}: (!transform.op<"tensor.unpack">)
- -> (!transform.op<"tensor.empty">,
- !transform.op<"linalg.transpose">,
- !transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
- transform.yield
- }
-}
-
-// -----
-
// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
%dest: tensor<200x4x16x100x16x32xi32>)
>From 671829ee1c89b1d5ff82c866f7074a969414698f Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 2 Dec 2024 21:44:20 +0000
Subject: [PATCH 4/7] Adding test cases to allowInsertSliceLowering and
allowExtractSliceLowering
---
.../Dialect/Linalg/transform-lower-pack.mlir | 60 +++++++++++++++++++
1 file changed, 60 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 7aadf190695630..2f7f2ff5211bf7 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
// -----
+// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
+// be lowered to insert_slice.
+// CHECK-LABEL: func.func @pack_disallowed_as_pad(
+func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+ // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
+ // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
+ // CHECK: %[[PAD:.*]] = tensor.pad %[[ARG0]]
+ // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
+ // CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
+ // CHECK: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
+ %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+ : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+ return %pack : tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}: (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+ transform.yield
+ }
+}
+
+// -----
+
// Check that we don't lower the following pack as a pad.
// Although all the outer most dimensions in the resulting shape are 1s,
// some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} {
// -----
+// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
+// be lowered to extract_slice.
+// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
+func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+
+ // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
+ // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
+ // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
+ // CHECK: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
+ %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+ : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
+ return %pack : tensor<129x47x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}: (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
%dest: tensor<200x4x16x100x16x32xi32>)
>From 3ad8cd64f687afa5cadc57e40aa228d6a587ed03 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 5 Dec 2024 15:42:03 +0000
Subject: [PATCH 5/7] Add test to verify pack/producer unpack/consumer fusion
---
.../transform-tile-and-fuse-pack-unpack.mlir | 121 ++++++++++++++++++
1 file changed, 121 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
new file mode 100644
index 00000000000000..31c28a852eef23
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -0,0 +1,121 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s
+
+// For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated.
+// This allows linalg.transpose to be fused as a producer operation. Alternatively, without this attribute
+// insert_slice will be generated and fusion blocked.
+
+module {
+ // CHECK-label: func @fuse_pack_as_producer
+ // CHECK: scf.forall {{.*}} {
+ // CHECK: linalg.transpose
+ // CHECK: linalg.generic
+ // CHECK: scf.forall.in_parallel
+ // CHECK: }
+ func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
+ -> tensor<4x4x128x256xf32> {
+ %dest = tensor.empty() : tensor<1x1x128x256xf32>
+ %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+ into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
+
+ %out = tensor.empty() : tensor<4x4x128x256xf32>
+ %res = linalg.generic
+ {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
+ affine_map<(i, j, k, l) -> (i, j, k, l)>,
+ affine_map<(i, j, k, l) -> (i, j, k, l)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
+ outs(%out: tensor<4x4x128x256xf32>) {
+ ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
+ %r = arith.addf %pack_elem, %other_elem : f32
+ linalg.yield %r : f32
+ } -> tensor<4x4x128x256xf32>
+
+ return %res : tensor<4x4x128x256xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ // Find and lower pack operation.
+ %pack = transform.structured.match ops{["tensor.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.op<"tensor.pack">
+ %paded, %expanded, %transpose = transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}
+ : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">,
+ !transform.op<"tensor.expand_shape">,
+ !transform.op<"linalg.transpose">)
+
+ %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile the lialg operation with parallel forall loop tiling [4, 4].
+ %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ // Fuse the transpose operation into the tiled loop.
+ transform.structured.fuse_into_containing_op %transpose into %forall_op
+ : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
+// -----
+// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
+// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
+// extract_slice will be generated and fusion blocked.
+
+module {
+ // CHECK-label: func @fuse_unpack_as_consumer
+ // CHECK: scf.forall {{.*}} {
+ // CHECK: linalg.generic
+ // CHECK: linalg.transpose
+ // CHECK: scf.forall.in_parallel
+ // CHECK: }
+ func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
+ -> tensor<128x256xf32> {
+ %out = tensor.empty() : tensor<1x1x128x256xf32>
+ %res = linalg.generic
+ {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
+ affine_map<(i, j, k, l) -> (i, j, k, l)>,
+ affine_map<(i, j, k, l) -> (0, 0, k, l)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
+ outs(%out: tensor<1x1x128x256xf32>) {
+ ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
+ %r = arith.addf %unpack_elem, %other_elem : f32
+ linalg.yield %r : f32
+ } -> tensor<1x1x128x256xf32>
+
+ %dest = tensor.empty() : tensor<128x256xf32>
+ %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+ into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
+
+ return %unpack : tensor<128x256xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ // Find and lower unpack operation.
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}
+ : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+
+ %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile the lialg operation with parallel forall loop tiling [4, 4].
+ %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ // Fuse the consumer operation into the tiled loop.
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
+ : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
+ transform.test.fuse_consumer %slice_op
+ : (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
>From 68b53288b08777ff1e895a7d0f54cdd97373bc25 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 9 Dec 2024 16:12:35 +0000
Subject: [PATCH 6/7] Adding additional negative test cases
- Added additional test cases to demonstrate insert/extract slice will
block producer/consumer fusion
- Readability enahncements
---
.../Dialect/Linalg/transform-lower-pack.mlir | 34 ++---
.../transform-tile-and-fuse-pack-unpack.mlir | 117 ++++++++++++++++++
2 files changed, 134 insertions(+), 17 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 2f7f2ff5211bf7..5f8ff36a165786 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -98,15 +98,15 @@ module attributes {transform.with_named_sequence} {
// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
// be lowered to insert_slice.
-// CHECK-LABEL: func.func @pack_disallowed_as_pad(
-func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+// CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice(
+func.func @pack_as_pad_disabled_insert_slice(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
// CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
- // CHECK: %[[PAD:.*]] = tensor.pad %[[ARG0]]
+ // CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]]
// CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
// CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
- // CHECK: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
+ // CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
%pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
: tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
return %pack : tensor<1x1x1x1x136x64x16x16xf32>
@@ -261,18 +261,18 @@ module attributes {transform.with_named_sequence} {
// -----
-// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
+// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
// be lowered to extract_slice.
-// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
-func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+// CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice(
+func.func @unpack_as_pad_disabled_extract_slice(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
// tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
- // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
- // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
- // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
- // CHECK: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
+ // CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
+ // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
+ // CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
%pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
: tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
return %pack : tensor<129x47x16x16xf32>
@@ -632,7 +632,7 @@ func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
- : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
@@ -687,9 +687,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: @unpack_with_outer_dims_perm
// CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
-// CHECK: %[[TRAN:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
+// CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
// CHECK-SAME: permutation = [1, 3, 0, 2]
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
// CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32>
@@ -698,7 +698,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: linalg.copy ins(%[[SLICE]]
// CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
- %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
+ %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
return %unpack : tensor<32x64xf32>
}
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index 31c28a852eef23..ffed9ab6e06535 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -58,6 +58,62 @@ module {
}
}
+// -----
+// For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion.
+
+module {
+ // CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice
+ // CHECK: tensor.insert_slice
+ // CHECK: scf.forall {{.*}} {
+ // CHECK: scf.forall.in_parallel
+ // CHECK: }
+ func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
+ -> tensor<4x4x128x256xf32> {
+ %dest = tensor.empty() : tensor<1x1x128x256xf32>
+ %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+ into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
+
+ %out = tensor.empty() : tensor<4x4x128x256xf32>
+ %res = linalg.generic
+ {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
+ affine_map<(i, j, k, l) -> (i, j, k, l)>,
+ affine_map<(i, j, k, l) -> (i, j, k, l)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
+ outs(%out: tensor<4x4x128x256xf32>) {
+ ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
+ %r = arith.addf %pack_elem, %other_elem : f32
+ linalg.yield %r : f32
+ } -> tensor<4x4x128x256xf32>
+
+ return %res : tensor<4x4x128x256xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ // Find and lower pack operation.
+ %pack = transform.structured.match ops{["tensor.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.op<"tensor.pack">
+ %paded, %expanded, %transpose = transform.structured.lower_pack %pack
+ : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">,
+ !transform.op<"tensor.expand_shape">,
+ !transform.op<"linalg.transpose">)
+
+ %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile the lialg operation with parallel forall loop tiling [4, 4].
+ %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ // Fuse the transpose operation into the tiled loop.
+ transform.structured.fuse_into_containing_op %transpose into %forall_op
+ : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
// -----
// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
@@ -119,3 +175,64 @@ module {
}
}
}
+
+// -----
+// For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion.
+
+module {
+ // CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice
+ // CHECK: scf.forall {{.*}} {
+ // CHECK: linalg.generic
+ // CHECK: scf.forall.in_parallel
+ // CHECK: }
+ // CHECK: tensor.extract_slice
+ func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
+ -> tensor<128x256xf32> {
+ %out = tensor.empty() : tensor<1x1x128x256xf32>
+ %res = linalg.generic
+ {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
+ affine_map<(i, j, k, l) -> (i, j, k, l)>,
+ affine_map<(i, j, k, l) -> (0, 0, k, l)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
+ outs(%out: tensor<1x1x128x256xf32>) {
+ ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
+ %r = arith.addf %unpack_elem, %other_elem : f32
+ linalg.yield %r : f32
+ } -> tensor<1x1x128x256xf32>
+
+ %dest = tensor.empty() : tensor<128x256xf32>
+ %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+ into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
+
+ return %unpack : tensor<128x256xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ // Find and lower unpack operation.
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack
+ : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+
+ %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile the lialg operation with parallel forall loop tiling [4, 4].
+ %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ // Fuse the consumer operation into the tiled loop.
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
+ : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
+ // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
+ // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
+ // to fuse" error.
+ transform.yield
+ }
+ }
+}
>From f6c54e82d01aa6ec026cc8eeae6dc5da3aa8f7d4 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 9 Dec 2024 22:00:44 +0000
Subject: [PATCH 7/7] Add additional LIT variable
This help to clearly demonstrate the produer fusion in pack case and
consumer fusion in unpack case.
---
.../transform-tile-and-fuse-pack-unpack.mlir | 30 ++++++++++---------
1 file changed, 16 insertions(+), 14 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index ffed9ab6e06535..faf7ff9ad7ed09 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -1,14 +1,14 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s
// For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated.
-// This allows linalg.transpose to be fused as a producer operation. Alternatively, without this attribute
-// insert_slice will be generated and fusion blocked.
+// This allows linalg.transpose to be fused as a producer operation. In below testcase, linalg.transpose
+// as a producer operation is fused into the scf.forall loop.
module {
// CHECK-label: func @fuse_pack_as_producer
// CHECK: scf.forall {{.*}} {
- // CHECK: linalg.transpose
- // CHECK: linalg.generic
+ // CHECK: %[[PRODUCER:.*]] = linalg.transpose
+ // CHECK: linalg.generic {{.*}} ins(%[[PRODUCER]]
// CHECK: scf.forall.in_parallel
// CHECK: }
func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
@@ -60,11 +60,13 @@ module {
// -----
// For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion.
+// In below testcase, tensor.insert_slice as a producer operation cannot be fused into the scf.forall loop.
module {
// CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice
- // CHECK: tensor.insert_slice
+ // CHECK: %[[PRODUCER:.*]] = tensor.insert_slice
// CHECK: scf.forall {{.*}} {
+ // CHECK: linalg.generic {{.*}} ins(%[[PRODUCER]]
// CHECK: scf.forall.in_parallel
// CHECK: }
func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
@@ -116,14 +118,13 @@ module {
// -----
// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
-// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
-// extract_slice will be generated and fusion blocked.
-
+// This allows linalg.transpose to be fused as a consumer operation. In below testcase, linalg.transpose
+// as a consumer operation is fused into the scf.forall loop.
module {
// CHECK-label: func @fuse_unpack_as_consumer
// CHECK: scf.forall {{.*}} {
- // CHECK: linalg.generic
- // CHECK: linalg.transpose
+ // CHECK: %[[CONSUMER:.*]] = linalg.generic
+ // CHECK: linalg.transpose ins(%[[CONSUMER]]
// CHECK: scf.forall.in_parallel
// CHECK: }
func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
@@ -178,14 +179,15 @@ module {
// -----
// For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion.
-
+// In below testcase, tensor.extract_slice as a consumer operation cannot be fused into the scf.forall loop.
module {
// CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice
- // CHECK: scf.forall {{.*}} {
- // CHECK: linalg.generic
+ // CHECK: %[[CONSUMER:.*]] = scf.forall {{.*}} {
+ // CHECK: %[[ADDF:.*]] = linalg.generic
// CHECK: scf.forall.in_parallel
+ // CHECK: tensor.parallel_insert_slice %[[ADDF]]
// CHECK: }
- // CHECK: tensor.extract_slice
+ // CHECK: tensor.extract_slice %[[CONSUMER]]
func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
-> tensor<128x256xf32> {
%out = tensor.empty() : tensor<1x1x128x256xf32>
More information about the Mlir-commits
mailing list