[Mlir-commits] [mlir] [mlir][tensor] Extend the logic to generalise tensor.pack (PR #109815)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Oct 1 08:50:08 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/109815

>From bd5809e04aabf2a35cd925314157599ae46f215a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 19 Sep 2024 17:44:51 +0000
Subject: [PATCH 1/5] [mlir][tensor] Extend the logic to generalise tensor.pack

Extends the logic to generalise tensor.pack (into e.g. tensor.pad +
tensor.transpose) so that it also works when one of the inner tile sizes
is scalable (i.e. a multiple of `vector.vscale`). For example:
```mlir
  %c8 = arith.constant 8 : index
  %vscale = vector.vscale
  %c8_vscale = arith.muli %vscale, %c8 : index
  %0 = tensor.pack %input
      padding_value(%pad : f32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8_vscale, 2]
      into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
}
```
is generalised as:
```mlir
  %c8 = arith.constant 8 : index
  %vscale = vector.vscale
  %c8_vscale = arith.muli %vscale, %c8 : index
  %0 = affine.apply #map()[%c8_vscale, %c5]
  %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
  ^bb0(%arg3: index, %arg4: index):
    tensor.yield %arg2 : f32
  } : tensor<5x1xf32> to tensor<?x2xf32>
```

At the Tensor level, we model scalability using dynamic shapes and this
change basically extends the relevant logic so that it also works for
dynamic shapes. However, rather than allowing arbitrary values and
number of tile sizes to be dynamic, only _one_ tile size is allowed to
be dynamic. In addition, it is required to be a constant multiple of
`vector.vscale`.

While the requirements above can be relaxed, I wanted to avoid full
generality for now. Primarily to avoid complexity that's not yet needed
and to make reviewing a bit easier.
---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h |  3 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 92 +++++++++++++++----
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       | 41 +++++++--
 .../Linalg/generalize-tensor-pack.mlir        | 35 +++++++
 4 files changed, 142 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 84d06d456bb689..b2366748804156 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -19,7 +19,8 @@ namespace tensor {
 // dimensions the padding width is set to zero. The op performs "high" padding
 // (i.e. it adds trailing padding values until the desired size is met).
 PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
-                      bool nofold, Location loc, OpBuilder &builder);
+                      bool nofold, Location loc, OpBuilder &builder,
+                      std::optional<Value> dynOutDim = {});
 
 // Creates dim ops for each dynamic dimension of the ranked tensor argument and
 // returns these as values.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e0dea8e78d55c1..42389b431566eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1021,8 +1021,16 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
   return success();
 }
 
-/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
-/// source directly. The method assumes that the `packOp` has static shapes.
+/// If padding value is set, returns a tensor.pad Op for the source tensor,
+/// with the output shape matching the output of `packOp`. Otherwise, returns
+/// the source directly.
+///
+/// This method assumes that all outer dims for this pack Op are 1.
+///
+/// At most _one_ inner tile size can be _dynamic_, all other inner tiles are
+/// required to have static sizes. The inner tile that's dynamic must be a
+/// multiple of vector.vscale (to support scalable tile sizes). This condition
+/// can be relaxed in the future.
 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                            tensor::PackOp packOp) {
   Value input = packOp.getSource();
@@ -1038,26 +1046,50 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
   ShapedType inputType = packOp.getSourceType();
   int64_t inputRank = inputType.getRank();
 
-  SmallVector<int64_t> paddedShape;
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
       packOp.getDimAndTileMapping();
-  for (int64_t dim = 0; dim < inputRank; ++dim) {
-    int64_t size = inputType.getDimSize(dim);
-    if (!tileAndPosMapping.count(dim)) {
-      paddedShape.push_back(size);
+
+  // The size of a scalable tile (if present).
+  Value scalableSize;
+
+  // Collect dims for the padded shape.
+  SmallVector<int64_t> paddedShape;
+  for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
+    int64_t inputDimSize = inputType.getDimSize(dimIdx);
+    // 1. Non-tiled outer dims.
+    // These dims should be 1 and we simply preserve them.
+    if (!tileAndPosMapping.count(dimIdx)) {
+      assert(inputDimSize == 1 &&
+             "with all outer dims == 1, this non-tiled input dim should be 1!");
+      paddedShape.push_back(inputDimSize);
+      continue;
+    }
+
+    // 2. Tiled outer dims
+    // As all outer dims == 1, it is safe to use the tile size for the padded
+    // shape.
+    OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
+
+    // 2.1 Static tile sizes
+    std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
+    if (cstTileSize.has_value()) {
+      paddedShape.push_back(cstTileSize.value());
       continue;
     }
 
-    // The size is less than or equal to tileSize because outer dims are all 1s.
-    std::optional<int64_t> tileSize =
-        getConstantIntValue(tileAndPosMapping.lookup(dim));
-    assert(tileSize.has_value() && "dynamic inner tile size is not supported");
-    paddedShape.push_back(tileSize.value());
+    // 2.2 Dynamic tile sizes
+    paddedShape.push_back(ShapedType::kDynamic);
+
+    // Get the value that holds the scalable size.
+    assert(!scalableSize && "Only one scalable size is supported ATM.");
+    scalableSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
+    assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+           "This dynamic shape is not a multiple of vscale, this !");
   }
   auto resultType =
       RankedTensorType::get(paddedShape, inputType.getElementType());
   return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
-                                 /*nofold=*/false, loc, builder);
+                                 /*nofold=*/false, loc, builder, scalableSize);
 }
 
 // Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1152,18 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
