[Mlir-commits] [mlir] [mlir][tensor] Extend the logic to generalise tensor.pack (PR #109815)

Andrzej Warzyński llvmlistbot at llvm.org
Thu Sep 26 06:39:07 PDT 2024


================
@@ -1181,7 +1221,23 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 
-  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+  // If there's a tile with a scalable size, retrieve its size. ATM only 1
+  // scalable tile is allowed.
+  Value scalableSize;
+  for (auto tile : packOp.getMixedTiles()) {
+    if (tile.is<Value>()) {
+      assert(!scalableSize && "Only one scalable size is supported ATM.");
+      scalableSize = cast<Value>(tile);
+      assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+             "This dynamic shape is not a multiple of vscale!");
+    }
+  }
+
+  Value empty =
+      ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
+          ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
+                                             scalableSize)
+          : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
----------------
banach-space wrote:

You're probably right that this can be simplified, but it's not immediately obvious to me :) Note that for this example:
```mlir
%0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32>
```

`readSizes` and `readShape` are `[1, 1, 32, 8]` and `[32, 8]`, respectively. I need to take another look 😅 

https://github.com/llvm/llvm-project/pull/109815


More information about the Mlir-commits mailing list