[Mlir-commits] [mlir] [MLIR] Add patterns to bubble-up pack and push-down unpack through collapse/expand shape ops (PR #85297)

Jerry Wu llvmlistbot at llvm.org
Tue Mar 19 10:44:56 PDT 2024


https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/85297

>From 61f1d84869b4bee93d01c4c943640fbb073f62d9 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 8 Mar 2024 23:59:47 +0000
Subject: [PATCH 1/6] Test collapse pack and unpack expand

---
 .../Transforms/DataLayoutPropagation.cpp      | 189 +++++++++++++++++-
 1 file changed, 188 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 5ceb85e7d9903b..4dc52891f4510c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -552,6 +552,192 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
   ControlPropagationFn controlFn;
 };
 
+static LogicalResult
+bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
+                                   tensor::PackOp packOp,
+                                   PatternRewriter &rewriter) {
+  SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
+  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+
+  if (llvm::any_of(innerTileSizes,
+                   [](int64_t size) { return ShapedType::isDynamic(size); })) {
+    return failure();
+  }
+
+  ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
+  SmallVector<ReassociationIndices> reassocIndices =
+      collapseOp.getReassociationIndices();
+  SmallVector<int64_t> baseDimsPos;
+  for (auto pos : innerDimsPos) {
+    baseDimsPos.push_back(reassocIndices[pos].back());
+  }
+  // Check if the base dims before reassociation are divisible by the inner tile
+  // sizes.
+  for (auto [basePos, tileSize] :
+       llvm::zip_equal(baseDimsPos, innerTileSizes)) {
+    int64_t dim = srcShape[basePos];
+    if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) {
+      return failure();
+    }
+  }
+  // Expand the outer dims perm with associated src dims.
+  SmallVector<int64_t> newOuterDimsPerm;
+  for (auto outerPos : outerDimsPerm) {
+    newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+                            reassocIndices[outerPos].begin(),
+                            reassocIndices[outerPos].end());
+  }
+
+  auto emptyOp = tensor::PackOp::createDestinationTensor(
+      rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), baseDimsPos,
+      newOuterDimsPerm);
+  auto newPackOp = rewriter.create<tensor::PackOp>(
+      packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos, packOp.getMixedTiles(),
+      packOp.getPaddingValue(), newOuterDimsPerm);
+
+  SmallVector<ReassociationIndices> newReassocIndices;
+  int64_t currPos = 0;
+  for (auto outerPos : outerDimsPerm) {
+    int64_t start = currPos;
+    int64_t end = start + reassocIndices[outerPos].size();
+    newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
+    currPos = end;
+  }
+  for (auto unused : innerTileSizes) {
+    (void)unused;
+    newReassocIndices.push_back({currPos});
+    currPos += 1;
+  }
+
+  auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
+      collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
+  rewriter.replaceOp(packOp, newCollapseOp);
+
+  return success();
+}
+
+class BubbleUpPackOpThroughReshapeOp final
+    : public OpRewritePattern<tensor::PackOp> {
+public:
+  BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
+      : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    if (packOp.getPaddingValue())
+      return failure();
+
+    Operation *srcOp = packOp.getSource().getDefiningOp();
+    if (!srcOp || !(srcOp->getNumResults() == 1) ||
+        !srcOp->getResult(0).hasOneUse())
+      return failure();
+
+    if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(srcOp)) {
+      return bubbleUpPackOpThroughCollapseShape(collapseOp, packOp, rewriter);
+    }
+    return failure();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
+static LogicalResult
+pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
+                                   tensor::ExpandShapeOp expandOp,
+                                   PatternRewriter &rewriter) {
+
+  SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
+  ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
+  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
+
+  if (llvm::any_of(innerTileSizes,
+                   [](int64_t size) { return ShapedType::isDynamic(size); })) {
+    return failure();
+  }
+
+  ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
+  SmallVector<ReassociationIndices> reassocIndices =
+      expandOp.getReassociationIndices();
+  SmallVector<int64_t> baseDimsPos;
+  for (auto pos : innerDimsPos) {
+    baseDimsPos.push_back(reassocIndices[pos].back());
+  }
+  // Check if the base dims after reassociation are divisible by the inner tile
+  // sizes.
+  for (auto [basePos, tileSize] :
+       llvm::zip_equal(baseDimsPos, innerTileSizes)) {
+    int64_t dim = dstShape[basePos];
+    if (ShapedType::isDynamic(dim) || dstShape[basePos] % tileSize != 0) {
+      return failure();
+    }
+  }
+  // Expand the outer dims perm with associated src dims.
+  SmallVector<int64_t> newOuterDimsPerm;
+  for (auto outerPos : outerDimsPerm) {
+    newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+                            reassocIndices[outerPos].begin(),
+                            reassocIndices[outerPos].end());
+  }
+
+  SmallVector<ReassociationIndices> newReassocIndices;
+  int64_t currPos = 0;
+  for (auto outerPos : outerDimsPerm) {
+    int64_t start = currPos;
+    int64_t end = start + reassocIndices[outerPos].size();
+    newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
+    currPos = end;
+  }
+  for (auto unused : innerTileSizes) {
+    (void)unused;
+    newReassocIndices.push_back({currPos});
+    currPos += 1;
+  }
+
+  RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
+      expandOp.getType(), innerTileSizes, baseDimsPos, newOuterDimsPerm);
+  auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+      expandOp.getLoc(), newExpandType, unPackOp.getSource(),
+      newReassocIndices);
+
+  auto emptyOp = tensor::UnPackOp::createDestinationTensor(
+      rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), baseDimsPos,
+      newOuterDimsPerm);
+  auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
+      unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, baseDimsPos,
+      unPackOp.getMixedTiles(), newOuterDimsPerm);
+  rewriter.replaceOp(expandOp, newUnPackOp);
+
+  return success();
+}
+
+class PushDownUnPackOpThroughReshapeOp final
+    : public OpRewritePattern<tensor::UnPackOp> {
+public:
+  PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
+                                   ControlPropagationFn fun)
+      : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
+  }
+
+  LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
+                                PatternRewriter &rewriter) const override {
+    Value result = unPackOp.getResult();
+    if (!result.hasOneUse()) {
+      return failure();
+    }
+    Operation *userOp = *result.user_begin();
+
+    if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(userOp)) {
+      return pushDownUnPackOpThroughExpandShape(unPackOp, expandOp, rewriter);
+    }
+    return failure();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 // TODO: Relax this restriction. We should unpack a generic op also
 // in the presence of multiple unpack ops as producers.
 /// Return the unpacked operand, if present, for the current generic op.
