[Mlir-commits] [mlir] [mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect (PR #144307)

Andrzej Warzyński llvmlistbot at llvm.org
Thu Jul 17 08:49:44 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/144307

>From 1a2d4812d7d97ea8bd3e0a563c2f5cdb08f6d902 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 29 May 2025 20:35:03 +0100
Subject: [PATCH 1/5] [mlir][vector] Remove MatrixMultiplyOp and
 FlatTransposeOp from Vector dialect
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch deletes `vector.matrix_multiply` and `vector.flat_transpose`,
which are thin wrappers around the corresponding LLVM intrinsics:
  - `llvm.intr.matrix.multiply`
  - `llvm.intr.matrix.transpose`

These Vector dialect ops did not provide additional semantics or
abstraction beyond the LLVM intrinsics. Their removal simplifies the
lowering pipeline without losing any functionality.

The lowering chains:
  - `vector.contract` → `vector.matrix_multiply` → `llvm.intr.matrix.multiply`
  - `vector.transpose` → `vector.flat_transpose` → `llvm.intr.matrix.transpose`

are now replaced with:
  - `vector.contract` → `llvm.intr.matrix.multiply`
  - `vector.transpose` → `llvm.intr.matrix.transpose`

This was accomplished by directly replacing:
  - `vector::MatrixMultiplyOp` with `LLVM::MatrixMultiplyOp`
  - `vector::FlatTransposeOp` with `LLVM::MatrixTransposeOp`

Note: This change introduces a build-time dependency from `Vector` to
`LLVM`. Ideally, such dependencies should be confined to dialect
conversion (`ConvertVectorToLLVM`). However, moving the lowering code
there would introduce notable churn, so this patch leaves the new
dependency in place for now.
---
 .../VectorToLLVM/ConvertVectorToLLVM.h        |   6 -
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 118 ------------------
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  41 ------
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   2 -
 .../Transforms/EmulateUnsupportedFloats.cpp   |   6 +-
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 +
 .../Vector/Transforms/LowerVectorContract.cpp |  11 +-
 .../Transforms/LowerVectorTranspose.cpp       |   2 +-
 .../VectorToLLVM/vector-to-llvm.mlir          |  80 ------------
 mlir/test/Dialect/Vector/invalid.mlir         |  27 ----
 mlir/test/Dialect/Vector/ops.mlir             |  16 ---
 ...tract-to-matrix-intrinsics-transforms.mlir |   4 +-
 .../Vector/vector-transpose-lowering.mlir     |   4 +-
 .../Vector/CPU/flat-transpose-col.mlir        |   8 +-
 .../Vector/CPU/flat-transpose-row.mlir        |   8 +-
 .../Vector/CPU/matrix-multiply-col.mlir       |   2 +-
 .../Vector/CPU/matrix-multiply-row.mlir       |   2 +-
 17 files changed, 26 insertions(+), 312 deletions(-)

diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index f6b09deb4e44c..cfb6cc313bc63 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -13,12 +13,6 @@
 namespace mlir {
 class LLVMTypeConverter;
 
-/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
-/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
-/// will be needed when invoking LLVM.
-void populateVectorToLLVMMatrixConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
-
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cbe490f6e4dd1..38a950023c772 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2788,124 +2788,6 @@ def Vector_PrintOp :
     }];
 }
 
-//===----------------------------------------------------------------------===//
-// Ops used for supporting progressive lowering and conversion type changes.
-// The Ops are typically not used directly by higher level dialects, but are
-// used by intra-dialect rewriting rules to bring vector operations closer
-// to the hardware ISA.
-//===----------------------------------------------------------------------===//
-
-/// Vector dialect matrix multiplication op that operates on flattened 1-D
-/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
-/// This may seem redundant with vector.contract but it serves the purposes of
-/// more progressive lowering and localized type conversion on the path:
-///   `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
-def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
-        PredOpTrait<"lhs operand and result have same element type",
-                    TCresVTEtIsSameAsOpBase<0, 0>>,
-        PredOpTrait<"rhs operand and result have same element type",
-                    TCresVTEtIsSameAsOpBase<0, 1>>]>,
-      Arguments<(
-        // TODO: tighten vector element types that make sense.
-        ins FixedVectorOfRankAndType<[1],
-              [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
-            FixedVectorOfRankAndType<[1],
-              [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
-            I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
-      Results<(
-        outs FixedVectorOfRankAndType<[1],
-               [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
-{
-  let summary = "Vector matrix multiplication op that operates on flattened 1-D"
-    " MLIR vectors";
-  let description = [{
-    This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
-    purposes of more progressive lowering and localized type conversion.
-    Higher levels typically lower matrix multiplications into 'vector.contract'
-    operations. Subsequent rewriting rule progressively lower these operations
-    into 'vector.matrix_multiply' operations to bring the operations closer
-    to the hardware ISA.
-
-    The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
-    and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
-    <rhs_columns> and multiplies them. The result matrix is returned embedded in
-    the result vector.
-
-    Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
-    support scalable vectors. Hence, this Op is only available for fixed-width
-    vectors. Also see:
-
-    http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
-
-    Example:
-
-    ```mlir
-    %C = vector.matrix_multiply %A, %B
-      { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
-      (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
-    ```
-  }];
-  let builders = [
-   OpBuilder<(ins "Value":$lhs, "Value":$rhs, "unsigned":$lhsRows,
-     "unsigned":$lhsColumns, "unsigned":$rhsColumns),
-   [{
-     $_state.addOperands({lhs, rhs});
-     $_state.addAttribute("lhs_rows",$_builder.getI32IntegerAttr(lhsRows));
-     $_state.addAttribute("lhs_columns",$_builder.getI32IntegerAttr(lhsColumns));
-     $_state.addAttribute("rhs_columns",$_builder.getI32IntegerAttr(rhsColumns));
-     $_state.addTypes(VectorType::get(lhsRows * rhsColumns,
-       ::llvm::cast<VectorType>(lhs.getType()).getElementType()));
-   }]>,
-  ];
-  let assemblyFormat = "$lhs `,` $rhs attr-dict "
-    "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
-}
-
-/// Vector dialect matrix transposition op that operates on flattened 1-D
-/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR.
-/// This may seem redundant with vector.transpose but it serves the purposes of
-/// more progressive lowering and localized type conversion on the path:
-///   `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
-def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
-  PredOpTrait<"source operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(
-      // TODO: tighten vector element types that make sense.
-      ins FixedVectorOfRankAndType<[1],
-            [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
-          I32Attr:$rows, I32Attr:$columns)>,
-    Results<(
-      outs FixedVectorOfRankAndType<[1],
-             [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
-  let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
-  let description = [{
-    This is the counterpart of llvm.matrix.transpose in MLIR. It serves
-    the purposes of more progressive lowering and localized type conversion.
-    Higher levels typically lower matrix transpositions into 'vector.transpose'
-    operations. Subsequent rewriting rule progressively lower these operations
-    into 'vector.flat_transpose' operations to bring the operations closer
-    to the hardware ISA.
-
-    The `vector.flat_transpose` op treats the 1-D input `matrix` as
-    a 2-D matrix with <rows> rows and <columns> columns, and returns the
-    transposed matrix in flattened form in 'res'.
-
-    Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not
-    support scalable vectors. Hence, this Op is only available for fixed-width
-    vectors. Also see:
-
-    http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic
-
-    Example:
-
-    ```mlir
-    %1 = vector.flat_transpose %0 {columns = 4 : i32, rows = 4 : i32}
-       : vector<16xf32> -> vector<16xf32>
-    ```
-  }];
-  let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
-}
-
 //===----------------------------------------------------------------------===//
 // SplatOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 501d98862672d..760dbea150758 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -184,41 +184,6 @@ class VectorBitCastOpConversion
   }
 };
 
-/// Conversion pattern for a vector.matrix_multiply.
-/// This is lowered directly to the proper llvm.intr.matrix.multiply.
-class VectorMatmulOpConversion
-    : public ConvertOpToLLVMPattern<vector::MatmulOp> {
-public:
-  using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
-        matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
-        adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
-        matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
-    return success();
-  }
-};
-
-/// Conversion pattern for a vector.flat_transpose.
-/// This is lowered directly to the proper llvm.intr.matrix.transpose.
-class VectorFlatTransposeOpConversion
-    : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
-public:
-  using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
-        transOp, typeConverter->convertType(transOp.getRes().getType()),
-        adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
-    return success();
-  }
-};
-
 /// Overloaded utility that replaces a vector.load, vector.store,
 /// vector.maskedload and vector.maskedstore with their respective LLVM
 /// couterparts.
@@ -2071,12 +2036,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
       converter);
 }
 
-void mlir::populateVectorToLLVMMatrixConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<VectorMatmulOpConversion>(converter);
-  patterns.add<VectorFlatTransposeOpConversion>(converter);
-}
-
 namespace {
 struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 549d0210af7ad..150ef41a24373 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -96,11 +96,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext(), options);
   RewritePatternSet patterns(&getContext());
   populateVectorTransferLoweringPatterns(patterns);
-  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
       converter, patterns, reassociateFPReductions, force32BitVectorIndices,
       useVectorAlignment);
-  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
 
   // Architecture specific augmentations.
   LLVMConversionTarget target(getContext());
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 62022bfb7df1e..f14264e2f55f3 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -118,9 +118,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
         return converter.isLegal(op);
       });
   // Manually mark arithmetic-performing vector instructions.
-  target.addDynamicallyLegalOp<
-      vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
-      vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
+  target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
+                               vector::MultiDimReductionOp, vector::FMAOp,
+                               vector::OuterProductOp, vector::ScanOp>(
       [&](Operation *op) { return converter.isLegal(op); });
   target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
                     arith::ConstantOp, vector::SplatOp>();
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 9e287fc109990..ff8a1f9a52910 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -50,6 +50,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRTensorDialect
   MLIRTransforms
   MLIRVectorDialect
+  MLIRLLVMDialect
   MLIRVectorInterfaces
   MLIRVectorUtils
   )
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 0cd3f786d8101..147821fe2dade 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1271,12 +1271,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 ///    %mtb = maybe_transpose
 ///    %flattened_a = vector.shape_cast %mta
 ///    %flattened_b = vector.shape_cast %mtb
-///    %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
+///    %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
 ///    %mtd = vector.shape_cast %flattened_d
 ///    %d = maybe_untranspose %mtd
 ///    %e = add %c, %d
 /// ```
-/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
 //
 /// This only kicks in when vectorContractLowering is set to `Matmul`.
 /// vector.transpose operations are inserted if the vector.contract op is not a
@@ -1353,8 +1352,12 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
 
