[Mlir-commits] [mlir] [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d (PR #146531)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 1 06:54:37 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

This patch introduces support for scalable vectorization of `linalg.mmt4d`.
The key design addition is a new state variable in the Linalg vectorizer:

  * `assumeScalableVecSizesMatchDimSize`

This flag informs the vectorizer that the memref/tensor dimensions
corresponding to scalable vector sizes (typically dynamic) _match the
vector sizes_ at runtime.

While this assumption is not generally valid, it does hold for
`linalg.mmt4d` because inputs and outputs are explicitly packed (via
`linalg.pack`). Packing includes padding, which ensures that dimension
sizes align with the scalable vector lengths (*).

See discussion here:
* https://github.com/llvm/llvm-project/issues/143920

(*) Provided that the tile sizes used for packing match the vector sizes used
during vectorization. Enforcing this is left to the user.


---
Full diff: https://github.com/llvm/llvm-project/pull/146531.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+5-6) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2-1) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+55-24) 
- (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir (+93-24) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d64f94a49f781..baa17f75e53b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2440,12 +2440,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
-                          $static_vector_sizes,
-                       OptionalAttr<UnitAttr>:$vectorize_nd_extract,
-                       DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
-                          $scalable_sizes);
+      Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
+      OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+      OptionalAttr<UnitAttr>:$assume_scalable_sizes_match_dim_size,
+      DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
 
   let results = (outs);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2b4855f49695c..a6d697b43c0b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -871,7 +871,8 @@ FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
-          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+          bool assumeScalableSizesMultipleOfDim = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8571d641e26d1..49b9a41831fc6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false));
+                          getVectorizeNdExtract().value_or(false), false,
+                          getAssumeScalableSizesMatchDimSize().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b467114c72f7d..3a533322a3c7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -222,9 +222,11 @@ struct VectorizationState {
   /// canonical vector shape for vectorization.
   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
                           ArrayRef<int64_t> inputVectorSizes,
-                          ArrayRef<bool> inputScalableVecDims);
+                          ArrayRef<bool> inputScalableVecDims,
+                          bool assumeScalableVecSizesMatchDimSize = false);
 
-  /// Returns the canonical vector shape used to vectorize the iteration space.
+  /// Returns the canonical vector shape used to vectorize the iteration
+  /// space.
   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
 
   /// Returns the vector dimensions that are scalable in the canonical vector
