[Mlir-commits] [mlir] [mlir][linalg] Update vectorization of linalg.pack (PR #163539)

Han-Chung Wang llvmlistbot at llvm.org
Thu Oct 23 11:31:40 PDT 2025


================
@@ -1901,12 +1801,120 @@ static VectorType getCollapsedVecType(VectorType type,
   return VectorType::get(newShape, type.getElementType(), newScalableFlags);
 }
 
+/// Vectorize `linalg.pack` as:
+///   * xfer_read -> shape_cast -> transpose -> xfer_write
+///
+/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
+/// sizes for the xfer_write operation). This is sufficient to infer the other
+/// vector sizes required here.
+///
+/// If the vector sizes are not provided:
+///  * the vector sizes are determined from the destination tensor static shape.
+///  * the inBounds attribute is used instead of masking.
+///
+/// EXAMPLE (no vector sizes):
+/// ```
+///   %pack = tensor.pack %src
+///     inner_dims_pos = [2, 1]
+///     inner_tiles = [16, 2]
+///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+/// ``
+/// is vectorizes as:
+/// ```
+///   %read = vector.transfer_read %src
+///     : tensor<32x7x16xf32>, vector<32x8x16xf32>
+///   %sc = vector.shape_cast %read
+///     : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+///   %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
+///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+///   %write = vector.transfer_write %tr into %dest
+///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+/// ```
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
+                        ArrayRef<int64_t> inputVectorSizes,
+                        SmallVectorImpl<Value> &newResults) {
+  if (!inputVectorSizes.empty()) {
+    assert(inputVectorSizes.size() == packOp.getDestRank() &&
+           "Invalid number of input vector sizes!");
+  }
+
+  // TODO: Introduce a parent class that will handle the insertion point update.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  Location loc = packOp.getLoc();
+  std::optional<Value> padValue = packOp.getPaddingValue()
+                                      ? std::optional(packOp.getPaddingValue())
+                                      : std::nullopt;
+
+  SmallVector<int64_t> destShape =
+      SmallVector<int64_t>(packOp.getDestType().getShape());
+
+  // This is just a convenience alias to clearly communicate that the input
+  // vector sizes determine the _write_ sizes.
+  ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
+
+  // In the absence of input-vector-sizes, use the _static_ input tensor shape.
+  // In addition, use the inBounds attribute instead of masking.
+  bool useInBoundsInsteadOfMasking = false;
+  if (writeVectorSizes.empty()) {
+    if (ShapedType::isDynamicShape(destShape))
+      return rewriter.notifyMatchFailure(packOp,
+                                         "Unable to infer vector sizes!");
----------------
hanhanW wrote:

We usually start the first sentence with a lowercase letter, and finish the last sentence without a period/exclamation mark .

https://llvm.org/docs/CodingStandards.html#error-and-warning-messages

https://github.com/llvm/llvm-project/pull/163539


More information about the Mlir-commits mailing list