-  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
-                                           rhsColumns);
+  Value mul = rew.create<LLVM::MatrixMultiplyOp>(
+      loc,
+      VectorType::get(lhsRows * rhsColumns,
+                      cast<VectorType>(lhs.getType()).getElementType()),
+      lhs, rhs, lhsRows, lhsColumns, rhsColumns);
+
   mul = rew.create<vector::ShapeCastOp>(
       loc,
       VectorType::get({lhsRows, rhsColumns},
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 1fac967e4a2f9..b5290cb14e079 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -337,7 +337,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
-      Value trans = rewriter.create<vector::FlatTransposeOp>(
+      Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
           loc, flattenedType, matrix, rows, columns);
       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
       return success();
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 64e51f5554628..72810b5dddaa3 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1424,36 +1424,6 @@ func.func @fma_scalable(%vec_1d: vector<[8]xf32>, %vec_2d: vector<2x[4]xf32>, %v
 
   return %0, %1, %2: vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32>
 }
-// -----
-
-//===----------------------------------------------------------------------===//
-// vector.matrix_multiply
-//===----------------------------------------------------------------------===//
-
-//                          4x16                16x3               4x3
-func.func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
-  %C = vector.matrix_multiply %A, %B
-    { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
-    (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
-  return %C: vector<12xf64>
-}
-// CHECK-LABEL: @matrix_ops
-//       CHECK:   llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
-//  CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
-//  CHECK-SAME: } : (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
-
-// -----
-
-func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vector<12xindex> {
-  %C = vector.matrix_multiply %A, %B
-    { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
-    (vector<64xindex>, vector<48xindex>) -> vector<12xindex>
-  return %C: vector<12xindex>
-}
-// CHECK-LABEL: @matrix_ops_index
-//       CHECK:   llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
-//  CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
-//  CHECK-SAME: } : (vector<64xi64>, vector<48xi64>) -> vector<12xi64>
 
 // -----
 
@@ -1602,56 +1572,6 @@ func.func @create_mask_1d_scalable(%num_elems : index) -> vector<[4]xi1> {
 
 // -----
 
-//===----------------------------------------------------------------------===//
-// vector.flat_transpose
-//===----------------------------------------------------------------------===//
-
-func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
-     : vector<16xf32> -> vector<16xf32>
-  return %0 : vector<16xf32>
-}
-
-// CHECK-LABEL: func @flat_transpose
-// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
-// CHECK:       %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
-// CHECK-SAME:      {columns = 4 : i32, rows = 4 : i32} :
-// CHECK-SAME:      vector<16xf32> into vector<16xf32>
-// CHECK:       return %[[T]] : vector<16xf32>
-
-// -----
-
-func.func @flat_transpose_index(%arg0: vector<16xindex>) -> vector<16xindex> {
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
-     : vector<16xindex> -> vector<16xindex>
-  return %0 : vector<16xindex>
-}
-// CHECK-LABEL: func @flat_transpose_index
-// CHECK-SAME:  %[[A:.*]]: vector<16xindex>
-// CHECK:       %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<16xindex> to vector<16xi64>
-// CHECK:       %[[T1:.*]] = llvm.intr.matrix.transpose %[[T0]]
-// CHECK-SAME:      {columns = 4 : i32, rows = 4 : i32} :
-// CHECK-SAME:      vector<16xi64> into vector<16xi64>
-// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<16xi64> to vector<16xindex>
-// CHECK:       return %[[T2]] : vector<16xindex>
-
-// -----
-
-func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
-     : vector<16xf32> -> vector<16xf32>
-  return %0 : vector<16xf32>
-}
-
-// CHECK-LABEL: func @flat_transpose
-// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
-// CHECK:       %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
-// CHECK-SAME:      {columns = 4 : i32, rows = 4 : i32} :
-// CHECK-SAME:      vector<16xf32> into vector<16xf32>
-// CHECK:       return %[[T]] : vector<16xf32>
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // vector.gather
 //
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5038646e1f026..dd00fd70ccc21 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1321,13 +1321,6 @@ func.func @transpose_dim_size_mismatch(%arg0: vector<11x7x3x2xi32>) {
 
 // -----
 
-func.func @flat_transpose_type_mismatch(%arg0: vector<16xf32>) {
-  // expected-error at +1 {{'vector.flat_transpose' op failed to verify that source operand and result have same element type}}
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf64>
-}
-
-// -----
-
 func.func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
   // expected-error at +1 {{expects operand to be a memref with identity layout}}
   %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>
@@ -1939,26 +1932,6 @@ func.func @invalid_step_2d() {
 
 // -----
 
-func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) {
-  // expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}}
-  %c = vector.matrix_multiply %a, %b {
-    lhs_rows = 2: i32,
-    lhs_columns = 2: i32 ,
-    rhs_columns = 2: i32 }
-  : (vector<[4]xf64>, vector<4xf64>) -> vector<4xf64>
-
-  return
-}
-
-// -----
-
-func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32> {
-  // expected-error @+1 {{'vector.flat_transpose' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[16]xf32>'}}
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
-     : vector<[16]xf32> -> vector<[16]xf32>
-  return %0 : vector<[16]xf32>
-}
-
 //===----------------------------------------------------------------------===//
 // vector.splat
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 10bf0f1620568..a24fe410c1440 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -738,22 +738,6 @@ func.func @transpose_int_0d(%arg0: vector<i32>) -> vector<i32> {
   return %0 : vector<i32>
 }
 
-// CHECK-LABEL: @flat_transpose_fp
-func.func @flat_transpose_fp(%arg0: vector<16xf32>) -> vector<16xf32> {
-  // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf32>
-  // CHECK: return %[[X]] : vector<16xf32>
-  return %0 : vector<16xf32>
-}
-
-// CHECK-LABEL: @flat_transpose_int
-func.func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
-  // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32>
-  %0 = vector.flat_transpose %arg0 { rows = 2: i32, columns = 8: i32 } : vector<16xi32> -> vector<16xi32>
-  // CHECK: return %[[X]] : vector<16xi32>
-  return %0 : vector<16xi32>
-}
-
 // CHECK-LABEL: @vector_load_and_store_0d_scalar_memref
 func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
                                                   %i : index, %j : index) {
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 08ac2ac5bb7d5..6960ab33925e5 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -29,7 +29,7 @@
 //      CHECK:  %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
 //      CHECK:  %[[b6:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32>
 //      CHECK:  %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
-//      CHECK:  %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
+//      CHECK:  %[[mm1:.*]] = llvm.intr.matrix.multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
 //      CHECK:  %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
 //      CHECK:  %[[mm3:.*]] = vector.insert %[[mm2]], %[[ub_1]] [0] : vector<3xf32> into vector<2x3xf32>
 //      CHECK:  %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
@@ -44,7 +44,7 @@ func.func @matmul(%arg0: vector<2x4xf32>,
 }
 
 // CHECK-LABEL: func @matmul_scalable
-// CHECK-NOT: vector.matrix_multiply
+// CHECK-NOT: llvm.intr.matrix.multiply
 func.func @matmul_scalable(%arg0: vector<2x4xf32>,
                            %arg1: vector<4x[3]xf32>,
                            %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index a730f217f027d..ca130538c908a 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -139,7 +139,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: func @transpose(
 func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
   // CHECK:       vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32>
-  // CHECK:       vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32>
+  // CHECK:       llvm.intr.matrix.transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
   // CHECK:       vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32>
   %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
   return %0 : vector<4x2xf32>
@@ -150,7 +150,7 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
 // CHECK-LABEL: func @transpose_scalable(
 func.func @transpose_scalable(%arg0: vector<2x[4]xf32>) -> vector<[4]x2xf32> {
   // CHECK-NOT:       vector.shape_cast
-  // CHECK-NOT:       vector.flat_transpose
+  // CHECK-NOT:       llvm.intr.matrix.transpose
   // CHECK:           vector.transpose
   %0 = vector.transpose %arg0, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
   return %0 : vector<[4]x2xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir
index b414242b34cc0..86bd0b1e09763 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir
@@ -57,10 +57,10 @@ func.func @entry() {
   // ( 1, 4 )    ->  ( 3, 4, 5 )
   // ( 2, 5 )
   //
-  %d = vector.flat_transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64>
-  %e = vector.flat_transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64>
-  %f = vector.flat_transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> -> vector<6xf64>
-  %g = vector.flat_transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> -> vector<6xf64>
+  %d = llvm.intr.matrix.transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64>
+  %e = llvm.intr.matrix.transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64>
+  %f = llvm.intr.matrix.transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> into vector<6xf64>
+  %g = llvm.intr.matrix.transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> into vector<6xf64>
 
   vector.print %d : vector<4xf64>
   vector.print %e : vector<4xf64>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir
index 95b178e04a2bb..55103bc686fb2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir
@@ -57,10 +57,10 @@ func.func @entry() {
   // ( 2, 3 )    ->  ( 1, 3, 5 )
   // ( 4, 5 )
   //
-  %d = vector.flat_transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64>
-  %e = vector.flat_transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64>
-  %f = vector.flat_transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> -> vector<6xf64>
-  %g = vector.flat_transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> -> vector<6xf64>
+  %d = llvm.intr.matrix.transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64>
+  %e = llvm.intr.matrix.transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64>
+  %f = llvm.intr.matrix.transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> into vector<6xf64>
+  %g = llvm.intr.matrix.transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> into vector<6xf64>
 
   vector.print %d : vector<4xf64>
   vector.print %e : vector<4xf64>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir
index 8f75ec98465ca..09941192cbc42 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir
@@ -39,7 +39,7 @@ func.func @entry() {
   //           x          =                  |/ | column-major!
   // ( 1, 3 )     (5, 7)     ( 19, 27 )
   //
-  %c = vector.matrix_multiply %a, %b
+  %c = llvm.intr.matrix.multiply %a, %b
       { lhs_rows = 2: i32, lhs_columns = 2: i32 , rhs_columns = 2: i32 }
       : (vector<4xf64>, vector<4xf64>) -> vector<4xf64>
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir
index b7d27c45226ef..d5f511c8ac119 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir
@@ -39,7 +39,7 @@ func.func @entry() {
   //           x          =
   // ( 2, 3 )     (6, 7)     ( 26, 31 )
   //
-  %c = vector.matrix_multiply %a, %b
+  %c = llvm.intr.matrix.multiply %a, %b
       { lhs_rows = 2: i32, lhs_columns = 2: i32 , rhs_columns = 2: i32 }
       : (vector<4xf64>, vector<4xf64>) -> vector<4xf64>
 

>From 4d86070cf67b75563bdeeda1e1c8428213f1815d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 9 Jul 2025 19:45:42 +0100
Subject: [PATCH 2/5] fixup! [mlir][vector] Remove MatrixMultiplyOp and
 FlatTransposeOp from Vector dialect

Revert the dependency of the Vector transforms on the LLVM Dialect.
---
 .../Vector/Transforms/LoweringPatterns.h      |   6 +
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 204 ++++++++++++++++++
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   7 +
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 -
 .../Vector/Transforms/LowerVectorContract.cpp | 171 +--------------
 .../Transforms/LowerVectorTranspose.cpp       |  15 --
 ...tract-to-matrix-intrinsics-transforms.mlir |  91 ++++----
 .../Vector/vector-transpose-lowering.mlir     |  32 ---
 ...nspose-to-matrix-intrinsics-transform.mlir |  18 ++
 9 files changed, 287 insertions(+), 258 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 6761cd65c5009..fd742188246ce 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -303,6 +303,12 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
 void populateVectorToFromElementsToShuffleTreePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// TODO
+void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns);
+
+/// TODO
+void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 760dbea150758..6e6400ba6e60e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -2000,6 +2000,190 @@ struct VectorScalableStepOpLowering
   }
 };
 
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+///    %flattened_a = vector.shape_cast %a
+///    %flattened_b = vector.shape_cast %b
+///    %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
+///    %d = vector.shape_cast %%flattened_d
+///    %e = add %c, %d
+/// ```
+/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when vectorContractLowering is set to Matmul and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToMatmulOpLowering
+    : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
+public:
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  ContractionOpToMatmulOpLowering(
+      vector::VectorContractLowering vectorContractLowering,
+      MLIRContext *context, PatternBenefit benefit = 100)
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
+};
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+///    %mta = maybe_transpose
+///    %mtb = maybe_transpose
+///    %flattened_a = vector.shape_cast %mta
+///    %flattened_b = vector.shape_cast %mtb
+///    %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
+///    %mtd = vector.shape_cast %flattened_d
+///    %d = maybe_untranspose %mtd
+///    %e = add %c, %d
+/// ```
+//
+/// This only kicks in when vectorContractLowering is set to `Matmul`.
+/// vector.transpose operations are inserted if the vector.contract op is not a
+/// row-major matrix multiply.
+///
+/// Scalable vectors are not supported.
+FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rew) const {
+  // TODO: Support vector.mask.
+  if (maskOp)
+    return failure();
+
+  auto iteratorTypes = op.getIteratorTypes().getValue();
+  if (!isParallelIterator(iteratorTypes[0]) ||
+      !isParallelIterator(iteratorTypes[1]) ||
+      !isReductionIterator(iteratorTypes[2]))
+    return failure();
+
+  Type opResType = op.getType();
+  VectorType vecType = dyn_cast<VectorType>(opResType);
+  if (vecType && vecType.isScalable()) {
+    // Note - this is sufficient to reject all cases with scalable vectors.
+    return failure();
+  }
+
+  Type elementType = op.getLhsType().getElementType();
+  if (!elementType.isIntOrFloat())
+    return failure();
+
+  Type dstElementType = vecType ? vecType.getElementType() : opResType;
+  if (elementType != dstElementType)
+    return failure();
+
+  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
+  // Bail out if the contraction cannot be put in this form.
+  MLIRContext *ctx = op.getContext();
+  Location loc = op.getLoc();
+  AffineExpr m, n, k;
+  bindDims(rew.getContext(), m, n, k);
+  // LHS must be A(m, k) or A(k, m).
+  Value lhs = op.getLhs();
+  auto lhsMap = op.getIndexingMapsArray()[0];
+  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
+    lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
+  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
+    return failure();
+
+  // RHS must be B(k, n) or B(n, k).
+  Value rhs = op.getRhs();
+  auto rhsMap = op.getIndexingMapsArray()[1];
+  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
+    rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
+  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
+    return failure();
+
+  // At this point lhs and rhs are in row-major.
+  VectorType lhsType = cast<VectorType>(lhs.getType());
+  VectorType rhsType = cast<VectorType>(rhs.getType());
+  int64_t lhsRows = lhsType.getDimSize(0);
+  int64_t lhsColumns = lhsType.getDimSize(1);
+  int64_t rhsColumns = rhsType.getDimSize(1);
+
+  Type flattenedLHSType =
+      VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
+
+  Type flattenedRHSType =
+      VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
+  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
+
+  Value mul = rew.create<LLVM::MatrixMultiplyOp>(
+      loc,
+      VectorType::get(lhsRows * rhsColumns,
+                      cast<VectorType>(lhs.getType()).getElementType()),
+      lhs, rhs, lhsRows, lhsColumns, rhsColumns);
+
+  mul = rew.create<vector::ShapeCastOp>(
+      loc,
+      VectorType::get({lhsRows, rhsColumns},
+                      getElementTypeOrSelf(op.getAcc().getType())),
+      mul);
+
+  // ACC must be C(m, n) or C(n, m).
+  auto accMap = op.getIndexingMapsArray()[2];
+  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
+    mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
+  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
+    llvm_unreachable("invalid contraction semantics");
+
+  Value res =
+      isa<IntegerType>(elementType)
+          ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
+          : static_cast<Value>(
+                rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
+
+  return res;
+}
+
+/// Progressive lowering of TransposeOp.
+/// One:
+///   %x = vector.transpose %y, [1, 0]
+/// is replaced by:
+///   %z = arith.constant dense<0.000000e+00>
+///   %0 = vector.extract %y[0, 0]
+///   %1 = vector.insert %0, %z [0, 0]
+///   ..
+///   %x = vector.insert .., .. [.., ..]
+class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern<TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    Value input = op.getVector();
+    VectorType inputType = op.getSourceVectorType();
+    VectorType resType = op.getResultVectorType();
+
+    if (inputType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "This lowering does not support scalable vectors");
+
+    // Set up convenience transposition table.
+    ArrayRef<int64_t> transp = op.getPermutation();
+
+    if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
+      return failure();
+    }
+
+    Type flattenedType =
+        VectorType::get(resType.getNumElements(), resType.getElementType());
+    auto matrix =
+        rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
+    auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
+    auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
+    Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
+        loc, flattenedType, matrix, rows, columns);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2007,6 +2191,26 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
   patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
 }
 
+/// Pattern to lower `vector.contract` to `llvm.intr.matrix.multiply`.
+///
+/// Given the high benefit, this will be prioriotised over other
+/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
+/// only run this registration conditionally.
+void mlir::vector::populateVectorContractToMatrixMultiply(
+    RewritePatternSet &patterns) {
+  patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), 100);
+}
+
+/// Pattern to lower `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
+///
+/// Given the high benefit, this will be prioriotised over other
+/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
+/// only run this registration conditionally.
+void mlir::vector::populateVectorTransposeToFlatTranspose(
+    RewritePatternSet &patterns) {
+  patterns.add<TransposeOpLowering>(patterns.getContext(), 100);
+}
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 150ef41a24373..b8eb0f58e98ca 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -70,10 +70,17 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorBitCastLoweringPatterns(patterns);
     populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
+    if (vectorContractLowering == vector::VectorContractLowering::Matmul) {
+      populateVectorContractToMatrixMultiply(patterns);
+    }
     populateVectorMaskOpLoweringPatterns(patterns);
     populateVectorShapeCastLoweringPatterns(patterns);
     populateVectorInterleaveLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
+    if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat) {
+      populateVectorTransposeToFlatTranspose(patterns);
+    }
+    populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index ff8a1f9a52910..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -50,7 +50,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRTensorDialect
   MLIRTransforms
   MLIRVectorDialect
