[Mlir-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
Momchil Velikov
llvmlistbot at llvm.org
Tue May 27 07:42:15 PDT 2025
================
@@ -0,0 +1,358 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the SVE FEAT_I8MM extension.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Get the LHS or RHS side operand of a vector contract. Handle two cases
+// * if the operand is a sign- or zero- extend operation of type `T` from i8
+// to i32, return the value before the extension, otherwise
+// * if the operand is of i8 type and the operation is sign-extend, return the
+// operand itself.
+//
+// This way we handle both explicit sign- or zero- extension or implicit
+// sign-extension.
+template <typename T>
+std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
+
+ static_assert(llvm::is_one_of<T, arith::ExtSIOp, arith::ExtUIOp>::value,
+ "Must be instantiated with either sign- or zero- extension op");
+
+ auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+ if (!extOp) {
+ if constexpr (std::is_same<T, arith::ExtSIOp>::value) {
+ auto vTy = cast<VectorType>(v.getType());
+ if (vTy.getElementType() != i8Ty)
+ return {};
+ return v;
+ }
+ return {};
+ }
+
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || inTy.getElementType() != i8Ty)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || outTy.getElementType() != i32Ty)
+ return {};
+
+ return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+ switch (op) {
+ case MMLA::Signed:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+ }
+}
+
+// Lower a contraction operation that performs a matrix multiplication
+// of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
+// for the left-hand side and the right-hand side, respectively,
+// yielding a <Mx[N]> 32-bit integer result.
+//
+// The operands shapes are such that the operands can be evenly split into
+// sub-tiles with dimensions as expected by the targeted FEAT_I8MM instructions.
+// The intent is that M and N are chosen (by higher level transforms) in such a
+// way as to maximise register usage. The main use case we envision as of now is
+// MMT4D, thus the RHS operand is expected pre-transposed.
+//
+// The matrix multiplication is performed by unrolling the usual tiled matrix
+// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
+// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
+//
+// One way to illustrate the operation is as follows:
+//
+// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
+// +-----------------------------
+// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+// ... | ... ... ... ...
+// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+//
+// The RHS operand is unpacked into N/2 values, each representing a sequence of
+// VSCALE number of sub-tiles with dimensions <8x2>.
+// The LHS operand is initially unpacked into M/2 values, each representing a
+// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
+// VSCALE times.
+// Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
+// correctly computes an entire result sub-tile.
+class LowerContractionToSVEI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+
+ // Check the operands have the expected shape. M and N dimensions must be
+ // even and at least 2.
----------------
momchil-velikov wrote:
Rewrote this.
https://github.com/llvm/llvm-project/pull/135636
More information about the Mlir-commits
mailing list