[Mlir-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed May 21 01:16:16 PDT 2025


================
@@ -0,0 +1,358 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the SVE FEAT_I8MM extension.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Get the LHS or RHS side operand of a vector contract. Handle two cases
+//   * if the operand is a sign- or zero- extend operation of type `T` from i8
+//     to i32, return the value before the extension, otherwise
+//   * if the operand is of i8 type and the operation is sign-extend, return the
+//     operand itself.
+//
+// This way we handle both explicit sign- or zero- extension or implicit
+// sign-extension.
+template <typename T>
+std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
+
+  static_assert(llvm::is_one_of<T, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<T, arith::ExtSIOp>::value) {
+      auto vTy = cast<VectorType>(v.getType());
+      if (vTy.getElementType() != i8Ty)
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
----------------
banach-space wrote:

To me, these are just different variants of MMLA. However, after your comment, I see that others might read this differently. 

https://github.com/llvm/llvm-project/pull/135636


More information about the Mlir-commits mailing list