[Mlir-commits] [mlir] [mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns (PR #86005)

Diego Caballero llvmlistbot at llvm.org
Thu Mar 21 15:27:29 PDT 2024


================
@@ -40,41 +40,38 @@ static Type matchContainerType(Type element, Type container) {
 
 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
 /// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
-/// smmla instruction is emitted.
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
 class LowerContractionToSMMLAPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    // Check index maps that represent M N K in contract.
-    auto indexingMaps = op.getIndexingMapsArray();
-    if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
-          return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
-                 affineMap.getNumResults() != 2;
-        })) {
-      return failure();
-    }
-    // Check iterator types for contract.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    if (iteratorTypes.size() != 3 ||
-        iteratorTypes[0] != vector::IteratorType::parallel ||
-        iteratorTypes[1] != vector::IteratorType::parallel ||
-        iteratorTypes[2] != vector::IteratorType::reduction) {
-      return failure();
-    }
-    // Infer tile sizes from operands; Note: RHS is not transposed.
+    // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+    // Note: RHS is not transposed.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
-    auto dimM = lhsType.getDimSize(0);
+    auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
-    auto dimK = lhsType.getDimSize(1);
-
+    auto dimK = rhsType.getDimSize(1);
+    bool isVecmat = dimM == 1 ? true : false;
+    if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+        rhsType.getDimSize(rhsType.getRank() - 1)) {
+      return failure(); // dimK mismatch
+    }
     // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
     // tiling.
-    if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+    if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+      return failure();
+    }
+
+    // Check iterator types for contract.
+    auto iteratorTypes = op.getIteratorTypesArray();
+    if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+                                        vector::IteratorType::reduction) {
----------------
dcaballe wrote:

Shouldn't we check that the ones that we fold into the smmla instruction are parallel?

https://github.com/llvm/llvm-project/pull/86005


More information about the Mlir-commits mailing list