[Mlir-commits] [mlir] [mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns (PR #86005)
Diego Caballero
llvmlistbot at llvm.org
Thu Mar 21 15:27:29 PDT 2024
================
@@ -40,41 +40,38 @@ static Type matchContainerType(Type element, Type container) {
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
-/// smmla instruction is emitted.
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- // Check index maps that represent M N K in contract.
- auto indexingMaps = op.getIndexingMapsArray();
- if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
- return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
- affineMap.getNumResults() != 2;
- })) {
- return failure();
- }
- // Check iterator types for contract.
- auto iteratorTypes = op.getIteratorTypesArray();
- if (iteratorTypes.size() != 3 ||
- iteratorTypes[0] != vector::IteratorType::parallel ||
- iteratorTypes[1] != vector::IteratorType::parallel ||
- iteratorTypes[2] != vector::IteratorType::reduction) {
- return failure();
- }
- // Infer tile sizes from operands; Note: RHS is not transposed.
+ // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+ // Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
- auto dimM = lhsType.getDimSize(0);
+ auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
- auto dimK = lhsType.getDimSize(1);
-
+ auto dimK = rhsType.getDimSize(1);
+ bool isVecmat = dimM == 1 ? true : false;
+ if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+ rhsType.getDimSize(rhsType.getRank() - 1)) {
+ return failure(); // dimK mismatch
+ }
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
// tiling.
- if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+ if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+ return failure();
+ }
+
+ // Check iterator types for contract.
+ auto iteratorTypes = op.getIteratorTypesArray();
+ if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+ vector::IteratorType::reduction) {
----------------
dcaballe wrote:
Shouldn't we check that the ones that we fold into the smmla instruction are parallel?
https://github.com/llvm/llvm-project/pull/86005
More information about the Mlir-commits
mailing list