-  if (llvm::any_of(packOp.getMixedTiles(),
-                   [](OpFoldResult tile) { return tile.is<Value>(); })) {
-    return rewriter.notifyMatchFailure(packOp,
-                                       "require inner tile sizes being static");
+  if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
+        return tile.is<Value>() && !vector::getConstantVscaleMultiplier(
+                                       llvm::dyn_cast<Value>(tile));
+      })) {
+    return rewriter.notifyMatchFailure(
+        packOp, "require inner tile sizes to be either static or a constant "
+                "multiple of vector.vscale");
+  }
+  if (llvm::count_if(packOp.getMixedTiles(),
+                     [](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
+    return rewriter.notifyMatchFailure(
+        packOp, "at most one dynamic tile size is supported");
   }
 
   // TODO: support the case that outer dimensions are not all 1s. A
@@ -1181,7 +1221,23 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 
-  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+  // If there's a tile with a scalable size, retrieve its size. ATM only 1
+  // scalable tile is allowed.
+  Value scalableSize;
+  for (auto tile : packOp.getMixedTiles()) {
+    if (tile.is<Value>()) {
+      assert(!scalableSize && "Only one scalable size is supported ATM.");
+      scalableSize = cast<Value>(tile);
+      assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+             "This dynamic shape is not a multiple of vscale!");
+    }
+  }
+
+  Value empty =
+      ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
+          ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
+                                             scalableSize)
+          : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
 
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index a0d8a08fc6ba47..ffa58b796f91ae 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR//VectorOps.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
@@ -23,19 +24,39 @@ using namespace mlir::tensor;
 
 PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
                                     Value pad, bool nofold, Location loc,