@@ -774,6 +960,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
     const ControlPropagationFn &controlPackUnPackPropagation) {
   patterns
       .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
-              PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
+              BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
+              PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
           patterns.getContext(), controlPackUnPackPropagation);
 }

>From a0b00a0d92d69a762131baae98c6f23d842879bf Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 11 Mar 2024 21:25:34 +0000
Subject: [PATCH 2/6] Handle unit dim

---
 .../Transforms/DataLayoutPropagation.cpp      | 46 +++++++++++++------
 1 file changed, 32 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 4dc52891f4510c..e230a11f9f2c0e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -552,6 +552,26 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
   ControlPropagationFn controlFn;
 };
 
+static SmallVector<int64_t>
+projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
+                                 ArrayRef<ReassociationIndices> reassocIndices,
+                                 ArrayRef<int64_t> baseShape) {
+  SmallVector<int64_t> projectedDimsPos;
+  for (auto pos : dimsPos) {
+    int64_t projectedPos = -1;
+    for (auto it = reassocIndices[pos].rbegin();
+         it != reassocIndices[pos].rend(); ++it) {
+      projectedPos = *it;
+      if (baseShape[projectedPos] > 1) {
+        break;
+      }
+    }
+    assert(projectedPos != -1 && "projected dim not found");
+    projectedDimsPos.push_back(projectedPos);
+  }
+  return projectedDimsPos;
+}
+
 static LogicalResult
 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
                                    tensor::PackOp packOp,
@@ -568,10 +588,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
   ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
   SmallVector<ReassociationIndices> reassocIndices =
       collapseOp.getReassociationIndices();
-  SmallVector<int64_t> baseDimsPos;
-  for (auto pos : innerDimsPos) {
-    baseDimsPos.push_back(reassocIndices[pos].back());
-  }
+  SmallVector<int64_t> baseDimsPos =
+      projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
+
   // Check if the base dims before reassociation are divisible by the inner tile
   // sizes.
   for (auto [basePos, tileSize] :
@@ -590,11 +609,11 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
   }
 
   auto emptyOp = tensor::PackOp::createDestinationTensor(
-      rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), baseDimsPos,
-      newOuterDimsPerm);
+      rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
+      baseDimsPos, newOuterDimsPerm);
   auto newPackOp = rewriter.create<tensor::PackOp>(
-      packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos, packOp.getMixedTiles(),
-      packOp.getPaddingValue(), newOuterDimsPerm);
+      packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos,
+      packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
 
   SmallVector<ReassociationIndices> newReassocIndices;
   int64_t currPos = 0;
@@ -660,10 +679,9 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
   ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
   SmallVector<ReassociationIndices> reassocIndices =
       expandOp.getReassociationIndices();
