[llvm-branch-commits] [mlir] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations (PR #148198)
    Momchil Velikov via llvm-branch-commits 
    llvm-branch-commits at lists.llvm.org
       
    Tue Jul 22 07:23:07 PDT 2025
    
    
  
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/148198
>From 14a83220d7aefdaa94bf771055fd398c273ec53b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 11 Jul 2025 10:03:18 +0000
Subject: [PATCH 1/4] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16
 operations
---
 .../TransformOps/ArmNeonVectorTransformOps.td |  15 +-
 .../include/mlir/Dialect/ArmNeon/Transforms.h |   4 +-
 .../Dialect/ArmSVE/Transforms/Transforms.h    |   3 +-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  13 +-
 .../ArmNeonVectorTransformOps.cpp             |   7 +-
 .../Dialect/ArmNeon/Transforms/CMakeLists.txt |   2 +-
 ...rn.cpp => LowerContractToNeonPatterns.cpp} | 126 +++++++---
 .../TransformOps/ArmSVEVectorTransformOps.cpp |   2 +-
 .../Transforms/LowerContractToSVEPatterns.cpp |   4 +-
 mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir  | 225 ++++++++++++++++++
 .../CPU/ArmNeon/vector-contract-bfmmla.mlir   | 176 ++++++++++++++
 .../CPU/ArmNeon/vector-contract-i8mm.mlir     |   2 +-
 12 files changed, 535 insertions(+), 44 deletions(-)
 rename mlir/lib/Dialect/ArmNeon/Transforms/{LowerContractionToNeonI8MMPattern.cpp => LowerContractToNeonPatterns.cpp} (81%)
 create mode 100644 mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
index bcaca7da967fa..35747126d3db1 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
@@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp
          "apply_patterns.arm_neon.vector_contract_to_i8mm",
          [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that vector.contract operations should be lowered to
-    finer-grained vector primitives from the ArmNeon dialect.
+    Indicates that vector contract operations should be lowered to
+    to ArmNeon dialect operations mapping to instructions from FEAT_I8MM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyArmNeonContractionToBFMMLAPatternsOp
+    : Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla",
+         [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contract operations should be lowered to
+    to ArmNeon dialect operations mapping to instructions from FEAT_BF16.
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
index 2f0f634a96770..08065a3b25266 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -13,8 +13,8 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arm_neon {
-void populateLowerContractionToNeonI8MMPatternPatterns(
-    RewritePatternSet &patterns);
+void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns);
+void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns);
 } // namespace arm_neon
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index de160dbf8ed94..0019192a31a02 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,8 +20,7 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
-void populateLowerContractionToSVEI8MMPatternPatterns(
-    RewritePatternSet &patterns);
+void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns);
 
 void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index d3d0a45eb2463..cf108690c3741 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -96,13 +96,16 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorGatherLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
-        arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+        arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
       if (armSVE)
-        populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+        populateLowerContractionToSVEI8MMPatterns(patterns);
+    }
+    if (armBF16) {
+      if (armNeon)
+        arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
+      if (armSVE)
+        populateLowerContractionToSVEBFMMLAPatterns(patterns);
     }
-    if (armBF16)
-      populateLowerContractionToSVEBFMMLAPatterns(patterns);
-
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
index d07e6a52d8b5f..d069bde6d9979 100644
--- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -20,7 +20,12 @@ using namespace mlir;
 
 void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+  arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
