[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` to SVE FEAT_BF16 operations (PR #147052)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Jul 7 01:43:05 PDT 2025
================
@@ -0,0 +1,607 @@
+//===- LowerContractToSVEPatterns.cpp - Contract to I8MM/BF16 ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the SVE FEAT_I8MM and FEAT_BF16 extensions.
+//
+// TODO: There may be opportunities to unify this with a similar pattern
+// for Neon. See:
+// https://github.com/llvm/llvm-project/issues/145559
+// LowerContractionToNeonI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+
+namespace {
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `i8` to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+ static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+ "Must be instantiated with either sign- or zero- extension op");
+
+ // If the operand is not defined by an explicit extend operation of the
+ // accepted operation type allow for an implicit sign-extension.
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ if (!extOp) {
+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+ auto vTy = cast<VectorType>(v.getType());
+ if (!vTy.getElementType().isSignlessInteger(8))
+ return {};
+ return v;
+ }
+ return {};
+ }
+
+ // If the operand is defined by an explicit extend operation of the accepted
+ // operation type, check it's extended from `i8` to `i32`.
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || !inTy.getElementType().isSignlessInteger(8))
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || !outTy.getElementType().isSignlessInteger(32))
+ return {};
+
+ return inOp;
+}
+
+/// This class encapsulates the algorithm and parametrisation (in terms of types
+/// and dimensions) of lowering a `vector.contract` to "primitive" matrix
+/// multiplication operations of the SVE dialect (here "primitive" would mean
+/// corresponding to a single target instruction).
+///
+/// Supported are lowering to FEAT_I8MM `smmla`, `ummla`, and `usmmla`, and to
+/// FEAT_BF16 `bfmmla`. All the transformations are very similar to each other
+/// for concreteness the description below is given for `smmla`.
+///
+/// The lowering triggers for a contraction operation that performs a matrix
+/// multiply of two 8-bit integer matrix tiles with logical dimensions
+/// <Mx8> and <8x[N]> for the left-hand side (LHS) and the right-hand side
+/// (RHS), respectively, added to a 32-bit integer accumulator operand (ACC)
+/// with dimensions <Mx[N]>, yielding a <Mx[N]> 32-bit integer result (OUT).
+///
+/// The operands' shapes are such that the operands can be evenly split into
+/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
+/// instructions. The intent is that M and N are chosen (by higher level
+/// transforms) in such a way as to maximise register usage. The main use case
+/// we envision as of now is MMT4D, thus the RHS operand is expected
+/// pre-transposed.
+///
+/// The matrix multiplication is performed by unrolling the usual tiled matrix
+/// multiplication algorithm using sub-tiles with dimensions <2x8> for the
+/// LHS, <8x[2]> for the RHS, and <2x[2]> for the result and the input
+/// accumulator.
+///
+/// One way to illustrate the operation is as follows:
+///
+/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
+/// +-----------------------------
+/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+/// ... | ... ... ... ...
+/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///
+/// The RHS operand is unpacked into N/2 values, each representing a sequence
+/// of VSCALE number of sub-tiles with dimensions <8x2>.
+/// The LHS operand is initially unpacked into M/2 values, each representing a
+/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
+/// VSCALE times. Multiplying thus replicated LHS sub-tile by the corresponding
+/// RHS sub-tile correctly computes an entire result sub-tile.
+/// The 2x2 sub-tiles of the ACC and OUT have rows that are not adjacent
+/// (in memory or when imposing a row-major layout on the 2D vector value).
+/// Reading the ACC is implemented as reading two consecutive rows and
+/// interleaving the by pairs to obtain a vector having length twice the length
+/// of an ACC row. This vector now is a sequence of one-dimensional tiles with
+/// the exact layout needed by the `smmla`/`bfmmla`/etc instructions, which
+/// tiles are extracted one by one. For illustration, if we have an 2x4 ACC tile
+/// a0 a1 b0 b1
+/// a2 a3 b2 b3
+/// we read the two rows as separate values and then interleave by pairs
+/// to obtain
+/// a0 a1 a2 a3 b0 b1 b2 b3
+/// from which we extract `a0 a1 a2 a3` and `b0 b1 b2 b3`.
+///
+/// Writing the OUT tile is done by the reverse of the above procedure,
+/// concatenate two "flattened" sub-tiles into
+/// c0 c1 c2 c3 d0 d1 d2 d3
+/// deinterleave by pairs to obtain as separate values
+/// c0 c1 d0 d1
+/// c2 c3 d2 d3
+/// which are then inserted into the final result.
+///
+/// Multiplication of a signed LHS by an unsigned LHS is performed by
+/// swapping the order of the operands and emitting an `usmmla` (since there
+/// isn't an `summla` instruction). Therefore each ACC sub-tile needs
+/// to be transposed before the addition and the sum, an OUT sub-tile,
+/// needs to be transposed before insertion into the final result.
+/// This is done very elegantly by a modification of the above to
+/// interleave/deinterleave not by pairs, but by individual elements, e.g.
+/// after ordinary interleave we obtain
+/// a0 a2 a1 a3 b0 b2 b1 b3
+/// which is exactly the desired layout of having each individual 2x2 tile
+/// transposed.
+///
+/// All of the above readily applies to FEAT_BF16 `bfmmla` with the
+/// difference that the shapes of the LHS, RHS are <Mx4>, <4x[M]>, and
+/// respectively, that is the "K" dimension is fixed to 4, instead of 8 (like
+/// for the integer case).
+class VectorContractRewriter {
+protected:
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
+ // multiplications.
+ enum class MMLA {
+ Nop,
+ SignedInt, // smmla
+ UnsignedInt, // ummla
+ MixedInt, // usmmla
+ Bfloat // bfmmla
+ };
+
+ // Lower-level operation to be emitted.
+ MMLA mmlaOp = MMLA::Nop;
+
+ // The operand tiles. These are not necessarily the operends of
+ // `vector.contract`, for example they could be operands to `arith.extsi`
+ // that is in turn fed into `vector.contract`.
+ Value lhs;
+ Value rhs;
+ Value acc;
+
+ // Conventional names for matrix dimensions.
+ int64_t M = 0;
+ int64_t N = 0;
+ int64_t K = 0;
+
+ // Single-dimensional vector types for the operands of the ArmSVE dialect
+ // op.
+ VectorType flatLhsType;
+ VectorType flatRhsType;
+ VectorType flatAccType;
+
+ // Single-dimension vector type for the entire RHS tile.
+ VectorType flatRhsTileType;
+
+ // Vector type having the same number of elements as a row in the
+ // accumulator/output tile and the same element type.
+ VectorType accRowTy;
+
+ // Vector type having twice the number of elements as a row in the
+ // accumulator/output tile the same element type.
+ VectorType accRowX2Ty;
+
+ // Vector type having half the number of elements as a row in the
+ // accumulator/output tile and an integer element type with twice the bit
+ // width.
+ VectorType accRow64Ty;
+ VectorType accRowX264Ty;
+
+ // Indicate if the operands for the ArmSVE dialect operation need to be
+ // swapped. Currently this is needed in order to emulate an "summla"
+ // operation.
+ bool swapOperands = false;
+
+ // Create the matrix mulitply and accumulate operation according to
+ // `mmlaOp`.
+ Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+ Value lhs, Value rhs);
+
+ // Check general preconditions for applying the transformation, common to the
+ // integer and the bfloat16 case.
+ LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter);
+
+public:
+ VectorContractRewriter() = default;
+
+ // Do the actuall rewrite. This member function is shared by both integer and
+ // bfloat16 rewrites.
+ Value rewrite(vector::ContractionOp op, PatternRewriter &rewriter);
+};
+
+Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
+ Location loc, Value acc, Value lhs,
+ Value rhs) {
+ if (swapOperands)
+ std::swap(lhs, rhs);
+
+ switch (mmlaOp) {
+ case MMLA::SignedInt:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::UnsignedInt:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::MixedInt:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::Bfloat:
+ return rewriter.create<arm_sve::BfmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ default:
+ llvm_unreachable("Uninitialized operation kind");
+ }
+}
+
+LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(
+ op, "iterator types do not correspond to matrix multiplication");
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // This corresponds to matrix multiplication with transposed RHS.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] != AffineMap::getMultiDimMapWithTargets(
+ 3, ArrayRef{0u, 1u}, op.getContext()))
+ return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(op, "combining kind is not an addition");
+
+ return success();
+}
+
+Value VectorContractRewriter::rewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ Location loc = op.getLoc();
+
+ // Extract LHS sub-tiles with logical shape <2xK>.
+ SmallVector<Value> lhsTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the LHS tile.
+ auto r0 =
+ rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i});
+ auto r1 =
+ rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i + 1});
+ // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
+ SmallVector<int64_t> shuffleIdx(2 * K);
+ std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
+ auto t = rewriter.create<vector::ShuffleOp>(loc, r0, r1, shuffleIdx);
+ // Turn it into a scalable vector.
+ auto s = rewriter.create<vector::ScalableInsertOp>(
+ loc, t, rewriter.create<ub::PoisonOp>(loc, flatLhsType), 0);
+ // Replicate the sub-tile VSCALE times to fill the entire vector.
+ auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+ lhsTile.push_back(r);
+ }
+
+ // "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
+ auto rhs = rewriter.create<vector::ShapeCastOp>(this->rhs.getLoc(),
+ flatRhsTileType, this->rhs);
+
+ // Extract the RHS sub-tiles with logical shape <Kx[2]>.
+ SmallVector<Value> rhsTile;
+ for (int64_t j = 0; j < N; j += 2)
+ rhsTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+ loc, flatRhsType, rhs, j * K));
+
+ // Extract and pack the ACC sub-tiles.
+ SmallVector<Value> accTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the accumulator tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i + 1});
+ Value accTileVec;
+ if (swapOperands) {
+ // We are performing the operation with swapped LHS and RHS we need to
+ // transpose each individual 2x2 tile of the accumulator and (later) the
+ // final result.
+ accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+ } else {
+ // Bitcast accumulator rows to double-width integer elements, so
+ // subsequent interleave/deinterleave work on pairs of elements.
+ auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+ auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+ // Interleave the rows, effectively flattening each 2x2 tile into 4
+ // consecutive elements.
+ auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64);
+
+ // Bitcast back to original element type.
+ accTileVec = rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
+ }
+ // Extract ACC sub-tiles.
+ for (int64_t j = 0; j < N; j += 2)
+ accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+ loc, flatAccType, accTileVec, j * 2));
+ }
+
+ // Emit sub-tile matrix multiplications.
+ SmallVector<Value> outTile;
+ for (int64_t i = 0; i < M / 2; ++i)
+ for (int64_t j = 0; j < N / 2; ++j) {
+ Value mmla = createMMLA(rewriter, loc, accTile[i * N / 2 + j], lhsTile[i],
+ rhsTile[j]);
+ outTile.push_back(mmla);
+ }
+
+ // Unpack the OUT sub-tiles and insert into the result.
+ Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+ for (int64_t i = 0; i < M / 2; ++i) {
+ // Collect a number of sub-tiles in a row.
+ Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+ for (int64_t j = 0; j < N / 2; ++j)
+ row = rewriter.create<vector::ScalableInsertOp>(
+ loc, outTile[i * N / 2 + j], row, j * 4);
+
+ // Unpack the row to obtain two rows of the output. If we have the out
+ // sub-tiles transposed we obtain two consecutive output rows by
+ // separating even and odd elements, i.e. a simple deinterleave.
+ // Otherwise, the interleave is by pairs.
+ Value out0, out1;
+ if (swapOperands) {
+ auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+ out0 = tmp.getRes1();
+ out1 = tmp.getRes2();
+ } else {
+ // Deinterleave by pairs.
+ auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+ auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+ // Bitcast back into original element type and insert into the result.
+ out0 =
+ rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes1());
+ out1 =
+ rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes2());
+ }
+ result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+ result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+ }
+
+ return result;
+}
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+ // Check the specific preconditions for the integer case. Initialise
+ // parametrisation types and dimensions.
+ LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) {
+
+ if (failed(VectorContractRewriter::match(op, rewriter)))
+ return failure();
+
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+
+ M = lhsType.getDimSize(0);
+ N = rhsType.getDimSize(0);
+ K = rhsType.getDimSize(1);
+
+ // Check the operands have the expected shape:
+ // * for LHS: fixed vector MxK
+ // * for RHS: scalable vector [N]xK
+ // * K == 8
+ // * M and N even and at least 2
+ if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
+ rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
+ M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+ !rhsType.getScalableDims()[0])
+ return rewriter.notifyMatchFailure(op, "non-matching operand shape");
+
+ // Check the output is a vector of i32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getResultType());
+ if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+ return rewriter.notifyMatchFailure(op,
+ "output type is not a vector of i32");
+
+ // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+ // before the extension. All four signed/unsigned combinations for input
+ // operands are supported, but they are lowered to different operations.
+ // Determine which is the appropriate operation to lower to.
+ mmlaOp = MMLA::SignedInt;
+ swapOperands = false;
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::UnsignedInt;
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+ }
+ if (!maybeLhs)
+ return rewriter.notifyMatchFailure(
+ op, "LHS is not a sign- or zero- extended i8");
+
+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::UnsignedInt)
+ mmlaOp = MMLA::MixedInt;
+ } else {
+ if (mmlaOp == MMLA::SignedInt) {
+ mmlaOp = MMLA::MixedInt;
+ swapOperands = true;
+ }
+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
+ }
+ if (!maybeRhs)
+ return rewriter.notifyMatchFailure(
+ op, "RHS is not a sign- or zero- extended i8");
+
+ // Initialise algorithm parameters.
+ lhs = *maybeLhs;
+ rhs = *maybeRhs;
+ acc = op.getAcc();
+
+ flatLhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
+ /*scalableDims=*/{true});
+ flatRhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
+ /*scalableDims=*/{true});
+
+ flatAccType = VectorType::get(/*shape=*/4, rewriter.getI32Type(),
+ /*scalableDims=*/{true});
+
+ flatRhsTileType = VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(),
+ /*scalableDims=*/{true});
+
+ accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(),
+ /*scalableDims=*/{true});
+ accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(),
+ /*scalableDims=*/{true});
+ accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
+ /*scalableDims=*/{true});
+ accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
+ /*scalableDims=*/{true});
+
+ return success();
+ }
+};
+
+class VectorContractRewriterBfloat : public VectorContractRewriter {
+public:
+ // Check the specific preconditions for the bfloat16 case. Initialise
+ // parametrisation types and dimensions.
+ LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) {
----------------
banach-space wrote:
This does a bit more than just matching :) I would split this into two methods, but I am also fine with just renaming.
```suggestion
// Check the specific preconditions for the bfloat16 case. Initialise
// parametrisation types and dimensions.
LogicalResult matchAndInitialize(vector::ContractionOp op, PatternRewriter &rewriter) {
```
https://github.com/llvm/llvm-project/pull/147052
More information about the Mlir-commits
mailing list