[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` with mixed signed/unsigned arguments to Neon FEAT_I8MM (PR #144698)

Momchil Velikov llvmlistbot at llvm.org
Wed Jun 18 10:35:08 PDT 2025


================
@@ -37,12 +37,87 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  // If the operand is not defined by an explicit extend operation of the
+  // accepted operation type allow for an implicit sign-extension.
+  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+      auto eltTy = cast<VectorType>(v.getType()).getElementType();
+      if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  // If the operand is defined by an explicit extend operation of the accepted
+  // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy)
+    return {};
+  auto inEltTy = inTy.getElementType();
+  if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix mulitply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Unsigned:
+    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Mixed:
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
+                                                     rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
+                                                     lhs);
+  }
+}
+
----------------
momchil-velikov wrote:

> Is `vecMat` support the only other outstanding difference between NEON and SVE support?

The differences on the functionality:
a) Neon handles arbitrary permutation maps, the SVE only the usual "identities + transposed RHS" one.
b) the Neon can handle `iN`, `N <= 8`, SVE can handle `N == 8'
c) SVE does not handle "vecmat" (it wants LHS with even number of rows)

b) and c) can _probably_ be added to the SVE version without too much trouble (as least as `vector.contract` lowering is concerned)

I don't have high expectations for a), the SVE handling of RHS especially was rather tricky even with the simple maps. Moreover, I'd also question the necessity of support for arbitrary permutation maps - a great chunk (I'll even wave my hands and say the "greatest chunk") of performance comes from the ability to do sequential loads and stores.


https://github.com/llvm/llvm-project/pull/144698


More information about the Mlir-commits mailing list