+}
+
+void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
index 06bafde451cbb..368dacac7b835 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect_library(MLIRArmNeonTransforms
-  LowerContractionToNeonI8MMPattern.cpp
+  LowerContractToNeonPatterns.cpp
 
   DEPENDS
   MLIRArmNeonIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
similarity index 81%
rename from mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
rename to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
index 59acb362191a7..5aadaece68732 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -1,4 +1,4 @@
-//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
+//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -93,15 +93,20 @@ class VectorContractRewriter {
   // multiplications.
   enum class MMLA {
     Nop,
-    Signed,      // smmla
-    Unsigned,    // ummla
-    Mixed,       // usmmla
-    MixedSwapped // usmmla with LHS and RHS swapped
+    SignedInt,   // smmla
+    UnsignedInt, // ummla
+    MixedInt,    // usmmla
+    Bfloat       // bfmmla
   };
 
   // Lower-level operation to be emitted.
   MMLA mmlaOp = MMLA::Nop;
 
+  // Indicate if the operands for the ArmNeon dialect operation need to be
+  // swapped. Currently this is needed in order to emulate an "summla"
+  // operation.
+  bool swapOperands = false;
+
   // The operand tiles. These are not necessarily the operands of
   // `vector.contract`, for example they could be operands to `arith.extsi`
   // that is in turn fed into `vector.contract`.
@@ -126,21 +131,22 @@ class VectorContractRewriter {
   // Create the matrix multiply and accumulate operation according to `mmlaOp`.
   Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
                    Value lhs, Value rhs) {
+
+    if (swapOperands)
+      std::swap(lhs, rhs);
     switch (mmlaOp) {
-    case MMLA::Signed:
+    case MMLA::SignedInt:
       return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
                                                       lhs, rhs);
-    case MMLA::Unsigned:
+    case MMLA::UnsignedInt:
       return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
                                                       lhs, rhs);
-    case MMLA::Mixed:
+    case MMLA::MixedInt:
       return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
                                                        lhs, rhs);
-    case MMLA::MixedSwapped:
-      // The accumulator comes transposed and the result will be transposed
-      // later, so all we have to do here is swap the operands.
-      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
-                                                       rhs, lhs);
+    case MMLA::Bfloat:
+      return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
+                                                 rhs);
     case MMLA::Nop:
       llvm_unreachable("Uninitialized operation type");
     }
@@ -273,7 +279,7 @@ class VectorContractRewriter {
       // Transpose ACC if doing signed by unsigned multiplication, because we're
       // using the instruction for unsigned by signed multiplication with
       // reversed operands.
-      if (mmlaOp == MMLA::MixedSwapped)
+      if (swapOperands)
         tiledAcc = rewriter.create<vector::TransposeOp>(
             loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
 
@@ -302,7 +308,7 @@ class VectorContractRewriter {
 
       // Because of the reversed operands the result is obtained transposed.
       // Transpose it back,
-      if (mmlaOp == MMLA::MixedSwapped)
+      if (swapOperands)
         tiledRes = rewriter.create<vector::TransposeOp>(
             loc, tiledRes, ArrayRef<int64_t>({1, 0}));
 
@@ -339,10 +345,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
     // values before the extension. All four signed/unsigned combinations for
     // input operands are supported, but they are lowered to different
     // operations. Determine which is the appropriate operation to lower to.
-    mmlaOp = MMLA::Signed;
+    mmlaOp = MMLA::SignedInt;
     auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
     if (!maybeLhs) {
-      mmlaOp = MMLA::Unsigned;
+      mmlaOp = MMLA::UnsignedInt;
       maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
     }
     if (!maybeLhs)
@@ -351,11 +357,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
 
     auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
     if (maybeRhs) {
-      if (mmlaOp == MMLA::Unsigned)
-        mmlaOp = MMLA::Mixed;
+      if (mmlaOp == MMLA::UnsignedInt)
+        mmlaOp = MMLA::MixedInt;
     } else {
-      if (mmlaOp == MMLA::Signed)
-        mmlaOp = MMLA::MixedSwapped;
+      if (mmlaOp == MMLA::SignedInt) {
+        mmlaOp = MMLA::MixedInt;
+        swapOperands = true;
+      }
       maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
     }
 
@@ -372,16 +380,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
     auto lhsExtInType = cast<VectorType>(lhs.getType());
     if (lhsExtInType.getElementTypeBitWidth() < 8)
       lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
-                                 /* signExt */ mmlaOp == MMLA::Signed ||
-                                     mmlaOp == MMLA::Mixed,
+                                 /* signExt */
+                                 (mmlaOp == MMLA::SignedInt ||
+                                  (mmlaOp == MMLA::MixedInt && !swapOperands)),
                                  rewriter);
 
     auto rhsExtInType = cast<VectorType>(rhs.getType());
     if (rhsExtInType.getElementTypeBitWidth() < 8)
-
       rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
-                                 /* signExt */ mmlaOp != MMLA::Unsigned &&
-                                     mmlaOp != MMLA::Mixed,
+                                 /* signExt */
+                                 (mmlaOp == MMLA::SignedInt ||
+                                  (mmlaOp == MMLA::MixedInt && swapOperands)),
                                  rewriter);
 
     // Initialize parameters for unrolling.
@@ -395,6 +404,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
   }
 };
 
