[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` to SVE FEAT_BF16 operations (PR #147052)

Momchil Velikov llvmlistbot at llvm.org
Fri Jul 4 06:26:21 PDT 2025


https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/147052

This patch adds lowering of Bfloat16 widening matrix multiply and accumulate `vector.contract`, by parametrising and refactoring the pattern for 8-bit integers.

>From c87e3c0bfc75d941af4748ee4d804ddb4f08ced3 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 4 Jul 2025 12:53:05 +0000
Subject: [PATCH] [MLIR][AArch64] Lower `vector.contract` to SVE FEAT_BF16
 operations

This patch adds lowering of Bfloat16 widening matrix multiply and
accumulate `vector.contract`, by parametrising and
refactoring the pattern for 8-bit integers.
---
 mlir/include/mlir/Conversion/Passes.td        |   4 +
 .../TransformOps/ArmSVEVectorTransformOps.td  |  12 +-
 .../Dialect/ArmSVE/Transforms/Transforms.h    |   2 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   3 +
 .../LowerContractionToNeonI8MMPattern.cpp     |   2 +-
 .../TransformOps/ArmSVEVectorTransformOps.cpp |   7 +-
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   2 +-
 .../Transforms/LowerContractToSVEPatterns.cpp | 607 ++++++++++++++++++
 .../LowerContractionToSVEI8MMPattern.cpp      | 366 -----------
 .../Vector/CPU/ArmSVE/vector-bfmmla.mlir      | 105 +++
 .../CPU/ArmSVE/vector-contract-bfmmla.mlir    | 201 ++++++
 .../CPU/ArmSVE/vector-contract-i8mm.mlir      |   6 +-
 12 files changed, 944 insertions(+), 373 deletions(-)
 create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
 delete mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..4f304b39a0528 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of Arm FEAT_I8MM instructions while lowering "
            "the vector dialect.">,
+    Option<"armBF16", "enable-arm-bf16",
+           "bool", /*default=*/"false",
+           "Enables the use of Arm FEAT_BF16 instructions while lowering "
+           "the vector dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
index 53784982be6dc..81b3c736b93f3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
@@ -12,7 +12,7 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
 
-def ApplyArmSVELowerContractionPatternsOp
+def ApplyArmSVELowerContractionToI8MMPatternsOp
     : Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_i8mm",
          [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
@@ -23,4 +23,14 @@ def ApplyArmSVELowerContractionPatternsOp
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyArmSVELowerContractionToBFMMLAPatternsOp
+    : Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_bfmmla",
+         [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contraction-like operations should be lowered to
+    finer-grained vector primitives using the ArmSVE dialect.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
 #endif // ARMSVE_VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 232e2be29e574..de160dbf8ed94 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -23,6 +23,8 @@ void populateArmSVELegalizeForLLVMExportPatterns(
 void populateLowerContractionToSVEI8MMPatternPatterns(
     RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 67c0eca15638a..4d74aabcaa50d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -89,6 +89,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
       if (armSVE)
         populateLowerContractionToSVEI8MMPatternPatterns(patterns);
     }
+    if (armBF16)
+      populateLowerContractionToSVEBFMMLAPatterns(patterns);
+
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 7180884c77e98..a95fc51d562c2 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -12,7 +12,7 @@
 // TODO: There may be opportunities to unify this with a similar pattern
 // for SVE. See:
 //   https://github.com/llvm/llvm-project/issues/145559
-//   LowerContractionToSVEI8MMPattern.cpp
+//   LowerContractToSVEPatterns.cpp
 //
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
index b2ca4fc1eaa8c..8572c34c8b12b 100644
--- a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -18,11 +18,16 @@ using namespace mlir;
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//
 
-void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
+void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
 }
 
+void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  mlir::populateLowerContractionToSVEBFMMLAPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 65f98b44b1b69..c29eaca244b4a 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRArmSVETransforms
   LegalizeForLLVMExport.cpp
   LegalizeVectorStorage.cpp
-  LowerContractionToSVEI8MMPattern.cpp
+  LowerContractToSVEPatterns.cpp
 
   DEPENDS
   MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
new file mode 100644
index 0000000000000..2987287afe9cd
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -0,0 +1,607 @@
+//===- LowerContractToSVEPatterns.cpp - Contract to I8MM/BF16 ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the SVE FEAT_I8MM and FEAT_BF16 extensions.
+//
+// TODO: There may be opportunities to unify this with a similar pattern
+// for Neon. See:
+//   https://github.com/llvm/llvm-project/issues/145559
+//   LowerContractionToNeonI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+
+namespace {
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `i8` to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  // If the operand is not defined by an explicit extend operation of the
+  // accepted operation type allow for an implicit sign-extension.
+  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+      auto vTy = cast<VectorType>(v.getType());
+      if (!vTy.getElementType().isSignlessInteger(8))
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  // If the operand is defined by an explicit extend operation of the accepted
+  // operation type, check it's extended from `i8` to `i32`.
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy || !inTy.getElementType().isSignlessInteger(8))
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!outTy || !outTy.getElementType().isSignlessInteger(32))
+    return {};
+
+  return inOp;
+}
+
+/// This class encapsulates the algorithm and parametrisation (in terms of types
+/// and dimensions) of lowering a `vector.contract` to "primitive" matrix
+/// multiplication operations of the SVE dialect (here "primitive" would mean
+/// corresponding to a single target instruction).
+///
+/// Supported are lowering to FEAT_I8MM `smmla`, `ummla`, and `usmmla`, and to
+/// FEAT_BF16 `bfmmla`. All the transformations are very similar to each other
+/// for concreteness the description below is given for `smmla`.
+///
+/// The lowering triggers for a contraction operation that performs a matrix
+/// multiply of two 8-bit integer matrix tiles with logical dimensions
+/// <Mx8> and <8x[N]> for the left-hand side (LHS) and the right-hand side
+/// (RHS), respectively, added to a 32-bit integer accumulator operand (ACC)
+/// with dimensions <Mx[N]>, yielding a <Mx[N]> 32-bit integer result (OUT).
+///
+/// The operands' shapes are such that the operands can be evenly split into
+/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
+/// instructions. The intent is that M and N are chosen (by higher level
+/// transforms) in such a way as to maximise register usage. The main use case
+/// we envision as of now is MMT4D, thus the RHS operand is expected
+/// pre-transposed.
+///
+/// The matrix multiplication is performed by unrolling the usual tiled matrix
+/// multiplication algorithm using sub-tiles with dimensions <2x8> for the
+/// LHS, <8x[2]> for the RHS, and <2x[2]> for the result and the input
+/// accumulator.
+///
+/// One way to illustrate the operation is as follows:
+///
+/// RHS<8x[N]>:       <8x[2]> <8x[2]> ... <8x[2]>
+///                 +-----------------------------
+/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///           <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///            ...  |   ...     ...   ...   ...
+///           <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///
+/// The RHS operand is unpacked into N/2 values, each representing a sequence
+/// of VSCALE number of sub-tiles with dimensions <8x2>.
+/// The LHS operand is initially unpacked into M/2 values, each representing a
+/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
+/// VSCALE times. Multiplying thus replicated LHS sub-tile by the corresponding
+/// RHS sub-tile correctly computes an entire result sub-tile.
+/// The 2x2 sub-tiles of the ACC and OUT have rows that are not adjacent
+/// (in memory or when imposing a row-major layout on the 2D vector value).
+/// Reading the ACC is implemented as reading two consecutive rows and
+/// interleaving the by pairs to obtain a vector having length twice the length
+/// of an ACC row. This vector now is a sequence of one-dimensional tiles with
+/// the exact layout needed by the `smmla`/`bfmmla`/etc instructions, which
+/// tiles are extracted one by one. For illustration, if we have an 2x4 ACC tile
+///   a0 a1 b0 b1
+///   a2 a3 b2 b3
+/// we read the two rows as separate values and then interleave by pairs
+/// to obtain
+///   a0 a1 a2 a3 b0 b1 b2 b3
+/// from which we extract `a0 a1 a2 a3` and `b0 b1 b2 b3`.
+///
+/// Writing the OUT tile is done by the reverse of the above procedure,
+/// concatenate two "flattened" sub-tiles into
+///   c0 c1 c2 c3 d0 d1 d2 d3
+/// deinterleave by pairs to obtain as separate values
+///   c0 c1 d0 d1
+///   c2 c3 d2 d3
+/// which are then inserted into the final result.
+///
+/// Multiplication of a signed LHS by an unsigned LHS is performed by
+/// swapping the order of the operands and emitting an `usmmla` (since there
+/// isn't an `summla` instruction). Therefore each ACC sub-tile needs
+/// to be transposed before the addition and the sum, an OUT sub-tile,
+/// needs to be transposed before insertion into the final result.
+/// This is done very elegantly by a modification of the above to
+/// interleave/deinterleave not by pairs, but by individual elements, e.g.
+/// after ordinary interleave we obtain
+///   a0 a2 a1 a3 b0 b2 b1 b3
+/// which is exactly the desired layout of having each individual 2x2 tile
+/// transposed.
+///
+/// All of the above readily applies to FEAT_BF16 `bfmmla` with the
+/// difference that the shapes of the LHS, RHS are <Mx4>, <4x[M]>, and
+/// respectively, that is the "K" dimension is fixed to 4, instead of 8 (like
+/// for the integer case).
+class VectorContractRewriter {
+protected:
+  // Designate the operation (resp. instruction) used to do sub-tile matrix
+  // multiplications.
+  enum class MMLA {
+    Nop,
+    SignedInt,   // smmla
+    UnsignedInt, // ummla
+    MixedInt,    // usmmla
+    Bfloat       // bfmmla
+  };
+
+  // Lower-level operation to be emitted.
+  MMLA mmlaOp = MMLA::Nop;
+
+  // The operand tiles. These are not necessarily the operends of
+  // `vector.contract`, for example they could be operands to `arith.extsi`
+  // that is in turn fed into `vector.contract`.
+  Value lhs;
+  Value rhs;
+  Value acc;
+
+  // Conventional names for matrix dimensions.
+  int64_t M = 0;
+  int64_t N = 0;
+  int64_t K = 0;
+
+  // Single-dimensional vector types for the operands of the ArmSVE dialect
+  // op.
+  VectorType flatLhsType;
+  VectorType flatRhsType;
+  VectorType flatAccType;
+
+  // Single-dimension vector type for the entire RHS tile.
+  VectorType flatRhsTileType;
+
+  // Vector type having the same number of elements as a row in the
+  // accumulator/output tile and the same element type.
+  VectorType accRowTy;
+
+  // Vector type having twice the number of elements as a row in the
+  // accumulator/output tile the same element type.
+  VectorType accRowX2Ty;
+
+  // Vector type having half the number of elements as a row in the
+  // accumulator/output tile and an integer element type with twice the bit
+  // width.
+  VectorType accRow64Ty;
+  VectorType accRowX264Ty;
+
+  // Indicate if the operands for the ArmSVE dialect operation need to be
+  // swapped. Currently this is needed in order to emulate an "summla"
+  // operation.
+  bool swapOperands = false;
+
+  // Create the matrix mulitply and accumulate operation according to
+  // `mmlaOp`.
+  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+                   Value lhs, Value rhs);
+
+  // Check general preconditions for applying the transformation, common to the
+  // integer and the bfloat16 case.
+  LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter);
+
+public:
+  VectorContractRewriter() = default;
+
+  // Do the actuall rewrite. This member function is shared by both integer and
+  // bfloat16 rewrites.
+  Value rewrite(vector::ContractionOp op, PatternRewriter &rewriter);
+};
+
+Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
+                                         Location loc, Value acc, Value lhs,
+                                         Value rhs) {
+  if (swapOperands)
+    std::swap(lhs, rhs);
+
+  switch (mmlaOp) {
+  case MMLA::SignedInt:
+    return rewriter.create<arm_sve::SmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+  case MMLA::UnsignedInt:
+    return rewriter.create<arm_sve::UmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+  case MMLA::MixedInt:
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+  case MMLA::Bfloat:
+    return rewriter.create<arm_sve::BfmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+  default:
+    llvm_unreachable("Uninitialized operation kind");
+  }
+}
+
+LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
+                                            PatternRewriter &rewriter) {
+  // Check iterator types for matrix multiplication.
+  auto itTypes = op.getIteratorTypesArray();
+  if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+      itTypes[1] != vector::IteratorType::parallel ||
+      itTypes[2] != vector::IteratorType::reduction)
+    return rewriter.notifyMatchFailure(
+        op, "iterator types do not correspond to matrix multiplication");
+
+  // Check permutation maps. For now only accept
+  //   lhs: (d0, d1, d2) -> (d0, d2)
+  //   rhs: (d0, d1, d2) -> (d1, d2)
+  //   acc: (d0, d1, d2) -> (d0, d1)
+  // This corresponds to matrix multiplication with transposed RHS.
+  if (op.getIndexingMapsArray()[0] !=
+          AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                               op.getContext()) ||
+      op.getIndexingMapsArray()[1] !=
+          AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                               op.getContext()) ||
+      op.getIndexingMapsArray()[2] != AffineMap::getMultiDimMapWithTargets(
+                                          3, ArrayRef{0u, 1u}, op.getContext()))
+    return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
+  // Check the combining kind is addition.
+  if (op.getKind() != vector::CombiningKind::ADD)
+    return rewriter.notifyMatchFailure(op, "combining kind is not an addition");
+
+  return success();
+}
+
+Value VectorContractRewriter::rewrite(vector::ContractionOp op,
+                                      PatternRewriter &rewriter) {
+  Location loc = op.getLoc();
+
+  // Extract LHS sub-tiles with logical shape <2xK>.
+  SmallVector<Value> lhsTile;
+  for (int64_t i = 0; i < M; i += 2) {
+    // Extract two consecutive rows of the LHS tile.
+    auto r0 =
+        rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i});
+    auto r1 =
+        rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i + 1});
+    // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
+    SmallVector<int64_t> shuffleIdx(2 * K);
+    std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
+    auto t = rewriter.create<vector::ShuffleOp>(loc, r0, r1, shuffleIdx);
+    // Turn it into a scalable vector.
+    auto s = rewriter.create<vector::ScalableInsertOp>(
+        loc, t, rewriter.create<ub::PoisonOp>(loc, flatLhsType), 0);
+    // Replicate the sub-tile VSCALE times to fill the entire vector.
+    auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+    lhsTile.push_back(r);
+  }
+
+  // "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
+  auto rhs = rewriter.create<vector::ShapeCastOp>(this->rhs.getLoc(),
+                                                  flatRhsTileType, this->rhs);
+
+  // Extract the RHS sub-tiles with logical shape <Kx[2]>.
+  SmallVector<Value> rhsTile;
+  for (int64_t j = 0; j < N; j += 2)
+    rhsTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+        loc, flatRhsType, rhs, j * K));
+
+  // Extract and pack the ACC sub-tiles.
+  SmallVector<Value> accTile;
+  for (int64_t i = 0; i < M; i += 2) {
+    // Extract two consecutive rows of the accumulator tile.
+    auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                 ArrayRef<int64_t>{i});
+    auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                 ArrayRef<int64_t>{i + 1});
+    Value accTileVec;
+    if (swapOperands) {
+      // We are performing the operation with swapped LHS and RHS we need to
+      // transpose each individual 2x2 tile of the accumulator and (later) the
+      // final result.
+      accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+    } else {
+      // Bitcast accumulator rows to double-width integer elements, so
+      // subsequent interleave/deinterleave work on pairs of elements.
+      auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+      auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+      // Interleave the rows, effectively flattening each 2x2 tile into 4
+      // consecutive elements.
+      auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64);
+
+      // Bitcast back to original element type.
+      accTileVec = rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
+    }
+    // Extract ACC sub-tiles.
+    for (int64_t j = 0; j < N; j += 2)
+      accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+          loc, flatAccType, accTileVec, j * 2));
+  }
+
+  // Emit sub-tile matrix multiplications.
+  SmallVector<Value> outTile;
+  for (int64_t i = 0; i < M / 2; ++i)
+    for (int64_t j = 0; j < N / 2; ++j) {
+      Value mmla = createMMLA(rewriter, loc, accTile[i * N / 2 + j], lhsTile[i],
+                              rhsTile[j]);
+      outTile.push_back(mmla);
+    }
+
+  // Unpack the OUT sub-tiles and insert into the result.
+  Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+  for (int64_t i = 0; i < M / 2; ++i) {
+    // Collect a number of sub-tiles in a row.
+    Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+    for (int64_t j = 0; j < N / 2; ++j)
+      row = rewriter.create<vector::ScalableInsertOp>(
+          loc, outTile[i * N / 2 + j], row, j * 4);
+
+    // Unpack the row to obtain two rows of the output. If we have the out
+    // sub-tiles transposed we obtain two consecutive output rows by
+    // separating even and odd elements, i.e. a simple deinterleave.
+    // Otherwise, the interleave is by pairs.
+    Value out0, out1;
+    if (swapOperands) {
+      auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+      out0 = tmp.getRes1();
+      out1 = tmp.getRes2();
+    } else {
+      // Deinterleave by pairs.
+      auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+      auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+      // Bitcast back into original element type and insert into the result.
+      out0 =
+          rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes1());
+      out1 =
+          rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes2());
+    }
+    result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+    result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+  }
+
+  return result;
+}
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+  // Check the specific preconditions for the integer case. Initialise
+  // parametrisation types and dimensions.
+  LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) {
+
+    if (failed(VectorContractRewriter::match(op, rewriter)))
+      return failure();
+
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+
+    M = lhsType.getDimSize(0);
+    N = rhsType.getDimSize(0);
+    K = rhsType.getDimSize(1);
+
+    // Check the operands have the expected shape:
+    //  * for LHS: fixed vector MxK
+    //  * for RHS: scalable vector [N]xK
+    //  * K == 8
+    //  * M and N even and at least 2
+    if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
+        rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
+        M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+        !rhsType.getScalableDims()[0])
+      return rewriter.notifyMatchFailure(op, "non-matching operand shape");
+
+    // Check the output is a vector of i32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getResultType());
+    if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+      return rewriter.notifyMatchFailure(op,
+                                         "output type is not a vector of i32");
+
+    // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+    // before the extension. All four signed/unsigned combinations for input
+    // operands are supported, but they are lowered to different operations.
+    // Determine which is the appropriate operation to lower to.
+    mmlaOp = MMLA::SignedInt;
+    swapOperands = false;
+    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::UnsignedInt;
+      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+    }
+    if (!maybeLhs)
+      return rewriter.notifyMatchFailure(
+          op, "LHS is not a sign- or zero- extended i8");
+
+    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::UnsignedInt)
+        mmlaOp = MMLA::MixedInt;
+    } else {
+      if (mmlaOp == MMLA::SignedInt) {
+        mmlaOp = MMLA::MixedInt;
+        swapOperands = true;
+      }
+      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
+    }
+    if (!maybeRhs)
+      return rewriter.notifyMatchFailure(
+          op, "RHS is not a sign- or zero- extended i8");
+
+    // Initialise algorithm parameters.
+    lhs = *maybeLhs;
+    rhs = *maybeRhs;
+    acc = op.getAcc();
+
+    flatLhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
+                                  /*scalableDims=*/{true});
+    flatRhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
+                                  /*scalableDims=*/{true});
+
+    flatAccType = VectorType::get(/*shape=*/4, rewriter.getI32Type(),
+                                  /*scalableDims=*/{true});
+
+    flatRhsTileType = VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(),
+                                      /*scalableDims=*/{true});
+
+    accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(),
+                               /*scalableDims=*/{true});
+    accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(),
+                                 /*scalableDims=*/{true});
+    accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
+                                 /*scalableDims=*/{true});
+    accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
+                                   /*scalableDims=*/{true});
+
+    return success();
+  }
+};
+
+class VectorContractRewriterBfloat : public VectorContractRewriter {
+public:
+  // Check the specific preconditions for the bfloat16 case. Initialise
+  // parametrisation types and dimensions.
+  LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) {
+
+    if (failed(VectorContractRewriter::match(op, rewriter)))
+      return failure();
+
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+
+    M = lhsType.getDimSize(0);
+    N = rhsType.getDimSize(0);
+    K = rhsType.getDimSize(1);
+
+    // Check the operands have the expected shape:
+    //  * for LHS: fixed vector MxK
+    //  * for RHS: scalable vector [N]xK
+    //  * K == 4
+    //  * M and N even and at least 2
+    if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
+        rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 4 ||
+        M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+        !rhsType.getScalableDims()[0])
+      return rewriter.notifyMatchFailure(op, "non-matching operand shape");
+
+    // Check the output is a vector of Float32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getResultType());
+    if (!outTy || outTy.getElementType() != rewriter.getF32Type())
+      return rewriter.notifyMatchFailure(op,
+                                         "output type is not a vector of f32");
+
+    // Check the inputs are vectors of BFloat16 elements.
+    if (lhsType.getElementType() != rewriter.getBF16Type())
+      return rewriter.notifyMatchFailure(op,
+                                         "input type is not a vector of bf16");
+
+    // Initialise algorithm parameters.
+    mmlaOp = MMLA::Bfloat;
+    swapOperands = false;
+    lhs = op.getLhs();
+    rhs = op.getRhs();
+    acc = op.getAcc();
+
+    flatLhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(),
+                                  /*scalableDims=*/{true});
+    flatRhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(),
+                                  /*scalableDims=*/{true});
+
+    flatAccType = VectorType::get(/*shape=*/4, rewriter.getF32Type(),
+                                  /*scalableDims=*/{true});
+
+    flatRhsTileType = VectorType::get(/*shape=*/4 * N, rewriter.getBF16Type(),
+                                      /*scalableDims=*/{true});
+
+    accRowTy = VectorType::get(/*shape=*/N, rewriter.getF32Type(),
+                               /*scalableDims=*/{true});
+    accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getF32Type(),
+                                 /*scalableDims=*/{true});
+    accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
+                                 /*scalableDims=*/{true});
+    accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
+                                   /*scalableDims=*/{true});
+
+    return success();
+  }
+};
+
+class LowerContractionToSVEI8MMPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    // Match i8xi8 -> i32 matrix multiply and accumulate.
+    VectorContractRewriterI8MM vcr;
+    if (failed(vcr.match(op, rewriter)))
+      return failure();
+
+    Value result = vcr.rewrite(op, rewriter);
+    rewriter.replaceOp(op, result);
+
+    return success();
+  }
+};
+
+class LowerContractionToSVEBFMMLAPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    // Match bf16xbf16 -> f32 matrix multiply and accumulate.
+    VectorContractRewriterBfloat vcr;
+    if (failed(vcr.match(op, rewriter)))
+      return failure();
+
+    Value result = vcr.rewrite(op, rewriter);
+    rewriter.replaceOp(op, result);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
+}
+
+void mlir::populateLowerContractionToSVEBFMMLAPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerContractionToSVEBFMMLAPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
deleted file mode 100644
index b7703ff0393eb..0000000000000
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ /dev/null
@@ -1,366 +0,0 @@
-//===- LowerContractionToSVEI8MMPattern.cpp - Contract to I8MM --*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements lowering patterns from vector.contract to operations
-// that map to instructions from the SVE FEAT_I8MM extension.
-//
-// TODO: There may be opportunities to unify this with a similar pattern
-// for Neon. See:
-//   https://github.com/llvm/llvm-project/issues/145559
-//   LowerContractionToNeonI8MMPattern.cpp
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
-#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#include "mlir/Dialect/UB/IR/UBOps.h"
-
-#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
-
-using namespace mlir;
-
-namespace {
-// Get the operand of a `vector.contract`. This function is intended to abstract
-// away from the particular way a value is extended before feeding it into the
-// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
-// (for implicit sign-extension see `vector.contract` documentation).
-//
-// The template parameter `Op` indicates the extension operation (explicit or
-// implicit) for which we are checking.
-//
-// Return success only for extensions from `i8` to `i32`.
-template <typename Op>
-std::optional<Value> getExtOperand(Value v) {
-
-  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
-                "Must be instantiated with either sign- or zero- extension op");
-
-  // If the operand is not defined by an explicit extend operation of the
-  // accepted operation type allow for an implicit sign-extension.
-  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
-  if (!extOp) {
-    if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
-      auto vTy = cast<VectorType>(v.getType());
-      if (!vTy.getElementType().isSignlessInteger(8))
-        return {};
-      return v;
-    }
-    return {};
-  }
-
-  // If the operand is defined by an explicit extend operation of the accepted
-  // operation type, check it's extended from `i8` to `i32`.
-  auto inOp = extOp.getIn();
-  auto inTy = dyn_cast<VectorType>(inOp.getType());
-  if (!inTy || !inTy.getElementType().isSignlessInteger(8))
-    return {};
-
-  auto outTy = dyn_cast<VectorType>(extOp.getType());
-  if (!outTy || !outTy.getElementType().isSignlessInteger(32))
-    return {};
-
-  return inOp;
-}
-
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
-  Signed,      // smmla
-  Unsigned,    // ummla
-  Mixed,       // usmmla
-  MixedSwapped // usmmla with LHS and RHS swapped
-};
-
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
-                 mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
-  switch (op) {
-  case MMLA::Signed:
-    return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
-  case MMLA::Unsigned:
-    return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
-  case MMLA::Mixed:
-    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
-  case MMLA::MixedSwapped:
-    // The accumulator comes transposed and the result will be transposed
-    // later, so all we have to do here is swap the operands.
-    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
-  }
-}
-
-/// Lower a contraction operation that performs a matrix multiplication
-/// of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
-/// for the left-hand side and the right-hand side, respectively,
-/// yielding a <Mx[N]> 32-bit integer result.
-///
-/// The operands' shapes are such that the operands can be evenly split into
-/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
-/// instructions. The intent is that M and N are chosen (by higher level
-/// transforms) in such a way as to maximise register usage. The main use case
-/// we envision as of now is MMT4D, thus the RHS operand is expected
-/// pre-transposed.
-///
-/// The matrix multiplication is performed by unrolling the usual tiled matrix
-/// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
-/// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
-///
-/// One way to illustrate the operation is as follows:
-///
-/// RHS<8x[N]>:       <8x[2]> <8x[2]> ... <8x[2]>
-///                 +-----------------------------
-/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
-///           <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
-///            ...  |   ...     ...   ...   ...
-///           <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
-///
-/// The RHS operand is unpacked into N/2 values, each representing a sequence of
-/// VSCALE number of sub-tiles with dimensions <8x2>.
-/// The LHS operand is initially unpacked into M/2 values, each representing a
-/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
-/// VSCALE times.
-/// Multiplying thus replicated LHS sub-tile by the corresponding RHS sub-tile
-/// correctly computes an entire result sub-tile.
-class LowerContractionToSVEI8MMPattern
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override {
-
-    Location loc = op.getLoc();
-    mlir::VectorType lhsType = op.getLhsType();
-    mlir::VectorType rhsType = op.getRhsType();
-
-    // Check the rank the types so we can safely examine their dimensions.
-    if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
-      return rewriter.notifyMatchFailure(op, "non-matching operand shape");
-
-    auto M = lhsType.getDimSize(0);
-    auto N = rhsType.getDimSize(0);
-    auto K = rhsType.getDimSize(1);
-
-    // Check the operands have the expected shape:
-    //  * for LHS: fixed vector MxK
-    //  * for RHS: scalable vector [N]xK
-    //  * K == 8
-    //  * M and N even and at least 2
-    if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
-        rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
-        M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
-        !rhsType.getScalableDims()[0])
-      return rewriter.notifyMatchFailure(op, "non-matching operand shape");
-
-    // Check permutation maps. For now only accept
-    //   lhs: (d0, d1, d2) -> (d0, d2)
-    //   rhs: (d0, d1, d2) -> (d1, d2)
-    //   acc: (d0, d1, d2) -> (d0, d1)
-    // This corresponds to matrix multiplication with transposed RHS.
-    if (op.getIndexingMapsArray()[0] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
-                                                 op.getContext()) ||
-        op.getIndexingMapsArray()[1] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
-                                                 op.getContext()) ||
-        op.getIndexingMapsArray()[2] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
-                                                 op.getContext()))
-      return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
-
-    // Check iterator types for matrix multiplication.
-    auto itTypes = op.getIteratorTypesArray();
-    if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
-        itTypes[1] != vector::IteratorType::parallel ||
-        itTypes[2] != vector::IteratorType::reduction)
-      return rewriter.notifyMatchFailure(
-          op, "iterator types do not correspond to matrix multiplication");
-
-    // Check the combining kind is addition.
-    if (op.getKind() != vector::CombiningKind::ADD)
-      return rewriter.notifyMatchFailure(op,
-                                         "combining kind is not an addition");
-
-    // Check the output is a vector of i32 elements.
-    auto outTy = dyn_cast<VectorType>(op.getResultType());
-    if (!outTy || outTy.getElementType() != rewriter.getI32Type())
-      return rewriter.notifyMatchFailure(op,
-                                         "output type is not a vector of i32");
-
-    // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
-    // before the extension. All four signed/unsigned combinations for input
-    // operands are supported, but they are lowered to different operations.
-    // Determine which is the appropriate operation to lower to.
-    MMLA mmlaOp = MMLA::Signed;
-    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
-    if (!maybeLhs) {
-      mmlaOp = MMLA::Unsigned;
-      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
-    }
-    if (!maybeLhs)
-      return rewriter.notifyMatchFailure(
-          op, "LHS is not a sign- or zero- extended i8");
-
-    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
-    if (maybeRhs) {
-      if (mmlaOp == MMLA::Unsigned)
-        mmlaOp = MMLA::Mixed;
-    } else {
-      if (mmlaOp == MMLA::Signed)
-        mmlaOp = MMLA::MixedSwapped;
-      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
-    }
-    if (!maybeRhs)
-      return rewriter.notifyMatchFailure(
-          op, "RHS is not a sign- or zero- extended i8");
-
-    // One-dimensional vector types for arm_sve.*mmla
-    auto nxv16i8 = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
-                                   /*scalableDims=*/{true});
-    auto nxv4i32 = VectorType::get(/*shape=*/4, rewriter.getI32Type(),
-                                   /*scalableDims=*/{true});
-
-    // Extract LHS sub-tiles with logicall shape <2x8>.
-    SmallVector<Value> lhsTile;
-    for (int64_t i = 0; i < M; i += 2) {
-      // Extract two consecutive rows of the LHS tile.
-      auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
-                                                   ArrayRef<int64_t>{i});
-      auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
-                                                   ArrayRef<int64_t>{i + 1});
-      // Concatenate to obtain a 16 x i8 flattened sub-tile.
-      auto t = rewriter.create<vector::ShuffleOp>(
-          loc, r0, r1,
-          llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
-                                  14, 15});
-      // Turn it into a scalable vector.
-      auto s = rewriter.create<vector::ScalableInsertOp>(
-          loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
-      // Replicate the sub-tile VSCALE times to fill the entire vector.
-      auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
-      lhsTile.push_back(r);
-    }
-
-    // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
-    auto rhs = rewriter.create<vector::ShapeCastOp>(
-        maybeRhs->getLoc(),
-        VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(),
-                        /*scalableDims=*/{true}),
-        *maybeRhs);
-
-    // Extract the RHS sub-tiles with logical shape <8x[2]>.
-    SmallVector<Value> rhsTile;
-    for (int64_t j = 0; j < N; j += 2)
-      rhsTile.push_back(
-          rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, rhs, j * 8));
-
-    // Handy types for packing/unpacking of the accumulator tile.
-    auto accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(),
-                                    /*scalableDims=*/{true});
-    auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(),
-                                      /*scalableDims=*/{true});
-    auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
-                                      /*scalableDims=*/{true});
-    auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
-                                        /*scalableDims=*/{true});
-
-    // Extract and pack the ACC sub-tiles.
-    SmallVector<Value> accTile;
-    for (int64_t i = 0; i < M; i += 2) {
-      // Extract two consecutive rows of the accumulator tile.
-      auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
-                                                   ArrayRef<int64_t>{i});
-      auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
-                                                   ArrayRef<int64_t>{i + 1});
-      Value accTileVec;
-      if (mmlaOp == MMLA::MixedSwapped) {
-        // We need to swap the positions of the LHS and RHS (since we don't have
-        // a signed * unsigned operation), but then each individual 2x2 tile of
-        // the acumulator and (later) the result need to be transposed.
-        accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
-      } else {
-        // Bitcast them to 64-bit elements, so subsequent
-        // interleave/deinterleave work on pairs of 32-bit numbers.
-        auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
-        auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
-
-        // Interleave the rows, effectively flattening each 2x2 tile into 4
-        // consecutive elements.
-        auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64);
-
-        // Bitcast back to 32-bit elements.
-        accTileVec =
-            rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
-      }
-      // Extract ACC sub-tiles.
-      for (int64_t j = 0; j < N; j += 2)
-        accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
-            loc, nxv4i32, accTileVec, j * 2));
-    }
-
-    // Emit sub-tile matrix multiplications.
-    SmallVector<Value> outTile;
-    for (int64_t i = 0; i < M / 2; ++i)
-      for (int64_t j = 0; j < N / 2; ++j) {
-        Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
-                                accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
-        outTile.push_back(mmla);
-      }
-
-    // Unpack the OUT sub-tiles and insert into the result.
-    Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
-    for (int64_t i = 0; i < M / 2; ++i) {
-      // Collect a number of sub-tiles in a row.
-      Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
-      for (int64_t j = 0; j < N / 2; ++j)
-        row = rewriter.create<vector::ScalableInsertOp>(
-            loc, outTile[i * N / 2 + j], row, j * 4);
-
-      // Unpack the row to obtain two rows of the output. If we have the out
-      // sub-tiles transposed we obtain two consecutive output rows by
-      // separating even and odd elements, i.e. a simple deinterleave.
-      // Otherwise, the interleave is by pairs.
-      Value out0, out1;
-      if (mmlaOp == MMLA::MixedSwapped) {
-        auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
-        out0 = tmp.getRes1();
-        out1 = tmp.getRes2();
-      } else {
-        // Deinterleave by pairs.
-        auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
-        auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
-
-        // Bitcast back into 32-bit elements and insert into the result.
-        out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
-                                                  deintr64.getRes1());
-        out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
-                                                  deintr64.getRes2());
-      }
-      result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
-      result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
-    }
-
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
-    RewritePatternSet &patterns) {
-  MLIRContext *context = patterns.getContext();
-  patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
-}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir
new file mode 100644
index 0000000000000..ca9d91576b512
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir
@@ -0,0 +1,105 @@
+// RUN:  mlir-opt %s --transform-interpreter | FileCheck %s
+
+#attrs = {
+  indexing_maps = [
+    affine_map<(d0, d1, d2) -> (d0, d2)>,
+    affine_map<(d0, d1, d2) -> (d1, d2)>,
+    affine_map<(d0, d1, d2) -> (d0, d1)>
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  kind = #vector.kind<add>
+}
+
+// CHECK-LABEL: @test_vector_contract_to_bfmmla
+// CHECK-SAME:    %[[LHS:.+]]: vector<4x4xbf16>, %[[RHS:.+]]: vector<[4]x4xbf16>, %[[ACC:.+]]: vector<4x[4]xf32>) -> vector<4x[4]xf32> {
+// CHECK-NEXT:    %[[T0:.+]]  = ub.poison : vector<[8]xf32>
+// CHECK-NEXT:    %[[UB:.+]] = ub.poison : vector<4x[4]xf32>
+// CHECK-NEXT:    %[[T2:.+]]  = ub.poison : vector<[8]xbf16>
+
+// Extract rows 0 and 1 of the LHS, concatenate them, and replicate the resulting 8xbf16 vector
+// VSCALE times to obtain a [8]xbf16 vector.
+// CHECK-NEXT:    %[[T3:.+]]  = vector.extract %[[LHS]][0] : vector<4xbf16> from vector<4x4xbf16>
+// CHECK-NEXT:    %[[T4:.+]]  = vector.extract %[[LHS]][1] : vector<4xbf16> from vector<4x4xbf16>
+// CHECK-NEXT:    %[[T5:.+]]  = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xbf16>, vector<4xbf16>
+// CHECK-NEXT:    %[[T6:.+]]  = vector.scalable.insert %[[T5]], %[[T2]][0] : vector<8xbf16> into vector<[8]xbf16>
+// CHECK-NEXT:    %[[LHS_00:.+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[8]xbf16>
+
+// Same for rows 2 and 3 of the LHS.
+// CHECK-NEXT:    %[[T8:.+]]  = vector.extract %[[LHS]][2] : vector<4xbf16> from vector<4x4xbf16>
+// CHECK-NEXT:    %[[T9:.+]]  = vector.extract %[[LHS]][3] : vector<4xbf16> from vector<4x4xbf16>
+// CHECK-NEXT:    %[[T10:.+]]  = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xbf16>, vector<4xbf16>
+// CHECK-NEXT:    %[[T11:.+]]  = vector.scalable.insert %[[T10]], %[[T2]][0] : vector<8xbf16> into vector<[8]xbf16>
+// CHECK-NEXT:    %[[LHS_10:.+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[8]xbf16>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT:    %[[T13:.+]]  = vector.shape_cast %[[RHS]] : vector<[4]x4xbf16> to vector<[16]xbf16>
+// CHECK-NEXT:    %[[RHS_00:.+]] = vector.scalable.extract %[[T13]][0] : vector<[8]xbf16> from vector<[16]xbf16>
+// CHECK-NEXT:    %[[RHS_01:.+]] = vector.scalable.extract %[[T13]][8] : vector<[8]xbf16> from vector<[16]xbf16>
+
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT:    %[[T16:.+]]  = vector.extract %[[ACC]][0] : vector<[4]xf32> from vector<4x[4]xf32>
+// CHECK-NEXT:    %[[T17:.+]]  = vector.extract %[[ACC]][1] : vector<[4]xf32> from vector<4x[4]xf32>
+// CHECK-NEXT:    %[[T18:.+]]  = vector.bitcast %[[T16]] : vector<[4]xf32> to vector<[2]xi64>
+// CHECK-NEXT:    %[[T19:.+]]  = vector.bitcast %[[T17]] : vector<[4]xf32> to vector<[2]xi64>
+// CHECK-NEXT:    %[[T20:.+]]  = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT:    %[[T21:.+]]  = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xf32>
+// CHECK-NEXT:    %[[ACC_00:.+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xf32> from vector<[8]xf32>
+// CHECK-NEXT:    %[[ACC_01:.+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xf32> from vector<[8]xf32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT:    %[[T24:.+]]  = vector.extract %[[ACC]][2] : vector<[4]xf32> from vector<4x[4]xf32>
+// CHECK-NEXT:    %[[T25:.+]]  = vector.extract %[[ACC]][3] : vector<[4]xf32> from vector<4x[4]xf32>
+// CHECK-NEXT:    %[[T26:.+]]  = vector.bitcast %[[T24]] : vector<[4]xf32> to vector<[2]xi64>
+// CHECK-NEXT:    %[[T27:.+]]  = vector.bitcast %[[T25]] : vector<[4]xf32> to vector<[2]xi64>
+// CHECK-NEXT:    %[[T28:.+]]  = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT:    %[[T29:.+]]  = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xf32>
+// CHECK-NEXT:    %[[ACC_10:.+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xf32> from vector<[8]xf32>
+// CHECK-NEXT:    %[[ACC_11:.+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xf32> from vector<[8]xf32>
+
+// Do the sub-tile matrix multiplications
+// CHECK-NEXT:    %[[PACK_RES_00:.+]] = arm_sve.intr.bfmmla %[[ACC_00]], %[[LHS_00]], %[[RHS_00]] : vector<[8]xbf16> to vector<[4]xf32>
+// CHECK-NEXT:    %[[PACK_RES_01:.+]] = arm_sve.intr.bfmmla %[[ACC_01]], %[[LHS_00]], %[[RHS_01]] : vector<[8]xbf16> to vector<[4]xf32>
+// CHECK-NEXT:    %[[PACK_RES_10:.+]] = arm_sve.intr.bfmmla %[[ACC_10]], %[[LHS_10]], %[[RHS_00]] : vector<[8]xbf16> to vector<[4]xf32>
+// CHECK-NEXT:    %[[PACK_RES_11:.+]] = arm_sve.intr.bfmmla %[[ACC_11]], %[[LHS_10]], %[[RHS_01]] : vector<[8]xbf16> to vector<[4]xf32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// CHECK-NEXT:    %[[T36:.+]]  = vector.scalable.insert %[[PACK_RES_00]], %[[T0]][0] : vector<[4]xf32> into vector<[8]xf32>
+// CHECK-NEXT:    %[[T37:.+]]  = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xf32> into vector<[8]xf32>
+// CHECK-NEXT:    %[[T38:.+]]  = vector.bitcast %[[T37]] : vector<[8]xf32> to vector<[4]xi64>
+// CHECK-NEXT:    %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT:    %[[UNPACK_RES_00:.+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xf32>
+// CHECK-NEXT:    %[[UNPACK_RES_01:.+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xf32>
+// CHECK-NEXT:    %[[TMP_OUT_0:.+]] = vector.insert %[[UNPACK_RES_00]], %[[UB]] [0] : vector<[4]xf32> into vector<4x[4]xf32>
+// CHECK-NEXT:    %[[TMP_OUT_1:.+]] = vector.insert %[[UNPACK_RES_01]], %[[TMP_OUT_0]] [1] : vector<[4]xf32> into vector<4x[4]xf32>
+
+// Same for result rows 2 and 3
+// CHECK-NEXT:    %[[T43:.+]]  = vector.scalable.insert %[[PACK_RES_10]], %[[T0]][0] : vector<[4]xf32> into vector<[8]xf32>
+// CHECK-NEXT:    %[[T44:.+]]  = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xf32> into vector<[8]xf32>
+// CHECK-NEXT:    %[[T45:.+]]  = vector.bitcast %[[T44]] : vector<[8]xf32> to vector<[4]xi64>
+// CHECK-NEXT:    %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT:    %[[UNPACK_RES_10:.+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xf32>
+// CHECK-NEXT:    %[[UNPACK_RES_11:.+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xf32>
+// CHECK-NEXT:    %[[TMP_OUT_2:.+]] = vector.insert %[[UNPACK_RES_10]], %[[TMP_OUT_1]] [2] : vector<[4]xf32> into vector<4x[4]xf32>
+// CHECK-NEXT:    %[[OUT:.+]] = vector.insert %[[UNPACK_RES_11]], %[[TMP_OUT_2]] [3] : vector<[4]xf32> into vector<4x[4]xf32>
+// CHECK-NEXT:    return %[[OUT]] : vector<4x[4]xf32>
+func.func @test_vector_contract_to_bfmmla(%lhs: vector<4x4xbf16>,
+                                          %rhs: vector<[4]x4xbf16>,
+                                          %acc: vector<4x[4]xf32>) -> vector<4x[4]xf32> {
+  %0 = vector.contract #attrs %lhs, %rhs, %acc
+    : vector<4x4xbf16>, vector<[4]x4xbf16> into vector<4x[4]xf32>
+
+  return %0 : vector<4x[4]xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.arm_sve.vector_contract_to_bfmmla
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir
new file mode 100644
index 0000000000000..0e988d1c2f42c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir
@@ -0,0 +1,201 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-bf16' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  \
+// DEFINE:   --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+bf16" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+
+//
+// Test the lowering of `vector.contract` using the `LowerContractionToSVEBFMMLAPattern`
+//
+// The operation that the `vector.contract` in this test performs is matrix
+// multiplication with accumulate
+//     OUT = ACC + LHS * RHS
+// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT.
+//
+// Tested are calculations as well as that the relevant `ArmSVE` dialect
+// operation ('arm_sve.intr.bfmmla`) is emitted.
+//
+// That pattern above handles (therefore this test prepares) input/output vectors with
+// specific shapes:
+//   * LHS:      vector<Mx4xbf16>
+//   * RHS:      vector<[N]x4xbf16>
+//   * ACC, OUT: vector<Mx[N]xf32>
+// Note that the RHS is transposed.
+// This data layout makes it efficient to load data into SVE
+// registers in the layout expected by te BFMMLA instruction.
+// Such a `vector.contract` is representative of the code we aim to generate
+// by scalable vectorisation of `linalg.mmt4d`.
+// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+// for more information and rationale about these shapes.
+//
+// In this specific test we use M == 4 and N == 4
+//
+
+// Allocate and initialise a memref containing test data for use as the ACC
+// operand. The memref has one dynamic dimension whose extent depends on the
+// runtime value of VSCALE.
+//
+// The input parameter `%in` is a vector that is replicated VSCALE times
+// across the columns of the memref.
+func.func private @prepareAccTestData(%in: vector<4x4xf32>) -> memref<4x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+
+  %vs = vector.vscale
+  %d = arith.muli %c4, %vs : index
+  %mem = memref.alloc(%d) : memref<4x?xf32>
+
+  scf.for %j = %c0 to %d step %c4 {
+    vector.transfer_write %in, %mem[%c0, %j] {in_bounds = [true, true]} : vector<4x4xf32>, memref<4x?xf32>
+  }
+
+  return %mem : memref<4x?xf32>
+}
+
+// Allocate and initialise a memref containing test data for use as the LHS
+// operand. This function just writes the parameter `%in` into the memref.
+// The size of the LHS does not depends on VSCALE.
+func.func private @prepareLHSTestData(%in: vector<4x4xbf16>) -> memref<4x4xbf16> {
+  %c0 = arith.constant 0 : index
+
+  %mem = memref.alloc() : memref<4x4xbf16>
+  vector.transfer_write %in, %mem[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xbf16>, memref<4x4xbf16>
+
+  return %mem : memref<4x4xbf16>
+}
+
+// Allocate and initialise a memref containing test data for use as the RHS
+// operand. The memref has one dynamic dimension whose extent depends on the
+// runtime value of VSCALE.
+//
+// The input parameter `%in` is a vector that is replicated VSCALE times
+// across the rows of the memref.
+//
+// For convenience, flatten the memref, since the RHS vector is read first as a
+// single-dimensional scalable vector and then cast into [N]x4 shape.
+func.func private @prepareRHSTestData(%in: vector<4x4xbf16>) -> memref<?xbf16> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+
+  %vs = vector.vscale
+  %d = arith.muli %c4, %vs : index
+  %mem = memref.alloc(%d) : memref<?x4xbf16>
+
+  scf.for %i = %c0 to %d step %c4 {
+    vector.transfer_write %in, %mem[%i, %c0] {in_bounds = [true, true]} : vector<4x4xbf16>, memref<?x4xbf16>
+  }
+
+  %mem_out = memref.collapse_shape %mem [[0, 1]] : memref<?x4xbf16> into memref<?xbf16>
+  return %mem_out : memref<?xbf16>
+}
+
+
+// CHECK-IR-LABEL: llvm.func @test_bfmmla
+// CHECK-IR-COUNT-4: arm_sve.intr.bfmmla
+func.func @test_bfmmla() {
+
+  %c0 = arith.constant 0 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %c0_bf16 = arith.constant 0.0 : bf16
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[ 0.7,  1.0, -0.1,  1.8],
+                                   [-0.5,  0.9,  0.7, -0.7],
+                                   [ 0.5, -1.3, -2.2,  0.1],
+                                   [-0.7,  1.0,  1.7, -1.0]]> : vector<4x4xf32>
+
+  %acc_mem = func.call @prepareAccTestData(%acc_cst) : (vector<4x4xf32>) -> memref<4x?xf32>
+  %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x?xf32>, vector<4x[4]xf32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[ 0.1,  0.7, -0.9,  1.3],
+                                   [-1.6,  0.7, -0.3, -0.3],
+                                   [-0.4,  0.6,  0.8, -0.5],
+                                   [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
+
+  %lhs_mem = func.call @prepareLHSTestData(%lhs_cst) : (vector<4x4xbf16>) -> memref<4x4xbf16>
+  %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[ 0.6,  1.3,  0.1, -0.9],
+                                   [ 0.5,  1.6,  1.8,  1.6],
+                                   [-0.2,  0.4,  1.0,  0.4],
+                                   [-1.3, -0.2, -2.2,  0.3]]> : vector<4x4xbf16>
+
+  %rhs_mem = func.call @prepareRHSTestData(%rhs_cst) : (vector<4x4xbf16>) -> memref<?xbf16>
+  %rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_bf16 {in_bounds = [true]} :  memref<?xbf16>, vector<[16]xbf16>
+  %rhs = vector.shape_cast %rhs_flat : vector<[16]xbf16> to vector<[4]x4xbf16>
+
+  // Matrix multiplication and accumulate with transposed RHS.
+  %0 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<4x4xbf16>, vector<[4]x4xbf16> into vector<4x[4]xf32>
+
+  // Display the result of the multiplication
+  vector.print str "Result(BFMMLA):\n"
+  %u0 = vector.extract %0[0] : vector<[4]xf32> from vector<4x[4]xf32>
+  %u1 = vector.extract %0[1] : vector<[4]xf32> from vector<4x[4]xf32>
+  %u2 = vector.extract %0[2] : vector<[4]xf32> from vector<4x[4]xf32>
+  %u3 = vector.extract %0[3] : vector<[4]xf32> from vector<4x[4]xf32>
+  vector.print %u0 : vector<[4]xf32>
+  vector.print %u1 : vector<[4]xf32>
+  vector.print %u2 : vector<[4]xf32>
+  vector.print %u3 : vector<[4]xf32>
+
+  // Deallocate the buffers.
+  memref.dealloc %acc_mem : memref<4x?xf32>
+  memref.dealloc %lhs_mem : memref<4x4xbf16>
+  memref.dealloc %rhs_mem : memref<?xbf16>
+
+  return
+}
+
+// Perform each test with SVE vector lengths 128 bits and 256 bits (i.e. VSCALEs
+// 1 and 2, respectively). The vector length is set via the `setArmVLBits`
+// function. The effect of setting a different vector length is that the tests
+// allocate and operate on different sized buffers (see `prepare<X>TestData`
+// functions).
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  %c256 = arith.constant 256 : i32
+
+// CHECK-LABEL: Result(BFMMLA):
+// CHECK: (  0.411922, 2.63254,  -0.219259,  3.89965 )
+// CHECK: ( -0.316515, 0.196875,  0.879375,  1.80924 )
+// CHECK: (  1.56867,  0.101367, -1.2784,   -1.41579 )
+// CHECK: ( -1.56041, -4.30078,   0.0196488, 1.88269 )
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+  func.call @test_bfmmla() : () -> ()
+
+// CHECK: Result(BFMMLA):
+// CHECK: (  0.411922, 2.63254,  -0.219259,  3.89965,  0.411922, 2.63254,  -0.219259,  3.89965 )
+// CHECK: ( -0.316515, 0.196875,  0.879375,  1.80924, -0.316515, 0.196875,  0.879375,  1.80924 )
+// CHECK: (  1.56867,  0.101367, -1.2784,   -1.41579,  1.56867,  0.101367, -1.2784,   -1.41579 )
+// CHECK: ( -1.56041, -4.30078,   0.0196488, 1.88269, -1.56041, -4.30078,   0.0196488, 1.88269 )
+  func.call @setArmVLBits(%c256) : (i32) -> ()
+  func.call @test_bfmmla() : () -> ()
+
+  return
+}
+
+func.func private @setArmVLBits(%bits : i32)
+func.func private @printMemrefF32(%ptr : memref<*xf32>)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
index 5f6e8e4c30892..8504d664fa0c6 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
@@ -20,7 +20,7 @@
 ]
 
 //
-// Test the lowering of `vector.contract` using the `LowerContractionToSVEI8MMPattern`
+// Test the lowering of `vector.contract` using the `LowerContractionToSVEBFMMLAPattern`
 //
 // The operation that the `vector.contract` in this test performs is matrix
 // multiplication with accumulate
@@ -42,7 +42,7 @@
 // registers in the layout expected by FEAT_I8MM instructions.
 // Such a `vector.contract` is representative of the code we aim to generate
 // by scalable vectorisation of `linalg.mmt4d`.
-// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
 // for more information and rationale about these shapes.
 //
 // In this specific test we use M == 4 and N == 4
@@ -316,7 +316,7 @@ func.func @test_usmmla() {
 
 // Test the operation where LHS is interpreted as signed and RHS is interpreted
 // as unsigned. In this test we ultimately emit end execute the `usmmla`
-// instruction with reversed operands, see `LowerContractionToSVEI8MMPattern.cpp`
+// instruction with reversed operands, see `LowerContractToSVEPatterns.cpp`
 // for more details.
 
 // CHECK-IR-LABEL: llvm.func @test_summla



More information about the Mlir-commits mailing list