-  MLIRLLVMDialect
   MLIRVectorInterfaces
   MLIRVectorUtils
   )
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 147821fe2dade..fc6c90f5132c7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -201,49 +201,6 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
 
 namespace {
 
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
-/// ```
-///    %flattened_a = vector.shape_cast %a
-///    %flattened_b = vector.shape_cast %b
-///    %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
-///    %d = vector.shape_cast %%flattened_d
-///    %e = add %c, %d
-/// ```
-/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
-//
-/// This only kicks in when vectorContractLowering is set to Matmul and
-/// the vector.contract op is a row-major matrix multiply.
-class ContractionOpToMatmulOpLowering
-    : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
-public:
-  using MaskableOpRewritePattern::MaskableOpRewritePattern;
-
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
-  ContractionOpToMatmulOpLowering(
-      vector::VectorContractLowering vectorContractLowering,
-      MLIRContext *context, PatternBenefit benefit = 1,
-      FilterConstraintType constraint = defaultFilter)
-      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorContractLowering(vectorContractLowering),
-        filter(std::move(constraint)) {}
-
-  FailureOr<Value>
-  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
-                            PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorContractLowering vectorContractLowering;
-  FilterConstraintType filter;
-};
-
 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
 /// semantics to a reduction_size-unrolled sequence:
 /// ```
