[Mlir-commits] [mlir] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations (PR #148198)

Momchil Velikov llvmlistbot at llvm.org
Fri Jul 11 03:36:58 PDT 2025


https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/148198

This is split in two commits:
 * refactor I8MM lowering to make it easier to add ...
 * ... BF16 lowering


>From 5ec628726772f32ad2cb4ccfbbf0a43532562fca Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 10 Jul 2025 11:10:19 +0000
Subject: [PATCH 1/2] [MLIR][AArch64] Refactor lowering of vector.contract to
 Neon I8MM

This patch refactors the pattern in `Transforms/LowerContractionToNeonI8MMPattern.cpp`
using similar approach as in https://github.com/llvm/llvm-project/pull/147052
to prepare for adding BF16 support.
---
 .../LowerContractionToNeonI8MMPattern.cpp     | 431 ++++++++++--------
 1 file changed, 247 insertions(+), 184 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 7180884c77e98..f961d3cafb443 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -33,23 +33,15 @@ using namespace mlir;
 using namespace mlir::arm_neon;
 
 namespace {
-
-/// Return the shaped type with new element type.
-static Type matchContainerType(Type element, Type container) {
-  if (auto shapedTy = dyn_cast<ShapedType>(container)) {
-    return shapedTy.clone(element);
-  }
-  return element;
-}
-
-// Get the operand of a `vector.contract`. This function is intended to abstract
-// away from the particular way a value is extended before feeding it into the
-// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
-// (for implicit sign-extension see `vector.contract` documentation).
-//
-// The template parameter `Op` indicates the extension operation (explicit or
-// implicit) for which we are checking.
-//
+/// Get the operand of a `vector.contract`. This function is intended to
+/// abstract away from the particular way a value is extended before feeding it
+/// into the `vector.contract` - via zero-extend or an explicit or implicit
+/// sign-extend (for implicit sign-extension see `vector.contract`
+/// documentation).
+///
+/// The template parameter `Op` indicates the extension operation (explicit or
+/// implicit) for which we are checking.
+///
 // Return success only for extensions from `iN` (N <= 8) to `i32`.
 template <typename Op>
 std::optional<Value> getExtOperand(Value v) {
@@ -87,202 +79,186 @@ std::optional<Value> getExtOperand(Value v) {
   return inOp;
 }
 
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
-  Signed,      // smmla
-  Unsigned,    // ummla
-  Mixed,       // usmmla
-  MixedSwapped // usmmla with LHS and RHS swapped
-};
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+                           bool signExt, PatternRewriter &rewriter) {
+  Type targetTy = srcTy.clone(rewriter.getI8Type());
+  return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+                 : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
 
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
-                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
-  switch (op) {
-  case MMLA::Signed:
-    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Unsigned:
-    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Mixed:
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
-                                                     rhs);
-  case MMLA::MixedSwapped:
-    // The accumulator comes transposed and the result will be transposed
-    // later, so all we have to do here is swap the operands.
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
-                                                     lhs);
+class VectorContractRewriter {
+protected:
+  // Designate the operation (resp. instruction) used to do sub-tile matrix
+  // multiplications.
+  enum class MMLA {
+    Nop,
+    Signed,      // smmla
+    Unsigned,    // ummla
+    Mixed,       // usmmla
+    MixedSwapped // usmmla with LHS and RHS swapped
+  };
+
+  // Lower-level operation to be emitted.
+  MMLA mmlaOp = MMLA::Nop;
+
+  // The operand tiles. These are not necessarily the operands of
+  // `vector.contract`, for example they could be operands to `arith.extsi`
+  // that is in turn fed into `vector.contract`.
+  Value lhs;
+  Value rhs;
+  Value acc;
+
+  // The dimensions logically corresponding to matrix multiplication of
+  // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+  // shapes, for example RHS could be NxK with a transposing indexing map.
+  int64_t dimM = 0;
+  int64_t dimN = 0;
+  int64_t dimK = 0;
+
+  // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+  SmallVector<int64_t> iterationBounds;
+
+  // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+  // of this shape.
+  SmallVector<int64_t> subTileShape;
+
+  // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+                   Value lhs, Value rhs) {
+    switch (mmlaOp) {
+    case MMLA::Signed:
+      return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Unsigned:
+      return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Mixed:
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       lhs, rhs);
+    case MMLA::MixedSwapped:
+      // The accumulator comes transposed and the result will be transposed
+      // later, so all we have to do here is swap the operands.
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       rhs, lhs);
+    case MMLA::Nop:
+      llvm_unreachable("Uninitialized operation type");
+    }
   }
-}
 
-/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
-/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
-/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
-/// necessary, a single smmla instruction is emitted.
-class LowerContractionToNeonI8MMPattern
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
-    // Note: RHS is not transposed.
-    mlir::VectorType lhsType = op.getLhsType();
-    mlir::VectorType rhsType = op.getRhsType();
+  // Check common preconditions for applying the patterns and initialize
+  // logical dimensions.
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    // Check iterator types for matrix multiplication.
+    SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+    if (!((itTypes.size() == 3 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::parallel &&
+            itTypes[2] == vector::IteratorType::reduction)) ||
+          (itTypes.size() == 2 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::reduction))))
+      return rewriter.notifyMatchFailure(
+          op, "iterator types do not correspond to matrix multiplication");
+
     // Avoid 0-D vectors and 1-D rhs:
