[Mlir-commits] [mlir] c511c90 - [mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns (#86005)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 3 16:24:21 PDT 2024


Author: Kojo Acquah
Date: 2024-04-03T19:24:18-04:00
New Revision: c511c90680eecae2e4adb87f442f41d465feb0f2

URL: https://github.com/llvm/llvm-project/commit/c511c90680eecae2e4adb87f442f41d465feb0f2
DIFF: https://github.com/llvm/llvm-project/commit/c511c90680eecae2e4adb87f442f41d465feb0f2.diff

LOG: [mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns (#86005)

Updates smmla unrolling patterns to handle vecmat contracts where `dimM=1`. This includes explicit vecmats in the form: `<1x8xi8> x <8x8xi8> --> <1x8xi32>` or implied with the leading dim folded: `<8xi8> x <8x8xi8> --> <8xi32>` 

Since the smmla operates on two `<2x8xi8>` input vectors to produce `<2x2xi8>` accumulators, half of each 2x2 accumulator tile is dummy data not pertinent to the computation, resulting in half throughput.

Added: 
    

Modified: 
    mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
    mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 1f48d27aa27b17..13740225749e46 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -40,8 +40,9 @@ static Type matchContainerType(Type element, Type container) {
 
 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
 /// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
-/// smmla instruction is emitted.
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
 class LowerContractionToSMMLAPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
@@ -49,32 +50,35 @@ class LowerContractionToSMMLAPattern
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    // Check index maps that represent M N K in contract.
-    auto indexingMaps = op.getIndexingMapsArray();
-    if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
-          return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
-                 affineMap.getNumResults() != 2;
-        })) {
-      return failure();
-    }
-    // Check iterator types for contract.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    if (iteratorTypes.size() != 3 ||
-        iteratorTypes[0] != vector::IteratorType::parallel ||
-        iteratorTypes[1] != vector::IteratorType::parallel ||
-        iteratorTypes[2] != vector::IteratorType::reduction) {
-      return failure();
-    }
-    // Infer tile sizes from operands; Note: RHS is not transposed.
+    // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+    // Note: RHS is not transposed.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
-    auto dimM = lhsType.getDimSize(0);
+    auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
-    auto dimK = lhsType.getDimSize(1);
-
+    auto dimK = rhsType.getDimSize(1);
+    bool isVecmat = dimM == 1 ? true : false;
+    if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+        rhsType.getDimSize(rhsType.getRank() - 1)) {
+      return failure(); // dimK mismatch
+    }
     // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
     // tiling.
-    if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+    if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+      return failure();
+    }
+
+    // Check iterator types for contract. All iterators except inner-most
+    // dimension must be parallel.
+    auto iteratorTypes = op.getIteratorTypesArray();
+    if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+                                        vector::IteratorType::reduction) {
+      return failure();
+    }
+    if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
+                     [](vector::IteratorType iteratorType) {
+                       return iteratorType != vector::IteratorType::parallel;
+                     })) {
       return failure();
     }
 
@@ -120,11 +124,14 @@ class LowerContractionToSMMLAPattern
         loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
 
     SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
-    SmallVector<int64_t> smmlaShape{2, 2, 8};
-    SmallVector<int64_t> loopOrder{0, 1, 2};
+    SmallVector<int64_t> smmlaShape{2, 8};
+    SmallVector<int64_t> loopOrder{0, 1};
+    if (unrolledSize.size() == 3) {
+      smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+      loopOrder.push_back(2);
+    }
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
-
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +157,40 @@ class LowerContractionToSMMLAPattern
       Value tiledAcc =
           extractOperand(op.getAcc(), accPermutationMap, accOffsets);
 