-  SmallVector<int64_t> baseDimsPos;
-  for (auto pos : innerDimsPos) {
-    baseDimsPos.push_back(reassocIndices[pos].back());
-  }
+  SmallVector<int64_t> baseDimsPos =
+      projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
+
   // Check if the base dims after reassociation are divisible by the inner tile
   // sizes.
   for (auto [basePos, tileSize] :
@@ -702,8 +720,8 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
       newReassocIndices);
 
   auto emptyOp = tensor::UnPackOp::createDestinationTensor(
-      rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), baseDimsPos,
-      newOuterDimsPerm);
+      rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
+      baseDimsPos, newOuterDimsPerm);
   auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
       unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, baseDimsPos,
       unPackOp.getMixedTiles(), newOuterDimsPerm);

>From ff6aa3a1fc4b655efe1d4e2f827cc324e54fa810 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 14 Mar 2024 19:10:54 +0000
Subject: [PATCH 3/6] Add test draft

---
 .../Linalg/data-layout-propagation.mlir       | 56 +++++++++++++++++++
 1 file changed, 56 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index e036695a2ac9fd..0344c483226af6 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -905,3 +905,59 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
 // CHECK-SAME:      into %[[UNPACK_NEW_DEST]]
 // CHECK:         return %[[UNPACK]] : tensor<16x540x960xi32>
+
+func.func @bubble_up_pack_through_collapse(%1: tensor<192x16x64x4xf32>) -> tensor<384x256x8x1xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2, 3]] : tensor<192x16x64x4xf32> into tensor<3072x256xf32>
+  %2 = tensor.empty() : tensor<384x256x8x1xf32>
+  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<3072x256xf32> -> tensor<384x256x8x1xf32>
+  func.return %pack : tensor<384x256x8x1xf32>
+}
+
+func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
+  %2 = tensor.empty() : tensor<4x32x3072x8x1xf32>
+  %pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32>
+  func.return %pack : tensor<4x32x3072x8x1xf32>
+}
+
+func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32>
+  %2 = tensor.empty() : tensor<8x4x8x1xf32>
+  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32>
+  func.return %pack : tensor<8x4x8x1xf32>
+}
+
+func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+  %2 = tensor.empty() : tensor<384x32x8x8xf32>
+  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32>
+  func.return %pack : tensor<384x32x8x8xf32>
+}
+
+func.func @push_down_unpack_through_expand(%5: tensor<384x32x8x8xf32>) -> tensor<12x256x256xf32> {
+  %6 = tensor.empty() : tensor<3072x256xf32>
+  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<12x256x256xf32>
+  func.return %expanded : tensor<12x256x256xf32>
+}
+
+func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
+  %6 = tensor.empty() : tensor<4x3072x256xf32>
+  %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
+  func.return %expanded : tensor<4x12x256x256xf32>
+}
+
+func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> {
+  %6 = tensor.empty() : tensor<48x256xf32>
+  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<6x32x8x8xf32> -> tensor<48x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
+  func.return %expanded : tensor<3x16x1x256xf32>
+}
+
+func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
+  %6 = tensor.empty() : tensor<3072x256xf32>
+  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+  func.return %expanded : tensor<256x12x256xf32>
+}

>From 3f79dc6924054456a444ee6edf92fed3334d3e9c Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 14 Mar 2024 20:56:52 +0000
Subject: [PATCH 4/6] Refactor

---
 .../Transforms/DataLayoutPropagation.cpp      | 178 +++++++++++-------
 1 file changed, 110 insertions(+), 68 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e230a11f9f2c0e..0d53205b8170c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
 
@@ -572,6 +573,39 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
   return projectedDimsPos;
 }
 
+static int64_t applyPermutationAndReindexReassoc(
+    SmallVector<ReassociationIndices> &reassociationIndices,
+    ArrayRef<int64_t> dimsPerm) {
+  applyPermutationToVector<ReassociationIndices>(reassociationIndices,
+                                                 dimsPerm);
+  int64_t lastPos = 0;
+  for (ReassociationIndices &indices : reassociationIndices) {
+    for (auto &index : indices) {
+      index = lastPos;
+      lastPos += 1;
+    }
+  }
+  return lastPos;
+}
+
+/// Bubble up pack op through collapse shape op when the packed dims can be
+/// mapped to the source dims before collapsing. This is possible when the inner
+/// tile sizes can divide the mapped source dims.
+///
+/// For example:
+///
+/// %collapsed = tensor.collapse_shape %in [[0, 1], 2] : tensor<?x16x4xf32> into
+/// tensor<?x4xf32> %out = tensor.empty() : tensor<?x4x8x1xf32> %pack =
+/// tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
+/// inner_tiles = [8, 1] into %out : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+///
+/// Can be transformed into:
+///
+/// %out = tensor.empty() : tensor<?x2x4x8x1xf32>
+/// %pack = tensor.pack %in outer_dims_perm = [1, 2] inner_dims_pos = [1, 2]
+/// inner_tiles = [8, 1] into %out : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+/// %collapsed = tensor.collapse_shape %1 [[0, 1], 2, 3, 4] :
+/// tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
 static LogicalResult
 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
                                    tensor::PackOp packOp,
