[Mlir-commits] [mlir] [mlir][tensor] Rewrite tensor.pack as a constant (PR #93954)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 31 05:23:43 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Adds a pattern to rewrite tensor.pack into arith.constant to avoid runtime packing of a constant tensor.
---
Patch is 22.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93954.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp (+154-1)
- (modified) mlir/test/Dialect/Tensor/rewrite-as-constant.mlir (+208)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 11e1de543ac91..b63551c268ddc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -6,10 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Threading.h"
using namespace mlir;
using namespace mlir::tensor;
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
}
};
+/// Rewrite tensor.pack with arith.constant if the pack is writing
+/// to an empty tensor and the destination shape is static.
+struct PackToConstant : OpRewritePattern<tensor::PackOp> {
+ using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
+ if (!constOp)
+ return failure();
+ // Must be a dense constant.
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!denseAttr)
+ return failure();
+
+ // Bail out if the pack is used as a writing operation i.e.,
+ // the destination is not a tensor.empty.
+ if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+ return rewriter.notifyMatchFailure(packOp,
+ "expects empty tensor destination");
+ // Pack destination must have static shape.
+ if (!packOp.getDestType().hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ packOp, "expects destination with static shape");
+
+ // Pack with padding is not supported currently.
+ // TODO: Insert padding values as a part of rewrite.
+ if (packOp.getPaddingValue())
+ return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+ OpBuilder::InsertionGuard guard(rewriter);
+
+ // If it is a splat constant, rewrite the pack directly.
+ if (denseAttr.isSplat()) {
+ DenseElementsAttr packedDenseShape =
+ denseAttr.reshape(packOp.getDestType());
+ rewriter.setInsertionPoint(constOp);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+ return success();
+ }
+
+ // Constant contains non-splat dense values.
+ // Move the data into a new packed buffer. Each value is placed into its new
+ // position as defined by the pack operation.
+ ArrayRef<char> srcRawData = denseAttr.getRawData();
+ SmallVector<char> destRawData(srcRawData.size());
+
+ int64_t numberOfElements = denseAttr.getNumElements();
+ SmallVector<int64_t> strides =
+ computeStrides(packOp.getDestType().getShape());
+
+ // Parallelize raw data movement to speedup large constant packing.
+ parallelFor(
+ packOp.getContext(), 0, numberOfElements,
+ [&](size_t destLinearizedIdx) {
+ // Step 1: De-linearize destination index.
+ // f(lin) = tmp[A][B][C]
+ SmallVector<int64_t> destIndices =
+ delinearize(destLinearizedIdx, strides);
+
+ // Step 2: Arrange the indexes based on the packing information.
+ // Compute inverse of outerDimsPerm to bring the loops into the
+ // canonical form tmp[A][B][a][b].
+ if (!packOp.getOuterDimsPerm().empty()) {
+ SmallVector<int64_t> inversePermutation =
+ invertPermutationVector(packOp.getOuterDimsPerm());
+ SmallVector<int64_t> tileLoops;
+ for (int64_t i = 0; i < packOp.getSourceType().getRank(); i++)
+ tileLoops.push_back(destIndices[i]);
+ applyPermutationToVector(tileLoops, inversePermutation);
+
+ SmallVector<int64_t> pointLoops;
+ for (size_t i = packOp.getSourceType().getRank();
+ i < destIndices.size(); i++) {
+ pointLoops.push_back(destIndices[i]);
+ }
+
+ destIndices = tileLoops;
+ destIndices.append(pointLoops.begin(), pointLoops.end());
+ }
+ assert(destIndices.size() ==
+ static_cast<size_t>(packOp.getDestType().getRank()));
+
+ // After interchanging the outermost tiled loop we end up in the
+ // canonical form tmp[A][B][a][b]. Squash the point loops with the
+ // tiled ones.
+ llvm::DenseSet<int64_t> tiledLoops(packOp.getInnerDimsPos().begin(),
+ packOp.getInnerDimsPos().end());
+ llvm::DenseMap<int64_t, int64_t> mappingTileToPointLoops;
+ // Map the position of the tiled loops with the point one.
+ // For example:
+ // [A][B] -> [A][B][a][b]
+ // entry: [A : 0] [a : 2]
+ // entry: [B : 1] [b : 3]
+ // [A][B] -> [A][B][b]
+ // entry: [B : 1] [b : 2]
+ for (auto [idx, tileLoop] : llvm::enumerate(packOp.getInnerDimsPos()))
+ mappingTileToPointLoops[tileLoop] = idx;
+
+ SmallVector<int64_t> srcIndices;
+ SmallVector<int64_t> tilesSizes = packOp.getStaticTiles();
+ int64_t numberOfTileLoops = packOp.getSourceType().getRank();
+ size_t tilePosIdx = 0;
+ for (int64_t i = 0; i < numberOfTileLoops; i++) {
+ if (!tiledLoops.count(i)) {
+ // Loop is not tiled.
+ srcIndices.push_back(destIndices[i]);
+ } else {
+ // Loop is tiled, account for the point loop distance.
+ srcIndices.push_back(
+ destIndices[i] * tilesSizes[tilePosIdx] +
+ destIndices[numberOfTileLoops + mappingTileToPointLoops[i]]);
+ tilePosIdx++;
+ }
+ }
+ assert(srcIndices.size() == static_cast<size_t>(numberOfTileLoops));
+
+ int64_t srcLinearizedIdx = linearize(
+ srcIndices, computeStrides(packOp.getSourceType().getShape()));
+ assert(srcLinearizedIdx < numberOfElements);
+
+ // Step 3: Do the packing.
+ // Copy the source element byte-wise to its packed destination
+ // position.
+ size_t elementByteSize =
+ denseAttr.getRawData().size() / denseAttr.getNumElements();
+ for (size_t i = 0; i < elementByteSize; i++) {
+ destRawData[destLinearizedIdx * elementByteSize + i] =
+ srcRawData[srcLinearizedIdx * elementByteSize + i];
+ }
+ });
+
+ // Fail gracefully if something went wrong.
+ bool detectSpalt = false;
+ if (!DenseElementsAttr::isValidRawBuffer(packOp.getDestType(), destRawData,
+ detectSpalt))
+ return rewriter.notifyMatchFailure(
+ packOp, "failed to create packed raw data buffer");
+
+ // Replace the pack with a new constant.
+ auto packedDenseShape =
+ DenseElementsAttr::getFromRawBuffer(packOp.getDestType(), destRawData);
+ rewriter.setInsertionPoint(constOp);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::tensor::populateRewriteAsConstantPatterns(
RewritePatternSet &patterns) {
- patterns.add<GenerateToConstant>(patterns.getContext());
+ patterns.add<GenerateToConstant, PackToConstant>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
index 1a1cf9e407d80..045cb5a0da1d5 100644
--- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
+++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
@@ -21,3 +21,211 @@ func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
} : tensor<2x3x5xf32>
return %0 : tensor<2x3x5xf32>
}
+
+// CHECK-LABEL: func.func @fold_pack_with_splat
+// CHECK: %[[CST:.+]] = arith.constant dense<1> : tensor<8x2x1x1x32x32xi64>
+// CHECK-NEXT: return %[[CST]] : tensor<8x2x1x1x32x32xi64>
+func.func @fold_pack_with_splat() -> tensor<8x2x1x1x32x32xi64> {
+ %cst = arith.constant dense<1> : tensor<1x1x64x256xi64>
+ %0 = tensor.empty() : tensor<8x2x1x1x32x32xi64>
+ %pack = tensor.pack %cst outer_dims_perm = [3, 2, 0, 1] inner_dims_pos = [2, 3] inner_tiles = [32, 32]
+ into %0 : tensor<1x1x64x256xi64> -> tensor<8x2x1x1x32x32xi64>
+ return %pack : tensor<8x2x1x1x32x32xi64>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [8.000000e+00, 9.000000e+00], [1.600000e+01, 1.700000e+01], [2.400000e+01, 2.500000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [1.000000e+01, 1.100000e+01], [1.800000e+01, 1.900000e+01], [2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00], [1.200000e+01, 1.300000e+01], [2.000000e+01, 2.100000e+01], [2.800000e+01, 2.900000e+01]
+// CHECK-SAME: [6.000000e+00, 7.000000e+00], [1.400000e+01, 1.500000e+01], [2.200000e+01, 2.300000e+01], [3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01], [4.000000e+01, 4.100000e+01], [4.900000e+01, 5.000000e+01], [5.700000e+01, 5.800000e+01]
+// CHECK-SAME: [3.400000e+01, 3.500000e+01], [4.200000e+01, 4.300000e+01], [5.100000e+01, 5.200000e+01], [5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01], [4.400000e+01, 4.500000e+01], [5.300000e+01, 5.400000e+01], [6.100000e+01, 6.200000e+01]
+// CHECK-SAME: [3.800000e+01, 3.900000e+01], [4.600000e+01, 4.700000e+01], [5.500000e+01, 5.600000e+01], [6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<2x4x4x2xf32>
+func.func @fold_pack_with_non_splat() -> tensor<2x4x4x2xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %0 = tensor.empty() : tensor<2x4x4x2xf32>
+ %pack = tensor.pack %cst inner_dims_pos = [0, 1] inner_tiles = [4, 2]
+ into %0 : tensor<8x8xf32> -> tensor<2x4x4x2xf32>
+ return %pack : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_dims_reordered
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 8.000000e+00, 1.600000e+01, 2.400000e+01], [1.000000e+00, 9.000000e+00, 1.700000e+01, 2.500000e+01]
+// CHECK-SAME: [4.000000e+00, 1.200000e+01, 2.000000e+01, 2.800000e+01], [5.000000e+00, 1.300000e+01, 2.100000e+01, 2.900000e+01]
+// CHECK-SAME: [8.000000e+00, 1.600000e+01, 2.400000e+01, 3.200000e+01], [9.000000e+00, 1.700000e+01, 2.500000e+01, 3.300000e+01]
+// CHECK-SAME: [1.200000e+01, 2.000000e+01, 2.800000e+01, 3.600000e+01], [1.300000e+01, 2.100000e+01, 2.900000e+01, 3.700000e+01]
+// CHECK-SAME: [1.600000e+01, 2.400000e+01, 3.200000e+01, 4.000000e+01], [1.700000e+01, 2.500000e+01, 3.300000e+01, 4.100000e+01]
+// CHECK-SAME: [2.000000e+01, 2.800000e+01, 3.600000e+01, 4.400000e+01], [2.100000e+01, 2.900000e+01, 3.700000e+01, 4.500000e+01]
+// CHECK-SAME: [2.400000e+01, 3.200000e+01, 4.000000e+01, 4.900000e+01], [2.500000e+01, 3.300000e+01, 4.100000e+01, 5.000000e+01]
+// CHECK-SAME: [2.800000e+01, 3.600000e+01, 4.400000e+01, 5.300000e+01], [2.900000e+01, 3.700000e+01, 4.500000e+01, 5.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<2x4x2x4xf32>
+func.func @fold_pack_with_non_splat_with_inner_dims_reordered() -> tensor<2x4x2x4xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %0 = tensor.empty() : tensor<2x4x2x4xf32>
+ %pack = tensor.pack %cst inner_dims_pos = [1, 0] inner_tiles = [2, 4]
+ into %0 : tensor<8x8xf32> -> tensor<2x4x2x4xf32>
+ return %pack : tensor<2x4x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_tiles_reordered
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00], [1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01]
+// CHECK-SAME: [1.600000e+01, 1.700000e+01, 1.800000e+01, 1.900000e+01], [2.400000e+01, 2.500000e+01, 2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [2.000000e+01, 2.100000e+01, 2.200000e+01, 2.300000e+01], [2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01, 3.400000e+01, 3.500000e+01], [4.000000e+01, 4.100000e+01, 4.200000e+01, 4.300000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01, 3.800000e+01, 3.900000e+01], [4.400000e+01, 4.500000e+01, 4.600000e+01, 4.700000e+01]
+// CHECK-SAME: [4.900000e+01, 5.000000e+01, 5.100000e+01, 5.200000e+01], [5.700000e+01, 5.800000e+01, 5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [5.300000e+01, 5.400000e+01, 5.500000e+01, 5.600000e+01], [6.100000e+01, 6.200000e+01, 6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<4x2x2x4xf32>
+func.func @fold_pack_with_non_splat_with_inner_tiles_reordered() -> tensor<4x2x2x4xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %0 = tensor.empty() : tensor<4x2x2x4xf32>
+ %pack = tensor.pack %cst inner_dims_pos = [0, 1] inner_tiles = [2, 4]
+ into %0 : tensor<8x8xf32> -> tensor<4x2x2x4xf32>
+ return %pack : tensor<4x2x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_outer_permutation
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [8.000000e+00, 9.000000e+00], [1.600000e+01, 1.700000e+01], [2.400000e+01, 2.500000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01], [4.000000e+01, 4.100000e+01], [4.900000e+01, 5.000000e+01], [5.700000e+01, 5.800000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [1.000000e+01, 1.100000e+01], [1.800000e+01, 1.900000e+01], [2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [3.400000e+01, 3.500000e+01], [4.200000e+01, 4.300000e+01], [5.100000e+01, 5.200000e+01], [5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00], [1.200000e+01, 1.300000e+01], [2.000000e+01, 2.100000e+01], [2.800000e+01, 2.900000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01], [4.400000e+01, 4.500000e+01], [5.300000e+01, 5.400000e+01], [6.100000e+01, 6.200000e+01]
+// CHECK-SAME: [6.000000e+00, 7.000000e+00], [1.400000e+01, 1.500000e+01], [2.200000e+01, 2.300000e+01], [3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.800000e+01, 3.900000e+01], [4.600000e+01, 4.700000e+01], [5.500000e+01, 5.600000e+01], [6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<4x2x4x2xf32>
+func.func @fold_pack_with_non_splat_with_outer_permutation() -> tensor<4x2x4x2xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %0 = tensor.empty() : tensor<4x2x4x2xf32>
+ %pack = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 2]
+ into %0 : tensor<8x8xf32> -> tensor<4x2x4x2xf32>
+ return %pack : tensor<4x2x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_and_outer
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [4.000000e+00, 5.000000e+00]
+// CHECK-SAME: [8.000000e+00, 9.000000e+00], [1.200000e+01, 1.300000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [6.000000e+00, 7.000000e+00]
+// CHECK-SAME: [1.000000e+01, 1.100000e+01], [1.400000e+01, 1.500000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<1x2x2x2x2xf32>
+func.func @fold_pack_with_non_splat_with_inner_and_outer_permutations() -> tensor<1x2x2x2x2xf32> {
+ %cst = arith.constant dense <[[[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]],
+ [[8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]]]]> : tensor<1x2x2x4xf32>
+ %0 = tensor.empty() : tensor<1x2x2x2x2xf32>
+ %1 = tensor.pack %cst outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
+ into %0 : tensor<1x2x2x4xf32> -> tensor<1x2x2x2x2xf32>
+ return %1 : tensor<1x2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func.func @no_fold_pack_into_non_empty_with_non_splat
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK: return %[[PACK]] : tensor<2x4x2x4xf32>
+func.func @no_fold_pack_into_non_empty_with_non_splat(%arg0: tensor<2x4x2x4xf32>) -> tensor<2x4x2x4xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %pack = tensor.pack %cst inner_dims_pos = [1, 0] inner_tiles = [2, 4]
+ into %arg0 : tensor<8x8xf32> -> tensor<2x4x2x4xf32>
+ return %pack : tensor<2x4x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @no_fold_dynamic_inner_tile_pack_with_non_splat
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK: return %[[PACK]] : tensor<?x4x2x?xf32>
+func.func @no_fold_dynamic_inner_tile_pack_with_non_splat(%outer: index, %tile: index) -> tensor<?x4x2x?xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41....
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/93954
More information about the Mlir-commits
mailing list