[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` to SVE FEAT_BF16 operations (PR #147052)
Momchil Velikov
llvmlistbot at llvm.org
Fri Jul 4 06:26:21 PDT 2025
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/147052
This patch adds lowering of Bfloat16 widening matrix multiply and accumulate `vector.contract`, by parametrising and refactoring the pattern for 8-bit integers.
>From c87e3c0bfc75d941af4748ee4d804ddb4f08ced3 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 4 Jul 2025 12:53:05 +0000
Subject: [PATCH] [MLIR][AArch64] Lower `vector.contract` to SVE FEAT_BF16
operations
This patch adds lowering of Bfloat16 widening matrix multiply and
accumulate `vector.contract`, by parametrising and
refactoring the pattern for 8-bit integers.
---
mlir/include/mlir/Conversion/Passes.td | 4 +
.../TransformOps/ArmSVEVectorTransformOps.td | 12 +-
.../Dialect/ArmSVE/Transforms/Transforms.h | 2 +
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 +
.../LowerContractionToNeonI8MMPattern.cpp | 2 +-
.../TransformOps/ArmSVEVectorTransformOps.cpp | 7 +-
.../Dialect/ArmSVE/Transforms/CMakeLists.txt | 2 +-
.../Transforms/LowerContractToSVEPatterns.cpp | 607 ++++++++++++++++++
.../LowerContractionToSVEI8MMPattern.cpp | 366 -----------
.../Vector/CPU/ArmSVE/vector-bfmmla.mlir | 105 +++
.../CPU/ArmSVE/vector-contract-bfmmla.mlir | 201 ++++++
.../CPU/ArmSVE/vector-contract-i8mm.mlir | 6 +-
12 files changed, 944 insertions(+), 373 deletions(-)
create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
delete mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..4f304b39a0528 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of Arm FEAT_I8MM instructions while lowering "
"the vector dialect.">,
+ Option<"armBF16", "enable-arm-bf16",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_BF16 instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
index 53784982be6dc..81b3c736b93f3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
@@ -12,7 +12,7 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
-def ApplyArmSVELowerContractionPatternsOp
+def ApplyArmSVELowerContractionToI8MMPatternsOp
: Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_i8mm",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
@@ -23,4 +23,14 @@ def ApplyArmSVELowerContractionPatternsOp
let assemblyFormat = "attr-dict";
}
+def ApplyArmSVELowerContractionToBFMMLAPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_bfmmla",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contraction-like operations should be lowered to
+ finer-grained vector primitives using the ArmSVE dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
#endif // ARMSVE_VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 232e2be29e574..de160dbf8ed94 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -23,6 +23,8 @@ void populateArmSVELegalizeForLLVMExportPatterns(
void populateLowerContractionToSVEI8MMPatternPatterns(
RewritePatternSet &patterns);
+void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);
+
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 67c0eca15638a..4d74aabcaa50d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -89,6 +89,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
if (armSVE)
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
+ if (armBF16)
+ populateLowerContractionToSVEBFMMLAPatterns(patterns);
+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 7180884c77e98..a95fc51d562c2 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -12,7 +12,7 @@
// TODO: There may be opportunities to unify this with a similar pattern
// for SVE. See:
// https://github.com/llvm/llvm-project/issues/145559
-// LowerContractionToSVEI8MMPattern.cpp
+// LowerContractToSVEPatterns.cpp
//
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
index b2ca4fc1eaa8c..8572c34c8b12b 100644
--- a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -18,11 +18,16 @@ using namespace mlir;
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
-void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
+void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
+void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ mlir::populateLowerContractionToSVEBFMMLAPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 65f98b44b1b69..c29eaca244b4a 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
- LowerContractionToSVEI8MMPattern.cpp
+ LowerContractToSVEPatterns.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
new file mode 100644
index 0000000000000..2987287afe9cd
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -0,0 +1,607 @@
+//===- LowerContractToSVEPatterns.cpp - Contract to I8MM/BF16 ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the SVE FEAT_I8MM and FEAT_BF16 extensions.
+//
+// TODO: There may be opportunities to unify this with a similar pattern
+// for Neon. See:
+// https://github.com/llvm/llvm-project/issues/145559
+// LowerContractionToNeonI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+
+namespace {
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `i8` to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+ static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+ "Must be instantiated with either sign- or zero- extension op");
+
+ // If the operand is not defined by an explicit extend operation of the
+ // accepted operation type allow for an implicit sign-extension.
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ if (!extOp) {
+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+ auto vTy = cast<VectorType>(v.getType());
+ if (!vTy.getElementType().isSignlessInteger(8))
+ return {};
+ return v;
+ }
+ return {};
+ }
+
+ // If the operand is defined by an explicit extend operation of the accepted
+ // operation type, check it's extended from `i8` to `i32`.
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || !inTy.getElementType().isSignlessInteger(8))
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || !outTy.getElementType().isSignlessInteger(32))
+ return {};
+
+ return inOp;
+}
+
+/// This class encapsulates the algorithm and parametrisation (in terms of types
+/// and dimensions) of lowering a `vector.contract` to "primitive" matrix
+/// multiplication operations of the SVE dialect (here "primitive" would mean
+/// corresponding to a single target instruction).
+///
+/// Supported are lowering to FEAT_I8MM `smmla`, `ummla`, and `usmmla`, and to
+/// FEAT_BF16 `bfmmla`. All the transformations are very similar to each other
+/// for concreteness the description below is given for `smmla`.
+///
+/// The lowering triggers for a contraction operation that performs a matrix
+/// multiply of two 8-bit integer matrix tiles with logical dimensions
+/// <Mx8> and <8x[N]> for the left-hand side (LHS) and the right-hand side
+/// (RHS), respectively, added to a 32-bit integer accumulator operand (ACC)
+/// with dimensions <Mx[N]>, yielding a <Mx[N]> 32-bit integer result (OUT).
+///
+/// The operands' shapes are such that the operands can be evenly split into
+/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
+/// instructions. The intent is that M and N are chosen (by higher level
+/// transforms) in such a way as to maximise register usage. The main use case
+/// we envision as of now is MMT4D, thus the RHS operand is expected
+/// pre-transposed.
+///
+/// The matrix multiplication is performed by unrolling the usual tiled matrix
+/// multiplication algorithm using sub-tiles with dimensions <2x8> for the
+/// LHS, <8x[2]> for the RHS, and <2x[2]> for the result and the input
+/// accumulator.
+///
+/// One way to illustrate the operation is as follows:
+///
+/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
+/// +-----------------------------
+/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+/// ... | ... ... ... ...
+/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///
+/// The RHS operand is unpacked into N/2 values, each representing a sequence
+/// of VSCALE number of sub-tiles with dimensions <8x2>.
+/// The LHS operand is initially unpacked into M/2 values, each representing a
+/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
+/// VSCALE times. Multiplying thus replicated LHS sub-tile by the corresponding
+/// RHS sub-tile correctly computes an entire result sub-tile.
+/// The 2x2 sub-tiles of the ACC and OUT have rows that are not adjacent
+/// (in memory or when imposing a row-major layout on the 2D vector value).
+/// Reading the ACC is implemented as reading two consecutive rows and
+/// interleaving the by pairs to obtain a vector having length twice the length
+/// of an ACC row. This vector now is a sequence of one-dimensional tiles with
+/// the exact layout needed by the `smmla`/`bfmmla`/etc instructions, which
+/// tiles are extracted one by one. For illustration, if we have an 2x4 ACC tile
+/// a0 a1 b0 b1
+/// a2 a3 b2 b3
+/// we read the two rows as separate values and then interleave by pairs
+/// to obtain
+/// a0 a1 a2 a3 b0 b1 b2 b3
+/// from which we extract `a0 a1 a2 a3` and `b0 b1 b2 b3`.
+///
+/// Writing the OUT tile is done by the reverse of the above procedure,
+/// concatenate two "flattened" sub-tiles into
+/// c0 c1 c2 c3 d0 d1 d2 d3
+/// deinterleave by pairs to obtain as separate values
+/// c0 c1 d0 d1
+/// c2 c3 d2 d3
+/// which are then inserted into the final result.
+///
+/// Multiplication of a signed LHS by an unsigned LHS is performed by
+/// swapping the order of the operands and emitting an `usmmla` (since there
+/// isn't an `summla` instruction). Therefore each ACC sub-tile needs
+/// to be transposed before the addition and the sum, an OUT sub-tile,
+/// needs to be transposed before insertion into the final result.
+/// This is done very elegantly by a modification of the above to
+/// interleave/deinterleave not by pairs, but by individual elements, e.g.
+/// after ordinary interleave we obtain
+/// a0 a2 a1 a3 b0 b2 b1 b3
+/// which is exactly the desired layout of having each individual 2x2 tile
+/// transposed.
+///
+/// All of the above readily applies to FEAT_BF16 `bfmmla` with the
+/// difference that the shapes of the LHS, RHS are <Mx4>, <4x[M]>, and
+/// respectively, that is the "K" dimension is fixed to 4, instead of 8 (like
+/// for the integer case).
+class VectorContractRewriter {
+protected:
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
+ // multiplications.
+ enum class MMLA {
+ Nop,
+ SignedInt, // smmla
+ UnsignedInt, // ummla
+ MixedInt, // usmmla
+ Bfloat // bfmmla
+ };
+
+ // Lower-level operation to be emitted.
+ MMLA mmlaOp = MMLA::Nop;
+
+ // The operand tiles. These are not necessarily the operends of
+ // `vector.contract`, for example they could be operands to `arith.extsi`
+ // that is in turn fed into `vector.contract`.
+ Value lhs;
+ Value rhs;
+ Value acc;
+
+ // Conventional names for matrix dimensions.
+ int64_t M = 0;
+ int64_t N = 0;
+ int64_t K = 0;
+
+ // Single-dimensional vector types for the operands of the ArmSVE dialect
+ // op.
+ VectorType flatLhsType;
+ VectorType flatRhsType;
+ VectorType flatAccType;
+
+ // Single-dimension vector type for the entire RHS tile.
+ VectorType flatRhsTileType;
+
+ // Vector type having the same number of elements as a row in the
+ // accumulator/output tile and the same element type.
+ VectorType accRowTy;
+
+ // Vector type having twice the number of elements as a row in the
+ // accumulator/output tile the same element type.
+ VectorType accRowX2Ty;
+
+ // Vector type having half the number of elements as a row in the
+ // accumulator/output tile and an integer element type with twice the bit
+ // width.
+ VectorType accRow64Ty;
+ VectorType accRowX264Ty;
+
+ // Indicate if the operands for the ArmSVE dialect operation need to be
+ // swapped. Currently this is needed in order to emulate an "summla"
+ // operation.
+ bool swapOperands = false;
+
+ // Create the matrix mulitply and accumulate operation according to
+ // `mmlaOp`.
+ Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+ Value lhs, Value rhs);
+
+ // Check general preconditions for applying the transformation, common to the
+ // integer and the bfloat16 case.
+ LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter);
+
+public:
+ VectorContractRewriter() = default;
+
+ // Do the actuall rewrite. This member function is shared by both integer and
+ // bfloat16 rewrites.
+ Value rewrite(vector::ContractionOp op, PatternRewriter &rewriter);
+};
+
+Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
+ Location loc, Value acc, Value lhs,
+ Value rhs) {
+ if (swapOperands)
+ std::swap(lhs, rhs);
+
+ switch (mmlaOp) {
+ case MMLA::SignedInt:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::UnsignedInt:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::MixedInt:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::Bfloat:
+ return rewriter.create<arm_sve::BfmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ default:
+ llvm_unreachable("Uninitialized operation kind");
+ }
+}
+
+LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(
+ op, "iterator types do not correspond to matrix multiplication");
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // This corresponds to matrix multiplication with transposed RHS.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] != AffineMap::getMultiDimMapWithTargets(
+ 3, ArrayRef{0u, 1u}, op.getContext()))
+ return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(op, "combining kind is not an addition");
+
+ return success();
+}
+
+Value VectorContractRewriter::rewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ Location loc = op.getLoc();
+
+ // Extract LHS sub-tiles with logical shape <2xK>.
+ SmallVector<Value> lhsTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the LHS tile.
+ auto r0 =
+ rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i});
+ auto r1 =
+ rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i + 1});
+ // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
+ SmallVector<int64_t> shuffleIdx(2 * K);
+ std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
+ auto t = rewriter.create<vector::ShuffleOp>(loc, r0, r1, shuffleIdx);
+ // Turn it into a scalable vector.
+ auto s = rewriter.create<vector::ScalableInsertOp>(
+ loc, t, rewriter.create<ub::PoisonOp>(loc, flatLhsType), 0);
+ // Replicate the sub-tile VSCALE times to fill the entire vector.
+ auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+ lhsTile.push_back(r);
+ }
+
+ // "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
+ auto rhs = rewriter.create<vector::ShapeCastOp>(this->rhs.getLoc(),
+ flatRhsTileType, this->rhs);
+
+ // Extract the RHS sub-tiles with logical shape <Kx[2]>.
+ SmallVector<Value> rhsTile;
+ for (int64_t j = 0; j < N; j += 2)
+ rhsTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+ loc, flatRhsType, rhs, j * K));
+
+ // Extract and pack the ACC sub-tiles.
+ SmallVector<Value> accTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the accumulator tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i + 1});
+ Value accTileVec;
+ if (swapOperands) {
+ // We are performing the operation with swapped LHS and RHS we need to
+ // transpose each individual 2x2 tile of the accumulator and (later) the
+ // final result.
+ accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+ } else {
+ // Bitcast accumulator rows to double-width integer elements, so
+ // subsequent interleave/deinterleave work on pairs of elements.
+ auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+ auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+ // Interleave the rows, effectively flattening each 2x2 tile into 4
+ // consecutive elements.
+ auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64);
+
+ // Bitcast back to original element type.
+ accTileVec = rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
+ }
+ // Extract ACC sub-tiles.
+ for (int64_t j = 0; j < N; j += 2)
+ accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+ loc, flatAccType, accTileVec, j * 2));
+ }
+
+ // Emit sub-tile matrix multiplications.
+ SmallVector<Value> outTile;
+ for (int64_t i = 0; i < M / 2; ++i)
+ for (int64_t j = 0; j < N / 2; ++j) {
+ Value mmla = createMMLA(rewriter, loc, accTile[i * N / 2 + j], lhsTile[i],
+ rhsTile[j]);
+ outTile.push_back(mmla);
+ }
+
+ // Unpack the OUT sub-tiles and insert into the result.
+ Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+ for (int64_t i = 0; i < M / 2; ++i) {
+ // Collect a number of sub-tiles in a row.
+ Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+ for (int64_t j = 0; j < N / 2; ++j)
+ row = rewriter.create<vector::ScalableInsertOp>(
+ loc, outTile[i * N / 2 + j], row, j * 4);
+
+ // Unpack the row to obtain two rows of the output. If we have the out
+ // sub-tiles transposed we obtain two consecutive output rows by
+ // separating even and odd elements, i.e. a simple deinterleave.
+ // Otherwise, the interleave is by pairs.
+ Value out0, out1;
+ if (swapOperands) {
+ auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+ out0 = tmp.getRes1();
+ out1 = tmp.getRes2();
+ } else {
+ // Deinterleave by pairs.
+ auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+ auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+ // Bitcast back into original element type and insert into the result.
+ out0 =
+ rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes1());
+ out1 =
+ rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes2());
+ }
+ result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+ result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+ }
+
+ return result;
+}
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+ // Check the specific preconditions for the integer case. Initialise
+ // parametrisation types and dimensions.
+ LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) {
+
+ if (failed(VectorContractRewriter::match(op, rewriter)))
+ return failure();
+
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+
+ M = lhsType.getDimSize(0);
+ N = rhsType.getDimSize(0);
+ K = rhsType.getDimSize(1);
+
+ // Check the operands have the expected shape:
+ // * for LHS: fixed vector MxK
+ // * for RHS: scalable vector [N]xK
+ // * K == 8
+ // * M and N even and at least 2
+ if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
+ rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
+ M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+ !rhsType.getScalableDims()[0])
+ return rewriter.notifyMatchFailure(op, "non-matching operand shape");
+
+ // Check the output is a vector of i32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getResultType());
+ if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+ return rewriter.notifyMatchFailure(op,
+ "output type is not a vector of i32");
+
+ // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+ // before the extension. All four signed/unsigned combinations for input
+ // operands are supported, but they are lowered to different operations.
+ // Determine which is the appropriate operation to lower to.
+ mmlaOp = MMLA::SignedInt;
+ swapOperands = false;
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::UnsignedInt;
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+ }
+ if (!maybeLhs)
+ return rewriter.notifyMatchFailure(
+ op, "LHS is not a sign- or zero- extended i8");
+
+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::UnsignedInt)
+ mmlaOp = MMLA::MixedInt;
+ } else {
+ if (mmlaOp == MMLA::SignedInt) {
+ mmlaOp = MMLA::MixedInt;
+ swapOperands = true;
+ }
+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
+ }
+ if (!maybeRhs)
+ return rewriter.notifyMatchFailure(
+ op, "RHS is not a sign- or zero- extended i8");
+
+ // Initialise algorithm parameters.
+ lhs = *maybeLhs;
+ rhs = *maybeRhs;
+ acc = op.getAcc();
+
+ flatLhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
+ /*scalableDims=*/{true});
+ flatRhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
+ /*scalableDims=*/{true});
+
+ flatAccType = VectorType::get(/*shape=*/4, rewriter.getI32Type(),
+ /*scalableDims=*/{true});
+
+ flatRhsTileType = VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(),
+ /*scalableDims=*/{true});
+
+ accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(),
+ /*scalableDims=*/{true});
+ accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(),
+ /*scalableDims=*/{true});
+ accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
+ /*scalableDims=*/{true});
+ accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
+ /*scalableDims=*/{true});
+
+ return success();
+ }
+};
+
+class VectorContractRewriterBfloat : public VectorContractRewriter {
+public:
+ // Check the specific preconditions for the bfloat16 case. Initialise
+ // parametrisation types and dimensions.
+ LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) {
+
+ if (failed(VectorContractRewriter::match(op, rewriter)))
+ return failure();
+
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+
+ M = lhsType.getDimSize(0);
+ N = rhsType.getDimSize(0);
+ K = rhsType.getDimSize(1);
+
+ // Check the operands have the expected shape:
+ // * for LHS: fixed vector MxK
+ // * for RHS: scalable vector [N]xK
+ // * K == 4
+ // * M and N even and at least 2
+ if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
+ rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 4 ||
+ M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+ !rhsType.getScalableDims()[0])
+ return rewriter.notifyMatchFailure(op, "non-matching operand shape");
+
+ // Check the output is a vector of Float32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getResultType());
+ if (!outTy || outTy.getElementType() != rewriter.getF32Type())
+ return rewriter.notifyMatchFailure(op,
+ "output type is not a vector of f32");
+
+ // Check the inputs are vectors of BFloat16 elements.
+ if (lhsType.getElementType() != rewriter.getBF16Type())
+ return rewriter.notifyMatchFailure(op,
+ "input type is not a vector of bf16");
+
+ // Initialise algorithm parameters.
+ mmlaOp = MMLA::Bfloat;
+ swapOperands = false;
+ lhs = op.getLhs();
+ rhs = op.getRhs();
+ acc = op.getAcc();
+
+ flatLhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(),
+ /*scalableDims=*/{true});
+ flatRhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(),
+ /*scalableDims=*/{true});
+
+ flatAccType = VectorType::get(/*shape=*/4, rewriter.getF32Type(),
+ /*scalableDims=*/{true});
+
+ flatRhsTileType = VectorType::get(/*shape=*/4 * N, rewriter.getBF16Type(),
+ /*scalableDims=*/{true});
+
+ accRowTy = VectorType::get(/*shape=*/N, rewriter.getF32Type(),
+ /*scalableDims=*/{true});
+ accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getF32Type(),
+ /*scalableDims=*/{true});
+ accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
+ /*scalableDims=*/{true});
+ accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
+ /*scalableDims=*/{true});
+
+ return success();
+ }
+};
+
+class LowerContractionToSVEI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ // Match i8xi8 -> i32 matrix multiply and accumulate.
+ VectorContractRewriterI8MM vcr;
+ if (failed(vcr.match(op, rewriter)))
+ return failure();
+
+ Value result = vcr.rewrite(op, rewriter);
+ rewriter.replaceOp(op, result);
+
+ return success();
+ }
+};
+
+class LowerContractionToSVEBFMMLAPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ // Match bf16xbf16 -> f32 matrix multiply and accumulate.
+ VectorContractRewriterBfloat vcr;
+ if (failed(vcr.match(op, rewriter)))
+ return failure();
+
+ Value result = vcr.rewrite(op, rewriter);
+ rewriter.replaceOp(op, result);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
+}
+
+void mlir::populateLowerContractionToSVEBFMMLAPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<LowerContractionToSVEBFMMLAPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
deleted file mode 100644
index b7703ff0393eb..0000000000000
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ /dev/null
@@ -1,366 +0,0 @@
-//===- LowerContractionToSVEI8MMPattern.cpp - Contract to I8MM --*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements lowering patterns from vector.contract to operations
-// that map to instructions from the SVE FEAT_I8MM extension.
-//
-// TODO: There may be opportunities to unify this with a similar pattern
-// for Neon. See:
-// https://github.com/llvm/llvm-project/issues/145559
-// LowerContractionToNeonI8MMPattern.cpp
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
-#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#include "mlir/Dialect/UB/IR/UBOps.h"
-
-#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
-
-using namespace mlir;
-
-namespace {
-// Get the operand of a `vector.contract`. This function is intended to abstract
-// away from the particular way a value is extended before feeding it into the
-// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
-// (for implicit sign-extension see `vector.contract` documentation).
-//
-// The template parameter `Op` indicates the extension operation (explicit or
-// implicit) for which we are checking.
-//
-// Return success only for extensions from `i8` to `i32`.
-template <typename Op>
-std::optional<Value> getExtOperand(Value v) {
-
- static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
- "Must be instantiated with either sign- or zero- extension op");
-
- // If the operand is not defined by an explicit extend operation of the
- // accepted operation type allow for an implicit sign-extension.
- auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
- if (!extOp) {
- if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
- auto vTy = cast<VectorType>(v.getType());
- if (!vTy.getElementType().isSignlessInteger(8))
- return {};
- return v;
- }
- return {};
- }
-
- // If the operand is defined by an explicit extend operation of the accepted
- // operation type, check it's extended from `i8` to `i32`.
- auto inOp = extOp.getIn();
- auto inTy = dyn_cast<VectorType>(inOp.getType());
- if (!inTy || !inTy.getElementType().isSignlessInteger(8))
- return {};
-
- auto outTy = dyn_cast<VectorType>(extOp.getType());
- if (!outTy || !outTy.getElementType().isSignlessInteger(32))
- return {};
-
- return inOp;
-}
-
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
- Signed, // smmla
- Unsigned, // ummla
- Mixed, // usmmla
- MixedSwapped // usmmla with LHS and RHS swapped
-};
-
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
- mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
- switch (op) {
- case MMLA::Signed:
- return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
- case MMLA::Unsigned:
- return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
- case MMLA::Mixed:
- return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
- case MMLA::MixedSwapped:
- // The accumulator comes transposed and the result will be transposed
- // later, so all we have to do here is swap the operands.
- return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
- }
-}
-
-/// Lower a contraction operation that performs a matrix multiplication
-/// of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
-/// for the left-hand side and the right-hand side, respectively,
-/// yielding a <Mx[N]> 32-bit integer result.
-///
-/// The operands' shapes are such that the operands can be evenly split into
-/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
-/// instructions. The intent is that M and N are chosen (by higher level
-/// transforms) in such a way as to maximise register usage. The main use case
-/// we envision as of now is MMT4D, thus the RHS operand is expected
-/// pre-transposed.
-///
-/// The matrix multiplication is performed by unrolling the usual tiled matrix
-/// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
-/// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
-///
-/// One way to illustrate the operation is as follows:
-///
-/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
-/// +-----------------------------
-/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
-/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
-/// ... | ... ... ... ...
-/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
-///
-/// The RHS operand is unpacked into N/2 values, each representing a sequence of
-/// VSCALE number of sub-tiles with dimensions <8x2>.
-/// The LHS operand is initially unpacked into M/2 values, each representing a
-/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
-/// VSCALE times.
-/// Multiplying thus replicated LHS sub-tile by the corresponding RHS sub-tile
-/// correctly computes an entire result sub-tile.
-class LowerContractionToSVEI8MMPattern
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override {
-
- Location loc = op.getLoc();
- mlir::VectorType lhsType = op.getLhsType();
- mlir::VectorType rhsType = op.getRhsType();
-
- // Check the rank the types so we can safely examine their dimensions.
- if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
- return rewriter.notifyMatchFailure(op, "non-matching operand shape");
-
- auto M = lhsType.getDimSize(0);
- auto N = rhsType.getDimSize(0);
- auto K = rhsType.getDimSize(1);
-
- // Check the operands have the expected shape:
- // * for LHS: fixed vector MxK
- // * for RHS: scalable vector [N]xK
- // * K == 8
- // * M and N even and at least 2
- if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
- rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
- M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
- !rhsType.getScalableDims()[0])
- return rewriter.notifyMatchFailure(op, "non-matching operand shape");
-
- // Check permutation maps. For now only accept
- // lhs: (d0, d1, d2) -> (d0, d2)
- // rhs: (d0, d1, d2) -> (d1, d2)
- // acc: (d0, d1, d2) -> (d0, d1)
- // This corresponds to matrix multiplication with transposed RHS.
- if (op.getIndexingMapsArray()[0] !=
- AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
- op.getContext()) ||
- op.getIndexingMapsArray()[1] !=
- AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
- op.getContext()) ||
- op.getIndexingMapsArray()[2] !=
- AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
- op.getContext()))
- return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
-
- // Check iterator types for matrix multiplication.
- auto itTypes = op.getIteratorTypesArray();
- if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
- itTypes[1] != vector::IteratorType::parallel ||
- itTypes[2] != vector::IteratorType::reduction)
- return rewriter.notifyMatchFailure(
- op, "iterator types do not correspond to matrix multiplication");
-
- // Check the combining kind is addition.
- if (op.getKind() != vector::CombiningKind::ADD)
- return rewriter.notifyMatchFailure(op,
- "combining kind is not an addition");
-
- // Check the output is a vector of i32 elements.
- auto outTy = dyn_cast<VectorType>(op.getResultType());
- if (!outTy || outTy.getElementType() != rewriter.getI32Type())
- return rewriter.notifyMatchFailure(op,
- "output type is not a vector of i32");
-
- // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
- // before the extension. All four signed/unsigned combinations for input
- // operands are supported, but they are lowered to different operations.
- // Determine which is the appropriate operation to lower to.
- MMLA mmlaOp = MMLA::Signed;
- auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
- if (!maybeLhs) {
- mmlaOp = MMLA::Unsigned;
- maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
- }
- if (!maybeLhs)
- return rewriter.notifyMatchFailure(
- op, "LHS is not a sign- or zero- extended i8");
-
- auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
- if (maybeRhs) {
- if (mmlaOp == MMLA::Unsigned)
- mmlaOp = MMLA::Mixed;
- } else {
- if (mmlaOp == MMLA::Signed)
- mmlaOp = MMLA::MixedSwapped;
- maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
- }
- if (!maybeRhs)
- return rewriter.notifyMatchFailure(
- op, "RHS is not a sign- or zero- extended i8");
-
- // One-dimensional vector types for arm_sve.*mmla
- auto nxv16i8 = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
- /*scalableDims=*/{true});
- auto nxv4i32 = VectorType::get(/*shape=*/4, rewriter.getI32Type(),
- /*scalableDims=*/{true});
-
- // Extract LHS sub-tiles with logicall shape <2x8>.
- SmallVector<Value> lhsTile;
- for (int64_t i = 0; i < M; i += 2) {
- // Extract two consecutive rows of the LHS tile.
- auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
- ArrayRef<int64_t>{i});
- auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
- ArrayRef<int64_t>{i + 1});
- // Concatenate to obtain a 16 x i8 flattened sub-tile.
- auto t = rewriter.create<vector::ShuffleOp>(
- loc, r0, r1,
- llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
- 14, 15});
- // Turn it into a scalable vector.
- auto s = rewriter.create<vector::ScalableInsertOp>(
- loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
- // Replicate the sub-tile VSCALE times to fill the entire vector.
- auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
- lhsTile.push_back(r);
- }
-
- // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
- auto rhs = rewriter.create<vector::ShapeCastOp>(
- maybeRhs->getLoc(),
- VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(),
- /*scalableDims=*/{true}),
- *maybeRhs);
-
- // Extract the RHS sub-tiles with logical shape <8x[2]>.
- SmallVector<Value> rhsTile;
- for (int64_t j = 0; j < N; j += 2)
- rhsTile.push_back(
- rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, rhs, j * 8));
-
- // Handy types for packing/unpacking of the accumulator tile.
- auto accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(),
- /*scalableDims=*/{true});
- auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(),
- /*scalableDims=*/{true});
- auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
- /*scalableDims=*/{true});
- auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
- /*scalableDims=*/{true});
-
- // Extract and pack the ACC sub-tiles.
- SmallVector<Value> accTile;
- for (int64_t i = 0; i < M; i += 2) {
- // Extract two consecutive rows of the accumulator tile.
- auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
- ArrayRef<int64_t>{i});
- auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
- ArrayRef<int64_t>{i + 1});
- Value accTileVec;
- if (mmlaOp == MMLA::MixedSwapped) {
- // We need to swap the positions of the LHS and RHS (since we don't have
- // a signed * unsigned operation), but then each individual 2x2 tile of
- // the acumulator and (later) the result need to be transposed.
- accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
- } else {
- // Bitcast them to 64-bit elements, so subsequent
- // interleave/deinterleave work on pairs of 32-bit numbers.
- auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
- auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
-
- // Interleave the rows, effectively flattening each 2x2 tile into 4
- // consecutive elements.
- auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64);
-
- // Bitcast back to 32-bit elements.
- accTileVec =
- rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
- }
- // Extract ACC sub-tiles.
- for (int64_t j = 0; j < N; j += 2)
- accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
- loc, nxv4i32, accTileVec, j * 2));
- }
-
- // Emit sub-tile matrix multiplications.
- SmallVector<Value> outTile;
- for (int64_t i = 0; i < M / 2; ++i)
- for (int64_t j = 0; j < N / 2; ++j) {
- Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
- accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
- outTile.push_back(mmla);
- }
-
- // Unpack the OUT sub-tiles and insert into the result.
- Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
- for (int64_t i = 0; i < M / 2; ++i) {
- // Collect a number of sub-tiles in a row.
- Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
- for (int64_t j = 0; j < N / 2; ++j)
- row = rewriter.create<vector::ScalableInsertOp>(
- loc, outTile[i * N / 2 + j], row, j * 4);
-
- // Unpack the row to obtain two rows of the output. If we have the out
- // sub-tiles transposed we obtain two consecutive output rows by
- // separating even and odd elements, i.e. a simple deinterleave.
- // Otherwise, the interleave is by pairs.
- Value out0, out1;
- if (mmlaOp == MMLA::MixedSwapped) {
- auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
- out0 = tmp.getRes1();
- out1 = tmp.getRes2();
- } else {
- // Deinterleave by pairs.
- auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
- auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
-
- // Bitcast back into 32-bit elements and insert into the result.
- out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
- deintr64.getRes1());
- out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
- deintr64.getRes2());
- }
- result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
- result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
- }
-
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-} // namespace
-
-void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
- RewritePatternSet &patterns) {
- MLIRContext *context = patterns.getContext();
- patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
-}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir
new file mode 100644
index 0000000000000..ca9d91576b512
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+#attrs = {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>
+}
+
+// CHECK-LABEL: @test_vector_contract_to_bfmmla
+// CHECK-SAME: %[[LHS:.+]]: vector<4x4xbf16>, %[[RHS:.+]]: vector<[4]x4xbf16>, %[[ACC:.+]]: vector<4x[4]xf32>) -> vector<4x[4]xf32> {
+// CHECK-NEXT: %[[T0:.+]] = ub.poison : vector<[8]xf32>
+// CHECK-NEXT: %[[UB:.+]] = ub.poison : vector<4x[4]xf32>
+// CHECK-NEXT: %[[T2:.+]] = ub.poison : vector<[8]xbf16>
+
+// Extract rows 0 and 1 of the LHS, concatenate them, and replicate the resulting 8xbf16 vector
+// VSCALE times to obtain a [8]xbf16 vector.
+// CHECK-NEXT: %[[T3:.+]] = vector.extract %[[LHS]][0] : vector<4xbf16> from vector<4x4xbf16>
+// CHECK-NEXT: %[[T4:.+]] = vector.extract %[[LHS]][1] : vector<4xbf16> from vector<4x4xbf16>
+// CHECK-NEXT: %[[T5:.+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xbf16>, vector<4xbf16>
+// CHECK-NEXT: %[[T6:.+]] = vector.scalable.insert %[[T5]], %[[T2]][0] : vector<8xbf16> into vector<[8]xbf16>
+// CHECK-NEXT: %[[LHS_00:.+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[8]xbf16>
+
+// Same for rows 2 and 3 of the LHS.
+// CHECK-NEXT: %[[T8:.+]] = vector.extract %[[LHS]][2] : vector<4xbf16> from vector<4x4xbf16>
+// CHECK-NEXT: %[[T9:.+]] = vector.extract %[[LHS]][3] : vector<4xbf16> from vector<4x4xbf16>
+// CHECK-NEXT: %[[T10:.+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xbf16>, vector<4xbf16>
+// CHECK-NEXT: %[[T11:.+]] = vector.scalable.insert %[[T10]], %[[T2]][0] : vector<8xbf16> into vector<[8]xbf16>
+// CHECK-NEXT: %[[LHS_10:.+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[8]xbf16>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T13:.+]] = vector.shape_cast %[[RHS]] : vector<[4]x4xbf16> to vector<[16]xbf16>
+// CHECK-NEXT: %[[RHS_00:.+]] = vector.scalable.extract %[[T13]][0] : vector<[8]xbf16> from vector<[16]xbf16>
+// CHECK-NEXT: %[[RHS_01:.+]] = vector.scalable.extract %[[T13]][8] : vector<[8]xbf16> from vector<[16]xbf16>
+
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T16:.+]] = vector.extract %[[ACC]][0] : vector<[4]xf32> from vector<4x[4]xf32>
+// CHECK-NEXT: %[[T17:.+]] = vector.extract %[[ACC]][1] : vector<[4]xf32> from vector<4x[4]xf32>
+// CHECK-NEXT: %[[T18:.+]] = vector.bitcast %[[T16]] : vector<[4]xf32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T19:.+]] = vector.bitcast %[[T17]] : vector<[4]xf32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T20:.+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T21:.+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xf32>
+// CHECK-NEXT: %[[ACC_00:.+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xf32> from vector<[8]xf32>
+// CHECK-NEXT: %[[ACC_01:.+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xf32> from vector<[8]xf32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT: %[[T24:.+]] = vector.extract %[[ACC]][2] : vector<[4]xf32> from vector<4x[4]xf32>
+// CHECK-NEXT: %[[T25:.+]] = vector.extract %[[ACC]][3] : vector<[4]xf32> from vector<4x[4]xf32>
+// CHECK-NEXT: %[[T26:.+]] = vector.bitcast %[[T24]] : vector<[4]xf32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T27:.+]] = vector.bitcast %[[T25]] : vector<[4]xf32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T28:.+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T29:.+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xf32>
+// CHECK-NEXT: %[[ACC_10:.+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xf32> from vector<[8]xf32>
+// CHECK-NEXT: %[[ACC_11:.+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xf32> from vector<[8]xf32>
+
+// Do the sub-tile matrix multiplications
+// CHECK-NEXT: %[[PACK_RES_00:.+]] = arm_sve.intr.bfmmla %[[ACC_00]], %[[LHS_00]], %[[RHS_00]] : vector<[8]xbf16> to vector<[4]xf32>
+// CHECK-NEXT: %[[PACK_RES_01:.+]] = arm_sve.intr.bfmmla %[[ACC_01]], %[[LHS_00]], %[[RHS_01]] : vector<[8]xbf16> to vector<[4]xf32>
+// CHECK-NEXT: %[[PACK_RES_10:.+]] = arm_sve.intr.bfmmla %[[ACC_10]], %[[LHS_10]], %[[RHS_00]] : vector<[8]xbf16> to vector<[4]xf32>
+// CHECK-NEXT: %[[PACK_RES_11:.+]] = arm_sve.intr.bfmmla %[[ACC_11]], %[[LHS_10]], %[[RHS_01]] : vector<[8]xbf16> to vector<[4]xf32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// CHECK-NEXT: %[[T36:.+]] = vector.scalable.insert %[[PACK_RES_00]], %[[T0]][0] : vector<[4]xf32> into vector<[8]xf32>
+// CHECK-NEXT: %[[T37:.+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xf32> into vector<[8]xf32>
+// CHECK-NEXT: %[[T38:.+]] = vector.bitcast %[[T37]] : vector<[8]xf32> to vector<[4]xi64>
+// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_00:.+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xf32>
+// CHECK-NEXT: %[[UNPACK_RES_01:.+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xf32>
+// CHECK-NEXT: %[[TMP_OUT_0:.+]] = vector.insert %[[UNPACK_RES_00]], %[[UB]] [0] : vector<[4]xf32> into vector<4x[4]xf32>
+// CHECK-NEXT: %[[TMP_OUT_1:.+]] = vector.insert %[[UNPACK_RES_01]], %[[TMP_OUT_0]] [1] : vector<[4]xf32> into vector<4x[4]xf32>
+
+// Same for result rows 2 and 3
+// CHECK-NEXT: %[[T43:.+]] = vector.scalable.insert %[[PACK_RES_10]], %[[T0]][0] : vector<[4]xf32> into vector<[8]xf32>
+// CHECK-NEXT: %[[T44:.+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xf32> into vector<[8]xf32>
+// CHECK-NEXT: %[[T45:.+]] = vector.bitcast %[[T44]] : vector<[8]xf32> to vector<[4]xi64>
+// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_10:.+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xf32>
+// CHECK-NEXT: %[[UNPACK_RES_11:.+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xf32>
+// CHECK-NEXT: %[[TMP_OUT_2:.+]] = vector.insert %[[UNPACK_RES_10]], %[[TMP_OUT_1]] [2] : vector<[4]xf32> into vector<4x[4]xf32>
+// CHECK-NEXT: %[[OUT:.+]] = vector.insert %[[UNPACK_RES_11]], %[[TMP_OUT_2]] [3] : vector<[4]xf32> into vector<4x[4]xf32>
+// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xf32>
+func.func @test_vector_contract_to_bfmmla(%lhs: vector<4x4xbf16>,
+ %rhs: vector<[4]x4xbf16>,
+ %acc: vector<4x[4]xf32>) -> vector<4x[4]xf32> {
+ %0 = vector.contract #attrs %lhs, %rhs, %acc
+ : vector<4x4xbf16>, vector<[4]x4xbf16> into vector<4x[4]xf32>
+
+ return %0 : vector<4x[4]xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.arm_sve.vector_contract_to_bfmmla
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir
new file mode 100644
index 0000000000000..0e988d1c2f42c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir
@@ -0,0 +1,201 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-bf16' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \
+// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+bf16" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+
+//
+// Test the lowering of `vector.contract` using the `LowerContractionToSVEBFMMLAPattern`
+//
+// The operation that the `vector.contract` in this test performs is matrix
+// multiplication with accumulate
+// OUT = ACC + LHS * RHS
+// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT.
+//
+// Tested are calculations as well as that the relevant `ArmSVE` dialect
+// operation ('arm_sve.intr.bfmmla`) is emitted.
+//
+// That pattern above handles (therefore this test prepares) input/output vectors with
+// specific shapes:
+// * LHS: vector<Mx4xbf16>
+// * RHS: vector<[N]x4xbf16>
+// * ACC, OUT: vector<Mx[N]xf32>
+// Note that the RHS is transposed.
+// This data layout makes it efficient to load data into SVE
+// registers in the layout expected by te BFMMLA instruction.
+// Such a `vector.contract` is representative of the code we aim to generate
+// by scalable vectorisation of `linalg.mmt4d`.
+// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+// for more information and rationale about these shapes.
+//
+// In this specific test we use M == 4 and N == 4
+//
+
+// Allocate and initialise a memref containing test data for use as the ACC
+// operand. The memref has one dynamic dimension whose extent depends on the
+// runtime value of VSCALE.
+//
+// The input parameter `%in` is a vector that is replicated VSCALE times
+// across the columns of the memref.
+func.func private @prepareAccTestData(%in: vector<4x4xf32>) -> memref<4x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %vs = vector.vscale
+ %d = arith.muli %c4, %vs : index
+ %mem = memref.alloc(%d) : memref<4x?xf32>
+
+ scf.for %j = %c0 to %d step %c4 {
+ vector.transfer_write %in, %mem[%c0, %j] {in_bounds = [true, true]} : vector<4x4xf32>, memref<4x?xf32>
+ }
+
+ return %mem : memref<4x?xf32>
+}
+
+// Allocate and initialise a memref containing test data for use as the LHS
+// operand. This function just writes the parameter `%in` into the memref.
+// The size of the LHS does not depends on VSCALE.
+func.func private @prepareLHSTestData(%in: vector<4x4xbf16>) -> memref<4x4xbf16> {
+ %c0 = arith.constant 0 : index
+
+ %mem = memref.alloc() : memref<4x4xbf16>
+ vector.transfer_write %in, %mem[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xbf16>, memref<4x4xbf16>
+
+ return %mem : memref<4x4xbf16>
+}
+
+// Allocate and initialise a memref containing test data for use as the RHS
+// operand. The memref has one dynamic dimension whose extent depends on the
+// runtime value of VSCALE.
+//
+// The input parameter `%in` is a vector that is replicated VSCALE times
+// across the rows of the memref.
+//
+// For convenience, flatten the memref, since the RHS vector is read first as a
+// single-dimensional scalable vector and then cast into [N]x4 shape.
+func.func private @prepareRHSTestData(%in: vector<4x4xbf16>) -> memref<?xbf16> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %vs = vector.vscale
+ %d = arith.muli %c4, %vs : index
+ %mem = memref.alloc(%d) : memref<?x4xbf16>
+
+ scf.for %i = %c0 to %d step %c4 {
+ vector.transfer_write %in, %mem[%i, %c0] {in_bounds = [true, true]} : vector<4x4xbf16>, memref<?x4xbf16>
+ }
+
+ %mem_out = memref.collapse_shape %mem [[0, 1]] : memref<?x4xbf16> into memref<?xbf16>
+ return %mem_out : memref<?xbf16>
+}
+
+
+// CHECK-IR-LABEL: llvm.func @test_bfmmla
+// CHECK-IR-COUNT-4: arm_sve.intr.bfmmla
+func.func @test_bfmmla() {
+
+ %c0 = arith.constant 0 : index
+ %c0_f32 = arith.constant 0.0 : f32
+ %c0_bf16 = arith.constant 0.0 : bf16
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[[ 0.7, 1.0, -0.1, 1.8],
+ [-0.5, 0.9, 0.7, -0.7],
+ [ 0.5, -1.3, -2.2, 0.1],
+ [-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32>
+
+ %acc_mem = func.call @prepareAccTestData(%acc_cst) : (vector<4x4xf32>) -> memref<4x?xf32>
+ %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x?xf32>, vector<4x[4]xf32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[ 0.1, 0.7, -0.9, 1.3],
+ [-1.6, 0.7, -0.3, -0.3],
+ [-0.4, 0.6, 0.8, -0.5],
+ [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
+
+ %lhs_mem = func.call @prepareLHSTestData(%lhs_cst) : (vector<4x4xbf16>) -> memref<4x4xbf16>
+ %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9],
+ [ 0.5, 1.6, 1.8, 1.6],
+ [-0.2, 0.4, 1.0, 0.4],
+ [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
+
+ %rhs_mem = func.call @prepareRHSTestData(%rhs_cst) : (vector<4x4xbf16>) -> memref<?xbf16>
+ %rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<?xbf16>, vector<[16]xbf16>
+ %rhs = vector.shape_cast %rhs_flat : vector<[16]xbf16> to vector<[4]x4xbf16>
+
+ // Matrix multiplication and accumulate with transposed RHS.
+ %0 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<4x4xbf16>, vector<[4]x4xbf16> into vector<4x[4]xf32>
+
+ // Display the result of the multiplication
+ vector.print str "Result(BFMMLA):\n"
+ %u0 = vector.extract %0[0] : vector<[4]xf32> from vector<4x[4]xf32>
+ %u1 = vector.extract %0[1] : vector<[4]xf32> from vector<4x[4]xf32>
+ %u2 = vector.extract %0[2] : vector<[4]xf32> from vector<4x[4]xf32>
+ %u3 = vector.extract %0[3] : vector<[4]xf32> from vector<4x[4]xf32>
+ vector.print %u0 : vector<[4]xf32>
+ vector.print %u1 : vector<[4]xf32>
+ vector.print %u2 : vector<[4]xf32>
+ vector.print %u3 : vector<[4]xf32>
+
+ // Deallocate the buffers.
+ memref.dealloc %acc_mem : memref<4x?xf32>
+ memref.dealloc %lhs_mem : memref<4x4xbf16>
+ memref.dealloc %rhs_mem : memref<?xbf16>
+
+ return
+}
+
+// Perform each test with SVE vector lengths 128 bits and 256 bits (i.e. VSCALEs
+// 1 and 2, respectively). The vector length is set via the `setArmVLBits`
+// function. The effect of setting a different vector length is that the tests
+// allocate and operate on different sized buffers (see `prepare<X>TestData`
+// functions).
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ %c256 = arith.constant 256 : i32
+
+// CHECK-LABEL: Result(BFMMLA):
+// CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 )
+// CHECK: ( -0.316515, 0.196875, 0.879375, 1.80924 )
+// CHECK: ( 1.56867, 0.101367, -1.2784, -1.41579 )
+// CHECK: ( -1.56041, -4.30078, 0.0196488, 1.88269 )
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+ func.call @test_bfmmla() : () -> ()
+
+// CHECK: Result(BFMMLA):
+// CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965, 0.411922, 2.63254, -0.219259, 3.89965 )
+// CHECK: ( -0.316515, 0.196875, 0.879375, 1.80924, -0.316515, 0.196875, 0.879375, 1.80924 )
+// CHECK: ( 1.56867, 0.101367, -1.2784, -1.41579, 1.56867, 0.101367, -1.2784, -1.41579 )
+// CHECK: ( -1.56041, -4.30078, 0.0196488, 1.88269, -1.56041, -4.30078, 0.0196488, 1.88269 )
+ func.call @setArmVLBits(%c256) : (i32) -> ()
+ func.call @test_bfmmla() : () -> ()
+
+ return
+}
+
+func.func private @setArmVLBits(%bits : i32)
+func.func private @printMemrefF32(%ptr : memref<*xf32>)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
index 5f6e8e4c30892..8504d664fa0c6 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
@@ -20,7 +20,7 @@
]
//
-// Test the lowering of `vector.contract` using the `LowerContractionToSVEI8MMPattern`
+// Test the lowering of `vector.contract` using the `LowerContractionToSVEBFMMLAPattern`
//
// The operation that the `vector.contract` in this test performs is matrix
// multiplication with accumulate
@@ -42,7 +42,7 @@
// registers in the layout expected by FEAT_I8MM instructions.
// Such a `vector.contract` is representative of the code we aim to generate
// by scalable vectorisation of `linalg.mmt4d`.
-// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
// for more information and rationale about these shapes.
//
// In this specific test we use M == 4 and N == 4
@@ -316,7 +316,7 @@ func.func @test_usmmla() {
// Test the operation where LHS is interpreted as signed and RHS is interpreted
// as unsigned. In this test we ultimately emit end execute the `usmmla`
-// instruction with reversed operands, see `LowerContractionToSVEI8MMPattern.cpp`
+// instruction with reversed operands, see `LowerContractToSVEPatterns.cpp`
// for more details.
// CHECK-IR-LABEL: llvm.func @test_summla
More information about the Mlir-commits
mailing list