-                                    OpBuilder &b) {
+                                    OpBuilder &b,
+                                    std::optional<Value> dynOutDim) {
+  assert(llvm::count_if(
+             type.getShape(),
+             [](int64_t dim) { return ShapedType::isDynamic(dim); }) <= 1 &&
+         "At most one output dim can be dynamic!");
+
+  // Init "low" and "high" padding values ("low" is kept as is, "high" is
+  // computed below).
   SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
   SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
   for (const auto &en : enumerate(type.getShape())) {
-    // Pad only the static dimensions of the result tensor type.
-    if (ShapedType::isDynamic(en.value()))
-      continue;
-    // Compute the padding width.
-    AffineExpr d0;
-    bindDims(b.getContext(), d0);
-    OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
-    high[en.index()] =
-        affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+    if (!ShapedType::isDynamic(en.value())) {
+      // Static sizes - the "high" value is computed based on the input and
+      // output dims. Compute the padding width.
+      AffineExpr d0;
+      bindDims(b.getContext(), d0);
+      OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
+      high[en.index()] =
+          affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+    } else {
+      // Dynamic sizes - the "high" value is computed based on the input dim
+      // and `dynOutDim`.
+      assert(dynOutDim.has_value() &&
+             "dynamic output dim requires dynOutDim to be set");
+
+      // Compute the padding width.
+      AffineExpr d0, d1;
+      auto dimVal = b.create<tensor::DimOp>(loc, source, en.index());
+      bindDims(b.getContext(), d0, d1);
+      high[en.index()] = affine::makeComposedFoldedAffineApply(
+          b, loc, d0 - d1, {dynOutDim.value(), dimVal.getResult()});
+    }
   }
   return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
 }
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 7d87a0994004fe..66a220005ebf36 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -23,6 +23,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
   %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
   return %0 : tensor<1x1x8x2xf32>
 }
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
+
 // CHECK-LABEL: func.func @simple_pad_and_pack
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
@@ -34,6 +36,39 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 
+/// Same as example above, but with scalable sizes.
+
+/// NOTE: For this example to make sense in practice, the "?" in the output shape
+///       should effectively be 8 * vector.vscale (and that's what tensor.dim
+///       below should return).
+
+func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
+  %c8 = arith.constant 8 : index
+  %vscale = vector.vscale
+  %c8_vscale = arith.muli %vscale, %c8 : index
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+  return %0 : tensor<1x1x?x2xf32>
+}
+
+
+// CHECK-LABEL:   func.func @simple_pad_and_pack_scalable(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
+// CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
+// CHECK:           %[[C2:.+]] = arith.constant 2 : index
+// CHECK:           %[[C5:.+]] = arith.constant 5 : index
+// CHECK:           %[[C8:.+]] = arith.constant 8 : index
+// CHECK:           %[[VS:.+]] = vector.vscale
+// CHECK:           %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
+// CHECK:           %[[PAD_HIGH:.+]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]], %[[C5]]]
+// CHECK:           %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK:             tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT:       linalg.transpose
+// CHECK:           %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK:           %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
+
 // -----
 
 func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{

>From 68a67139c56f0e026b8ee6acc7c3900a5f75e020 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 26 Sep 2024 14:13:30 +0100
Subject: [PATCH 2/5] fixup! [mlir][tensor] Extend the logic to generalise
 tensor.pack

Address PR comments from Han-Chung. Some clean-up and also relaxing the
requirement that the dynamic dim has to be a constant multiple of
vector.vscale.
---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h |  8 ++--
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 42 +++++++-----------
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       | 43 ++++++++-----------
 .../Linalg/generalize-tensor-pack.mlir        | 30 ++++++++++---
 4 files changed, 62 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index b2366748804156..4c3d725fbf56e2 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -14,10 +14,10 @@
 namespace mlir {
 namespace tensor {
 
-// Return a PadOp that pads `source` to `type` size where the static
-// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
-// dimensions the padding width is set to zero. The op performs "high" padding
-// (i.e. it adds trailing padding values until the desired size is met).
+// Return a PadOp that pads `source` to `type` size. Output sizes (from `type`)
+// are assumed to be static and greater than the potentially dynamic input sizes
+// (from `source`). The op performs "high" padding (i.e. it adds trailing
+// padding values until the desired size is met).
 PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
                       bool nofold, Location loc, OpBuilder &builder,
                       std::optional<Value> dynOutDim = {});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 42389b431566eb..47c37a347506d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1028,9 +1028,8 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
 /// This method assumes that all outer dims for this pack Op are 1.
 ///
 /// At most _one_ inner tile size can be _dynamic_, all other inner tiles are
-/// required to have static sizes. The inner tile that's dynamic must be a
-/// multiple of vector.vscale (to support scalable tile sizes). This condition
-/// can be relaxed in the future.
+/// required to have static sizes. This restriction can be relaxed in the
+/// future.
 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                            tensor::PackOp packOp) {
   Value input = packOp.getSource();
@@ -1049,8 +1048,8 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
       packOp.getDimAndTileMapping();
 
-  // The size of a scalable tile (if present).
-  Value scalableSize;
+  // The size of a dynamic tile (if present).
+  Value dynamicTileSize;
 
   // Collect dims for the padded shape.
   SmallVector<int64_t> paddedShape;
@@ -1080,16 +1079,15 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
     // 2.2 Dynamic tile sizes
     paddedShape.push_back(ShapedType::kDynamic);
 
-    // Get the value that holds the scalable size.
-    assert(!scalableSize && "Only one scalable size is supported ATM.");
-    scalableSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
-    assert(vector::getConstantVscaleMultiplier(scalableSize) &&
-           "This dynamic shape is not a multiple of vscale, this !");
+    // Get the value that holds the dynamic size.
+    assert(!dynamicTileSize && "Only one dynamic tile is supported ATM.");
+    dynamicTileSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
   }
   auto resultType =
       RankedTensorType::get(paddedShape, inputType.getElementType());
   return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
-                                 /*nofold=*/false, loc, builder, scalableSize);
+                                 /*nofold=*/false, loc, builder,
+                                 dynamicTileSize);
 }
 
 // Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1152,14 +1150,6 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