@@ -580,27 +614,23 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
 
-  if (llvm::any_of(innerTileSizes,
-                   [](int64_t size) { return ShapedType::isDynamic(size); })) {
-    return failure();
-  }
-
   ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
   SmallVector<ReassociationIndices> reassocIndices =
       collapseOp.getReassociationIndices();
-  SmallVector<int64_t> baseDimsPos =
+  SmallVector<int64_t> projectedInnerDimsPos =
       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
 
-  // Check if the base dims before reassociation are divisible by the inner tile
+  // Check if the projected dims on the source are divisible by the inner tile
   // sizes.
-  for (auto [basePos, tileSize] :
-       llvm::zip_equal(baseDimsPos, innerTileSizes)) {
-    int64_t dim = srcShape[basePos];
-    if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) {
+  for (auto [projectedPos, tileSize] :
+       llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
+    int64_t dim = srcShape[projectedPos];
+    if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
       return failure();
-    }
   }
-  // Expand the outer dims perm with associated src dims.
+  // Expand the outer dims permutation with the associated source dims for the
+  // new permutation after bubbling. This is because moving a collapsed dim is
+  // equivalent to moving the associated source dims together.
   SmallVector<int64_t> newOuterDimsPerm;
   for (auto outerPos : outerDimsPerm) {
     newOuterDimsPerm.insert(newOuterDimsPerm.end(),
@@ -610,23 +640,19 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
 
   auto emptyOp = tensor::PackOp::createDestinationTensor(
       rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
-      baseDimsPos, newOuterDimsPerm);
+      projectedInnerDimsPos, newOuterDimsPerm);
   auto newPackOp = rewriter.create<tensor::PackOp>(
-      packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos,
+      packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
       packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
 
-  SmallVector<ReassociationIndices> newReassocIndices;
-  int64_t currPos = 0;
-  for (auto outerPos : outerDimsPerm) {
-    int64_t start = currPos;
-    int64_t end = start + reassocIndices[outerPos].size();
-    newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
-    currPos = end;
-  }
-  for (auto unused : innerTileSizes) {
-    (void)unused;
-    newReassocIndices.push_back({currPos});
-    currPos += 1;
+  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
+  // First build reassociations on the outer dims after the permutation.
+  int64_t lastPos =
+      applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
+  // Then add direct mapping for the inner tile dims.
+  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
+    newReassocIndices.push_back({lastPos});
+    lastPos += 1;
   }
 
   auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
@@ -644,18 +670,28 @@ class BubbleUpPackOpThroughReshapeOp final
 
   LogicalResult matchAndRewrite(tensor::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
-    if (packOp.getPaddingValue())
+    // User controlled propagation function.
+    if (!controlFn(packOp))
       return failure();
 
     Operation *srcOp = packOp.getSource().getDefiningOp();
+    // Currently only support when the pack op is the only user.
     if (!srcOp || !(srcOp->getNumResults() == 1) ||
-        !srcOp->getResult(0).hasOneUse())
+        !srcOp->getResult(0).hasOneUse()) {
       return failure();
-
-    if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(srcOp)) {
-      return bubbleUpPackOpThroughCollapseShape(collapseOp, packOp, rewriter);
     }
-    return failure();
+    // Currently only support static inner tile sizes.
+    if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
+          return ShapedType::isDynamic(size);
+        })) {
+      return failure();
+    }
+
+    return TypeSwitch<Operation *, LogicalResult>(srcOp)
+        .Case([&](tensor::CollapseShapeOp op) {
+          return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
+        })
+        .Default([](Operation *) { return failure(); });
   }
 
 private:
