[Mlir-commits] [mlir] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations (PR #148198)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Jul 15 11:52:58 PDT 2025
================
@@ -0,0 +1,499 @@
+//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the Neon FEAT_I8MM extension.
+//
+// TODO: There may be opportunities to unify this with a similar pattern
+// for SVE. See:
+// https://github.com/llvm/llvm-project/issues/145559
+// LowerContractionToSVEI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-neon"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+/// Get the operand of a `vector.contract`. This function is intended to
+/// abstract away from the particular way a value is extended before feeding it
+/// into the `vector.contract` - via zero-extend or an explicit or implicit
+/// sign-extend (for implicit sign-extension see `vector.contract`
+/// documentation).
+///
+/// The template parameter `Op` indicates the extension operation (explicit or
+/// implicit) for which we are checking.
+///
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+ static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+ "Must be instantiated with either sign- or zero- extension op");
+
+ // If the operand is not defined by an explicit extend operation of the
+ // accepted operation type allow for an implicit sign-extension.
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ if (!extOp) {
+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+ auto eltTy = cast<VectorType>(v.getType()).getElementType();
+ if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+ return {};
+ return v;
+ }
+ return {};
+ }
+
+ // If the operand is defined by an explicit extend operation of the accepted
+ // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy)
+ return {};
+ auto inEltTy = inTy.getElementType();
+ if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+ return {};
+
+ return inOp;
+}
+
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+ bool signExt, PatternRewriter &rewriter) {
+ Type targetTy = srcTy.clone(rewriter.getI8Type());
+ return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+ : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
+
+class VectorContractRewriter {
+protected:
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
+ // multiplications.
+ enum class MMLA {
+ Nop,
+ SignedInt, // smmla
+ UnsignedInt, // ummla
+ MixedInt, // usmmla
+ Bfloat // bfmmla
+ };
+
+ // Lower-level operation to be emitted.
+ MMLA mmlaOp = MMLA::Nop;
+
+ // Indicate if the operands for the ArmNeon dialect operation need to be
+ // swapped. Currently this is needed in order to emulate an "summla"
+ // operation.
+ bool swapOperands = false;
+
+ // The operand tiles. These are not necessarily the operands of
+ // `vector.contract`, for example they could be operands to `arith.extsi`
+ // that is in turn fed into `vector.contract`.
+ Value lhs;
+ Value rhs;
+ Value acc;
+
+ // The dimensions logically corresponding to matrix multiplication of
+ // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+ // shapes, for example RHS could be NxK with a transposing indexing map.
+ int64_t dimM = 0;
+ int64_t dimN = 0;
+ int64_t dimK = 0;
+
+ // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+ SmallVector<int64_t> iterationBounds;
+
+ // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+ // of this shape.
+ SmallVector<int64_t> subTileShape;
+
+ // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+ Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+ Value lhs, Value rhs) {
+
+ if (swapOperands)
+ std::swap(lhs, rhs);
+ switch (mmlaOp) {
+ case MMLA::SignedInt:
+ return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::UnsignedInt:
+ return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::MixedInt:
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::Bfloat:
+ return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
+ rhs);
+ case MMLA::Nop:
+ llvm_unreachable("Uninitialized operation type");
+ }
+ }
+
+ // Check common preconditions for applying the patterns and initialize
+ // logical dimensions.
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ // Check iterator types for matrix multiplication.
+ SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+ if (!((itTypes.size() == 3 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::parallel &&
+ itTypes[2] == vector::IteratorType::reduction)) ||
+ (itTypes.size() == 2 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::reduction))))
+ return rewriter.notifyMatchFailure(
+ op, "iterator types do not correspond to matrix multiplication");
+
+ // Avoid 0-D vectors and 1-D rhs:
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+ rhsType.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
+ // This codegen does not work for scalable vectors. Return failure so this
+ // pattern is not accidentally chosen over patterns that lower to ArmSVE.
+ if (lhsType.isScalable() || rhsType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "Not applicable to scalable vectors");
+
+ // Initialize dimensions and check for a matching K dimension.
+ dimM = lhsType.getDimSize(0);
+ dimN = rhsType.getDimSize(0);
+ dimK = rhsType.getDimSize(1);
+
+ int64_t lhsDimK;
+ if (lhsType.getRank() == 1) {
+ dimM = 1;
+ lhsDimK = lhsType.getDimSize(0);
+ } else {
+ lhsDimK = lhsType.getDimSize(1);
+ }
+
+ if (lhsDimK != dimK)
+ return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
+
+ return success();
+ }
+
+public:
+ void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
----------------
banach-space wrote:
[nit] I was going through https://mlir.llvm.org/deprecation/ and realised that `match` and `rewrite` are deprecated. While you don't use `rewrite` here (this method is something different), I am thinking that it would be good to use some other, more descriptive name. How about `lowerToNEON`? Or just `lower`?
https://github.com/llvm/llvm-project/pull/148198
More information about the Mlir-commits
mailing list