[Mlir-commits] [mlir] [mlir][tensor] Extend the logic to generalise tensor.pack (PR #109815)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Sep 26 06:43:24 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/109815

>From 7a5a87592426f094c5b5af32310e747a011dc72f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 23 Sep 2024 14:32:47 +0100
Subject: [PATCH 1/3] [mlir][tensor] Refine the semantics of `createPadHighOp`

Refine `createPadHighOp` so that the output tensor is required to be
statically shaped. This is to prevent the current behaviour, which is
incorrect:

>  // If `type` has dynamic dimensions the padding width is set to zero.

The actual padding width should be set to: `%new_dim - %old_dim`, where
%new_dim` and `%old_dim` are defined via e.g. `tensor.dim` Op applied to
output and input tensors, respectively.

This PR is an attempt to clarify the semantics surrounding dynamic
shapes in preparation for adding support for scalable vectors to the
pack/unpack logic in Tensor/Linalg (dynamic shapes is what we use to
model scalable (*) sizes at the Tensor/MemRef level).

(*) Scalable as in Arm's Scalable Vector Extension (SVE)
---
 mlir/include/mlir/Dialect/Tensor/Utils/Utils.h |  8 ++++----
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp        | 10 +++++++---
 2 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 84d06d456bb689..e63749eb384316 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -14,10 +14,10 @@
 namespace mlir {
 namespace tensor {
 
-// Return a PadOp that pads `source` to `type` size where the static
-// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
-// dimensions the padding width is set to zero. The op performs "high" padding
-// (i.e. it adds trailing padding values until the desired size is met).
+// Return a PadOp that pads `source` to `type` size. Output sizes (from `type`)
+// are assumed to be static and greater than the potentially dynamic input sizes
+// (from `source). The op performs "high" padding (i.e. it adds trailing padding
+// values until the desired size is met).
 PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
                       bool nofold, Location loc, OpBuilder &builder);
 
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index a0d8a08fc6ba47..c8e0c05bfb2b87 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -24,12 +24,16 @@ using namespace mlir::tensor;
 PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
                                     Value pad, bool nofold, Location loc,
                                     OpBuilder &b) {
+
+  assert(!ShapedType::isDynamicShape(type.getShape()) &&
+         "The output type is dynamic - that's not supported ATM.");
+
+  // Init "low" and "high" padding values ("low" is kept as is, "high" is
+  // computed below).
   SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
   SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
+
   for (const auto &en : enumerate(type.getShape())) {
-    // Pad only the static dimensions of the result tensor type.
-    if (ShapedType::isDynamic(en.value()))
-      continue;
     // Compute the padding width.
     AffineExpr d0;
     bindDims(b.getContext(), d0);

>From 58c8ec0b2a74547176090a94ad22ff2d7cbc66a7 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 19 Sep 2024 17:44:51 +0000
Subject: [PATCH 2/3] [mlir][tensor] Extend the logic to generalise tensor.pack

Extends the logic to generalise tensor.pack (into e.g. tensor.pad +
tensor.transpose) so that it also works when one of the inner tile sizes
is scalable (i.e. a multiple of `vector.vscale`). For example:
```mlir
  %c8 = arith.constant 8 : index
  %vscale = vector.vscale
  %c8_vscale = arith.muli %vscale, %c8 : index
  %0 = tensor.pack %input
      padding_value(%pad : f32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8_vscale, 2]
      into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
}
```
is generalised as:
```mlir
  %c8 = arith.constant 8 : index
  %vscale = vector.vscale
  %c8_vscale = arith.muli %vscale, %c8 : index
  %0 = affine.apply #map()[%c8_vscale, %c5]
  %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
  ^bb0(%arg3: index, %arg4: index):
    tensor.yield %arg2 : f32
  } : tensor<5x1xf32> to tensor<?x2xf32>
```

At the Tensor level, we model scalability using dynamic shapes and this
change basically extends the relevant logic so that it also works for
dynamic shapes. However, rather than allowing arbitrary values and
number of tile sizes to be dynamic, only _one_ tile size is allowed to
be dynamic. In addition, it is required to be a constant multiple of
`vector.vscale`.

While the requirements above can be relaxed, I wanted to avoid full
generality for now. Primarily to avoid complexity that's not yet needed
and to make reviewing a bit easier.
---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h |  3 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 92 +++++++++++++++----
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       | 38 ++++++--
 .../Linalg/generalize-tensor-pack.mlir        | 35 +++++++
 4 files changed, 139 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index e63749eb384316..db5a15c9ec3550 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -19,7 +19,8 @@ namespace tensor {
 // (from `source). The op performs "high" padding (i.e. it adds trailing padding
 // values until the desired size is met).
 PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
-                      bool nofold, Location loc, OpBuilder &builder);
+                      bool nofold, Location loc, OpBuilder &builder,
+                      std::optional<Value> dynOutDim = {});
 
 // Creates dim ops for each dynamic dimension of the ranked tensor argument and
 // returns these as values.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e0dea8e78d55c1..42389b431566eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1021,8 +1021,16 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
   return success();
 }
 
