[Mlir-commits] [mlir] [MLIR] Add allow Insert/extract slice option to pack/unpack op (PR #117340)

Zhuoran Yin llvmlistbot at llvm.org
Mon Dec 9 14:02:00 PST 2024


https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/117340

>From 012f6d46ff6c187470d6ca102be513e7a5a78a21 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 22 Nov 2024 15:52:12 +0000
Subject: [PATCH 1/7] [NFC] Add allowInsertSliceLowering to packOp and
 allowExtractSliceLowering to UnPackOp

---
 .../Linalg/TransformOps/LinalgTransformOps.td        |  6 ++++--
 .../mlir/Dialect/Linalg/Transforms/Transforms.h      |  8 +++++---
 .../Linalg/TransformOps/LinalgTransformOps.cpp       |  8 ++++++--
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp    | 12 +++++++-----
 4 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e3084530bd11b5..ea96da77b6c331 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -548,7 +548,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
     Return handles to the newly produced pad, expand_shape and transpose ops.
   }];
 
-  let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
+  let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
+                       DefaultValuedAttr<BoolAttr, "true">:$allowInsertSliceLowering);
   let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
                       Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
@@ -588,7 +589,8 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
     Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
   }];
 
-  let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target);
+  let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
+                       DefaultValuedAttr<BoolAttr, "true">:$allowExtractSliceLowering);
   let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
                       Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 51967f83fee377..fd27e7929764d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1121,7 +1121,8 @@ struct LowerPackResult {
 
 /// Rewrite pack as pad + reshape + transpose.
 FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
-                                     tensor::PackOp packOp);
+                                     tensor::PackOp packOp,
+                                     bool allowInsertSliceLowering = true);
 
 struct LowerUnPackOpResult {
   tensor::EmptyOp emptyOp;
@@ -1131,8 +1132,9 @@ struct LowerUnPackOpResult {
 };
 
 /// Rewrite pack as empty + transpose + reshape + extract_slice.
-FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
-                                           tensor::UnPackOp unPackOp);
+FailureOr<LowerUnPackOpResult>
+lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
+            bool allowExtractSliceLowering = true);
 
 /// Struct to hold the result of a `pack` call.
 struct PackResult {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ada80deacfdbfe..5117a5c58c381d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1171,7 +1171,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
     transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
+  bool allowInsertSliceLowering = getAllowInsertSliceLowering();
+  FailureOr<LowerPackResult> res =
+      lowerPack(rewriter, target, allowInsertSliceLowering);
   if (failed(res)) {
     return mlir::emitSilenceableFailure(target->getLoc())
            << "cannot lower to pad + expand + transpose";
@@ -1191,7 +1193,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
     transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
+  bool allowExtractSliceLowering = getAllowExtractSliceLowering();
+  FailureOr<LowerUnPackOpResult> res =
+      lowerUnPack(rewriter, target, allowExtractSliceLowering);
   if (failed(res)) {
     DiagnosedSilenceableFailure diag =
         emitSilenceableError()
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index d92543d7264625..0717dad4c2852f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -217,7 +217,8 @@ struct PackedOperandsDimList {
 } // namespace
 
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
-                                             tensor::PackOp packOp) {
+                                             tensor::PackOp packOp,
+                                             bool allowInsertSliceLowering) {
   // 1. Filter out NYI cases.
   auto packedTensorType =
       cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -295,7 +296,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
       DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
 
-  if (packOp.isLikePad()) {
+  if (allowInsertSliceLowering && packOp.isLikePad()) {
     // Pack ops which operate as simple pads may not produce legal
     // tensor.insert_slice operations when the packed type does not rank reduce
     // to the padded type.
@@ -351,8 +352,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   return LowerPackResult{padOp, reshapeOp, transposeOp};
 }
 
-FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
-                                                   tensor::UnPackOp unPackOp) {
+FailureOr<LowerUnPackOpResult>
+linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
+                    bool allowExtractSliceLowering) {
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
@@ -362,7 +364,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
   auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
-  if (unPackOp.isLikeUnPad()) {
+  if (allowExtractSliceLowering && unPackOp.isLikeUnPad()) {
     // This unpack is just a plain unpad.
     // Just extract the slice from the higher ranked tensor.
     ArrayRef<int64_t> destShape = destTensorType.getShape();

>From 46b72028918f13f8faf7ee474d6da14f15a246ef Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 22 Nov 2024 17:51:40 +0000
Subject: [PATCH 2/7] This commit add test cases to allowInsertSliceLowering
 and allowExtractSliceLowering

---
 .../Dialect/Linalg/transform-lower-pack.mlir  | 56 +++++++++++++++++++
 1 file changed, 56 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 7aadf190695630..2e6a5ea97aaa33 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// This is same as pack_as_pad but since we explicitly added {allowInsertSliceLowering = false}, it should not
+// be lowered to insert_slice.
+// CHECK-LABEL: func.func @pack_disallowed_as_pad(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
+func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+  // tensor.pack is lowered to tensor.pad + tensor.expand_shape + tensor.insert_slice
+  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
+  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+  //  CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
+  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+  return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.pack">
+    transform.structured.lower_pack %pack {allowInsertSliceLowering = false}: (!transform.op<"tensor.pack">)
+      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+      transform.yield
+  }
+}
+
+// -----
+
 // Check that we don't lower the following pack as a pad.
 // Although all the outer most dimensions in the resulting shape are 1s,
 // some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,34 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// This is same as upack_as_pad but since we explicitly added {allowExtractSlicelowering = false}, it should not 
+// be lowered to extract_slice.
+// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
+func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+
+  // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
+  //  CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
+  return %pack : tensor<129x47x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
+    transform.structured.lower_unpack %unpack {allowExtractSliceLowering = false}: (!transform.op<"tensor.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+          transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func.func @pack_with_outer_dims_perm(
 func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
                                      %dest: tensor<200x4x16x100x16x32xi32>)

>From 0fa54017fd955f7637f9c8289896b4691518537f Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 25 Nov 2024 15:57:32 +0000
Subject: [PATCH 3/7] Address review requests

- Renamed allowInsertSliceLowering to lowerPadLikeWithInsertSlice
- Renamed allowExtractSliceLowering to lowerUnpadLikeWithExtractSlice
- Removed the redundant unit test since this is NFC change

This reverts commit 46b72028918f13f8faf7ee474d6da14f15a246ef.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  4 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  4 +-
 .../TransformOps/LinalgTransformOps.cpp       |  8 +--
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  8 +--
 .../Dialect/Linalg/transform-lower-pack.mlir  | 56 -------------------
 5 files changed, 12 insertions(+), 68 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ea96da77b6c331..675a766ec98b3c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -549,7 +549,7 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
   }];
 
   let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
-                       DefaultValuedAttr<BoolAttr, "true">:$allowInsertSliceLowering);
+                       DefaultValuedAttr<BoolAttr, "true">:$lowerPadLikeWithInsertSlice);
   let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
                       Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
