[Mlir-commits] [mlir] [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d (PR #146531)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jul 17 05:32:05 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/146531

>From 487db47dbae6de69f174868b625fb7730612c09a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 1 Jul 2025 13:39:46 +0000
Subject: [PATCH 1/5] [mlir][linalg] Add support for scalable vectorization of
 linalg.mmt4d

This patch introduces support for scalable vectorization of `linalg.mmt4d`.
The key design addition is a new state variable in the Linalg vectorizer:

  * `assumeScalableVecSizesMatchDimSize`

This flag informs the vectorizer that the memref/tensor dimensions
corresponding to scalable vector sizes (typically dynamic) _match the
vector sizes_ at runtime.

While this assumption is not generally valid, it does hold for
`linalg.mmt4d` because inputs and outputs are explicitly packed (via
`linalg.pack`). Packing includes padding, which ensures that dimension
sizes align with the scalable vector lengths (*).

See discussion here:
* https://github.com/llvm/llvm-project/issues/143920

(*) Provided that the tile sizes used for packing match the vector sizes used
during vectorization. Enforcing this is left to the user.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  11 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |   3 +-
 .../TransformOps/LinalgTransformOps.cpp       |   3 +-
 .../Linalg/Transforms/Vectorization.cpp       |  79 ++++++++----
 .../Linalg/vectorization/linalg-ops.mlir      | 117 ++++++++++++++----
 5 files changed, 157 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d64f94a49f781..baa17f75e53b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2440,12 +2440,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
-                          $static_vector_sizes,
-                       OptionalAttr<UnitAttr>:$vectorize_nd_extract,
-                       DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
-                          $scalable_sizes);
+      Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
+      OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+      OptionalAttr<UnitAttr>:$assume_scalable_sizes_match_dim_size,
+      DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
 
   let results = (outs);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2b4855f49695c..a6d697b43c0b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -871,7 +871,8 @@ FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
-          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+          bool assumeScalableSizesMultipleOfDim = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8571d641e26d1..49b9a41831fc6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false));
+                          getVectorizeNdExtract().value_or(false), false,
+                          getAssumeScalableSizesMatchDimSize().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b467114c72f7d..3a533322a3c7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -222,9 +222,11 @@ struct VectorizationState {
   /// canonical vector shape for vectorization.
   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
                           ArrayRef<int64_t> inputVectorSizes,
-                          ArrayRef<bool> inputScalableVecDims);
+                          ArrayRef<bool> inputScalableVecDims,
+                          bool assumeScalableVecSizesMatchDimSize = false);
 
-  /// Returns the canonical vector shape used to vectorize the iteration space.
+  /// Returns the canonical vector shape used to vectorize the iteration
+  /// space.
   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
 
   /// Returns the vector dimensions that are scalable in the canonical vector