+class VectorContractRewriterBFMMLA : public VectorContractRewriter {
+public:
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+
+    if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+      return failure();
+
+    // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
+    // tiling.
+    if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
+      return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+    // Check the output is a vector of Float32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getResultType());
+    if (!outTy || outTy.getElementType() != rewriter.getF32Type())
+      return rewriter.notifyMatchFailure(op,
+                                         "output type is not a vector of f32");
+
+    // Check the inputs are vectors of BFloat16 elements.
+    if (op.getLhsType().getElementType() != rewriter.getBF16Type())
+      return rewriter.notifyMatchFailure(op,
+                                         "input type is not a vector of bf16");
+
+    mmlaOp = MMLA::Bfloat;
+    swapOperands = false;
+    lhs = op.getLhs();
+    rhs = op.getRhs();
+    acc = op.getAcc();
+
+    // Initialize parameters for unrolling.
+    iterationBounds = *op.getShapeForUnroll();
+    if (iterationBounds.size() == 3)
+      subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
+    else
+      subTileShape = SmallVector<int64_t>({2, 4});
+
+    return success();
+  }
+};
+
 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
 /// any vector.contract into multiple smmla instructions with unrolling so long
 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -416,10 +466,32 @@ class LowerContractionToNeonI8MMPattern
   }
 };
 
+class LowerContractionToNeonBFMMLAPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    VectorContractRewriterBFMMLA vcr;
+    if (failed(vcr.matchAndInit(op, rewriter)))
+      return failure();
+    vcr.lower(op, rewriter);
+
+    return success();
+  }
+};
+
 } // namespace
 
-void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
+void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
   patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
 }
+
+void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
index 8572c34c8b12b..d355fe010821a 100644
--- a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -20,7 +20,7 @@ using namespace mlir;
 
 void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+  mlir::populateLowerContractionToSVEI8MMPatterns(patterns);
 }
 
 void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
index f63eac91a38aa..ac1df3889ecfd 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -12,7 +12,7 @@
 // TODO: There may be opportunities to unify this with a similar pattern
 // for Neon. See:
 //   https://github.com/llvm/llvm-project/issues/145559
-//   LowerContractionToNeonI8MMPattern.cpp
+//   LowerContractToNeonPatterns.cpp
 //
 //===----------------------------------------------------------------------===//
 
@@ -580,7 +580,7 @@ class LowerContractionToSVEBFMMLAPattern
 
 } // namespace
 
-void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+void mlir::populateLowerContractionToSVEI8MMPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
   patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
