[Mlir-commits] [mlir] [mlir][linalg] Add pattern to bubble-up pack through expand shape op (PR #93529)
Adam Siemieniuk
llvmlistbot at llvm.org
Fri Jun 14 04:45:10 PDT 2024
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/93529
>From f9f1f9beac68832970da00e450ed93753cfc936d Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 27 May 2024 17:36:43 +0200
Subject: [PATCH 1/7] [mlir][linalg] Add pattern to bubble-up pack through
expand shape op
Extends bubble-up pack through reshape pattern to handle pack propagation
through expand shape ops.
---
.../Transforms/DataLayoutPropagation.cpp | 104 +++++++++
.../Linalg/data-layout-propagation.mlir | 204 ++++++++++++++++++
2 files changed, 308 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 2bea083ac2d78..73a86caa2fbcb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -17,6 +17,8 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -694,6 +696,105 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
return success();
}
+/// Project dimsPos to their collapsed positions in the reassocIndices.
+///
+/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
+/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
+/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
+/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
+static SmallVector<int64_t>
+projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
+ ArrayRef<ReassociationIndices> reassocIndices) {
+ SmallVector<int64_t> projectedPos;
+
+ // Map each dimension to the position of corresponding reassociation index.
+ for (auto pos : dimsPos) {
+ for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
+ // If the dimension is present in the current indices group, the group
+ // position within the reassociation map is the desired projected
+ // dimension position.
+ if (llvm::any_of(indices,
+ [&](int64_t expandDim) { return expandDim == pos; })) {
+ projectedPos.push_back(idx);
+ break;
+ }
+ }
+ }
+ assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
+
+ return projectedPos;
+}
+
+/// Bubble up pack op through expand shape op.
+static LogicalResult
+bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
+ tensor::PackOp packOp,
+ PatternRewriter &rewriter) {
+ // Cannot propagate shape expansion if there is outer dimensions permutation.
+ ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+ if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
+ return rewriter.notifyMatchFailure(
+ packOp, "expects outer_dims_perm is empty or an identity permutation");
+ }
+
+ // Validate dimensions' relations between shape expansion and packing.
+ SmallVector<ReassociationIndices, 4> reassoc =
+ expandOp.getReassociationIndices();
+ ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
+ llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
+ packInnerDims.end());
+
+ for (auto [idx, indices] : llvm::enumerate(reassoc)) {
+ llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
+ llvm::SetVector<int64_t> packedDims =
+ llvm::set_intersection(packDimsPos, expandDimPos);
+
+ // The expanded dimension is not packed - simply continue.
+ if (packedDims.empty())
+ continue;
+ // Shape expansion cannot be propagated when multiple expanded dimension are
+ // packed.
+ if (packedDims.size() > 1)
+ return rewriter.notifyMatchFailure(
+ packOp, "only one of the expanded dimensions can be packed");
+ // Only the inner-most dim should be packed. Otherwise, elements order will
+ // be affected after operation reordering.
+ if (packedDims[0] != indices.back())
+ return rewriter.notifyMatchFailure(
+ packOp, "can only pack the inner-most expanded dimension");
+ }
+
+ // Project pack.inner_dims_pos to positions before shape expansion.
+ SmallVector<int64_t> projectedInnerDimsPos =
+ projectDimsPosIntoReassocPos(packInnerDims, reassoc);
+
+ // Project the shape expansion to new packed shape.
+ // The pack.outer_dims_perm is restricted to identity so, the permutation can
+ // be omitted for simplicity.
+ RankedTensorType newPackType = tensor::PackOp::inferPackedType(
+ expandOp.getSrcType(), packOp.getStaticInnerTiles(),
+ projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+ auto reassocExpand =
+ getReassociationIndicesForReshape(newPackType, packOp.getDestType());
+ if (!reassocExpand)
+ return rewriter.notifyMatchFailure(
+ packOp, "could not reassociate dims after bubbling up");
+
+ Value destTensor = tensor::PackOp::createDestinationTensor(
+ rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
+ projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+ Value packedVal = rewriter.create<tensor::PackOp>(
+ packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
+ packOp.getMixedTiles(), packOp.getPaddingValue(),
+ /*outerDimsPerm=*/SmallVector<int64_t>{});
+
+ Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+ packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
+ rewriter.replaceOp(packOp, newExpandOp);
+
+ return success();
+}
+
class BubbleUpPackOpThroughReshapeOp final
: public OpRewritePattern<tensor::PackOp> {
public:
@@ -723,6 +824,9 @@ class BubbleUpPackOpThroughReshapeOp final
.Case([&](tensor::CollapseShapeOp op) {
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
})
+ .Case([&](tensor::ExpandShapeOp op) {
+ return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
+ })
.Default([](Operation *) { return failure(); });
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 9140904620acd..43f9799357df5 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -988,6 +988,210 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4
// -----
+func.func @bubble_up_pack_outer_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x64x4xf32> {
+ %empty = tensor.empty() : tensor<4x2x64x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [1] inner_tiles = [4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x64x4xf32>
+ return %pack : tensor<4x2x64x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_outer_expanded_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x64x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] output_shape [4, 2, 64, 4] : tensor<8x64x4xf32> into tensor<4x2x64x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x64x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_inner_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x4x4xf32> {
+ %empty = tensor.empty() : tensor<32x4x4x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [4] into %empty : tensor<32x4x16xf32> -> tensor<32x4x4x4xf32>
+ return %pack : tensor<32x4x4x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_inner_expanded_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<32x16x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] output_shape [32, 4, 4, 4] : tensor<32x16x4xf32> into tensor<32x4x4x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<32x4x4x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_non_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x32x16x4xf32> {
+ %empty = tensor.empty() : tensor<8x2x32x16x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [4] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x32x16x4xf32>
+ return %pack : tensor<8x2x32x16x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_dims_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x16x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]] output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<8x2x32x16x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_through_expand_dynamic(%arg0: tensor<?x64xf32>) -> tensor<?x4x2x8xf32> {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?x64xf32>
+ %empty = tensor.empty(%dim) : tensor<?x4x2x8xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%dim, 4, 16] : tensor<?x64xf32> into tensor<?x4x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [8] into %empty : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
+ return %pack : tensor<?x4x2x8xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_expand_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM_INPUT:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x64xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM_INPUT]]) : tensor<?x8x8xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] : tensor<?x64xf32> -> tensor<?x8x8xf32>
+// CHECK: %[[DIM_PACK:.+]] = tensor.dim %[[PACK]], %[[C0]] : tensor<?x8x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] output_shape [%[[DIM_PACK]], 4, 2, 8] : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
+// CHECK: return %[[EXPANDED]] : tensor<?x4x2x8xf32>
+
+// -----
+
+func.func @bubble_up_pack_non_expanded_padding_through_expand(%arg0: tensor<32x60xf32>) -> tensor<4x2x8x4x8xf32> {
+ %cst = arith.constant 3.000000e+00 : f32
+ %empty = tensor.empty() : tensor<4x2x8x4x8xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x60xf32> into tensor<4x8x60xf32>
+ %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1, 2] inner_tiles = [4, 8] into %empty : tensor<4x8x60xf32> -> tensor<4x2x8x4x8xf32>
+ return %pack : tensor<4x2x8x4x8xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_padding_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[CST]] : f32) inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %[[EMPTY]] : tensor<32x60xf32> -> tensor<8x8x4x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 8, 4, 8] : tensor<8x8x4x8xf32> into tensor<4x2x8x4x8xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x8x4x8xf32>
+
+// -----
+
+func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x32x4x2xf32> {
+ %empty = tensor.empty() : tensor<4x2x32x4x2xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<4x2x32x4x2xf32>
+ return %pack : tensor<4x2x32x4x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32x4x2xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [4, 2] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x32x4x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 32, 4, 2] : tensor<8x32x4x2xf32> into tensor<4x2x32x4x2xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x32x4x2xf32>
+
+// -----
+
+func.func @bubble_up_pack_multiple_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x4x8x4x8x2xf32> {
+ %empty = tensor.empty() : tensor<8x2x4x8x4x8x2xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0, 2, 3] inner_tiles = [4, 8, 2] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x4x8x4x8x2xf32>
+ return %pack : tensor<8x2x4x8x4x8x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_multiple_dims_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x8x4x8x2xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1, 2] inner_tiles = [4, 8, 2] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x8x8x4x8x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4], [5], [6]] output_shape [8, 2, 4, 8, 4, 8, 2] : tensor<8x8x8x4x8x2xf32> into tensor<8x2x4x8x4x8x2xf32>
+// CHECK: return %[[EXPANDED]] : tensor<8x2x4x8x4x8x2xf32>
+
+// -----
+
+func.func @bubble_up_pack_inner_dims_reorder_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x4x16x4xf32> {
+ %empty = tensor.empty() : tensor<4x2x4x16x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [2, 1] inner_tiles = [16, 4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x4x16x4xf32>
+ return %pack : tensor<4x2x4x16x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_inner_dims_reorder_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x4x16x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 4, 16, 4] : tensor<8x4x16x4xf32> into tensor<4x2x4x16x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x4x16x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<4x2x2x8x16x4x4xf32> {
+ %empty = tensor.empty() : tensor<4x2x2x8x16x4x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [4, 8, 2, 32, 16] : tensor<32x64x16xf32> into tensor<4x8x2x32x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [1, 3] inner_tiles = [4, 4] into %empty : tensor<4x8x2x32x16xf32> -> tensor<4x2x2x8x16x4x4xf32>
+ return %pack : tensor<4x2x2x8x16x4x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x16x16x4x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x16x16x4x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2, 3], [4], [5], [6]] output_shape [4, 2, 2, 8, 16, 4, 4] : tensor<8x16x16x4x4xf32> into tensor<4x2x2x8x16x4x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x2x8x16x4x4xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x2x4x2xf32> {
+ %empty = tensor.empty() : tensor<32x4x2x4x2xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
+ return %pack : tensor<32x4x2x4x2xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x4x2x4x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
+// CHECK: return %[[PACK]] : tensor<32x4x2x4x2xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x2x64x2x4xf32> {
+ %empty = tensor.empty() : tensor<2x2x64x2x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %empty : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
+ return %pack : tensor<2x2x64x2x4xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x64x2x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
+// CHECK: return %[[PACK]] : tensor<2x2x64x2x4xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x8x64x2xf32> {
+ %empty = tensor.empty() : tensor<2x8x64x2xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [2] into %empty : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
+ return %pack : tensor<2x8x64x2xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] inner_dims_pos = [0] inner_tiles = [2] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
+// CHECK: return %[[PACK]] : tensor<2x8x64x2xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(%arg0: tensor<30x60xf32>) -> tensor<3x2x60x8xf32> {
+ %cst = arith.constant 3.000000e+00 : f32
+ %empty = tensor.empty() : tensor<3x2x60x8xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
+ %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [8] into %empty : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
+ return %pack : tensor<3x2x60x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x2x60x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] padding_value(%[[CST]] : f32) inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
+// CHECK: return %[[PACK]] : tensor<3x2x60x8xf32>
+
+// -----
+
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
>From 7f7931a54a35787f799a2a5543196988f3949379 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 29 May 2024 18:51:57 +0200
Subject: [PATCH 2/7] Add TODOs
---
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 73a86caa2fbcb..4af9238c3fcde 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -730,7 +730,8 @@ static LogicalResult
bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
tensor::PackOp packOp,
PatternRewriter &rewriter) {
- // Cannot propagate shape expansion if there is outer dimensions permutation.
+ // Outer dimensions permutation is not supported currently.
+ // TODO: Handle outer_dims_perm variants.
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
return rewriter.notifyMatchFailure(
@@ -771,6 +772,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
// Project the shape expansion to new packed shape.
// The pack.outer_dims_perm is restricted to identity so, the permutation can
// be omitted for simplicity.
+ // TODO: Account for outer dimensions permutation.
RankedTensorType newPackType = tensor::PackOp::inferPackedType(
expandOp.getSrcType(), packOp.getStaticInnerTiles(),
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
>From 85cc0c668419890bf027d4df96964742e2539180 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 31 May 2024 12:29:34 +0200
Subject: [PATCH 3/7] Improve test readability
---
.../Linalg/data-layout-propagation.mlir | 90 +++++++++++++------
1 file changed, 64 insertions(+), 26 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 43f9799357df5..f2d73611186ce 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -997,8 +997,10 @@ func.func @bubble_up_pack_outer_expanded_through_expand(%arg0: tensor<32x64xf32>
// CHECK-LABEL: func.func @bubble_up_pack_outer_expanded_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x4xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x64x4xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] output_shape [4, 2, 64, 4] : tensor<8x64x4xf32> into tensor<4x2x64x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x64x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3]]
+// CHECK-SAME: output_shape [4, 2, 64, 4] : tensor<8x64x4xf32> into tensor<4x2x64x4xf32>
// CHECK: return %[[EXPANDED]] : tensor<4x2x64x4xf32>
// -----
@@ -1012,8 +1014,11 @@ func.func @bubble_up_pack_inner_expanded_through_expand(%arg0: tensor<32x64xf32>
// CHECK-LABEL: func.func @bubble_up_pack_inner_expanded_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x4xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<32x16x4xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] output_shape [32, 4, 4, 4] : tensor<32x16x4xf32> into tensor<32x4x4x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64xf32> -> tensor<32x16x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]]
+// CHECK-SAME: output_shape [32, 4, 4, 4] : tensor<32x16x4xf32> into tensor<32x4x4x4xf32>
// CHECK: return %[[EXPANDED]] : tensor<32x4x4x4xf32>
// -----
@@ -1027,8 +1032,11 @@ func.func @bubble_up_pack_non_expanded_dims_through_expand(%arg0: tensor<32x64x1
// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_dims_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x16x4xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]] output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]]
+// CHECK-SAME: output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32>
// CHECK: return %[[EXPANDED]] : tensor<8x2x32x16x4xf32>
// -----
@@ -1046,9 +1054,12 @@ func.func @bubble_up_pack_through_expand_dynamic(%arg0: tensor<?x64xf32>) -> ten
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM_INPUT:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x64xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM_INPUT]]) : tensor<?x8x8xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] : tensor<?x64xf32> -> tensor<?x8x8xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]]
+// CHECK-SAME: : tensor<?x64xf32> -> tensor<?x8x8xf32>
// CHECK: %[[DIM_PACK:.+]] = tensor.dim %[[PACK]], %[[C0]] : tensor<?x8x8xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] output_shape [%[[DIM_PACK]], 4, 2, 8] : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]]
+// CHECK-SAME: output_shape [%[[DIM_PACK]], 4, 2, 8] : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
// CHECK: return %[[EXPANDED]] : tensor<?x4x2x8xf32>
// -----
@@ -1064,8 +1075,11 @@ func.func @bubble_up_pack_non_expanded_padding_through_expand(%arg0: tensor<32x6
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[CST]] : f32) inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %[[EMPTY]] : tensor<32x60xf32> -> tensor<8x8x4x8xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 8, 4, 8] : tensor<8x8x4x8xf32> into tensor<4x2x8x4x8xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[CST]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x60xf32> -> tensor<8x8x4x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]]
+// CHECK-SAME: output_shape [4, 2, 8, 4, 8] : tensor<8x8x4x8xf32> into tensor<4x2x8x4x8xf32>
// CHECK: return %[[EXPANDED]] : tensor<4x2x8x4x8xf32>
// -----
@@ -1079,8 +1093,11 @@ func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(%arg0: tensor<
// CHECK-LABEL: func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32x4x2xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [4, 2] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x32x4x2xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 32, 4, 2] : tensor<8x32x4x2xf32> into tensor<4x2x32x4x2xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 2] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64xf32> -> tensor<8x32x4x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]]
+// CHECK-SAME: output_shape [4, 2, 32, 4, 2] : tensor<8x32x4x2xf32> into tensor<4x2x32x4x2xf32>
// CHECK: return %[[EXPANDED]] : tensor<4x2x32x4x2xf32>
// -----
@@ -1094,8 +1111,11 @@ func.func @bubble_up_pack_multiple_dims_through_expand(%arg0: tensor<32x64x16xf3
// CHECK-LABEL: func.func @bubble_up_pack_multiple_dims_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x8x4x8x2xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1, 2] inner_tiles = [4, 8, 2] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x8x8x4x8x2xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4], [5], [6]] output_shape [8, 2, 4, 8, 4, 8, 2] : tensor<8x8x8x4x8x2xf32> into tensor<8x2x4x8x4x8x2xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1, 2] inner_tiles = [4, 8, 2] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x8x8x4x8x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4], [5], [6]]
+// CHECK-SAME: output_shape [8, 2, 4, 8, 4, 8, 2] : tensor<8x8x8x4x8x2xf32> into tensor<8x2x4x8x4x8x2xf32>
// CHECK: return %[[EXPANDED]] : tensor<8x2x4x8x4x8x2xf32>
// -----
@@ -1109,8 +1129,11 @@ func.func @bubble_up_pack_inner_dims_reorder_through_expand(%arg0: tensor<32x64x
// CHECK-LABEL: func.func @bubble_up_pack_inner_dims_reorder_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x4xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x4x16x4xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 4, 16, 4] : tensor<8x4x16x4xf32> into tensor<4x2x4x16x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64xf32> -> tensor<8x4x16x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]]
+// CHECK-SAME: output_shape [4, 2, 4, 16, 4] : tensor<8x4x16x4xf32> into tensor<4x2x4x16x4xf32>
// CHECK: return %[[EXPANDED]] : tensor<4x2x4x16x4xf32>
// -----
@@ -1124,8 +1147,11 @@ func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(%arg0:
// CHECK-LABEL: func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x16x16x4x4xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x16x16x4x4xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2, 3], [4], [5], [6]] output_shape [4, 2, 2, 8, 16, 4, 4] : tensor<8x16x16x4x4xf32> into tensor<4x2x2x8x16x4x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x16x16x4x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2, 3], [4], [5], [6]]
+// CHECK-SAME: output_shape [4, 2, 2, 8, 16, 4, 4] : tensor<8x16x16x4x4xf32> into tensor<4x2x2x8x16x4x4xf32>
// CHECK: return %[[EXPANDED]] : tensor<4x2x2x8x16x4x4xf32>
// -----
@@ -1139,8 +1165,11 @@ func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(%arg0: tensor
// CHECK-LABEL: func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x4x2x4x2xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
+// CHECK-SAME: outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %[[EMPTY]]
+// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
// CHECK: return %[[PACK]] : tensor<32x4x2x4x2xf32>
// -----
@@ -1154,8 +1183,11 @@ func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(%arg0: te
// CHECK-LABEL: func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x64x2x4xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
// CHECK: return %[[PACK]] : tensor<2x2x64x2x4xf32>
// -----
@@ -1169,8 +1201,11 @@ func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(%arg0: te
// CHECK-LABEL: func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x2xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] inner_dims_pos = [0] inner_tiles = [2] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [2] into %[[EMPTY]]
+// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
// CHECK: return %[[PACK]] : tensor<2x8x64x2xf32>
// -----
@@ -1186,8 +1221,11 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x2x60x8xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] padding_value(%[[CST]] : f32) inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] padding_value(%[[CST]] : f32)
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]]
+// CHECK-SAME: : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
// CHECK: return %[[PACK]] : tensor<3x2x60x8xf32>
// -----
>From c5f2dfb40cdb994f65d4128d7da34ec8f9ff6171 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 7 Jun 2024 12:23:10 +0200
Subject: [PATCH 4/7] NYI message
---
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 4af9238c3fcde..4472b0d4bc1b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -734,8 +734,8 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
// TODO: Handle outer_dims_perm variants.
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
- return rewriter.notifyMatchFailure(
- packOp, "expects outer_dims_perm is empty or an identity permutation");
+ return rewriter.notifyMatchFailure(packOp,
+ "non-identity outer dims perm NYI");
}
// Validate dimensions' relations between shape expansion and packing.
>From 8fa83252ac31c586a2d7262ecb8138cb3f87b8e6 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 13 Jun 2024 11:28:12 +0200
Subject: [PATCH 5/7] Update
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Co-authored-by: Prashant Kumar <pk5561 at gmail.com>
---
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 4472b0d4bc1b7..bd797892382ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -760,7 +760,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
packOp, "only one of the expanded dimensions can be packed");
// Only the inner-most dim should be packed. Otherwise, elements order will
// be affected after operation reordering.
- if (packedDims[0] != indices.back())
+ if (packedDims.front() != indices.back())
return rewriter.notifyMatchFailure(
packOp, "can only pack the inner-most expanded dimension");
}
>From 6484bdfe68fafaa551232fbc57bd89be8b1bd9ea Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 13 Jun 2024 11:42:43 +0200
Subject: [PATCH 6/7] Address comment
---
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index bd797892382ab..c03a3a7c32450 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -755,7 +755,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
continue;
// Shape expansion cannot be propagated when multiple expanded dimension are
// packed.
- if (packedDims.size() > 1)
+ if (packedDims.size() != 1)
return rewriter.notifyMatchFailure(
packOp, "only one of the expanded dimensions can be packed");
// Only the inner-most dim should be packed. Otherwise, elements order will
>From 752e75a29643d4d1183bacdd93f529a1883e0d12 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 14 Jun 2024 13:44:52 +0200
Subject: [PATCH 7/7] Improve docs + extra test
---
.../Transforms/DataLayoutPropagation.cpp | 32 ++++++++++++++++---
.../Linalg/data-layout-propagation.mlir | 18 +++++++++++
2 files changed, 46 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index c03a3a7c32450..e51ae2264a36a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -726,6 +726,22 @@ projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
}
/// Bubble up pack op through expand shape op.
+///
+/// For example:
+///
+/// %expand = tensor.expand_shape %in [[0], [1, 2]]
+/// : tensor<?x64xf32> into tensor<?x4x16xf32>
+/// %pack = tensor.pack %expand outer_dims_perm = [0, 1]
+/// inner_dims_pos = [2] inner_tiles = [8] into %empty
+/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
+///
+/// can be transformed into:
+///
+/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
+/// inner_dims_pos = [1] inner_tiles = [8] into %empty
+/// : tensor<?x64xf32> -> tensor<?x8x8xf32>
+/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
+/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
static LogicalResult
bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
tensor::PackOp packOp,
@@ -746,20 +762,24 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
packInnerDims.end());
for (auto [idx, indices] : llvm::enumerate(reassoc)) {
+ // For each expand_shape reassociation, figure out which dimensions get
+ // packed if any.
llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
llvm::SetVector<int64_t> packedDims =
llvm::set_intersection(packDimsPos, expandDimPos);
- // The expanded dimension is not packed - simply continue.
+ // The expanded dimension is not packed so, it does not affect moving pack
+ // before shape expansion - simply continue.
if (packedDims.empty())
continue;
// Shape expansion cannot be propagated when multiple expanded dimension are
- // packed.
+ // packed - in this case operation reordering would affect final element
+ // positions and/or shapes can no longer be projected.
if (packedDims.size() != 1)
return rewriter.notifyMatchFailure(
packOp, "only one of the expanded dimensions can be packed");
- // Only the inner-most dim should be packed. Otherwise, elements order will
- // be affected after operation reordering.
+ // Only the inner-most expanded dimension should be packed. Otherwise,
+ // elements order will be affected after operation reordering.
if (packedDims.front() != indices.back())
return rewriter.notifyMatchFailure(
packOp, "can only pack the inner-most expanded dimension");
@@ -773,6 +793,10 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
// The pack.outer_dims_perm is restricted to identity so, the permutation can
// be omitted for simplicity.
// TODO: Account for outer dimensions permutation.
+ //
+ // If reassociation is not possible, then reordering cannot happen.
+ // This can be caused by pack padding affecting previously expanded
+ // dimensions or packing extending dimensions.
RankedTensorType newPackType = tensor::PackOp::inferPackedType(
expandOp.getSrcType(), packOp.getStaticInnerTiles(),
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index f2d73611186ce..78505d0aa4140 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1230,6 +1230,24 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
// -----
+func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
+ %empty = tensor.empty() : tensor<8x4x16x8xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
+ return %pack : tensor<8x4x16x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
+// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
+// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
+
+// -----
+
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
More information about the Mlir-commits
mailing list