-/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
-/// source directly. The method assumes that the `packOp` has static shapes.
+/// If padding value is set, returns a tensor.pad Op for the source tensor,
+/// with the output shape matching the output of `packOp`. Otherwise, returns
+/// the source directly.
+///
+/// This method assumes that all outer dims for this pack Op are 1.
+///
+/// At most _one_ inner tile size can be _dynamic_, all other inner tiles are
+/// required to have static sizes. The inner tile that's dynamic must be a
+/// multiple of vector.vscale (to support scalable tile sizes). This condition
+/// can be relaxed in the future.
 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                            tensor::PackOp packOp) {
   Value input = packOp.getSource();
@@ -1038,26 +1046,50 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
   ShapedType inputType = packOp.getSourceType();
   int64_t inputRank = inputType.getRank();
 
-  SmallVector<int64_t> paddedShape;
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
       packOp.getDimAndTileMapping();
-  for (int64_t dim = 0; dim < inputRank; ++dim) {
-    int64_t size = inputType.getDimSize(dim);
-    if (!tileAndPosMapping.count(dim)) {
-      paddedShape.push_back(size);
+
+  // The size of a scalable tile (if present).
+  Value scalableSize;
+
+  // Collect dims for the padded shape.
+  SmallVector<int64_t> paddedShape;
+  for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
+    int64_t inputDimSize = inputType.getDimSize(dimIdx);
+    // 1. Non-tiled outer dims.
+    // These dims should be 1 and we simply preserve them.
+    if (!tileAndPosMapping.count(dimIdx)) {
+      assert(inputDimSize == 1 &&
+             "with all outer dims == 1, this non-tiled input dim should be 1!");
+      paddedShape.push_back(inputDimSize);
+      continue;
+    }
+
+    // 2. Tiled outer dims
+    // As all outer dims == 1, it is safe to use the tile size for the padded
+    // shape.
+    OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
+
+    // 2.1 Static tile sizes
+    std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
+    if (cstTileSize.has_value()) {
+      paddedShape.push_back(cstTileSize.value());
       continue;
     }
 
-    // The size is less than or equal to tileSize because outer dims are all 1s.
-    std::optional<int64_t> tileSize =
-        getConstantIntValue(tileAndPosMapping.lookup(dim));
-    assert(tileSize.has_value() && "dynamic inner tile size is not supported");
-    paddedShape.push_back(tileSize.value());
+    // 2.2 Dynamic tile sizes
+    paddedShape.push_back(ShapedType::kDynamic);
+
+    // Get the value that holds the scalable size.
+    assert(!scalableSize && "Only one scalable size is supported ATM.");
+    scalableSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
+    assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+           "This dynamic shape is not a multiple of vscale, this !");
   }
   auto resultType =
       RankedTensorType::get(paddedShape, inputType.getElementType());
   return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
-                                 /*nofold=*/false, loc, builder);
+                                 /*nofold=*/false, loc, builder, scalableSize);
 }
 
 // Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1152,18 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
