[llvm-branch-commits] [mlir] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations (PR #148198)
Momchil Velikov via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jul 22 07:23:07 PDT 2025
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/148198
>From 14a83220d7aefdaa94bf771055fd398c273ec53b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 11 Jul 2025 10:03:18 +0000
Subject: [PATCH 1/4] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16
operations
---
.../TransformOps/ArmNeonVectorTransformOps.td | 15 +-
.../include/mlir/Dialect/ArmNeon/Transforms.h | 4 +-
.../Dialect/ArmSVE/Transforms/Transforms.h | 3 +-
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 13 +-
.../ArmNeonVectorTransformOps.cpp | 7 +-
.../Dialect/ArmNeon/Transforms/CMakeLists.txt | 2 +-
...rn.cpp => LowerContractToNeonPatterns.cpp} | 126 +++++++---
.../TransformOps/ArmSVEVectorTransformOps.cpp | 2 +-
.../Transforms/LowerContractToSVEPatterns.cpp | 4 +-
mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir | 225 ++++++++++++++++++
.../CPU/ArmNeon/vector-contract-bfmmla.mlir | 176 ++++++++++++++
.../CPU/ArmNeon/vector-contract-i8mm.mlir | 2 +-
12 files changed, 535 insertions(+), 44 deletions(-)
rename mlir/lib/Dialect/ArmNeon/Transforms/{LowerContractionToNeonI8MMPattern.cpp => LowerContractToNeonPatterns.cpp} (81%)
create mode 100644 mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
index bcaca7da967fa..35747126d3db1 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
@@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp
"apply_patterns.arm_neon.vector_contract_to_i8mm",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that vector.contract operations should be lowered to
- finer-grained vector primitives from the ArmNeon dialect.
+ Indicates that vector contract operations should be lowered to
+ to ArmNeon dialect operations mapping to instructions from FEAT_I8MM.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+def ApplyArmNeonContractionToBFMMLAPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contract operations should be lowered to
+ to ArmNeon dialect operations mapping to instructions from FEAT_BF16.
}];
let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
index 2f0f634a96770..08065a3b25266 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -13,8 +13,8 @@ namespace mlir {
class RewritePatternSet;
namespace arm_neon {
-void populateLowerContractionToNeonI8MMPatternPatterns(
- RewritePatternSet &patterns);
+void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns);
+void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns);
} // namespace arm_neon
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index de160dbf8ed94..0019192a31a02 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,8 +20,7 @@ class RewritePatternSet;
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
-void populateLowerContractionToSVEI8MMPatternPatterns(
- RewritePatternSet &patterns);
+void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns);
void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index d3d0a45eb2463..cf108690c3741 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -96,13 +96,16 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorGatherLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
- arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+ arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
if (armSVE)
- populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+ populateLowerContractionToSVEI8MMPatterns(patterns);
+ }
+ if (armBF16) {
+ if (armNeon)
+ arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
+ if (armSVE)
+ populateLowerContractionToSVEBFMMLAPatterns(patterns);
}
- if (armBF16)
- populateLowerContractionToSVEBFMMLAPatterns(patterns);
-
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
index d07e6a52d8b5f..d069bde6d9979 100644
--- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -20,7 +20,12 @@ using namespace mlir;
void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+ arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
+}
+
+void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
index 06bafde451cbb..368dacac7b835 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIRArmNeonTransforms
- LowerContractionToNeonI8MMPattern.cpp
+ LowerContractToNeonPatterns.cpp
DEPENDS
MLIRArmNeonIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
similarity index 81%
rename from mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
rename to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
index 59acb362191a7..5aadaece68732 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -1,4 +1,4 @@
-//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
+//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -93,15 +93,20 @@ class VectorContractRewriter {
// multiplications.
enum class MMLA {
Nop,
- Signed, // smmla
- Unsigned, // ummla
- Mixed, // usmmla
- MixedSwapped // usmmla with LHS and RHS swapped
+ SignedInt, // smmla
+ UnsignedInt, // ummla
+ MixedInt, // usmmla
+ Bfloat // bfmmla
};
// Lower-level operation to be emitted.
MMLA mmlaOp = MMLA::Nop;
+ // Indicate if the operands for the ArmNeon dialect operation need to be
+ // swapped. Currently this is needed in order to emulate an "summla"
+ // operation.
+ bool swapOperands = false;
+
// The operand tiles. These are not necessarily the operands of
// `vector.contract`, for example they could be operands to `arith.extsi`
// that is in turn fed into `vector.contract`.
@@ -126,21 +131,22 @@ class VectorContractRewriter {
// Create the matrix multiply and accumulate operation according to `mmlaOp`.
Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
Value lhs, Value rhs) {
+
+ if (swapOperands)
+ std::swap(lhs, rhs);
switch (mmlaOp) {
- case MMLA::Signed:
+ case MMLA::SignedInt:
return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
- case MMLA::Unsigned:
+ case MMLA::UnsignedInt:
return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
- case MMLA::Mixed:
+ case MMLA::MixedInt:
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
- case MMLA::MixedSwapped:
- // The accumulator comes transposed and the result will be transposed
- // later, so all we have to do here is swap the operands.
- return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
- rhs, lhs);
+ case MMLA::Bfloat:
+ return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
+ rhs);
case MMLA::Nop:
llvm_unreachable("Uninitialized operation type");
}
@@ -273,7 +279,7 @@ class VectorContractRewriter {
// Transpose ACC if doing signed by unsigned multiplication, because we're
// using the instruction for unsigned by signed multiplication with
// reversed operands.
- if (mmlaOp == MMLA::MixedSwapped)
+ if (swapOperands)
tiledAcc = rewriter.create<vector::TransposeOp>(
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
@@ -302,7 +308,7 @@ class VectorContractRewriter {
// Because of the reversed operands the result is obtained transposed.
// Transpose it back,
- if (mmlaOp == MMLA::MixedSwapped)
+ if (swapOperands)
tiledRes = rewriter.create<vector::TransposeOp>(
loc, tiledRes, ArrayRef<int64_t>({1, 0}));
@@ -339,10 +345,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
// values before the extension. All four signed/unsigned combinations for
// input operands are supported, but they are lowered to different
// operations. Determine which is the appropriate operation to lower to.
- mmlaOp = MMLA::Signed;
+ mmlaOp = MMLA::SignedInt;
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
if (!maybeLhs) {
- mmlaOp = MMLA::Unsigned;
+ mmlaOp = MMLA::UnsignedInt;
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
}
if (!maybeLhs)
@@ -351,11 +357,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
if (maybeRhs) {
- if (mmlaOp == MMLA::Unsigned)
- mmlaOp = MMLA::Mixed;
+ if (mmlaOp == MMLA::UnsignedInt)
+ mmlaOp = MMLA::MixedInt;
} else {
- if (mmlaOp == MMLA::Signed)
- mmlaOp = MMLA::MixedSwapped;
+ if (mmlaOp == MMLA::SignedInt) {
+ mmlaOp = MMLA::MixedInt;
+ swapOperands = true;
+ }
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
}
@@ -372,16 +380,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
auto lhsExtInType = cast<VectorType>(lhs.getType());
if (lhsExtInType.getElementTypeBitWidth() < 8)
lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
- /* signExt */ mmlaOp == MMLA::Signed ||
- mmlaOp == MMLA::Mixed,
+ /* signExt */
+ (mmlaOp == MMLA::SignedInt ||
+ (mmlaOp == MMLA::MixedInt && !swapOperands)),
rewriter);
auto rhsExtInType = cast<VectorType>(rhs.getType());
if (rhsExtInType.getElementTypeBitWidth() < 8)
-
rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
- /* signExt */ mmlaOp != MMLA::Unsigned &&
- mmlaOp != MMLA::Mixed,
+ /* signExt */
+ (mmlaOp == MMLA::SignedInt ||
+ (mmlaOp == MMLA::MixedInt && swapOperands)),
rewriter);
// Initialize parameters for unrolling.
@@ -395,6 +404,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
}
};
+class VectorContractRewriterBFMMLA : public VectorContractRewriter {
+public:
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+
+ if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+ return failure();
+
+ // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
+ // tiling.
+ if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
+ return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+ // Check the output is a vector of Float32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getResultType());
+ if (!outTy || outTy.getElementType() != rewriter.getF32Type())
+ return rewriter.notifyMatchFailure(op,
+ "output type is not a vector of f32");
+
+ // Check the inputs are vectors of BFloat16 elements.
+ if (op.getLhsType().getElementType() != rewriter.getBF16Type())
+ return rewriter.notifyMatchFailure(op,
+ "input type is not a vector of bf16");
+
+ mmlaOp = MMLA::Bfloat;
+ swapOperands = false;
+ lhs = op.getLhs();
+ rhs = op.getRhs();
+ acc = op.getAcc();
+
+ // Initialize parameters for unrolling.
+ iterationBounds = *op.getShapeForUnroll();
+ if (iterationBounds.size() == 3)
+ subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
+ else
+ subTileShape = SmallVector<int64_t>({2, 4});
+
+ return success();
+ }
+};
+
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -416,10 +466,32 @@ class LowerContractionToNeonI8MMPattern
}
};
+class LowerContractionToNeonBFMMLAPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorContractRewriterBFMMLA vcr;
+ if (failed(vcr.matchAndInit(op, rewriter)))
+ return failure();
+ vcr.lower(op, rewriter);
+
+ return success();
+ }
+};
+
} // namespace
-void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
+void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
}
+
+void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
index 8572c34c8b12b..d355fe010821a 100644
--- a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -20,7 +20,7 @@ using namespace mlir;
void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+ mlir::populateLowerContractionToSVEI8MMPatterns(patterns);
}
void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
index f63eac91a38aa..ac1df3889ecfd 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -12,7 +12,7 @@
// TODO: There may be opportunities to unify this with a similar pattern
// for Neon. See:
// https://github.com/llvm/llvm-project/issues/145559
-// LowerContractionToNeonI8MMPattern.cpp
+// LowerContractToNeonPatterns.cpp
//
//===----------------------------------------------------------------------===//
@@ -580,7 +580,7 @@ class LowerContractionToSVEBFMMLAPattern
} // namespace
-void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+void mlir::populateLowerContractionToSVEI8MMPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
diff --git a/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
new file mode 100644
index 0000000000000..229c4e5b2dc3a
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
@@ -0,0 +1,225 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// Test lowering of vector.contract to BFMMLA operations.
+// For each iteration [I, J, K] sub-tiles are extracted from offsets as follows:
+// LHS: [2*I, 4*K]
+// RHS: [2*J, 4*K]
+// ACC: [2*I, 2*J]
+// Sub-tile insert offsets for the result are as like ACC (there are redundant
+// inserts).
+
+// CHECK-LABEL: func.func @vector_contract_to_bfmmla
+// CHECK-SAME: %[[LHS:.+]]: vector<4x8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4x4xf32>
+
+// %[[INIT_RES:.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+
+// Iteration [0, 0, 0]
+// Extract sib-tiles from each of LHS, RHS and ACC
+// %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+
+// Flatten the operands to fit the `bfmmla` operation types
+// %[[T3:.+]] = vector.shape_cast %[[T0]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T4:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T5:.+]] = vector.shape_cast %[[T2]] : vector<2x2xf32> to vector<4xf32>
+
+// Perform the matrix multiply and accumulate
+// %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T5]], %[[T3]], %[[T4]] : vector<8xbf16> to vector<4xf32>
+
+// Un-flatten the output sub-tile and inserr into the result
+// %[[T7:.+]] = vector.shape_cast %[[K_ACC_0]] : vectK_ACCor<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T7]], %[[INIT_RES]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 0, 1]
+// %[[T9:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T10:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T11:.+]] = vector.shape_cast %[[T9]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T12:.+]] = vector.shape_cast %[[T1]]0 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T13:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T1]]1, %[[T1]]2 : vector<8xbf16> to vector<4xf32>
+// %[[T14:.+]] = vector.shape_cast %[[T1]]3 : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T1]]4, %[[TMP_RES_0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 1, 0]
+// %[[T16:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T17:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T18:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T19:.+]] = vector.shape_cast %[[T1]]6 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T20:.+]] = vector.shape_cast %[[T1]]7 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T21:.+]] = vector.shape_cast %[[T1]]8 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T2]]1, %[[T1]]9, %[[T2]]0 : vector<8xbf16> to vector<4xf32>
+// %[[T23:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T2]]3, %[[TMP_RES_1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 1, 1]
+// %[[T25:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T26:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T27:.+]] = vector.shape_cast %[[T2]]5 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T28:.+]] = vector.shape_cast %[[T2]]6 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T29:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T2]]7, %[[T2]]8 : vector<8xbf16> to vector<4xf32>
+// %[[T30:.+]] = vector.shape_cast %[[T2]]9 : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_3:.+]] = vector.insert_strided_slice %[[T3]]0, %[[TMP_RES_2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 0, 0]
+// %[[T32:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T33:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T34:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T35:.+]] = vector.shape_cast %[[T3]]2 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T36:.+]] = vector.shape_cast %[[T3]]3 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T37:.+]] = vector.shape_cast %[[T3]]4 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_2:.+]] = arm_neon.intr.bfmmla %[[T3]]7, %[[T3]]5, %[[T3]]6 : vector<8xbf16> to vector<4xf32>
+// %[[T39:.+]] = vector.shape_cast %[[K_ACC_2]] : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_4:.+]] = vector.insert_strided_slice %[[T3]]9, %[[TMP_RES_3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 0, 1]
+// %[[T41:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T42:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T43:.+]] = vector.shape_cast %[[T4]]1 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T44:.+]] = vector.shape_cast %[[T4]]2 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T45:.+]] = arm_neon.intr.bfmmla %[[K_ACC_2]], %[[T4]]3, %[[T4]]4 : vector<8xbf16> to vector<4xf32>
+// %[[T46:.+]] = vector.shape_cast %[[T4]]5 : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_5:.+]] = vector.insert_strided_slice %[[T4]]6,%[[TMP_RES_4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 1, 0]
+// %[[T48:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T49:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T50:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T51:.+]] = vector.shape_cast %[[T4]]8 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T52:.+]] = vector.shape_cast %[[T4]]9 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T53:.+]] = vector.shape_cast %[[T5]]0 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_3:.+]] = arm_neon.intr.bfmmla %[[T5]]3, %[[T5]]1, %[[T5]]2 : vector<8xbf16> to vector<4xf32>
+// %[[T55:.+]] = vector.shape_cast %[[K_ACC_3]] : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_6:.+]] = vector.insert_strided_slice %[[T5]]5,%[[TMP_RES_5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 1, 1]
+// %[[T57:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T58:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T59:.+]] = vector.shape_cast %[[T5]]7 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T60:.+]] = vector.shape_cast %[[T5]]8 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T61:.+]] = arm_neon.intr.bfmmla %[[K_ACC_3]], %[[T5]]9, %[[T6]]0 : vector<8xbf16> to vector<4xf32>
+// %[[T62:.+]] = vector.shape_cast %[[T6]]1 : vector<4xf32> to vector<2x2xf32>
+// %[[RESULT:.+]] = vector.insert_strided_slice %[[T6]]2,%[[TMP_RES_6]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// return %[[RESULT]] : vector<4x4xf32>
+
+func.func @vector_contract_to_bfmmla(%lhs: vector<4x8xbf16>,
+ %rhs: vector<4x8xbf16>,
+ %acc: vector<4x4xf32>) -> vector<4x4xf32> {
+ %0 = vector.contract { indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>
+ }
+ %lhs, %rhs, %acc : vector<4x8xbf16>, vector<4x8xbf16> into vector<4x4xf32>
+
+ return %0 : vector<4x4xf32>
+}
+
+// Test lowering of vector.contract, representing vector by matrix multiply and
+// accumulate, to BFMMLA operations.
+
+// For each iteration [J, K] sub-tiles are extracted from offsets as follows:
+// LHS: [4*K]
+// RHS: [2*J, 4*K]
+// ACC: [2*J]
+// Sub-tile insert offsets for the result are as like ACC (there are redundant
+// inserts).
+// CHECK-LABEL: func.func @vector_contract_vecmat_to_bfmmla
+// CHECK-SAME: %[[LHS:.+]]: vector<8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4xf32>) -> vector<4xf32> {
+// CHECK: %[[ACC_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[LHS_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x4xbf16>
+// CHECK: %[[RES_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+
+// Iteration [0, 0]
+// Extract sub-tiles
+// CHECK: %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+
+// Pad LHS sub-tile/vector with an extra row of zeroes
+// CHECK: %[[T3:.+]] = vector.insert_strided_slice %[[T0]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+
+// Pad ACC sub-tile/vector with an extra row of zeroes
+// CHECK: %[[T4:.+]] = vector.insert_strided_slice %[[T2]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32>
+
+// Flatten the operands to fit the `bfmmla` operation types
+// CHECK: %[[T5:.+]] = vector.shape_cast %[[T3]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T6:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T7:.+]] = vector.shape_cast %[[T4]] : vector<2x2xf32> to vector<4xf32>
+
+// Perform the matrix multiply and accumulate
+// CHECK: %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T7]], %[[T5]], %[[T6]] : vector<8xbf16> to vector<4xf32>
+
+// Un-flatten the output sub-tile
+// CHECK: %[[T9:.+]] = vector.shape_cast %[[K_ACC_0]] : vector<4xf32> to vector<2x2xf32>
+
+// Extract the first rows (the second row is padding) and insert into the result
+// CHECK: %[[T10:.+]] = vector.extract %[[T9]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T10]], %[[RES_INIT]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [0, 1]
+// CHECK: %[[T12:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T13:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T14:.+]] = vector.insert_strided_slice %[[T12]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T15:.+]] = vector.shape_cast %[[T14]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T16:.+]] = vector.shape_cast %[[T13]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T17:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T15]], %[[T16]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T18:.+]] = vector.shape_cast %[[T17]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T19:.+]] = vector.extract %[[T18]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T19]], %[[TMP_RES_0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [1, 0]
+// CHECK: %[[T21:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T22:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T23:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: %[[T24:.+]] = vector.insert_strided_slice %[[T21]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T25:.+]] = vector.insert_strided_slice %[[T23]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32>
+// CHECK: %[[T26:.+]] = vector.shape_cast %[[T24]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T27:.+]] = vector.shape_cast %[[T22]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T28:.+]] = vector.shape_cast %[[T25]] : vector<2x2xf32> to vector<4xf32>
+// CHECK: %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T28]], %[[T26]], %[[T27]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T30:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T31:.+]] = vector.extract %[[T30]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T31]], %[[TMP_RES_1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [1, 1]
+// CHECK: %[[T33:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T34:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T35:.+]] = vector.insert_strided_slice %[[T33]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T36:.+]] = vector.shape_cast %[[T35]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T37:.+]] = vector.shape_cast %[[T34]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T38:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T36]], %[[T37]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T39:.+]] = vector.shape_cast %[[T38]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T40:.+]] = vector.extract %[[T39]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RESULT:.+]] = vector.insert_strided_slice %[[T40]], %[[TMP_RES_2]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: return %[[RESULT]] : vector<4xf32>
+func.func @vector_contract_vecmat_to_bfmmla(%lhs: vector<8xbf16>,
+ %rhs: vector<4x8xbf16>,
+ %acc: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.contract { indexing_maps = [
+ affine_map<(n, k) -> (k)>,
+ affine_map<(n, k) -> (n, k)>,
+ affine_map<(n, k) -> (n)>
+ ],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ }
+ %lhs, %rhs, %acc : vector<8xbf16>, vector<4x8xbf16> into vector<4xf32>
+
+ return %0 : vector<4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.arm_neon.vector_contract_to_bfmmla
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
new file mode 100644
index 0000000000000..b62ae040f364b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -0,0 +1,176 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-neon enable-arm-bf16' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \
+// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+bf16" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+
+//
+// Test the lowering of `vector.contract` using the `LowerContractionToNeonBFMMLAPattern`
+//
+// The operation that the `vector.contract` in this test performs is matrix
+// multiplication with accumulate
+// OUT = ACC + LHS * RHS
+// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT.
+//
+// Tested are calculations as well as that the relevant `ArmNeon` dialect
+// operation (`arm_neon.intr.bfmmla`) is emitted.
+//
+// That pattern above handles (therefore this test prepares) input/output vectors with
+// specific shapes:
+// * LHS: vector<MxKxbf16>
+// * RHS: vector<NxKxbf16>
+// * ACC, OUT: vector<MxNxf32>
+// where the M and N are even and K is divisible by 4.
+// Note that the RHS is transposed.
+// This data layout makes it efficient to load data into SIMD
+// registers in the layout expected by BFMMLA instruction.
+// Such a `vector.contract` is representative of the code we aim to generate
+// by vectorisation of `linalg.mmt4d`.
+//
+// In this specific test we use M == 4, N == 4, and K == 4.
+
+// CHECK-IR-LABEL: llvm.func @matrix_by_matrix_mul_and_acc
+// CHECK-IR-COUNT-4: arm_neon.intr.bfmmla
+func.func @matrix_by_matrix_mul_and_acc() {
+
+ %c0 = arith.constant 0 : index
+ %c0_f32 = arith.constant 0.0 : f32
+ %c0_bf16 = arith.constant 0.0 : bf16
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[[ 0.7, 1.0, -0.1, 1.8],
+ [-0.5, 0.9, 0.7, -0.7],
+ [ 0.5, -1.3, -2.2, 0.1],
+ [-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32>
+
+ %acc_mem = memref.alloc() : memref<4x4xf32>
+ vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32>
+ %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[ 0.1, 0.7, -0.9, 1.3],
+ [-1.6, 0.7, -0.3, -0.3],
+ [-0.4, 0.6, 0.8, -0.5],
+ [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
+
+ %lhs_mem = memref.alloc() : memref<4x4xbf16>
+ vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+ %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9],
+ [ 0.5, 1.6, 1.8, 1.6],
+ [-0.2, 0.4, 1.0, 0.4],
+ [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
+
+ %rhs_mem = memref.alloc() : memref<4x4xbf16>
+ vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+ %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+ // Matrix multiplication and accumulate with transposed RHS.
+ %0 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<4x4xbf16>, vector<4x4xbf16> into vector<4x4xf32>
+
+ // Display the result of the multiplication
+ vector.print str "Result(BFMMLA):\n"
+ %u0 = vector.extract %0[0] : vector<4xf32> from vector<4x4xf32>
+ %u1 = vector.extract %0[1] : vector<4xf32> from vector<4x4xf32>
+ %u2 = vector.extract %0[2] : vector<4xf32> from vector<4x4xf32>
+ %u3 = vector.extract %0[3] : vector<4xf32> from vector<4x4xf32>
+ vector.print %u0 : vector<4xf32>
+ vector.print %u1 : vector<4xf32>
+ vector.print %u2 : vector<4xf32>
+ vector.print %u3 : vector<4xf32>
+
+ return
+}
+
+// Test when the LHS is a one-dimensional vector.
+//
+// In the vector by matrix case the dhapes ae as follows:
+// * LHS: vector<Kxbf16>
+// * RHS: vector<NxKxbf16>
+// * ACC, OUT: vector<Nxf32>
+// N is even and K is divisible by 4.
+// In this specific test we use N == 4, and K == 4.
+
+// CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc
+// CHECK-IR-COUNT-2: arm_neon.intr.bfmmla
+func.func @vector_by_matrix_mul_and_acc() {
+ %c0 = arith.constant 0 : index
+ %c0_f32 = arith.constant 0.0 : f32
+ %c0_bf16 = arith.constant 0.0 : bf16
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[0.7, 1.0, -0.1, 1.8]> : vector<4xf32>
+
+ %acc_mem = memref.alloc() : memref<4xf32>
+ vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32>
+ %acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[0.1, 0.7, -0.9, 1.3]> : vector<4xbf16>
+
+ %lhs_mem = memref.alloc() : memref<4xbf16>
+ vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16>
+ %lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9],
+ [ 0.5, 1.6, 1.8, 1.6],
+ [-0.2, 0.4, 1.0, 0.4],
+ [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
+
+ %rhs_mem = memref.alloc() : memref<4x4xbf16>
+ vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+ %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+ // Vector by matrix multiplication and accumulate with transposed RHS.
+ %0 = vector.contract { indexing_maps = [
+ affine_map<(n, k) -> (k)>,
+ affine_map<(n, k) -> (n, k)>,
+ affine_map<(n, k) -> (n)>
+ ],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ }
+ %lhs, %rhs, %acc : vector<4xbf16>, vector<4x4xbf16> into vector<4xf32>
+
+ // Display the result of the multiplication
+ vector.print str "Result(BFMMLA, vecmat):\n"
+ vector.print %0 : vector<4xf32>
+
+ return
+}
+
+func.func @main() {
+ // CHECK-LABEL: Result(BFMMLA):
+ // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 )
+ // CHECK: ( -0.316515, 0.196875, 0.879375, 1.80924 )
+ // CHECK: ( 1.56867, 0.101367, -1.2784, -1.41579 )
+ // CHECK: ( -1.56041, -4.30078, 0.0196488, 1.88269 )
+ func.call @matrix_by_matrix_mul_and_acc() : () -> ()
+
+ // CHECK-LABEL: Result(BFMMLA, vecmat):
+ // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 )
+ func.call @vector_by_matrix_mul_and_acc() : () -> ()
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
index 1ce55ca05c90e..f6012bbd3d0b2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
@@ -240,7 +240,7 @@ func.func @test_usmmla() {
// Test the operation where LHS is interpreted as signed and RHS is interpreted
// as unsigned. In this test we ultimately emit end execute the `usmmla`
-// instruction with reversed operands, see `LowerContractionToNeonI8MMPattern.cpp`
+// instruction with reversed operands, see `LowerContractoNeonPatterns.cpp`
// for more details.
// CHECK-IR-LABEL: llvm.func @test_summla
>From 6501a5b67a802a7ed0e13805156bbf8de19de25b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 21 Jul 2025 11:16:02 +0000
Subject: [PATCH 2/4] [fixup] Rename a member function and chanege some allocs
to allocas
---
.../Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
index b62ae040f364b..9acc97da0d53c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -58,7 +58,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
[ 0.5, -1.3, -2.2, 0.1],
[-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32>
- %acc_mem = memref.alloc() : memref<4x4xf32>
+ %acc_mem = memref.alloca() : memref<4x4xf32>
vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32>
%acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32>
@@ -68,7 +68,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
[-0.4, 0.6, 0.8, -0.5],
[-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
- %lhs_mem = memref.alloc() : memref<4x4xbf16>
+ %lhs_mem = memref.alloca() : memref<4x4xbf16>
vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
%lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
@@ -78,7 +78,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
[-0.2, 0.4, 1.0, 0.4],
[-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
- %rhs_mem = memref.alloc() : memref<4x4xbf16>
+ %rhs_mem = memref.alloca() : memref<4x4xbf16>
vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
%rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
@@ -121,14 +121,14 @@ func.func @vector_by_matrix_mul_and_acc() {
// Accumulator test data
%acc_cst = arith.constant dense<[0.7, 1.0, -0.1, 1.8]> : vector<4xf32>
- %acc_mem = memref.alloc() : memref<4xf32>
+ %acc_mem = memref.alloca() : memref<4xf32>
vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32>
%acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
// LHS test data
%lhs_cst = arith.constant dense<[0.1, 0.7, -0.9, 1.3]> : vector<4xbf16>
- %lhs_mem = memref.alloc() : memref<4xbf16>
+ %lhs_mem = memref.alloca() : memref<4xbf16>
vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16>
%lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16>
@@ -138,7 +138,7 @@ func.func @vector_by_matrix_mul_and_acc() {
[-0.2, 0.4, 1.0, 0.4],
[-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
- %rhs_mem = memref.alloc() : memref<4x4xbf16>
+ %rhs_mem = memref.alloca() : memref<4x4xbf16>
vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
%rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
>From 5322c0f8149879235896e1fd0a15a2475ed14e1e Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 21 Jul 2025 13:01:35 +0000
Subject: [PATCH 3/4] [fixup] Add a comment about memory ops
---
.../Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
index 9acc97da0d53c..368f332e40602 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -110,6 +110,10 @@ func.func @matrix_by_matrix_mul_and_acc() {
// * ACC, OUT: vector<Nxf32>
// N is even and K is divisible by 4.
// In this specific test we use N == 4, and K == 4.
+// Note: the seemingly unnecessary writes of test vectors to memory are done
+// in order to introduce memory load operations on the path leading up to
+// `vector.contract` since that's more representation of the typical usage
+// when multiplying big, non-constant tensors.
// CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc
// CHECK-IR-COUNT-2: arm_neon.intr.bfmmla
>From a95561a0c493e7e12ea0b22064da490623c02a23 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 21 Jul 2025 13:10:58 +0000
Subject: [PATCH 4/4] [fixup] Move a comment, it was accidentally in the wrong
place
---
.../Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
index 368f332e40602..4285260906251 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -44,6 +44,12 @@
//
// In this specific test we use M == 4, N == 4, and K == 4.
+// Note: In this and in the following test the seemingly unnecessary
+// writes of test vectors to memory are done in order to introduce memory
+// load operations on the path leading up to `vector.contract` since
+// that's more representation of the typical usage when multiplying
+// big, non-constant tensors.
+
// CHECK-IR-LABEL: llvm.func @matrix_by_matrix_mul_and_acc
// CHECK-IR-COUNT-4: arm_neon.intr.bfmmla
func.func @matrix_by_matrix_mul_and_acc() {
@@ -110,10 +116,6 @@ func.func @matrix_by_matrix_mul_and_acc() {
// * ACC, OUT: vector<Nxf32>
// N is even and K is divisible by 4.
// In this specific test we use N == 4, and K == 4.
-// Note: the seemingly unnecessary writes of test vectors to memory are done
-// in order to introduce memory load operations on the path leading up to
-// `vector.contract` since that's more representation of the typical usage
-// when multiplying big, non-constant tensors.
// CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc
// CHECK-IR-COUNT-2: arm_neon.intr.bfmmla
More information about the llvm-branch-commits
mailing list