[Mlir-commits] [mlir] [mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns (PR #86005)

Kojo Acquah llvmlistbot at llvm.org
Wed Mar 20 20:17:40 PDT 2024


https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/86005

>From d39ae36969b287cd950bbb7031a936b9128c73e1 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Wed, 20 Mar 2024 20:56:08 +0000
Subject: [PATCH] implement vecmat unroll for i8mm

---
 .../LowerContractionToSMMLAPattern.cpp        |  68 ++++++----
 .../Dialect/ArmNeon/lower-to-arm-neon.mlir    | 122 ++++++++++++++++++
 2 files changed, 165 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 1f48d27aa27b17..cdd3065e0752df 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -49,32 +49,28 @@ class LowerContractionToSMMLAPattern
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    // Check index maps that represent M N K in contract.
-    auto indexingMaps = op.getIndexingMapsArray();
-    if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
-          return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
-                 affineMap.getNumResults() != 2;
-        })) {
-      return failure();
-    }
-    // Check iterator types for contract.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    if (iteratorTypes.size() != 3 ||
-        iteratorTypes[0] != vector::IteratorType::parallel ||
-        iteratorTypes[1] != vector::IteratorType::parallel ||
-        iteratorTypes[2] != vector::IteratorType::reduction) {
-      return failure();
-    }
-    // Infer tile sizes from operands; Note: RHS is not transposed.
+    // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+    // Note: RHS is not transposed.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
-    auto dimM = lhsType.getDimSize(0);
+    auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
-    auto dimK = lhsType.getDimSize(1);
-
+    auto dimK = rhsType.getDimSize(1);
+    bool isVecmat = dimM == 1 ? true : false;
+    if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+        rhsType.getDimSize(rhsType.getRank() - 1)) {
+      return failure(); // dimK mismatch
+    }
     // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
     // tiling.
-    if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+    if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+      return failure();
+    }
+
+    // Check iterator types for contract.
+    auto iteratorTypes = op.getIteratorTypesArray();
+    if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+                                        vector::IteratorType::reduction) {
       return failure();
     }
 
@@ -120,11 +116,14 @@ class LowerContractionToSMMLAPattern
         loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
 
     SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
-    SmallVector<int64_t> smmlaShape{2, 2, 8};
+    SmallVector<int64_t> smmlaShape{isVecmat ? 1 : 2, 2, 8};
     SmallVector<int64_t> loopOrder{0, 1, 2};
+    if (unrolledSize.size() == 2) {
+      smmlaShape = {2, 8};
+      loopOrder = {0, 1};
+    }
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
-
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +149,30 @@ class LowerContractionToSMMLAPattern
       Value tiledAcc =
           extractOperand(op.getAcc(), accPermutationMap, accOffsets);
 
+      // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+      // rows along dimM. Broadcast both to the full width
+      if (isVecmat) {
+        auto lhsBroadcastType = VectorType::get(
+            {2, 8}, tiledLhs.getType().cast<ShapedType>().getElementType());
+        tiledLhs = rewriter.create<vector::BroadcastOp>(loc, lhsBroadcastType,
+                                                        tiledLhs);
+        auto accBroadcastType = VectorType::get(
+            {2, 2}, tiledAcc.getType().cast<ShapedType>().getElementType());
+        tiledAcc = rewriter.create<vector::BroadcastOp>(loc, accBroadcastType,
+                                                        tiledAcc);
+      }
+
       // Collapse tiled operands to 1D vectors required by smmla intrinsic
       auto collapsedInputType = VectorType::get(
           tiledLhs.getType().cast<ShapedType>().getNumElements(),
           tiledLhs.getType().cast<ShapedType>().getElementType());
