[Mlir-commits] [mlir] [mlir][linalg] Vectorize directly to a named contraction (PR #147296)

Adam Siemieniuk llvmlistbot at llvm.org
Wed Jul 16 05:38:11 PDT 2025


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/147296

>From 7ea54e9126d78c45f6d6167e941245277e919c8c Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 3 Jul 2025 13:19:43 +0200
Subject: [PATCH 1/4] [mlir][linalg] Vectorize directly to a named contraction

Extends linalg vectorizer with a path to lower contraction ops
directly into `vector.contract`.

The direct rewriting preserves high-level op semantics and provides
more progressive lowering compared to reconstructing contraction back
from multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their
well defined semantics to avoid complex precondition verification.

The new path is optional and disabled by default to avoid changing
the default vectorizer behavior.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |   2 +
 .../Dialect/Linalg/Transforms/Transforms.h    |   5 +-
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |   3 +-
 .../TransformOps/LinalgTransformOps.cpp       |   4 +-
 .../Linalg/Transforms/Vectorization.cpp       | 110 ++++-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  13 +-
 .../vectorization/contraction-named.mlir      | 400 ++++++++++++++++++
 7 files changed, 523 insertions(+), 14 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b4dde776822a1..0f51ac2f61c18 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2435,6 +2435,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
                           $static_vector_sizes,
                        OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+                       OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
+                       OptionalAttr<UnitAttr>:$create_named_contraction,
                        DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
                           $scalable_sizes);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..eee856f3eba68 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,11 +876,14 @@ struct VectorizationResult {
 /// greater than or equal to their counterpart iteration space sizes, if static.
 /// `inputVectorShapes` also allows the vectorization of operations with dynamic
 /// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
 FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
-          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+          bool createNamedContraction = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..9b765d0b8ede6 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
 /// Note: all read offsets are set to 0.
 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
                              ArrayRef<int64_t> inputVectorSizes, Value padValue,
-                             bool useInBoundsInsteadOfMasking = false);
+                             bool useInBoundsInsteadOfMasking = false,
+                             ArrayRef<bool> scalableDims = {});
 
 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
 /// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d5f9de465561..bcab8197c9dd4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false));
+                          getVectorizeNdExtract().value_or(false),
+                          getFlatten1DDepthwiseConv().value_or(false),
+                          getCreateNamedContraction().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..6e428646eaad1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
     return write;
 
   // Compute the mask and mask the write Op.
-  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+                                       vecToStoreType.getScalableDims());
 
   SmallVector<OpFoldResult> destSizes =
-      tensor::getMixedSizes(builder, loc, dest);
+      isa<MemRefType>(dest.getType())
+          ? memref::getMixedSizes(builder, loc, dest)
+          : tensor::getMixedSizes(builder, loc, dest);
   SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
                                       destSizes.end());
 
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+/// Vectorize a named linalg contraction op into:
+///   vector::TransferReadOp - Reads vectors from the operands
+///   vector::ContractionOp - Performs contraction
+///   vector::TransferWriteOp - Write the result vector back to the
+///   destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+                             LinalgOp linalgOp,
+                             SmallVectorImpl<Value> &newResults) {
+  Location loc = linalgOp.getLoc();
+  MLIRContext *ctx = linalgOp.getContext();
+
+  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+    return failure();
+
+  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+  Operation *reduceOp = matchLinalgReduction(outOperand);
+  auto maybeKind = getCombinerOpKind(reduceOp);
+  if (!maybeKind)
+    return failure();
+
+  // Check that all dimensions are present in the input operands.
+  // Arbitrary broadcasts are not supported by the vector contraction.
+  // Broadcasts are expected to be materialized before vectorization.
+  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+    return failure();
+
+  // Load operands.
+  SmallVector<Value> vecOperands;
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    // The operand vector shape is computed by mapping the canonical vector
+    // shape to the operand's domain. Further permutations are left as a part of
+    // the contraction.
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+    AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+        indexingMap.getNumResults(), rewriter.getContext());
+    Type elemType = getElementTypeOrSelf(opOperand.get());
+    VectorType readType =
+        state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+    Value read = mlir::vector::createReadOrMaskedRead(
+        rewriter, loc, opOperand.get(), readType.getShape(),
+        /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+        /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+    vecOperands.push_back(read);
+  }
+
+  // Remap iterators from linalg to vector.
+  SmallVector<Attribute> iterAttrs;
+  auto iterators = linalgOp.getIteratorTypesArray();
+  for (utils::IteratorType iter : iterators) {
+    auto vecIter = iter == utils::IteratorType::parallel
+                       ? vector::IteratorType::parallel
+                       : vector::IteratorType::reduction;
+    iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+  }
+
+  // Create contraction.
+  Value contractOp = rewriter.create<vector::ContractionOp>(
+      loc, /*lhs=*/vecOperands[0],
+      /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+      linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+  // Store result.
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+
+  // Finalize.
+  if (!write->getResults().empty())
+    newResults.push_back(write->getResult(0));
+
+  return success();
+}
+
 namespace {
 enum class ConvOperationKind { Conv, Pool };
 } // namespace
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        ArrayRef<bool> inputScalableVecDims,
-                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+    RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv, bool createNamedContraction) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
               return failure();
             }
 