+      auto inputElementType =
+          tiledLhs.getType().cast<ShapedType>().getElementType();
+      auto accElementType =
+          tiledAcc.getType().cast<ShapedType>().getElementType();
+      auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
+      auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+      // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+      // rows along dimM. Expand their shapes to match the smmla op.
+      if (isVecmat) {
+        auto expandForSMMLA = [&](Value tiledOperand,
+                                  VectorType expandedTypeType) {
+          auto emptyOperand = rewriter.create<arith::ConstantOp>(
+              loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
+          SmallVector<int64_t> offsets(
+              emptyOperand.getType().cast<ShapedType>().getRank(), 0);
+          SmallVector<int64_t> strides(
+              tiledOperand.getType().cast<ShapedType>().getRank(), 1);
+          return rewriter.createOrFold<vector::InsertStridedSliceOp>(
+              loc, tiledOperand, emptyOperand, offsets, strides);
+        };
+        tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
+        tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+      }
+
       // Collapse tiled operands to 1D vectors required by smmla intrinsic
-      auto collapsedInputType = VectorType::get(
-          tiledLhs.getType().cast<ShapedType>().getNumElements(),
-          tiledLhs.getType().cast<ShapedType>().getElementType());
-      auto collapsedOutputType = VectorType::get(
-          {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
+      auto collapsedInputType =
+          VectorType::get(inputExpandedType.getNumElements(), inputElementType);
       auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledLhs.getLoc(), collapsedInputType, tiledLhs);
       auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+      auto collapsedOutputType =
+          VectorType::get(outputExpandedType.getNumElements(), accElementType);
       auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
 
@@ -172,6 +203,11 @@ class LowerContractionToSMMLAPattern
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
           smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
 
+      // With vecmat, only one row of tiled ACC can be inserted inot file result
+      if (isVecmat) {
+        tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+      }
+
       // Insert the tiled result back into the non tiled result of the
       // contract op.
       SmallVector<int64_t> strides(

diff  --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index e2be87453bf6f2..46c4026d13b660 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -134,3 +134,127 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
   return %res : vector<4x4xi32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_vecmat_unroll(
+// CHECK-SAME:  %[[VAL_0:.*]]: vector<8xi8>,
+// CHECK-SAME:  %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME:  %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32>
+// CHECK:  %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK:  %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_8:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32>
+// CHECK:  %[[VAL_10:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_11:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_13:.*]] = arm_neon.intr.smmla %[[VAL_12]], %[[VAL_10]], %[[VAL_11]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_14:.*]] = vector.shape_cast %[[VAL_13]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.extract %[[VAL_14]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_16:.*]] = vector.insert_strided_slice %[[VAL_15]], %[[VAL_5]] {offsets = [0], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  %[[VAL_17:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_19:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32>
+// CHECK:  %[[VAL_21:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_22:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_23:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_24:.*]] = arm_neon.intr.smmla %[[VAL_23]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.shape_cast %[[VAL_24]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_26:.*]] = vector.extract %[[VAL_25]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_27:.*]] = vector.insert_strided_slice %[[VAL_26]], %[[VAL_16]] {offsets = [2], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_29:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [4], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_31:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32>
+// CHECK:  %[[VAL_32:.*]] = vector.shape_cast %[[VAL_30]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_33:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_35:.*]] = arm_neon.intr.smmla %[[VAL_34]], %[[VAL_32]], %[[VAL_33]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_36:.*]] = vector.shape_cast %[[VAL_35]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_37:.*]] = vector.extract %[[VAL_36]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_38:.*]] = vector.insert_strided_slice %[[VAL_37]], %[[VAL_27]] {offsets = [4], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  %[[VAL_39:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_40:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [6], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_41:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_42:.*]] = vector.insert_strided_slice %[[VAL_40]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32>
+// CHECK:  %[[VAL_43:.*]] = vector.shape_cast %[[VAL_41]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_44:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_45:.*]] = vector.shape_cast %[[VAL_42]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_46:.*]] = arm_neon.intr.smmla %[[VAL_45]], %[[VAL_43]], %[[VAL_44]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_47:.*]] = vector.shape_cast %[[VAL_46]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_48:.*]] = vector.extract %[[VAL_47]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  return %[[VAL_49]] : vector<8xi32>
+// CHECK:  }
+func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+  %lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32>
+  return %res : vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(
+// CHECK-SAME:  %[[VAL_0:.*]]: vector<1x8xi8>,
+// CHECK-SAME:  %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME:  %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32>
+// CHECK:  %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = arith.constant dense<0> : vector<1x8xi32>
+// CHECK:  %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_8:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK:  %[[VAL_10:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_11:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_13:.*]] = arm_neon.intr.smmla %[[VAL_12]], %[[VAL_10]], %[[VAL_11]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_14:.*]] = vector.shape_cast %[[VAL_13]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.extract %[[VAL_14]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_16:.*]] = vector.insert_strided_slice %[[VAL_15]], %[[VAL_5]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  %[[VAL_17:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_19:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK:  %[[VAL_21:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_22:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_23:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_24:.*]] = arm_neon.intr.smmla %[[VAL_23]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.shape_cast %[[VAL_24]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_26:.*]] = vector.extract %[[VAL_25]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_27:.*]] = vector.insert_strided_slice %[[VAL_26]], %[[VAL_16]] {offsets = [0, 2], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_29:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 4], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_31:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK:  %[[VAL_32:.*]] = vector.shape_cast %[[VAL_30]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_33:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_35:.*]] = arm_neon.intr.smmla %[[VAL_34]], %[[VAL_32]], %[[VAL_33]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_36:.*]] = vector.shape_cast %[[VAL_35]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_37:.*]] = vector.extract %[[VAL_36]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_38:.*]] = vector.insert_strided_slice %[[VAL_37]], %[[VAL_27]] {offsets = [0, 4], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  %[[VAL_39:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_40:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 6], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_41:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_42:.*]] = vector.insert_strided_slice %[[VAL_40]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK:  %[[VAL_43:.*]] = vector.shape_cast %[[VAL_41]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_44:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_45:.*]] = vector.shape_cast %[[VAL_42]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_46:.*]] = arm_neon.intr.smmla %[[VAL_45]], %[[VAL_43]], %[[VAL_44]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_47:.*]] = vector.shape_cast %[[VAL_46]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_48:.*]] = vector.extract %[[VAL_47]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [0, 6], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  return %[[VAL_49]] : vector<1x8xi32>
+// CHECK:  }
+func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> {
+  %lhs_extsi= arith.extsi %lhs : vector<1x8xi8> to vector<1x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
+  return %res : vector<1x8xi32>
+}


        


More information about the Mlir-commits mailing list