[Mlir-commits] [mlir] [MLIR][AArch64] Refactor lowering of vector.contract to Neon I8MM (PR #149810)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Jul 21 06:45:34 PDT 2025
================
@@ -85,202 +77,186 @@ std::optional<Value> getExtOperand(Value v) {
return inOp;
}
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
- Signed, // smmla
- Unsigned, // ummla
- Mixed, // usmmla
- MixedSwapped // usmmla with LHS and RHS swapped
-};
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+ bool signExt, PatternRewriter &rewriter) {
+ Type targetTy = srcTy.clone(rewriter.getI8Type());
+ return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+ : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
- mlir::Type accType, Value acc, Value lhs, Value rhs) {
- switch (op) {
- case MMLA::Signed:
- return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
- rhs);
- case MMLA::Unsigned:
- return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
- rhs);
- case MMLA::Mixed:
- return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
- rhs);
- case MMLA::MixedSwapped:
- // The accumulator comes transposed and the result will be transposed
- // later, so all we have to do here is swap the operands.
- return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
- lhs);
+class VectorContractRewriter {
+protected:
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
+ // multiplications.
+ enum class MMLA {
+ Nop,
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+ };
+
+ // Lower-level operation to be emitted.
+ MMLA mmlaOp = MMLA::Nop;
+
+ // The operand tiles. These are not necessarily the operands of
+ // `vector.contract`, for example they could be operands to `arith.extsi`
+ // that is in turn fed into `vector.contract`.
+ Value lhs;
+ Value rhs;
+ Value acc;
+
+ // The dimensions logically corresponding to matrix multiplication of
+ // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+ // shapes, for example RHS could be NxK with a transposing indexing map.
+ int64_t dimM = 0;
+ int64_t dimN = 0;
+ int64_t dimK = 0;
+
+ // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+ SmallVector<int64_t> iterationBounds;
+
+ // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+ // of this shape.
+ SmallVector<int64_t> subTileShape;
+
+ // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+ Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+ Value lhs, Value rhs) {
+ switch (mmlaOp) {
+ case MMLA::Signed:
+ return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+ rhs, lhs);
+ case MMLA::Nop:
+ llvm_unreachable("Uninitialized operation type");
+ }
}
-}
-/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
-/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
-/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
-/// necessary, a single smmla instruction is emitted.
-class LowerContractionToNeonI8MMPattern
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
- // Note: RHS is not transposed.
- mlir::VectorType lhsType = op.getLhsType();
- mlir::VectorType rhsType = op.getRhsType();
+ // Check common preconditions for applying the patterns and initialize
+ // logical dimensions.
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ // Check iterator types for matrix multiplication.
+ SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+ if (!((itTypes.size() == 3 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::parallel &&
+ itTypes[2] == vector::IteratorType::reduction)) ||
+ (itTypes.size() == 2 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::reduction))))
+ return rewriter.notifyMatchFailure(
+ op, "iterator types do not correspond to matrix multiplication");
+
// Avoid 0-D vectors and 1-D rhs:
- if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
- return failure();
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+ rhsType.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
// This codegen does not work for scalable vectors. Return failure so this
// pattern is not accidentally chosen over patterns that lower to ArmSVE.
if (lhsType.isScalable() || rhsType.isScalable())
- return failure();
- auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
- auto dimN = rhsType.getDimSize(0);
- auto dimK = rhsType.getDimSize(1);
- bool isVecmat = dimM == 1 ? true : false;
- if (lhsType.getDimSize(lhsType.getRank() - 1) !=
- rhsType.getDimSize(rhsType.getRank() - 1)) {
- return failure(); // dimK mismatch
- }
- // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
- // tiling.
- if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
- return failure();
- }
-
- // Check iterator types for contract. All iterators except inner-most
- // dimension must be parallel.
- auto iteratorTypes = op.getIteratorTypesArray();
- if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
- vector::IteratorType::reduction) {
- return failure();
- }
- if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
- [](vector::IteratorType iteratorType) {
- return iteratorType != vector::IteratorType::parallel;
- })) {
- return failure();
+ return rewriter.notifyMatchFailure(op,
+ "Not applicable to scalable vectors");
+
+ // Initialize dimensions and check for a matching K dimension.
+ dimM = lhsType.getDimSize(0);
+ dimN = rhsType.getDimSize(0);
+ dimK = rhsType.getDimSize(1);
+
+ int64_t lhsDimK;
+ if (lhsType.getRank() == 1) {
+ dimM = 1;
+ lhsDimK = lhsType.getDimSize(0);
+ } else {
+ lhsDimK = lhsType.getDimSize(1);
}
- // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
- // values before the extension. All four signed/unsigned combinations for
- // input operands are supported, but they are lowered to different
- // operations. Determine which is the appropriate operation to lower to.
- MMLA mmlaOp = MMLA::Signed;
- auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
- if (!maybeLhs) {
- mmlaOp = MMLA::Unsigned;
- maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
- }
- if (!maybeLhs)
- return failure();
+ if (lhsDimK != dimK)
+ return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
- auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
- if (maybeRhs) {
- if (mmlaOp == MMLA::Unsigned)
- mmlaOp = MMLA::Mixed;
- } else {
- if (mmlaOp == MMLA::Signed)
- mmlaOp = MMLA::MixedSwapped;
- maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
- }
- if (!maybeRhs)
- return failure();
+ return success();
+ }
- Value origLhs = *maybeLhs;
- Value origRhs = *maybeRhs;
-
- // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
- // following neon instruction. Check inputs for extsi are <=i8
- Value extLhs;
- Value extRhs;
- if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
- if (lhsExtInType.getElementTypeBitWidth() <= 8) {
- Type targetLhsExtTy =
- matchContainerType(rewriter.getI8Type(), lhsExtInType);
- if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
- extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
- origLhs);
- else
- extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
- origLhs);
- }
- }
- if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
- if (rhsExtInType.getElementTypeBitWidth() <= 8) {
- Type targetRhsExtTy =
- matchContainerType(rewriter.getI8Type(), rhsExtInType);
- if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
- extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
- origRhs);
- else
- extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
- origRhs);
- }
- }
+public:
+ void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
----------------
banach-space wrote:
```suggestion
void lower(vector::ContractionOp op, PatternRewriter &rewriter) {
```
https://github.com/llvm/llvm-project/pull/149810
More information about the Mlir-commits
mailing list