-    if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
-      return failure();
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+    if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+        rhsType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
     // This codegen does not work for scalable vectors. Return failure so this
     // pattern is not accidentally chosen over patterns that lower to ArmSVE.
     if (lhsType.isScalable() || rhsType.isScalable())
-      return failure();
-    auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
-    auto dimN = rhsType.getDimSize(0);
-    auto dimK = rhsType.getDimSize(1);
-    bool isVecmat = dimM == 1 ? true : false;
-    if (lhsType.getDimSize(lhsType.getRank() - 1) !=
-        rhsType.getDimSize(rhsType.getRank() - 1)) {
-      return failure(); // dimK mismatch
-    }
-    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
-    // tiling.
-    if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
-      return failure();
-    }
-
-    // Check iterator types for contract. All iterators except inner-most
-    // dimension must be parallel.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
-                                        vector::IteratorType::reduction) {
-      return failure();
-    }
-    if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
-                     [](vector::IteratorType iteratorType) {
-                       return iteratorType != vector::IteratorType::parallel;
-                     })) {
-      return failure();
+      return rewriter.notifyMatchFailure(op,
+                                         "Not applicable to scalable vectors");
+
+    // Initialize dimensions and check for a matching K dimension.
+    dimM = lhsType.getDimSize(0);
+    dimN = rhsType.getDimSize(0);
+    dimK = rhsType.getDimSize(1);
+
+    int64_t lhsDimK;
+    if (lhsType.getRank() == 1) {
+      dimM = 1;
+      lhsDimK = lhsType.getDimSize(0);
+    } else {
+      lhsDimK = lhsType.getDimSize(1);
     }
 
-    // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
-    // values before the extension. All four signed/unsigned combinations for
-    // input operands are supported, but they are lowered to different
-    // operations. Determine which is the appropriate operation to lower to.
-    MMLA mmlaOp = MMLA::Signed;
-    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
-    if (!maybeLhs) {
-      mmlaOp = MMLA::Unsigned;
-      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
-    }
-    if (!maybeLhs)
-      return failure();
+    if (lhsDimK != dimK)
+      return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
 
-    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
-    if (maybeRhs) {
-      if (mmlaOp == MMLA::Unsigned)
-        mmlaOp = MMLA::Mixed;
-    } else {
-      if (mmlaOp == MMLA::Signed)
-        mmlaOp = MMLA::MixedSwapped;
-      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
-    }
-    if (!maybeRhs)
-      return failure();
+    return success();
+  }
 
-    Value origLhs = *maybeLhs;
-    Value origRhs = *maybeRhs;
-
-    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
-    // following neon instruction. Check inputs for extsi are <=i8
-    Value extLhs;
-    Value extRhs;
-    if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
-      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetLhsExtTy =
-            matchContainerType(rewriter.getI8Type(), lhsExtInType);
-        if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
-          extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-        else
-          extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-      }
-    }
-    if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
-      if (rhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetRhsExtTy =
-            matchContainerType(rewriter.getI8Type(), rhsExtInType);
-        if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
-          extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-        else
-          extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-      }
-    }
+public:
+  void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
+    // Create some convenience types.
+    auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
+    auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
+    auto inputExpandedType =
+        VectorType::get({2, subTileShape.back()}, inputElementType);
+    auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+    // One-dimensional representation of logical sub-tiles as required by the
+    // ArmNeon ops.
+    auto collapsedInputType =
+        VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+    auto collapsedOutputType =
+        VectorType::get(outputExpandedType.getNumElements(), accElementType);
+
+    // Get indexing maps for a more concise/convenient access.
+    auto indexingMaps = op.getIndexingMapsArray();
+    AffineMap &lhsPermutationMap = indexingMaps[0];
+    AffineMap &rhsPermutationMap = indexingMaps[1];
+    AffineMap &accPermutationMap = indexingMaps[2];
 
-    if (!extLhs || !extRhs) {
-      return failure();
-    }
+    Location loc = op.getLoc();
 
     // Initial accumulator for the final result. This is the un-tiled result if
     // tiling is done.
     Value result = rewriter.create<arith::ConstantOp>(
         loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
 
-    SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
-    SmallVector<int64_t> smmlaShape = {2, 8};
-    SmallVector<int64_t> loopOrder = {0, 1};
-    if (unrolledSize.size() == 3) {
-      smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+    SmallVector<int64_t, 3> loopOrder = {0, 1};
+    if (iterationBounds.size() == 3)
       loopOrder.push_back(2);
-    }
 
     // Keep track of the previous accumulator when tiling over K.
     Value kAcc;
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+         StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
-        SmallVector<int64_t> operandShape =
-            applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+        SmallVector<int64_t> operandShape = applyPermutationMap(
+            permutationMap, ArrayRef<int64_t>(subTileShape));
         SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
         return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
             loc, operand, operandOffsets, operandShape, operandStrides);
       };
 
       // Extract tiled lhs, rhs, and acc
-      AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
       SmallVector<int64_t> lhsOffsets =
           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
-      AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+      Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
       SmallVector<int64_t> rhsOffsets =
           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
-      AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+      Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
       SmallVector<int64_t> accOffsets =
           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledAcc =
-          extractOperand(op.getAcc(), accPermutationMap, accOffsets);
-
-      auto inputElementType =
-          cast<ShapedType>(tiledLhs.getType()).getElementType();
-      auto accElementType =
-          cast<ShapedType>(tiledAcc.getType()).getElementType();
-      auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
-      auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+      Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
 
       // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