-  if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
-        return tile.is<Value>() && !vector::getConstantVscaleMultiplier(
-                                       llvm::dyn_cast<Value>(tile));
-      })) {
-    return rewriter.notifyMatchFailure(
-        packOp, "require inner tile sizes to be either static or a constant "
-                "multiple of vector.vscale");
-  }
   if (llvm::count_if(packOp.getMixedTiles(),
                      [](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
     return rewriter.notifyMatchFailure(
@@ -1221,22 +1211,20 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 
-  // If there's a tile with a scalable size, retrieve its size. ATM only 1
-  // scalable tile is allowed.
-  Value scalableSize;
+  // If there's a tile with a dynamic size, retrieve its size. ATM only 1
+  // dynamic tile is allowed.
+  Value dynDimSize;
   for (auto tile : packOp.getMixedTiles()) {
     if (tile.is<Value>()) {
-      assert(!scalableSize && "Only one scalable size is supported ATM.");
-      scalableSize = cast<Value>(tile);
-      assert(vector::getConstantVscaleMultiplier(scalableSize) &&
-             "This dynamic shape is not a multiple of vscale!");
+      assert(!dynDimSize && "Only one scalable size is supported ATM.");
+      dynDimSize = cast<Value>(tile);
     }
   }
 
   Value empty =
       ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
           ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
-                                             scalableSize)
+                                             dynDimSize)
           : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index ffa58b796f91ae..77bb56a6969014 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -26,37 +26,30 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
                                     Value pad, bool nofold, Location loc,
                                     OpBuilder &b,
                                     std::optional<Value> dynOutDim) {
-  assert(llvm::count_if(
-             type.getShape(),
-             [](int64_t dim) { return ShapedType::isDynamic(dim); }) <= 1 &&
+
+  assert(type.getNumDynamicDims() <= 1 &&
          "At most one output dim can be dynamic!");
 
   // Init "low" and "high" padding values ("low" is kept as is, "high" is
   // computed below).
   SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
   SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
-  for (const auto &en : enumerate(type.getShape())) {
-    if (!ShapedType::isDynamic(en.value())) {
-      // Static sizes - the "high" value is computed based on the input and
-      // output dims. Compute the padding width.
-      AffineExpr d0;
-      bindDims(b.getContext(), d0);
-      OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
-      high[en.index()] =
-          affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
-    } else {
-      // Dynamic sizes - the "high" value is computed based on the input dim
-      // and `dynOutDim`.
-      assert(dynOutDim.has_value() &&
-             "dynamic output dim requires dynOutDim to be set");
-
-      // Compute the padding width.
-      AffineExpr d0, d1;
-      auto dimVal = b.create<tensor::DimOp>(loc, source, en.index());
-      bindDims(b.getContext(), d0, d1);
-      high[en.index()] = affine::makeComposedFoldedAffineApply(
-          b, loc, d0 - d1, {dynOutDim.value(), dimVal.getResult()});
-    }
+
+  for (const auto [idx, val] : enumerate(type.getShape())) {
+    bool isOutDimDynamic = ShapedType::isDynamic(val);
+    assert((!isOutDimDynamic || dynOutDim.has_value()) &&
+           "dynamic output dim requires dynOutDim to be set");
+
+    // Compute the padding width: outDim - srcDim.
+    AffineExpr d0, d1;
+    bindDims(b.getContext(), d0, d1);
+    OpFoldResult srcDim = tensor::getMixedSize(b, loc, source, idx);
+    Value outDim = isOutDimDynamic
+                       ? dynOutDim.value()
+                       : b.create<arith::ConstantIndexOp>(loc, val).getResult();
+
+    high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
+                                                      {outDim, srcDim});
   }
   return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
 }
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 66a220005ebf36..d663fc02eb6a47 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -23,7 +23,7 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
   %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
   return %0 : tensor<1x1x8x2xf32>
 }
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 - 5)>
 
 // CHECK-LABEL: func.func @simple_pad_and_pack
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
@@ -36,7 +36,29 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 
-/// Same as example above, but with scalable sizes.
+/// Same as example above, but with dynamic tile size.
+
+func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+  return %0 : tensor<1x1x?x2xf32>
+}
+
+// CHECK-LABEL:   func.func @simple_pad_and_pack_dynamic(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[HIGH_VAL:.*]]: index) -> tensor<1x1x?x2xf32> {
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
+// CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK:             tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT:       linalg.transpose
+// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK:           %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
+
+/// Same as example above, but with scalable tile size.
 
 /// NOTE: For this example to make sense in practice, the "?" in the output shape
 ///       should effectively be 8 * vector.vscale (and that's what tensor.dim
