[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` with mixed signed/unsigned arguments to Neon FEAT_I8MM (PR #144698)
Andrzej Warzyński
llvmlistbot at llvm.org
Fri Jun 20 00:13:47 PDT 2025
================
@@ -37,12 +37,87 @@ static Type matchContainerType(Type element, Type container) {
return element;
}
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+ static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+ "Must be instantiated with either sign- or zero- extension op");
+
+ // If the operand is not defined by an explicit extend operation of the
+ // accepted operation type allow for an implicit sign-extension.
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ if (!extOp) {
+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+ auto eltTy = cast<VectorType>(v.getType()).getElementType();
+ if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+ return {};
+ return v;
+ }
+ return {};
+ }
+
+ // If the operand is defined by an explicit extend operation of the accepted
+ // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy)
+ return {};
+ auto inEltTy = inTy.getElementType();
+ if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+ return {};
+
+ return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix mulitply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::Type accType, Value acc, Value lhs, Value rhs) {
+ switch (op) {
+ case MMLA::Signed:
+ return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
+ rhs);
+ case MMLA::Unsigned:
+ return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
+ rhs);
+ case MMLA::Mixed:
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
+ rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
+ lhs);
+ }
+}
+
----------------
banach-space wrote:
Thanks for the overview - let’s make sure this information is also captured in the code for future reference.
As a concrete request: could you expand the top-level block comment in this file with a brief summary of the supported cases? For example:
```cpp
// Supports:
// a) Arbitrary permutation maps.
// b) iN, n <= 8
// c) vecMat operations
```
Please also add a similar summary comment in the `SVE` pattern. This will make it much easier to compare both implementations side-by-side, and clarify what’s required to unify them:
* Either make the `SVE` pattern support arbitrary permutation maps, or reduce the NEON support.
* Either make the `SVE` pattern support sub-byte types, or restrict the `NEON` pattern to `i8`.
* Either add `vecMat` support to the `SVE` pattern, or remove it from `NEON`.
At this stage, it’s hard to say which direction is best without integrating everything (i.e., `Linalg` + `Vector` support) and evaluating. I’m fine with leaving this as a TODO for now.
https://github.com/llvm/llvm-project/pull/144698
More information about the Mlir-commits
mailing list