-  if (llvm::any_of(packOp.getMixedTiles(),
-                   [](OpFoldResult tile) { return tile.is<Value>(); })) {
-    return rewriter.notifyMatchFailure(packOp,
-                                       "require inner tile sizes being static");
+  if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
+        return tile.is<Value>() && !vector::getConstantVscaleMultiplier(
+                                       llvm::dyn_cast<Value>(tile));
+      })) {
+    return rewriter.notifyMatchFailure(
+        packOp, "require inner tile sizes to be either static or a constant "
+                "multiple of vector.vscale");
+  }
+  if (llvm::count_if(packOp.getMixedTiles(),
+                     [](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
+    return rewriter.notifyMatchFailure(
+        packOp, "at most one dynamic tile size is supported");
   }
 
   // TODO: support the case that outer dimensions are not all 1s. A
@@ -1181,7 +1221,23 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 
-  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+  // If there's a tile with a scalable size, retrieve its size. ATM only 1
+  // scalable tile is allowed.
+  Value scalableSize;
+  for (auto tile : packOp.getMixedTiles()) {
+    if (tile.is<Value>()) {
+      assert(!scalableSize && "Only one scalable size is supported ATM.");
+      scalableSize = cast<Value>(tile);
+      assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+             "This dynamic shape is not a multiple of vscale!");
+    }
+  }
+
+  Value empty =
+      ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
+          ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
+                                             scalableSize)
+          : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
 
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index c8e0c05bfb2b87..7b25d9747827e3 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR//VectorOps.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
@@ -23,10 +24,12 @@ using namespace mlir::tensor;
 
 PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
                                     Value pad, bool nofold, Location loc,
-                                    OpBuilder &b) {
-
-  assert(!ShapedType::isDynamicShape(type.getShape()) &&
-         "The output type is dynamic - that's not supported ATM.");
+                                    OpBuilder &b,
+                                    std::optional<Value> dynOutDim) {
+  assert(llvm::count_if(
+             type.getShape(),
+             [](int64_t dim) { return ShapedType::isDynamic(dim); }) <= 1 &&
+         "At most one output dim can be dynamic!");
 
   // Init "low" and "high" padding values ("low" is kept as is, "high" is
   // computed below).
@@ -34,12 +37,27 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
   SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
 
   for (const auto &en : enumerate(type.getShape())) {
-    // Compute the padding width.
-    AffineExpr d0;
-    bindDims(b.getContext(), d0);
-    OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
-    high[en.index()] =
-        affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+    if (!ShapedType::isDynamic(en.value())) {
+      // Static sizes - the "high" value is computed based on the input and
+      // output dims. Compute the padding width.
+      AffineExpr d0;
+      bindDims(b.getContext(), d0);
+      OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
+      high[en.index()] =
+          affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+    } else {
+      // Dynamic sizes - the "high" value is computed based on the input dim
+      // and `dynOutDim`.
+      assert(dynOutDim.has_value() &&
+             "dynamic output dim requires dynOutDim to be set");
+
+      // Compute the padding width.
+      AffineExpr d0, d1;
+      auto dimVal = b.create<tensor::DimOp>(loc, source, en.index());
+      bindDims(b.getContext(), d0, d1);
+      high[en.index()] = affine::makeComposedFoldedAffineApply(
+          b, loc, d0 - d1, {dynOutDim.value(), dimVal.getResult()});
+    }
   }
   return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
 }
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 7d87a0994004fe..66a220005ebf36 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -23,6 +23,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
   %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
   return %0 : tensor<1x1x8x2xf32>
 }
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
+
 // CHECK-LABEL: func.func @simple_pad_and_pack
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
@@ -34,6 +36,39 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 
+/// Same as example above, but with scalable sizes.
+
+/// NOTE: For this example to make sense in practice, the "?" in the output shape
+///       should effectively be 8 * vector.vscale (and that's what tensor.dim
+///       below should return).
+
+func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
+  %c8 = arith.constant 8 : index
+  %vscale = vector.vscale
+  %c8_vscale = arith.muli %vscale, %c8 : index
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+  return %0 : tensor<1x1x?x2xf32>
+}
+
+
+// CHECK-LABEL:   func.func @simple_pad_and_pack_scalable(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
+// CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
+// CHECK:           %[[C2:.+]] = arith.constant 2 : index
+// CHECK:           %[[C5:.+]] = arith.constant 5 : index
+// CHECK:           %[[C8:.+]] = arith.constant 8 : index
+// CHECK:           %[[VS:.+]] = vector.vscale
+// CHECK:           %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
+// CHECK:           %[[PAD_HIGH:.+]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]], %[[C5]]]
+// CHECK:           %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK:             tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT:       linalg.transpose
+// CHECK:           %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK:           %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
+
 // -----
 
 func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{

>From 42360fdbea6863d342ddf8fb2c46b655e4e757c5 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 26 Sep 2024 14:13:30 +0100
Subject: [PATCH 3/3] fixup! [mlir][tensor] Extend the logic to generalise
 tensor.pack

Address PR comments from Han-Chung. Some clean-up and also relaxing the
requirement that the dynamic dim has to be a constant multiple of
vector.vscale.
---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h |  4 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 42 +++++++------------
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       | 42 ++++++++-----------
 .../Linalg/generalize-tensor-pack.mlir        | 30 ++++++++++---
 4 files changed, 59 insertions(+), 59 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index db5a15c9ec3550..4c3d725fbf56e2 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -16,8 +16,8 @@ namespace tensor {
 
 // Return a PadOp that pads `source` to `type` size. Output sizes (from `type`)
 // are assumed to be static and greater than the potentially dynamic input sizes
-// (from `source). The op performs "high" padding (i.e. it adds trailing padding
-// values until the desired size is met).
+// (from `source`). The op performs "high" padding (i.e. it adds trailing
+// padding values until the desired size is met).
 PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
                       bool nofold, Location loc, OpBuilder &builder,
                       std::optional<Value> dynOutDim = {});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 42389b431566eb..47c37a347506d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1028,9 +1028,8 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
 /// This method assumes that all outer dims for this pack Op are 1.
 ///
 /// At most _one_ inner tile size can be _dynamic_, all other inner tiles are
-/// required to have static sizes. The inner tile that's dynamic must be a
-/// multiple of vector.vscale (to support scalable tile sizes). This condition
-/// can be relaxed in the future.
+/// required to have static sizes. This restriction can be relaxed in the
+/// future.
 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                            tensor::PackOp packOp) {
   Value input = packOp.getSource();
@@ -1049,8 +1048,8 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
       packOp.getDimAndTileMapping();
 
-  // The size of a scalable tile (if present).
-  Value scalableSize;
+  // The size of a dynamic tile (if present).
+  Value dynamicTileSize;
 
   // Collect dims for the padded shape.
   SmallVector<int64_t> paddedShape;
@@ -1080,16 +1079,15 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
     // 2.2 Dynamic tile sizes
     paddedShape.push_back(ShapedType::kDynamic);
 
-    // Get the value that holds the scalable size.
-    assert(!scalableSize && "Only one scalable size is supported ATM.");
-    scalableSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
-    assert(vector::getConstantVscaleMultiplier(scalableSize) &&
-           "This dynamic shape is not a multiple of vscale, this !");
+    // Get the value that holds the dynamic size.
+    assert(!dynamicTileSize && "Only one dynamic tile is supported ATM.");
+    dynamicTileSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
   }
   auto resultType =
       RankedTensorType::get(paddedShape, inputType.getElementType());
   return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