@@ -590,7 +590,7 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
   }];
 
   let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
-                       DefaultValuedAttr<BoolAttr, "true">:$allowExtractSliceLowering);
+                       DefaultValuedAttr<BoolAttr, "true">:$lowerUnpadLikeWithExtractSlice);
   let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
                       Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fd27e7929764d3..82558de0fbfe67 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1122,7 +1122,7 @@ struct LowerPackResult {
 /// Rewrite pack as pad + reshape + transpose.
 FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
                                      tensor::PackOp packOp,
-                                     bool allowInsertSliceLowering = true);
+                                     bool lowerPadLikeWithInsertSlice = true);
 
 struct LowerUnPackOpResult {
   tensor::EmptyOp emptyOp;
@@ -1134,7 +1134,7 @@ struct LowerUnPackOpResult {
 /// Rewrite pack as empty + transpose + reshape + extract_slice.
 FailureOr<LowerUnPackOpResult>
 lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
-            bool allowExtractSliceLowering = true);
+            bool lowerUnpadLikeWithExtractSlice = true);
 
 /// Struct to hold the result of a `pack` call.
 struct PackResult {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5117a5c58c381d..06f58d4943394f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1171,9 +1171,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
     transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  bool allowInsertSliceLowering = getAllowInsertSliceLowering();
+  bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
   FailureOr<LowerPackResult> res =
-      lowerPack(rewriter, target, allowInsertSliceLowering);
+      lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
   if (failed(res)) {
     return mlir::emitSilenceableFailure(target->getLoc())
            << "cannot lower to pad + expand + transpose";
@@ -1193,9 +1193,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
     transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  bool allowExtractSliceLowering = getAllowExtractSliceLowering();
+  bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
   FailureOr<LowerUnPackOpResult> res =
-      lowerUnPack(rewriter, target, allowExtractSliceLowering);
+      lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
   if (failed(res)) {
     DiagnosedSilenceableFailure diag =
         emitSilenceableError()
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0717dad4c2852f..f597faa16cf60f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -218,7 +218,7 @@ struct PackedOperandsDimList {
 
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                              tensor::PackOp packOp,
-                                             bool allowInsertSliceLowering) {
+                                             bool lowerPadLikeWithInsertSlice) {
   // 1. Filter out NYI cases.
   auto packedTensorType =
       cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -296,7 +296,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
       DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
 
-  if (allowInsertSliceLowering && packOp.isLikePad()) {
+  if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
     // Pack ops which operate as simple pads may not produce legal
     // tensor.insert_slice operations when the packed type does not rank reduce
     // to the padded type.
@@ -354,7 +354,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
 
 FailureOr<LowerUnPackOpResult>
 linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
-                    bool allowExtractSliceLowering) {
+                    bool lowerUnpadLikeWithExtractSlice) {
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
@@ -364,7 +364,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
   auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
-  if (allowExtractSliceLowering && unPackOp.isLikeUnPad()) {
+  if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
     // This unpack is just a plain unpad.
     // Just extract the slice from the higher ranked tensor.
     ArrayRef<int64_t> destShape = destTensorType.getShape();
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 2e6a5ea97aaa33..7aadf190695630 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -96,34 +96,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// This is same as pack_as_pad but since we explicitly added {allowInsertSliceLowering = false}, it should not
-// be lowered to insert_slice.
-// CHECK-LABEL: func.func @pack_disallowed_as_pad(
-// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
-// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
-func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
-  %cst_0 = arith.constant 0.0 : f32
-  // tensor.pack is lowered to tensor.pad + tensor.expand_shape + tensor.insert_slice
-  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
-  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
-  //  CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
-  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
-    : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
-  return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
-      : (!transform.any_op) -> !transform.op<"tensor.pack">
-    transform.structured.lower_pack %pack {allowInsertSliceLowering = false}: (!transform.op<"tensor.pack">)
-      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
-      transform.yield
-  }
-}
-
-// -----
-
 // Check that we don't lower the following pack as a pad.
 // Although all the outer most dimensions in the resulting shape are 1s,
 // some of the original dimensions are not part of the inner_dims_pos, hence
@@ -261,34 +233,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// This is same as upack_as_pad but since we explicitly added {allowExtractSlicelowering = false}, it should not 
-// be lowered to extract_slice.
-// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
-func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
-  %cst_0 = arith.constant 0.0 : f32
-
-  // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
-  //  CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
-  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
-    : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
-  return %pack : tensor<129x47x16x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
-      : (!transform.any_op) -> !transform.op<"tensor.unpack">
-    transform.structured.lower_unpack %unpack {allowExtractSliceLowering = false}: (!transform.op<"tensor.unpack">)
-      -> (!transform.op<"tensor.empty">,
-          !transform.op<"linalg.transpose">,
-          !transform.op<"tensor.collapse_shape">,
-          !transform.op<"tensor.extract_slice">)
-          transform.yield
-  }
-}
-
-// -----
-
 // CHECK-LABEL: func.func @pack_with_outer_dims_perm(
 func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
                                      %dest: tensor<200x4x16x100x16x32xi32>)

>From 671829ee1c89b1d5ff82c866f7074a969414698f Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 2 Dec 2024 21:44:20 +0000
Subject: [PATCH 4/7] Adding test cases to allowInsertSliceLowering and
 allowExtractSliceLowering

---
 .../Dialect/Linalg/transform-lower-pack.mlir  | 60 +++++++++++++++++++
 1 file changed, 60 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 7aadf190695630..2f7f2ff5211bf7 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
+// be lowered to insert_slice.
+// CHECK-LABEL: func.func @pack_disallowed_as_pad(
+func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+  // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
+  // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
+  //      CHECK: %[[PAD:.*]] = tensor.pad %[[ARG0]]
+  //  CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
+  //      CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
+  //      CHECK: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
+  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+  return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.pack">
+    transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}: (!transform.op<"tensor.pack">)
+      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+      transform.yield
+  }
+}
+
+// -----
+
 // Check that we don't lower the following pack as a pad.
 // Although all the outer most dimensions in the resulting shape are 1s,
 // some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not 
