[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` with mixed signed/unsigned arguments to Neon FEAT_I8MM (PR #144698)

Andrzej Warzyński llvmlistbot at llvm.org
Fri Jun 20 00:13:47 PDT 2025


================
@@ -37,12 +37,87 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  // If the operand is not defined by an explicit extend operation of the
+  // accepted operation type allow for an implicit sign-extension.
+  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+      auto eltTy = cast<VectorType>(v.getType()).getElementType();
+      if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  // If the operand is defined by an explicit extend operation of the accepted
+  // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy)
+    return {};
+  auto inEltTy = inTy.getElementType();
+  if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix mulitply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Unsigned:
+    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Mixed:
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
+                                                     rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
+                                                     lhs);
+  }
+}
+
----------------
banach-space wrote:

Thanks for the overview - let’s make sure this information is also captured in the code for future reference.

As a concrete request: could you expand the top-level block comment in this file with a brief summary of the supported cases? For example:
```cpp
// Supports:
// a) Arbitrary permutation maps.
// b) iN, n <= 8
// c) vecMat operations
```

Please also add a similar summary comment in the `SVE` pattern. This will make it much easier to compare both implementations side-by-side, and clarify what’s required to unify them:
* Either make the `SVE` pattern support arbitrary permutation maps, or reduce the NEON support.
* Either make the `SVE` pattern support sub-byte types, or restrict the `NEON` pattern to `i8`.
* Either add `vecMat` support to the `SVE` pattern, or remove it from `NEON`.

At this stage, it’s hard to say which direction is best without integrating everything (i.e., `Linalg` + `Vector` support) and evaluating. I’m fine with leaving this as a TODO for now.

https://github.com/llvm/llvm-project/pull/144698


More information about the Mlir-commits mailing list