-      // rows along dimM. Expand their shapes to match the smmla op.
-      if (isVecmat) {
-        auto expandForSMMLA = [&](Value tiledOperand,
-                                  VectorType expandedTypeType) {
+      // rows along dimM. Expand their shapes to match the ArmNeon op.
+      if (dimM == 1) {
+        auto expandRowVector = [&](Value tiledOperand,
+                                   VectorType expandedTypeType) {
           auto emptyOperand = rewriter.create<arith::ConstantOp>(
               loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
           SmallVector<int64_t> offsets(
@@ -292,8 +268,8 @@ class LowerContractionToNeonI8MMPattern
           return rewriter.createOrFold<vector::InsertStridedSliceOp>(
               loc, tiledOperand, emptyOperand, offsets, strides);
         };
-        tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
-        tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+        tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
+        tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
       }
 
       // Transpose ACC if doing signed by unsigned multiplication, because we're
@@ -303,15 +279,11 @@ class LowerContractionToNeonI8MMPattern
         tiledAcc = rewriter.create<vector::TransposeOp>(
             loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
 
-      // Collapse tiled operands to 1D vectors required by smmla intrinsic
-      auto collapsedInputType =
-          VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+      // Collapse tiled operands to 1D vectors required by the ArmNeon ops
       auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledLhs.getLoc(), collapsedInputType, tiledLhs);
       auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
-      auto collapsedOutputType =
-          VectorType::get(outputExpandedType.getNumElements(), accElementType);
 
       bool initialKAcc = offsets.back() == 0;
       Value collapsedRes;
@@ -323,8 +295,8 @@ class LowerContractionToNeonI8MMPattern
       }
 
       // Insert contract op
-      kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
-                        collapsedRes, collapsedLhs, collapsedRhs);
+      kAcc =
+          createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
 
       // Reshape output back to 2D
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
@@ -338,9 +310,8 @@ class LowerContractionToNeonI8MMPattern
 
       // With vecmat, only one row of tiled ACC can be inserted into the final
       // result
-      if (isVecmat) {
+      if (dimM == 1)
         tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
-      }
 
       // Insert the tiled result back into the non tiled result of the
       // contract op.
@@ -351,6 +322,98 @@ class LowerContractionToNeonI8MMPattern
     }
 
     rewriter.replaceOp(op, result);
+  }
+};
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+      return failure();
+
+    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
+    // tiling.
+    if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
+      return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+    // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
+    // values before the extension. All four signed/unsigned combinations for
+    // input operands are supported, but they are lowered to different
+    // operations. Determine which is the appropriate operation to lower to.
+    mmlaOp = MMLA::Signed;
+    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::Unsigned;
+      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+    }
+    if (!maybeLhs)
+      return rewriter.notifyMatchFailure(
+          op, "LHS is not a sign- or zero- extended iN, N <= 8");
+
+    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::Unsigned)
+        mmlaOp = MMLA::Mixed;
+    } else {
+      if (mmlaOp == MMLA::Signed)
+        mmlaOp = MMLA::MixedSwapped;
+      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
+    }
+
+    if (!maybeRhs)
+      return rewriter.notifyMatchFailure(
+          op, "RHS is not a sign- or zero- extended iN, N <= 8");
+
+    lhs = *maybeLhs;
+    rhs = *maybeRhs;
+    acc = op.getAcc();
+
+    // Extend inputs from iN, N < 8 to i8.
+    Location loc = op.getLoc();
+    auto lhsExtInType = cast<VectorType>(lhs.getType());
+    if (lhsExtInType.getElementTypeBitWidth() < 8)
+      lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
+                                 /* signExt */ mmlaOp == MMLA::Signed ||
+                                     mmlaOp == MMLA::Mixed,
+                                 rewriter);
+
+    auto rhsExtInType = cast<VectorType>(rhs.getType());
+    if (rhsExtInType.getElementTypeBitWidth() < 8)
+
+      rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
+                                 /* signExt */ mmlaOp != MMLA::Unsigned &&
+                                     mmlaOp != MMLA::Mixed,
+                                 rewriter);
+
+    // Initialize parameters for unrolling.
+    iterationBounds = *op.getShapeForUnroll();
+    if (iterationBounds.size() == 3)
+      subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 8});
+    else
+      subTileShape = SmallVector<int64_t>({2, 8});
+
+    return success();
+  }
+};
+
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
+/// any vector.contract into multiple smmla instructions with unrolling so long
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
+class LowerContractionToNeonI8MMPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    VectorContractRewriterI8MM vcr;
+    if (failed(vcr.matchAndInit(op, rewriter)))
+      return failure();
+    vcr.rewrite(op, rewriter);
+
     return success();
   }
 };

>From 442e29ac0b3ad54aa521ff02029889fa899a9c75 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 11 Jul 2025 10:03:18 +0000
Subject: [PATCH 2/2] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16
 operations

