[Mlir-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri May 30 08:33:53 PDT 2025


================
@@ -0,0 +1,358 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the SVE FEAT_I8MM extension.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Get the LHS or RHS side operand of a vector contract. Handle two cases
+//   * if the operand is a sign- or zero- extend operation of type `T` from i8
+//     to i32, return the value before the extension, otherwise
+//   * if the operand is of i8 type and the operation is sign-extend, return the
+//     operand itself.
+//
+// This way we handle both explicit sign- or zero- extension or implicit
+// sign-extension.
+template <typename T>
+std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
+
+  static_assert(llvm::is_one_of<T, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<T, arith::ExtSIOp>::value) {
+      auto vTy = cast<VectorType>(v.getType());
+      if (vTy.getElementType() != i8Ty)
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+    return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+  }
+}
+
+// Lower a contraction operation that performs a matrix multiplication
+// of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
+// for the left-hand side and the right-hand side, respectively,
+// yielding a <Mx[N]> 32-bit integer result.
+//
+// The operands shapes are such that the operands can be evenly split into
+// sub-tiles with dimensions as expected by the targeted FEAT_I8MM instructions.
+// The intent is that M and N are chosen (by higher level transforms) in such a
+// way as to maximise register usage. The main use case we envision as of now is
+// MMT4D, thus the RHS operand is expected pre-transposed.
+//
+// The matrix multiplication is performed by unrolling the usual tiled matrix
+// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
+// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
+//
+// One way to illustrate the operation is as follows:
+//
+// RHS<8x[N]>:       <8x[2]> <8x[2]> ... <8x[2]>
+//                 +-----------------------------
+// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+//           <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+//            ...  |   ...     ...   ...   ...
+//           <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+//
+// The RHS operand is unpacked into N/2 values, each representing a sequence of
+// VSCALE number of sub-tiles with dimensions <8x2>.
+// The LHS operand is initially unpacked into M/2 values, each representing a
+// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
+// VSCALE times.
+// Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
+// correctly computes an entire result sub-tile.
+class LowerContractionToSVEI8MMPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+
+    // Check the operands have the expected shape. M and N dimensions must be
+    // even and at least 2.
+    if (lhsType.getRank() != 2 || rhsType.getRank() != 2 ||
+        lhsType.isScalable() || !rhsType.isScalable())
+      return rewriter.notifyMatchFailure(op, "non-matching operand shape");
+
+    // M, N, and K are the conventional names for matrix dimensions in the
+    // context of matrix multiplication.
+    auto M = lhsType.getDimSize(0);
+    auto N = rhsType.getDimSize(0);
+    auto K = rhsType.getDimSize(1);
+
+    if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+        N % 2 != 0 || !rhsType.getScalableDims()[0])
+      return rewriter.notifyMatchFailure(op, "non-matching operand shape");
+
+    // Check permutation maps. For now only accept
+    //   lhs: (d0, d1, d2) -> (d0, d2)
+    //   rhs: (d0, d1, d2) -> (d1, d2)
+    //   acc: (d0, d1, d2) -> (d0, d1)
+    // Note: RHS is transposed.
+    if (op.getIndexingMapsArray()[0] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[1] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[2] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+                                                 op.getContext()))
+      return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
+    // Check iterator types for matrix multiplication.
+    auto itTypes = op.getIteratorTypesArray();
+    if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+        itTypes[1] != vector::IteratorType::parallel ||
+        itTypes[2] != vector::IteratorType::reduction)
+      return rewriter.notifyMatchFailure(
+          op, "iterator types do not correspond to matrix multiplication");
+
+    // Check the combining kind is addition.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(op,
+                                         "combining kind is not an addition");
+
+    // Check the output is a vector of i32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getResultType());
+    if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+      return rewriter.notifyMatchFailure(op,
+                                         "output type is not a vector of i32");
+
+    // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+    // before the extension. All four signed/unsigned combinations for input
+    // operands are supported, but they are lowered to different operations.
+    // Determine which is the appropriate operation to lower to.
+    MMLA mmlaOp = MMLA::Signed;
+    auto maybeLhs = getExtOperand<arith::ExtSIOp>(
+        op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::Unsigned;
+      maybeLhs = getExtOperand<arith::ExtUIOp>(
+          op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeLhs)
+      return rewriter.notifyMatchFailure(
+          op, "LHS is not a sign- or zero- extended i8");
+
+    auto maybeRhs = getExtOperand<arith::ExtSIOp>(
+        op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::Unsigned)
+        mmlaOp = MMLA::Mixed;
+    } else {
+      if (mmlaOp == MMLA::Signed)
+        mmlaOp = MMLA::MixedSwapped;
+      maybeRhs = getExtOperand<arith::ExtUIOp>(
+          op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeRhs)
+      return rewriter.notifyMatchFailure(
+          op, "RHS is not a sign- or zero- extended i8");
+
+    // One-dimensional vector types for arm_sve.*mmla
+    auto nxv16i8 = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
+                                   /*scalableDims=*/{true});
+    auto nxv4i32 = VectorType::get(/*shape=*/4, rewriter.getI32Type(),
+                                   /*scalableDims=*/{true});
+
+    // Extract LHS sub-tiles with logicall shape <2x8>.
+    SmallVector<Value> lhsTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Exract two consective rows of the LHS tile.
+      auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i});
+      auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i + 1});
+      // Concatenate to obtain a 16 x i8 flattened sub-tile.
+      auto t = rewriter.create<vector::ShuffleOp>(
+          loc, r0, r1,
+          llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+                                  14, 15});
+      // Turn it into a scalable vector.
+      auto s = rewriter.create<vector::ScalableInsertOp>(
+          loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
+      // Replicate the sub-tile VSCALE times to fill the entire vector.
+      auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+      lhsTile.push_back(r);
+    }
+
+    // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
+    auto RHS = rewriter.create<vector::ShapeCastOp>(
----------------
banach-space wrote:

> IMHO M, N and K are fine the way they are.

I am a bit uneasy about diverging from the coding guidelines, but in this case it would hurt readability, lets keep it as is. This is complex as is.

https://github.com/llvm/llvm-project/pull/135636


More information about the Mlir-commits mailing list