[Mlir-commits] [mlir] [mlir][linalg] Update vectorization of linalg.pack (PR #163539)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Oct 23 11:31:40 PDT 2025
================
@@ -1901,12 +1801,120 @@ static VectorType getCollapsedVecType(VectorType type,
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
}
+/// Vectorize `linalg.pack` as:
+/// * xfer_read -> shape_cast -> transpose -> xfer_write
+///
+/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
+/// sizes for the xfer_write operation). This is sufficient to infer the other
+/// vector sizes required here.
+///
+/// If the vector sizes are not provided:
+/// * the vector sizes are determined from the destination tensor static shape.
+/// * the inBounds attribute is used instead of masking.
+///
+/// EXAMPLE (no vector sizes):
+/// ```
+/// %pack = tensor.pack %src
+/// inner_dims_pos = [2, 1]
+/// inner_tiles = [16, 2]
+/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+/// ``
+/// is vectorizes as:
+/// ```
+/// %read = vector.transfer_read %src
+/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
+/// %sc = vector.shape_cast %read
+/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
+/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+/// %write = vector.transfer_write %tr into %dest
+/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+/// ```
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ if (!inputVectorSizes.empty()) {
+ assert(inputVectorSizes.size() == packOp.getDestRank() &&
+ "Invalid number of input vector sizes!");
+ }
+
+ // TODO: Introduce a parent class that will handle the insertion point update.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
+ Location loc = packOp.getLoc();
+ std::optional<Value> padValue = packOp.getPaddingValue()
+ ? std::optional(packOp.getPaddingValue())
+ : std::nullopt;
+
+ SmallVector<int64_t> destShape =
+ SmallVector<int64_t>(packOp.getDestType().getShape());
+
+ // This is just a convenience alias to clearly communicate that the input
+ // vector sizes determine the _write_ sizes.
+ ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
+
+ // In the absence of input-vector-sizes, use the _static_ input tensor shape.
+ // In addition, use the inBounds attribute instead of masking.
+ bool useInBoundsInsteadOfMasking = false;
+ if (writeVectorSizes.empty()) {
+ if (ShapedType::isDynamicShape(destShape))
+ return rewriter.notifyMatchFailure(packOp,
+ "Unable to infer vector sizes!");
----------------
hanhanW wrote:
We usually start the first sentence with a lowercase letter, and finish the last sentence without a period/exclamation mark .
https://llvm.org/docs/CodingStandards.html#error-and-warning-messages
https://github.com/llvm/llvm-project/pull/163539
More information about the Mlir-commits
mailing list