[Mlir-commits] [mlir] [mlir][tensor] Extend the logic to generalise tensor.pack (PR #109815)
Han-Chung Wang
llvmlistbot at llvm.org
Wed Sep 25 10:48:05 PDT 2024
================
@@ -1181,7 +1221,23 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<int64_t> transpShape = readShape;
applyPermutationToVector<int64_t>(transpShape, perm);
- Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+ // If there's a tile with a scalable size, retrieve its size. ATM only 1
+ // scalable tile is allowed.
+ Value scalableSize;
+ for (auto tile : packOp.getMixedTiles()) {
+ if (tile.is<Value>()) {
+ assert(!scalableSize && "Only one scalable size is supported ATM.");
+ scalableSize = cast<Value>(tile);
+ assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+ "This dynamic shape is not a multiple of vscale!");
+ }
+ }
+
+ Value empty =
+ ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
+ ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
+ scalableSize)
+ : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
----------------
hanhanW wrote:
Can we simplify it by using `SmallVector<OpFoldResult>`? I think the `readSizes` already has the shape (without transposition). So it can be something like:
```cpp
SmallVector<OpFoldResult> transpShape = readSize;
applyPermutationToVector<OpFoldResult>(transpShape, perm);
Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
```
https://github.com/llvm/llvm-project/pull/109815
More information about the Mlir-commits
mailing list