[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` with mixed signed/unsigned arguments to Neon FEAT_I8MM (PR #144698)

Momchil Velikov llvmlistbot at llvm.org
Wed Jun 18 10:18:19 PDT 2025


================
@@ -37,12 +37,87 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  // If the operand is not defined by an explicit extend operation of the
+  // accepted operation type allow for an implicit sign-extension.
+  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+      auto eltTy = cast<VectorType>(v.getType()).getElementType();
+      if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  // If the operand is defined by an explicit extend operation of the accepted
+  // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy)
+    return {};
+  auto inEltTy = inTy.getElementType();
+  if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix mulitply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Unsigned:
+    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Mixed:
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
+                                                     rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
+                                                     lhs);
+  }
+}
+
----------------
momchil-velikov wrote:

I can think of sharing the one duplicated function and mangling the other in a way so it can be shared between both patterns as soon as you suggest where to put them.

Depends on what do you mean by "unify". The algorithms for the transformations are completely different, I don't see any unification possible.


https://github.com/llvm/llvm-project/pull/144698


More information about the Mlir-commits mailing list