-      auto collapsedOutputType = VectorType::get(
-          {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
       auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledLhs.getLoc(), collapsedInputType, tiledLhs);
       auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+      auto collapsedOutputType = VectorType::get(
+          tiledAcc.getType().cast<ShapedType>().getNumElements(),
+          tiledAcc.getType().cast<ShapedType>().getElementType());
       auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
 
@@ -172,6 +185,11 @@ class LowerContractionToSMMLAPattern
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
           smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
 
+      // With vecmat, only one row of tiled ACC can be inserted inot file result
+      if (isVecmat) {
+        tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+      }
+
       // Insert the tiled result back into the non tiled result of the
       // contract op.
       SmallVector<int64_t> strides(
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index e2be87453bf6f2..b2cc3e611ea3d9 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -134,3 +134,125 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
   return %res : vector<4x4xi32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_vecmat_unroll(
+// CHECK-SAME:  %[[VAL_0:.*]]: vector<8xi8>,
+// CHECK-SAME:  %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME:  %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_6:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_7:.*]] = vector.broadcast %[[VAL_5]] : vector<2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_11:.*]] = arm_neon.intr.smmla %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_13:.*]] = vector.extract %[[VAL_12]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_14:.*]] = vector.insert_strided_slice %[[VAL_13]], %[[VAL_3]] {offsets = [0], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_16:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_17:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_18:.*]] = vector.broadcast %[[VAL_16]] : vector<2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_19:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_21:.*]] = vector.shape_cast %[[VAL_18]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_22:.*]] = arm_neon.intr.smmla %[[VAL_21]], %[[VAL_19]], %[[VAL_20]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_24:.*]] = vector.extract %[[VAL_23]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.insert_strided_slice %[[VAL_24]], %[[VAL_14]] {offsets = [2], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  %[[VAL_26:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [4], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_28:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_29:.*]] = vector.broadcast %[[VAL_27]] : vector<2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_30:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_31:.*]] = vector.shape_cast %[[VAL_26]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_32:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_33:.*]] = arm_neon.intr.smmla %[[VAL_32]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_33]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_36:.*]] = vector.insert_strided_slice %[[VAL_35]], %[[VAL_25]] {offsets = [4], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_38:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [6], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_39:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_40:.*]] = vector.broadcast %[[VAL_38]] : vector<2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_41:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_42:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_43:.*]] = vector.shape_cast %[[VAL_40]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_44:.*]] = arm_neon.intr.smmla %[[VAL_43]], %[[VAL_41]], %[[VAL_42]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_45:.*]] = vector.shape_cast %[[VAL_44]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_46:.*]] = vector.extract %[[VAL_45]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_47:.*]] = vector.insert_strided_slice %[[VAL_46]], %[[VAL_36]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  return %[[VAL_47]] : vector<8xi32>
+// CHECK:  }
+func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+  %lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32>
+  return %res : vector<8xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(
+// CHECK-SAME:  %[[VAL_0:.*]]: vector<1x8xi8>,
+// CHECK-SAME:  %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME:  %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x8xi32>
+// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_6:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_7:.*]] = vector.broadcast %[[VAL_5]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_11:.*]] = arm_neon.intr.smmla %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_13:.*]] = vector.extract %[[VAL_12]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_14:.*]] = vector.insert_strided_slice %[[VAL_13]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_16:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_17:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_18:.*]] = vector.broadcast %[[VAL_16]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_19:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_21:.*]] = vector.shape_cast %[[VAL_18]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_22:.*]] = arm_neon.intr.smmla %[[VAL_21]], %[[VAL_19]], %[[VAL_20]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_24:.*]] = vector.extract %[[VAL_23]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.insert_strided_slice %[[VAL_24]], %[[VAL_14]] {offsets = [0, 2], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  %[[VAL_26:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 4], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_28:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_29:.*]] = vector.broadcast %[[VAL_27]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_30:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_31:.*]] = vector.shape_cast %[[VAL_26]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_32:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_33:.*]] = arm_neon.intr.smmla %[[VAL_32]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_33]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_36:.*]] = vector.insert_strided_slice %[[VAL_35]], %[[VAL_25]] {offsets = [0, 4], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_38:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 6], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_39:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_40:.*]] = vector.broadcast %[[VAL_38]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_41:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_42:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_43:.*]] = vector.shape_cast %[[VAL_40]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_44:.*]] = arm_neon.intr.smmla %[[VAL_43]], %[[VAL_41]], %[[VAL_42]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_45:.*]] = vector.shape_cast %[[VAL_44]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_46:.*]] = vector.extract %[[VAL_45]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_47:.*]] = vector.insert_strided_slice %[[VAL_46]], %[[VAL_36]] {offsets = [0, 6], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  return %[[VAL_47]] : vector<1x8xi32>
+// CHECK:  }
+
+func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> {
+  %lhs_extsi= arith.extsi %lhs : vector<1x8xi8> to vector<1x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
+  return %res : vector<1x8xi32>
+}



More information about the Mlir-commits mailing list