@@ -666,32 +702,29 @@ static LogicalResult
 pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
                                    tensor::ExpandShapeOp expandOp,
                                    PatternRewriter &rewriter) {
-
   SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
 
-  if (llvm::any_of(innerTileSizes,
-                   [](int64_t size) { return ShapedType::isDynamic(size); })) {
-    return failure();
-  }
-
   ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
   SmallVector<ReassociationIndices> reassocIndices =
       expandOp.getReassociationIndices();
-  SmallVector<int64_t> baseDimsPos =
+  SmallVector<int64_t> projectedInnerDimsPos =
       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
 
-  // Check if the base dims after reassociation are divisible by the inner tile
+  // Check if the projected dims on the dest are divisible by the inner tile
   // sizes.
-  for (auto [basePos, tileSize] :
-       llvm::zip_equal(baseDimsPos, innerTileSizes)) {
-    int64_t dim = dstShape[basePos];
-    if (ShapedType::isDynamic(dim) || dstShape[basePos] % tileSize != 0) {
+  for (auto [projectedPos, tileSize] :
+       llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
+    int64_t dim = dstShape[projectedPos];
+    if (ShapedType::isDynamic(dim) ||
+        (dstShape[projectedPos] % tileSize) != 0) {
       return failure();
     }
   }
-  // Expand the outer dims perm with associated src dims.
+  // Expand the outer dims permutation with the associated expanded dims for the
+  // new permutation after pushing. This is because moving a source dim is
+  // equivalent to moving the associated expanded dims together.
   SmallVector<int64_t> newOuterDimsPerm;
   for (auto outerPos : outerDimsPerm) {
     newOuterDimsPerm.insert(newOuterDimsPerm.end(),
@@ -699,32 +732,29 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
                             reassocIndices[outerPos].end());
   }
 
-  SmallVector<ReassociationIndices> newReassocIndices;
-  int64_t currPos = 0;
-  for (auto outerPos : outerDimsPerm) {
-    int64_t start = currPos;
-    int64_t end = start + reassocIndices[outerPos].size();
-    newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
-    currPos = end;
-  }
-  for (auto unused : innerTileSizes) {
-    (void)unused;
-    newReassocIndices.push_back({currPos});
-    currPos += 1;
+  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
+  // First build reassociations on the outer dims after the permutation.
+  int64_t lastPos =
+      applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
+  // Then add direct mapping for the inner tile dims.
+  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
+    newReassocIndices.push_back({lastPos});
+    lastPos += 1;
   }
 
-  RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
-      expandOp.getType(), innerTileSizes, baseDimsPos, newOuterDimsPerm);
+  RankedTensorType newExpandType =
+      tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
+                                      projectedInnerDimsPos, newOuterDimsPerm);
   auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
       expandOp.getLoc(), newExpandType, unPackOp.getSource(),
       newReassocIndices);
 
   auto emptyOp = tensor::UnPackOp::createDestinationTensor(
       rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
-      baseDimsPos, newOuterDimsPerm);
+      projectedInnerDimsPos, newOuterDimsPerm);
   auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
-      unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, baseDimsPos,
-      unPackOp.getMixedTiles(), newOuterDimsPerm);
+      unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
+      projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
   rewriter.replaceOp(expandOp, newUnPackOp);
 
   return success();
@@ -740,16 +770,28 @@ class PushDownUnPackOpThroughReshapeOp final
 
   LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
+    // User controlled propagation function.
+    if (!controlFn(unPackOp))
+      return failure();
+
     Value result = unPackOp.getResult();
+    // Currently only support unpack op with the single user.
     if (!result.hasOneUse()) {
       return failure();
     }
-    Operation *userOp = *result.user_begin();
-
-    if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(userOp)) {
-      return pushDownUnPackOpThroughExpandShape(unPackOp, expandOp, rewriter);
+    // Currently only support static inner tile sizes.
+    if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
+          return ShapedType::isDynamic(size);
+        })) {
+      return failure();
     }
-    return failure();
+
+    Operation *userOp = *result.user_begin();
+    return TypeSwitch<Operation *, LogicalResult>(userOp)
+        .Case([&](tensor::ExpandShapeOp op) {
+          return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
+        })
+        .Default([](Operation *) { return failure(); });
   }
 
 private:

>From 370ae551f5933266bd7bbeaedcb0eb2f596ca1f3 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 18 Mar 2024 22:45:19 +0000
Subject: [PATCH 5/6] Finish tests