+            // For simplicity, contraction vectorization is limited to linalg
+            // named ops. Generic op is ignored as not every arbitrary
+            // contraction body can be expressed by a vector.contract.
+            if (createNamedContraction &&
+                isa<ContractionOpInterface>(linalgOp.getOperation())) {
+              // Attempt vectorizing directly into a named contraction.
+              // In case of failure, fall back to the generic path.
+              LogicalResult res = vectorizeAsLinalgContraction(
+                  rewriter, state, linalgOp, results);
+              if (succeeded(res))
+                return success();
+
+              LDBG("Failed to vectorize as a named contraction.\n");
+            }
+
             LDBG("Vectorize generic by broadcasting to the canonical vector "
                  "shape\n");
 
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 7e4984582b373..9b055853fc8b0 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
                                      Value source,
                                      ArrayRef<int64_t> inputVectorSizes,
                                      Value padValue,
-                                     bool useInBoundsInsteadOfMasking) {
+                                     bool useInBoundsInsteadOfMasking,
+                                     ArrayRef<bool> scalableDims) {
   assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
          "invalid input vector sizes");
   auto sourceShapedType = cast<ShapedType>(source.getType());
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+  auto vectorType =
+      VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(builder, loc, source);
+      isa<MemRefType>(source.getType())
+          ? memref::getMixedSizes(builder, loc, source)
+          : tensor::getMixedSizes(builder, loc, source);
 
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+  auto maskType =
+      VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
new file mode 100644
index 0000000000000..1831acf092afb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
@@ -0,0 +1,400 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME:    %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+    %C: memref<?x?xf32>) {
+  linalg.matmul
+    ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+    outs(%C: memref<?x?xf32>)
+  return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME:    %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: memref<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+                     affine_map<(m, n, k) -> (n, k)>, // transpose B
+                     affine_map<(m, n, k) -> (m, n)>]
+    ins(%A, %B : tensor<4x8xf32>, tensor<16x4xf32>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_transpose(
+// CHECK-SAME:    %[[A:.*]]: tensor<4x8xf32>, %[[B:.*]]: tensor<16x4xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8xf32>, vector<4x8xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<16x4xf32>, vector<16x4xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic_transpose(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+                     affine_map<(m, n, k) -> (n, k)>, // transpose B
+                     affine_map<(m, n, k) -> (m, n)>]
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_transpose(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<4x8xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<16x4xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
+                     affine_map<(m, n, k) -> (k, n)>,
+                     affine_map<(m, n, k) -> (m, n)>]
+    ins(%A, %B : tensor<4xf32>, tensor<4x16xf32>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @negative_matmul_broadcast(
+// CHECK-NOT: vector.contract
+// CHECK: vector.multi_reduction
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_mixed_precision(%A: tensor<8x4xf16>, %B: tensor<4x16xf16>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<8x4xf16>, tensor<4x16xf16>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_mixed_precision(
+// CHECK-SAME:    %[[A:.*]]: tensor<8x4xf16>, %[[B:.*]]: tensor<4x16xf16>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf16>, vector<8x4xf16>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf16>, vector<4x16xf16>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @batch_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>,
+    %C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32> {
+  %0 = linalg.batch_matmul
+    ins(%A, %B : tensor<3x8x4xf16>, tensor<3x4x16xf16>)
+    outs(%C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32>
+  return %0 : tensor<3x8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME:    %[[A:.*]]: tensor<3x8x4xf16>, %[[B:.*]]: tensor<3x4x16xf16>,
+// CHECK-SAME:    %[[C:.*]]: tensor<3x8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf16>, vector<3x8x4xf16>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf16>, vector<3x4x16xf16>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<3x8x16xf32>, vector<3x8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<3x8x16xf32>, tensor<3x8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @batch_reduce_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.batch_reduce_matmul
+    ins(%A, %B : tensor<3x8x4xf16>, tensor<3x4x16xf16>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK-LABEL: func.func @batch_reduce_matmul(
+// CHECK-SAME:    %[[A:.*]]: tensor<3x8x4xf16>, %[[B:.*]]: tensor<3x4x16xf16>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf16>, vector<3x8x4xf16>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf16>, vector<3x4x16xf16>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @contract(%A: tensor<4x8x2xf16>, %B: tensor<8x16x2xf16>,
+    %C: tensor<4x16xf32>) -> tensor<4x16xf32> {
+  %0 = linalg.contract
+    indexing_maps = [affine_map<(m, n, k, vnni) -> (m, k, vnni)>,
+                     affine_map<(m, n, k, vnni) -> (k, n, vnni)>,
+                     affine_map<(m, n, k, vnni) -> (m, n)>]
+    ins(%A, %B : tensor<4x8x2xf16>, tensor<8x16x2xf16>)
+    outs(%C : tensor<4x16xf32>) -> tensor<4x16xf32>
+  return %0 : tensor<4x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract(
+// CHECK-SAME:    %[[A:.*]]: tensor<4x8x2xf16>, %[[B:.*]]: tensor<8x16x2xf16>,
+// CHECK-SAME:    %[[C:.*]]: tensor<4x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8x2xf16>, vector<4x8x2xf16>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<8x16x2xf16>, vector<8x16x2xf16>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<4x16xf32>, tensor<4x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Generic is currently ignored in direct lowering to a named contraction.
+
+func.func @negative_generic(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+                     affine_map<(m, n, k) -> (k, n)>,
+                     affine_map<(m, n, k) -> (m, n)>],
+    iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+    outs(%C : tensor<8x16xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %out, %1 : f32
+      linalg.yield %2 : f32
+    } -> tensor<8x16xf32>
+    return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @negative_generic(
+// CHECK-NOT: vector.contract
+// CHECK: vector.multi_reduction
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}

>From 69f6aca7274d87b33cab4e81d823e4864378111d Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 16 Jul 2025 13:52:29 +0200
Subject: [PATCH 2/4] Remove flatten transform flag

---
 .../mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td      | 1 -
 mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp     | 2 +-
 2 files changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 0f51ac2f61c18..f87e0c331d824 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2435,7 +2435,6 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
                           $static_vector_sizes,
                        OptionalAttr<UnitAttr>:$vectorize_nd_extract,
-                       OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
                        OptionalAttr<UnitAttr>:$create_named_contraction,
                        DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
                           $scalable_sizes);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bcab8197c9dd4..ccb772d521103 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3921,7 +3921,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
                           getVectorizeNdExtract().value_or(false),
-                          getFlatten1DDepthwiseConv().value_or(false),
+                          /*flatten1DDepthwiseConv=*/false,
                           getCreateNamedContraction().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())

>From 42a16329529a0233cf08f3805f72046d691c13fe Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 16 Jul 2025 14:12:59 +0200
Subject: [PATCH 3/4] Rename test + descriptions

---
 ...ion-named.mlir => contraction-interface.mlir} | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)
 rename mlir/test/Dialect/Linalg/vectorization/{contraction-named.mlir => contraction-interface.mlir} (95%)

diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
similarity index 95%
rename from mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
rename to mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
index 1831acf092afb..a3c8e61a29fdf 100644
--- a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
@@ -1,5 +1,11 @@
 // RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
 
+///----------------------------------------------------------------------------------------
+/// Tests for vectorizing operations implementing contraction op interface.
+/// Ops implementing the contraction interface are vectorized directly to their
+/// vector dialect named counterparts.
+///----------------------------------------------------------------------------------------
+
 func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
     %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
   %0 = linalg.matmul
@@ -208,6 +214,12 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+/// Contractions' arbitrarty broadcasts are not supported in contraction interface
+/// vectorization.
+/// Dimension broadcasts are expected to be decomposed first which removes ambiguity
+/// caused by possible variants of dimensions materialization.
+/// For example, whether the below target LHS input layout is (m, k) or (k, m).
+
 func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
     %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
   %0 = linalg.matmul
@@ -368,7 +380,9 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// Generic is currently ignored in direct lowering to a named contraction.
+/// Generic can represent contractions but it does not implement contraction interface.
+/// Thus, direct lowering to vector.contract is not supported.
+/// Vectorization still works and applies generic rewrite logic.
 
 func.func @negative_generic(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
     %C: tensor<8x16xf32>) -> tensor<8x16xf32> {

>From 452b053c807ad06a8baece00c72da2386ef24187 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 16 Jul 2025 14:23:51 +0200
Subject: [PATCH 4/4] Disable contraction vectorization with broadcasts

---
 .../Linalg/Transforms/Vectorization.cpp       | 31 +++++++++----------
 .../vectorization/contraction-interface.mlir  |  9 ++----
 2 files changed, 17 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6e428646eaad1..39c183ff76f0c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2103,7 +2103,7 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
 ///   vector::TransferWriteOp - Write the result vector back to the
 ///   destination
 /// The operands shapes are preserved and loaded directly into vectors.
-/// Any further permutations or numerical casting remain within contraction.
+/// Any further permutations or numerical casting remain within contraction op.
 static LogicalResult
 vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
                              LinalgOp linalgOp,
@@ -2111,22 +2111,29 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
   Location loc = linalgOp.getLoc();
   MLIRContext *ctx = linalgOp.getContext();
 
+  // For simplicity, contraction vectorization is limited to linalg named ops.
+  // Generic op is ignored as not every arbitrary contraction body can be
+  // expressed by a vector.contract.
   if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
     return failure();
 
   OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
   Operation *reduceOp = matchLinalgReduction(outOperand);
   auto maybeKind = getCombinerOpKind(reduceOp);
-  if (!maybeKind)
+  if (!maybeKind) {
+    LDBG("Failed to determine contraction combining kind.\n");
     return failure();
+  }
 
   // Check that all dimensions are present in the input operands.
   // Arbitrary broadcasts are not supported by the vector contraction.
-  // Broadcasts are expected to be materialized before vectorization.
+  // Broadcasts are expected to be decomposed before vectorization.
   AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
   AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
-  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
+    LDBG("Contractions with broadcasts are not supported.\n");
     return failure();
+  }
 
   // Load operands.
   SmallVector<Value> vecOperands;
@@ -2659,20 +2666,10 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
               return failure();
             }
 
-            // For simplicity, contraction vectorization is limited to linalg
-            // named ops. Generic op is ignored as not every arbitrary
-            // contraction body can be expressed by a vector.contract.
             if (createNamedContraction &&
-                isa<ContractionOpInterface>(linalgOp.getOperation())) {
-              // Attempt vectorizing directly into a named contraction.
-              // In case of failure, fall back to the generic path.
-              LogicalResult res = vectorizeAsLinalgContraction(
-                  rewriter, state, linalgOp, results);
-              if (succeeded(res))
-                return success();
-
-              LDBG("Failed to vectorize as a named contraction.\n");
-            }
+                isa<ContractionOpInterface>(linalgOp.getOperation()))
+              return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
+                                                  results);
 
             LDBG("Vectorize generic by broadcasting to the canonical vector "
                  "shape\n");
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
index a3c8e61a29fdf..f2d09aee87a4b 100644
--- a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
 
 ///----------------------------------------------------------------------------------------
 /// Tests for vectorizing operations implementing contraction op interface.
@@ -214,7 +214,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-/// Contractions' arbitrarty broadcasts are not supported in contraction interface
+/// Contractions with arbitrarty broadcasts are not supported in contraction interface
 /// vectorization.
 /// Dimension broadcasts are expected to be decomposed first which removes ambiguity
 /// caused by possible variants of dimensions materialization.
@@ -222,6 +222,7 @@ module attributes {transform.with_named_sequence} {
 
 func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
     %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
   %0 = linalg.matmul
     indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
                      affine_map<(m, n, k) -> (k, n)>,
@@ -231,10 +232,6 @@ func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
   return %0 : tensor<8x16xf32>
 }
 
-// CHECK-LABEL: func.func @negative_matmul_broadcast(
-// CHECK-NOT: vector.contract
-// CHECK: vector.multi_reduction
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op



More information about the Mlir-commits mailing list