@@ -50,17 +72,15 @@ func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor
   return %0 : tensor<1x1x?x2xf32>
 }
 
-
 // CHECK-LABEL:   func.func @simple_pad_and_pack_scalable(
 // CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
 // CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
 // CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
 // CHECK:           %[[C2:.+]] = arith.constant 2 : index
-// CHECK:           %[[C5:.+]] = arith.constant 5 : index
 // CHECK:           %[[C8:.+]] = arith.constant 8 : index
 // CHECK:           %[[VS:.+]] = vector.vscale
 // CHECK:           %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
-// CHECK:           %[[PAD_HIGH:.+]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]], %[[C5]]]
+// CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]]]
 // CHECK:           %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose

>From 894974f16e6c4bd3f7a84021193c6eaef314172e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 26 Sep 2024 20:04:18 +0100
Subject: [PATCH 3/5] fixup! [mlir][tensor] Extend the logic to generalise
 tensor.pack

* Use `OpFoldResult` for the output shape of tensor::EmptyOp (to
  simplify the invocation)
* Rename `readShape` as `readShapeForExtractSlice`
* Rename `transpShape` as `transShapeForEmpty`
---
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 25 ++++++++++---------
 1 file changed, 13 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 47c37a347506d7..e482231109b179 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1177,12 +1177,15 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> readSizes;
-  SmallVector<int64_t> readShape;
+  SmallVector<OpFoldResult> transShapeForEmpty;
+  SmallVector<int64_t> readShapeForExtractSlice;
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
     if (dimAndTileMapping.count(i)) {
-      readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
-                              .value_or(ShapedType::kDynamic));
+      readShapeForExtractSlice.push_back(
+          getConstantIntValue(dimAndTileMapping[i])
+              .value_or(ShapedType::kDynamic));
       readSizes.push_back(dimAndTileMapping[i]);
+      transShapeForEmpty.push_back(dimAndTileMapping[i]);
       continue;
     }
     if (ShapedType::isDynamic(inputShape[i])) {
@@ -1191,12 +1194,14 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     } else {
       readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
     }
-    if (inputShape[i] != 1)
-      readShape.push_back(inputShape[i]);
+    if (inputShape[i] != 1) {
+      readShapeForExtractSlice.push_back(inputShape[i]);
+      transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
+    }
   }
 
   Type elemType = packOp.getSourceType().getElementType();
-  auto readType = RankedTensorType::get(readShape, elemType);
+  auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
 
   Value tile = rewriter.create<tensor::ExtractSliceOp>(
       loc, readType, input, readOffsets, readSizes, readStrides);