@@ -233,8 +235,8 @@ struct VectorizationState {
 
   /// Returns a vector type of the provided `elementType` with the canonical
   /// vector shape and the corresponding fixed/scalable dimensions bit. If
-  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
-  /// accordingly.
+  /// `dimPermutation` is provided, the canonical vector dimensions are
+  /// permuted accordingly.
   VectorType getCanonicalVecType(
       Type elementType,
       std::optional<AffineMap> dimPermutation = std::nullopt) const {
@@ -254,9 +256,9 @@ struct VectorizationState {
   }
 
   /// Masks an operation with the canonical vector mask if the operation needs
-  /// masking. Returns the masked operation or the original operation if masking
-  /// is not needed. If provided, the canonical mask for this operation is
-  /// permuted using `maybeIndexingMap`.
+  /// masking. Returns the masked operation or the original operation if
+  /// masking is not needed. If provided, the canonical mask for this
+  /// operation is permuted using `maybeIndexingMap`.
   Operation *
   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
                 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
@@ -276,15 +278,15 @@ struct VectorizationState {
 
   /// Create or retrieve an existing mask value to mask `opToMask` in the
   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
-  /// permuted using that permutation map. If a new mask is created, it will be
-  /// cached for future users.
+  /// permuted using that permutation map. If a new mask is created, it will
+  /// be cached for future users.
   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
                            LinalgOp linalgOp,
                            std::optional<AffineMap> maybeMaskingMap);
 
   /// Check whether this permutation map can be used for masking. At the
-  /// moment we only make sure that there are no broadcast dimensions, but this
-  /// might change if indexing maps evolve.
+  /// moment we only make sure that there are no broadcast dimensions, but
+  /// this might change if indexing maps evolve.
   bool isValidMaskingMap(AffineMap maskingMap) {
     return maskingMap.getBroadcastDims().size() == 0;
   }
@@ -324,13 +326,24 @@ struct VectorizationState {
   /// shape.
   SmallVector<bool> scalableVecDims;
 
-  /// Holds the active masks for permutations of the canonical vector iteration
-  /// space.
+  /// Holds the active masks for permutations of the canonical vector
+  /// iteration space.
   DenseMap<AffineMap, Value> activeMaskCache;
 
   /// Global vectorization guard for the incoming rewriter. It's initialized
   /// when the vectorization state is initialized.
   OpBuilder::InsertionGuard rewriterGuard;
+
+  /// Do all scalable vector sizes match the corresponding input dim sizes?
+  /// (tensor or memref)
+  ///
+  /// At the Tensor + MemRef levels, scalable sizes are modelled using
+  /// dynamic dimensions (i.e. `?`). In many cases these sizes result from
+  /// e.g. "scalable packing + tiling" and are known to always match the
+  /// scalable vector sizes. In such cases, masking can be safely skipped,
+  /// despite the presence of dynamic shapes. Use this flag with care and
+  /// only for cases where you are confident the assumption holds.
+  bool assumeScalableVecSizesMatchDimSize = false;
 };
 
 LogicalResult
@@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
 /// Initializes the vectorization state, including the computation of the
 /// canonical vector shape for vectorization.
 // TODO: Move this to the constructor when we can remove the failure cases.
-LogicalResult
-VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
-                              ArrayRef<int64_t> inputVectorSizes,
-                              ArrayRef<bool> inputScalableVecDims) {
+LogicalResult VectorizationState::initState(RewriterBase &rewriter,
+                                            LinalgOp linalgOp,
+                                            ArrayRef<int64_t> inputVectorSizes,
+                                            ArrayRef<bool> inputScalableVecDims,
+                                            bool assumeScalableSizes) {
+  assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
   // Initialize the insertion point.
   rewriter.setInsertionPoint(linalgOp);
 
@@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor(
     return Value();
   }
 
+  if (assumeScalableVecSizesMatchDimSize) {
+    // Given that all _scalable vector sizes_ match the corresponding
+    // memref/tensor dim sizes, masking can be skipped provided that:
+    // * all vector sizes corresponding to dynamic dims are scalable.
+    if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
+                     [](auto it) {
+                       return std::get<0>(it) == ShapedType::kDynamic
+                                  ? std::get<1>(it)
+                                  : false;
+                     }))
+      LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
+    activeMaskCache[maskingMap] = Value();
+    return Value();
+  }
+
   // Permute the iteration space value sizes to compute the mask upper bounds.
   SmallVector<Value> upperBounds =
       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
@@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
                  isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
-                 isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
+                 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+                 hasReductionIterator(linalgOp));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        ArrayRef<bool> inputScalableVecDims,
-                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+    RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
   VectorizationState state(rewriter);
   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
-                               inputScalableVecDims))) {
+                               inputScalableVecDims,
+                               assumeScalableSizesMultipleOfDim))) {
       LDBG("Vectorization state couldn't be initialized\n");
       return failure();
     }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 6722de817f6bf..188f03069938f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+               outs(%C_in: memref<16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
+// CHECK:           %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
+// CHECK:           %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable_with_assume(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK-NOT:       mask
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
+// CHECK:           vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
+    transform.yield
+  }
+}
+
 ///----------------------------------------------------------------------------------------
 /// Tests for other Ops
 ///----------------------------------------------------------------------------------------
@@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// -----
-
-func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
-  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
-               outs(%C_in: memref<16x16x8x8xf32>)
-  return
-}
-
-// CHECK-LABEL:   func.func @mmt4d(
-// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
-// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
-// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
-// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %mmt4d : !transform.any_op
-    transform.yield
-  }
-}
 
 // -----
 

>From 8ef66618eeb45c2ad8a435f47715d0cf0b15fc86 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 10 Jul 2025 16:44:41 +0000
Subject: [PATCH 2/5] fixup! [mlir][linalg] Add support for scalable
 vectorization of linalg.mmt4d

Revert changes in comments
---
 .../Linalg/Transforms/Vectorization.cpp       | 25 +++++++++----------
 1 file changed, 12 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3a533322a3c7f..38bf37d844be4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -225,8 +225,7 @@ struct VectorizationState {
                           ArrayRef<bool> inputScalableVecDims,
                           bool assumeScalableVecSizesMatchDimSize = false);
 
-  /// Returns the canonical vector shape used to vectorize the iteration
-  /// space.
+  /// Returns the canonical vector shape used to vectorize the iteration space.
   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
 
   /// Returns the vector dimensions that are scalable in the canonical vector