+// be lowered to extract_slice.
+// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
+func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+
+  // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
+  // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
+  //  CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+  //      CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
+  //      CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
+  //      CHECK: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
+  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
+  return %pack : tensor<129x47x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
+    transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}: (!transform.op<"tensor.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+          transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func.func @pack_with_outer_dims_perm(
 func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
                                      %dest: tensor<200x4x16x100x16x32xi32>)

>From 3ad8cd64f687afa5cadc57e40aa228d6a587ed03 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 5 Dec 2024 15:42:03 +0000
Subject: [PATCH 5/7] Add test to verify pack/producer unpack/consumer fusion

---
 .../transform-tile-and-fuse-pack-unpack.mlir  | 121 ++++++++++++++++++
 1 file changed, 121 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir

diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
new file mode 100644
index 00000000000000..31c28a852eef23
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -0,0 +1,121 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s
+
+// For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated.
+// This allows linalg.transpose to be fused as a producer operation. Alternatively, without this attribute
+// insert_slice will be generated and fusion blocked.
+
+module {
+  // CHECK-label: func @fuse_pack_as_producer
+  // CHECK:       scf.forall {{.*}} {
+  // CHECK:         linalg.transpose
+  // CHECK:         linalg.generic
+  // CHECK:         scf.forall.in_parallel
+  // CHECK:       }
+  func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
+      -> tensor<4x4x128x256xf32> {
+    %dest = tensor.empty() : tensor<1x1x128x256xf32>
+    %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+        into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
+
+    %out = tensor.empty() : tensor<4x4x128x256xf32>
+    %res = linalg.generic
+        {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>],
+         iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
+        outs(%out: tensor<4x4x128x256xf32>) {
+      ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
+        %r = arith.addf %pack_elem, %other_elem : f32
+        linalg.yield %r : f32
+    } -> tensor<4x4x128x256xf32>
+
+    return %res : tensor<4x4x128x256xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      // Find and lower pack operation.
+      %pack = transform.structured.match ops{["tensor.pack"]} in %arg1
+        : (!transform.any_op) -> !transform.op<"tensor.pack">
+      %paded, %expanded, %transpose = transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}
+        : (!transform.op<"tensor.pack">)
+        -> (!transform.op<"tensor.pad">,
+            !transform.op<"tensor.expand_shape">,
+            !transform.op<"linalg.transpose">)
+
+      %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+          : (!transform.any_op) -> !transform.any_op
+      // Tile the lialg operation with parallel forall loop tiling [4, 4].
+      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+      // Fuse the transpose operation into the tiled loop.
+      transform.structured.fuse_into_containing_op %transpose into %forall_op
+          : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
+// -----
+// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
+// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
+// extract_slice will be generated and fusion blocked.
+
+module {
+  // CHECK-label: func @fuse_unpack_as_consumer
+  // CHECK:       scf.forall {{.*}} {
+  // CHECK:         linalg.generic
+  // CHECK:         linalg.transpose
+  // CHECK:         scf.forall.in_parallel
+  // CHECK:       }
+  func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
+      -> tensor<128x256xf32> {
+    %out = tensor.empty() : tensor<1x1x128x256xf32>
+    %res = linalg.generic
+        {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (0, 0, k, l)>],
+         iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
+        outs(%out: tensor<1x1x128x256xf32>) {
+      ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
+        %r = arith.addf %unpack_elem, %other_elem : f32
+        linalg.yield %r : f32
+    } -> tensor<1x1x128x256xf32>
+
+    %dest = tensor.empty() : tensor<128x256xf32>
+    %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+        into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
+
+    return %unpack : tensor<128x256xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      // Find and lower unpack operation.
+      %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
+          : (!transform.any_op) -> !transform.op<"tensor.unpack">
+      transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}
+        : (!transform.op<"tensor.unpack">)
+        -> (!transform.op<"tensor.empty">,
+            !transform.op<"linalg.transpose">,
+            !transform.op<"tensor.collapse_shape">,
+            !transform.op<"tensor.extract_slice">)
+
+      %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+          : (!transform.any_op) -> !transform.any_op
+      // Tile the lialg operation with parallel forall loop tiling [4, 4].
+      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+      // Fuse the consumer operation into the tiled loop.
+      %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
+          : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
+      transform.test.fuse_consumer %slice_op
+        : (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}

>From 68b53288b08777ff1e895a7d0f54cdd97373bc25 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 9 Dec 2024 16:12:35 +0000
Subject: [PATCH 6/7] Adding additional negative test cases

- Added additional test cases to demonstrate insert/extract slice will
  block producer/consumer fusion
- Readability enahncements
---
 .../Dialect/Linalg/transform-lower-pack.mlir  |  34 ++---
 .../transform-tile-and-fuse-pack-unpack.mlir  | 117 ++++++++++++++++++
 2 files changed, 134 insertions(+), 17 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 2f7f2ff5211bf7..5f8ff36a165786 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -98,15 +98,15 @@ module attributes {transform.with_named_sequence} {
 
 // This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
 // be lowered to insert_slice.
-// CHECK-LABEL: func.func @pack_disallowed_as_pad(
-func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+// CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice(
+func.func @pack_as_pad_disabled_insert_slice(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
   // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
   // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
-  //      CHECK: %[[PAD:.*]] = tensor.pad %[[ARG0]]
+  //  CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]]
   //  CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
   //      CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
-  //      CHECK: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
+  //  CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
   %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
     : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
   return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
@@ -261,18 +261,18 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not 
+// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
 // be lowered to extract_slice.
-// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
-func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+// CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice(
+func.func @unpack_as_pad_disabled_extract_slice(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
-  // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
-  //  CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
-  //      CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
-  //      CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
-  //      CHECK: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
+  // CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
+  // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+  //     CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
+  //     CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
+  // CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
   %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
     : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
   return %pack : tensor<129x47x16x16xf32>
@@ -632,7 +632,7 @@ func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
-      : (!transform.any_op) -> !transform.op<"tensor.unpack"> 
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
     transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
           -> (!transform.op<"tensor.empty">,
           !transform.op<"linalg.transpose">,
@@ -687,9 +687,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: @unpack_with_outer_dims_perm
 //  CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
 //       CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
-//       CHECK: %[[TRAN:.*]] = linalg.transpose 
-//  CHECK-SAME:   ins(%[[ARG1]] : tensor<2x4x32x8xf32>) 
-//  CHECK-SAME:   outs(%[[EMPTY]] : tensor<4x8x2x32xf32>) 
+//       CHECK: %[[TRAN:.*]] = linalg.transpose
+//  CHECK-SAME:   ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
+//  CHECK-SAME:   outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
 //  CHECK-SAME:   permutation = [1, 3, 0, 2]
 //       CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
 //  CHECK-SAME:   : tensor<4x8x2x32xf32> into tensor<32x64xf32>
@@ -698,7 +698,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK: linalg.copy ins(%[[SLICE]]
 //  CHECK-SAME:   : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
 func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
-  %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0] 
+  %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
     inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
   return %unpack : tensor<32x64xf32>
 }
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index 31c28a852eef23..ffed9ab6e06535 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -58,6 +58,62 @@ module {
   }
 }
 
+// -----
+// For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion.
+
+module {
+  // CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice
+  // CHECK:       tensor.insert_slice
+  // CHECK:       scf.forall {{.*}} {
+  // CHECK:         scf.forall.in_parallel
+  // CHECK:       }
+  func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
+      -> tensor<4x4x128x256xf32> {
+    %dest = tensor.empty() : tensor<1x1x128x256xf32>
+    %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+        into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
+
+    %out = tensor.empty() : tensor<4x4x128x256xf32>
+    %res = linalg.generic
+        {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>],
+         iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
+        outs(%out: tensor<4x4x128x256xf32>) {
+      ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
+        %r = arith.addf %pack_elem, %other_elem : f32
+        linalg.yield %r : f32
+    } -> tensor<4x4x128x256xf32>
+
+    return %res : tensor<4x4x128x256xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      // Find and lower pack operation.
+      %pack = transform.structured.match ops{["tensor.pack"]} in %arg1
+        : (!transform.any_op) -> !transform.op<"tensor.pack">
+      %paded, %expanded, %transpose = transform.structured.lower_pack %pack
+        : (!transform.op<"tensor.pack">)
+        -> (!transform.op<"tensor.pad">,
+            !transform.op<"tensor.expand_shape">,
+            !transform.op<"linalg.transpose">)
+
+      %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+          : (!transform.any_op) -> !transform.any_op
+      // Tile the lialg operation with parallel forall loop tiling [4, 4].
+      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+      // Fuse the transpose operation into the tiled loop.
+      transform.structured.fuse_into_containing_op %transpose into %forall_op
+          : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
 // -----
 // For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
 // This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
@@ -119,3 +175,64 @@ module {
     }
   }
 }
+
+// -----
+// For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion.
+
+module {
+  // CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice
+  // CHECK:       scf.forall {{.*}} {
+  // CHECK:         linalg.generic
+  // CHECK:         scf.forall.in_parallel
+  // CHECK:       }
+  // CHECK:       tensor.extract_slice
+  func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
+      -> tensor<128x256xf32> {
+    %out = tensor.empty() : tensor<1x1x128x256xf32>
+    %res = linalg.generic
+        {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (0, 0, k, l)>],
+         iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
+        outs(%out: tensor<1x1x128x256xf32>) {
+      ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
+        %r = arith.addf %unpack_elem, %other_elem : f32
+        linalg.yield %r : f32
+    } -> tensor<1x1x128x256xf32>
+
+    %dest = tensor.empty() : tensor<128x256xf32>
+    %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+        into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
+
+    return %unpack : tensor<128x256xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      // Find and lower unpack operation.
+      %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
+          : (!transform.any_op) -> !transform.op<"tensor.unpack">
+      transform.structured.lower_unpack %unpack
+        : (!transform.op<"tensor.unpack">)
+        -> (!transform.op<"tensor.empty">,
+            !transform.op<"linalg.transpose">,
+            !transform.op<"tensor.collapse_shape">,
+            !transform.op<"tensor.extract_slice">)
+
+      %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+          : (!transform.any_op) -> !transform.any_op
+      // Tile the lialg operation with parallel forall loop tiling [4, 4].
+      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+      // Fuse the consumer operation into the tiled loop.
+      %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
+          : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
+      // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
+      // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
+      // to fuse" error.
+      transform.yield
+    }
+  }
+}

>From fbd32d18ced9f7bd12ee4418e6c9bd71cf221158 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 9 Dec 2024 22:00:44 +0000
Subject: [PATCH 7/7] Add additional LIT variable

This help to clearly demonstrate the produer fusion in pack case and
consumer fusion in unpack case.
---
 .../transform-tile-and-fuse-pack-unpack.mlir  | 30 ++++++++++---------
 1 file changed, 16 insertions(+), 14 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index ffed9ab6e06535..f4c9e613ad70b3 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -1,14 +1,14 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s
 
 // For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated.
-// This allows linalg.transpose to be fused as a producer operation. Alternatively, without this attribute
-// insert_slice will be generated and fusion blocked.
+// This allows linalg.transpose to be fused as a producer operation. In below testcase, linalg.transpose
+// as a producer operation is fused into the scf.forall loop.
 
 module {
   // CHECK-label: func @fuse_pack_as_producer
   // CHECK:       scf.forall {{.*}} {
-  // CHECK:         linalg.transpose
-  // CHECK:         linalg.generic
+  // CHECK:         %[[PRODUCER:.*]] = linalg.transpose ins(%[[EXPANDED]]
+  // CHECK:         linalg.generic {{.*}} ins(%[[PRODUCER]]
   // CHECK:         scf.forall.in_parallel
   // CHECK:       }
   func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
@@ -60,11 +60,13 @@ module {
 
 // -----
 // For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion.
+// In below testcase, tensor.insert_slice as a producer operation cannot be fused into the scf.forall loop.
 
 module {
   // CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice
-  // CHECK:       tensor.insert_slice
+  // CHECK:       %[[PRODUCER:.*]] = tensor.insert_slice
   // CHECK:       scf.forall {{.*}} {
+  // CHECK:         linalg.generic {{.*}} ins(%[[PRODUCER]]
   // CHECK:         scf.forall.in_parallel
   // CHECK:       }
   func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
@@ -116,14 +118,13 @@ module {
 
 // -----
 // For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
-// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
-// extract_slice will be generated and fusion blocked.
-
+// This allows linalg.transpose to be fused as a consumer operation. In below testcase, linalg.transpose
+// as a consumer operation is fused into the scf.forall loop.
 module {
   // CHECK-label: func @fuse_unpack_as_consumer
   // CHECK:       scf.forall {{.*}} {
-  // CHECK:         linalg.generic
-  // CHECK:         linalg.transpose
+  // CHECK:         %[[CONSUMER:.*]] = linalg.generic
+  // CHECK:         linalg.transpose ins(%[[CONSUMER]]
   // CHECK:         scf.forall.in_parallel
   // CHECK:       }
   func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
@@ -178,14 +179,15 @@ module {
 
 // -----
 // For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion.
-
+// In below testcase, tensor.extract_slice as a consumer operation cannot be fused into the scf.forall loop.
 module {
   // CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice
-  // CHECK:       scf.forall {{.*}} {
-  // CHECK:         linalg.generic
+  // CHECK:       %[[CONSUMER:.*]] = scf.forall {{.*}} {
+  // CHECK:         %[[ADDF:.*]] = linalg.generic
   // CHECK:         scf.forall.in_parallel
+  // CHECK:           tensor.parallel_insert_slice %[[ADDF]]
   // CHECK:       }
-  // CHECK:       tensor.extract_slice
+  // CHECK:       tensor.extract_slice %[[CONSUMER]]
   func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
       -> tensor<128x256xf32> {
     %out = tensor.empty() : tensor<1x1x128x256xf32>



More information about the Mlir-commits mailing list