@@ -233,8 +235,8 @@ struct VectorizationState {
 
   /// Returns a vector type of the provided `elementType` with the canonical
   /// vector shape and the corresponding fixed/scalable dimensions bit. If
-  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
-  /// accordingly.
+  /// `dimPermutation` is provided, the canonical vector dimensions are
+  /// permuted accordingly.
   VectorType getCanonicalVecType(
       Type elementType,
       std::optional<AffineMap> dimPermutation = std::nullopt) const {
@@ -254,9 +256,9 @@ struct VectorizationState {
   }
 
   /// Masks an operation with the canonical vector mask if the operation needs
-  /// masking. Returns the masked operation or the original operation if masking
-  /// is not needed. If provided, the canonical mask for this operation is
-  /// permuted using `maybeIndexingMap`.
+  /// masking. Returns the masked operation or the original operation if
+  /// masking is not needed. If provided, the canonical mask for this
+  /// operation is permuted using `maybeIndexingMap`.
   Operation *
   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
                 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
@@ -276,15 +278,15 @@ struct VectorizationState {
 
   /// Create or retrieve an existing mask value to mask `opToMask` in the
   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
-  /// permuted using that permutation map. If a new mask is created, it will be
-  /// cached for future users.
+  /// permuted using that permutation map. If a new mask is created, it will
+  /// be cached for future users.
   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
                            LinalgOp linalgOp,
                            std::optional<AffineMap> maybeMaskingMap);
 
   /// Check whether this permutation map can be used for masking. At the
-  /// moment we only make sure that there are no broadcast dimensions, but this
-  /// might change if indexing maps evolve.
+  /// moment we only make sure that there are no broadcast dimensions, but
+  /// this might change if indexing maps evolve.
   bool isValidMaskingMap(AffineMap maskingMap) {
     return maskingMap.getBroadcastDims().size() == 0;
   }
@@ -324,13 +326,24 @@ struct VectorizationState {
   /// shape.
   SmallVector<bool> scalableVecDims;
 
-  /// Holds the active masks for permutations of the canonical vector iteration
-  /// space.
+  /// Holds the active masks for permutations of the canonical vector
+  /// iteration space.
   DenseMap<AffineMap, Value> activeMaskCache;
 
   /// Global vectorization guard for the incoming rewriter. It's initialized
   /// when the vectorization state is initialized.
   OpBuilder::InsertionGuard rewriterGuard;
+
+  /// Do all scalable vector sizes match the corresponding input dim sizes?
+  /// (tensor or memref)
+  ///
+  /// At the Tensor + MemRef levels, scalable sizes are modelled using
+  /// dynamic dimensions (i.e. `?`). In many cases these sizes result from
+  /// e.g. "scalable packing + tiling" and are known to always match the
+  /// scalable vector sizes. In such cases, masking can be safely skipped,
+  /// despite the presence of dynamic shapes. Use this flag with care and
+  /// only for cases where you are confident the assumption holds.
+  bool assumeScalableVecSizesMatchDimSize = false;
 };
 
 LogicalResult
@@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
 /// Initializes the vectorization state, including the computation of the
 /// canonical vector shape for vectorization.
 // TODO: Move this to the constructor when we can remove the failure cases.
-LogicalResult
-VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
-                              ArrayRef<int64_t> inputVectorSizes,
-                              ArrayRef<bool> inputScalableVecDims) {
+LogicalResult VectorizationState::initState(RewriterBase &rewriter,
+                                            LinalgOp linalgOp,
+                                            ArrayRef<int64_t> inputVectorSizes,
+                                            ArrayRef<bool> inputScalableVecDims,
+                                            bool assumeScalableSizes) {
+  assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
   // Initialize the insertion point.
   rewriter.setInsertionPoint(linalgOp);
 
@@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor(
     return Value();
   }
 
+  if (assumeScalableVecSizesMatchDimSize) {
+    // Given that all _scalable vector sizes_ match the corresponding
+    // memref/tensor dim sizes, masking can be skipped provided that:
+    // * all vector sizes corresponding to dynamic dims are scalable.
+    if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
+                     [](auto it) {
+                       return std::get<0>(it) == ShapedType::kDynamic
+                                  ? std::get<1>(it)
+                                  : false;
+                     }))
+      LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
+    activeMaskCache[maskingMap] = Value();
+    return Value();
+  }
+
   // Permute the iteration space value sizes to compute the mask upper bounds.
   SmallVector<Value> upperBounds =
       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
@@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
                  isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
-                 isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
+                 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+                 hasReductionIterator(linalgOp));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        ArrayRef<bool> inputScalableVecDims,
-                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+    RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
   VectorizationState state(rewriter);
   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
-                               inputScalableVecDims))) {
+                               inputScalableVecDims,
+                               assumeScalableSizesMultipleOfDim))) {
       LDBG("Vectorization state couldn't be initialized\n");
       return failure();
     }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 6722de817f6bf..188f03069938f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+               outs(%C_in: memref<16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
+// CHECK:           %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
+// CHECK:           %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable_with_assume(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK-NOT:       mask
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
+// CHECK:           vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
+    transform.yield
+  }
+}
+
 ///----------------------------------------------------------------------------------------
 /// Tests for other Ops
 ///----------------------------------------------------------------------------------------
@@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// -----
-
-func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
-  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
-               outs(%C_in: memref<16x16x8x8xf32>)
-  return
-}
-
-// CHECK-LABEL:   func.func @mmt4d(
-// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
-// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
-// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
-// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %mmt4d : !transform.any_op
-    transform.yield
-  }
-}
 
 // -----
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/146531


More information about the Mlir-commits mailing list