---
 .../Transforms/DataLayoutPropagation.cpp      | 94 ++++++++++++-------
 .../Linalg/data-layout-propagation.mlir       | 88 +++++++++++++++--
 2 files changed, 139 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0d53205b8170c6..9b76da2cf97368 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -559,20 +559,34 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
                                  ArrayRef<int64_t> baseShape) {
   SmallVector<int64_t> projectedDimsPos;
   for (auto pos : dimsPos) {
-    int64_t projectedPos = -1;
+    // In the case all dims are unit, this will return the inner-most one.
+    int64_t projectedPos = reassocIndices[pos].back();
     for (auto it = reassocIndices[pos].rbegin();
          it != reassocIndices[pos].rend(); ++it) {
-      projectedPos = *it;
-      if (baseShape[projectedPos] > 1) {
+      int64_t dim = baseShape[*it];
+      if (dim > 1 || ShapedType::isDynamic(dim)) {
+        projectedPos = *it;
         break;
       }
     }
-    assert(projectedPos != -1 && "projected dim not found");
     projectedDimsPos.push_back(projectedPos);
   }
   return projectedDimsPos;
 }
 
+static bool
+isProjectedDimsDivisibleByTileSizes(ArrayRef<int64_t> projectedDimsPos,
+                                    ArrayRef<int64_t> targetShape,
+                                    ArrayRef<int64_t> tileSizes) {
+  for (auto [projectedPos, tileSize] :
+       llvm::zip_equal(projectedDimsPos, tileSizes)) {
+    int64_t dim = targetShape[projectedPos];
+    if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
+      return false;
+  }
+  return true;
+}
+
 static int64_t applyPermutationAndReindexReassoc(
     SmallVector<ReassociationIndices> &reassociationIndices,
     ArrayRef<int64_t> dimsPerm) {
@@ -589,23 +603,24 @@ static int64_t applyPermutationAndReindexReassoc(
 }
 
 /// Bubble up pack op through collapse shape op when the packed dims can be
-/// mapped to the source dims before collapsing. This is possible when the inner
-/// tile sizes can divide the mapped source dims.
+/// projected to the dims before collapsing. This is possible when the inner
+/// tile sizes can divide the projected dims.
 ///
 /// For example:
 ///
-/// %collapsed = tensor.collapse_shape %in [[0, 1], 2] : tensor<?x16x4xf32> into
-/// tensor<?x4xf32> %out = tensor.empty() : tensor<?x4x8x1xf32> %pack =
-/// tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
-/// inner_tiles = [8, 1] into %out : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
+///     : tensor<?x16x4xf32> into tensor<?x4xf32>
+/// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
+///     inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
+///     : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
 ///
 /// Can be transformed into:
 ///
-/// %out = tensor.empty() : tensor<?x2x4x8x1xf32>
-/// %pack = tensor.pack %in outer_dims_perm = [1, 2] inner_dims_pos = [1, 2]
-/// inner_tiles = [8, 1] into %out : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
-/// %collapsed = tensor.collapse_shape %1 [[0, 1], 2, 3, 4] :
-/// tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
+/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
+///     inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
+///     : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
+///     : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
 static LogicalResult
 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
                                    tensor::PackOp packOp,
@@ -620,13 +635,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
   SmallVector<int64_t> projectedInnerDimsPos =
       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
 
-  // Check if the projected dims on the source are divisible by the inner tile
-  // sizes.
-  for (auto [projectedPos, tileSize] :
-       llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
-    int64_t dim = srcShape[projectedPos];
-    if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
-      return failure();
+  if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
+                                           innerTileSizes)) {
+    return failure();
   }
   // Expand the outer dims permutation with the associated source dims for the
   // new permutation after bubbling. This is because moving a collapsed dim is
@@ -646,7 +657,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
       packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
 
   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
-  // First build reassociations on the outer dims after the permutation.
+  // First apply the permutation on the reassociations of the outer dims.
+  // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+  // -> [[0], [1, 2]]
   int64_t lastPos =
       applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
   // Then add direct mapping for the inner tile dims.
@@ -698,6 +711,25 @@ class BubbleUpPackOpThroughReshapeOp final
   ControlPropagationFn controlFn;
 };
 
+/// Push down unpack op through expand shape op when the packed dims can be
+/// projected to the dims after expanding. This is possible when the inner tile
+/// sizes can divide the projected dims.
+///
+/// For example:
+///
+/// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
+///     inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
+///     : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
+///     : tensor<?x256xf32> into tensor<?x256x256xf32>
+///
+/// Can be transformed into:
+///
+/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
+///     : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
+///     inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
+///     : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
 static LogicalResult
 pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
                                    tensor::ExpandShapeOp expandOp,
@@ -712,15 +744,9 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
   SmallVector<int64_t> projectedInnerDimsPos =
       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
 
-  // Check if the projected dims on the dest are divisible by the inner tile
-  // sizes.
-  for (auto [projectedPos, tileSize] :
-       llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
-    int64_t dim = dstShape[projectedPos];
-    if (ShapedType::isDynamic(dim) ||
-        (dstShape[projectedPos] % tileSize) != 0) {
-      return failure();
-    }
+  if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
+                                           innerTileSizes)) {
+    return failure();
   }
   // Expand the outer dims permutation with the associated expanded dims for the
   // new permutation after pushing. This is because moving a source dim is
@@ -733,7 +759,9 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
   }
 
   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
-  // First build reassociations on the outer dims after the permutation.
+  // First apply the permutation on the reassociations of the outer dims.
+  // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+  // -> [[0], [1, 2]]
   int64_t lastPos =
       applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
   // Then add direct mapping for the inner tile dims.
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 0344c483226af6..0c6977139402b1 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -906,12 +906,25 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:      into %[[UNPACK_NEW_DEST]]
 // CHECK:         return %[[UNPACK]] : tensor<16x540x960xi32>
 
