[Mlir-commits] [mlir] [mlir][tensor] Extend the logic to generalise tensor.pack (PR #109815)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 24 08:22:08 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

- **[mlir][tensor] Refine the semantics of `createPadHighOp`**
- **[mlir][tensor] Extend the logic to generalise tensor.pack**


---
Full diff: https://github.com/llvm/llvm-project/pull/109815.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/Utils/Utils.h (+6-5) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+74-18) 
- (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+32-10) 
- (modified) mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir (+35) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 84d06d456bb689..db5a15c9ec3550 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -14,12 +14,13 @@
 namespace mlir {
 namespace tensor {
 
-// Return a PadOp that pads `source` to `type` size where the static
-// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
-// dimensions the padding width is set to zero. The op performs "high" padding
-// (i.e. it adds trailing padding values until the desired size is met).
+// Return a PadOp that pads `source` to `type` size. Output sizes (from `type`)
+// are assumed to be static and greater than the potentially dynamic input sizes
+// (from `source). The op performs "high" padding (i.e. it adds trailing padding
+// values until the desired size is met).
 PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
-                      bool nofold, Location loc, OpBuilder &builder);
+                      bool nofold, Location loc, OpBuilder &builder,
+                      std::optional<Value> dynOutDim = {});
 
 // Creates dim ops for each dynamic dimension of the ranked tensor argument and
 // returns these as values.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e0dea8e78d55c1..42389b431566eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1021,8 +1021,16 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
   return success();
 }
 
-/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
-/// source directly. The method assumes that the `packOp` has static shapes.
+/// If padding value is set, returns a tensor.pad Op for the source tensor,
+/// with the output shape matching the output of `packOp`. Otherwise, returns
+/// the source directly.
+///
+/// This method assumes that all outer dims for this pack Op are 1.
+///
+/// At most _one_ inner tile size can be _dynamic_, all other inner tiles are
+/// required to have static sizes. The inner tile that's dynamic must be a
+/// multiple of vector.vscale (to support scalable tile sizes). This condition
+/// can be relaxed in the future.
 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                            tensor::PackOp packOp) {
   Value input = packOp.getSource();
@@ -1038,26 +1046,50 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
   ShapedType inputType = packOp.getSourceType();
   int64_t inputRank = inputType.getRank();
 
-  SmallVector<int64_t> paddedShape;
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
       packOp.getDimAndTileMapping();
-  for (int64_t dim = 0; dim < inputRank; ++dim) {
-    int64_t size = inputType.getDimSize(dim);
-    if (!tileAndPosMapping.count(dim)) {
-      paddedShape.push_back(size);
+
+  // The size of a scalable tile (if present).
+  Value scalableSize;
+
+  // Collect dims for the padded shape.
+  SmallVector<int64_t> paddedShape;
+  for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
+    int64_t inputDimSize = inputType.getDimSize(dimIdx);
+    // 1. Non-tiled outer dims.
+    // These dims should be 1 and we simply preserve them.
+    if (!tileAndPosMapping.count(dimIdx)) {
+      assert(inputDimSize == 1 &&
+             "with all outer dims == 1, this non-tiled input dim should be 1!");
+      paddedShape.push_back(inputDimSize);
+      continue;
+    }
+
+    // 2. Tiled outer dims
+    // As all outer dims == 1, it is safe to use the tile size for the padded
+    // shape.
+    OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
+
+    // 2.1 Static tile sizes
+    std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
+    if (cstTileSize.has_value()) {
+      paddedShape.push_back(cstTileSize.value());
       continue;
     }
 
-    // The size is less than or equal to tileSize because outer dims are all 1s.
-    std::optional<int64_t> tileSize =
-        getConstantIntValue(tileAndPosMapping.lookup(dim));
-    assert(tileSize.has_value() && "dynamic inner tile size is not supported");
-    paddedShape.push_back(tileSize.value());
+    // 2.2 Dynamic tile sizes
+    paddedShape.push_back(ShapedType::kDynamic);
+
+    // Get the value that holds the scalable size.
+    assert(!scalableSize && "Only one scalable size is supported ATM.");
+    scalableSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
+    assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+           "This dynamic shape is not a multiple of vscale, this !");
   }
   auto resultType =
       RankedTensorType::get(paddedShape, inputType.getElementType());
   return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
-                                 /*nofold=*/false, loc, builder);
+                                 /*nofold=*/false, loc, builder, scalableSize);
 }
 
 // Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1152,18 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
