[Mlir-commits] [mlir] [mlir][tensor] Extend the logic to generalise tensor.pack (PR #109815)

Han-Chung Wang llvmlistbot at llvm.org
Thu Sep 26 09:36:43 PDT 2024


================
@@ -1181,7 +1221,23 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 
-  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+  // If there's a tile with a scalable size, retrieve its size. ATM only 1
+  // scalable tile is allowed.
+  Value scalableSize;
+  for (auto tile : packOp.getMixedTiles()) {
+    if (tile.is<Value>()) {
+      assert(!scalableSize && "Only one scalable size is supported ATM.");
+      scalableSize = cast<Value>(tile);
+      assert(vector::getConstantVscaleMultiplier(scalableSize) &&
+             "This dynamic shape is not a multiple of vscale!");
+    }
+  }
+
+  Value empty =
+      ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
+          ? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
+                                             scalableSize)
+          : rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
----------------
hanhanW wrote:

oh I see, I misread the code. The unit dims are not inserted to `readShape`. Perhaps we can add `SmallVector<OpFoldResult> transpShape` to l.1190, and insert the `OpFoldResult` values to the vector and use it here.

https://github.com/llvm/llvm-project/pull/109815


More information about the Mlir-commits mailing list