[Mlir-commits] [mlir] [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d (PR #146531)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Jul 17 05:32:05 PDT 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/146531
>From 487db47dbae6de69f174868b625fb7730612c09a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 1 Jul 2025 13:39:46 +0000
Subject: [PATCH 1/5] [mlir][linalg] Add support for scalable vectorization of
linalg.mmt4d
This patch introduces support for scalable vectorization of `linalg.mmt4d`.
The key design addition is a new state variable in the Linalg vectorizer:
* `assumeScalableVecSizesMatchDimSize`
This flag informs the vectorizer that the memref/tensor dimensions
corresponding to scalable vector sizes (typically dynamic) _match the
vector sizes_ at runtime.
While this assumption is not generally valid, it does hold for
`linalg.mmt4d` because inputs and outputs are explicitly packed (via
`linalg.pack`). Packing includes padding, which ensures that dimension
sizes align with the scalable vector lengths (*).
See discussion here:
* https://github.com/llvm/llvm-project/issues/143920
(*) Provided that the tile sizes used for packing match the vector sizes used
during vectorization. Enforcing this is left to the user.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 11 +-
.../Dialect/Linalg/Transforms/Transforms.h | 3 +-
.../TransformOps/LinalgTransformOps.cpp | 3 +-
.../Linalg/Transforms/Vectorization.cpp | 79 ++++++++----
.../Linalg/vectorization/linalg-ops.mlir | 117 ++++++++++++++----
5 files changed, 157 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d64f94a49f781..baa17f75e53b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2440,12 +2440,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
- $static_vector_sizes,
- OptionalAttr<UnitAttr>:$vectorize_nd_extract,
- DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
- $scalable_sizes);
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
+ OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+ OptionalAttr<UnitAttr>:$assume_scalable_sizes_match_dim_size,
+ DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
let results = (outs);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2b4855f49695c..a6d697b43c0b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -871,7 +871,8 @@ FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+ bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+ bool assumeScalableSizesMultipleOfDim = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8571d641e26d1..49b9a41831fc6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
- getVectorizeNdExtract().value_or(false));
+ getVectorizeNdExtract().value_or(false), false,
+ getAssumeScalableSizesMatchDimSize().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b467114c72f7d..3a533322a3c7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -222,9 +222,11 @@ struct VectorizationState {
/// canonical vector shape for vectorization.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
- ArrayRef<bool> inputScalableVecDims);
+ ArrayRef<bool> inputScalableVecDims,
+ bool assumeScalableVecSizesMatchDimSize = false);
- /// Returns the canonical vector shape used to vectorize the iteration space.
+ /// Returns the canonical vector shape used to vectorize the iteration
+ /// space.
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
/// Returns the vector dimensions that are scalable in the canonical vector
@@ -233,8 +235,8 @@ struct VectorizationState {
/// Returns a vector type of the provided `elementType` with the canonical
/// vector shape and the corresponding fixed/scalable dimensions bit. If
- /// `dimPermutation` is provided, the canonical vector dimensions are permuted
- /// accordingly.
+ /// `dimPermutation` is provided, the canonical vector dimensions are
+ /// permuted accordingly.
VectorType getCanonicalVecType(
Type elementType,
std::optional<AffineMap> dimPermutation = std::nullopt) const {
@@ -254,9 +256,9 @@ struct VectorizationState {
}
/// Masks an operation with the canonical vector mask if the operation needs
- /// masking. Returns the masked operation or the original operation if masking
- /// is not needed. If provided, the canonical mask for this operation is
- /// permuted using `maybeIndexingMap`.
+ /// masking. Returns the masked operation or the original operation if
+ /// masking is not needed. If provided, the canonical mask for this
+ /// operation is permuted using `maybeIndexingMap`.
Operation *
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
std::optional<AffineMap> maybeIndexingMap = std::nullopt);
@@ -276,15 +278,15 @@ struct VectorizationState {
/// Create or retrieve an existing mask value to mask `opToMask` in the
/// canonical vector iteration space. If `maybeMaskingMap` the mask is
- /// permuted using that permutation map. If a new mask is created, it will be
- /// cached for future users.
+ /// permuted using that permutation map. If a new mask is created, it will
+ /// be cached for future users.
Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
LinalgOp linalgOp,
std::optional<AffineMap> maybeMaskingMap);
/// Check whether this permutation map can be used for masking. At the
- /// moment we only make sure that there are no broadcast dimensions, but this
- /// might change if indexing maps evolve.
+ /// moment we only make sure that there are no broadcast dimensions, but
+ /// this might change if indexing maps evolve.
bool isValidMaskingMap(AffineMap maskingMap) {
return maskingMap.getBroadcastDims().size() == 0;
}
@@ -324,13 +326,24 @@ struct VectorizationState {
/// shape.
SmallVector<bool> scalableVecDims;
- /// Holds the active masks for permutations of the canonical vector iteration
- /// space.
+ /// Holds the active masks for permutations of the canonical vector
+ /// iteration space.
DenseMap<AffineMap, Value> activeMaskCache;
/// Global vectorization guard for the incoming rewriter. It's initialized
/// when the vectorization state is initialized.
OpBuilder::InsertionGuard rewriterGuard;
+
+ /// Do all scalable vector sizes match the corresponding input dim sizes?
+ /// (tensor or memref)
+ ///
+ /// At the Tensor + MemRef levels, scalable sizes are modelled using
+ /// dynamic dimensions (i.e. `?`). In many cases these sizes result from
+ /// e.g. "scalable packing + tiling" and are known to always match the
+ /// scalable vector sizes. In such cases, masking can be safely skipped,
+ /// despite the presence of dynamic shapes. Use this flag with care and
+ /// only for cases where you are confident the assumption holds.
+ bool assumeScalableVecSizesMatchDimSize = false;
};
LogicalResult
@@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
/// Initializes the vectorization state, including the computation of the
/// canonical vector shape for vectorization.
// TODO: Move this to the constructor when we can remove the failure cases.
-LogicalResult
-VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
- ArrayRef<int64_t> inputVectorSizes,
- ArrayRef<bool> inputScalableVecDims) {
+LogicalResult VectorizationState::initState(RewriterBase &rewriter,
+ LinalgOp linalgOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims,
+ bool assumeScalableSizes) {
+ assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
// Initialize the insertion point.
rewriter.setInsertionPoint(linalgOp);
@@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor(
return Value();
}
+ if (assumeScalableVecSizesMatchDimSize) {
+ // Given that all _scalable vector sizes_ match the corresponding
+ // memref/tensor dim sizes, masking can be skipped provided that:
+ // * all vector sizes corresponding to dynamic dims are scalable.
+ if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
+ [](auto it) {
+ return std::get<0>(it) == ShapedType::kDynamic
+ ? std::get<1>(it)
+ : false;
+ }))
+ LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
+ activeMaskCache[maskingMap] = Value();
+ return Value();
+ }
+
// Permute the iteration space value sizes to compute the mask upper bounds.
SmallVector<Value> upperBounds =
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
@@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
- isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
+ isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+ hasReductionIterator(linalgOp));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> inputVectorSizes,
- ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+ RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+ bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
VectorizationState state(rewriter);
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
- inputScalableVecDims))) {
+ inputScalableVecDims,
+ assumeScalableSizesMultipleOfDim))) {
LDBG("Vectorization state couldn't be initialized\n");
return failure();
}
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 6722de817f6bf..188f03069938f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+ linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+ outs(%C_in: memref<16x16x8x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @mmt4d(
+// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %mmt4d : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+ linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+ outs(%C_in: memref<16x16x8x?xf32>)
+ return
+}
+// CHECK-LABEL: func.func @mmt4d_scalable(
+// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
+// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
+// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
+// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
+// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
+// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+ linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+ outs(%C_in: memref<16x16x8x?xf32>)
+ return
+}
+// CHECK-LABEL: func.func @mmt4d_scalable_with_assume(
+// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK-NOT: mask
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
+// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
+// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
+ transform.yield
+ }
+}
+
///----------------------------------------------------------------------------------------
/// Tests for other Ops
///----------------------------------------------------------------------------------------
@@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
}
}
-// -----
-
-func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
- linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
- outs(%C_in: memref<16x16x8x8xf32>)
- return
-}
-
-// CHECK-LABEL: func.func @mmt4d(
-// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
-// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
-// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
-// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
-// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %mmt4d : !transform.any_op
- transform.yield
- }
-}
// -----
>From 8ef66618eeb45c2ad8a435f47715d0cf0b15fc86 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 10 Jul 2025 16:44:41 +0000
Subject: [PATCH 2/5] fixup! [mlir][linalg] Add support for scalable
vectorization of linalg.mmt4d
Revert changes in comments
---
.../Linalg/Transforms/Vectorization.cpp | 25 +++++++++----------
1 file changed, 12 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3a533322a3c7f..38bf37d844be4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -225,8 +225,7 @@ struct VectorizationState {
ArrayRef<bool> inputScalableVecDims,
bool assumeScalableVecSizesMatchDimSize = false);
- /// Returns the canonical vector shape used to vectorize the iteration
- /// space.
+ /// Returns the canonical vector shape used to vectorize the iteration space.
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
/// Returns the vector dimensions that are scalable in the canonical vector
@@ -235,8 +234,8 @@ struct VectorizationState {
/// Returns a vector type of the provided `elementType` with the canonical
/// vector shape and the corresponding fixed/scalable dimensions bit. If
- /// `dimPermutation` is provided, the canonical vector dimensions are
- /// permuted accordingly.
+ /// `dimPermutation` is provided, the canonical vector dimensions are permuted
+ /// accordingly.
VectorType getCanonicalVecType(
Type elementType,
std::optional<AffineMap> dimPermutation = std::nullopt) const {
@@ -256,9 +255,9 @@ struct VectorizationState {
}
/// Masks an operation with the canonical vector mask if the operation needs
- /// masking. Returns the masked operation or the original operation if
- /// masking is not needed. If provided, the canonical mask for this
- /// operation is permuted using `maybeIndexingMap`.
+ /// masking. Returns the masked operation or the original operation if masking
+ /// is not needed. If provided, the canonical mask for this operation is
+ /// permuted using `maybeIndexingMap`.
Operation *
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
std::optional<AffineMap> maybeIndexingMap = std::nullopt);
@@ -278,15 +277,15 @@ struct VectorizationState {
/// Create or retrieve an existing mask value to mask `opToMask` in the
/// canonical vector iteration space. If `maybeMaskingMap` the mask is
- /// permuted using that permutation map. If a new mask is created, it will
- /// be cached for future users.
+ /// permuted using that permutation map. If a new mask is created, it will be
+ /// cached for future users.
Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
LinalgOp linalgOp,
std::optional<AffineMap> maybeMaskingMap);
/// Check whether this permutation map can be used for masking. At the
- /// moment we only make sure that there are no broadcast dimensions, but
- /// this might change if indexing maps evolve.
+ /// moment we only make sure that there are no broadcast dimensions, but this
+ /// might change if indexing maps evolve.
bool isValidMaskingMap(AffineMap maskingMap) {
return maskingMap.getBroadcastDims().size() == 0;
}
@@ -326,8 +325,8 @@ struct VectorizationState {
/// shape.
SmallVector<bool> scalableVecDims;
- /// Holds the active masks for permutations of the canonical vector
- /// iteration space.
+ /// Holds the active masks for permutations of the canonical vector iteration
+ /// space.
DenseMap<AffineMap, Value> activeMaskCache;
/// Global vectorization guard for the incoming rewriter. It's initialized
>From 2b6019caf04f37715ca5857b31b3f4fe9c48a417 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 10 Jul 2025 17:23:24 +0000
Subject: [PATCH 3/5] fixup! [mlir][linalg] Add support for scalable
vectorization of linalg.mmt4d
Rename the bool to assumeDynamicDimsMatchVecSizes
---
.../Linalg/TransformOps/LinalgTransformOps.td | 2 +-
.../Dialect/Linalg/Transforms/Transforms.h | 2 +-
.../TransformOps/LinalgTransformOps.cpp | 2 +-
.../Linalg/Transforms/Vectorization.cpp | 27 +++++++++----------
4 files changed, 15 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index baa17f75e53b6..472df21cb464e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2443,7 +2443,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
- OptionalAttr<UnitAttr>:$assume_scalable_sizes_match_dim_size,
+ OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
let results = (outs);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a6d697b43c0b7..8ba4f8f218721 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -872,7 +872,7 @@ vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
- bool assumeScalableSizesMultipleOfDim = false);
+ bool assumeDynamicDimsMatchVecSizes = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 49b9a41831fc6..7e1911a56693f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3922,7 +3922,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
getVectorizeNdExtract().value_or(false), false,
- getAssumeScalableSizesMatchDimSize().value_or(false));
+ getAssumeDynamicDimsMatchVecSizes().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 38bf37d844be4..11ba3eebc128a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -223,7 +223,7 @@ struct VectorizationState {
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
- bool assumeScalableVecSizesMatchDimSize = false);
+ bool assumeDynamicDimsMatchVecSizes = false);
/// Returns the canonical vector shape used to vectorize the iteration space.
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
@@ -333,16 +333,13 @@ struct VectorizationState {
/// when the vectorization state is initialized.
OpBuilder::InsertionGuard rewriterGuard;
- /// Do all scalable vector sizes match the corresponding input dim sizes?
- /// (tensor or memref)
+ /// Do all dynamic dims match the corresponding vector sizes?
///
- /// At the Tensor + MemRef levels, scalable sizes are modelled using
- /// dynamic dimensions (i.e. `?`). In many cases these sizes result from
- /// e.g. "scalable packing + tiling" and are known to always match the
- /// scalable vector sizes. In such cases, masking can be safely skipped,
- /// despite the presence of dynamic shapes. Use this flag with care and
- /// only for cases where you are confident the assumption holds.
- bool assumeScalableVecSizesMatchDimSize = false;
+ /// When a dynamic tensor/memref dimension matches the corresponding vector
+ /// dimension, masking can be safely skipped, despite the presence of dynamic
+ /// shapes. Use this flag with care and only for cases where you are
+ /// confident the assumption holds.
+ bool assumeDynamicDimsMatchVecSizes = false;
};
LogicalResult
@@ -383,8 +380,8 @@ LogicalResult VectorizationState::initState(RewriterBase &rewriter,
LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
- bool assumeScalableSizes) {
- assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
+ bool assumeDimsMatchVec) {
+ assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
// Initialize the insertion point.
rewriter.setInsertionPoint(linalgOp);
@@ -484,7 +481,7 @@ Value VectorizationState::getOrCreateMaskFor(
return Value();
}
- if (assumeScalableVecSizesMatchDimSize) {
+ if (assumeDynamicDimsMatchVecSizes) {
// Given that all _scalable vector sizes_ match the corresponding
// memref/tensor dim sizes, masking can be skipped provided that:
// * all vector sizes corresponding to dynamic dims are scalable.
@@ -2568,7 +2565,7 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
FailureOr<VectorizationResult> mlir::linalg::vectorize(
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
- bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
+ bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2589,7 +2586,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
inputScalableVecDims,
- assumeScalableSizesMultipleOfDim))) {
+ assumeDynamicDimsMatchVecSizes))) {
LDBG("Vectorization state couldn't be initialized\n");
return failure();
}
>From 1ec6935800afaa3527746297bdc8b8485459e534 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 15 Jul 2025 09:58:39 +0000
Subject: [PATCH 4/5] fixup! fixup! [mlir][linalg] Add support for scalable
vectorization of linalg.mmt4d
Fix the condition that checks whether masks are needed, fix test
---
.../Linalg/Transforms/Vectorization.cpp | 19 ++++++++++---------
.../Linalg/vectorization/linalg-ops.mlir | 2 +-
2 files changed, 11 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 11ba3eebc128a..b2f7caeb84a18 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -482,18 +482,19 @@ Value VectorizationState::getOrCreateMaskFor(
}
if (assumeDynamicDimsMatchVecSizes) {
- // Given that all _scalable vector sizes_ match the corresponding
- // memref/tensor dim sizes, masking can be skipped provided that:
- // * all vector sizes corresponding to dynamic dims are scalable.
- if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
+ // While we can _assume_ that for dynamic dim sizes the corresponding
+ // vector sizes match, we still need to check the static dim sizes to be
+ // 100% sure that masking is indeed not required.
+ if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
[](auto it) {
return std::get<0>(it) == ShapedType::kDynamic
- ? std::get<1>(it)
- : false;
- }))
+ ? true
+ : std::get<0>(it) == std::get<1>(it);
+ })) {
LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
- activeMaskCache[maskingMap] = Value();
- return Value();
+ activeMaskCache[maskingMap] = Value();
+ return Value();
+ }
}
// Permute the iteration space value sizes to compute the mask upper bounds.
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 188f03069938f..c5e6cce6f125b 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -928,7 +928,7 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
+ transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
transform.yield
}
}
>From 4ca2b53de6d0382b0a4d5fd7ec356c140528d031 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 17 Jul 2025 12:31:51 +0000
Subject: [PATCH 5/5] fixup! fixup! fixup! [mlir][linalg] Add support for
scalable vectorization of linalg.mmt4d
Refine comment
---
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b2f7caeb84a18..a30cf57352b74 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -482,9 +482,9 @@ Value VectorizationState::getOrCreateMaskFor(
}
if (assumeDynamicDimsMatchVecSizes) {
- // While we can _assume_ that for dynamic dim sizes the corresponding
- // vector sizes match, we still need to check the static dim sizes to be
- // 100% sure that masking is indeed not required.
+ // While for _dynamic_ dim sizes we can _assume_ that the corresponding
+ // vector sizes match, we still need to check the _static_ dim sizes. Only
+ // then we can be 100% sure that masking is not required.
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
[](auto it) {
return std::get<0>(it) == ShapedType::kDynamic
More information about the Mlir-commits
mailing list