[Mlir-commits] [mlir] [mlir][linalg] Vectorize directly to a named contraction (PR #147296)

Adam Siemieniuk llvmlistbot at llvm.org
Mon Jul 21 09:42:03 PDT 2025


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/147296

>From 143c13ca25f4ee26f92d6b3c8dd5ebabc5bd0816 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 3 Jul 2025 13:19:43 +0200
Subject: [PATCH 1/8] [mlir][linalg] Vectorize directly to a named contraction

Extends linalg vectorizer with a path to lower contraction ops
directly into `vector.contract`.

The direct rewriting preserves high-level op semantics and provides
more progressive lowering compared to reconstructing contraction back
from multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their
well defined semantics to avoid complex precondition verification.

The new path is optional and disabled by default to avoid changing
the default vectorizer behavior.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |   2 +
 .../Dialect/Linalg/Transforms/Transforms.h    |   5 +-
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |   3 +-
 .../TransformOps/LinalgTransformOps.cpp       |   3 +-
 .../Linalg/Transforms/Vectorization.cpp       | 104 ++++-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  13 +-
 .../vectorization/contraction-named.mlir      | 400 ++++++++++++++++++
 7 files changed, 520 insertions(+), 10 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bafeca924e4c5..9eda469bdf930 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2435,6 +2435,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
       OptionalAttr<UnitAttr>:$vectorize_nd_extract,
       OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
+      OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
+      OptionalAttr<UnitAttr>:$create_named_contraction,
       DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
 
   let results = (outs);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 9e62d0dcc7890..38e53648e7c34 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,12 +876,15 @@ struct VectorizationResult {
 /// greater than or equal to their counterpart iteration space sizes, if static.
 /// `inputVectorShapes` also allows the vectorization of operations with dynamic
 /// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
 FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
           bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
-          bool assumeDynamicDimsMatchVecSizes = false);
+          bool assumeDynamicDimsMatchVecSizes = false,
+          bool createNamedContraction = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..9b765d0b8ede6 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
 /// Note: all read offsets are set to 0.
 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
                              ArrayRef<int64_t> inputVectorSizes, Value padValue,
-                             bool useInBoundsInsteadOfMasking = false);
+                             bool useInBoundsInsteadOfMasking = false,
+                             ArrayRef<bool> scalableDims = {});
 
 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
 /// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c959310136319..b23641a1ceb54 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
                           getVectorizeNdExtract().value_or(false), false,
-                          getAssumeDynamicDimsMatchVecSizes().value_or(false));
+                          getAssumeDynamicDimsMatchVecSizes().value_or(false),
+                          getCreateNamedContraction().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4add50f4b36e5..f2f4330f025f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -1709,10 +1710,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
     return write;
 
   // Compute the mask and mask the write Op.
-  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+                                       vecToStoreType.getScalableDims());
 
   SmallVector<OpFoldResult> destSizes =
-      tensor::getMixedSizes(builder, loc, dest);
+      isa<MemRefType>(dest.getType())
+          ? memref::getMixedSizes(builder, loc, dest)
+          : tensor::getMixedSizes(builder, loc, dest);
   SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
                                       destSizes.end());
 
@@ -2118,6 +2122,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+/// Vectorize a named linalg contraction op into:
+///   vector::TransferReadOp - Reads vectors from the operands
+///   vector::ContractionOp - Performs contraction
+///   vector::TransferWriteOp - Write the result vector back to the
+///   destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+                             LinalgOp linalgOp,
+                             SmallVectorImpl<Value> &newResults) {
+  Location loc = linalgOp.getLoc();
+  MLIRContext *ctx = linalgOp.getContext();
+
+  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+    return failure();
+
+  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+  Operation *reduceOp = matchLinalgReduction(outOperand);
+  auto maybeKind = getCombinerOpKind(reduceOp);
+  if (!maybeKind)
+    return failure();
+
+  // Check that all dimensions are present in the input operands.
+  // Arbitrary broadcasts are not supported by the vector contraction.
+  // Broadcasts are expected to be materialized before vectorization.
+  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+    return failure();
+
+  // Load operands.
+  SmallVector<Value> vecOperands;
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    // The operand vector shape is computed by mapping the canonical vector
+    // shape to the operand's domain. Further permutations are left as a part of
+    // the contraction.
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+    AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+        indexingMap.getNumResults(), rewriter.getContext());
+    Type elemType = getElementTypeOrSelf(opOperand.get());
+    VectorType readType =
+        state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+    Value read = mlir::vector::createReadOrMaskedRead(
+        rewriter, loc, opOperand.get(), readType.getShape(),
+        /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+        /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+    vecOperands.push_back(read);
+  }
+
+  // Remap iterators from linalg to vector.
+  SmallVector<Attribute> iterAttrs;
+  auto iterators = linalgOp.getIteratorTypesArray();
+  for (utils::IteratorType iter : iterators) {
+    auto vecIter = iter == utils::IteratorType::parallel
+                       ? vector::IteratorType::parallel
+                       : vector::IteratorType::reduction;
+    iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+  }
+
+  // Create contraction.
+  Value contractOp = rewriter.create<vector::ContractionOp>(
+      loc, /*lhs=*/vecOperands[0],
+      /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+      linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+  // Store result.
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+
+  // Finalize.
+  if (!write->getResults().empty())
+    newResults.push_back(write->getResult(0));
+
+  return success();
+}
+
 namespace {
 enum class ConvOperationKind { Conv, Pool };
 } // namespace