---
 mlir/include/mlir/Conversion/Passes.td        |   4 +
 .../TransformOps/ArmNeonVectorTransformOps.td |  15 +-
 .../include/mlir/Dialect/ArmNeon/Transforms.h |   4 +-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   4 +-
 .../ArmNeonVectorTransformOps.cpp             |   7 +-
 .../Dialect/ArmNeon/Transforms/CMakeLists.txt |   2 +-
 ...rn.cpp => LowerContractToNeonPatterns.cpp} | 126 +++++++---
 .../LowerContractionToSVEI8MMPattern.cpp      |   2 +-
 mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir  | 225 ++++++++++++++++++
 .../CPU/ArmNeon/vector-contract-bfmmla.mlir   | 176 ++++++++++++++
 .../CPU/ArmNeon/vector-contract-i8mm.mlir     |   2 +-
 11 files changed, 531 insertions(+), 36 deletions(-)
 rename mlir/lib/Dialect/ArmNeon/Transforms/{LowerContractionToNeonI8MMPattern.cpp => LowerContractToNeonPatterns.cpp} (81%)
 create mode 100644 mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..4f304b39a0528 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of Arm FEAT_I8MM instructions while lowering "
            "the vector dialect.">,
+    Option<"armBF16", "enable-arm-bf16",
+           "bool", /*default=*/"false",
+           "Enables the use of Arm FEAT_BF16 instructions while lowering "
+           "the vector dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
index bcaca7da967fa..35747126d3db1 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
@@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp
          "apply_patterns.arm_neon.vector_contract_to_i8mm",
          [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that vector.contract operations should be lowered to
-    finer-grained vector primitives from the ArmNeon dialect.
+    Indicates that vector contract operations should be lowered to
+    to ArmNeon dialect operations mapping to instructions from FEAT_I8MM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyArmNeonContractionToBFMMLAPatternsOp
+    : Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla",
+         [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contract operations should be lowered to
+    to ArmNeon dialect operations mapping to instructions from FEAT_BF16.
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
index 2f0f634a96770..08065a3b25266 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -13,8 +13,8 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arm_neon {
-void populateLowerContractionToNeonI8MMPatternPatterns(
-    RewritePatternSet &patterns);
+void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns);
+void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns);
 } // namespace arm_neon
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 549d0210af7ad..1045824c437ab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -84,10 +84,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorGatherLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
-        arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+        arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
       if (armSVE)
         populateLowerContractionToSVEI8MMPatternPatterns(patterns);
     }
+    if (armBF16 && armNeon)
+      arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
index d07e6a52d8b5f..d069bde6d9979 100644
--- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -20,7 +20,12 @@ using namespace mlir;
 
 void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+  arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
+}
+
+void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
index 06bafde451cbb..368dacac7b835 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect_library(MLIRArmNeonTransforms
-  LowerContractionToNeonI8MMPattern.cpp
+  LowerContractToNeonPatterns.cpp
 
   DEPENDS
   MLIRArmNeonIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
similarity index 81%
rename from mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
rename to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
index f961d3cafb443..06746daa8075b 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -1,4 +1,4 @@
-//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
+//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -95,15 +95,20 @@ class VectorContractRewriter {
   // multiplications.
   enum class MMLA {
     Nop,
-    Signed,      // smmla
-    Unsigned,    // ummla
-    Mixed,       // usmmla
-    MixedSwapped // usmmla with LHS and RHS swapped
+    SignedInt,   // smmla
+    UnsignedInt, // ummla
+    MixedInt,    // usmmla
+    Bfloat       // bfmmla
   };
 
   // Lower-level operation to be emitted.
   MMLA mmlaOp = MMLA::Nop;
 
+  // Indicate if the operands for the ArmNeon dialect operation need to be
+  // swapped. Currently this is needed in order to emulate an "summla"
+  // operation.
+  bool swapOperands = false;
+
   // The operand tiles. These are not necessarily the operands of
   // `vector.contract`, for example they could be operands to `arith.extsi`
   // that is in turn fed into `vector.contract`.
@@ -128,21 +133,22 @@ class VectorContractRewriter {
   // Create the matrix multiply and accumulate operation according to `mmlaOp`.
   Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
                    Value lhs, Value rhs) {
+
+    if (swapOperands)
+      std::swap(lhs, rhs);
     switch (mmlaOp) {
-    case MMLA::Signed:
+    case MMLA::SignedInt:
       return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
                                                       lhs, rhs);
-    case MMLA::Unsigned:
+    case MMLA::UnsignedInt:
       return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
                                                       lhs, rhs);
-    case MMLA::Mixed:
+    case MMLA::MixedInt:
       return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
                                                        lhs, rhs);
-    case MMLA::MixedSwapped:
-      // The accumulator comes transposed and the result will be transposed
-      // later, so all we have to do here is swap the operands.
-      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
-                                                       rhs, lhs);
+    case MMLA::Bfloat:
+      return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
+                                                 rhs);
     case MMLA::Nop:
       llvm_unreachable("Uninitialized operation type");
     }
