[Mlir-commits] [mlir] [mlir][tensor] Rewrite tensor.pack as a constant (PR #93954)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 31 05:23:43 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Adds a pattern to rewrite tensor.pack into arith.constant to avoid runtime packing of a constant tensor.

---

Patch is 22.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93954.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp (+154-1) 
- (modified) mlir/test/Dialect/Tensor/rewrite-as-constant.mlir (+208) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 11e1de543ac91..b63551c268ddc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -6,10 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Threading.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
   }
 };
 
+/// Rewrite tensor.pack with arith.constant if the pack is writing
+/// to an empty tensor and the destination shape is static.
+struct PackToConstant : OpRewritePattern<tensor::PackOp> {
+  using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
+    if (!constOp)
+      return failure();
+    // Must be a dense constant.
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!denseAttr)
+      return failure();
+
+    // Bail out if the pack is used as a writing operation i.e.,
+    // the destination is not a tensor.empty.
+    if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+      return rewriter.notifyMatchFailure(packOp,
+                                         "expects empty tensor destination");
+    // Pack destination must have static shape.
+    if (!packOp.getDestType().hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          packOp, "expects destination with static shape");
+
+    // Pack with padding is not supported currently.
+    // TODO: Insert padding values as a part of rewrite.
+    if (packOp.getPaddingValue())
+      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+    OpBuilder::InsertionGuard guard(rewriter);
+
+    // If it is a splat constant, rewrite the pack directly.
+    if (denseAttr.isSplat()) {
+      DenseElementsAttr packedDenseShape =
+          denseAttr.reshape(packOp.getDestType());
+      rewriter.setInsertionPoint(constOp);
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+      return success();
+    }
+
+    // Constant contains non-splat dense values.
+    // Move the data into a new packed buffer. Each value is placed into its new
+    // position as defined by the pack operation.
+    ArrayRef<char> srcRawData = denseAttr.getRawData();
+    SmallVector<char> destRawData(srcRawData.size());
+
+    int64_t numberOfElements = denseAttr.getNumElements();
+    SmallVector<int64_t> strides =
+        computeStrides(packOp.getDestType().getShape());
+
+    // Parallelize raw data movement to speedup large constant packing.
+    parallelFor(
+        packOp.getContext(), 0, numberOfElements,
+        [&](size_t destLinearizedIdx) {
+          // Step 1: De-linearize destination index.
+          // f(lin) = tmp[A][B][C]
+          SmallVector<int64_t> destIndices =
+              delinearize(destLinearizedIdx, strides);
+
+          // Step 2: Arrange the indexes based on the packing information.
+          // Compute inverse of outerDimsPerm to bring the loops into the
+          // canonical form tmp[A][B][a][b].
+          if (!packOp.getOuterDimsPerm().empty()) {
+            SmallVector<int64_t> inversePermutation =
+                invertPermutationVector(packOp.getOuterDimsPerm());
+            SmallVector<int64_t> tileLoops;
+            for (int64_t i = 0; i < packOp.getSourceType().getRank(); i++)
+              tileLoops.push_back(destIndices[i]);
+            applyPermutationToVector(tileLoops, inversePermutation);
+
+            SmallVector<int64_t> pointLoops;
+            for (size_t i = packOp.getSourceType().getRank();
+                 i < destIndices.size(); i++) {
+              pointLoops.push_back(destIndices[i]);
+            }
+
+            destIndices = tileLoops;
+            destIndices.append(pointLoops.begin(), pointLoops.end());
+          }
+          assert(destIndices.size() ==
+                 static_cast<size_t>(packOp.getDestType().getRank()));
+
+          // After interchanging the outermost tiled loop we end up in the
+          // canonical form tmp[A][B][a][b]. Squash the point loops with the
+          // tiled ones.
+          llvm::DenseSet<int64_t> tiledLoops(packOp.getInnerDimsPos().begin(),
+                                             packOp.getInnerDimsPos().end());
+          llvm::DenseMap<int64_t, int64_t> mappingTileToPointLoops;
+          // Map the position of the tiled loops with the point one.
+          // For example:
+          // [A][B] -> [A][B][a][b]
+          // entry: [A : 0] [a : 2]
+          // entry: [B : 1] [b : 3]
+          // [A][B] -> [A][B][b]
+          // entry: [B : 1] [b : 2]
+          for (auto [idx, tileLoop] : llvm::enumerate(packOp.getInnerDimsPos()))
+            mappingTileToPointLoops[tileLoop] = idx;
+
+          SmallVector<int64_t> srcIndices;
+          SmallVector<int64_t> tilesSizes = packOp.getStaticTiles();
+          int64_t numberOfTileLoops = packOp.getSourceType().getRank();
+          size_t tilePosIdx = 0;
+          for (int64_t i = 0; i < numberOfTileLoops; i++) {
+            if (!tiledLoops.count(i)) {
+              // Loop is not tiled.
+              srcIndices.push_back(destIndices[i]);
+            } else {
+              // Loop is tiled, account for the point loop distance.
+              srcIndices.push_back(
+                  destIndices[i] * tilesSizes[tilePosIdx] +
+                  destIndices[numberOfTileLoops + mappingTileToPointLoops[i]]);
+              tilePosIdx++;
+            }
+          }
+          assert(srcIndices.size() == static_cast<size_t>(numberOfTileLoops));
+
+          int64_t srcLinearizedIdx = linearize(
+              srcIndices, computeStrides(packOp.getSourceType().getShape()));
+          assert(srcLinearizedIdx < numberOfElements);
+
+          // Step 3: Do the packing.
+          // Copy the source element byte-wise to its packed destination
+          // position.
+          size_t elementByteSize =
+              denseAttr.getRawData().size() / denseAttr.getNumElements();
+          for (size_t i = 0; i < elementByteSize; i++) {
+            destRawData[destLinearizedIdx * elementByteSize + i] =
+                srcRawData[srcLinearizedIdx * elementByteSize + i];
+          }
+        });
+
+    // Fail gracefully if something went wrong.
+    bool detectSpalt = false;
+    if (!DenseElementsAttr::isValidRawBuffer(packOp.getDestType(), destRawData,
+                                             detectSpalt))
+      return rewriter.notifyMatchFailure(
+          packOp, "failed to create packed raw data buffer");
+
+    // Replace the pack with a new constant.
+    auto packedDenseShape =
+        DenseElementsAttr::getFromRawBuffer(packOp.getDestType(), destRawData);
+    rewriter.setInsertionPoint(constOp);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateRewriteAsConstantPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<GenerateToConstant>(patterns.getContext());
+  patterns.add<GenerateToConstant, PackToConstant>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
index 1a1cf9e407d80..045cb5a0da1d5 100644
--- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
+++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
@@ -21,3 +21,211 @@ func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
   } : tensor<2x3x5xf32>
   return %0 : tensor<2x3x5xf32>
 }
