[Mlir-commits] [mlir] [MLIR] Add patterns to bubble-up pack and push-down unpack through collapse/expand shape ops (PR #85297)
Jerry Wu
llvmlistbot at llvm.org
Wed Mar 27 10:46:41 PDT 2024
https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/85297
>From f20ca9d10e07418f1a776f19858fdfdf1f9e0342 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 8 Mar 2024 23:59:47 +0000
Subject: [PATCH 1/8] Test collapse pack and unpack expand
---
.../Transforms/DataLayoutPropagation.cpp | 189 +++++++++++++++++-
1 file changed, 188 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 5ceb85e7d9903b..4dc52891f4510c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -552,6 +552,192 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
ControlPropagationFn controlFn;
};
+static LogicalResult
+bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
+ tensor::PackOp packOp,
+ PatternRewriter &rewriter) {
+ SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+
+ if (llvm::any_of(innerTileSizes,
+ [](int64_t size) { return ShapedType::isDynamic(size); })) {
+ return failure();
+ }
+
+ ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
+ SmallVector<ReassociationIndices> reassocIndices =
+ collapseOp.getReassociationIndices();
+ SmallVector<int64_t> baseDimsPos;
+ for (auto pos : innerDimsPos) {
+ baseDimsPos.push_back(reassocIndices[pos].back());
+ }
+ // Check if the base dims before reassociation are divisible by the inner tile
+ // sizes.
+ for (auto [basePos, tileSize] :
+ llvm::zip_equal(baseDimsPos, innerTileSizes)) {
+ int64_t dim = srcShape[basePos];
+ if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) {
+ return failure();
+ }
+ }
+ // Expand the outer dims perm with associated src dims.
+ SmallVector<int64_t> newOuterDimsPerm;
+ for (auto outerPos : outerDimsPerm) {
+ newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+ reassocIndices[outerPos].begin(),
+ reassocIndices[outerPos].end());
+ }
+
+ auto emptyOp = tensor::PackOp::createDestinationTensor(
+ rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), baseDimsPos,
+ newOuterDimsPerm);
+ auto newPackOp = rewriter.create<tensor::PackOp>(
+ packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos, packOp.getMixedTiles(),
+ packOp.getPaddingValue(), newOuterDimsPerm);
+
+ SmallVector<ReassociationIndices> newReassocIndices;
+ int64_t currPos = 0;
+ for (auto outerPos : outerDimsPerm) {
+ int64_t start = currPos;
+ int64_t end = start + reassocIndices[outerPos].size();
+ newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
+ currPos = end;
+ }
+ for (auto unused : innerTileSizes) {
+ (void)unused;
+ newReassocIndices.push_back({currPos});
+ currPos += 1;
+ }
+
+ auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
+ collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
+ rewriter.replaceOp(packOp, newCollapseOp);
+
+ return success();
+}
+
+class BubbleUpPackOpThroughReshapeOp final
+ : public OpRewritePattern<tensor::PackOp> {
+public:
+ BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
+ : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ if (packOp.getPaddingValue())
+ return failure();
+
+ Operation *srcOp = packOp.getSource().getDefiningOp();
+ if (!srcOp || !(srcOp->getNumResults() == 1) ||
+ !srcOp->getResult(0).hasOneUse())
+ return failure();
+
+ if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(srcOp)) {
+ return bubbleUpPackOpThroughCollapseShape(collapseOp, packOp, rewriter);
+ }
+ return failure();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
+static LogicalResult
+pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
+ tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) {
+
+ SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
+ ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
+ ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
+
+ if (llvm::any_of(innerTileSizes,
+ [](int64_t size) { return ShapedType::isDynamic(size); })) {
+ return failure();
+ }
+
+ ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
+ SmallVector<ReassociationIndices> reassocIndices =
+ expandOp.getReassociationIndices();
+ SmallVector<int64_t> baseDimsPos;
+ for (auto pos : innerDimsPos) {
+ baseDimsPos.push_back(reassocIndices[pos].back());
+ }
+ // Check if the base dims after reassociation are divisible by the inner tile
+ // sizes.
+ for (auto [basePos, tileSize] :
+ llvm::zip_equal(baseDimsPos, innerTileSizes)) {
+ int64_t dim = dstShape[basePos];
+ if (ShapedType::isDynamic(dim) || dstShape[basePos] % tileSize != 0) {
+ return failure();
+ }
+ }
+ // Expand the outer dims perm with associated src dims.
+ SmallVector<int64_t> newOuterDimsPerm;
+ for (auto outerPos : outerDimsPerm) {
+ newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+ reassocIndices[outerPos].begin(),
+ reassocIndices[outerPos].end());
+ }
+
+ SmallVector<ReassociationIndices> newReassocIndices;
+ int64_t currPos = 0;
+ for (auto outerPos : outerDimsPerm) {
+ int64_t start = currPos;
+ int64_t end = start + reassocIndices[outerPos].size();
+ newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
+ currPos = end;
+ }
+ for (auto unused : innerTileSizes) {
+ (void)unused;
+ newReassocIndices.push_back({currPos});
+ currPos += 1;
+ }
+
+ RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
+ expandOp.getType(), innerTileSizes, baseDimsPos, newOuterDimsPerm);
+ auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+ expandOp.getLoc(), newExpandType, unPackOp.getSource(),
+ newReassocIndices);
+
+ auto emptyOp = tensor::UnPackOp::createDestinationTensor(
+ rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), baseDimsPos,
+ newOuterDimsPerm);
+ auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
+ unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, baseDimsPos,
+ unPackOp.getMixedTiles(), newOuterDimsPerm);
+ rewriter.replaceOp(expandOp, newUnPackOp);
+
+ return success();
+}
+
+class PushDownUnPackOpThroughReshapeOp final
+ : public OpRewritePattern<tensor::UnPackOp> {
+public:
+ PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
+ }
+
+ LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
+ PatternRewriter &rewriter) const override {
+ Value result = unPackOp.getResult();
+ if (!result.hasOneUse()) {
+ return failure();
+ }
+ Operation *userOp = *result.user_begin();
+
+ if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(userOp)) {
+ return pushDownUnPackOpThroughExpandShape(unPackOp, expandOp, rewriter);
+ }
+ return failure();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
// TODO: Relax this restriction. We should unpack a generic op also
// in the presence of multiple unpack ops as producers.
/// Return the unpacked operand, if present, for the current generic op.
@@ -774,6 +960,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
const ControlPropagationFn &controlPackUnPackPropagation) {
patterns
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
- PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
+ BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
+ PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
>From 4f2da368e57a8e06af61ace8c03a8184c96a22a5 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 11 Mar 2024 21:25:34 +0000
Subject: [PATCH 2/8] Handle unit dim
---
.../Transforms/DataLayoutPropagation.cpp | 46 +++++++++++++------
1 file changed, 32 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 4dc52891f4510c..e230a11f9f2c0e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -552,6 +552,26 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
ControlPropagationFn controlFn;
};
+static SmallVector<int64_t>
+projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
+ ArrayRef<ReassociationIndices> reassocIndices,
+ ArrayRef<int64_t> baseShape) {
+ SmallVector<int64_t> projectedDimsPos;
+ for (auto pos : dimsPos) {
+ int64_t projectedPos = -1;
+ for (auto it = reassocIndices[pos].rbegin();
+ it != reassocIndices[pos].rend(); ++it) {
+ projectedPos = *it;
+ if (baseShape[projectedPos] > 1) {
+ break;
+ }
+ }
+ assert(projectedPos != -1 && "projected dim not found");
+ projectedDimsPos.push_back(projectedPos);
+ }
+ return projectedDimsPos;
+}
+
static LogicalResult
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
tensor::PackOp packOp,
@@ -568,10 +588,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
SmallVector<ReassociationIndices> reassocIndices =
collapseOp.getReassociationIndices();
- SmallVector<int64_t> baseDimsPos;
- for (auto pos : innerDimsPos) {
- baseDimsPos.push_back(reassocIndices[pos].back());
- }
+ SmallVector<int64_t> baseDimsPos =
+ projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
+
// Check if the base dims before reassociation are divisible by the inner tile
// sizes.
for (auto [basePos, tileSize] :
@@ -590,11 +609,11 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
}
auto emptyOp = tensor::PackOp::createDestinationTensor(
- rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), baseDimsPos,
- newOuterDimsPerm);
+ rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
+ baseDimsPos, newOuterDimsPerm);
auto newPackOp = rewriter.create<tensor::PackOp>(
- packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos, packOp.getMixedTiles(),
- packOp.getPaddingValue(), newOuterDimsPerm);
+ packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos,
+ packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
SmallVector<ReassociationIndices> newReassocIndices;
int64_t currPos = 0;
@@ -660,10 +679,9 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
SmallVector<ReassociationIndices> reassocIndices =
expandOp.getReassociationIndices();
- SmallVector<int64_t> baseDimsPos;
- for (auto pos : innerDimsPos) {
- baseDimsPos.push_back(reassocIndices[pos].back());
- }
+ SmallVector<int64_t> baseDimsPos =
+ projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
+
// Check if the base dims after reassociation are divisible by the inner tile
// sizes.
for (auto [basePos, tileSize] :
@@ -702,8 +720,8 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
newReassocIndices);
auto emptyOp = tensor::UnPackOp::createDestinationTensor(
- rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), baseDimsPos,
- newOuterDimsPerm);
+ rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
+ baseDimsPos, newOuterDimsPerm);
auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, baseDimsPos,
unPackOp.getMixedTiles(), newOuterDimsPerm);
>From a49d91d8284593f607fa212ba9ae205d4115341e Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 14 Mar 2024 19:10:54 +0000
Subject: [PATCH 3/8] Add test draft
---
.../Linalg/data-layout-propagation.mlir | 56 +++++++++++++++++++
1 file changed, 56 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index e036695a2ac9fd..0344c483226af6 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -905,3 +905,59 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
+
+func.func @bubble_up_pack_through_collapse(%1: tensor<192x16x64x4xf32>) -> tensor<384x256x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2, 3]] : tensor<192x16x64x4xf32> into tensor<3072x256xf32>
+ %2 = tensor.empty() : tensor<384x256x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<3072x256xf32> -> tensor<384x256x8x1xf32>
+ func.return %pack : tensor<384x256x8x1xf32>
+}
+
+func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
+ %2 = tensor.empty() : tensor<4x32x3072x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32>
+ func.return %pack : tensor<4x32x3072x8x1xf32>
+}
+
+func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32>
+ %2 = tensor.empty() : tensor<8x4x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32>
+ func.return %pack : tensor<8x4x8x1xf32>
+}
+
+func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+ %2 = tensor.empty() : tensor<384x32x8x8xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32>
+ func.return %pack : tensor<384x32x8x8xf32>
+}
+
+func.func @push_down_unpack_through_expand(%5: tensor<384x32x8x8xf32>) -> tensor<12x256x256xf32> {
+ %6 = tensor.empty() : tensor<3072x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<12x256x256xf32>
+ func.return %expanded : tensor<12x256x256xf32>
+}
+
+func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
+ %6 = tensor.empty() : tensor<4x3072x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
+ func.return %expanded : tensor<4x12x256x256xf32>
+}
+
+func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> {
+ %6 = tensor.empty() : tensor<48x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<6x32x8x8xf32> -> tensor<48x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
+ func.return %expanded : tensor<3x16x1x256xf32>
+}
+
+func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
+ %6 = tensor.empty() : tensor<3072x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+ func.return %expanded : tensor<256x12x256xf32>
+}
>From 92020144089fee6ab328c25eece55d18072cedc5 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 14 Mar 2024 20:56:52 +0000
Subject: [PATCH 4/8] Refactor
---
.../Transforms/DataLayoutPropagation.cpp | 178 +++++++++++-------
1 file changed, 110 insertions(+), 68 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e230a11f9f2c0e..0d53205b8170c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -572,6 +573,39 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
return projectedDimsPos;
}
+static int64_t applyPermutationAndReindexReassoc(
+ SmallVector<ReassociationIndices> &reassociationIndices,
+ ArrayRef<int64_t> dimsPerm) {
+ applyPermutationToVector<ReassociationIndices>(reassociationIndices,
+ dimsPerm);
+ int64_t lastPos = 0;
+ for (ReassociationIndices &indices : reassociationIndices) {
+ for (auto &index : indices) {
+ index = lastPos;
+ lastPos += 1;
+ }
+ }
+ return lastPos;
+}
+
+/// Bubble up pack op through collapse shape op when the packed dims can be
+/// mapped to the source dims before collapsing. This is possible when the inner
+/// tile sizes can divide the mapped source dims.
+///
+/// For example:
+///
+/// %collapsed = tensor.collapse_shape %in [[0, 1], 2] : tensor<?x16x4xf32> into
+/// tensor<?x4xf32> %out = tensor.empty() : tensor<?x4x8x1xf32> %pack =
+/// tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
+/// inner_tiles = [8, 1] into %out : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+///
+/// Can be transformed into:
+///
+/// %out = tensor.empty() : tensor<?x2x4x8x1xf32>
+/// %pack = tensor.pack %in outer_dims_perm = [1, 2] inner_dims_pos = [1, 2]
+/// inner_tiles = [8, 1] into %out : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+/// %collapsed = tensor.collapse_shape %1 [[0, 1], 2, 3, 4] :
+/// tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
static LogicalResult
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
tensor::PackOp packOp,
@@ -580,27 +614,23 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
- if (llvm::any_of(innerTileSizes,
- [](int64_t size) { return ShapedType::isDynamic(size); })) {
- return failure();
- }
-
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
SmallVector<ReassociationIndices> reassocIndices =
collapseOp.getReassociationIndices();
- SmallVector<int64_t> baseDimsPos =
+ SmallVector<int64_t> projectedInnerDimsPos =
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
- // Check if the base dims before reassociation are divisible by the inner tile
+ // Check if the projected dims on the source are divisible by the inner tile
// sizes.
- for (auto [basePos, tileSize] :
- llvm::zip_equal(baseDimsPos, innerTileSizes)) {
- int64_t dim = srcShape[basePos];
- if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) {
+ for (auto [projectedPos, tileSize] :
+ llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
+ int64_t dim = srcShape[projectedPos];
+ if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
return failure();
- }
}
- // Expand the outer dims perm with associated src dims.
+ // Expand the outer dims permutation with the associated source dims for the
+ // new permutation after bubbling. This is because moving a collapsed dim is
+ // equivalent to moving the associated source dims together.
SmallVector<int64_t> newOuterDimsPerm;
for (auto outerPos : outerDimsPerm) {
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
@@ -610,23 +640,19 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
auto emptyOp = tensor::PackOp::createDestinationTensor(
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
- baseDimsPos, newOuterDimsPerm);
+ projectedInnerDimsPos, newOuterDimsPerm);
auto newPackOp = rewriter.create<tensor::PackOp>(
- packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos,
+ packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
- SmallVector<ReassociationIndices> newReassocIndices;
- int64_t currPos = 0;
- for (auto outerPos : outerDimsPerm) {
- int64_t start = currPos;
- int64_t end = start + reassocIndices[outerPos].size();
- newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
- currPos = end;
- }
- for (auto unused : innerTileSizes) {
- (void)unused;
- newReassocIndices.push_back({currPos});
- currPos += 1;
+ SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
+ // First build reassociations on the outer dims after the permutation.
+ int64_t lastPos =
+ applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
+ // Then add direct mapping for the inner tile dims.
+ for (size_t i = 0; i < innerDimsPos.size(); ++i) {
+ newReassocIndices.push_back({lastPos});
+ lastPos += 1;
}
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
@@ -644,18 +670,28 @@ class BubbleUpPackOpThroughReshapeOp final
LogicalResult matchAndRewrite(tensor::PackOp packOp,
PatternRewriter &rewriter) const override {
- if (packOp.getPaddingValue())
+ // User controlled propagation function.
+ if (!controlFn(packOp))
return failure();
Operation *srcOp = packOp.getSource().getDefiningOp();
+ // Currently only support when the pack op is the only user.
if (!srcOp || !(srcOp->getNumResults() == 1) ||
- !srcOp->getResult(0).hasOneUse())
+ !srcOp->getResult(0).hasOneUse()) {
return failure();
-
- if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(srcOp)) {
- return bubbleUpPackOpThroughCollapseShape(collapseOp, packOp, rewriter);
}
- return failure();
+ // Currently only support static inner tile sizes.
+ if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
+ return ShapedType::isDynamic(size);
+ })) {
+ return failure();
+ }
+
+ return TypeSwitch<Operation *, LogicalResult>(srcOp)
+ .Case([&](tensor::CollapseShapeOp op) {
+ return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
+ })
+ .Default([](Operation *) { return failure(); });
}
private:
@@ -666,32 +702,29 @@ static LogicalResult
pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) {
-
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
- if (llvm::any_of(innerTileSizes,
- [](int64_t size) { return ShapedType::isDynamic(size); })) {
- return failure();
- }
-
ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
SmallVector<ReassociationIndices> reassocIndices =
expandOp.getReassociationIndices();
- SmallVector<int64_t> baseDimsPos =
+ SmallVector<int64_t> projectedInnerDimsPos =
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
- // Check if the base dims after reassociation are divisible by the inner tile
+ // Check if the projected dims on the dest are divisible by the inner tile
// sizes.
- for (auto [basePos, tileSize] :
- llvm::zip_equal(baseDimsPos, innerTileSizes)) {
- int64_t dim = dstShape[basePos];
- if (ShapedType::isDynamic(dim) || dstShape[basePos] % tileSize != 0) {
+ for (auto [projectedPos, tileSize] :
+ llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
+ int64_t dim = dstShape[projectedPos];
+ if (ShapedType::isDynamic(dim) ||
+ (dstShape[projectedPos] % tileSize) != 0) {
return failure();
}
}
- // Expand the outer dims perm with associated src dims.
+ // Expand the outer dims permutation with the associated expanded dims for the
+ // new permutation after pushing. This is because moving a source dim is
+ // equivalent to moving the associated expanded dims together.
SmallVector<int64_t> newOuterDimsPerm;
for (auto outerPos : outerDimsPerm) {
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
@@ -699,32 +732,29 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
reassocIndices[outerPos].end());
}
- SmallVector<ReassociationIndices> newReassocIndices;
- int64_t currPos = 0;
- for (auto outerPos : outerDimsPerm) {
- int64_t start = currPos;
- int64_t end = start + reassocIndices[outerPos].size();
- newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
- currPos = end;
- }
- for (auto unused : innerTileSizes) {
- (void)unused;
- newReassocIndices.push_back({currPos});
- currPos += 1;
+ SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
+ // First build reassociations on the outer dims after the permutation.
+ int64_t lastPos =
+ applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
+ // Then add direct mapping for the inner tile dims.
+ for (size_t i = 0; i < innerDimsPos.size(); ++i) {
+ newReassocIndices.push_back({lastPos});
+ lastPos += 1;
}
- RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
- expandOp.getType(), innerTileSizes, baseDimsPos, newOuterDimsPerm);
+ RankedTensorType newExpandType =
+ tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
+ projectedInnerDimsPos, newOuterDimsPerm);
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
expandOp.getLoc(), newExpandType, unPackOp.getSource(),
newReassocIndices);
auto emptyOp = tensor::UnPackOp::createDestinationTensor(
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
- baseDimsPos, newOuterDimsPerm);
+ projectedInnerDimsPos, newOuterDimsPerm);
auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
- unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, baseDimsPos,
- unPackOp.getMixedTiles(), newOuterDimsPerm);
+ unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
+ projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
rewriter.replaceOp(expandOp, newUnPackOp);
return success();
@@ -740,16 +770,28 @@ class PushDownUnPackOpThroughReshapeOp final
LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
+ // User controlled propagation function.
+ if (!controlFn(unPackOp))
+ return failure();
+
Value result = unPackOp.getResult();
+ // Currently only support unpack op with the single user.
if (!result.hasOneUse()) {
return failure();
}
- Operation *userOp = *result.user_begin();
-
- if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(userOp)) {
- return pushDownUnPackOpThroughExpandShape(unPackOp, expandOp, rewriter);
+ // Currently only support static inner tile sizes.
+ if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
+ return ShapedType::isDynamic(size);
+ })) {
+ return failure();
}
- return failure();
+
+ Operation *userOp = *result.user_begin();
+ return TypeSwitch<Operation *, LogicalResult>(userOp)
+ .Case([&](tensor::ExpandShapeOp op) {
+ return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
+ })
+ .Default([](Operation *) { return failure(); });
}
private:
>From 376f8d578d12709c257e5c2e757d5109a1d528e5 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 18 Mar 2024 22:45:19 +0000
Subject: [PATCH 5/8] Finish tests
---
.../Transforms/DataLayoutPropagation.cpp | 94 ++++++++++++-------
.../Linalg/data-layout-propagation.mlir | 88 +++++++++++++++--
2 files changed, 139 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0d53205b8170c6..9b76da2cf97368 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -559,20 +559,34 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
ArrayRef<int64_t> baseShape) {
SmallVector<int64_t> projectedDimsPos;
for (auto pos : dimsPos) {
- int64_t projectedPos = -1;
+ // In the case all dims are unit, this will return the inner-most one.
+ int64_t projectedPos = reassocIndices[pos].back();
for (auto it = reassocIndices[pos].rbegin();
it != reassocIndices[pos].rend(); ++it) {
- projectedPos = *it;
- if (baseShape[projectedPos] > 1) {
+ int64_t dim = baseShape[*it];
+ if (dim > 1 || ShapedType::isDynamic(dim)) {
+ projectedPos = *it;
break;
}
}
- assert(projectedPos != -1 && "projected dim not found");
projectedDimsPos.push_back(projectedPos);
}
return projectedDimsPos;
}
+static bool
+isProjectedDimsDivisibleByTileSizes(ArrayRef<int64_t> projectedDimsPos,
+ ArrayRef<int64_t> targetShape,
+ ArrayRef<int64_t> tileSizes) {
+ for (auto [projectedPos, tileSize] :
+ llvm::zip_equal(projectedDimsPos, tileSizes)) {
+ int64_t dim = targetShape[projectedPos];
+ if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
+ return false;
+ }
+ return true;
+}
+
static int64_t applyPermutationAndReindexReassoc(
SmallVector<ReassociationIndices> &reassociationIndices,
ArrayRef<int64_t> dimsPerm) {
@@ -589,23 +603,24 @@ static int64_t applyPermutationAndReindexReassoc(
}
/// Bubble up pack op through collapse shape op when the packed dims can be
-/// mapped to the source dims before collapsing. This is possible when the inner
-/// tile sizes can divide the mapped source dims.
+/// projected to the dims before collapsing. This is possible when the inner
+/// tile sizes can divide the projected dims.
///
/// For example:
///
-/// %collapsed = tensor.collapse_shape %in [[0, 1], 2] : tensor<?x16x4xf32> into
-/// tensor<?x4xf32> %out = tensor.empty() : tensor<?x4x8x1xf32> %pack =
-/// tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
-/// inner_tiles = [8, 1] into %out : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
+/// : tensor<?x16x4xf32> into tensor<?x4xf32>
+/// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
+/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
+/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
///
/// Can be transformed into:
///
-/// %out = tensor.empty() : tensor<?x2x4x8x1xf32>
-/// %pack = tensor.pack %in outer_dims_perm = [1, 2] inner_dims_pos = [1, 2]
-/// inner_tiles = [8, 1] into %out : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
-/// %collapsed = tensor.collapse_shape %1 [[0, 1], 2, 3, 4] :
-/// tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
+/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
+/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
+/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
+/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
static LogicalResult
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
tensor::PackOp packOp,
@@ -620,13 +635,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
SmallVector<int64_t> projectedInnerDimsPos =
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
- // Check if the projected dims on the source are divisible by the inner tile
- // sizes.
- for (auto [projectedPos, tileSize] :
- llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
- int64_t dim = srcShape[projectedPos];
- if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
- return failure();
+ if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
+ innerTileSizes)) {
+ return failure();
}
// Expand the outer dims permutation with the associated source dims for the
// new permutation after bubbling. This is because moving a collapsed dim is
@@ -646,7 +657,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
- // First build reassociations on the outer dims after the permutation.
+ // First apply the permutation on the reassociations of the outer dims.
+ // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+ // -> [[0], [1, 2]]
int64_t lastPos =
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
// Then add direct mapping for the inner tile dims.
@@ -698,6 +711,25 @@ class BubbleUpPackOpThroughReshapeOp final
ControlPropagationFn controlFn;
};
+/// Push down unpack op through expand shape op when the packed dims can be
+/// projected to the dims after expanding. This is possible when the inner tile
+/// sizes can divide the projected dims.
+///
+/// For example:
+///
+/// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
+/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
+/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
+/// : tensor<?x256xf32> into tensor<?x256x256xf32>
+///
+/// Can be transformed into:
+///
+/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
+/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
+/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
+/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
static LogicalResult
pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
tensor::ExpandShapeOp expandOp,
@@ -712,15 +744,9 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
SmallVector<int64_t> projectedInnerDimsPos =
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
- // Check if the projected dims on the dest are divisible by the inner tile
- // sizes.
- for (auto [projectedPos, tileSize] :
- llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
- int64_t dim = dstShape[projectedPos];
- if (ShapedType::isDynamic(dim) ||
- (dstShape[projectedPos] % tileSize) != 0) {
- return failure();
- }
+ if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
+ innerTileSizes)) {
+ return failure();
}
// Expand the outer dims permutation with the associated expanded dims for the
// new permutation after pushing. This is because moving a source dim is
@@ -733,7 +759,9 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
}
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
- // First build reassociations on the outer dims after the permutation.
+ // First apply the permutation on the reassociations of the outer dims.
+ // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+ // -> [[0], [1, 2]]
int64_t lastPos =
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
// Then add direct mapping for the inner tile dims.
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 0344c483226af6..0c6977139402b1 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -906,12 +906,25 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
-func.func @bubble_up_pack_through_collapse(%1: tensor<192x16x64x4xf32>) -> tensor<384x256x8x1xf32> {
- %collapsed = tensor.collapse_shape %1 [[0, 1], [2, 3]] : tensor<192x16x64x4xf32> into tensor<3072x256xf32>
- %2 = tensor.empty() : tensor<384x256x8x1xf32>
- %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<3072x256xf32> -> tensor<384x256x8x1xf32>
- func.return %pack : tensor<384x256x8x1xf32>
+// -----
+
+func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
+ %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+ func.return %pack : tensor<?x4x8x1xf32>
}
+// CHECK-LABEL: func.func @bubble_up_pack_through_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
+
+// -----
func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
%collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
@@ -919,6 +932,14 @@ func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>
%pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32>
func.return %pack : tensor<4x32x3072x8x1xf32>
}
+// CHECK-LABEL: func.func @bubble_up_permuted_pack_through_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x32x192x16x8x1xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<4x192x16x256xf32> -> tensor<4x32x192x16x8x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %pack {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x192x16x8x1xf32> into tensor<4x32x3072x8x1xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<4x32x3072x8x1xf32>
+
+// -----
func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> {
%collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32>
@@ -926,6 +947,14 @@ func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> ten
%pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32>
func.return %pack : tensor<8x4x8x1xf32>
}
+// CHECK-LABEL: func.func @bubble_up_pack_through_unit_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x4x8x1xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<1x64x1x4xf32> -> tensor<1x8x1x4x8x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<1x8x1x4x8x1xf32> into tensor<8x4x8x1xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<8x4x8x1xf32>
+
+// -----
func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
%collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
@@ -933,13 +962,31 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4
%pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32>
func.return %pack : tensor<384x32x8x8xf32>
}
+// CHECK-LABEL: func.func @no_bubble_up_pack_through_non_divisible_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[COLLAPSED]]
+// CHECK: return %[[PACK]] : tensor<384x32x8x8xf32>
-func.func @push_down_unpack_through_expand(%5: tensor<384x32x8x8xf32>) -> tensor<12x256x256xf32> {
- %6 = tensor.empty() : tensor<3072x256xf32>
- %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
- %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<12x256x256xf32>
- func.return %expanded : tensor<12x256x256xf32>
+// -----
+
+func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+ %6 = tensor.empty(%dim) : tensor<?x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+ func.return %expanded : tensor<?x256x256xf32>
}
+// CHECK-LABEL: func.func @push_down_unpack_through_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
+// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
+
+// -----
func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
%6 = tensor.empty() : tensor<4x3072x256xf32>
@@ -947,6 +994,14 @@ func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>
%expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
func.return %expanded : tensor<4x12x256x256xf32>
}
+// CHECK-LABEL: @push_down_permuted_unpack_through_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
+// CHECK: %[[UNPACL:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
+// CHECK: return %[[UNPACK]] : tensor<4x12x256x256xf32>
+
+// -----
func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> {
%6 = tensor.empty() : tensor<48x256xf32>
@@ -954,6 +1009,14 @@ func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> ten
%expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
func.return %expanded : tensor<3x16x1x256xf32>
}
+// CHECK-LABEL: func.func @push_down_unpack_through_unit_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x16x1x256xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<3x2x1x32x8x8xf32> -> tensor<3x16x1x256xf32>
+// CHECK: return %[[UNPACK]] : tensor<3x16x1x256xf32>
+
+// -----
func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
%6 = tensor.empty() : tensor<3072x256xf32>
@@ -961,3 +1024,8 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
%expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
func.return %expanded : tensor<256x12x256xf32>
}
+// CHECK-LABEL: func.func @no_push_down_unpack_through_non_divisible_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
>From c60506c31fc667ef641d9e72afbd722375ce226c Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Tue, 19 Mar 2024 17:24:34 +0000
Subject: [PATCH 6/8] Refactor and fix tests
---
.../Transforms/DataLayoutPropagation.cpp | 91 ++++++++++++-------
.../Linalg/data-layout-propagation.mlir | 2 +-
2 files changed, 60 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 9b76da2cf97368..470e7cc5474cb5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -553,17 +553,26 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
ControlPropagationFn controlFn;
};
+/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
+///
+/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
+/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
+/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
+/// non-unit projected dims in pos [2, 3] is 2.
+///
+/// If all candidates in a reassociation are unit dims, it chooses the
+/// inner-most dim pos.
static SmallVector<int64_t>
projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
ArrayRef<ReassociationIndices> reassocIndices,
- ArrayRef<int64_t> baseShape) {
+ ArrayRef<int64_t> targetShape) {
SmallVector<int64_t> projectedDimsPos;
for (auto pos : dimsPos) {
// In the case all dims are unit, this will return the inner-most one.
int64_t projectedPos = reassocIndices[pos].back();
for (auto it = reassocIndices[pos].rbegin();
it != reassocIndices[pos].rend(); ++it) {
- int64_t dim = baseShape[*it];
+ int64_t dim = targetShape[*it];
if (dim > 1 || ShapedType::isDynamic(dim)) {
projectedPos = *it;
break;
@@ -574,32 +583,36 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
return projectedDimsPos;
}
-static bool
-isProjectedDimsDivisibleByTileSizes(ArrayRef<int64_t> projectedDimsPos,
- ArrayRef<int64_t> targetShape,
- ArrayRef<int64_t> tileSizes) {
- for (auto [projectedPos, tileSize] :
- llvm::zip_equal(projectedDimsPos, tileSizes)) {
- int64_t dim = targetShape[projectedPos];
+/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
+static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
+ ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> tileSizes) {
+ for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
+ int64_t dim = shape[pos];
if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
return false;
}
return true;
}
+/// Permutate the reassociation indices and reindex them in the sequence order.
+/// Returns the next dim pos in the sequence.
+///
+/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
+/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
+/// [[0], [1, 2]].
static int64_t applyPermutationAndReindexReassoc(
- SmallVector<ReassociationIndices> &reassociationIndices,
- ArrayRef<int64_t> dimsPerm) {
- applyPermutationToVector<ReassociationIndices>(reassociationIndices,
- dimsPerm);
- int64_t lastPos = 0;
- for (ReassociationIndices &indices : reassociationIndices) {
+ SmallVector<ReassociationIndices> &reassocIndices,
+ ArrayRef<int64_t> permutation) {
+ applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
+ int64_t nextPos = 0;
+ for (ReassociationIndices &indices : reassocIndices) {
for (auto &index : indices) {
- index = lastPos;
- lastPos += 1;
+ index = nextPos;
+ nextPos += 1;
}
}
- return lastPos;
+ return nextPos;
}
/// Bubble up pack op through collapse shape op when the packed dims can be
@@ -614,7 +627,7 @@ static int64_t applyPermutationAndReindexReassoc(
/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
///
-/// Can be transformed into:
+/// can be transformed into:
///
/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
@@ -632,11 +645,18 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
SmallVector<ReassociationIndices> reassocIndices =
collapseOp.getReassociationIndices();
+ // Project inner tile pos to the dim pos before collapsing. For example, if
+ // dims [x, y] is collapsed into [z], packing on dim z can be projected back
+ // to pack on dim y.
+ //
+ // Project to inner-most non-unit dims to increase the chance that they can be
+ // divided by the inner tile sizes. This is correct because for [..., x, 1],
+ // packing on dim 1 is equivalent to packing on dim x.
SmallVector<int64_t> projectedInnerDimsPos =
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
- if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
- innerTileSizes)) {
+ if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
+ innerTileSizes)) {
return failure();
}
// Expand the outer dims permutation with the associated source dims for the
@@ -658,14 +678,14 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
// First apply the permutation on the reassociations of the outer dims.
- // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+ // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
// -> [[0], [1, 2]]
- int64_t lastPos =
+ int64_t nextPos =
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
// Then add direct mapping for the inner tile dims.
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
- newReassocIndices.push_back({lastPos});
- lastPos += 1;
+ newReassocIndices.push_back({nextPos});
+ nextPos += 1;
}
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
@@ -723,7 +743,7 @@ class BubbleUpPackOpThroughReshapeOp final
/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
/// : tensor<?x256xf32> into tensor<?x256x256xf32>
///
-/// Can be transformed into:
+/// can be transformed into:
///
/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
@@ -741,11 +761,18 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
SmallVector<ReassociationIndices> reassocIndices =
expandOp.getReassociationIndices();
+ // Project inner tile pos to the dim pos after expanding. For example, if dims
+ // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
+ // on dim y.
+ //
+ // Project to inner-most non-unit dims to increase the chance that they can be
+ // divided by the inner tile sizes. This is correct because for [..., x, 1],
+ // unpacking on dim 1 is equivalent to unpacking on dim x.
SmallVector<int64_t> projectedInnerDimsPos =
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
- if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
- innerTileSizes)) {
+ if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
+ innerTileSizes)) {
return failure();
}
// Expand the outer dims permutation with the associated expanded dims for the
@@ -760,14 +787,14 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
// First apply the permutation on the reassociations of the outer dims.
- // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+ // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
// -> [[0], [1, 2]]
- int64_t lastPos =
+ int64_t nextPos =
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
// Then add direct mapping for the inner tile dims.
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
- newReassocIndices.push_back({lastPos});
- lastPos += 1;
+ newReassocIndices.push_back({nextPos});
+ nextPos += 1;
}
RankedTensorType newExpandType =
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 0c6977139402b1..10c9f5bafb5c03 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -998,7 +998,7 @@ func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
-// CHECK: %[[UNPACL:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
// CHECK: return %[[UNPACK]] : tensor<4x12x256x256xf32>
// -----
>From 143243f79fe4989c73bf9e132812df551d27eb1c Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Wed, 20 Mar 2024 18:31:36 +0000
Subject: [PATCH 7/8] Fix control function
---
.../Transforms/DataLayoutPropagation.cpp | 20 +++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 470e7cc5474cb5..e01653b8940672 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -703,10 +703,6 @@ class BubbleUpPackOpThroughReshapeOp final
LogicalResult matchAndRewrite(tensor::PackOp packOp,
PatternRewriter &rewriter) const override {
- // User controlled propagation function.
- if (!controlFn(packOp))
- return failure();
-
Operation *srcOp = packOp.getSource().getDefiningOp();
// Currently only support when the pack op is the only user.
if (!srcOp || !(srcOp->getNumResults() == 1) ||
@@ -720,6 +716,10 @@ class BubbleUpPackOpThroughReshapeOp final
return failure();
}
+ // User controlled propagation function.
+ if (!controlFn(srcOp))
+ return failure();
+
return TypeSwitch<Operation *, LogicalResult>(srcOp)
.Case([&](tensor::CollapseShapeOp op) {
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
@@ -825,10 +825,6 @@ class PushDownUnPackOpThroughReshapeOp final
LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
- // User controlled propagation function.
- if (!controlFn(unPackOp))
- return failure();
-
Value result = unPackOp.getResult();
// Currently only support unpack op with the single user.
if (!result.hasOneUse()) {
@@ -841,8 +837,12 @@ class PushDownUnPackOpThroughReshapeOp final
return failure();
}
- Operation *userOp = *result.user_begin();
- return TypeSwitch<Operation *, LogicalResult>(userOp)
+ Operation *consumerOp = *result.user_begin();
+ // User controlled propagation function.
+ if (!controlFn(consumerOp))
+ return failure();
+
+ return TypeSwitch<Operation *, LogicalResult>(consumerOp)
.Case([&](tensor::ExpandShapeOp op) {
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
})
>From 86024e8f789529a6e3da94d91b1f00948c0d0d44 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Wed, 27 Mar 2024 17:32:44 +0000
Subject: [PATCH 8/8] Add new tests
---
.../Linalg/data-layout-propagation.mlir | 36 +++++++++++++++++++
1 file changed, 36 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 10c9f5bafb5c03..79d61ab757e327 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -956,6 +956,24 @@ func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> ten
// -----
+func.func @bubble_up_pack_through_collapse_on_outer_dims(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x1x4xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
+ %2 = tensor.empty(%dim) : tensor<?x1x4xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [4] into %2 : tensor<?x4xf32> -> tensor<?x1x4xf32>
+ func.return %pack : tensor<?x1x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_collapse_on_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x16x1x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [4] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x16x1x4xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] : tensor<?x16x1x4xf32> into tensor<?x1x4xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<?x1x4xf32>
+
+// -----
+
func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
%collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
%2 = tensor.empty() : tensor<384x32x8x8xf32>
@@ -1018,6 +1036,24 @@ func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> ten
// -----
+func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+ %6 = tensor.empty(%dim) : tensor<?x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [8] into %6 : tensor<?x32x8xf32> -> tensor<?x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+ func.return %expanded : tensor<?x256x256xf32>
+}
+// CHECK-LABEL: func.func @push_down_unpack_through_expand_on_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] : tensor<?x32x8xf32> into tensor<?x256x32x8xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x256x32x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [8] into %[[EMPTY]] : tensor<?x256x32x8xf32> -> tensor<?x256x256xf32>
+// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
+
+// -----
+
func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
%6 = tensor.empty() : tensor<3072x256xf32>
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
More information about the Mlir-commits
mailing list