[Mlir-commits] [mlir] 66f84c8 - [mlir][tensor] Extend the logic to generalise tensor.pack (#109815)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 2 01:44:16 PDT 2024


Author: Andrzej WarzyƄski
Date: 2024-10-02T09:44:13+01:00
New Revision: 66f84c8b8a762832af39e91370018f8f8307a0fc

URL: https://github.com/llvm/llvm-project/commit/66f84c8b8a762832af39e91370018f8f8307a0fc
DIFF: https://github.com/llvm/llvm-project/commit/66f84c8b8a762832af39e91370018f8f8307a0fc.diff

LOG: [mlir][tensor] Extend the logic to generalise tensor.pack (#109815)

Extends the logic to generalise tensor.pack (into e.g. tensor.pad +
tensor.transpose) so that it also works when one of the inner tile sizes
is scalable (i.e. a multiple of `vector.vscale`). For example:
```mlir
  %c8 = arith.constant 8 : index
  %vscale = vector.vscale
  %c8_vscale = arith.muli %vscale, %c8 : index
  %0 = tensor.pack %input
      padding_value(%pad : f32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8_vscale, 2]
      into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
}
```
is generalised as:
```mlir
  %c8 = arith.constant 8 : index
  %vscale = vector.vscale
  %c8_vscale = arith.muli %vscale, %c8 : index
  %0 = affine.apply #map()[%c8_vscale, %c5]
  %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
  ^bb0(%arg3: index, %arg4: index):
    tensor.yield %arg2 : f32
  } : tensor<5x1xf32> to tensor<?x2xf32>
```

At the Tensor level, we model scalability using dynamic shapes and this
change basically extends the relevant logic so that it also works for
dynamic shapes.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Tensor/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 84d06d456bb689..ed1ec1e871482d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -14,12 +14,22 @@
 namespace mlir {
 namespace tensor {
 
-// Return a PadOp that pads `source` to `type` size where the static
-// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
-// dimensions the padding width is set to zero. The op performs "high" padding
-// (i.e. it adds trailing padding values until the desired size is met).
-PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
-                      bool nofold, Location loc, OpBuilder &builder);
+// Return a PadOp that pads `source` to `resType` size. The op performs "high"
+// padding, i.e. it adds trailing padding values until the desired size is met.
+// Output sizes are assumed to be greater than the input sizes. The padding
+// width is calculated as: resDim - sourceDim.
+//
+// Handling static sizes is trivial. Dynamic dimensions are trickier (*):
+//  1. dynamic input sizes are extracted from `source`
+//  2. for dynamic output dims, there are two options:
+//    2.1 all output dynamic dim sizes are specified in `dynOutDim`,
+//    2.2 `dynOutDim` is empty and the corresponding padding width is set to 0.
+//
+// (*) Note that `resType` is just a shape and it only encodes the actual sizes
+// for _static_ dimensions.
+PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad,
+                      bool nofold, Location loc, OpBuilder &builder,
+                      SmallVector<Value> dynOutDim = {});
 
 // Creates dim ops for each dynamic dimension of the ranked tensor argument and
 // returns these as values.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e0dea8e78d55c1..729b2653cd83c0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1021,8 +1021,11 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
   return success();
 }
 
-/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
-/// source directly. The method assumes that the `packOp` has static shapes.
+/// If padding value is set, returns a tensor.pad Op for the source tensor,
+/// with the output shape matching the output of `packOp`. Otherwise, returns
+/// the source directly.
+///
+/// This method assumes that all outer dims for this pack Op are 1.
 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                            tensor::PackOp packOp) {
   Value input = packOp.getSource();
@@ -1038,26 +1041,48 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
   ShapedType inputType = packOp.getSourceType();
   int64_t inputRank = inputType.getRank();
 
-  SmallVector<int64_t> paddedShape;
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
       packOp.getDimAndTileMapping();
-  for (int64_t dim = 0; dim < inputRank; ++dim) {
-    int64_t size = inputType.getDimSize(dim);
-    if (!tileAndPosMapping.count(dim)) {
-      paddedShape.push_back(size);
+
+  // The sizes of dynamic tiles
+  SmallVector<Value> dynamicTileSizes;
+
+  // Collect dims for the padded shape.
+  SmallVector<int64_t> paddedShape;
+  for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
+    // 1. Non-tiled outer dims.
+    // These dims should be 1 and we simply preserve them.
+    if (!tileAndPosMapping.count(dimIdx)) {
+      int64_t inputDimSize = inputType.getDimSize(dimIdx);
+      assert(inputDimSize == 1 &&
+             "with all outer dims == 1, this non-tiled input dim should be 1!");
+      paddedShape.push_back(inputDimSize);
+      continue;
+    }
+
+    // 2. Tiled outer dims
+    // As all outer dims == 1, it is safe to use the tile size for the padded
+    // shape.
+    OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
+
+    // 2.1 Static tile sizes
+    std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
+    if (cstTileSize.has_value()) {
+      paddedShape.push_back(cstTileSize.value());
       continue;
     }
 
-    // The size is less than or equal to tileSize because outer dims are all 1s.
-    std::optional<int64_t> tileSize =
-        getConstantIntValue(tileAndPosMapping.lookup(dim));
-    assert(tileSize.has_value() && "dynamic inner tile size is not supported");
-    paddedShape.push_back(tileSize.value());
+    // 2.2 Dynamic tile sizes
+    paddedShape.push_back(ShapedType::kDynamic);
+
+    // Get the value that holds the dynamic size.
+    dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
   }
   auto resultType =
       RankedTensorType::get(paddedShape, inputType.getElementType());
   return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
-                                 /*nofold=*/false, loc, builder);
+                                 /*nofold=*/false, loc, builder,
+                                 dynamicTileSizes);
 }
 
 // Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1145,10 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
