[Mlir-commits] [mlir] [mlir][tensor] Extend the logic to generalise tensor.pack (PR #109815)
Andrzej Warzyński
llvmlistbot at llvm.org
Thu Sep 26 06:39:07 PDT 2024
================
@@ -1181,7 +1221,23 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<int64_t> transpShape = readShape;
applyPermutationToVector<int64_t>(transpShape, perm);
- Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+ // If there's a tile with a scalable size, retrieve its size. ATM only 1
+ // scalable tile is allowed.
+ Value scalableSize;
+ for (auto tile : packOp.getMixedTiles()) {
+ if (tile.is<Value>()) {
+ assert(!scalableSize && "Only one scalable size is supported ATM.");
+ scalableSize = cast<Value>(tile);
+ assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+ "This dynamic shape is not a multiple of vscale!");
+ }
+ }
+
+ Value empty =
+ ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
+ ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
+ scalableSize)
+ : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
----------------
banach-space wrote:
You're probably right that this can be simplified, but it's not immediately obvious to me :) Note that for this example:
```mlir
%0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32>
```
`readSizes` and `readShape` are `[1, 1, 32, 8]` and `[32, 8]`, respectively. I need to take another look 😅
https://github.com/llvm/llvm-project/pull/109815
More information about the Mlir-commits
mailing list