[Mlir-commits] [mlir] 90e632e - [MLIR][AArch64] Refactor lowering of vector.contract to Neon I8MM (#149810)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 22 08:31:22 PDT 2025
Author: Momchil Velikov
Date: 2025-07-22T16:31:17+01:00
New Revision: 90e632eb11a9ee49e8852c5356758024281fa26f
URL: https://github.com/llvm/llvm-project/commit/90e632eb11a9ee49e8852c5356758024281fa26f
DIFF: https://github.com/llvm/llvm-project/commit/90e632eb11a9ee49e8852c5356758024281fa26f.diff
LOG: [MLIR][AArch64] Refactor lowering of vector.contract to Neon I8MM (#149810)
This patch refactors the pattern in
`Transforms/LowerContractionToNeonI8MMPattern.cpp` using similar
approach as in https://github.com/llvm/llvm-project/pull/147052 to
prepare for adding BF16 support.
Added:
Modified:
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 15de736480c5e..59acb362191a7 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -31,23 +31,15 @@ using namespace mlir;
using namespace mlir::arm_neon;
namespace {
-
-/// Return the shaped type with new element type.
-static Type matchContainerType(Type element, Type container) {
- if (auto shapedTy = dyn_cast<ShapedType>(container)) {
- return shapedTy.clone(element);
- }
- return element;
-}
-
-// Get the operand of a `vector.contract`. This function is intended to abstract
-// away from the particular way a value is extended before feeding it into the
-// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
-// (for implicit sign-extension see `vector.contract` documentation).
-//
-// The template parameter `Op` indicates the extension operation (explicit or
-// implicit) for which we are checking.
-//
+/// Get the operand of a `vector.contract`. This function is intended to
+/// abstract away from the particular way a value is extended before feeding it
+/// into the `vector.contract` - via zero-extend or an explicit or implicit
+/// sign-extend (for implicit sign-extension see `vector.contract`
+/// documentation).
+///
+/// The template parameter `Op` indicates the extension operation (explicit or
+/// implicit) for which we are checking.
+///
// Return success only for extensions from `iN` (N <= 8) to `i32`.
template <typename Op>
std::optional<Value> getExtOperand(Value v) {
@@ -85,202 +77,186 @@ std::optional<Value> getExtOperand(Value v) {
return inOp;
}
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
- Signed, // smmla
- Unsigned, // ummla
- Mixed, // usmmla
- MixedSwapped // usmmla with LHS and RHS swapped
-};
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+ bool signExt, PatternRewriter &rewriter) {
+ Type targetTy = srcTy.clone(rewriter.getI8Type());
+ return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+ : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
- mlir::Type accType, Value acc, Value lhs, Value rhs) {
- switch (op) {
- case MMLA::Signed:
- return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
- rhs);
- case MMLA::Unsigned:
- return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
- rhs);
- case MMLA::Mixed:
- return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
- rhs);
- case MMLA::MixedSwapped:
- // The accumulator comes transposed and the result will be transposed
- // later, so all we have to do here is swap the operands.
- return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
- lhs);
+class VectorContractRewriter {
+protected:
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
+ // multiplications.
+ enum class MMLA {
+ Nop,
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+ };
+
+ // Lower-level operation to be emitted.
+ MMLA mmlaOp = MMLA::Nop;
+
+ // The operand tiles. These are not necessarily the operands of
+ // `vector.contract`, for example they could be operands to `arith.extsi`
+ // that is in turn fed into `vector.contract`.
+ Value lhs;
+ Value rhs;
+ Value acc;
+
+ // The dimensions logically corresponding to matrix multiplication of
+ // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+ // shapes, for example RHS could be NxK with a transposing indexing map.
+ int64_t dimM = 0;
+ int64_t dimN = 0;
+ int64_t dimK = 0;
+
+ // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+ SmallVector<int64_t> iterationBounds;
+
+ // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+ // of this shape.
+ SmallVector<int64_t> subTileShape;
+
+ // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+ Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+ Value lhs, Value rhs) {
+ switch (mmlaOp) {
+ case MMLA::Signed:
+ return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+ rhs, lhs);
+ case MMLA::Nop:
+ llvm_unreachable("Uninitialized operation type");
+ }
}
-}
-/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
-/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
-/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
-/// necessary, a single smmla instruction is emitted.
-class LowerContractionToNeonI8MMPattern
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
- // Note: RHS is not transposed.
- mlir::VectorType lhsType = op.getLhsType();
- mlir::VectorType rhsType = op.getRhsType();
+ // Check common preconditions for applying the patterns and initialize
+ // logical dimensions.
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ // Check iterator types for matrix multiplication.
+ SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+ if (!((itTypes.size() == 3 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::parallel &&
+ itTypes[2] == vector::IteratorType::reduction)) ||
+ (itTypes.size() == 2 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::reduction))))
+ return rewriter.notifyMatchFailure(
+ op, "iterator types do not correspond to matrix multiplication");
+
// Avoid 0-D vectors and 1-D rhs:
- if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
- return failure();
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+ rhsType.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
// This codegen does not work for scalable vectors. Return failure so this
// pattern is not accidentally chosen over patterns that lower to ArmSVE.
if (lhsType.isScalable() || rhsType.isScalable())
- return failure();
- auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
- auto dimN = rhsType.getDimSize(0);
- auto dimK = rhsType.getDimSize(1);
- bool isVecmat = dimM == 1 ? true : false;
- if (lhsType.getDimSize(lhsType.getRank() - 1) !=
- rhsType.getDimSize(rhsType.getRank() - 1)) {
- return failure(); // dimK mismatch
- }
- // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
- // tiling.
- if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
- return failure();
- }
-
- // Check iterator types for contract. All iterators except inner-most
- // dimension must be parallel.
- auto iteratorTypes = op.getIteratorTypesArray();
- if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
- vector::IteratorType::reduction) {
- return failure();
- }
- if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
- [](vector::IteratorType iteratorType) {
- return iteratorType != vector::IteratorType::parallel;
- })) {
- return failure();
+ return rewriter.notifyMatchFailure(op,
+ "Not applicable to scalable vectors");
+
+ // Initialize dimensions and check for a matching K dimension.
+ dimM = lhsType.getDimSize(0);
+ dimN = rhsType.getDimSize(0);
+ dimK = rhsType.getDimSize(1);
+
+ int64_t lhsDimK;
+ if (lhsType.getRank() == 1) {
+ dimM = 1;
+ lhsDimK = lhsType.getDimSize(0);
+ } else {
+ lhsDimK = lhsType.getDimSize(1);
}
- // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
- // values before the extension. All four signed/unsigned combinations for
- // input operands are supported, but they are lowered to
diff erent
- // operations. Determine which is the appropriate operation to lower to.
- MMLA mmlaOp = MMLA::Signed;
- auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
- if (!maybeLhs) {
- mmlaOp = MMLA::Unsigned;
- maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
- }
- if (!maybeLhs)
- return failure();
+ if (lhsDimK != dimK)
+ return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
- auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
- if (maybeRhs) {
- if (mmlaOp == MMLA::Unsigned)
- mmlaOp = MMLA::Mixed;
- } else {
- if (mmlaOp == MMLA::Signed)
- mmlaOp = MMLA::MixedSwapped;
- maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
- }
- if (!maybeRhs)
- return failure();
+ return success();
+ }
- Value origLhs = *maybeLhs;
- Value origRhs = *maybeRhs;
-
- // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
- // following neon instruction. Check inputs for extsi are <=i8
- Value extLhs;
- Value extRhs;
- if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
- if (lhsExtInType.getElementTypeBitWidth() <= 8) {
- Type targetLhsExtTy =
- matchContainerType(rewriter.getI8Type(), lhsExtInType);
- if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
- extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
- origLhs);
- else
- extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
- origLhs);
- }
- }
- if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
- if (rhsExtInType.getElementTypeBitWidth() <= 8) {
- Type targetRhsExtTy =
- matchContainerType(rewriter.getI8Type(), rhsExtInType);
- if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
- extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
- origRhs);
- else
- extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
- origRhs);
- }
- }
+public:
+ void lower(vector::ContractionOp op, PatternRewriter &rewriter) {
+ // Create some convenience types.
+ auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
+ auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
+ auto inputExpandedType =
+ VectorType::get({2, subTileShape.back()}, inputElementType);
+ auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+ // One-dimensional representation of logical sub-tiles as required by the
+ // ArmNeon ops.
+ auto collapsedInputType =
+ VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+ auto collapsedOutputType =
+ VectorType::get(outputExpandedType.getNumElements(), accElementType);
+
+ // Get indexing maps for a more concise/convenient access.
+ auto indexingMaps = op.getIndexingMapsArray();
+ AffineMap &lhsPermutationMap = indexingMaps[0];
+ AffineMap &rhsPermutationMap = indexingMaps[1];
+ AffineMap &accPermutationMap = indexingMaps[2];
- if (!extLhs || !extRhs) {
- return failure();
- }
+ Location loc = op.getLoc();
// Initial accumulator for the final result. This is the un-tiled result if
// tiling is done.
Value result = rewriter.create<arith::ConstantOp>(
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
- SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
- SmallVector<int64_t> smmlaShape = {2, 8};
- SmallVector<int64_t> loopOrder = {0, 1};
- if (unrolledSize.size() == 3) {
- smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+ SmallVector<int64_t, 3> loopOrder = {0, 1};
+ if (iterationBounds.size() == 3)
loopOrder.push_back(2);
- }
// Keep track of the previous accumulator when tiling over K.
Value kAcc;
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+ StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
- SmallVector<int64_t> operandShape =
- applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+ SmallVector<int64_t> operandShape = applyPermutationMap(
+ permutationMap, ArrayRef<int64_t>(subTileShape));
SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand, operandOffsets, operandShape, operandStrides);
};
// Extract tiled lhs, rhs, and acc
- AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
SmallVector<int64_t> lhsOffsets =
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
- Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
- AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+ Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
SmallVector<int64_t> rhsOffsets =
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
- Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
- AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+ Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
SmallVector<int64_t> accOffsets =
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
- Value tiledAcc =
- extractOperand(op.getAcc(), accPermutationMap, accOffsets);
-
- auto inputElementType =
- cast<ShapedType>(tiledLhs.getType()).getElementType();
- auto accElementType =
- cast<ShapedType>(tiledAcc.getType()).getElementType();
- auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
- auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+ Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
// With vecmat, tiled LHS and ACC will contain only one of 2 necessary
- // rows along dimM. Expand their shapes to match the smmla op.
- if (isVecmat) {
- auto expandForSMMLA = [&](Value tiledOperand,
- VectorType expandedTypeType) {
+ // rows along dimM. Expand their shapes to match the ArmNeon op.
+ if (dimM == 1) {
+ auto expandRowVector = [&](Value tiledOperand,
+ VectorType expandedTypeType) {
auto emptyOperand = rewriter.create<arith::ConstantOp>(
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
SmallVector<int64_t> offsets(
@@ -290,8 +266,8 @@ class LowerContractionToNeonI8MMPattern
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledOperand, emptyOperand, offsets, strides);
};
- tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
- tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+ tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
+ tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
}
// Transpose ACC if doing signed by unsigned multiplication, because we're
@@ -301,15 +277,11 @@ class LowerContractionToNeonI8MMPattern
tiledAcc = rewriter.create<vector::TransposeOp>(
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
- // Collapse tiled operands to 1D vectors required by smmla intrinsic
- auto collapsedInputType =
- VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+ // Collapse tiled operands to 1D vectors required by the ArmNeon ops
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
- auto collapsedOutputType =
- VectorType::get(outputExpandedType.getNumElements(), accElementType);
bool initialKAcc = offsets.back() == 0;
Value collapsedRes;
@@ -321,8 +293,8 @@ class LowerContractionToNeonI8MMPattern
}
// Insert contract op
- kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
- collapsedRes, collapsedLhs, collapsedRhs);
+ kAcc =
+ createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
// Reshape output back to 2D
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
@@ -336,9 +308,8 @@ class LowerContractionToNeonI8MMPattern
// With vecmat, only one row of tiled ACC can be inserted into the final
// result
- if (isVecmat) {
+ if (dimM == 1)
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
- }
// Insert the tiled result back into the non tiled result of the
// contract op.
@@ -349,6 +320,98 @@ class LowerContractionToNeonI8MMPattern
}
rewriter.replaceOp(op, result);
+ }
+};
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+ return failure();
+
+ // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
+ // tiling.
+ if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
+ return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+ // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
+ // values before the extension. All four signed/unsigned combinations for
+ // input operands are supported, but they are lowered to
diff erent
+ // operations. Determine which is the appropriate operation to lower to.
+ mmlaOp = MMLA::Signed;
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::Unsigned;
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+ }
+ if (!maybeLhs)
+ return rewriter.notifyMatchFailure(
+ op, "LHS is not a sign- or zero- extended iN, N <= 8");
+
+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::Unsigned)
+ mmlaOp = MMLA::Mixed;
+ } else {
+ if (mmlaOp == MMLA::Signed)
+ mmlaOp = MMLA::MixedSwapped;
+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
+ }
+
+ if (!maybeRhs)
+ return rewriter.notifyMatchFailure(
+ op, "RHS is not a sign- or zero- extended iN, N <= 8");
+
+ lhs = *maybeLhs;
+ rhs = *maybeRhs;
+ acc = op.getAcc();
+
+ // Extend inputs from iN, N < 8 to i8.
+ Location loc = op.getLoc();
+ auto lhsExtInType = cast<VectorType>(lhs.getType());
+ if (lhsExtInType.getElementTypeBitWidth() < 8)
+ lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
+ /* signExt */ mmlaOp == MMLA::Signed ||
+ mmlaOp == MMLA::Mixed,
+ rewriter);
+
+ auto rhsExtInType = cast<VectorType>(rhs.getType());
+ if (rhsExtInType.getElementTypeBitWidth() < 8)
+
+ rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
+ /* signExt */ mmlaOp != MMLA::Unsigned &&
+ mmlaOp != MMLA::Mixed,
+ rewriter);
+
+ // Initialize parameters for unrolling.
+ iterationBounds = *op.getShapeForUnroll();
+ if (iterationBounds.size() == 3)
+ subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 8});
+ else
+ subTileShape = SmallVector<int64_t>({2, 8});
+
+ return success();
+ }
+};
+
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
+/// any vector.contract into multiple smmla instructions with unrolling so long
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
+class LowerContractionToNeonI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorContractRewriterI8MM vcr;
+ if (failed(vcr.matchAndInit(op, rewriter)))
+ return failure();
+ vcr.lower(op, rewriter);
+
return success();
}
};
More information about the Mlir-commits
mailing list