-  if (llvm::any_of(packOp.getMixedTiles(),
-                   [](OpFoldResult tile) { return tile.is<Value>(); })) {
-    return rewriter.notifyMatchFailure(packOp,
-                                       "require inner tile sizes being static");
+  if (llvm::count_if(packOp.getMixedTiles(),
+                     [](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
+    return rewriter.notifyMatchFailure(
+        packOp, "at most one dynamic tile size is supported");
   }
 
   // TODO: support the case that outer dimensions are not all 1s. A
@@ -1147,12 +1172,15 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> readSizes;
-  SmallVector<int64_t> readShape;
+  SmallVector<OpFoldResult> transShapeForEmpty;
+  SmallVector<int64_t> readShapeForExtractSlice;
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
     if (dimAndTileMapping.count(i)) {
-      readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
-                              .value_or(ShapedType::kDynamic));
+      readShapeForExtractSlice.push_back(
+          getConstantIntValue(dimAndTileMapping[i])
+              .value_or(ShapedType::kDynamic));
       readSizes.push_back(dimAndTileMapping[i]);
+      transShapeForEmpty.push_back(dimAndTileMapping[i]);
       continue;
     }
     if (ShapedType::isDynamic(inputShape[i])) {
@@ -1161,12 +1189,14 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     } else {
       readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
     }
-    if (inputShape[i] != 1)
-      readShape.push_back(inputShape[i]);
+    if (inputShape[i] != 1) {
+      readShapeForExtractSlice.push_back(inputShape[i]);
+      transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
+    }
   }
 
   Type elemType = packOp.getSourceType().getElementType();
-  auto readType = RankedTensorType::get(readShape, elemType);
+  auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
 
   Value tile = rewriter.create<tensor::ExtractSliceOp>(
       loc, readType, input, readOffsets, readSizes, readStrides);
@@ -1178,10 +1208,10 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
              llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
 
-  SmallVector<int64_t> transpShape = readShape;
-  applyPermutationToVector<int64_t>(transpShape, perm);
+  applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
 
-  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+  Value empty =
+      rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
 

diff  --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index a0d8a08fc6ba47..1cb040b6dca414 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -16,28 +16,48 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR//VectorOps.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
 
-PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
+PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
                                     Value pad, bool nofold, Location loc,
-                                    OpBuilder &b) {
-  SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
-  SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
-  for (const auto &en : enumerate(type.getShape())) {
-    // Pad only the static dimensions of the result tensor type.
-    if (ShapedType::isDynamic(en.value()))
+                                    OpBuilder &b,
+                                    SmallVector<Value> dynOutDims) {
+
+  assert((resType.getNumDynamicDims() == dynOutDims.size()) ||
+         dynOutDims.empty() &&
+             "Either none or all output dynamic dims must be specified!");
+
+  // Init "low" and "high" padding values ("low" is kept as is, "high" is
+  // computed below).
+  SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0));
+  SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0));
+
+  size_t outDimIdx = 0;
+
+  for (const auto [idx, val] : enumerate(resType.getShape())) {
+    bool isDimDynamic = ShapedType::isDynamic(val);
+    bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();
+
+    // Keep the default padding width (i.e. "0") when the output dim is dynamic
+    // and no actual output sizes have been provided.
+    if (!updatePadHigh)
       continue;
-    // Compute the padding width.
-    AffineExpr d0;
-    bindDims(b.getContext(), d0);
-    OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
-    high[en.index()] =
-        affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+
+    // Compute the padding width: resDim - sourceDim.
+    AffineExpr d0, d1;
+    bindDims(b.getContext(), d0, d1);
+    OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
+    OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++])
+                                       : OpFoldResult(b.getIndexAttr(val));
+
+    high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
+                                                      {outDim, sourceDim});
   }
-  return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
+  return b.create<PadOp>(loc, resType, source, low, high, pad, nofold);
 }
 
 SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 7d87a0994004fe..bb23a869a9cc5b 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -23,6 +23,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
   %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
   return %0 : tensor<1x1x8x2xf32>
 }
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 - 5)>
+
 // CHECK-LABEL: func.func @simple_pad_and_pack
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
@@ -34,6 +36,59 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 
+/// Same as example above, but with dynamic tile size.
+
+func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+  return %0 : tensor<1x1x?x2xf32>
+}
+
+// CHECK-LABEL:   func.func @simple_pad_and_pack_dynamic(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[HIGH_VAL:.*]]: index) -> tensor<1x1x?x2xf32> {
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
+// CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK:             tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT:       linalg.transpose
+// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK:           %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
+
+/// Same as example above, but with scalable tile size.
+
+/// NOTE: For this example to make sense in practice, the "?" in the output shape
+///       should effectively be 8 * vector.vscale (and that's what tensor.dim
+///       below should return).
+
+func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
+  %c8 = arith.constant 8 : index
+  %vscale = vector.vscale
+  %c8_vscale = arith.muli %vscale, %c8 : index
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+  return %0 : tensor<1x1x?x2xf32>
+}
+
+// CHECK-LABEL:   func.func @simple_pad_and_pack_scalable(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
+// CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
+// CHECK-DAG:       %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VS:.+]] = vector.vscale
+// CHECK:           %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
+// CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]]]
+// CHECK:           %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK:             tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT:       linalg.transpose
+// CHECK:           %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK:           %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
+
 // -----
 
 func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{


        


More information about the Mlir-commits mailing list