@@ -275,7 +281,7 @@ class VectorContractRewriter {
       // Transpose ACC if doing signed by unsigned multiplication, because we're
       // using the instruction for unsigned by signed multiplication with
       // reversed operands.
-      if (mmlaOp == MMLA::MixedSwapped)
+      if (swapOperands)
         tiledAcc = rewriter.create<vector::TransposeOp>(
             loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
 
@@ -304,7 +310,7 @@ class VectorContractRewriter {
 
       // Because of the reversed operands the result is obtained transposed.
       // Transpose it back,
-      if (mmlaOp == MMLA::MixedSwapped)
+      if (swapOperands)
         tiledRes = rewriter.create<vector::TransposeOp>(
             loc, tiledRes, ArrayRef<int64_t>({1, 0}));
 
@@ -341,10 +347,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
     // values before the extension. All four signed/unsigned combinations for
     // input operands are supported, but they are lowered to different
     // operations. Determine which is the appropriate operation to lower to.
-    mmlaOp = MMLA::Signed;
+    mmlaOp = MMLA::SignedInt;
     auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
     if (!maybeLhs) {
-      mmlaOp = MMLA::Unsigned;
+      mmlaOp = MMLA::UnsignedInt;
       maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
     }
     if (!maybeLhs)
@@ -353,11 +359,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
 
     auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
     if (maybeRhs) {
-      if (mmlaOp == MMLA::Unsigned)
-        mmlaOp = MMLA::Mixed;
+      if (mmlaOp == MMLA::UnsignedInt)
+        mmlaOp = MMLA::MixedInt;
     } else {
-      if (mmlaOp == MMLA::Signed)
-        mmlaOp = MMLA::MixedSwapped;
+      if (mmlaOp == MMLA::SignedInt) {
+        mmlaOp = MMLA::MixedInt;
+        swapOperands = true;
+      }
       maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
     }
 
@@ -374,16 +382,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
     auto lhsExtInType = cast<VectorType>(lhs.getType());
     if (lhsExtInType.getElementTypeBitWidth() < 8)
       lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
-                                 /* signExt */ mmlaOp == MMLA::Signed ||
-                                     mmlaOp == MMLA::Mixed,
+                                 /* signExt */
+                                 (mmlaOp == MMLA::SignedInt ||
+                                  (mmlaOp == MMLA::MixedInt && !swapOperands)),
                                  rewriter);
 
     auto rhsExtInType = cast<VectorType>(rhs.getType());
     if (rhsExtInType.getElementTypeBitWidth() < 8)
-
       rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
-                                 /* signExt */ mmlaOp != MMLA::Unsigned &&
-                                     mmlaOp != MMLA::Mixed,
+                                 /* signExt */
+                                 (mmlaOp == MMLA::SignedInt ||
+                                  (mmlaOp == MMLA::MixedInt && swapOperands)),
                                  rewriter);
 
     // Initialize parameters for unrolling.
@@ -397,6 +406,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
   }
 };
 
+class VectorContractRewriterBFMMLA : public VectorContractRewriter {
+public:
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+
+    if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+      return failure();
+
+    // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
+    // tiling.
+    if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
+      return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+    // Check the output is a vector of Float32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getResultType());
+    if (!outTy || outTy.getElementType() != rewriter.getF32Type())
+      return rewriter.notifyMatchFailure(op,
+                                         "output type is not a vector of f32");
+
+    // Check the inputs are vectors of BFloat16 elements.
+    if (op.getLhsType().getElementType() != rewriter.getBF16Type())
+      return rewriter.notifyMatchFailure(op,
+                                         "input type is not a vector of bf16");
+
+    mmlaOp = MMLA::Bfloat;
+    swapOperands = false;
+    lhs = op.getLhs();
+    rhs = op.getRhs();
+    acc = op.getAcc();
+
+    // Initialize parameters for unrolling.
+    iterationBounds = *op.getShapeForUnroll();
+    if (iterationBounds.size() == 3)
+      subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
+    else
+      subTileShape = SmallVector<int64_t>({2, 4});
+
+    return success();
+  }
+};
+
 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
 /// any vector.contract into multiple smmla instructions with unrolling so long
 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -418,10 +468,32 @@ class LowerContractionToNeonI8MMPattern
   }
 };
 
+class LowerContractionToNeonBFMMLAPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    VectorContractRewriterBFMMLA vcr;
+    if (failed(vcr.matchAndInit(op, rewriter)))
+      return failure();
+    vcr.rewrite(op, rewriter);
+
+    return success();
+  }
+};
+
 } // namespace
 
-void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
+void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
   patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
 }
+
+void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index b7703ff0393eb..f7a9499e2db07 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -12,7 +12,7 @@
 // TODO: There may be opportunities to unify this with a similar pattern
 // for Neon. See:
 //   https://github.com/llvm/llvm-project/issues/145559