@@ -235,8 +234,8 @@ struct VectorizationState {
 
   /// Returns a vector type of the provided `elementType` with the canonical
   /// vector shape and the corresponding fixed/scalable dimensions bit. If
-  /// `dimPermutation` is provided, the canonical vector dimensions are
-  /// permuted accordingly.
+  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
+  /// accordingly.
   VectorType getCanonicalVecType(
       Type elementType,
       std::optional<AffineMap> dimPermutation = std::nullopt) const {
@@ -256,9 +255,9 @@ struct VectorizationState {
   }
 
   /// Masks an operation with the canonical vector mask if the operation needs
-  /// masking. Returns the masked operation or the original operation if
-  /// masking is not needed. If provided, the canonical mask for this
-  /// operation is permuted using `maybeIndexingMap`.
+  /// masking. Returns the masked operation or the original operation if masking
+  /// is not needed. If provided, the canonical mask for this operation is
+  /// permuted using `maybeIndexingMap`.
   Operation *
   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
                 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
@@ -278,15 +277,15 @@ struct VectorizationState {
 
   /// Create or retrieve an existing mask value to mask `opToMask` in the
   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
-  /// permuted using that permutation map. If a new mask is created, it will
-  /// be cached for future users.
+  /// permuted using that permutation map. If a new mask is created, it will be
+  /// cached for future users.
   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
                            LinalgOp linalgOp,
                            std::optional<AffineMap> maybeMaskingMap);
 
   /// Check whether this permutation map can be used for masking. At the
-  /// moment we only make sure that there are no broadcast dimensions, but
-  /// this might change if indexing maps evolve.
+  /// moment we only make sure that there are no broadcast dimensions, but this
+  /// might change if indexing maps evolve.
   bool isValidMaskingMap(AffineMap maskingMap) {
     return maskingMap.getBroadcastDims().size() == 0;
   }
@@ -326,8 +325,8 @@ struct VectorizationState {
   /// shape.
   SmallVector<bool> scalableVecDims;
 
-  /// Holds the active masks for permutations of the canonical vector
-  /// iteration space.
+  /// Holds the active masks for permutations of the canonical vector iteration
+  /// space.
   DenseMap<AffineMap, Value> activeMaskCache;
 
   /// Global vectorization guard for the incoming rewriter. It's initialized

>From 2b6019caf04f37715ca5857b31b3f4fe9c48a417 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 10 Jul 2025 17:23:24 +0000
Subject: [PATCH 3/5] fixup! [mlir][linalg] Add support for scalable
 vectorization of linalg.mmt4d

Rename the bool to assumeDynamicDimsMatchVecSizes
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  2 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  2 +-
 .../TransformOps/LinalgTransformOps.cpp       |  2 +-
 .../Linalg/Transforms/Vectorization.cpp       | 27 +++++++++----------
 4 files changed, 15 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index baa17f75e53b6..472df21cb464e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2443,7 +2443,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
       Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
       OptionalAttr<UnitAttr>:$vectorize_nd_extract,
-      OptionalAttr<UnitAttr>:$assume_scalable_sizes_match_dim_size,
+      OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
       DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
 
   let results = (outs);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a6d697b43c0b7..8ba4f8f218721 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -872,7 +872,7 @@ vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
           bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
-          bool assumeScalableSizesMultipleOfDim = false);
+          bool assumeDynamicDimsMatchVecSizes = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 49b9a41831fc6..7e1911a56693f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3922,7 +3922,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
                           getVectorizeNdExtract().value_or(false), false,
-                          getAssumeScalableSizesMatchDimSize().value_or(false));
+                          getAssumeDynamicDimsMatchVecSizes().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 38bf37d844be4..11ba3eebc128a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -223,7 +223,7 @@ struct VectorizationState {
   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
                           ArrayRef<int64_t> inputVectorSizes,
                           ArrayRef<bool> inputScalableVecDims,
-                          bool assumeScalableVecSizesMatchDimSize = false);
+                          bool assumeDynamicDimsMatchVecSizes = false);
 
   /// Returns the canonical vector shape used to vectorize the iteration space.
   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
@@ -333,16 +333,13 @@ struct VectorizationState {
   /// when the vectorization state is initialized.
   OpBuilder::InsertionGuard rewriterGuard;
 
-  /// Do all scalable vector sizes match the corresponding input dim sizes?
-  /// (tensor or memref)
+  /// Do all dynamic dims match the corresponding vector sizes?
   ///
-  /// At the Tensor + MemRef levels, scalable sizes are modelled using
-  /// dynamic dimensions (i.e. `?`). In many cases these sizes result from
-  /// e.g. "scalable packing + tiling" and are known to always match the
-  /// scalable vector sizes. In such cases, masking can be safely skipped,
-  /// despite the presence of dynamic shapes. Use this flag with care and
-  /// only for cases where you are confident the assumption holds.
-  bool assumeScalableVecSizesMatchDimSize = false;
+  /// When a dynamic tensor/memref dimension matches the corresponding vector
+  /// dimension, masking can be safely skipped, despite the presence of dynamic
+  /// shapes. Use this flag with care and only for cases where you are
+  /// confident the assumption holds.
+  bool assumeDynamicDimsMatchVecSizes = false;
 };
 
 LogicalResult
