[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` with mixed signed/unsigned arguments to Neon FEAT_I8MM (PR #144698)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Jun 18 09:36:30 PDT 2025


================
@@ -37,12 +37,87 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  // If the operand is not defined by an explicit extend operation of the
+  // accepted operation type allow for an implicit sign-extension.
+  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+      auto eltTy = cast<VectorType>(v.getType()).getElementType();
+      if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  // If the operand is defined by an explicit extend operation of the accepted
+  // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy)
+    return {};
+  auto inEltTy = inTy.getElementType();
+  if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix mulitply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Unsigned:
+    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Mixed:
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
+                                                     rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
+                                                     lhs);
+  }
+}
+
----------------
banach-space wrote:

This is a lot of duplication with https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp

Is `vecMat` support the only other outstanding difference between NEON and SVE support? If we don't unify the implementations _now_, it would be good to at least leave some TODOs + GitHub issue.





https://github.com/llvm/llvm-project/pull/144698


More information about the Mlir-commits mailing list