[Mlir-commits] [mlir] [mlir][linalg] Fix empty outer dim case for packing reshape op (PR #96732)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 26 18:24:54 PDT 2024


https://github.com/yifeizh2 updated https://github.com/llvm/llvm-project/pull/96732

>From 52b58dc5bb5562e0bb77614991b4a14d6d375859 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Tue, 25 Jun 2024 22:22:07 -0700
Subject: [PATCH 1/5] [mlir][linalg] Fix empty outer dim case for packing
 reshape op

---
 .../Linalg/Transforms/DataLayoutPropagation.cpp | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e51ae2264a36a..699bf56f96581 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -641,7 +641,14 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
                                    PatternRewriter &rewriter) {
   SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
-  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+  auto numOuterDims =
+      dyn_cast<RankedTensorType>(packOp.getDpsInputs()[0].getType())
+          .getShape()
+          .size();
+  SmallVector<int64_t> outerDimsPerm =
+      packOp.getOuterDimsPerm().empty()
+          ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
+          : SmallVector<int64_t>(packOp.getOuterDimsPerm());
 
   ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
   SmallVector<ReassociationIndices> reassocIndices =
@@ -885,8 +892,12 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
                                    PatternRewriter &rewriter) {
   SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
-  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
-
+  auto numOuterDims =
+      dyn_cast<RankedTensorType>(unPackOp.getType()).getShape().size();
+  SmallVector<int64_t> outerDimsPerm =
+      unPackOp.getOuterDimsPerm().empty()
+          ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
+          : SmallVector<int64_t>(unPackOp.getOuterDimsPerm());
   auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
   if (!expandTy)
     return failure();

>From c74adf754c33a7cdf985c5be7046b7d40a46034c Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Tue, 25 Jun 2024 23:36:46 -0700
Subject: [PATCH 2/5] add unit tests

---
 .../Linalg/data-layout-propagation.mlir       | 33 +++++++++++++++++++
 1 file changed, 33 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 78505d0aa4140..88014594d6129 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -926,6 +926,24 @@ func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index)
 
 // -----
 
+func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
+  %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32>
+  %pack = tensor.pack %collapsed inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+  func.return %pack : tensor<?x4x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
+
+// -----
+
 func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
   %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
   %2 = tensor.empty() : tensor<4x32x3072x8x1xf32>
@@ -1269,6 +1287,21 @@ func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index
 
 // -----
 
+func.func @push_down_unpack_through_expand_empty_outer_dims_perm(%5: tensor<4x384x32x8x8xf32>) -> tensor<4x12x256x256xf32> {
+  %6 = tensor.empty() : tensor<4x3072x256xf32>
+  %unpack = tensor.unpack %5 inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x384x32x8x8xf32> -> tensor<4x3072x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] output_shape [4, 12, 256, 256] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
+  func.return %expanded : tensor<4x12x256x256xf32>
+}
+// CHECK-LABEL: func.func @push_down_unpack_through_expand_empty_outer_dims_perm
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4], [5]] output_shape [4, 12, 32, 32, 8, 8] : tensor<4x384x32x8x8xf32> into tensor<4x12x32x32x8x8xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x12x32x32x8x8xf32> -> tensor<4x12x256x256xf32>
+// CHECK:         return %[[UNPACK]] : tensor<4x12x256x256xf32>
+
+// -----
+
 func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
   %6 = tensor.empty() : tensor<4x3072x256xf32>
   %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>

>From 63cdfcbf20a502024fccbe4d08048ddc6efd8580 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Tue, 25 Jun 2024 23:56:39 -0700
Subject: [PATCH 3/5] update unit test

---
 .../Linalg/data-layout-propagation.mlir       | 24 ++++++++++++-------
 1 file changed, 15 insertions(+), 9 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 88014594d6129..f3ebb494f94b6 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1287,18 +1287,24 @@ func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index
 
 // -----
 
