[Mlir-commits] [mlir] 3b11aaa - [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d (#146531)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 17 11:02:12 PDT 2025


Author: Andrzej Warzyński
Date: 2025-07-17T19:02:08+01:00
New Revision: 3b11aaaf94fe6c7b4ccfd031f952265f706c1b68

URL: https://github.com/llvm/llvm-project/commit/3b11aaaf94fe6c7b4ccfd031f952265f706c1b68
DIFF: https://github.com/llvm/llvm-project/commit/3b11aaaf94fe6c7b4ccfd031f952265f706c1b68.diff

LOG: [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d (#146531)

This patch adds support for scalable vectorization of linalg.mmt4d. The
key design change is the introduction of a new vectorizer state variable:

* `assumeDynamicDimsMatchVecSizes`

...along with the corresponding Transform dialect attribute:

* `assume_dynamic_dims_match_vec_sizes`.

This flag instructs the vectorizer to assume that dynamic memref/tensor
dimensions match the corresponding vector sizes (fixed or scalable). With this
assumption, masking becomes unnecessary, which simplifies the lowering pipeline
significantly.

While this assumption is not universally valid, it typically holds for
`linalg.mmt4d`. Inputs and outputs are explicitly packed using `linalg.pack`,
and this packing includes padding, ensuring that dimension sizes align with
vector sizes (*).

* Related discussion: https://github.com/llvm/llvm-project/issues/143920

An upcoming patch will include an end-to-end test that leverages scalable
vectorization of linalg.mmt4d to demonstrate the newly enabled functionality.
This would not be feasible without the changes introduced here, as it would
otherwise require additional logic to handle complex - but ultimately redundant
- masks.

(*) This holds provided that the tile sizes used for packing match the vector
sizes used during vectorization. It is the user’s responsibility to enforce
this.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b4dde776822a1..bafeca924e4c5 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2431,12 +2431,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
-                          $static_vector_sizes,
-                       OptionalAttr<UnitAttr>:$vectorize_nd_extract,
-                       DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
-                          $scalable_sizes);
+      Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
+      OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+      OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
+      DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
 
   let results = (outs);
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..9e62d0dcc7890 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -880,7 +880,8 @@ FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
-          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+          bool assumeDynamicDimsMatchVecSizes = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d5f9de465561..c959310136319 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false));
+                          getVectorizeNdExtract().value_or(false), false,
+                          getAssumeDynamicDimsMatchVecSizes().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 458ed543b8216..4add50f4b36e5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -219,7 +219,8 @@ struct VectorizationState {
   /// canonical vector shape for vectorization.
   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
                           ArrayRef<int64_t> inputVectorSizes,
-                          ArrayRef<bool> inputScalableVecDims);
+                          ArrayRef<bool> inputScalableVecDims,
+                          bool assumeDynamicDimsMatchVecSizes = false);
 
   /// Returns the canonical vector shape used to vectorize the iteration space.
   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
@@ -328,6 +329,14 @@ struct VectorizationState {
   /// Global vectorization guard for the incoming rewriter. It's initialized
   /// when the vectorization state is initialized.
   OpBuilder::InsertionGuard rewriterGuard;
+
+  /// Do all dynamic dims match the corresponding vector sizes?
+  ///
+  /// When a dynamic tensor/memref dimension matches the corresponding vector
+  /// dimension, masking can be safely skipped, despite the presence of dynamic
+  /// shapes. Use this flag with care and only for cases where you are
+  /// confident the assumption holds.
+  bool assumeDynamicDimsMatchVecSizes = false;
 };
 
 LogicalResult
@@ -364,10 +373,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
 /// Initializes the vectorization state, including the computation of the
 /// canonical vector shape for vectorization.
 // TODO: Move this to the constructor when we can remove the failure cases.
-LogicalResult
-VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
-                              ArrayRef<int64_t> inputVectorSizes,
-                              ArrayRef<bool> inputScalableVecDims) {
+LogicalResult VectorizationState::initState(RewriterBase &rewriter,
+                                            LinalgOp linalgOp,
+                                            ArrayRef<int64_t> inputVectorSizes,
+                                            ArrayRef<bool> inputScalableVecDims,
+                                            bool assumeDimsMatchVec) {
+  assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
   // Initialize the insertion point.
   rewriter.setInsertionPoint(linalgOp);
 
@@ -467,6 +478,23 @@ Value VectorizationState::getOrCreateMaskFor(
     return Value();
   }
 
+  if (assumeDynamicDimsMatchVecSizes) {
+    // While for _dynamic_ dim sizes we can _assume_ that the corresponding
+    // vector sizes match, we still need to check the _static_ dim sizes. Only
+    // then we can be 100% sure that masking is not required.
+    if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
+                     [](auto it) {
+                       return std::get<0>(it) == ShapedType::kDynamic
+                                  ? true
+                                  : std::get<0>(it) == std::get<1>(it);
+                     })) {
+      LDBG("Dynamic + static dimensions match vector sizes, masking is not "
+           "required.\n");
+      activeMaskCache[maskingMap] = Value();
+      return Value();
+    }
+  }
+
   // Permute the iteration space value sizes to compute the mask upper bounds.
   SmallVector<Value> upperBounds =
       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
@@ -2469,7 +2497,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
                  isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
-                 isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
+                 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+                 hasReductionIterator(linalgOp));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2525,11 +2554,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        ArrayRef<bool> inputScalableVecDims,
-                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+    RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2549,7 +2577,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
   VectorizationState state(rewriter);
   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
-                               inputScalableVecDims))) {
+                               inputScalableVecDims,
+                               assumeDynamicDimsMatchVecSizes))) {
       LDBG("Vectorization state couldn't be initialized\n");
       return failure();
     }

diff  --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 679adf0a52175..4fc39e220f86d 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+               outs(%C_in: memref<16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
+// CHECK:           %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
+// CHECK:           %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable_with_assume(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK-NOT:       mask
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
+// CHECK:           vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+    transform.yield
+  }
+}
+
 ///----------------------------------------------------------------------------------------
 /// Tests for other Ops
 ///----------------------------------------------------------------------------------------
@@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// -----
-
-func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
-  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
-               outs(%C_in: memref<16x16x8x8xf32>)
-  return
-}
-
-// CHECK-LABEL:   func.func @mmt4d(
-// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
-// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
-// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
-// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %mmt4d : !transform.any_op
-    transform.yield
-  }
-}
 
 // -----
 


        


More information about the Mlir-commits mailing list