@@ -939,24 +896,18 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
 
-  ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
+  ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal1 =
       pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal1))
     return newVal1;
 
-  ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
+  ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal2 =
       pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal2))
     return newVal2;
 
-  ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
-  FailureOr<Value> newVal3 =
-      pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
-  if (!failed(newVal3))
-    return newVal3;
-
   ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal4 =
       pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
@@ -1264,121 +1215,6 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
   }
 };
 
-/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
-/// ```
-///    %mta = maybe_transpose
-///    %mtb = maybe_transpose
-///    %flattened_a = vector.shape_cast %mta
-///    %flattened_b = vector.shape_cast %mtb
-///    %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
-///    %mtd = vector.shape_cast %flattened_d
-///    %d = maybe_untranspose %mtd
-///    %e = add %c, %d
-/// ```
-//
-/// This only kicks in when vectorContractLowering is set to `Matmul`.
-/// vector.transpose operations are inserted if the vector.contract op is not a
-/// row-major matrix multiply.
-///
-/// Scalable vectors are not supported.
-FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
-    vector::ContractionOp op, MaskingOpInterface maskOp,
-    PatternRewriter &rew) const {
-  // TODO: Support vector.mask.
-  if (maskOp)
-    return failure();
-
-  if (vectorContractLowering != vector::VectorContractLowering::Matmul)
-    return failure();
-  if (failed(filter(op)))
-    return failure();
-
-  auto iteratorTypes = op.getIteratorTypes().getValue();
-  if (!isParallelIterator(iteratorTypes[0]) ||
-      !isParallelIterator(iteratorTypes[1]) ||
-      !isReductionIterator(iteratorTypes[2]))
-    return failure();
-
-  Type opResType = op.getType();
-  VectorType vecType = dyn_cast<VectorType>(opResType);
-  if (vecType && vecType.isScalable()) {
-    // Note - this is sufficient to reject all cases with scalable vectors.
-    return failure();
-  }
-
-  Type elementType = op.getLhsType().getElementType();
-  if (!elementType.isIntOrFloat())
-    return failure();
-
-  Type dstElementType = vecType ? vecType.getElementType() : opResType;
-  if (elementType != dstElementType)
-    return failure();
-
-  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
-  // Bail out if the contraction cannot be put in this form.
-  MLIRContext *ctx = op.getContext();
-  Location loc = op.getLoc();
-  AffineExpr m, n, k;
-  bindDims(rew.getContext(), m, n, k);
-  // LHS must be A(m, k) or A(k, m).
-  Value lhs = op.getLhs();
-  auto lhsMap = op.getIndexingMapsArray()[0];
-  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
-    lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
-  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
-    return failure();
-
-  // RHS must be B(k, n) or B(n, k).
-  Value rhs = op.getRhs();
-  auto rhsMap = op.getIndexingMapsArray()[1];
-  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
-    rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
-  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
-    return failure();
-
-  // At this point lhs and rhs are in row-major.
-  VectorType lhsType = cast<VectorType>(lhs.getType());
-  VectorType rhsType = cast<VectorType>(rhs.getType());
-  int64_t lhsRows = lhsType.getDimSize(0);
-  int64_t lhsColumns = lhsType.getDimSize(1);
-  int64_t rhsColumns = rhsType.getDimSize(1);
-
-  Type flattenedLHSType =
-      VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
-  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
-
-  Type flattenedRHSType =
-      VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
-  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
-
-  Value mul = rew.create<LLVM::MatrixMultiplyOp>(
-      loc,
-      VectorType::get(lhsRows * rhsColumns,
-                      cast<VectorType>(lhs.getType()).getElementType()),
-      lhs, rhs, lhsRows, lhsColumns, rhsColumns);
-
-  mul = rew.create<vector::ShapeCastOp>(
-      loc,
-      VectorType::get({lhsRows, rhsColumns},
-                      getElementTypeOrSelf(op.getAcc().getType())),
-      mul);
-
-  // ACC must be C(m, n) or C(n, m).
-  auto accMap = op.getIndexingMapsArray()[2];
-  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
-    mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
-  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
-    llvm_unreachable("invalid contraction semantics");
-
-  Value res =
-      isa<IntegerType>(elementType)
-          ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
-          : static_cast<Value>(
-                rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
-
-  return res;
-}
 } // namespace
 
 void mlir::vector::populateVectorContractLoweringPatterns(
@@ -1387,8 +1223,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
     bool disableOuterProductLowering) {
   if (!disableOuterProductLowering)
     patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
-  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
-               ContractionOpToOuterProductOpLowering>(
+  patterns.add<ContractionOpLowering, ContractionOpToOuterProductOpLowering>(
       vectorContractLoweringOption, patterns.getContext(), benefit);
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index b5290cb14e079..bb9a6832146e8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -328,21 +328,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
-    // Handle a true 2-D matrix transpose differently when requested.
-    if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
-        resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
-      Type flattenedType =
-          VectorType::get(resType.getNumElements(), resType.getElementType());
-      auto matrix =
-          rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
-      auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
-      auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
-      Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
-          loc, flattenedType, matrix, rows, columns);
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
-      return success();
-    }
-
     // Generate unrolled extract/insert ops. We do not unroll the rightmost
     // (i.e., highest-order) dimensions that are not transposed and leave them
     // in vector form to improve performance. Therefore, we prune those
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 6960ab33925e5..3950e54006eec 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s  --convert-vector-to-llvm='vector-contract-lowering=matmul' | FileCheck %s
 
 #matmat_accesses = [
   affine_map<(i, j, k) -> (i, k)>,