-func.func @bubble_up_pack_through_collapse(%1: tensor<192x16x64x4xf32>) -> tensor<384x256x8x1xf32> {
-  %collapsed = tensor.collapse_shape %1 [[0, 1], [2, 3]] : tensor<192x16x64x4xf32> into tensor<3072x256xf32>
-  %2 = tensor.empty() : tensor<384x256x8x1xf32>
-  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<3072x256xf32> -> tensor<384x256x8x1xf32>
-  func.return %pack : tensor<384x256x8x1xf32>
+// -----
+
+func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
+  %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32>
+  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+  func.return %pack : tensor<?x4x8x1xf32>
 }
+// CHECK-LABEL: func.func @bubble_up_pack_through_collapse
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
+
+// -----
 
 func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
   %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
@@ -919,6 +932,14 @@ func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>
   %pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32>
   func.return %pack : tensor<4x32x3072x8x1xf32>
 }
+// CHECK-LABEL: func.func @bubble_up_permuted_pack_through_collapse
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x32x192x16x8x1xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<4x192x16x256xf32> -> tensor<4x32x192x16x8x1xf32>
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %pack {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x192x16x8x1xf32> into tensor<4x32x3072x8x1xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<4x32x3072x8x1xf32>
+
+// -----
 
 func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> {
   %collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32>
@@ -926,6 +947,14 @@ func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> ten
   %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32>
   func.return %pack : tensor<8x4x8x1xf32>
 }
+// CHECK-LABEL: func.func @bubble_up_pack_through_unit_collapse
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x4x8x1xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<1x64x1x4xf32> -> tensor<1x8x1x4x8x1xf32>
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<1x8x1x4x8x1xf32> into tensor<8x4x8x1xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<8x4x8x1xf32>
+
+// -----
 
 func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
   %collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
@@ -933,13 +962,31 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4
   %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32>
   func.return %pack : tensor<384x32x8x8xf32>
 }
+// CHECK-LABEL: func.func @no_bubble_up_pack_through_non_divisible_collapse
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[COLLAPSED]]
+// CHECK:         return %[[PACK]] : tensor<384x32x8x8xf32>
 
-func.func @push_down_unpack_through_expand(%5: tensor<384x32x8x8xf32>) -> tensor<12x256x256xf32> {
-  %6 = tensor.empty() : tensor<3072x256xf32>
-  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
-  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<12x256x256xf32>
-  func.return %expanded : tensor<12x256x256xf32>
+// -----
+
+func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+  %6 = tensor.empty(%dim) : tensor<?x256xf32>
+  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+  func.return %expanded : tensor<?x256x256xf32>
 }
+// CHECK-LABEL: func.func @push_down_unpack_through_expand
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+// CHECK:         %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
+// CHECK:         return %[[UNPACK]] : tensor<?x256x256xf32>
+
+// -----
 
 func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
   %6 = tensor.empty() : tensor<4x3072x256xf32>
@@ -947,6 +994,14 @@ func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>
   %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
   func.return %expanded : tensor<4x12x256x256xf32>
 }
+// CHECK-LABEL: @push_down_permuted_unpack_through_expand
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
+// CHECK:         %[[UNPACL:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
+// CHECK:         return %[[UNPACK]] : tensor<4x12x256x256xf32>
+
+// -----
 
 func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> {
   %6 = tensor.empty() : tensor<48x256xf32>
@@ -954,6 +1009,14 @@ func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> ten
   %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
   func.return %expanded : tensor<3x16x1x256xf32>
 }
+// CHECK-LABEL: func.func @push_down_unpack_through_unit_expand
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<3x16x1x256xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<3x2x1x32x8x8xf32> -> tensor<3x16x1x256xf32>
+// CHECK:         return %[[UNPACK]] : tensor<3x16x1x256xf32>
+
+// -----
 
 func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
   %6 = tensor.empty() : tensor<3072x256xf32>
@@ -961,3 +1024,8 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
   %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
   func.return %expanded : tensor<256x12x256xf32>
 }
+// CHECK-LABEL: func.func @no_push_down_unpack_through_non_divisible_expand
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>

>From 46b89b5b1647a9467c89ae15703e2538a1884ffb Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Tue, 19 Mar 2024 17:24:34 +0000
Subject: [PATCH 6/6] Refactor and fix tests

---
 .../Transforms/DataLayoutPropagation.cpp      | 54 +++++++++++++------
 .../Linalg/data-layout-propagation.mlir       |  2 +-
 2 files changed, 40 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 9b76da2cf97368..82a27484e31c4a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -553,17 +553,24 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
   ControlPropagationFn controlFn;
 };
 