@@ -383,8 +380,8 @@ LogicalResult VectorizationState::initState(RewriterBase &rewriter,
                                             LinalgOp linalgOp,
                                             ArrayRef<int64_t> inputVectorSizes,
                                             ArrayRef<bool> inputScalableVecDims,
-                                            bool assumeScalableSizes) {
-  assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
+                                            bool assumeDimsMatchVec) {
+  assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
   // Initialize the insertion point.
   rewriter.setInsertionPoint(linalgOp);
 
@@ -484,7 +481,7 @@ Value VectorizationState::getOrCreateMaskFor(
     return Value();
   }
 
-  if (assumeScalableVecSizesMatchDimSize) {
+  if (assumeDynamicDimsMatchVecSizes) {
     // Given that all _scalable vector sizes_ match the corresponding
     // memref/tensor dim sizes, masking can be skipped provided that:
     // * all vector sizes corresponding to dynamic dims are scalable.
@@ -2568,7 +2565,7 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
 FailureOr<VectorizationResult> mlir::linalg::vectorize(
     RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
     ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
-    bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
+    bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2589,7 +2586,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
                                inputScalableVecDims,
-                               assumeScalableSizesMultipleOfDim))) {
+                               assumeDynamicDimsMatchVecSizes))) {
       LDBG("Vectorization state couldn't be initialized\n");
       return failure();
     }

>From 1ec6935800afaa3527746297bdc8b8485459e534 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 15 Jul 2025 09:58:39 +0000
Subject: [PATCH 4/5] fixup! fixup! [mlir][linalg] Add support for scalable
 vectorization of linalg.mmt4d

Fix the condition that checks whether masks are needed, fix test
---
 .../Linalg/Transforms/Vectorization.cpp       | 19 ++++++++++---------
 .../Linalg/vectorization/linalg-ops.mlir      |  2 +-
 2 files changed, 11 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 11ba3eebc128a..b2f7caeb84a18 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -482,18 +482,19 @@ Value VectorizationState::getOrCreateMaskFor(
   }
 
   if (assumeDynamicDimsMatchVecSizes) {
-    // Given that all _scalable vector sizes_ match the corresponding
-    // memref/tensor dim sizes, masking can be skipped provided that:
-    // * all vector sizes corresponding to dynamic dims are scalable.
-    if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
+    // While we can _assume_ that for dynamic dim sizes the corresponding
+    // vector sizes match, we still need to check the static dim sizes to be
+    // 100% sure that masking is indeed not required.
+    if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
                      [](auto it) {
                        return std::get<0>(it) == ShapedType::kDynamic
-                                  ? std::get<1>(it)
-                                  : false;
-                     }))
+                                  ? true
+                                  : std::get<0>(it) == std::get<1>(it);
+                     })) {
       LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
-    activeMaskCache[maskingMap] = Value();
-    return Value();
+      activeMaskCache[maskingMap] = Value();
+      return Value();
+    }
   }
 
   // Permute the iteration space value sizes to compute the mask upper bounds.
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 188f03069938f..c5e6cce6f125b 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -928,7 +928,7 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
     transform.yield
   }
 }

>From 4ca2b53de6d0382b0a4d5fd7ec356c140528d031 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 17 Jul 2025 12:31:51 +0000
Subject: [PATCH 5/5] fixup! fixup! fixup! [mlir][linalg] Add support for
 scalable vectorization of linalg.mmt4d

Refine comment
---
 mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b2f7caeb84a18..a30cf57352b74 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -482,9 +482,9 @@ Value VectorizationState::getOrCreateMaskFor(
   }
 
   if (assumeDynamicDimsMatchVecSizes) {
-    // While we can _assume_ that for dynamic dim sizes the corresponding
-    // vector sizes match, we still need to check the static dim sizes to be
-    // 100% sure that masking is indeed not required.
+    // While for _dynamic_ dim sizes we can _assume_ that the corresponding
+    // vector sizes match, we still need to check the _static_ dim sizes. Only
+    // then we can be 100% sure that masking is not required.
     if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
                      [](auto it) {
                        return std::get<0>(it) == ShapedType::kDynamic



More information about the Mlir-commits mailing list