-func.func @push_down_unpack_through_expand_empty_outer_dims_perm(%5: tensor<4x384x32x8x8xf32>) -> tensor<4x12x256x256xf32> {
-  %6 = tensor.empty() : tensor<4x3072x256xf32>
-  %unpack = tensor.unpack %5 inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x384x32x8x8xf32> -> tensor<4x3072x256xf32>
-  %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] output_shape [4, 12, 256, 256] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
-  func.return %expanded : tensor<4x12x256x256xf32>
+func.func @push_down_unpack_through_expand_empty_outer_dims_perm(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
+  %6 = tensor.empty(%dim) : tensor<?x256xf32>
+  %unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32>
+  func.return %expanded : tensor<?x256x256xf32>
 }
 // CHECK-LABEL: func.func @push_down_unpack_through_expand_empty_outer_dims_perm
 // CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4], [5]] output_shape [4, 12, 32, 32, 8, 8] : tensor<4x384x32x8x8xf32> into tensor<4x12x32x32x8x8xf32>
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x12x32x32x8x8xf32> -> tensor<4x12x256x256xf32>
-// CHECK:         return %[[UNPACK]] : tensor<4x12x256x256xf32>
+// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[C32:.+]] = arith.constant 32 : index
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32>
+// CHECK:         %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C32]] : index
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+// CHECK:         %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
+// CHECK:         return %[[UNPACK]] : tensor<?x256x256xf32>
 
 // -----
 

>From b3137a6214c5968b91c0bfb57e32e9aba87c9ec1 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Wed, 26 Jun 2024 01:36:00 -0700
Subject: [PATCH 4/5] update based on comment

---
 .../Transforms/DataLayoutPropagation.cpp      | 21 ++++++-------------
 .../Linalg/data-layout-propagation.mlir       |  4 ++--
 2 files changed, 8 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 699bf56f96581..6b9a073bf5e3d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -605,7 +605,9 @@ static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
 static int64_t applyPermutationAndReindexReassoc(
     SmallVector<ReassociationIndices> &reassocIndices,
     ArrayRef<int64_t> permutation) {
-  applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
+  if (!permutation.empty()) {
+    applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
+  }
   int64_t nextPos = 0;
   for (ReassociationIndices &indices : reassocIndices) {
     for (auto &index : indices) {
@@ -641,14 +643,7 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
                                    PatternRewriter &rewriter) {
   SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
-  auto numOuterDims =
-      dyn_cast<RankedTensorType>(packOp.getDpsInputs()[0].getType())
-          .getShape()
-          .size();
-  SmallVector<int64_t> outerDimsPerm =
-      packOp.getOuterDimsPerm().empty()
-          ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
-          : SmallVector<int64_t>(packOp.getOuterDimsPerm());
+  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
 
   ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
   SmallVector<ReassociationIndices> reassocIndices =
@@ -892,12 +887,8 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
                                    PatternRewriter &rewriter) {
   SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
-  auto numOuterDims =
-      dyn_cast<RankedTensorType>(unPackOp.getType()).getShape().size();
-  SmallVector<int64_t> outerDimsPerm =
-      unPackOp.getOuterDimsPerm().empty()
-          ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
-          : SmallVector<int64_t>(unPackOp.getOuterDimsPerm());
+  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
+
   auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
   if (!expandTy)
     return failure();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index f3ebb494f94b6..626dd8b697e59 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -938,7 +938,7 @@ func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm(%1: tensor<?x16
 // CHECK:         %[[C0:.+]] = arith.constant 0 : index
 // CHECK:         %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
 // CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
-// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
 // CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
 // CHECK:         return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
 
@@ -1303,7 +1303,7 @@ func.func @push_down_unpack_through_expand_empty_outer_dims_perm(%5: tensor<?x32
 // CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
 // CHECK:         %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
 // CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
 // CHECK:         return %[[UNPACK]] : tensor<?x256x256xf32>
 
 // -----

>From fd98e1cdfb64d6e6b4e6e9f868d79937e3946cf2 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Wed, 26 Jun 2024 18:24:16 -0700
Subject: [PATCH 5/5] fix comment

---
 mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6b9a073bf5e3d..6984bc2dff498 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -605,9 +605,8 @@ static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
 static int64_t applyPermutationAndReindexReassoc(
     SmallVector<ReassociationIndices> &reassocIndices,
     ArrayRef<int64_t> permutation) {
-  if (!permutation.empty()) {
+  if (!permutation.empty())
     applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
-  }
   int64_t nextPos = 0;
   for (ReassociationIndices &indices : reassocIndices) {
     for (auto &index : indices) {



More information about the Mlir-commits mailing list