@@ -1208,8 +1213,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
              llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
 
-  SmallVector<int64_t> transpShape = readShape;
-  applyPermutationToVector<int64_t>(transpShape, perm);
+  applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
 
   // If there's a tile with a dynamic size, retrieve its size. ATM only 1
   // dynamic tile is allowed.
@@ -1222,10 +1226,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   }
 
   Value empty =
-      ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
-          ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
-                                             dynDimSize)
-          : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+      rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
 

>From b3da00e4d95e12cba482204b508a4e6b0c79689f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 30 Sep 2024 08:36:12 +0100
Subject: [PATCH 4/5] fixup! fixup! [mlir][tensor] Extend the logic to
 generalise tensor.pack

Allow multiple dynamic dims and remove the dependency on #109667
---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h | 21 +++++++---
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  9 ++---
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       | 39 +++++++++++--------
 3 files changed, 42 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 4c3d725fbf56e2..ed1ec1e871482d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -14,13 +14,22 @@
 namespace mlir {
 namespace tensor {
 
-// Return a PadOp that pads `source` to `type` size. Output sizes (from `type`)
-// are assumed to be static and greater than the potentially dynamic input sizes
-// (from `source`). The op performs "high" padding (i.e. it adds trailing
-// padding values until the desired size is met).
-PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
+// Return a PadOp that pads `source` to `resType` size. The op performs "high"
+// padding, i.e. it adds trailing padding values until the desired size is met.
+// Output sizes are assumed to be greater than the input sizes. The padding
+// width is calculated as: resDim - sourceDim.
+//
+// Handling static sizes is trivial. Dynamic dimensions are trickier (*):
+//  1. dynamic input sizes are extracted from `source`
+//  2. for dynamic output dims, there are two options:
+//    2.1 all output dynamic dim sizes are specified in `dynOutDim`,
+//    2.2 `dynOutDim` is empty and the corresponding padding width is set to 0.
+//
+// (*) Note that `resType` is just a shape and it only encodes the actual sizes
+// for _static_ dimensions.
+PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad,
                       bool nofold, Location loc, OpBuilder &builder,
-                      std::optional<Value> dynOutDim = {});
+                      SmallVector<Value> dynOutDim = {});
 
 // Creates dim ops for each dynamic dimension of the ranked tensor argument and
 // returns these as values.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e482231109b179..a260172ad8a13a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1048,8 +1048,8 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
       packOp.getDimAndTileMapping();
 
-  // The size of a dynamic tile (if present).
-  Value dynamicTileSize;
+  // The sizes of dynamic tiles
+  SmallVector<Value> dynamicTileSizes;
 
   // Collect dims for the padded shape.
   SmallVector<int64_t> paddedShape;
@@ -1080,14 +1080,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
     paddedShape.push_back(ShapedType::kDynamic);
 
     // Get the value that holds the dynamic size.
-    assert(!dynamicTileSize && "Only one dynamic tile is supported ATM.");
-    dynamicTileSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
+    dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
   }
   auto resultType =
       RankedTensorType::get(paddedShape, inputType.getElementType());
   return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
                                  /*nofold=*/false, loc, builder,
-                                 dynamicTileSize);
+                                 dynamicTileSizes);
 }
 
 // Normalizes a permutation on a higher rank space to its actual size, e.g.
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 77bb56a6969014..555048f7e10611 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -22,36 +22,43 @@
 using namespace mlir;
 using namespace mlir::tensor;
 
-PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
+PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
                                     Value pad, bool nofold, Location loc,
                                     OpBuilder &b,
-                                    std::optional<Value> dynOutDim) {
+                                    SmallVector<Value> dynOutDims) {
 
-  assert(type.getNumDynamicDims() <= 1 &&
-         "At most one output dim can be dynamic!");
+  assert((resType.getNumDynamicDims() == dynOutDims.size()) ||
+         dynOutDims.empty() &&
+             "Either none or all output dynamic dims must be specified!");
 
   // Init "low" and "high" padding values ("low" is kept as is, "high" is
   // computed below).
-  SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
-  SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
+  SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0));
+  SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0));
 