@@ -10,31 +10,54 @@
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
-// CHECK-LABEL: func @matmul
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-//  CHECK-DAG:  %[[ub:.*]] = ub.poison : vector<8xf32>
-//  CHECK-DAG:  %[[ub_0:.*]] = ub.poison : vector<12xf32>
-//  CHECK-DAG:  %[[ub_1:.*]] = ub.poison : vector<2x3xf32>
-//      CHECK:  %[[a0:.*]] = vector.extract %[[A]][0] : vector<4xf32> from vector<2x4xf32>
-//      CHECK:  %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[ub]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
-//      CHECK:  %[[a2:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>
-//      CHECK:  %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
-//      CHECK:  %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<4x3xf32>
-//      CHECK:  %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[ub_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
-//      CHECK:  %[[b2:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<4x3xf32>
-//      CHECK:  %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
-//      CHECK:  %[[b4:.*]] = vector.extract %[[B]][2] : vector<3xf32> from vector<4x3xf32>
-//      CHECK:  %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
-//      CHECK:  %[[b6:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32>
-//      CHECK:  %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
-//      CHECK:  %[[mm1:.*]] = llvm.intr.matrix.multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
-//      CHECK:  %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
-//      CHECK:  %[[mm3:.*]] = vector.insert %[[mm2]], %[[ub_1]] [0] : vector<3xf32> into vector<2x3xf32>
-//      CHECK:  %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
-//      CHECK:  %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
-//      CHECK:  %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32>
+// CHECK-LABEL:   func.func @matmul(
+// CHECK-SAME:                      %[[ARG0:.*]]: vector<2x4xf32>,
+// CHECK-SAME:                      %[[ARG1:.*]]: vector<4x3xf32>,
+// CHECK-SAME:                      %[[ARG2:.*]]: vector<2x3xf32>) -> vector<2x3xf32> {
+// CHECK:           %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>>
+// CHECK:           %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
+// CHECK:           %[[VAL_2:.*]] = ub.poison : vector<2x3xf32>
+// CHECK:           %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
+// CHECK:           %[[POISON_RHS:.*]] = ub.poison : vector<12xf32>
+// CHECK:           %[[POISON_LHS:.*]] = ub.poison : vector<8xf32>
+
+// ===> Extract LHS
+//       | ROW_1 |
+//       | ----- | --> | ROW_1 | ROW_2 |
+//       | ROW_2 |
+//
+// CHECK:           %[[LHS_ROW_1:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<2 x vector<4xf32>>
+// CHECK:           %[[TP_1:.*]] = llvm.shufflevector %[[LHS_ROW_1]], %[[LHS_ROW_1]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>
+// CHECK:           %[[TP_2:.*]] = llvm.shufflevector %[[TP_1]], %[[POISON_LHS]] [0, 1, 2, 3, 12, 13, 14, 15] : vector<8xf32>
+// CHECK:           %[[LHS_ROW_2:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.array<2 x vector<4xf32>>
+// CHECK:           %[[TP_3:.*]] = llvm.shufflevector %[[LHS_ROW_2]], %[[LHS_ROW_2]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>
+// CHECK:           %[[LHS:.*]] = llvm.shufflevector %[[TP_3]], %[[TP_2]] [8, 9, 10, 11, 0, 1, 2, 3] : vector<8xf32>
+
+// == Extract RHS
+//       | ROW_1 |
+//       | ----- |
+//       | ROW_2 |
+//       | ----- | --> | ROW_1 | ROW_2 | ROW_3 | ROW_4 |
+//       | ROW_3 |
+//       | ----- |
+//       | ROW_4 |
+// CHECK:           %[[RHS_ROW_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.array<4 x vector<3xf32>>
+// CHECK:           %[[TP_4:.*]] = llvm.shufflevector %[[RHS_ROW_1]], %[[RHS_ROW_1]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32>
+// CHECK:           %[[TP_5:.*]] = llvm.shufflevector %[[TP_4]], %[[POISON_RHS]] [0, 1, 2, 15, 16, 17, 18, 19, 20, 21, 22, 23] : vector<12xf32>
+// CHECK:           %[[RHS_ROW_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.array<4 x vector<3xf32>>
+// CHECK:           %[[TP_6:.*]] = llvm.shufflevector %[[RHS_ROW_2]], %[[RHS_ROW_2]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32>
+// CHECK:           %[[TP_7:.*]] = llvm.shufflevector %[[TP_6]], %[[TP_5]] [12, 13, 14, 0, 1, 2, 18, 19, 20, 21, 22, 23] : vector<12xf32>
+// CHECK:           %[[RHS_ROW_3:.*]] = llvm.extractvalue %[[VAL_0]][2] : !llvm.array<4 x vector<3xf32>>
+// CHECK:           %[[TP_8:.*]] = llvm.shufflevector %[[RHS_ROW_3]], %[[RHS_ROW_3]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32>
+// CHECK:           %[[TP_9:.*]] = llvm.shufflevector %[[TP_8]], %[[TP_7]] [12, 13, 14, 15, 16, 17, 0, 1, 2, 21, 22, 23] : vector<12xf32>
+// CHECK:           %[[RHS_ROW_4:.*]] = llvm.extractvalue %[[VAL_0]][3] : !llvm.array<4 x vector<3xf32>>
+// CHECK:           %[[TP_10:.*]] = llvm.shufflevector %[[RHS_ROW_4]], %[[RHS_ROW_4]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32>
+// CHECK:           %[[RHS:.*]] = llvm.shufflevector %[[TP_10]], %[[TP_9]] [12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 1, 2] : vector<12xf32>
+
+// ===> Matrix multiply
+// CHECK:           %[[MM:.*]] = llvm.intr.matrix.multiply %[[LHS]], %[[RHS]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
+// CHECK:           %[[RES:.*]] = arith.addf %[[ARG2]], %{{.*}} : vector<2x3xf32>
+// CHECK:           return %[[RES]] : vector<2x3xf32>
 func.func @matmul(%arg0: vector<2x4xf32>,
                   %arg1: vector<4x3xf32>,
                   %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
@@ -52,19 +75,3 @@ func.func @matmul_scalable(%arg0: vector<2x4xf32>,
     : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %f = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
-
-    transform.apply_patterns to %f {
-      transform.apply_patterns.vector.lower_contraction lowering_strategy = "matmulintrinsics"
-    } : !transform.any_op
-
-    transform.apply_patterns to %f {
-      transform.apply_patterns.vector.lower_shape_cast
-    } : !transform.any_op
-    transform.yield
-  }
-}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index ca130538c908a..7838aad1825bc 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -136,38 +136,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: func @transpose(
-func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
-  // CHECK:       vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32>
-  // CHECK:       llvm.intr.matrix.transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
-  // CHECK:       vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
-  return %0 : vector<4x2xf32>
-}
-
-/// Scalable vectors are not supported
-
-// CHECK-LABEL: func @transpose_scalable(
-func.func @transpose_scalable(%arg0: vector<2x[4]xf32>) -> vector<[4]x2xf32> {
-  // CHECK-NOT:       vector.shape_cast
-  // CHECK-NOT:       llvm.intr.matrix.transpose
-  // CHECK:           vector.transpose
-  %0 = vector.transpose %arg0, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
-  return %0 : vector<[4]x2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
-    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
-    transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose"
-    } : !transform.op<"func.func">
-    transform.yield
-  }
-}
-
-// -----
-
 // CHECK-LABEL: @transpose_shuffle16x16xf32(
 func.func @transpose_shuffle16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16xf32> {
   // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir
new file mode 100644
index 0000000000000..94689fa0dfb88
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s --convert-vector-to-llvm='vector-transpose-lowering=flat' --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @transpose(
+func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
+  // CHECK:       llvm.intr.matrix.transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}
+
+/// Scalable vectors are not supported
+
+// CHECK-LABEL: func @transpose_scalable(
+func.func @transpose_scalable(%arg0: vector<2x[4]xf32>) -> vector<[4]x2xf32> {
+  // CHECK-NOT:       llvm.intr.matrix.transpose
+  // CHECK:           vector.transpose
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+  return %0 : vector<[4]x2xf32>
+}

>From 4b70c322c0da9f2d8d68c11ca8693df0b942cd8e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 16 Jul 2025 14:26:08 +0100
Subject: [PATCH 3/5] fixup! fixup! [mlir][vector] Remove MatrixMultiplyOp and
 FlatTransposeOp from Vector dialect

Address comment from Adam
---
 .../Vector/Transforms/LoweringPatterns.h      | 24 +++++++++++---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 32 +++++--------------
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  7 +++-
 3 files changed, 34 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index fd742188246ce..9505ada8a5a5e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -303,11 +303,27 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
 void populateVectorToFromElementsToShuffleTreePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// TODO
-void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns);
+/// Populate the pattern set with the following patterns:
+///
+/// [ContractionOpToMatmulOpLowering]
+/// Lowers `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
+///
+/// Given the high benefit, this will be prioriotised over other
+/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
+/// only run this registration conditionally.
+void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
+                                            PatternBenefit benefit = 100);
 