-  if (llvm::any_of(packOp.getMixedTiles(),
-                   [](OpFoldResult tile) { return tile.is<Value>(); })) {
-    return rewriter.notifyMatchFailure(packOp,
-                                       "require inner tile sizes being static");
+  if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
+        return tile.is<Value>() && !vector::getConstantVscaleMultiplier(
+                                       llvm::dyn_cast<Value>(tile));
+      })) {
+    return rewriter.notifyMatchFailure(
+        packOp, "require inner tile sizes to be either static or a constant "
+                "multiple of vector.vscale");
+  }
+  if (llvm::count_if(packOp.getMixedTiles(),
+                     [](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
+    return rewriter.notifyMatchFailure(
+        packOp, "at most one dynamic tile size is supported");
   }
 
   // TODO: support the case that outer dimensions are not all 1s. A
@@ -1181,7 +1221,23 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 
-  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+  // If there's a tile with a scalable size, retrieve its size. ATM only 1
+  // scalable tile is allowed.
+  Value scalableSize;
+  for (auto tile : packOp.getMixedTiles()) {
+    if (tile.is<Value>()) {
+      assert(!scalableSize && "Only one scalable size is supported ATM.");
+      scalableSize = cast<Value>(tile);
+      assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+             "This dynamic shape is not a multiple of vscale!");
+    }
+  }
+
+  Value empty =
+      ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
+          ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
+                                             scalableSize)
+          : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
 
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index a0d8a08fc6ba47..7b25d9747827e3 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR//VectorOps.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
@@ -23,19 +24,40 @@ using namespace mlir::tensor;
 
 PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
                                     Value pad, bool nofold, Location loc,
-                                    OpBuilder &b) {
+                                    OpBuilder &b,
+                                    std::optional<Value> dynOutDim) {
+  assert(llvm::count_if(
+             type.getShape(),
+             [](int64_t dim) { return ShapedType::isDynamic(dim); }) <= 1 &&
+         "At most one output dim can be dynamic!");
+
+  // Init "low" and "high" padding values ("low" is kept as is, "high" is
+  // computed below).
   SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
   SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
+
   for (const auto &en : enumerate(type.getShape())) {
-    // Pad only the static dimensions of the result tensor type.
-    if (ShapedType::isDynamic(en.value()))
-      continue;
-    // Compute the padding width.
-    AffineExpr d0;
-    bindDims(b.getContext(), d0);
-    OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
-    high[en.index()] =
-        affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+    if (!ShapedType::isDynamic(en.value())) {
+      // Static sizes - the "high" value is computed based on the input and
+      // output dims. Compute the padding width.
+      AffineExpr d0;
+      bindDims(b.getContext(), d0);
+      OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
+      high[en.index()] =
+          affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+    } else {
+      // Dynamic sizes - the "high" value is computed based on the input dim
+      // and `dynOutDim`.
+      assert(dynOutDim.has_value() &&
+             "dynamic output dim requires dynOutDim to be set");
+
+      // Compute the padding width.
+      AffineExpr d0, d1;
+      auto dimVal = b.create<tensor::DimOp>(loc, source, en.index());
+      bindDims(b.getContext(), d0, d1);
+      high[en.index()] = affine::makeComposedFoldedAffineApply(
+          b, loc, d0 - d1, {dynOutDim.value(), dimVal.getResult()});
+    }
   }
   return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
 }
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 7d87a0994004fe..66a220005ebf36 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -23,6 +23,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
   %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
   return %0 : tensor<1x1x8x2xf32>
 }
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
+
 // CHECK-LABEL: func.func @simple_pad_and_pack
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
@@ -34,6 +36,39 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 
+/// Same as example above, but with scalable sizes.
+
+/// NOTE: For this example to make sense in practice, the "?" in the output shape
+///       should effectively be 8 * vector.vscale (and that's what tensor.dim
+///       below should return).
+
+func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
+  %c8 = arith.constant 8 : index
+  %vscale = vector.vscale
+  %c8_vscale = arith.muli %vscale, %c8 : index
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+  return %0 : tensor<1x1x?x2xf32>
+}
+
+
+// CHECK-LABEL:   func.func @simple_pad_and_pack_scalable(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
+// CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
+// CHECK:           %[[C2:.+]] = arith.constant 2 : index
+// CHECK:           %[[C5:.+]] = arith.constant 5 : index
+// CHECK:           %[[C8:.+]] = arith.constant 8 : index
+// CHECK:           %[[VS:.+]] = vector.vscale
+// CHECK:           %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
+// CHECK:           %[[PAD_HIGH:.+]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]], %[[C5]]]
+// CHECK:           %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK:             tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT:       linalg.transpose
+// CHECK:           %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK:           %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
+
 // -----
 
 func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{

``````````

</details>


https://github.com/llvm/llvm-project/pull/109815


More information about the Mlir-commits mailing list