+/// Project dimsPos to the inner-most non unit dim pos with reassocIndices.
+/// For example: Given dimsPos: [0, 2], reassocIndices: [[0, 1], [2, 3]], and
+/// targetShape: [3, 4, 5, 1], it returns [1, 2]. Because for pos 0, the
+/// inner-most projected dim in [0, 1] is 1. And for pos 2, the inner-most
+/// non-unit projected dims in [2, 3] is 2.
+///
+/// If all projected dims are unit dims, it chooses the inner-most dim pos.
 static SmallVector<int64_t>
 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
                                  ArrayRef<ReassociationIndices> reassocIndices,
-                                 ArrayRef<int64_t> baseShape) {
+                                 ArrayRef<int64_t> targetShape) {
   SmallVector<int64_t> projectedDimsPos;
   for (auto pos : dimsPos) {
     // In the case all dims are unit, this will return the inner-most one.
     int64_t projectedPos = reassocIndices[pos].back();
     for (auto it = reassocIndices[pos].rbegin();
          it != reassocIndices[pos].rend(); ++it) {
-      int64_t dim = baseShape[*it];
+      int64_t dim = targetShape[*it];
       if (dim > 1 || ShapedType::isDynamic(dim)) {
         projectedPos = *it;
         break;
@@ -574,24 +581,27 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
   return projectedDimsPos;
 }
 
-static bool
-isProjectedDimsDivisibleByTileSizes(ArrayRef<int64_t> projectedDimsPos,
-                                    ArrayRef<int64_t> targetShape,
-                                    ArrayRef<int64_t> tileSizes) {
-  for (auto [projectedPos, tileSize] :
-       llvm::zip_equal(projectedDimsPos, tileSizes)) {
-    int64_t dim = targetShape[projectedPos];
+/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
+static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
+                                       ArrayRef<int64_t> shape,
+                                       ArrayRef<int64_t> tileSizes) {
+  for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
+    int64_t dim = shape[pos];
     if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
       return false;
   }
   return true;
 }
 
+/// Permutate the reassociation indices and reindex them in the sequence order.
+/// For example: given reassociationIndices: [[0, 1], [2]] and permutation: [1,
+/// 0], it applies the permutation to get [[2], [0, 1]] and reindexes the
+/// indices into [[0], [1, 2]].
 static int64_t applyPermutationAndReindexReassoc(
     SmallVector<ReassociationIndices> &reassociationIndices,
-    ArrayRef<int64_t> dimsPerm) {
+    ArrayRef<int64_t> permutation) {
   applyPermutationToVector<ReassociationIndices>(reassociationIndices,
-                                                 dimsPerm);
+                                                 permutation);
   int64_t lastPos = 0;
   for (ReassociationIndices &indices : reassociationIndices) {
     for (auto &index : indices) {
@@ -632,11 +642,18 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
   ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
   SmallVector<ReassociationIndices> reassocIndices =
       collapseOp.getReassociationIndices();
+  // Project inner tile pos to the dim pos before collapsing. For example, if
+  // dims [x, y] is collapsed into [z], packing on dim z can be projected back
+  // to pack on dim y.
+  //
+  // Project to inner-most non-unit dims to increase the chance that they can be
+  // divided by the inner tile sizes, while keep the correctness. This is
+  // because in [..., x, 1], packing on dim 1 is equivalent to packing on dim x.
   SmallVector<int64_t> projectedInnerDimsPos =
       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
 
-  if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
-                                           innerTileSizes)) {
+  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
+                                  innerTileSizes)) {
     return failure();
   }
   // Expand the outer dims permutation with the associated source dims for the
@@ -741,11 +758,18 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
   ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
   SmallVector<ReassociationIndices> reassocIndices =
       expandOp.getReassociationIndices();
+  // Project inner tile pos to the dim pos after expanding. For example, if dims
+  // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
+  // on dim y.
+  //
+  // Project to inner-most non-unit dims to increase the chance that they can be
+  // divided by the inner tile sizes, while keep the correctness. This is
+  // because in [..., x, 1], packing on dim 1 is equivalent to packing on dim x.
   SmallVector<int64_t> projectedInnerDimsPos =
       projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
 
-  if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
-                                           innerTileSizes)) {
+  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
+                                  innerTileSizes)) {
     return failure();
   }
   // Expand the outer dims permutation with the associated expanded dims for the
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 0c6977139402b1..10c9f5bafb5c03 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -998,7 +998,7 @@ func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>
 // CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
-// CHECK:         %[[UNPACL:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
 // CHECK:         return %[[UNPACK]] : tensor<4x12x256x256xf32>
 
 // -----



More information about the Mlir-commits mailing list