-                                 /*nofold=*/false, loc, builder, scalableSize);
+                                 /*nofold=*/false, loc, builder,
+                                 dynamicTileSize);
 }
 
 // Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1152,14 +1150,6 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
-  if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
-        return tile.is<Value>() && !vector::getConstantVscaleMultiplier(
-                                       llvm::dyn_cast<Value>(tile));
-      })) {
-    return rewriter.notifyMatchFailure(
-        packOp, "require inner tile sizes to be either static or a constant "
-                "multiple of vector.vscale");
-  }
   if (llvm::count_if(packOp.getMixedTiles(),
                      [](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
     return rewriter.notifyMatchFailure(
@@ -1221,22 +1211,20 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 
-  // If there's a tile with a scalable size, retrieve its size. ATM only 1
-  // scalable tile is allowed.
-  Value scalableSize;
+  // If there's a tile with a dynamic size, retrieve its size. ATM only 1
+  // dynamic tile is allowed.
+  Value dynDimSize;
   for (auto tile : packOp.getMixedTiles()) {
     if (tile.is<Value>()) {
-      assert(!scalableSize && "Only one scalable size is supported ATM.");
-      scalableSize = cast<Value>(tile);
-      assert(vector::getConstantVscaleMultiplier(scalableSize) &&
-             "This dynamic shape is not a multiple of vscale!");
+      assert(!dynDimSize && "Only one scalable size is supported ATM.");
+      dynDimSize = cast<Value>(tile);
     }
   }
 
   Value empty =
       ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
           ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
-                                             scalableSize)
+                                             dynDimSize)
           : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 7b25d9747827e3..77bb56a6969014 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -26,9 +26,8 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
                                     Value pad, bool nofold, Location loc,
                                     OpBuilder &b,
                                     std::optional<Value> dynOutDim) {
-  assert(llvm::count_if(
-             type.getShape(),
-             [](int64_t dim) { return ShapedType::isDynamic(dim); }) <= 1 &&
+
+  assert(type.getNumDynamicDims() <= 1 &&
          "At most one output dim can be dynamic!");
 
   // Init "low" and "high" padding values ("low" is kept as is, "high" is
@@ -36,28 +35,21 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
   SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
   SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
 
-  for (const auto &en : enumerate(type.getShape())) {
-    if (!ShapedType::isDynamic(en.value())) {
-      // Static sizes - the "high" value is computed based on the input and
-      // output dims. Compute the padding width.
-      AffineExpr d0;
-      bindDims(b.getContext(), d0);
-      OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
-      high[en.index()] =
-          affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
-    } else {
-      // Dynamic sizes - the "high" value is computed based on the input dim
-      // and `dynOutDim`.
-      assert(dynOutDim.has_value() &&
-             "dynamic output dim requires dynOutDim to be set");
-
-      // Compute the padding width.
-      AffineExpr d0, d1;
-      auto dimVal = b.create<tensor::DimOp>(loc, source, en.index());
-      bindDims(b.getContext(), d0, d1);
-      high[en.index()] = affine::makeComposedFoldedAffineApply(
-          b, loc, d0 - d1, {dynOutDim.value(), dimVal.getResult()});
-    }
+  for (const auto [idx, val] : enumerate(type.getShape())) {
+    bool isOutDimDynamic = ShapedType::isDynamic(val);
+    assert((!isOutDimDynamic || dynOutDim.has_value()) &&
+           "dynamic output dim requires dynOutDim to be set");
+
+    // Compute the padding width: outDim - srcDim.
+    AffineExpr d0, d1;
+    bindDims(b.getContext(), d0, d1);
+    OpFoldResult srcDim = tensor::getMixedSize(b, loc, source, idx);
+    Value outDim = isOutDimDynamic
+                       ? dynOutDim.value()
+                       : b.create<arith::ConstantIndexOp>(loc, val).getResult();
+
+    high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
+                                                      {outDim, srcDim});
   }
   return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
 }
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 66a220005ebf36..d663fc02eb6a47 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -23,7 +23,7 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
   %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
   return %0 : tensor<1x1x8x2xf32>
 }
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 - 5)>
 
 // CHECK-LABEL: func.func @simple_pad_and_pack
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
@@ -36,7 +36,29 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 
-/// Same as example above, but with scalable sizes.
+/// Same as example above, but with dynamic tile size.
+
+func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+  return %0 : tensor<1x1x?x2xf32>
+}
+
+// CHECK-LABEL:   func.func @simple_pad_and_pack_dynamic(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[HIGH_VAL:.*]]: index) -> tensor<1x1x?x2xf32> {
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
+// CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK:             tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT:       linalg.transpose
+// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK:           %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
+
+/// Same as example above, but with scalable tile size.
 
 /// NOTE: For this example to make sense in practice, the "?" in the output shape
 ///       should effectively be 8 * vector.vscale (and that's what tensor.dim
@@ -50,17 +72,15 @@ func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor
   return %0 : tensor<1x1x?x2xf32>
 }
 
-
 // CHECK-LABEL:   func.func @simple_pad_and_pack_scalable(
 // CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
 // CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
 // CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
 // CHECK:           %[[C2:.+]] = arith.constant 2 : index
-// CHECK:           %[[C5:.+]] = arith.constant 5 : index
 // CHECK:           %[[C8:.+]] = arith.constant 8 : index
 // CHECK:           %[[VS:.+]] = vector.vscale
 // CHECK:           %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
-// CHECK:           %[[PAD_HIGH:.+]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]], %[[C5]]]
+// CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]]]
 // CHECK:           %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose



More information about the Mlir-commits mailing list