[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
Momchil Velikov via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed May 14 10:05:49 PDT 2025
================
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+ std::is_base_of_v<arith::ExtUIOp, T>),
+ std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+ auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+ if (!extOp)
+ return {};
+
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || inTy.getElementType() != i8Ty)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || outTy.getElementType() != i32Ty)
+ return {};
+
+ return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+ switch (op) {
+ case MMLA::Signed:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+ }
+}
+
+class LowerContractionToSVEI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+
+ // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+ // eventually expect from MMT4D. M and N dimensions must be even and at
+ // least 2.
+ if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+ rhsType.getRank() != 2)
+ return failure();
+
+ if (lhsType.isScalable() || !rhsType.isScalable())
+ return failure();
+
+ // M, N, and K are the conventional names for matrix dimensions in the
+ // context of matrix multiplication.
+ auto M = lhsType.getDimSize(0);
+ auto N = rhsType.getDimSize(0);
+ auto K = rhsType.getDimSize(1);
+
+ if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+ N % 2 != 0 || !rhsType.getScalableDims()[0])
+ return failure();
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // Note: RHS is transposed.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+ return failure();
+
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return failure();
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return failure();
+
+ // Check the output is a vector of i32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getType());
+ if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+ return failure();
+
+ // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+ // before the extension. All four signed/unsigned combinations for input
+ // operands are supported, but they are lowered to different operations.
+ // Determina which is the appropriate operation to lower to.
+ MMLA mmlaOp = MMLA::Signed;
+ auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::Unsigned;
+ maybeLhs = extractExtOperand<arith::ExtUIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeLhs)
+ return failure();
+
+ auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::Unsigned)
+ mmlaOp = MMLA::Mixed;
+ } else {
+ if (mmlaOp == MMLA::Signed)
+ mmlaOp = MMLA::MixedSwapped;
+ maybeRhs = extractExtOperand<arith::ExtUIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeRhs)
+ return failure();
+
+ // One-dimensional vector types for arm_sve.*mmla
+ auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+ auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+ // Extract LHS sub-tiles.
----------------
momchil-velikov wrote:
Done.
https://github.com/llvm/llvm-project/pull/135636
More information about the llvm-branch-commits
mailing list