[Mlir-commits] [mlir] 66f84c8 - [mlir][tensor] Extend the logic to generalise tensor.pack (#109815)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 2 01:44:16 PDT 2024
Author: Andrzej WarzyĆski
Date: 2024-10-02T09:44:13+01:00
New Revision: 66f84c8b8a762832af39e91370018f8f8307a0fc
URL: https://github.com/llvm/llvm-project/commit/66f84c8b8a762832af39e91370018f8f8307a0fc
DIFF: https://github.com/llvm/llvm-project/commit/66f84c8b8a762832af39e91370018f8f8307a0fc.diff
LOG: [mlir][tensor] Extend the logic to generalise tensor.pack (#109815)
Extends the logic to generalise tensor.pack (into e.g. tensor.pad +
tensor.transpose) so that it also works when one of the inner tile sizes
is scalable (i.e. a multiple of `vector.vscale`). For example:
```mlir
%c8 = arith.constant 8 : index
%vscale = vector.vscale
%c8_vscale = arith.muli %vscale, %c8 : index
%0 = tensor.pack %input
padding_value(%pad : f32)
inner_dims_pos = [0, 1]
inner_tiles = [%c8_vscale, 2]
into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
}
```
is generalised as:
```mlir
%c8 = arith.constant 8 : index
%vscale = vector.vscale
%c8_vscale = arith.muli %vscale, %c8 : index
%0 = affine.apply #map()[%c8_vscale, %c5]
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg3: index, %arg4: index):
tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
```
At the Tensor level, we model scalability using dynamic shapes and this
change basically extends the relevant logic so that it also works for
dynamic shapes.
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 84d06d456bb689..ed1ec1e871482d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -14,12 +14,22 @@
namespace mlir {
namespace tensor {
-// Return a PadOp that pads `source` to `type` size where the static
-// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
-// dimensions the padding width is set to zero. The op performs "high" padding
-// (i.e. it adds trailing padding values until the desired size is met).
-PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
- bool nofold, Location loc, OpBuilder &builder);
+// Return a PadOp that pads `source` to `resType` size. The op performs "high"
+// padding, i.e. it adds trailing padding values until the desired size is met.
+// Output sizes are assumed to be greater than the input sizes. The padding
+// width is calculated as: resDim - sourceDim.
+//
+// Handling static sizes is trivial. Dynamic dimensions are trickier (*):
+// 1. dynamic input sizes are extracted from `source`
+// 2. for dynamic output dims, there are two options:
+// 2.1 all output dynamic dim sizes are specified in `dynOutDim`,
+// 2.2 `dynOutDim` is empty and the corresponding padding width is set to 0.
+//
+// (*) Note that `resType` is just a shape and it only encodes the actual sizes
+// for _static_ dimensions.
+PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad,
+ bool nofold, Location loc, OpBuilder &builder,
+ SmallVector<Value> dynOutDim = {});
// Creates dim ops for each dynamic dimension of the ranked tensor argument and
// returns these as values.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e0dea8e78d55c1..729b2653cd83c0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1021,8 +1021,11 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
return success();
}
-/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
-/// source directly. The method assumes that the `packOp` has static shapes.
+/// If padding value is set, returns a tensor.pad Op for the source tensor,
+/// with the output shape matching the output of `packOp`. Otherwise, returns
+/// the source directly.
+///
+/// This method assumes that all outer dims for this pack Op are 1.
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
tensor::PackOp packOp) {
Value input = packOp.getSource();
@@ -1038,26 +1041,48 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
ShapedType inputType = packOp.getSourceType();
int64_t inputRank = inputType.getRank();
- SmallVector<int64_t> paddedShape;
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
packOp.getDimAndTileMapping();
- for (int64_t dim = 0; dim < inputRank; ++dim) {
- int64_t size = inputType.getDimSize(dim);
- if (!tileAndPosMapping.count(dim)) {
- paddedShape.push_back(size);
+
+ // The sizes of dynamic tiles
+ SmallVector<Value> dynamicTileSizes;
+
+ // Collect dims for the padded shape.
+ SmallVector<int64_t> paddedShape;
+ for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
+ // 1. Non-tiled outer dims.
+ // These dims should be 1 and we simply preserve them.
+ if (!tileAndPosMapping.count(dimIdx)) {
+ int64_t inputDimSize = inputType.getDimSize(dimIdx);
+ assert(inputDimSize == 1 &&
+ "with all outer dims == 1, this non-tiled input dim should be 1!");
+ paddedShape.push_back(inputDimSize);
+ continue;
+ }
+
+ // 2. Tiled outer dims
+ // As all outer dims == 1, it is safe to use the tile size for the padded
+ // shape.
+ OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
+
+ // 2.1 Static tile sizes
+ std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
+ if (cstTileSize.has_value()) {
+ paddedShape.push_back(cstTileSize.value());
continue;
}
- // The size is less than or equal to tileSize because outer dims are all 1s.
- std::optional<int64_t> tileSize =
- getConstantIntValue(tileAndPosMapping.lookup(dim));
- assert(tileSize.has_value() && "dynamic inner tile size is not supported");
- paddedShape.push_back(tileSize.value());
+ // 2.2 Dynamic tile sizes
+ paddedShape.push_back(ShapedType::kDynamic);
+
+ // Get the value that holds the dynamic size.
+ dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
}
auto resultType =
RankedTensorType::get(paddedShape, inputType.getElementType());
return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
- /*nofold=*/false, loc, builder);
+ /*nofold=*/false, loc, builder,
+ dynamicTileSizes);
}
// Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1145,10 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
- if (llvm::any_of(packOp.getMixedTiles(),
- [](OpFoldResult tile) { return tile.is<Value>(); })) {
- return rewriter.notifyMatchFailure(packOp,
- "require inner tile sizes being static");
+ if (llvm::count_if(packOp.getMixedTiles(),
+ [](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
+ return rewriter.notifyMatchFailure(
+ packOp, "at most one dynamic tile size is supported");
}
// TODO: support the case that outer dimensions are not all 1s. A
@@ -1147,12 +1172,15 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
- SmallVector<int64_t> readShape;
+ SmallVector<OpFoldResult> transShapeForEmpty;
+ SmallVector<int64_t> readShapeForExtractSlice;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
- readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
- .value_or(ShapedType::kDynamic));
+ readShapeForExtractSlice.push_back(
+ getConstantIntValue(dimAndTileMapping[i])
+ .value_or(ShapedType::kDynamic));
readSizes.push_back(dimAndTileMapping[i]);
+ transShapeForEmpty.push_back(dimAndTileMapping[i]);
continue;
}
if (ShapedType::isDynamic(inputShape[i])) {
@@ -1161,12 +1189,14 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
} else {
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
}
- if (inputShape[i] != 1)
- readShape.push_back(inputShape[i]);
+ if (inputShape[i] != 1) {
+ readShapeForExtractSlice.push_back(inputShape[i]);
+ transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
+ }
}
Type elemType = packOp.getSourceType().getElementType();
- auto readType = RankedTensorType::get(readShape, elemType);
+ auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, readType, input, readOffsets, readSizes, readStrides);
@@ -1178,10 +1208,10 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
- SmallVector<int64_t> transpShape = readShape;
- applyPermutationToVector<int64_t>(transpShape, perm);
+ applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
- Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+ Value empty =
+ rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index a0d8a08fc6ba47..1cb040b6dca414 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -16,28 +16,48 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR//VectorOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
using namespace mlir::tensor;
-PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
+PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
Value pad, bool nofold, Location loc,
- OpBuilder &b) {
- SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
- SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
- for (const auto &en : enumerate(type.getShape())) {
- // Pad only the static dimensions of the result tensor type.
- if (ShapedType::isDynamic(en.value()))
+ OpBuilder &b,
+ SmallVector<Value> dynOutDims) {
+
+ assert((resType.getNumDynamicDims() == dynOutDims.size()) ||
+ dynOutDims.empty() &&
+ "Either none or all output dynamic dims must be specified!");
+
+ // Init "low" and "high" padding values ("low" is kept as is, "high" is
+ // computed below).
+ SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0));
+ SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0));
+
+ size_t outDimIdx = 0;
+
+ for (const auto [idx, val] : enumerate(resType.getShape())) {
+ bool isDimDynamic = ShapedType::isDynamic(val);
+ bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();
+
+ // Keep the default padding width (i.e. "0") when the output dim is dynamic
+ // and no actual output sizes have been provided.
+ if (!updatePadHigh)
continue;
- // Compute the padding width.
- AffineExpr d0;
- bindDims(b.getContext(), d0);
- OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
- high[en.index()] =
- affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
+
+ // Compute the padding width: resDim - sourceDim.
+ AffineExpr d0, d1;
+ bindDims(b.getContext(), d0, d1);
+ OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
+ OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++])
+ : OpFoldResult(b.getIndexAttr(val));
+
+ high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
+ {outDim, sourceDim});
}
- return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
+ return b.create<PadOp>(loc, resType, source, low, high, pad, nofold);
}
SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 7d87a0994004fe..bb23a869a9cc5b 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -23,6 +23,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
return %0 : tensor<1x1x8x2xf32>
}
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 - 5)>
+
// CHECK-LABEL: func.func @simple_pad_and_pack
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
@@ -34,6 +36,59 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+/// Same as example above, but with dynamic tile size.
+
+func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+ return %0 : tensor<1x1x?x2xf32>
+}
+
+// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[HIGH_VAL:.*]]: index) -> tensor<1x1x?x2xf32> {
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
+// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK: tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT: linalg.transpose
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+
+/// Same as example above, but with scalable tile size.
+
+/// NOTE: For this example to make sense in practice, the "?" in the output shape
+/// should effectively be 8 * vector.vscale (and that's what tensor.dim
+/// below should return).
+
+func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
+ %c8 = arith.constant 8 : index
+ %vscale = vector.vscale
+ %c8_vscale = arith.muli %vscale, %c8 : index
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+ return %0 : tensor<1x1x?x2xf32>
+}
+
+// CHECK-LABEL: func.func @simple_pad_and_pack_scalable(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
+// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[VS:.+]] = vector.vscale
+// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
+// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK: tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NOT: linalg.transpose
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
+// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+
// -----
func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{
More information about the Mlir-commits
mailing list