-  for (const auto [idx, val] : enumerate(type.getShape())) {
-    bool isOutDimDynamic = ShapedType::isDynamic(val);
-    assert((!isOutDimDynamic || dynOutDim.has_value()) &&
-           "dynamic output dim requires dynOutDim to be set");
+  size_t outDimIdx = 0;
 
-    // Compute the padding width: outDim - srcDim.
+  for (const auto [idx, val] : enumerate(resType.getShape())) {
+    bool isDimDynamic = ShapedType::isDynamic(val);
+    bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();
+
+    // Keep the default padding width (i.e. "0") when the output dim is dynamic
+    // and no actual output sizes have been provided.
+    if (!updatePadHigh)
+      continue;
+
+    // Compute the padding width: resDim - sourceDim.
     AffineExpr d0, d1;
     bindDims(b.getContext(), d0, d1);
-    OpFoldResult srcDim = tensor::getMixedSize(b, loc, source, idx);
-    Value outDim = isOutDimDynamic
-                       ? dynOutDim.value()
+    OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
+    Value outDim = isDimDynamic
+                       ? dynOutDims[outDimIdx++]
                        : b.create<arith::ConstantIndexOp>(loc, val).getResult();
 
     high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
-                                                      {outDim, srcDim});
+                                                      {outDim, sourceDim});
   }
-  return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
+  return b.create<PadOp>(loc, resType, source, low, high, pad, nofold);
 }
 
 SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,

>From 3e9521c5cb4e8cd5016b321d20150560b3b4e189 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 1 Oct 2024 16:49:26 +0100
Subject: [PATCH 5/5] fixup! fixup! fixup! [mlir][tensor] Extend the logic to
 generalise tensor.pack

Incorporating suggestions from @hanhanW
---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp    | 12 +-----------
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp              |  5 ++---
 mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir |  6 +++---
 3 files changed, 6 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a260172ad8a13a..ab8971a9c5a966 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1054,10 +1054,10 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
   // Collect dims for the padded shape.
   SmallVector<int64_t> paddedShape;
   for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
-    int64_t inputDimSize = inputType.getDimSize(dimIdx);
     // 1. Non-tiled outer dims.
     // These dims should be 1 and we simply preserve them.
     if (!tileAndPosMapping.count(dimIdx)) {
+      int64_t inputDimSize = inputType.getDimSize(dimIdx);
       assert(inputDimSize == 1 &&
              "with all outer dims == 1, this non-tiled input dim should be 1!");
       paddedShape.push_back(inputDimSize);
@@ -1214,16 +1214,6 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
   applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
 
-  // If there's a tile with a dynamic size, retrieve its size. ATM only 1
-  // dynamic tile is allowed.
-  Value dynDimSize;
-  for (auto tile : packOp.getMixedTiles()) {
-    if (tile.is<Value>()) {
-      assert(!dynDimSize && "Only one scalable size is supported ATM.");
-      dynDimSize = cast<Value>(tile);
-    }
-  }
-
   Value empty =
       rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
   auto transposedOp =
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 555048f7e10611..1cb040b6dca414 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -51,9 +51,8 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
     AffineExpr d0, d1;
     bindDims(b.getContext(), d0, d1);
     OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
-    Value outDim = isDimDynamic
-                       ? dynOutDims[outDimIdx++]
-                       : b.create<arith::ConstantIndexOp>(loc, val).getResult();
+    OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++])
+                                       : OpFoldResult(b.getIndexAttr(val));
 
     high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
                                                       {outDim, sourceDim});
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index d663fc02eb6a47..bb23a869a9cc5b 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -76,9 +76,9 @@ func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor
 // CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
 // CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
 // CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
-// CHECK:           %[[C2:.+]] = arith.constant 2 : index
-// CHECK:           %[[C8:.+]] = arith.constant 8 : index
-// CHECK:           %[[VS:.+]] = vector.vscale
+// CHECK-DAG:       %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VS:.+]] = vector.vscale
 // CHECK:           %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
 // CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]]]
 // CHECK:           %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {



More information about the Mlir-commits mailing list