[Mlir-commits] [mlir] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations (PR #148198)
Momchil Velikov
llvmlistbot at llvm.org
Fri Jul 11 03:36:58 PDT 2025
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/148198
This is split in two commits:
* refactor I8MM lowering to make it easier to add ...
* ... BF16 lowering
>From 5ec628726772f32ad2cb4ccfbbf0a43532562fca Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 10 Jul 2025 11:10:19 +0000
Subject: [PATCH 1/2] [MLIR][AArch64] Refactor lowering of vector.contract to
Neon I8MM
This patch refactors the pattern in `Transforms/LowerContractionToNeonI8MMPattern.cpp`
using similar approach as in https://github.com/llvm/llvm-project/pull/147052
to prepare for adding BF16 support.
---
.../LowerContractionToNeonI8MMPattern.cpp | 431 ++++++++++--------
1 file changed, 247 insertions(+), 184 deletions(-)
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 7180884c77e98..f961d3cafb443 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -33,23 +33,15 @@ using namespace mlir;
using namespace mlir::arm_neon;
namespace {
-
-/// Return the shaped type with new element type.
-static Type matchContainerType(Type element, Type container) {
- if (auto shapedTy = dyn_cast<ShapedType>(container)) {
- return shapedTy.clone(element);
- }
- return element;
-}
-
-// Get the operand of a `vector.contract`. This function is intended to abstract
-// away from the particular way a value is extended before feeding it into the
-// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
-// (for implicit sign-extension see `vector.contract` documentation).
-//
-// The template parameter `Op` indicates the extension operation (explicit or
-// implicit) for which we are checking.
-//
+/// Get the operand of a `vector.contract`. This function is intended to
+/// abstract away from the particular way a value is extended before feeding it
+/// into the `vector.contract` - via zero-extend or an explicit or implicit
+/// sign-extend (for implicit sign-extension see `vector.contract`
+/// documentation).
+///
+/// The template parameter `Op` indicates the extension operation (explicit or
+/// implicit) for which we are checking.
+///
// Return success only for extensions from `iN` (N <= 8) to `i32`.
template <typename Op>
std::optional<Value> getExtOperand(Value v) {
@@ -87,202 +79,186 @@ std::optional<Value> getExtOperand(Value v) {
return inOp;
}
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
- Signed, // smmla
- Unsigned, // ummla
- Mixed, // usmmla
- MixedSwapped // usmmla with LHS and RHS swapped
-};
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+ bool signExt, PatternRewriter &rewriter) {
+ Type targetTy = srcTy.clone(rewriter.getI8Type());
+ return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+ : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
- mlir::Type accType, Value acc, Value lhs, Value rhs) {
- switch (op) {
- case MMLA::Signed:
- return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
- rhs);
- case MMLA::Unsigned:
- return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
- rhs);
- case MMLA::Mixed:
- return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
- rhs);
- case MMLA::MixedSwapped:
- // The accumulator comes transposed and the result will be transposed
- // later, so all we have to do here is swap the operands.
- return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
- lhs);
+class VectorContractRewriter {
+protected:
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
+ // multiplications.
+ enum class MMLA {
+ Nop,
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+ };
+
+ // Lower-level operation to be emitted.
+ MMLA mmlaOp = MMLA::Nop;
+
+ // The operand tiles. These are not necessarily the operands of
+ // `vector.contract`, for example they could be operands to `arith.extsi`
+ // that is in turn fed into `vector.contract`.
+ Value lhs;
+ Value rhs;
+ Value acc;
+
+ // The dimensions logically corresponding to matrix multiplication of
+ // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+ // shapes, for example RHS could be NxK with a transposing indexing map.
+ int64_t dimM = 0;
+ int64_t dimN = 0;
+ int64_t dimK = 0;
+
+ // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+ SmallVector<int64_t> iterationBounds;
+
+ // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+ // of this shape.
+ SmallVector<int64_t> subTileShape;
+
+ // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+ Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+ Value lhs, Value rhs) {
+ switch (mmlaOp) {
+ case MMLA::Signed:
+ return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+ rhs, lhs);
+ case MMLA::Nop:
+ llvm_unreachable("Uninitialized operation type");
+ }
}
-}
-/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
-/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
-/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
-/// necessary, a single smmla instruction is emitted.
-class LowerContractionToNeonI8MMPattern
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
- // Note: RHS is not transposed.
- mlir::VectorType lhsType = op.getLhsType();
- mlir::VectorType rhsType = op.getRhsType();
+ // Check common preconditions for applying the patterns and initialize
+ // logical dimensions.
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ // Check iterator types for matrix multiplication.
+ SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+ if (!((itTypes.size() == 3 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::parallel &&
+ itTypes[2] == vector::IteratorType::reduction)) ||
+ (itTypes.size() == 2 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::reduction))))
+ return rewriter.notifyMatchFailure(
+ op, "iterator types do not correspond to matrix multiplication");
+
// Avoid 0-D vectors and 1-D rhs:
- if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
- return failure();
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+ rhsType.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
// This codegen does not work for scalable vectors. Return failure so this
// pattern is not accidentally chosen over patterns that lower to ArmSVE.
if (lhsType.isScalable() || rhsType.isScalable())
- return failure();
- auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
- auto dimN = rhsType.getDimSize(0);
- auto dimK = rhsType.getDimSize(1);
- bool isVecmat = dimM == 1 ? true : false;
- if (lhsType.getDimSize(lhsType.getRank() - 1) !=
- rhsType.getDimSize(rhsType.getRank() - 1)) {
- return failure(); // dimK mismatch
- }
- // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
- // tiling.
- if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
- return failure();
- }
-
- // Check iterator types for contract. All iterators except inner-most
- // dimension must be parallel.
- auto iteratorTypes = op.getIteratorTypesArray();
- if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
- vector::IteratorType::reduction) {
- return failure();
- }
- if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
- [](vector::IteratorType iteratorType) {
- return iteratorType != vector::IteratorType::parallel;
- })) {
- return failure();
+ return rewriter.notifyMatchFailure(op,
+ "Not applicable to scalable vectors");
+
+ // Initialize dimensions and check for a matching K dimension.
+ dimM = lhsType.getDimSize(0);
+ dimN = rhsType.getDimSize(0);
+ dimK = rhsType.getDimSize(1);
+
+ int64_t lhsDimK;
+ if (lhsType.getRank() == 1) {
+ dimM = 1;
+ lhsDimK = lhsType.getDimSize(0);
+ } else {
+ lhsDimK = lhsType.getDimSize(1);
}
- // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
- // values before the extension. All four signed/unsigned combinations for
- // input operands are supported, but they are lowered to different
- // operations. Determine which is the appropriate operation to lower to.
- MMLA mmlaOp = MMLA::Signed;
- auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
- if (!maybeLhs) {
- mmlaOp = MMLA::Unsigned;
- maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
- }
- if (!maybeLhs)
- return failure();
+ if (lhsDimK != dimK)
+ return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
- auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
- if (maybeRhs) {
- if (mmlaOp == MMLA::Unsigned)
- mmlaOp = MMLA::Mixed;
- } else {
- if (mmlaOp == MMLA::Signed)
- mmlaOp = MMLA::MixedSwapped;
- maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
- }
- if (!maybeRhs)
- return failure();
+ return success();
+ }
- Value origLhs = *maybeLhs;
- Value origRhs = *maybeRhs;
-
- // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
- // following neon instruction. Check inputs for extsi are <=i8
- Value extLhs;
- Value extRhs;
- if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
- if (lhsExtInType.getElementTypeBitWidth() <= 8) {
- Type targetLhsExtTy =
- matchContainerType(rewriter.getI8Type(), lhsExtInType);
- if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
- extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
- origLhs);
- else
- extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
- origLhs);
- }
- }
- if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
- if (rhsExtInType.getElementTypeBitWidth() <= 8) {
- Type targetRhsExtTy =
- matchContainerType(rewriter.getI8Type(), rhsExtInType);
- if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
- extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
- origRhs);
- else
- extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
- origRhs);
- }
- }
+public:
+ void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
+ // Create some convenience types.
+ auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
+ auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
+ auto inputExpandedType =
+ VectorType::get({2, subTileShape.back()}, inputElementType);
+ auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+ // One-dimensional representation of logical sub-tiles as required by the
+ // ArmNeon ops.
+ auto collapsedInputType =
+ VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+ auto collapsedOutputType =
+ VectorType::get(outputExpandedType.getNumElements(), accElementType);
+
+ // Get indexing maps for a more concise/convenient access.
+ auto indexingMaps = op.getIndexingMapsArray();
+ AffineMap &lhsPermutationMap = indexingMaps[0];
+ AffineMap &rhsPermutationMap = indexingMaps[1];
+ AffineMap &accPermutationMap = indexingMaps[2];
- if (!extLhs || !extRhs) {
- return failure();
- }
+ Location loc = op.getLoc();
// Initial accumulator for the final result. This is the un-tiled result if
// tiling is done.
Value result = rewriter.create<arith::ConstantOp>(
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
- SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
- SmallVector<int64_t> smmlaShape = {2, 8};
- SmallVector<int64_t> loopOrder = {0, 1};
- if (unrolledSize.size() == 3) {
- smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+ SmallVector<int64_t, 3> loopOrder = {0, 1};
+ if (iterationBounds.size() == 3)
loopOrder.push_back(2);
- }
// Keep track of the previous accumulator when tiling over K.
Value kAcc;
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+ StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
- SmallVector<int64_t> operandShape =
- applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+ SmallVector<int64_t> operandShape = applyPermutationMap(
+ permutationMap, ArrayRef<int64_t>(subTileShape));
SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand, operandOffsets, operandShape, operandStrides);
};
// Extract tiled lhs, rhs, and acc
- AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
SmallVector<int64_t> lhsOffsets =
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
- Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
- AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+ Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
SmallVector<int64_t> rhsOffsets =
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
- Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
- AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+ Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
SmallVector<int64_t> accOffsets =
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
- Value tiledAcc =
- extractOperand(op.getAcc(), accPermutationMap, accOffsets);
-
- auto inputElementType =
- cast<ShapedType>(tiledLhs.getType()).getElementType();
- auto accElementType =
- cast<ShapedType>(tiledAcc.getType()).getElementType();
- auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
- auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+ Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
// With vecmat, tiled LHS and ACC will contain only one of 2 necessary
- // rows along dimM. Expand their shapes to match the smmla op.
- if (isVecmat) {
- auto expandForSMMLA = [&](Value tiledOperand,
- VectorType expandedTypeType) {
+ // rows along dimM. Expand their shapes to match the ArmNeon op.
+ if (dimM == 1) {
+ auto expandRowVector = [&](Value tiledOperand,
+ VectorType expandedTypeType) {
auto emptyOperand = rewriter.create<arith::ConstantOp>(
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
SmallVector<int64_t> offsets(
@@ -292,8 +268,8 @@ class LowerContractionToNeonI8MMPattern
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledOperand, emptyOperand, offsets, strides);
};
- tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
- tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+ tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
+ tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
}
// Transpose ACC if doing signed by unsigned multiplication, because we're
@@ -303,15 +279,11 @@ class LowerContractionToNeonI8MMPattern
tiledAcc = rewriter.create<vector::TransposeOp>(
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
- // Collapse tiled operands to 1D vectors required by smmla intrinsic
- auto collapsedInputType =
- VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+ // Collapse tiled operands to 1D vectors required by the ArmNeon ops
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
- auto collapsedOutputType =
- VectorType::get(outputExpandedType.getNumElements(), accElementType);
bool initialKAcc = offsets.back() == 0;
Value collapsedRes;
@@ -323,8 +295,8 @@ class LowerContractionToNeonI8MMPattern
}
// Insert contract op
- kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
- collapsedRes, collapsedLhs, collapsedRhs);
+ kAcc =
+ createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
// Reshape output back to 2D
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
@@ -338,9 +310,8 @@ class LowerContractionToNeonI8MMPattern
// With vecmat, only one row of tiled ACC can be inserted into the final
// result
- if (isVecmat) {
+ if (dimM == 1)
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
- }
// Insert the tiled result back into the non tiled result of the
// contract op.
@@ -351,6 +322,98 @@ class LowerContractionToNeonI8MMPattern
}
rewriter.replaceOp(op, result);
+ }
+};
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+ return failure();
+
+ // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
+ // tiling.
+ if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
+ return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+ // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
+ // values before the extension. All four signed/unsigned combinations for
+ // input operands are supported, but they are lowered to different
+ // operations. Determine which is the appropriate operation to lower to.
+ mmlaOp = MMLA::Signed;
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::Unsigned;
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+ }
+ if (!maybeLhs)
+ return rewriter.notifyMatchFailure(
+ op, "LHS is not a sign- or zero- extended iN, N <= 8");
+
+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::Unsigned)
+ mmlaOp = MMLA::Mixed;
+ } else {
+ if (mmlaOp == MMLA::Signed)
+ mmlaOp = MMLA::MixedSwapped;
+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
+ }
+
+ if (!maybeRhs)
+ return rewriter.notifyMatchFailure(
+ op, "RHS is not a sign- or zero- extended iN, N <= 8");
+
+ lhs = *maybeLhs;
+ rhs = *maybeRhs;
+ acc = op.getAcc();
+
+ // Extend inputs from iN, N < 8 to i8.
+ Location loc = op.getLoc();
+ auto lhsExtInType = cast<VectorType>(lhs.getType());
+ if (lhsExtInType.getElementTypeBitWidth() < 8)
+ lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
+ /* signExt */ mmlaOp == MMLA::Signed ||
+ mmlaOp == MMLA::Mixed,
+ rewriter);
+
+ auto rhsExtInType = cast<VectorType>(rhs.getType());
+ if (rhsExtInType.getElementTypeBitWidth() < 8)
+
+ rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
+ /* signExt */ mmlaOp != MMLA::Unsigned &&
+ mmlaOp != MMLA::Mixed,
+ rewriter);
+
+ // Initialize parameters for unrolling.
+ iterationBounds = *op.getShapeForUnroll();
+ if (iterationBounds.size() == 3)
+ subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 8});
+ else
+ subTileShape = SmallVector<int64_t>({2, 8});
+
+ return success();
+ }
+};
+
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
+/// any vector.contract into multiple smmla instructions with unrolling so long
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
+class LowerContractionToNeonI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorContractRewriterI8MM vcr;
+ if (failed(vcr.matchAndInit(op, rewriter)))
+ return failure();
+ vcr.rewrite(op, rewriter);
+
return success();
}
};
>From 442e29ac0b3ad54aa521ff02029889fa899a9c75 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 11 Jul 2025 10:03:18 +0000
Subject: [PATCH 2/2] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16
operations
---
mlir/include/mlir/Conversion/Passes.td | 4 +
.../TransformOps/ArmNeonVectorTransformOps.td | 15 +-
.../include/mlir/Dialect/ArmNeon/Transforms.h | 4 +-
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 4 +-
.../ArmNeonVectorTransformOps.cpp | 7 +-
.../Dialect/ArmNeon/Transforms/CMakeLists.txt | 2 +-
...rn.cpp => LowerContractToNeonPatterns.cpp} | 126 +++++++---
.../LowerContractionToSVEI8MMPattern.cpp | 2 +-
mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir | 225 ++++++++++++++++++
.../CPU/ArmNeon/vector-contract-bfmmla.mlir | 176 ++++++++++++++
.../CPU/ArmNeon/vector-contract-i8mm.mlir | 2 +-
11 files changed, 531 insertions(+), 36 deletions(-)
rename mlir/lib/Dialect/ArmNeon/Transforms/{LowerContractionToNeonI8MMPattern.cpp => LowerContractToNeonPatterns.cpp} (81%)
create mode 100644 mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..4f304b39a0528 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of Arm FEAT_I8MM instructions while lowering "
"the vector dialect.">,
+ Option<"armBF16", "enable-arm-bf16",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_BF16 instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
index bcaca7da967fa..35747126d3db1 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
@@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp
"apply_patterns.arm_neon.vector_contract_to_i8mm",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that vector.contract operations should be lowered to
- finer-grained vector primitives from the ArmNeon dialect.
+ Indicates that vector contract operations should be lowered to
+ to ArmNeon dialect operations mapping to instructions from FEAT_I8MM.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+def ApplyArmNeonContractionToBFMMLAPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contract operations should be lowered to
+ to ArmNeon dialect operations mapping to instructions from FEAT_BF16.
}];
let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
index 2f0f634a96770..08065a3b25266 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -13,8 +13,8 @@ namespace mlir {
class RewritePatternSet;
namespace arm_neon {
-void populateLowerContractionToNeonI8MMPatternPatterns(
- RewritePatternSet &patterns);
+void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns);
+void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns);
} // namespace arm_neon
} // namespace mlir
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 549d0210af7ad..1045824c437ab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -84,10 +84,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorGatherLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
- arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+ arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
if (armSVE)
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
+ if (armBF16 && armNeon)
+ arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
index d07e6a52d8b5f..d069bde6d9979 100644
--- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -20,7 +20,12 @@ using namespace mlir;
void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+ arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
+}
+
+void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
index 06bafde451cbb..368dacac7b835 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIRArmNeonTransforms
- LowerContractionToNeonI8MMPattern.cpp
+ LowerContractToNeonPatterns.cpp
DEPENDS
MLIRArmNeonIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
similarity index 81%
rename from mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
rename to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
index f961d3cafb443..06746daa8075b 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -1,4 +1,4 @@
-//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
+//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -95,15 +95,20 @@ class VectorContractRewriter {
// multiplications.
enum class MMLA {
Nop,
- Signed, // smmla
- Unsigned, // ummla
- Mixed, // usmmla
- MixedSwapped // usmmla with LHS and RHS swapped
+ SignedInt, // smmla
+ UnsignedInt, // ummla
+ MixedInt, // usmmla
+ Bfloat // bfmmla
};
// Lower-level operation to be emitted.
MMLA mmlaOp = MMLA::Nop;
+ // Indicate if the operands for the ArmNeon dialect operation need to be
+ // swapped. Currently this is needed in order to emulate an "summla"
+ // operation.
+ bool swapOperands = false;
+
// The operand tiles. These are not necessarily the operands of
// `vector.contract`, for example they could be operands to `arith.extsi`
// that is in turn fed into `vector.contract`.
@@ -128,21 +133,22 @@ class VectorContractRewriter {
// Create the matrix multiply and accumulate operation according to `mmlaOp`.
Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
Value lhs, Value rhs) {
+
+ if (swapOperands)
+ std::swap(lhs, rhs);
switch (mmlaOp) {
- case MMLA::Signed:
+ case MMLA::SignedInt:
return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
- case MMLA::Unsigned:
+ case MMLA::UnsignedInt:
return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
- case MMLA::Mixed:
+ case MMLA::MixedInt:
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
- case MMLA::MixedSwapped:
- // The accumulator comes transposed and the result will be transposed
- // later, so all we have to do here is swap the operands.
- return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
- rhs, lhs);
+ case MMLA::Bfloat:
+ return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
+ rhs);
case MMLA::Nop:
llvm_unreachable("Uninitialized operation type");
}
@@ -275,7 +281,7 @@ class VectorContractRewriter {
// Transpose ACC if doing signed by unsigned multiplication, because we're
// using the instruction for unsigned by signed multiplication with
// reversed operands.
- if (mmlaOp == MMLA::MixedSwapped)
+ if (swapOperands)
tiledAcc = rewriter.create<vector::TransposeOp>(
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
@@ -304,7 +310,7 @@ class VectorContractRewriter {
// Because of the reversed operands the result is obtained transposed.
// Transpose it back,
- if (mmlaOp == MMLA::MixedSwapped)
+ if (swapOperands)
tiledRes = rewriter.create<vector::TransposeOp>(
loc, tiledRes, ArrayRef<int64_t>({1, 0}));
@@ -341,10 +347,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
// values before the extension. All four signed/unsigned combinations for
// input operands are supported, but they are lowered to different
// operations. Determine which is the appropriate operation to lower to.
- mmlaOp = MMLA::Signed;
+ mmlaOp = MMLA::SignedInt;
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
if (!maybeLhs) {
- mmlaOp = MMLA::Unsigned;
+ mmlaOp = MMLA::UnsignedInt;
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
}
if (!maybeLhs)
@@ -353,11 +359,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
if (maybeRhs) {
- if (mmlaOp == MMLA::Unsigned)
- mmlaOp = MMLA::Mixed;
+ if (mmlaOp == MMLA::UnsignedInt)
+ mmlaOp = MMLA::MixedInt;
} else {
- if (mmlaOp == MMLA::Signed)
- mmlaOp = MMLA::MixedSwapped;
+ if (mmlaOp == MMLA::SignedInt) {
+ mmlaOp = MMLA::MixedInt;
+ swapOperands = true;
+ }
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
}
@@ -374,16 +382,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
auto lhsExtInType = cast<VectorType>(lhs.getType());
if (lhsExtInType.getElementTypeBitWidth() < 8)
lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
- /* signExt */ mmlaOp == MMLA::Signed ||
- mmlaOp == MMLA::Mixed,
+ /* signExt */
+ (mmlaOp == MMLA::SignedInt ||
+ (mmlaOp == MMLA::MixedInt && !swapOperands)),
rewriter);
auto rhsExtInType = cast<VectorType>(rhs.getType());
if (rhsExtInType.getElementTypeBitWidth() < 8)
-
rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
- /* signExt */ mmlaOp != MMLA::Unsigned &&
- mmlaOp != MMLA::Mixed,
+ /* signExt */
+ (mmlaOp == MMLA::SignedInt ||
+ (mmlaOp == MMLA::MixedInt && swapOperands)),
rewriter);
// Initialize parameters for unrolling.
@@ -397,6 +406,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
}
};
+class VectorContractRewriterBFMMLA : public VectorContractRewriter {
+public:
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+
+ if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+ return failure();
+
+ // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
+ // tiling.
+ if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
+ return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+ // Check the output is a vector of Float32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getResultType());
+ if (!outTy || outTy.getElementType() != rewriter.getF32Type())
+ return rewriter.notifyMatchFailure(op,
+ "output type is not a vector of f32");
+
+ // Check the inputs are vectors of BFloat16 elements.
+ if (op.getLhsType().getElementType() != rewriter.getBF16Type())
+ return rewriter.notifyMatchFailure(op,
+ "input type is not a vector of bf16");
+
+ mmlaOp = MMLA::Bfloat;
+ swapOperands = false;
+ lhs = op.getLhs();
+ rhs = op.getRhs();
+ acc = op.getAcc();
+
+ // Initialize parameters for unrolling.
+ iterationBounds = *op.getShapeForUnroll();
+ if (iterationBounds.size() == 3)
+ subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
+ else
+ subTileShape = SmallVector<int64_t>({2, 4});
+
+ return success();
+ }
+};
+
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -418,10 +468,32 @@ class LowerContractionToNeonI8MMPattern
}
};
+class LowerContractionToNeonBFMMLAPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorContractRewriterBFMMLA vcr;
+ if (failed(vcr.matchAndInit(op, rewriter)))
+ return failure();
+ vcr.rewrite(op, rewriter);
+
+ return success();
+ }
+};
+
} // namespace
-void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
+void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
}
+
+void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index b7703ff0393eb..f7a9499e2db07 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -12,7 +12,7 @@
// TODO: There may be opportunities to unify this with a similar pattern
// for Neon. See:
// https://github.com/llvm/llvm-project/issues/145559
-// LowerContractionToNeonI8MMPattern.cpp
+// LowerContracToNeonPatterns.cpp
//
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
new file mode 100644
index 0000000000000..229c4e5b2dc3a
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir
@@ -0,0 +1,225 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// Test lowering of vector.contract to BFMMLA operations.
+// For each iteration [I, J, K] sub-tiles are extracted from offsets as follows:
+// LHS: [2*I, 4*K]
+// RHS: [2*J, 4*K]
+// ACC: [2*I, 2*J]
+// Sub-tile insert offsets for the result are as like ACC (there are redundant
+// inserts).
+
+// CHECK-LABEL: func.func @vector_contract_to_bfmmla
+// CHECK-SAME: %[[LHS:.+]]: vector<4x8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4x4xf32>
+
+// %[[INIT_RES:.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+
+// Iteration [0, 0, 0]
+// Extract sib-tiles from each of LHS, RHS and ACC
+// %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+
+// Flatten the operands to fit the `bfmmla` operation types
+// %[[T3:.+]] = vector.shape_cast %[[T0]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T4:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T5:.+]] = vector.shape_cast %[[T2]] : vector<2x2xf32> to vector<4xf32>
+
+// Perform the matrix multiply and accumulate
+// %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T5]], %[[T3]], %[[T4]] : vector<8xbf16> to vector<4xf32>
+
+// Un-flatten the output sub-tile and inserr into the result
+// %[[T7:.+]] = vector.shape_cast %[[K_ACC_0]] : vectK_ACCor<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T7]], %[[INIT_RES]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 0, 1]
+// %[[T9:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T10:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T11:.+]] = vector.shape_cast %[[T9]] : vector<2x4xbf16> to vector<8xbf16>
+// %[[T12:.+]] = vector.shape_cast %[[T1]]0 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T13:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T1]]1, %[[T1]]2 : vector<8xbf16> to vector<4xf32>
+// %[[T14:.+]] = vector.shape_cast %[[T1]]3 : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T1]]4, %[[TMP_RES_0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 1, 0]
+// %[[T16:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T17:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T18:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T19:.+]] = vector.shape_cast %[[T1]]6 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T20:.+]] = vector.shape_cast %[[T1]]7 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T21:.+]] = vector.shape_cast %[[T1]]8 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T2]]1, %[[T1]]9, %[[T2]]0 : vector<8xbf16> to vector<4xf32>
+// %[[T23:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T2]]3, %[[TMP_RES_1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [0, 1, 1]
+// %[[T25:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T26:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T27:.+]] = vector.shape_cast %[[T2]]5 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T28:.+]] = vector.shape_cast %[[T2]]6 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T29:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T2]]7, %[[T2]]8 : vector<8xbf16> to vector<4xf32>
+// %[[T30:.+]] = vector.shape_cast %[[T2]]9 : vector<4xf32> to vector<2x2xf32>
+// %[[TMP_RES_3:.+]] = vector.insert_strided_slice %[[T3]]0, %[[TMP_RES_2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 0, 0]
+// %[[T32:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T33:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T34:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T35:.+]] = vector.shape_cast %[[T3]]2 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T36:.+]] = vector.shape_cast %[[T3]]3 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T37:.+]] = vector.shape_cast %[[T3]]4 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_2:.+]] = arm_neon.intr.bfmmla %[[T3]]7, %[[T3]]5, %[[T3]]6 : vector<8xbf16> to vector<4xf32>
+// %[[T39:.+]] = vector.shape_cast %[[K_ACC_2]] : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_4:.+]] = vector.insert_strided_slice %[[T3]]9, %[[TMP_RES_3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 0, 1]
+// %[[T41:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T42:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T43:.+]] = vector.shape_cast %[[T4]]1 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T44:.+]] = vector.shape_cast %[[T4]]2 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T45:.+]] = arm_neon.intr.bfmmla %[[K_ACC_2]], %[[T4]]3, %[[T4]]4 : vector<8xbf16> to vector<4xf32>
+// %[[T46:.+]] = vector.shape_cast %[[T4]]5 : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_5:.+]] = vector.insert_strided_slice %[[T4]]6,%[[TMP_RES_4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 1, 0]
+// %[[T48:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T49:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T50:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// %[[T51:.+]] = vector.shape_cast %[[T4]]8 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T52:.+]] = vector.shape_cast %[[T4]]9 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T53:.+]] = vector.shape_cast %[[T5]]0 : vector<2x2xf32> to vector<4xf32>
+// %[[K_ACC_3:.+]] = arm_neon.intr.bfmmla %[[T5]]3, %[[T5]]1, %[[T5]]2 : vector<8xbf16> to vector<4xf32>
+// %[[T55:.+]] = vector.shape_cast %[[K_ACC_3]] : vector<4xf32> to vector<2x2xf32>
+//%[[TMP_RES_6:.+]] = vector.insert_strided_slice %[[T5]]5,%[[TMP_RES_5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Iteration [1, 1, 1]
+// %[[T57:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T58:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// %[[T59:.+]] = vector.shape_cast %[[T5]]7 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T60:.+]] = vector.shape_cast %[[T5]]8 : vector<2x4xbf16> to vector<8xbf16>
+// %[[T61:.+]] = arm_neon.intr.bfmmla %[[K_ACC_3]], %[[T5]]9, %[[T6]]0 : vector<8xbf16> to vector<4xf32>
+// %[[T62:.+]] = vector.shape_cast %[[T6]]1 : vector<4xf32> to vector<2x2xf32>
+// %[[RESULT:.+]] = vector.insert_strided_slice %[[T6]]2,%[[TMP_RES_6]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// return %[[RESULT]] : vector<4x4xf32>
+
+func.func @vector_contract_to_bfmmla(%lhs: vector<4x8xbf16>,
+ %rhs: vector<4x8xbf16>,
+ %acc: vector<4x4xf32>) -> vector<4x4xf32> {
+ %0 = vector.contract { indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>
+ }
+ %lhs, %rhs, %acc : vector<4x8xbf16>, vector<4x8xbf16> into vector<4x4xf32>
+
+ return %0 : vector<4x4xf32>
+}
+
+// Test lowering of vector.contract, representing vector by matrix multiply and
+// accumulate, to BFMMLA operations.
+
+// For each iteration [J, K] sub-tiles are extracted from offsets as follows:
+// LHS: [4*K]
+// RHS: [2*J, 4*K]
+// ACC: [2*J]
+// Sub-tile insert offsets for the result are as like ACC (there are redundant
+// inserts).
+// CHECK-LABEL: func.func @vector_contract_vecmat_to_bfmmla
+// CHECK-SAME: %[[LHS:.+]]: vector<8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4xf32>) -> vector<4xf32> {
+// CHECK: %[[ACC_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[LHS_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x4xbf16>
+// CHECK: %[[RES_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+
+// Iteration [0, 0]
+// Extract sub-tiles
+// CHECK: %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+
+// Pad LHS sub-tile/vector with an extra row of zeroes
+// CHECK: %[[T3:.+]] = vector.insert_strided_slice %[[T0]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+
+// Pad ACC sub-tile/vector with an extra row of zeroes
+// CHECK: %[[T4:.+]] = vector.insert_strided_slice %[[T2]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32>
+
+// Flatten the operands to fit the `bfmmla` operation types
+// CHECK: %[[T5:.+]] = vector.shape_cast %[[T3]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T6:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T7:.+]] = vector.shape_cast %[[T4]] : vector<2x2xf32> to vector<4xf32>
+
+// Perform the matrix multiply and accumulate
+// CHECK: %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T7]], %[[T5]], %[[T6]] : vector<8xbf16> to vector<4xf32>
+
+// Un-flatten the output sub-tile
+// CHECK: %[[T9:.+]] = vector.shape_cast %[[K_ACC_0]] : vector<4xf32> to vector<2x2xf32>
+
+// Extract the first rows (the second row is padding) and insert into the result
+// CHECK: %[[T10:.+]] = vector.extract %[[T9]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T10]], %[[RES_INIT]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [0, 1]
+// CHECK: %[[T12:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T13:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T14:.+]] = vector.insert_strided_slice %[[T12]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T15:.+]] = vector.shape_cast %[[T14]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T16:.+]] = vector.shape_cast %[[T13]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T17:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T15]], %[[T16]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T18:.+]] = vector.shape_cast %[[T17]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T19:.+]] = vector.extract %[[T18]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T19]], %[[TMP_RES_0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [1, 0]
+// CHECK: %[[T21:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T22:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T23:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: %[[T24:.+]] = vector.insert_strided_slice %[[T21]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T25:.+]] = vector.insert_strided_slice %[[T23]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32>
+// CHECK: %[[T26:.+]] = vector.shape_cast %[[T24]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T27:.+]] = vector.shape_cast %[[T22]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T28:.+]] = vector.shape_cast %[[T25]] : vector<2x2xf32> to vector<4xf32>
+// CHECK: %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T28]], %[[T26]], %[[T27]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T30:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T31:.+]] = vector.extract %[[T30]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T31]], %[[TMP_RES_1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+
+// Iteration [1, 1]
+// CHECK: %[[T33:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16>
+// CHECK: %[[T34:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16>
+// CHECK: %[[T35:.+]] = vector.insert_strided_slice %[[T33]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16>
+// CHECK: %[[T36:.+]] = vector.shape_cast %[[T35]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T37:.+]] = vector.shape_cast %[[T34]] : vector<2x4xbf16> to vector<8xbf16>
+// CHECK: %[[T38:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T36]], %[[T37]] : vector<8xbf16> to vector<4xf32>
+// CHECK: %[[T39:.+]] = vector.shape_cast %[[T38]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: %[[T40:.+]] = vector.extract %[[T39]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RESULT:.+]] = vector.insert_strided_slice %[[T40]], %[[TMP_RES_2]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: return %[[RESULT]] : vector<4xf32>
+func.func @vector_contract_vecmat_to_bfmmla(%lhs: vector<8xbf16>,
+ %rhs: vector<4x8xbf16>,
+ %acc: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.contract { indexing_maps = [
+ affine_map<(n, k) -> (k)>,
+ affine_map<(n, k) -> (n, k)>,
+ affine_map<(n, k) -> (n)>
+ ],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ }
+ %lhs, %rhs, %acc : vector<8xbf16>, vector<4x8xbf16> into vector<4xf32>
+
+ return %0 : vector<4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.arm_neon.vector_contract_to_bfmmla
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
new file mode 100644
index 0000000000000..b62ae040f364b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir
@@ -0,0 +1,176 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-neon enable-arm-bf16' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \
+// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+bf16" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+
+//
+// Test the lowering of `vector.contract` using the `LowerContractionToNeonBFMMLAPattern`
+//
+// The operation that the `vector.contract` in this test performs is matrix
+// multiplication with accumulate
+// OUT = ACC + LHS * RHS
+// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT.
+//
+// Tested are calculations as well as that the relevant `ArmNeon` dialect
+// operation (`arm_neon.intr.bfmmla`) is emitted.
+//
+// That pattern above handles (therefore this test prepares) input/output vectors with
+// specific shapes:
+// * LHS: vector<MxKxbf16>
+// * RHS: vector<NxKxbf16>
+// * ACC, OUT: vector<MxNxf32>
+// where the M and N are even and K is divisible by 4.
+// Note that the RHS is transposed.
+// This data layout makes it efficient to load data into SIMD
+// registers in the layout expected by BFMMLA instruction.
+// Such a `vector.contract` is representative of the code we aim to generate
+// by vectorisation of `linalg.mmt4d`.
+//
+// In this specific test we use M == 4, N == 4, and K == 4.
+
+// CHECK-IR-LABEL: llvm.func @matrix_by_matrix_mul_and_acc
+// CHECK-IR-COUNT-4: arm_neon.intr.bfmmla
+func.func @matrix_by_matrix_mul_and_acc() {
+
+ %c0 = arith.constant 0 : index
+ %c0_f32 = arith.constant 0.0 : f32
+ %c0_bf16 = arith.constant 0.0 : bf16
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[[ 0.7, 1.0, -0.1, 1.8],
+ [-0.5, 0.9, 0.7, -0.7],
+ [ 0.5, -1.3, -2.2, 0.1],
+ [-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32>
+
+ %acc_mem = memref.alloc() : memref<4x4xf32>
+ vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32>
+ %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[ 0.1, 0.7, -0.9, 1.3],
+ [-1.6, 0.7, -0.3, -0.3],
+ [-0.4, 0.6, 0.8, -0.5],
+ [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
+
+ %lhs_mem = memref.alloc() : memref<4x4xbf16>
+ vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+ %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9],
+ [ 0.5, 1.6, 1.8, 1.6],
+ [-0.2, 0.4, 1.0, 0.4],
+ [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
+
+ %rhs_mem = memref.alloc() : memref<4x4xbf16>
+ vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+ %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+ // Matrix multiplication and accumulate with transposed RHS.
+ %0 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<4x4xbf16>, vector<4x4xbf16> into vector<4x4xf32>
+
+ // Display the result of the multiplication
+ vector.print str "Result(BFMMLA):\n"
+ %u0 = vector.extract %0[0] : vector<4xf32> from vector<4x4xf32>
+ %u1 = vector.extract %0[1] : vector<4xf32> from vector<4x4xf32>
+ %u2 = vector.extract %0[2] : vector<4xf32> from vector<4x4xf32>
+ %u3 = vector.extract %0[3] : vector<4xf32> from vector<4x4xf32>
+ vector.print %u0 : vector<4xf32>
+ vector.print %u1 : vector<4xf32>
+ vector.print %u2 : vector<4xf32>
+ vector.print %u3 : vector<4xf32>
+
+ return
+}
+
+// Test when the LHS is a one-dimensional vector.
+//
+// In the vector by matrix case the dhapes ae as follows:
+// * LHS: vector<Kxbf16>
+// * RHS: vector<NxKxbf16>
+// * ACC, OUT: vector<Nxf32>
+// N is even and K is divisible by 4.
+// In this specific test we use N == 4, and K == 4.
+
+// CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc
+// CHECK-IR-COUNT-2: arm_neon.intr.bfmmla
+func.func @vector_by_matrix_mul_and_acc() {
+ %c0 = arith.constant 0 : index
+ %c0_f32 = arith.constant 0.0 : f32
+ %c0_bf16 = arith.constant 0.0 : bf16
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[0.7, 1.0, -0.1, 1.8]> : vector<4xf32>
+
+ %acc_mem = memref.alloc() : memref<4xf32>
+ vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32>
+ %acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[0.1, 0.7, -0.9, 1.3]> : vector<4xbf16>
+
+ %lhs_mem = memref.alloc() : memref<4xbf16>
+ vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16>
+ %lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9],
+ [ 0.5, 1.6, 1.8, 1.6],
+ [-0.2, 0.4, 1.0, 0.4],
+ [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
+
+ %rhs_mem = memref.alloc() : memref<4x4xbf16>
+ vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
+ %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+ // Vector by matrix multiplication and accumulate with transposed RHS.
+ %0 = vector.contract { indexing_maps = [
+ affine_map<(n, k) -> (k)>,
+ affine_map<(n, k) -> (n, k)>,
+ affine_map<(n, k) -> (n)>
+ ],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ }
+ %lhs, %rhs, %acc : vector<4xbf16>, vector<4x4xbf16> into vector<4xf32>
+
+ // Display the result of the multiplication
+ vector.print str "Result(BFMMLA, vecmat):\n"
+ vector.print %0 : vector<4xf32>
+
+ return
+}
+
+func.func @main() {
+ // CHECK-LABEL: Result(BFMMLA):
+ // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 )
+ // CHECK: ( -0.316515, 0.196875, 0.879375, 1.80924 )
+ // CHECK: ( 1.56867, 0.101367, -1.2784, -1.41579 )
+ // CHECK: ( -1.56041, -4.30078, 0.0196488, 1.88269 )
+ func.call @matrix_by_matrix_mul_and_acc() : () -> ()
+
+ // CHECK-LABEL: Result(BFMMLA, vecmat):
+ // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 )
+ func.call @vector_by_matrix_mul_and_acc() : () -> ()
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
index 1ce55ca05c90e..f6012bbd3d0b2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir
@@ -240,7 +240,7 @@ func.func @test_usmmla() {
// Test the operation where LHS is interpreted as signed and RHS is interpreted
// as unsigned. In this test we ultimately emit end execute the `usmmla`
-// instruction with reversed operands, see `LowerContractionToNeonI8MMPattern.cpp`
+// instruction with reversed operands, see `LowerContractoNeonPatterns.cpp`
// for more details.
// CHECK-IR-LABEL: llvm.func @test_summla
More information about the Mlir-commits
mailing list