-//   LowerContractionToNeonI8MMPattern.cpp
+//   LowerContracToNeonPatterns.cpp
 //
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
new file mode 100644
index 0000000000000..229c4e5b2dc3a
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
@@ -0,0 +1,225 @@
+// RUN:  mlir-opt %s --transform-interpreter | FileCheck %s
+
+// Test lowering of vector.contract to BFMMLA operations.
+// For each iteration [I, J, K] sub-tiles are extracted from offsets as follows:
+//   LHS: [2*I, 4*K]
+//   RHS: [2*J, 4*K]
+//   ACC: [2*I, 2*J]
+// Sub-tile insert offsets for the result are as like ACC (there are redundant
+// inserts).
+
+// CHECK-LABEL: func.func @vector_contract_to_bfmmla
+// CHECK-SAME:    %[[LHS:.+]]: vector<4x8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4x4xf32>
+
+// %[[INIT_RES:.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+
+// Iteration [0, 0, 0]
+// Extract sib-tiles from each of LHS, RHS and ACC
+// %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+
+// Flatten the operands to fit the `bfmmla` operation types
+// %[[T3:.+]] = vector.shape_cast %[[T0]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T4:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T5:.+]] = vector.shape_cast %[[T2]] : vector<2x2xf32> to vector<4xf32>
+
+// Perform the matrix multiply and accumulate
+// %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T5]], %[[T3]], %[[T4]] : vector<8xbf16> to vector<4xf32>
+
+// Un-flatten the output sub-tile and inserr into the result
+// %[[T7:.+]] = vector.shape_cast %[[K_ACC_0]] : vectK_ACCor<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T7]], %[[INIT_RES]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 0, 1]
+// %[[T9:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T10:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T11:.+]] = vector.shape_cast %[[T9]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T12:.+]] = vector.shape_cast %[[T1]]0 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T13:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T1]]1, %[[T1]]2 : vector<8xbf16> to vector<4xf32>
+// %[[T14:.+]] = vector.shape_cast %[[T1]]3 : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T1]]4, %[[TMP_RES_0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 1, 0]
+// %[[T16:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T17:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T18:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T19:.+]] = vector.shape_cast %[[T1]]6 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T20:.+]] = vector.shape_cast %[[T1]]7 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T21:.+]] = vector.shape_cast %[[T1]]8 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T2]]1, %[[T1]]9, %[[T2]]0 : vector<8xbf16> to vector<4xf32>
+// %[[T23:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T2]]3, %[[TMP_RES_1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 1, 1]
+// %[[T25:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T26:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T27:.+]] = vector.shape_cast %[[T2]]5 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T28:.+]] = vector.shape_cast %[[T2]]6 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T29:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T2]]7, %[[T2]]8 : vector<8xbf16> to vector<4xf32>
+// %[[T30:.+]] = vector.shape_cast %[[T2]]9 : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_3:.+]] = vector.insert_strided_slice %[[T3]]0, %[[TMP_RES_2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 0, 0]
+// %[[T32:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T33:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T34:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T35:.+]] = vector.shape_cast %[[T3]]2 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T36:.+]] = vector.shape_cast %[[T3]]3 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T37:.+]] = vector.shape_cast %[[T3]]4 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_2:.+]] = arm_neon.intr.bfmmla %[[T3]]7, %[[T3]]5, %[[T3]]6 : vector<8xbf16> to vector<4xf32>
+// %[[T39:.+]] = vector.shape_cast %[[K_ACC_2]] : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_4:.+]] = vector.insert_strided_slice %[[T3]]9, %[[TMP_RES_3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 0, 1]
+// %[[T41:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T42:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T43:.+]] = vector.shape_cast %[[T4]]1 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T44:.+]] = vector.shape_cast %[[T4]]2 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T45:.+]] = arm_neon.intr.bfmmla %[[K_ACC_2]], %[[T4]]3, %[[T4]]4 : vector<8xbf16> to vector<4xf32>
+// %[[T46:.+]] = vector.shape_cast %[[T4]]5 : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_5:.+]] = vector.insert_strided_slice %[[T4]]6,%[[TMP_RES_4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 1, 0]
+// %[[T48:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T49:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T50:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T51:.+]] = vector.shape_cast %[[T4]]8 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T52:.+]] = vector.shape_cast %[[T4]]9 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T53:.+]] = vector.shape_cast %[[T5]]0 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_3:.+]] = arm_neon.intr.bfmmla %[[T5]]3, %[[T5]]1, %[[T5]]2 : vector<8xbf16> to vector<4xf32>
+// %[[T55:.+]] = vector.shape_cast %[[K_ACC_3]] : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_6:.+]] = vector.insert_strided_slice %[[T5]]5,%[[TMP_RES_5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 1, 1]
+// %[[T57:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T58:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T59:.+]] = vector.shape_cast %[[T5]]7 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T60:.+]] = vector.shape_cast %[[T5]]8 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T61:.+]] = arm_neon.intr.bfmmla %[[K_ACC_3]], %[[T5]]9, %[[T6]]0 : vector<8xbf16> to vector<4xf32>
+// %[[T62:.+]] = vector.shape_cast %[[T6]]1 : vector<4xf32> to vector<2x2xf32>
+// %[[RESULT:.+]] = vector.insert_strided_slice %[[T6]]2,%[[TMP_RES_6]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// return %[[RESULT]] : vector<4x4xf32>
+
+func.func @vector_contract_to_bfmmla(%lhs: vector<4x8xbf16>,
+                                     %rhs: vector<4x8xbf16>,
+                                     %acc: vector<4x4xf32>) -> vector<4x4xf32> {
+  %0 = vector.contract { indexing_maps = [
+                          affine_map<(m, n, k) -> (m, k)>,
+                          affine_map<(m, n, k) -> (n, k)>,
+                          affine_map<(m, n, k) -> (m, n)>
+                        ],
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>
+                      }
+    %lhs, %rhs, %acc : vector<4x8xbf16>, vector<4x8xbf16> into vector<4x4xf32>
+
+  return %0 : vector<4x4xf32>
+}
+
+// Test lowering of vector.contract, representing vector by matrix multiply and
+// accumulate, to BFMMLA operations.
+
+// For each iteration [J, K] sub-tiles are extracted from offsets as follows:
+//   LHS: [4*K]
+//   RHS: [2*J, 4*K]
+//   ACC: [2*J]
+// Sub-tile insert offsets for the result are as like ACC (there are redundant
+// inserts).
+// CHECK-LABEL: func.func @vector_contract_vecmat_to_bfmmla
+// CHECK-SAME:   %[[LHS:.+]]: vector<8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4xf32>) -> vector<4xf32> {
+// CHECK: %[[ACC_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[LHS_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x4xbf16>
+// CHECK: %[[RES_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+
+// Iteration [0, 0]
+// Extract sub-tiles
+// CHECK: %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+
+// Pad LHS sub-tile/vector with an extra row of zeroes
+// CHECK: %[[T3:.+]] = vector.insert_strided_slice %[[T0]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+
+// Pad ACC sub-tile/vector with an extra row of zeroes
+// CHECK: %[[T4:.+]] = vector.insert_strided_slice %[[T2]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32>
+
+// Flatten the operands to fit the `bfmmla` operation types
+// CHECK: %[[T5:.+]] = vector.shape_cast %[[T3]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T6:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T7:.+]] = vector.shape_cast %[[T4]] : vector<2x2xf32> to vector<4xf32>
+
+// Perform the matrix multiply and accumulate
+// CHECK: %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T7]], %[[T5]], %[[T6]] : vector<8xbf16> to vector<4xf32>
+
+// Un-flatten the output sub-tile
+// CHECK: %[[T9:.+]] = vector.shape_cast %[[K_ACC_0]] : vector<4xf32> to vector<2x2xf32>
+
+// Extract the first rows (the second row is padding) and insert into the result
+// CHECK: %[[T10:.+]] = vector.extract %[[T9]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T10]], %[[RES_INIT]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [0, 1]
+// CHECK: %[[T12:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T13:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T14:.+]] = vector.insert_strided_slice %[[T12]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T15:.+]] = vector.shape_cast %[[T14]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T16:.+]] = vector.shape_cast %[[T13]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T17:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T15]], %[[T16]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T18:.+]] = vector.shape_cast %[[T17]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T19:.+]] = vector.extract %[[T18]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T19]], %[[TMP_RES_0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [1, 0]
+// CHECK: %[[T21:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T22:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T23:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: %[[T24:.+]] = vector.insert_strided_slice %[[T21]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T25:.+]] = vector.insert_strided_slice %[[T23]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32>
+// CHECK: %[[T26:.+]] = vector.shape_cast %[[T24]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T27:.+]] = vector.shape_cast %[[T22]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T28:.+]] = vector.shape_cast %[[T25]] : vector<2x2xf32> to vector<4xf32>
+// CHECK: %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T28]], %[[T26]], %[[T27]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T30:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T31:.+]] = vector.extract %[[T30]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T31]], %[[TMP_RES_1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [1, 1]
+// CHECK: %[[T33:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T34:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T35:.+]] = vector.insert_strided_slice %[[T33]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T36:.+]] = vector.shape_cast %[[T35]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T37:.+]] = vector.shape_cast %[[T34]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T38:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T36]], %[[T37]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T39:.+]] = vector.shape_cast %[[T38]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T40:.+]] = vector.extract %[[T39]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RESULT:.+]] = vector.insert_strided_slice %[[T40]], %[[TMP_RES_2]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: return %[[RESULT]] : vector<4xf32>
+func.func @vector_contract_vecmat_to_bfmmla(%lhs: vector<8xbf16>,
+                                            %rhs: vector<4x8xbf16>,
+                                            %acc: vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.contract { indexing_maps = [
+                          affine_map<(n, k) -> (k)>,
+                          affine_map<(n, k) -> (n, k)>,
+                          affine_map<(n, k) -> (n)>
+                        ],
+                        iterator_types = ["parallel", "reduction"],
+                        kind = #vector.kind<add>
+                      }
+    %lhs, %rhs, %acc : vector<8xbf16>, vector<4x8xbf16> into vector<4xf32>
+
+  return %0 : vector<4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.arm_neon.vector_contract_to_bfmmla
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
new file mode 100644
index 0000000000000..b62ae040f364b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -0,0 +1,176 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-neon enable-arm-bf16' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  \
+// DEFINE:   --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+bf16" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+
+//
+// Test the lowering of `vector.contract` using the `LowerContractionToNeonBFMMLAPattern`
+//
+// The operation that the `vector.contract` in this test performs is matrix
+// multiplication with accumulate
+//     OUT = ACC + LHS * RHS
+// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT.
+//
+// Tested are calculations as well as that the relevant `ArmNeon` dialect
+// operation (`arm_neon.intr.bfmmla`) is emitted.
+//
+// That pattern above handles (therefore this test prepares) input/output vectors with
+// specific shapes:
+//   * LHS:      vector<MxKxbf16>
+//   * RHS:      vector<NxKxbf16>
+//   * ACC, OUT: vector<MxNxf32>
+// where the M and N are even and K is divisible by 4.
+// Note that the RHS is transposed.
+// This data layout makes it efficient to load data into SIMD
+// registers in the layout expected by BFMMLA instruction.
+// Such a `vector.contract` is representative of the code we aim to generate
+// by vectorisation of `linalg.mmt4d`.
+//
+// In this specific test we use M == 4, N == 4, and K == 4.
+
+// CHECK-IR-LABEL: llvm.func @matrix_by_matrix_mul_and_acc
+// CHECK-IR-COUNT-4: arm_neon.intr.bfmmla
+func.func @matrix_by_matrix_mul_and_acc() {
+
+  %c0 = arith.constant 0 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %c0_bf16 = arith.constant 0.0 : bf16
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[ 0.7,  1.0, -0.1,  1.8],
+                                   [-0.5,  0.9,  0.7, -0.7],
+                                   [ 0.5, -1.3, -2.2,  0.1],
+                                   [-0.7,  1.0,  1.7, -1.0]]> : vector<4x4xf32>
+
+  %acc_mem = memref.alloc() : memref<4x4xf32>
+  vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32>
+  %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[ 0.1,  0.7, -0.9,  1.3],
+                                   [-1.6,  0.7, -0.3, -0.3],
+                                   [-0.4,  0.6,  0.8, -0.5],
+                                   [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
+
+  %lhs_mem = memref.alloc() : memref<4x4xbf16>
+  vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+  %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[ 0.6,  1.3,  0.1, -0.9],
+                                   [ 0.5,  1.6,  1.8,  1.6],
+                                   [-0.2,  0.4,  1.0,  0.4],
+                                   [-1.3, -0.2, -2.2,  0.3]]> : vector<4x4xbf16>
+
+  %rhs_mem = memref.alloc() : memref<4x4xbf16>
+  vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+  %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+  // Matrix multiplication and accumulate with transposed RHS.
+  %0 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<4x4xbf16>, vector<4x4xbf16> into vector<4x4xf32>
+
+  // Display the result of the multiplication
+  vector.print str "Result(BFMMLA):\n"
+  %u0 = vector.extract %0[0] : vector<4xf32> from vector<4x4xf32>
+  %u1 = vector.extract %0[1] : vector<4xf32> from vector<4x4xf32>
+  %u2 = vector.extract %0[2] : vector<4xf32> from vector<4x4xf32>
+  %u3 = vector.extract %0[3] : vector<4xf32> from vector<4x4xf32>
+  vector.print %u0 : vector<4xf32>
+  vector.print %u1 : vector<4xf32>
+  vector.print %u2 : vector<4xf32>
+  vector.print %u3 : vector<4xf32>
+
+  return
+}
+
+// Test when the LHS is a one-dimensional vector.
+// 
+// In the vector by matrix case the dhapes ae as follows:
+//   * LHS:      vector<Kxbf16>
+//   * RHS:      vector<NxKxbf16>
+//   * ACC, OUT: vector<Nxf32>
+// N is even and K is divisible by 4.
+// In this specific test we use N == 4, and K == 4.
+
+// CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc
+// CHECK-IR-COUNT-2: arm_neon.intr.bfmmla
+func.func @vector_by_matrix_mul_and_acc() {
+  %c0 = arith.constant 0 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %c0_bf16 = arith.constant 0.0 : bf16
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[0.7,  1.0, -0.1,  1.8]> : vector<4xf32>
+
+  %acc_mem = memref.alloc() : memref<4xf32>
+  vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32>
+  %acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[0.1,  0.7, -0.9,  1.3]> : vector<4xbf16>
+
+  %lhs_mem = memref.alloc() : memref<4xbf16>
+  vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16>
+  %lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[ 0.6,  1.3,  0.1, -0.9],
+                                   [ 0.5,  1.6,  1.8,  1.6],
+                                   [-0.2,  0.4,  1.0,  0.4],
+                                   [-1.3, -0.2, -2.2,  0.3]]> : vector<4x4xbf16>
+
+  %rhs_mem = memref.alloc() : memref<4x4xbf16>
+  vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+  %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+  // Vector by matrix multiplication and accumulate with transposed RHS.
+  %0 = vector.contract { indexing_maps = [
+                           affine_map<(n, k) -> (k)>,
+                           affine_map<(n, k) -> (n, k)>,
+                           affine_map<(n, k) -> (n)>
+                         ],
+                         iterator_types = ["parallel", "reduction"],
+                         kind = #vector.kind<add>
+                       }
+    %lhs, %rhs, %acc : vector<4xbf16>, vector<4x4xbf16> into vector<4xf32>
+
+  // Display the result of the multiplication
+  vector.print str "Result(BFMMLA, vecmat):\n"
+  vector.print %0 : vector<4xf32>
+  
+  return
+}
+
+func.func @main() {
+  // CHECK-LABEL: Result(BFMMLA):
+  // CHECK: (  0.411922, 2.63254,  -0.219259,  3.89965 )
+  // CHECK: ( -0.316515, 0.196875,  0.879375,  1.80924 )
+  // CHECK: (  1.56867,  0.101367, -1.2784,   -1.41579 )
+  // CHECK: ( -1.56041, -4.30078,   0.0196488, 1.88269 )
+  func.call @matrix_by_matrix_mul_and_acc() : () -> ()
+
+  // CHECK-LABEL: Result(BFMMLA, vecmat):
+  // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 )
+  func.call @vector_by_matrix_mul_and_acc() : () -> ()
+
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
index 1ce55ca05c90e..f6012bbd3d0b2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
@@ -240,7 +240,7 @@ func.func @test_usmmla() {
 
 // Test the operation where LHS is interpreted as signed and RHS is interpreted
 // as unsigned. In this test we ultimately emit end execute the `usmmla`
-// instruction with reversed operands, see `LowerContractionToNeonI8MMPattern.cpp`
+// instruction with reversed operands, see `LowerContractoNeonPatterns.cpp`
 // for more details.
 
 // CHECK-IR-LABEL: llvm.func @test_summla



More information about the Mlir-commits mailing list