diff --git a/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
new file mode 100644
index 0000000000000..229c4e5b2dc3a
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
@@ -0,0 +1,225 @@
+// RUN:  mlir-opt %s --transform-interpreter | FileCheck %s
+
+// Test lowering of vector.contract to BFMMLA operations.
+// For each iteration [I, J, K] sub-tiles are extracted from offsets as follows:
+//   LHS: [2*I, 4*K]
+//   RHS: [2*J, 4*K]
+//   ACC: [2*I, 2*J]
+// Sub-tile insert offsets for the result are as like ACC (there are redundant
+// inserts).
+
+// CHECK-LABEL: func.func @vector_contract_to_bfmmla
+// CHECK-SAME:    %[[LHS:.+]]: vector<4x8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4x4xf32>
+
+// %[[INIT_RES:.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+
+// Iteration [0, 0, 0]
+// Extract sib-tiles from each of LHS, RHS and ACC
+// %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+
+// Flatten the operands to fit the `bfmmla` operation types
+// %[[T3:.+]] = vector.shape_cast %[[T0]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T4:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T5:.+]] = vector.shape_cast %[[T2]] : vector<2x2xf32> to vector<4xf32>
+
+// Perform the matrix multiply and accumulate
+// %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T5]], %[[T3]], %[[T4]] : vector<8xbf16> to vector<4xf32>
+
+// Un-flatten the output sub-tile and inserr into the result
+// %[[T7:.+]] = vector.shape_cast %[[K_ACC_0]] : vectK_ACCor<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T7]], %[[INIT_RES]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 0, 1]
+// %[[T9:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T10:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T11:.+]] = vector.shape_cast %[[T9]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T12:.+]] = vector.shape_cast %[[T1]]0 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T13:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T1]]1, %[[T1]]2 : vector<8xbf16> to vector<4xf32>
+// %[[T14:.+]] = vector.shape_cast %[[T1]]3 : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T1]]4, %[[TMP_RES_0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 1, 0]
+// %[[T16:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T17:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T18:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T19:.+]] = vector.shape_cast %[[T1]]6 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T20:.+]] = vector.shape_cast %[[T1]]7 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T21:.+]] = vector.shape_cast %[[T1]]8 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T2]]1, %[[T1]]9, %[[T2]]0 : vector<8xbf16> to vector<4xf32>
+// %[[T23:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T2]]3, %[[TMP_RES_1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 1, 1]
+// %[[T25:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T26:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T27:.+]] = vector.shape_cast %[[T2]]5 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T28:.+]] = vector.shape_cast %[[T2]]6 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T29:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T2]]7, %[[T2]]8 : vector<8xbf16> to vector<4xf32>
+// %[[T30:.+]] = vector.shape_cast %[[T2]]9 : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_3:.+]] = vector.insert_strided_slice %[[T3]]0, %[[TMP_RES_2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 0, 0]
+// %[[T32:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T33:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T34:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T35:.+]] = vector.shape_cast %[[T3]]2 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T36:.+]] = vector.shape_cast %[[T3]]3 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T37:.+]] = vector.shape_cast %[[T3]]4 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_2:.+]] = arm_neon.intr.bfmmla %[[T3]]7, %[[T3]]5, %[[T3]]6 : vector<8xbf16> to vector<4xf32>
+// %[[T39:.+]] = vector.shape_cast %[[K_ACC_2]] : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_4:.+]] = vector.insert_strided_slice %[[T3]]9, %[[TMP_RES_3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 0, 1]
+// %[[T41:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T42:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T43:.+]] = vector.shape_cast %[[T4]]1 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T44:.+]] = vector.shape_cast %[[T4]]2 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T45:.+]] = arm_neon.intr.bfmmla %[[K_ACC_2]], %[[T4]]3, %[[T4]]4 : vector<8xbf16> to vector<4xf32>
+// %[[T46:.+]] = vector.shape_cast %[[T4]]5 : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_5:.+]] = vector.insert_strided_slice %[[T4]]6,%[[TMP_RES_4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 1, 0]
+// %[[T48:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T49:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T50:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T51:.+]] = vector.shape_cast %[[T4]]8 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T52:.+]] = vector.shape_cast %[[T4]]9 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T53:.+]] = vector.shape_cast %[[T5]]0 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_3:.+]] = arm_neon.intr.bfmmla %[[T5]]3, %[[T5]]1, %[[T5]]2 : vector<8xbf16> to vector<4xf32>
+// %[[T55:.+]] = vector.shape_cast %[[K_ACC_3]] : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_6:.+]] = vector.insert_strided_slice %[[T5]]5,%[[TMP_RES_5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 1, 1]
+// %[[T57:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T58:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T59:.+]] = vector.shape_cast %[[T5]]7 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T60:.+]] = vector.shape_cast %[[T5]]8 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T61:.+]] = arm_neon.intr.bfmmla %[[K_ACC_3]], %[[T5]]9, %[[T6]]0 : vector<8xbf16> to vector<4xf32>
+// %[[T62:.+]] = vector.shape_cast %[[T6]]1 : vector<4xf32> to vector<2x2xf32>
+// %[[RESULT:.+]] = vector.insert_strided_slice %[[T6]]2,%[[TMP_RES_6]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// return %[[RESULT]] : vector<4x4xf32>
+
+func.func @vector_contract_to_bfmmla(%lhs: vector<4x8xbf16>,
+                                     %rhs: vector<4x8xbf16>,
+                                     %acc: vector<4x4xf32>) -> vector<4x4xf32> {
+  %0 = vector.contract { indexing_maps = [
+                          affine_map<(m, n, k) -> (m, k)>,
+                          affine_map<(m, n, k) -> (n, k)>,
+                          affine_map<(m, n, k) -> (m, n)>
+                        ],
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>
+                      }
+    %lhs, %rhs, %acc : vector<4x8xbf16>, vector<4x8xbf16> into vector<4x4xf32>
+
+  return %0 : vector<4x4xf32>
+}
+
+// Test lowering of vector.contract, representing vector by matrix multiply and
+// accumulate, to BFMMLA operations.
+
+// For each iteration [J, K] sub-tiles are extracted from offsets as follows:
+//   LHS: [4*K]
+//   RHS: [2*J, 4*K]
+//   ACC: [2*J]
+// Sub-tile insert offsets for the result are as like ACC (there are redundant
+// inserts).
+// CHECK-LABEL: func.func @vector_contract_vecmat_to_bfmmla
+// CHECK-SAME:   %[[LHS:.+]]: vector<8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4xf32>) -> vector<4xf32> {
+// CHECK: %[[ACC_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[LHS_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x4xbf16>
+// CHECK: %[[RES_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+
+// Iteration [0, 0]
+// Extract sub-tiles
+// CHECK: %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+
+// Pad LHS sub-tile/vector with an extra row of zeroes
+// CHECK: %[[T3:.+]] = vector.insert_strided_slice %[[T0]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+
+// Pad ACC sub-tile/vector with an extra row of zeroes
+// CHECK: %[[T4:.+]] = vector.insert_strided_slice %[[T2]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32>
+
+// Flatten the operands to fit the `bfmmla` operation types
+// CHECK: %[[T5:.+]] = vector.shape_cast %[[T3]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T6:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T7:.+]] = vector.shape_cast %[[T4]] : vector<2x2xf32> to vector<4xf32>
+
+// Perform the matrix multiply and accumulate
+// CHECK: %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T7]], %[[T5]], %[[T6]] : vector<8xbf16> to vector<4xf32>
+
+// Un-flatten the output sub-tile
+// CHECK: %[[T9:.+]] = vector.shape_cast %[[K_ACC_0]] : vector<4xf32> to vector<2x2xf32>
+
+// Extract the first rows (the second row is padding) and insert into the result
+// CHECK: %[[T10:.+]] = vector.extract %[[T9]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T10]], %[[RES_INIT]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [0, 1]
+// CHECK: %[[T12:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T13:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T14:.+]] = vector.insert_strided_slice %[[T12]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T15:.+]] = vector.shape_cast %[[T14]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T16:.+]] = vector.shape_cast %[[T13]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T17:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T15]], %[[T16]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T18:.+]] = vector.shape_cast %[[T17]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T19:.+]] = vector.extract %[[T18]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T19]], %[[TMP_RES_0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [1, 0]
+// CHECK: %[[T21:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T22:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T23:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: %[[T24:.+]] = vector.insert_strided_slice %[[T21]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T25:.+]] = vector.insert_strided_slice %[[T23]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32>
+// CHECK: %[[T26:.+]] = vector.shape_cast %[[T24]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T27:.+]] = vector.shape_cast %[[T22]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T28:.+]] = vector.shape_cast %[[T25]] : vector<2x2xf32> to vector<4xf32>
+// CHECK: %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T28]], %[[T26]], %[[T27]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T30:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T31:.+]] = vector.extract %[[T30]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T31]], %[[TMP_RES_1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [1, 1]
+// CHECK: %[[T33:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T34:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T35:.+]] = vector.insert_strided_slice %[[T33]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T36:.+]] = vector.shape_cast %[[T35]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T37:.+]] = vector.shape_cast %[[T34]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T38:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T36]], %[[T37]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T39:.+]] = vector.shape_cast %[[T38]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T40:.+]] = vector.extract %[[T39]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RESULT:.+]] = vector.insert_strided_slice %[[T40]], %[[TMP_RES_2]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: return %[[RESULT]] : vector<4xf32>
+func.func @vector_contract_vecmat_to_bfmmla(%lhs: vector<8xbf16>,
+                                            %rhs: vector<4x8xbf16>,
+                                            %acc: vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.contract { indexing_maps = [
+                          affine_map<(n, k) -> (k)>,
+                          affine_map<(n, k) -> (n, k)>,
+                          affine_map<(n, k) -> (n)>
+                        ],
+                        iterator_types = ["parallel", "reduction"],
+                        kind = #vector.kind<add>
+                      }
+    %lhs, %rhs, %acc : vector<8xbf16>, vector<4x8xbf16> into vector<4xf32>
+
+  return %0 : vector<4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.arm_neon.vector_contract_to_bfmmla
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
new file mode 100644
index 0000000000000..b62ae040f364b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -0,0 +1,176 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-neon enable-arm-bf16' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  \
+// DEFINE:   --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+bf16" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+
+//
+// Test the lowering of `vector.contract` using the `LowerContractionToNeonBFMMLAPattern`
+//
+// The operation that the `vector.contract` in this test performs is matrix
+// multiplication with accumulate
+//     OUT = ACC + LHS * RHS
+// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT.
+//
+// Tested are calculations as well as that the relevant `ArmNeon` dialect
+// operation (`arm_neon.intr.bfmmla`) is emitted.
+//
+// That pattern above handles (therefore this test prepares) input/output vectors with
+// specific shapes:
+//   * LHS:      vector<MxKxbf16>
+//   * RHS:      vector<NxKxbf16>
+//   * ACC, OUT: vector<MxNxf32>
+// where the M and N are even and K is divisible by 4.
+// Note that the RHS is transposed.
+// This data layout makes it efficient to load data into SIMD
+// registers in the layout expected by BFMMLA instruction.
+// Such a `vector.contract` is representative of the code we aim to generate
+// by vectorisation of `linalg.mmt4d`.
+//
+// In this specific test we use M == 4, N == 4, and K == 4.
+
+// CHECK-IR-LABEL: llvm.func @matrix_by_matrix_mul_and_acc
+// CHECK-IR-COUNT-4: arm_neon.intr.bfmmla
+func.func @matrix_by_matrix_mul_and_acc() {
+
+  %c0 = arith.constant 0 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %c0_bf16 = arith.constant 0.0 : bf16
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[ 0.7,  1.0, -0.1,  1.8],
+                                   [-0.5,  0.9,  0.7, -0.7],
+                                   [ 0.5, -1.3, -2.2,  0.1],
+                                   [-0.7,  1.0,  1.7, -1.0]]> : vector<4x4xf32>
+
+  %acc_mem = memref.alloc() : memref<4x4xf32>
+  vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32>
+  %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[ 0.1,  0.7, -0.9,  1.3],
+                                   [-1.6,  0.7, -0.3, -0.3],
+                                   [-0.4,  0.6,  0.8, -0.5],
+                                   [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
+
+  %lhs_mem = memref.alloc() : memref<4x4xbf16>
+  vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+  %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[ 0.6,  1.3,  0.1, -0.9],
+                                   [ 0.5,  1.6,  1.8,  1.6],
+                                   [-0.2,  0.4,  1.0,  0.4],
+                                   [-1.3, -0.2, -2.2,  0.3]]> : vector<4x4xbf16>
+
+  %rhs_mem = memref.alloc() : memref<4x4xbf16>
+  vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+  %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+  // Matrix multiplication and accumulate with transposed RHS.
+  %0 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<4x4xbf16>, vector<4x4xbf16> into vector<4x4xf32>
+
+  // Display the result of the multiplication
+  vector.print str "Result(BFMMLA):\n"
+  %u0 = vector.extract %0[0] : vector<4xf32> from vector<4x4xf32>
+  %u1 = vector.extract %0[1] : vector<4xf32> from vector<4x4xf32>
+  %u2 = vector.extract %0[2] : vector<4xf32> from vector<4x4xf32>
+  %u3 = vector.extract %0[3] : vector<4xf32> from vector<4x4xf32>
+  vector.print %u0 : vector<4xf32>
+  vector.print %u1 : vector<4xf32>
+  vector.print %u2 : vector<4xf32>
+  vector.print %u3 : vector<4xf32>
+
+  return
+}
+
+// Test when the LHS is a one-dimensional vector.
+// 
+// In the vector by matrix case the dhapes ae as follows:
+//   * LHS:      vector<Kxbf16>
+//   * RHS:      vector<NxKxbf16>
+//   * ACC, OUT: vector<Nxf32>
+// N is even and K is divisible by 4.
+// In this specific test we use N == 4, and K == 4.
+
+// CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc
+// CHECK-IR-COUNT-2: arm_neon.intr.bfmmla
+func.func @vector_by_matrix_mul_and_acc() {
+  %c0 = arith.constant 0 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %c0_bf16 = arith.constant 0.0 : bf16
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[0.7,  1.0, -0.1,  1.8]> : vector<4xf32>
+
+  %acc_mem = memref.alloc() : memref<4xf32>
+  vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32>
+  %acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[0.1,  0.7, -0.9,  1.3]> : vector<4xbf16>
+
+  %lhs_mem = memref.alloc() : memref<4xbf16>
+  vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16>
+  %lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[ 0.6,  1.3,  0.1, -0.9],
+                                   [ 0.5,  1.6,  1.8,  1.6],
+                                   [-0.2,  0.4,  1.0,  0.4],
+                                   [-1.3, -0.2, -2.2,  0.3]]> : vector<4x4xbf16>
+
+  %rhs_mem = memref.alloc() : memref<4x4xbf16>
+  vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+  %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+  // Vector by matrix multiplication and accumulate with transposed RHS.
+  %0 = vector.contract { indexing_maps = [
+                           affine_map<(n, k) -> (k)>,
+                           affine_map<(n, k) -> (n, k)>,
+                           affine_map<(n, k) -> (n)>
+                         ],
+                         iterator_types = ["parallel", "reduction"],
+                         kind = #vector.kind<add>
+                       }
+    %lhs, %rhs, %acc : vector<4xbf16>, vector<4x4xbf16> into vector<4xf32>
+
+  // Display the result of the multiplication
+  vector.print str "Result(BFMMLA, vecmat):\n"
+  vector.print %0 : vector<4xf32>
+  
+  return
+}
+
+func.func @main() {
+  // CHECK-LABEL: Result(BFMMLA):
+  // CHECK: (  0.411922, 2.63254,  -0.219259,  3.89965 )
+  // CHECK: ( -0.316515, 0.196875,  0.879375,  1.80924 )
+  // CHECK: (  1.56867,  0.101367, -1.2784,   -1.41579 )
+  // CHECK: ( -1.56041, -4.30078,   0.0196488, 1.88269 )
+  func.call @matrix_by_matrix_mul_and_acc() : () -> ()
+
+  // CHECK-LABEL: Result(BFMMLA, vecmat):
+  // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 )
+  func.call @vector_by_matrix_mul_and_acc() : () -> ()
+
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
index 1ce55ca05c90e..f6012bbd3d0b2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
@@ -240,7 +240,7 @@ func.func @test_usmmla() {
 
 // Test the operation where LHS is interpreted as signed and RHS is interpreted
 // as unsigned. In this test we ultimately emit end execute the `usmmla`
-// instruction with reversed operands, see `LowerContractionToNeonI8MMPattern.cpp`
+// instruction with reversed operands, see `LowerContractoNeonPatterns.cpp`
 // for more details.
 
 // CHECK-IR-LABEL: llvm.func @test_summla
>From 6501a5b67a802a7ed0e13805156bbf8de19de25b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 21 Jul 2025 11:16:02 +0000
Subject: [PATCH 2/4] [fixup] Rename a member function and chanege some allocs
 to allocas
---
 .../Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir   | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
index b62ae040f364b..9acc97da0d53c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -58,7 +58,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
                                    [ 0.5, -1.3, -2.2,  0.1],
                                    [-0.7,  1.0,  1.7, -1.0]]> : vector<4x4xf32>
 
-  %acc_mem = memref.alloc() : memref<4x4xf32>
+  %acc_mem = memref.alloca() : memref<4x4xf32>
   vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32>
   %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32>
 
@@ -68,7 +68,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
                                    [-0.4,  0.6,  0.8, -0.5],
                                    [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
 
-  %lhs_mem = memref.alloc() : memref<4x4xbf16>
+  %lhs_mem = memref.alloca() : memref<4x4xbf16>
   vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
   %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
 
@@ -78,7 +78,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
                                    [-0.2,  0.4,  1.0,  0.4],
                                    [-1.3, -0.2, -2.2,  0.3]]> : vector<4x4xbf16>
 
-  %rhs_mem = memref.alloc() : memref<4x4xbf16>
+  %rhs_mem = memref.alloca() : memref<4x4xbf16>
   vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
   %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
 
@@ -121,14 +121,14 @@ func.func @vector_by_matrix_mul_and_acc() {
   // Accumulator test data
   %acc_cst = arith.constant dense<[0.7,  1.0, -0.1,  1.8]> : vector<4xf32>
 
-  %acc_mem = memref.alloc() : memref<4xf32>
+  %acc_mem = memref.alloca() : memref<4xf32>
   vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32>
   %acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
 
   // LHS test data
   %lhs_cst = arith.constant dense<[0.1,  0.7, -0.9,  1.3]> : vector<4xbf16>
 
-  %lhs_mem = memref.alloc() : memref<4xbf16>
+  %lhs_mem = memref.alloca() : memref<4xbf16>
   vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16>
   %lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16>
 
@@ -138,7 +138,7 @@ func.func @vector_by_matrix_mul_and_acc() {
                                    [-0.2,  0.4,  1.0,  0.4],
                                    [-1.3, -0.2, -2.2,  0.3]]> : vector<4x4xbf16>
 
-  %rhs_mem = memref.alloc() : memref<4x4xbf16>
+  %rhs_mem = memref.alloca() : memref<4x4xbf16>
   vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
   %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
 
>From 5322c0f8149879235896e1fd0a15a2475ed14e1e Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 21 Jul 2025 13:01:35 +0000
Subject: [PATCH 3/4] [fixup] Add a comment about memory ops
---
 .../Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir    | 4 ++++
 1 file changed, 4 insertions(+)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
index 9acc97da0d53c..368f332e40602 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -110,6 +110,10 @@ func.func @matrix_by_matrix_mul_and_acc() {
 //   * ACC, OUT: vector<Nxf32>
 // N is even and K is divisible by 4.
 // In this specific test we use N == 4, and K == 4.
+// Note: the seemingly unnecessary writes of test vectors to memory are done
+// in order to introduce memory load operations on the path leading up to
+// `vector.contract` since that's more representation of the typical usage
+// when multiplying big, non-constant tensors.
 
 // CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc
 // CHECK-IR-COUNT-2: arm_neon.intr.bfmmla
>From a95561a0c493e7e12ea0b22064da490623c02a23 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 21 Jul 2025 13:10:58 +0000
Subject: [PATCH 4/4] [fixup] Move a comment, it was accidentally in the wrong
 place
---
 .../Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir     | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
index 368f332e40602..4285260906251 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -44,6 +44,12 @@
 //
 // In this specific test we use M == 4, N == 4, and K == 4.
 
+// Note: In this and in the following test the seemingly unnecessary
+// writes of test vectors to memory are done in order to introduce memory
+// load operations on the path leading up to `vector.contract` since
+// that's more representation of the typical usage when multiplying
+// big, non-constant tensors.
+
 // CHECK-IR-LABEL: llvm.func @matrix_by_matrix_mul_and_acc
 // CHECK-IR-COUNT-4: arm_neon.intr.bfmmla
 func.func @matrix_by_matrix_mul_and_acc() {
@@ -110,10 +116,6 @@ func.func @matrix_by_matrix_mul_and_acc() {
 //   * ACC, OUT: vector<Nxf32>
 // N is even and K is divisible by 4.
 // In this specific test we use N == 4, and K == 4.
-// Note: the seemingly unnecessary writes of test vectors to memory are done
-// in order to introduce memory load operations on the path leading up to
-// `vector.contract` since that's more representation of the typical usage
-// when multiplying big, non-constant tensors.
 
 // CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc
 // CHECK-IR-COUNT-2: arm_neon.intr.bfmmla
    
    
More information about the llvm-branch-commits
mailing list