-/// TODO
-void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns);
+/// Populate the pattern set with the following patterns:
+///
+/// [TransposeOpLowering]
+/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
+///
+/// Given the high benefit, this will be prioriotised over other
+/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
+/// only run this registration conditionally.
+void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
+                                            PatternBenefit benefit = 100);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6e6400ba6e60e..e4c1880738e4d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -2139,16 +2139,9 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
   return res;
 }
 
-/// Progressive lowering of TransposeOp.
-/// One:
-///   %x = vector.transpose %y, [1, 0]
-/// is replaced by:
-///   %z = arith.constant dense<0.000000e+00>
-///   %0 = vector.extract %y[0, 0]
-///   %1 = vector.insert %0, %z [0, 0]
-///   ..
-///   %x = vector.insert .., .. [.., ..]
-class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+/// Lowers vector.transpose to llvm.intr.matrix.transpose
+class TransposeOpToMatrixTransposeOpLowering
+    : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern<TransposeOp>::OpRewritePattern;
 
@@ -2191,24 +2184,15 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
   patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
 }
 
-/// Pattern to lower `vector.contract` to `llvm.intr.matrix.multiply`.
-///
-/// Given the high benefit, this will be prioriotised over other
-/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
-/// only run this registration conditionally.
 void mlir::vector::populateVectorContractToMatrixMultiply(
-    RewritePatternSet &patterns) {
-  patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), 100);
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), benefit);
 }
 