+
+// CHECK-LABEL: func.func @fold_pack_with_splat
+// CHECK: %[[CST:.+]] = arith.constant dense<1> : tensor<8x2x1x1x32x32xi64>
+// CHECK-NEXT: return %[[CST]] : tensor<8x2x1x1x32x32xi64>
+func.func @fold_pack_with_splat() ->  tensor<8x2x1x1x32x32xi64> {
+  %cst = arith.constant dense<1> : tensor<1x1x64x256xi64>
+  %0 = tensor.empty() : tensor<8x2x1x1x32x32xi64>
+  %pack = tensor.pack %cst outer_dims_perm = [3, 2, 0, 1] inner_dims_pos = [2, 3] inner_tiles = [32, 32]
+    into %0 : tensor<1x1x64x256xi64> -> tensor<8x2x1x1x32x32xi64>
+  return  %pack : tensor<8x2x1x1x32x32xi64>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [8.000000e+00, 9.000000e+00], [1.600000e+01, 1.700000e+01], [2.400000e+01, 2.500000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [1.000000e+01, 1.100000e+01], [1.800000e+01, 1.900000e+01], [2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00], [1.200000e+01, 1.300000e+01], [2.000000e+01, 2.100000e+01], [2.800000e+01, 2.900000e+01]
+// CHECK-SAME: [6.000000e+00, 7.000000e+00], [1.400000e+01, 1.500000e+01], [2.200000e+01, 2.300000e+01], [3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01], [4.000000e+01, 4.100000e+01], [4.900000e+01, 5.000000e+01], [5.700000e+01, 5.800000e+01]
+// CHECK-SAME: [3.400000e+01, 3.500000e+01], [4.200000e+01, 4.300000e+01], [5.100000e+01, 5.200000e+01], [5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01], [4.400000e+01, 4.500000e+01], [5.300000e+01, 5.400000e+01], [6.100000e+01, 6.200000e+01]
+// CHECK-SAME: [3.800000e+01, 3.900000e+01], [4.600000e+01, 4.700000e+01], [5.500000e+01, 5.600000e+01], [6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<2x4x4x2xf32>
+func.func @fold_pack_with_non_splat() -> tensor<2x4x4x2xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %pack = tensor.pack %cst inner_dims_pos = [0, 1] inner_tiles = [4, 2]
+    into %0 : tensor<8x8xf32> -> tensor<2x4x4x2xf32>
+  return %pack : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_dims_reordered
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 8.000000e+00, 1.600000e+01, 2.400000e+01], [1.000000e+00, 9.000000e+00, 1.700000e+01, 2.500000e+01]
+// CHECK-SAME: [4.000000e+00, 1.200000e+01, 2.000000e+01, 2.800000e+01], [5.000000e+00, 1.300000e+01, 2.100000e+01, 2.900000e+01]
+// CHECK-SAME: [8.000000e+00, 1.600000e+01, 2.400000e+01, 3.200000e+01], [9.000000e+00, 1.700000e+01, 2.500000e+01, 3.300000e+01]
+// CHECK-SAME: [1.200000e+01, 2.000000e+01, 2.800000e+01, 3.600000e+01], [1.300000e+01, 2.100000e+01, 2.900000e+01, 3.700000e+01]
+// CHECK-SAME: [1.600000e+01, 2.400000e+01, 3.200000e+01, 4.000000e+01], [1.700000e+01, 2.500000e+01, 3.300000e+01, 4.100000e+01]
+// CHECK-SAME: [2.000000e+01, 2.800000e+01, 3.600000e+01, 4.400000e+01], [2.100000e+01, 2.900000e+01, 3.700000e+01, 4.500000e+01]
+// CHECK-SAME: [2.400000e+01, 3.200000e+01, 4.000000e+01, 4.900000e+01], [2.500000e+01, 3.300000e+01, 4.100000e+01, 5.000000e+01]
+// CHECK-SAME: [2.800000e+01, 3.600000e+01, 4.400000e+01, 5.300000e+01], [2.900000e+01, 3.700000e+01, 4.500000e+01, 5.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<2x4x2x4xf32>
+func.func @fold_pack_with_non_splat_with_inner_dims_reordered() -> tensor<2x4x2x4xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %0 = tensor.empty() : tensor<2x4x2x4xf32>
+  %pack = tensor.pack %cst inner_dims_pos = [1, 0] inner_tiles = [2, 4]
+    into %0 : tensor<8x8xf32> -> tensor<2x4x2x4xf32>
+  return %pack : tensor<2x4x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_tiles_reordered
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00], [1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01]
+// CHECK-SAME: [1.600000e+01, 1.700000e+01, 1.800000e+01, 1.900000e+01], [2.400000e+01, 2.500000e+01, 2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [2.000000e+01, 2.100000e+01, 2.200000e+01, 2.300000e+01], [2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01, 3.400000e+01, 3.500000e+01], [4.000000e+01, 4.100000e+01, 4.200000e+01, 4.300000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01, 3.800000e+01, 3.900000e+01], [4.400000e+01, 4.500000e+01, 4.600000e+01, 4.700000e+01]
+// CHECK-SAME: [4.900000e+01, 5.000000e+01, 5.100000e+01, 5.200000e+01], [5.700000e+01, 5.800000e+01, 5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [5.300000e+01, 5.400000e+01, 5.500000e+01, 5.600000e+01], [6.100000e+01, 6.200000e+01, 6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<4x2x2x4xf32>
+func.func @fold_pack_with_non_splat_with_inner_tiles_reordered() -> tensor<4x2x2x4xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %0 = tensor.empty() : tensor<4x2x2x4xf32>
+  %pack = tensor.pack %cst inner_dims_pos = [0, 1] inner_tiles = [2, 4]
+    into %0 : tensor<8x8xf32> -> tensor<4x2x2x4xf32>
+  return %pack : tensor<4x2x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_outer_permutation
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [8.000000e+00, 9.000000e+00], [1.600000e+01, 1.700000e+01], [2.400000e+01, 2.500000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01], [4.000000e+01, 4.100000e+01], [4.900000e+01, 5.000000e+01], [5.700000e+01, 5.800000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [1.000000e+01, 1.100000e+01], [1.800000e+01, 1.900000e+01], [2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [3.400000e+01, 3.500000e+01], [4.200000e+01, 4.300000e+01], [5.100000e+01, 5.200000e+01], [5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00], [1.200000e+01, 1.300000e+01], [2.000000e+01, 2.100000e+01], [2.800000e+01, 2.900000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01], [4.400000e+01, 4.500000e+01], [5.300000e+01, 5.400000e+01], [6.100000e+01, 6.200000e+01]
+// CHECK-SAME: [6.000000e+00, 7.000000e+00], [1.400000e+01, 1.500000e+01], [2.200000e+01, 2.300000e+01], [3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.800000e+01, 3.900000e+01], [4.600000e+01, 4.700000e+01], [5.500000e+01, 5.600000e+01], [6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<4x2x4x2xf32>
+func.func @fold_pack_with_non_splat_with_outer_permutation() -> tensor<4x2x4x2xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %0 = tensor.empty() : tensor<4x2x4x2xf32>
+  %pack = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 2]
+    into %0 : tensor<8x8xf32> -> tensor<4x2x4x2xf32>
+  return %pack : tensor<4x2x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_and_outer
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [4.000000e+00, 5.000000e+00]
+// CHECK-SAME: [8.000000e+00, 9.000000e+00], [1.200000e+01, 1.300000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [6.000000e+00, 7.000000e+00]
+// CHECK-SAME: [1.000000e+01, 1.100000e+01], [1.400000e+01, 1.500000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<1x2x2x2x2xf32>
+func.func @fold_pack_with_non_splat_with_inner_and_outer_permutations() -> tensor<1x2x2x2x2xf32> {
+  %cst = arith.constant dense <[[[[0.0, 1.0, 2.0, 3.0],   [4.0, 5.0, 6.0, 7.0]],
+                                 [[8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]]]]> : tensor<1x2x2x4xf32>
+  %0 = tensor.empty() : tensor<1x2x2x2x2xf32>
+  %1 = tensor.pack %cst outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
+    into %0 : tensor<1x2x2x4xf32> -> tensor<1x2x2x2x2xf32>
+  return %1 : tensor<1x2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func.func @no_fold_pack_into_non_empty_with_non_splat
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK: return %[[PACK]] : tensor<2x4x2x4xf32>
+func.func @no_fold_pack_into_non_empty_with_non_splat(%arg0: tensor<2x4x2x4xf32>) -> tensor<2x4x2x4xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %pack = tensor.pack %cst inner_dims_pos = [1, 0] inner_tiles = [2, 4]
+    into %arg0 : tensor<8x8xf32> -> tensor<2x4x2x4xf32>
+  return %pack : tensor<2x4x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @no_fold_dynamic_inner_tile_pack_with_non_splat
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK: return %[[PACK]] : tensor<?x4x2x?xf32>
+func.func @no_fold_dynamic_inner_tile_pack_with_non_splat(%outer: index, %tile: index) -> tensor<?x4x2x?xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/93954


More information about the Mlir-commits mailing list