[Mlir-commits] [mlir] [mlir][tensor] Extend the logic to generalise tensor.pack (PR #109815)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 24 08:22:08 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
- **[mlir][tensor] Refine the semantics of `createPadHighOp`**
- **[mlir][tensor] Extend the logic to generalise tensor.pack**
---
Full diff: https://github.com/llvm/llvm-project/pull/109815.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tensor/Utils/Utils.h (+6-5)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+74-18)
- (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+32-10)
- (modified) mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir (+35)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 84d06d456bb689..db5a15c9ec3550 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -14,12 +14,13 @@
namespace mlir {
namespace tensor {
-// Return a PadOp that pads `source` to `type` size where the static
-// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
-// dimensions the padding width is set to zero. The op performs "high" padding
-// (i.e. it adds trailing padding values until the desired size is met).
+// Return a PadOp that pads `source` to `type` size. Output sizes (from `type`)
+// are assumed to be static and greater than the potentially dynamic input sizes
+// (from `source). The op performs "high" padding (i.e. it adds trailing padding
+// values until the desired size is met).
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
- bool nofold, Location loc, OpBuilder &builder);
+ bool nofold, Location loc, OpBuilder &builder,
+ std::optional<Value> dynOutDim = {});
// Creates dim ops for each dynamic dimension of the ranked tensor argument and
// returns these as values.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e0dea8e78d55c1..42389b431566eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1021,8 +1021,16 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
return success();
}
-/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
-/// source directly. The method assumes that the `packOp` has static shapes.
+/// If padding value is set, returns a tensor.pad Op for the source tensor,
+/// with the output shape matching the output of `packOp`. Otherwise, returns
+/// the source directly.
+///
+/// This method assumes that all outer dims for this pack Op are 1.
+///
+/// At most _one_ inner tile size can be _dynamic_, all other inner tiles are
+/// required to have static sizes. The inner tile that's dynamic must be a
+/// multiple of vector.vscale (to support scalable tile sizes). This condition
+/// can be relaxed in the future.
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
tensor::PackOp packOp) {
Value input = packOp.getSource();
@@ -1038,26 +1046,50 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
ShapedType inputType = packOp.getSourceType();
int64_t inputRank = inputType.getRank();
- SmallVector<int64_t> paddedShape;
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
packOp.getDimAndTileMapping();
- for (int64_t dim = 0; dim < inputRank; ++dim) {
- int64_t size = inputType.getDimSize(dim);
- if (!tileAndPosMapping.count(dim)) {
- paddedShape.push_back(size);
+
+ // The size of a scalable tile (if present).
+ Value scalableSize;
+
+ // Collect dims for the padded shape.
+ SmallVector<int64_t> paddedShape;
+ for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
+ int64_t inputDimSize = inputType.getDimSize(dimIdx);
+ // 1. Non-tiled outer dims.
+ // These dims should be 1 and we simply preserve them.
+ if (!tileAndPosMapping.count(dimIdx)) {
+ assert(inputDimSize == 1 &&
+ "with all outer dims == 1, this non-tiled input dim should be 1!");
+ paddedShape.push_back(inputDimSize);
+ continue;
+ }
+
+ // 2. Tiled outer dims
+ // As all outer dims == 1, it is safe to use the tile size for the padded
+ // shape.
+ OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
+
+ // 2.1 Static tile sizes
+ std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
+ if (cstTileSize.has_value()) {
+ paddedShape.push_back(cstTileSize.value());
continue;
}
- // The size is less than or equal to tileSize because outer dims are all 1s.
- std::optional<int64_t> tileSize =
- getConstantIntValue(tileAndPosMapping.lookup(dim));
- assert(tileSize.has_value() && "dynamic inner tile size is not supported");
- paddedShape.push_back(tileSize.value());
+ // 2.2 Dynamic tile sizes
+ paddedShape.push_back(ShapedType::kDynamic);
+
+ // Get the value that holds the scalable size.
+ assert(!scalableSize && "Only one scalable size is supported ATM.");
+ scalableSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
+ assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+ "This dynamic shape is not a multiple of vscale, this !");
}
auto resultType =
RankedTensorType::get(paddedShape, inputType.getElementType());
return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
- /*nofold=*/false, loc, builder);
+ /*nofold=*/false, loc, builder, scalableSize);
}
// Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1152,18 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
- if (llvm::any_of(packOp.getMixedTiles(),
- [](OpFoldResult tile) { return tile.is<Value>(); })) {
- return rewriter.notifyMatchFailure(packOp,
- "require inner tile sizes being static");
+ if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
+ return tile.is<Value>() && !vector::getConstantVscaleMultiplier(
+ llvm::dyn_cast<Value>(tile));
+ })) {
+ return rewriter.notifyMatchFailure(
+ packOp, "require inner tile sizes to be either static or a constant "
+ "multiple of vector.vscale");
+ }
+ if (llvm::count_if(packOp.getMixedTiles(),
+ [](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
+ return rewriter.notifyMatchFailure(
+ packOp, "at most one dynamic tile size is supported");
}
// TODO: support the case that outer dimensions are not all 1s. A
@@ -1181,7 +1221,23 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<int64_t> transpShape = readShape;
applyPermutationToVector<int64_t>(transpShape, perm);
- Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+ // If there's a tile with a scalable size, retrieve its size. ATM only 1
+ // scalable tile is allowed.
+ Value scalableSize;
+ for (auto tile : packOp.getMixedTiles()) {
+ if (tile.is<Value>()) {
+ assert(!scalableSize && "Only one scalable size is supported ATM.");
+ scalableSize = cast<Value>(tile);
+ assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+ "This dynamic shape is not a multiple of vscale!");
+ }
+ }
+
+ Value empty =
+ ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
+ ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
+ scalableSize)
+ : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index a0d8a08fc6ba47..7b25d9747827e3 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR//VectorOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
@@ -23,19 +24,40 @@ using namespace mlir::tensor;
PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
Value pad, bool nofold, Location loc,
- OpBuilder &b) {
+ OpBuilder &b,
+ std::optional<Value> dynOutDim) {
+ assert(llvm::count_if(
+ type.getShape(),
+ [](int64_t dim) { return ShapedType::isDynamic(dim); }) <= 1 &&
+ "At most one output dim can be dynamic!");
+
+ // Init "low" and "high" padding values ("low" is kept as is, "high" is
+ // computed below).
SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
+
for (const auto &en : enumerate(type.getShape())) {
- // Pad only the static dimensions of the result tensor type.
- if (ShapedType::isDynamic(en.value()))
- continue;
- // Compute the padding width.
- AffineExpr d0;
- bindDims(b.getContext(), d0);
- OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
- high[en.index()] =
- affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+ if (!ShapedType::isDynamic(en.value())) {
+ // Static sizes - the "high" value is computed based on the input and
+ // output dims. Compute the padding width.
+ AffineExpr d0;
+ bindDims(b.getContext(), d0);
+ OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
+ high[en.index()] =
+ affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+ } else {
+ // Dynamic sizes - the "high" value is computed based on the input dim
+ // and `dynOutDim`.
+ assert(dynOutDim.has_value() &&
+ "dynamic output dim requires dynOutDim to be set");
+
+ // Compute the padding width.
+ AffineExpr d0, d1;
+ auto dimVal = b.create<tensor::DimOp>(loc, source, en.index());
+ bindDims(b.getContext(), d0, d1);
+ high[en.index()] = affine::makeComposedFoldedAffineApply(
+ b, loc, d0 - d1, {dynOutDim.value(), dimVal.getResult()});
+ }
}
return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
}
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 7d87a0994004fe..66a220005ebf36 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -23,6 +23,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
return %0 : tensor<1x1x8x2xf32>
}
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
+
// CHECK-LABEL: func.func @simple_pad_and_pack
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
@@ -34,6 +36,39 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+/// Same as example above, but with scalable sizes.
+
+/// NOTE: For this example to make sense in practice, the "?" in the output shape
+/// should effectively be 8 * vector.vscale (and that's what tensor.dim
+/// below should return).
+
+func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
+ %c8 = arith.constant 8 : index
+ %vscale = vector.vscale
+ %c8_vscale = arith.muli %vscale, %c8 : index
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+ return %0 : tensor<1x1x?x2xf32>
+}
+
+
+// CHECK-LABEL: func.func @simple_pad_and_pack_scalable(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
+// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[C5:.+]] = arith.constant 5 : index
+// CHECK: %[[C8:.+]] = arith.constant 8 : index
+// CHECK: %[[VS:.+]] = vector.vscale
+// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
+// CHECK: %[[PAD_HIGH:.+]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]], %[[C5]]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK: tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT: linalg.transpose
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+
// -----
func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{
``````````
</details>
https://github.com/llvm/llvm-project/pull/109815
More information about the Mlir-commits
mailing list