@@ -2557,7 +2639,8 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
 FailureOr<VectorizationResult> mlir::linalg::vectorize(
     RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
     ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
-    bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
+    bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
+    bool createNamedContraction) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2604,6 +2687,21 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
               return failure();
             }
 
+            // For simplicity, contraction vectorization is limited to linalg
+            // named ops. Generic op is ignored as not every arbitrary
+            // contraction body can be expressed by a vector.contract.
+            if (createNamedContraction &&
+                isa<ContractionOpInterface>(linalgOp.getOperation())) {
+              // Attempt vectorizing directly into a named contraction.
+              // In case of failure, fall back to the generic path.
+              LogicalResult res = vectorizeAsLinalgContraction(
+                  rewriter, state, linalgOp, results);
+              if (succeeded(res))
+                return success();
+
+              LDBG("Failed to vectorize as a named contraction.\n");
+            }
+
             LDBG("Vectorize generic by broadcasting to the canonical vector "
                  "shape\n");
 
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 7e4984582b373..9b055853fc8b0 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
                                      Value source,
                                      ArrayRef<int64_t> inputVectorSizes,
                                      Value padValue,
-                                     bool useInBoundsInsteadOfMasking) {
+                                     bool useInBoundsInsteadOfMasking,
+                                     ArrayRef<bool> scalableDims) {
   assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
          "invalid input vector sizes");
   auto sourceShapedType = cast<ShapedType>(source.getType());
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+  auto vectorType =
+      VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(builder, loc, source);
+      isa<MemRefType>(source.getType())
+          ? memref::getMixedSizes(builder, loc, source)
+          : tensor::getMixedSizes(builder, loc, source);
 
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+  auto maskType =
+      VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
new file mode 100644
index 0000000000000..1831acf092afb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
@@ -0,0 +1,400 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME:    %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+    %C: memref<?x?xf32>) {
+  linalg.matmul
+    ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+    outs(%C: memref<?x?xf32>)
+  return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME:    %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: memref<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+                     affine_map<(m, n, k) -> (n, k)>, // transpose B
+                     affine_map<(m, n, k) -> (m, n)>]
+    ins(%A, %B : tensor<4x8xf32>, tensor<16x4xf32>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_transpose(
+// CHECK-SAME:    %[[A:.*]]: tensor<4x8xf32>, %[[B:.*]]: tensor<16x4xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8xf32>, vector<4x8xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<16x4xf32>, vector<16x4xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic_transpose(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+                     affine_map<(m, n, k) -> (n, k)>, // transpose B
+                     affine_map<(m, n, k) -> (m, n)>]
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_transpose(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<4x8xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<16x4xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
+                     affine_map<(m, n, k) -> (k, n)>,
+                     affine_map<(m, n, k) -> (m, n)>]
+    ins(%A, %B : tensor<4xf32>, tensor<4x16xf32>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @negative_matmul_broadcast(
+// CHECK-NOT: vector.contract
+// CHECK: vector.multi_reduction
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_mixed_precision(%A: tensor<8x4xf16>, %B: tensor<4x16xf16>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<8x4xf16>, tensor<4x16xf16>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_mixed_precision(
+// CHECK-SAME:    %[[A:.*]]: tensor<8x4xf16>, %[[B:.*]]: tensor<4x16xf16>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf16>, vector<8x4xf16>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf16>, vector<4x16xf16>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @batch_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>,
+    %C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32> {
+  %0 = linalg.batch_matmul
+    ins(%A, %B : tensor<3x8x4xf16>, tensor<3x4x16xf16>)
+    outs(%C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32>
+  return %0 : tensor<3x8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME:    %[[A:.*]]: tensor<3x8x4xf16>, %[[B:.*]]: tensor<3x4x16xf16>,
+// CHECK-SAME:    %[[C:.*]]: tensor<3x8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf16>, vector<3x8x4xf16>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf16>, vector<3x4x16xf16>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<3x8x16xf32>, vector<3x8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<3x8x16xf32>, tensor<3x8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @batch_reduce_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.batch_reduce_matmul
+    ins(%A, %B : tensor<3x8x4xf16>, tensor<3x4x16xf16>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK-LABEL: func.func @batch_reduce_matmul(
+// CHECK-SAME:    %[[A:.*]]: tensor<3x8x4xf16>, %[[B:.*]]: tensor<3x4x16xf16>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf16>, vector<3x8x4xf16>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf16>, vector<3x4x16xf16>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @contract(%A: tensor<4x8x2xf16>, %B: tensor<8x16x2xf16>,
+    %C: tensor<4x16xf32>) -> tensor<4x16xf32> {
+  %0 = linalg.contract
+    indexing_maps = [affine_map<(m, n, k, vnni) -> (m, k, vnni)>,
+                     affine_map<(m, n, k, vnni) -> (k, n, vnni)>,
+                     affine_map<(m, n, k, vnni) -> (m, n)>]
+    ins(%A, %B : tensor<4x8x2xf16>, tensor<8x16x2xf16>)
+    outs(%C : tensor<4x16xf32>) -> tensor<4x16xf32>
+  return %0 : tensor<4x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract(
+// CHECK-SAME:    %[[A:.*]]: tensor<4x8x2xf16>, %[[B:.*]]: tensor<8x16x2xf16>,
+// CHECK-SAME:    %[[C:.*]]: tensor<4x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8x2xf16>, vector<4x8x2xf16>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<8x16x2xf16>, vector<8x16x2xf16>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<4x16xf32>, tensor<4x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Generic is currently ignored in direct lowering to a named contraction.
+
+func.func @negative_generic(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+                     affine_map<(m, n, k) -> (k, n)>,
+                     affine_map<(m, n, k) -> (m, n)>],
+    iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+    outs(%C : tensor<8x16xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %out, %1 : f32
+      linalg.yield %2 : f32
+    } -> tensor<8x16xf32>
+    return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @negative_generic(
+// CHECK-NOT: vector.contract
+// CHECK: vector.multi_reduction
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}

>From c9b6f8b55dd44fc0bd9e76f9b7963eecd55b752b Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 16 Jul 2025 13:52:29 +0200
Subject: [PATCH 2/8] Remove flatten transform flag

---
 .../mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td     | 1 -
 mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp    | 3 ++-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 9eda469bdf930..8d45c40a93e2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2435,7 +2435,6 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
       OptionalAttr<UnitAttr>:$vectorize_nd_extract,
       OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
-      OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
       OptionalAttr<UnitAttr>:$create_named_contraction,
       DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b23641a1ceb54..109e5b7f95ec0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false), false,
+                          getVectorizeNdExtract().value_or(false),
+                          /*flatten1DDepthwiseConv=*/false,
                           getAssumeDynamicDimsMatchVecSizes().value_or(false),
                           getCreateNamedContraction().value_or(false));
     if (failed(vectorResults)) {

>From 5355c14e270700108dec1dc9ae4c17cb3cfbd014 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 16 Jul 2025 14:12:59 +0200
Subject: [PATCH 3/8] Rename test + descriptions

---
 ...ion-named.mlir => contraction-interface.mlir} | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)
 rename mlir/test/Dialect/Linalg/vectorization/{contraction-named.mlir => contraction-interface.mlir} (95%)

diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
similarity index 95%
rename from mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
rename to mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
index 1831acf092afb..a3c8e61a29fdf 100644
--- a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
@@ -1,5 +1,11 @@
 // RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
 
+///----------------------------------------------------------------------------------------
+/// Tests for vectorizing operations implementing contraction op interface.
+/// Ops implementing the contraction interface are vectorized directly to their
+/// vector dialect named counterparts.
+///----------------------------------------------------------------------------------------
+
 func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
     %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
   %0 = linalg.matmul
@@ -208,6 +214,12 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+/// Contractions' arbitrarty broadcasts are not supported in contraction interface
+/// vectorization.
+/// Dimension broadcasts are expected to be decomposed first which removes ambiguity
+/// caused by possible variants of dimensions materialization.
+/// For example, whether the below target LHS input layout is (m, k) or (k, m).
+
 func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
     %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
   %0 = linalg.matmul
@@ -368,7 +380,9 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// Generic is currently ignored in direct lowering to a named contraction.
+/// Generic can represent contractions but it does not implement contraction interface.
+/// Thus, direct lowering to vector.contract is not supported.
+/// Vectorization still works and applies generic rewrite logic.
 
 func.func @negative_generic(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
     %C: tensor<8x16xf32>) -> tensor<8x16xf32> {

>From a2eb7f04b5d87824ca0eee720c8022bb150941db Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 16 Jul 2025 14:23:51 +0200
Subject: [PATCH 4/8] Disable contraction vectorization with broadcasts

---
 .../Linalg/Transforms/Vectorization.cpp       | 31 +++++++++----------
 .../vectorization/contraction-interface.mlir  |  9 ++----
 2 files changed, 17 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f2f4330f025f4..5bcd52f5e18be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2128,7 +2128,7 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
 ///   vector::TransferWriteOp - Write the result vector back to the
 ///   destination
 /// The operands shapes are preserved and loaded directly into vectors.
-/// Any further permutations or numerical casting remain within contraction.
+/// Any further permutations or numerical casting remain within contraction op.
 static LogicalResult
 vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
                              LinalgOp linalgOp,
@@ -2136,22 +2136,29 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
   Location loc = linalgOp.getLoc();
   MLIRContext *ctx = linalgOp.getContext();
 
+  // For simplicity, contraction vectorization is limited to linalg named ops.
+  // Generic op is ignored as not every arbitrary contraction body can be
+  // expressed by a vector.contract.
   if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
     return failure();
 
   OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
   Operation *reduceOp = matchLinalgReduction(outOperand);
   auto maybeKind = getCombinerOpKind(reduceOp);
-  if (!maybeKind)
+  if (!maybeKind) {
+    LDBG("Failed to determine contraction combining kind.\n");
     return failure();
+  }
 
   // Check that all dimensions are present in the input operands.
   // Arbitrary broadcasts are not supported by the vector contraction.
-  // Broadcasts are expected to be materialized before vectorization.
+  // Broadcasts are expected to be decomposed before vectorization.
   AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
   AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
-  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
+    LDBG("Contractions with broadcasts are not supported.\n");
     return failure();
+  }
 
   // Load operands.
   SmallVector<Value> vecOperands;
@@ -2687,20 +2694,10 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
               return failure();
             }
 
-            // For simplicity, contraction vectorization is limited to linalg
-            // named ops. Generic op is ignored as not every arbitrary
-            // contraction body can be expressed by a vector.contract.
             if (createNamedContraction &&
-                isa<ContractionOpInterface>(linalgOp.getOperation())) {
-              // Attempt vectorizing directly into a named contraction.
-              // In case of failure, fall back to the generic path.
-              LogicalResult res = vectorizeAsLinalgContraction(
-                  rewriter, state, linalgOp, results);
-              if (succeeded(res))
-                return success();
-
-              LDBG("Failed to vectorize as a named contraction.\n");
-            }
+                isa<ContractionOpInterface>(linalgOp.getOperation()))
+              return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
+                                                  results);
 
             LDBG("Vectorize generic by broadcasting to the canonical vector "
                  "shape\n");
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
index a3c8e61a29fdf..f2d09aee87a4b 100644
--- a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
 
 ///----------------------------------------------------------------------------------------
 /// Tests for vectorizing operations implementing contraction op interface.
@@ -214,7 +214,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-/// Contractions' arbitrarty broadcasts are not supported in contraction interface
+/// Contractions with arbitrarty broadcasts are not supported in contraction interface
 /// vectorization.
 /// Dimension broadcasts are expected to be decomposed first which removes ambiguity
 /// caused by possible variants of dimensions materialization.
@@ -222,6 +222,7 @@ module attributes {transform.with_named_sequence} {
 
 func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
     %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
   %0 = linalg.matmul
     indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
                      affine_map<(m, n, k) -> (k, n)>,
@@ -231,10 +232,6 @@ func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
   return %0 : tensor<8x16xf32>
 }
 
-// CHECK-LABEL: func.func @negative_matmul_broadcast(
-// CHECK-NOT: vector.contract
-// CHECK: vector.multi_reduction
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op

>From 8e4b553c0f5c81c58103be6212318dff8ca12d15 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 16 Jul 2025 14:39:59 +0200
Subject: [PATCH 5/8] Rename test case

---
 .../Dialect/Linalg/vectorization/contraction-interface.mlir   | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
index f2d09aee87a4b..77ddc004ed13a 100644
--- a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
@@ -107,7 +107,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+func.func @matmul_dynamic_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
     %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.matmul
     ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
@@ -118,7 +118,7 @@ func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 // CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 // CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
 // CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-LABEL: func.func @matmul_dynamic_scalable(
 // CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
 // CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
 //      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>

>From 6ffb252650bc4381e6eddfea7d9ba3773ab542d9 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 21 Jul 2025 13:15:29 +0200
Subject: [PATCH 6/8] Refactor tests

---
 .../vectorization/contraction-interface.mlir  | 36 +++++++++----------
 1 file changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
index 77ddc004ed13a..31fb9ca769796 100644
--- a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
@@ -275,10 +275,10 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @batch_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>,
+func.func @batch_matmul(%A: tensor<3x8x4xf32>, %B: tensor<3x4x16xf32>,
     %C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32> {
   %0 = linalg.batch_matmul
-    ins(%A, %B : tensor<3x8x4xf16>, tensor<3x4x16xf16>)
+    ins(%A, %B : tensor<3x8x4xf32>, tensor<3x4x16xf32>)
     outs(%C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32>
   return %0 : tensor<3x8x16xf32>
 }
@@ -287,10 +287,10 @@ func.func @batch_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>,
 // CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 // CHECK-LABEL: func.func @batch_matmul(
-// CHECK-SAME:    %[[A:.*]]: tensor<3x8x4xf16>, %[[B:.*]]: tensor<3x4x16xf16>,
+// CHECK-SAME:    %[[A:.*]]: tensor<3x8x4xf32>, %[[B:.*]]: tensor<3x4x16xf32>,
 // CHECK-SAME:    %[[C:.*]]: tensor<3x8x16xf32>)
-//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf16>, vector<3x8x4xf16>
-//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf16>, vector<3x4x16xf16>
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf32>, vector<3x8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf32>, vector<3x4x16xf32>
 //      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<3x8x16xf32>, vector<3x8x16xf32>
 //      CHECK: %[[CONTRACT:.*]] = vector.contract
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
@@ -308,10 +308,10 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @batch_reduce_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>,
+func.func @batch_reduce_matmul(%A: tensor<3x8x4xf32>, %B: tensor<3x4x16xf32>,
     %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
   %0 = linalg.batch_reduce_matmul
-    ins(%A, %B : tensor<3x8x4xf16>, tensor<3x4x16xf16>)
+    ins(%A, %B : tensor<3x8x4xf32>, tensor<3x4x16xf32>)
     outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
   return %0 : tensor<8x16xf32>
 }
@@ -320,10 +320,10 @@ func.func @batch_reduce_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>,
 // CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 // CHECK-LABEL: func.func @batch_reduce_matmul(
-// CHECK-SAME:    %[[A:.*]]: tensor<3x8x4xf16>, %[[B:.*]]: tensor<3x4x16xf16>,
+// CHECK-SAME:    %[[A:.*]]: tensor<3x8x4xf32>, %[[B:.*]]: tensor<3x4x16xf32>,
 // CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
-//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf16>, vector<3x8x4xf16>
-//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf16>, vector<3x4x16xf16>
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf32>, vector<3x8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf32>, vector<3x4x16xf32>
 //      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
 //      CHECK: %[[CONTRACT:.*]] = vector.contract
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
@@ -341,13 +341,13 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @contract(%A: tensor<4x8x2xf16>, %B: tensor<8x16x2xf16>,
+func.func @contract(%A: tensor<4x8x2xf32>, %B: tensor<8x16x2xf32>,
     %C: tensor<4x16xf32>) -> tensor<4x16xf32> {
   %0 = linalg.contract
-    indexing_maps = [affine_map<(m, n, k, vnni) -> (m, k, vnni)>,
-                     affine_map<(m, n, k, vnni) -> (k, n, vnni)>,
-                     affine_map<(m, n, k, vnni) -> (m, n)>]
-    ins(%A, %B : tensor<4x8x2xf16>, tensor<8x16x2xf16>)
+    indexing_maps = [affine_map<(m, n, k, kk) -> (m, k, kk)>,
+                     affine_map<(m, n, k, kk) -> (k, n, kk)>,
+                     affine_map<(m, n, k, kk) -> (m, n)>]
+    ins(%A, %B : tensor<4x8x2xf32>, tensor<8x16x2xf32>)
     outs(%C : tensor<4x16xf32>) -> tensor<4x16xf32>
   return %0 : tensor<4x16xf32>
 }
@@ -356,10 +356,10 @@ func.func @contract(%A: tensor<4x8x2xf16>, %B: tensor<8x16x2xf16>,
 // CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
 // CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
 // CHECK-LABEL: func.func @contract(
-// CHECK-SAME:    %[[A:.*]]: tensor<4x8x2xf16>, %[[B:.*]]: tensor<8x16x2xf16>,
+// CHECK-SAME:    %[[A:.*]]: tensor<4x8x2xf32>, %[[B:.*]]: tensor<8x16x2xf32>,
 // CHECK-SAME:    %[[C:.*]]: tensor<4x16xf32>)
-//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8x2xf16>, vector<4x8x2xf16>
-//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<8x16x2xf16>, vector<8x16x2xf16>
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8x2xf32>, vector<4x8x2xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<8x16x2xf32>, vector<8x16x2xf32>
 //      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
 //      CHECK: %[[CONTRACT:.*]] = vector.contract
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]

>From dc5ce08b0cdb1ed54e20a120a6c39ae35b97762b Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 21 Jul 2025 13:47:02 +0200
Subject: [PATCH 7/8] Mask contraction

---
 mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp      | 7 ++++---
 .../Linalg/vectorization/contraction-interface.mlir       | 8 ++++----
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5bcd52f5e18be..77c85abab9aa0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2191,14 +2191,15 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
   }
 
   // Create contraction.
-  Value contractOp = rewriter.create<vector::ContractionOp>(
+  Operation *contractOp = rewriter.create<vector::ContractionOp>(
       loc, /*lhs=*/vecOperands[0],
       /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
       linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+  contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
 
   // Store result.
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, contractOp->getResult(0), outOperand->get());
 
   // Finalize.
   if (!write->getResults().empty())
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
index 31fb9ca769796..c577efe66257a 100644
--- a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
@@ -56,7 +56,7 @@ func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 //      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
 //      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
 //      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
-//      CHECK: %[[CONTRACT:.*]] = vector.contract
+//      CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
@@ -90,7 +90,7 @@ func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
 //      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
 //      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
 //      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
-//      CHECK: %[[CONTRACT:.*]] = vector.contract
+//      CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
@@ -124,7 +124,7 @@ func.func @matmul_dynamic_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 //      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
 //      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
 //      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
-//      CHECK: %[[CONTRACT:.*]] = vector.contract
+//      CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
@@ -197,7 +197,7 @@ func.func @matmul_dynamic_transpose(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 //      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<4x8xf32>
 //      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<16x4xf32>
 //      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
-//      CHECK: %[[CONTRACT:.*]] = vector.contract
+//      CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]

>From 9fb7b1e6657419dcbaa5376786002a2be7b52103 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 21 Jul 2025 18:41:48 +0200
Subject: [PATCH 8/8] Test - check vector masks + minor refactor

---
 .../vectorization/contraction-interface.mlir  | 113 ++++++++++++++----
 1 file changed, 93 insertions(+), 20 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
index c577efe66257a..d8f897cca958d 100644
--- a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
@@ -27,7 +27,7 @@ func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+//      CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -53,14 +53,82 @@ func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 // CHECK-LABEL: func.func @matmul_dynamic(
 // CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
 // CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
-//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
-//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
-//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
-//      CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
-// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
-// CHECK-SAME:   kind = #vector.kind<add>
-// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+/// Get the contraction dimensions
+//  CHECK: %[[MATMUL_DIM_M_IDX:.*]] = arith.constant 0 : index
+//  CHECK: %[[MATMUL_DIM_M:.*]] = tensor.dim %[[A]], %[[MATMUL_DIM_M_IDX]] : tensor<?x?xf32>
+//  CHECK: %[[MATMUL_DIM_N_IDX:.*]] = arith.constant 1 : index
+//  CHECK: %[[MATMUL_DIM_N:.*]] = tensor.dim %[[B]], %[[MATMUL_DIM_N_IDX]] : tensor<?x?xf32>
+//  CHECK: %[[MATMUL_DIM_K_IDX:.*]] = arith.constant 1 : index
+//  CHECK: %[[MATMUL_DIM_K:.*]] = tensor.dim %[[A]], %[[MATMUL_DIM_K_IDX]] : tensor<?x?xf32>
+
+/// Create a mask for the A matrix
+//      CHECK: %[[A_OFFSET:.*]] = arith.constant 0 : index
+//      CHECK: %[[A_DIM_M_IDX:.*]] = arith.constant 0 : index
+//      CHECK: %[[A_DIM_M:.*]] = tensor.dim %[[A]], %[[A_DIM_M_IDX]] : tensor<?x?xf32>
+//      CHECK: %[[A_DIM_K_IDX:.*]] = arith.constant 1 : index
+//      CHECK: %[[A_DIM_K:.*]] = tensor.dim %[[A]], %[[A_DIM_K_IDX]] : tensor<?x?xf32>
+//      CHECK: %[[LOAD_A_MASK:.*]] = vector.create_mask
+// CHECK-SAME:   %[[A_DIM_M]], %[[A_DIM_K]] : vector<8x4xi1>
+/// Read the A matrix
+//      CHECK: %[[LOAD_A:.*]] = vector.mask %[[LOAD_A_MASK]]
+// CHECK-SAME:   { vector.transfer_read %[[A]]{{\[}}%[[A_OFFSET]], %[[A_OFFSET]]{{\]}}
+// CHECK-SAME:     : tensor<?x?xf32>, vector<8x4xf32> }
+// CHECK-SAME:   : vector<8x4xi1> -> vector<8x4xf32>
+
+/// Create a mask for the B matrix
+//      CHECK: %[[B_OFFSET:.*]] = arith.constant 0 : index
+//      CHECK: %[[B_DIM_K_IDX:.*]] = arith.constant 0 : index
+//      CHECK: %[[B_DIM_K:.*]] = tensor.dim %[[B]], %[[B_DIM_K_IDX]] : tensor<?x?xf32>
+//      CHECK: %[[B_DIM_N_IDX:.*]] = arith.constant 1 : index
+//      CHECK: %[[B_DIM_N:.*]] = tensor.dim %[[B]], %[[B_DIM_N_IDX]] : tensor<?x?xf32>
+//      CHECK: %[[LOAD_B_MASK:.*]] = vector.create_mask
+// CHECK-SAME:   %[[B_DIM_K]], %[[B_DIM_N]] : vector<4x16xi1>
+/// Read the B matrix
+//      CHECK: %[[LOAD_B:.*]] = vector.mask %[[LOAD_B_MASK]]
+// CHECK-SAME:   { vector.transfer_read %[[B]]{{\[}}%[[B_OFFSET]], %[[B_OFFSET]]{{\]}}
+// CHECK-SAME:     : tensor<?x?xf32>, vector<4x16xf32> }
+// CHECK-SAME:   : vector<4x16xi1> -> vector<4x16xf32>
+
+/// Create a mask for the C matrix
+//      CHECK: %[[C_OFFSET:.*]] = arith.constant 0 : index
+//      CHECK: %[[C_DIM_M_IDX:.*]] = arith.constant 0 : index
+//      CHECK: %[[C_DIM_M:.*]] = tensor.dim %[[C]], %[[C_DIM_M_IDX]] : tensor<?x?xf32>
+//      CHECK: %[[C_DIM_N_IDX:.*]] = arith.constant 1 : index
+//      CHECK: %[[C_DIM_N:.*]] = tensor.dim %[[C]], %[[C_DIM_N_IDX]] : tensor<?x?xf32>
+//      CHECK: %[[LOAD_C_MASK:.*]] = vector.create_mask
+// CHECK-SAME:   %[[C_DIM_M]], %[[C_DIM_N]] : vector<8x16xi1>
+/// Read the C matrix
+//      CHECK: %[[LOAD_C:.*]] = vector.mask %[[LOAD_C_MASK]]
+// CHECK-SAME:   { vector.transfer_read %[[C]]{{\[}}%[[C_OFFSET]], %[[C_OFFSET]]{{\]}}
+// CHECK-SAME:     : tensor<?x?xf32>, vector<8x16xf32> }
+// CHECK-SAME:   : vector<8x16xi1> -> vector<8x16xf32>
+
+/// Create a mask for the contraction
+//      CHECK: %[[CONTRACTION_MASK:.*]] = vector.create_mask
+// CHECK-SAME:   %[[MATMUL_DIM_M]], %[[MATMUL_DIM_N]], %[[MATMUL_DIM_K]]
+// CHECK-SAME:   : vector<8x16x4xi1>
+/// Perform the contraction
+//      CHECK: %[[D:.*]] = vector.mask %[[CONTRACTION_MASK]]
+// CHECK-SAME:   { vector.contract
+// CHECK-SAME:     indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:     kind = #vector.kind<add>
+// CHECK-SAME:     %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK-SAME:   } : vector<8x16x4xi1> -> vector<8x16xf32>
+
+/// Create a mask for the result
+//      CHECK: %[[D_OFFSET:.*]] = arith.constant 0 : index
+//      CHECK: %[[D_DIM_M_IDX:.*]] = arith.constant 0 : index
+//      CHECK: %[[D_DIM_M:.*]] = tensor.dim %[[C]], %[[D_DIM_M_IDX]] : tensor<?x?xf32>
+//      CHECK: %[[D_DIM_N_IDX:.*]] = arith.constant 1 : index
+//      CHECK: %[[D_DIM_N:.*]] = tensor.dim %[[C]], %[[D_DIM_N_IDX]] : tensor<?x?xf32>
+//      CHECK: %[[LOAD_D_MASK:.*]] = vector.create_mask
+// CHECK-SAME:   %[[D_DIM_M]], %[[D_DIM_N]] : vector<8x16xi1>
+/// Write the result
+//      CHECK: vector.mask %[[LOAD_D_MASK]]
+// CHECK-SAME: { vector.transfer_write %[[D]], %[[C]]{{\[}}%[[D_OFFSET]], %[[D_OFFSET]]{{\]}}
+// CHECK-SAME:   : vector<8x16xf32>, tensor<?x?xf32> }
+// CHECK-SAME: : vector<8x16xi1> -> tensor<?x?xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -94,7 +162,7 @@ func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+//      CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -121,14 +189,19 @@ func.func @matmul_dynamic_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 // CHECK-LABEL: func.func @matmul_dynamic_scalable(
 // CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
 // CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
-//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
-//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
-//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32> }
+// CHECK-SAME:   : vector<8x4xi1> -> vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32> }
+// CHECK-SAME:   : vector<4x[16]xi1> -> vector<4x[16]xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32> }
+// CHECK-SAME:   : vector<8x[16]xi1> -> vector<8x[16]xf32>
 //      CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+// CHECK-SAME:   } : vector<8x[16]x4xi1> -> vector<8x[16]xf32>
+//      CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32> }
+// CHECK-SAME:   : vector<8x[16]xi1> -> tensor<?x?xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -165,7 +238,7 @@ func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+//      CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -201,7 +274,7 @@ func.func @matmul_dynamic_transpose(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+//      CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -263,7 +336,7 @@ func.func @matmul_mixed_precision(%A: tensor<8x4xf16>, %B: tensor<4x16xf16>,
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+//      CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -296,7 +369,7 @@ func.func @batch_matmul(%A: tensor<3x8x4xf32>, %B: tensor<3x4x16xf32>,
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<3x8x16xf32>, tensor<3x8x16xf32>
+//      CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<3x8x16xf32>, tensor<3x8x16xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -329,7 +402,7 @@ func.func @batch_reduce_matmul(%A: tensor<3x8x4xf32>, %B: tensor<3x4x16xf32>,
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+//      CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -365,7 +438,7 @@ func.func @contract(%A: tensor<4x8x2xf32>, %B: tensor<8x16x2xf32>,
 // CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
 // CHECK-SAME:   kind = #vector.kind<add>
 // CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
-// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<4x16xf32>, tensor<4x16xf32>
+//      CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<4x16xf32>, tensor<4x16xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {



More information about the Mlir-commits mailing list