-/// Pattern to lower `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
-///
-/// Given the high benefit, this will be prioriotised over other
-/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
-/// only run this registration conditionally.
 void mlir::vector::populateVectorTransposeToFlatTranspose(
-    RewritePatternSet &patterns) {
-  patterns.add<TransposeOpLowering>(patterns.getContext(), 100);
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<TransposeOpToMatrixTransposeOpLowering>(patterns.getContext(),
+                                                       benefit);
 }
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index b8eb0f58e98ca..66d2ffbde1751 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -71,6 +71,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
     if (vectorContractLowering == vector::VectorContractLowering::Matmul) {
+      // This pattern creates a dependency on the LLVM dialect, hence we don't
+      // include it in `populateVectorContractLoweringPatterns` that is part of
+      // the Vector dialect (and should not depend on LLVM).
       populateVectorContractToMatrixMultiply(patterns);
     }
     populateVectorMaskOpLoweringPatterns(patterns);
@@ -78,9 +81,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorInterleaveLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
     if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat) {
+      // This pattern creates a dependency on the LLVM dialect, hence we don't
+      // include it in `populateVectorTransposeLoweringPatterns` that is part of
+      // the Vector dialect (and should not depend on LLVM).
       populateVectorTransposeToFlatTranspose(patterns);
     }
-    populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,

>From 18e7d785d409629092224c86fb916858c3d719f7 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 16 Jul 2025 15:35:32 +0100
Subject: [PATCH 4/5] fixup! fixup! fixup! [mlir][vector] Remove
 MatrixMultiplyOp and FlatTransposeOp from Vector dialect

Fix comment
---
 .../include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 9505ada8a5a5e..8a67712ecd8c7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -306,7 +306,7 @@ void populateVectorToFromElementsToShuffleTreePatterns(
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
-/// Lowers `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
+/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
 ///
 /// Given the high benefit, this will be prioriotised over other
 /// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
@@ -317,7 +317,7 @@ void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
 /// Populate the pattern set with the following patterns:
 ///
 /// [TransposeOpLowering]
-/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
+/// Lowers `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
 ///
 /// Given the high benefit, this will be prioriotised over other
 /// contract-lowering patterns. As such, the convert-vector-to-llvm pass will

>From 145e1b9716d4727d20a9043d8c81fa2cab27ddb4 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 17 Jul 2025 16:49:23 +0100
Subject: [PATCH 5/5] fixup! fixup! fixup! fixup! [mlir][vector] Remove
 MatrixMultiplyOp and FlatTransposeOp from Vector dialect

Fix comment
---
 .../include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 8a67712ecd8c7..e03f0dabece52 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -309,7 +309,7 @@ void populateVectorToFromElementsToShuffleTreePatterns(
 /// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
 ///
 /// Given the high benefit, this will be prioriotised over other
-/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
+/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
 /// only run this registration conditionally.
 void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 100);
@@ -320,7 +320,7 @@ void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
 /// Lowers `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
 ///
 /// Given the high benefit, this will be prioriotised over other
-/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
+/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
 /// only run this registration conditionally.
 void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 100);



More information about the Mlir-commits mailing list