[Mlir-commits] [mlir] [mlir][ArmNeon] Update `LowerContractionToSMMLAPattern` to support proper unrolling for k dimension (PR #88591)

Kojo Acquah llvmlistbot at llvm.org
Thu Apr 18 13:00:47 PDT 2024


https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/88591

>From 1ebf2e038656fd9c8b989e16cc2d6f49c23f88db Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Fri, 12 Apr 2024 23:02:41 +0000
Subject: [PATCH 1/2] implement proper unrolling for k dim

---
 .../LowerContractionToSMMLAPattern.cpp        | 11 ++-
 .../Dialect/ArmNeon/lower-to-arm-neon.mlir    | 77 +++++++++++++++++++
 2 files changed, 87 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 3ae894692089b3..abe7f216533d7a 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -133,8 +133,12 @@ class LowerContractionToSMMLAPattern
       smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
       loopOrder.push_back(2);
     }
+
+    // Keep track of the previous accumulator when tiling over K
+    Value kAcc;
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+      auto kTileIndex = offsets[offsets.size() - 1] / 8;
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
@@ -197,16 +201,21 @@ class LowerContractionToSMMLAPattern
       auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
 
+      if (kTileIndex != 0) {
+        collapsedRes = kAcc;
+      }
+
       // Insert contract op
       auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
           op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
           collapsedRhs);
+      kAcc = smmlaOp;
 
       // Reshape output back to 2D
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
           smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
 
-      // With vecmat, only one row of tiled ACC can be inserted inot file result
+      // With vecmat, only one row of tiled ACC can be inserted into file result
       if (isVecmat) {
         tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
       }
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index c276a5b0c2a14b..b70ca36c2d7f60 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -269,3 +269,80 @@ func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
   return %res : vector<8xi32>
 }
+
+
+// -----
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_k_unroll(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2x16xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2x16xi4>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<2x2xi32>) -> vector<2x2xi32> {
+// CHECK:  %[[VAL_3:.*]] = arith.extsi %[[VAL_1]] : vector<2x16xi4> to vector<2x16xi8>
+// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_3]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_6:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_2]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_9:.*]] = arm_neon.intr.smmla %[[VAL_8]], %[[VAL_6]], %[[VAL_7]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_3]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_13:.*]] = vector.shape_cast %[[VAL_11]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_14:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_12]], %[[VAL_13]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  return %[[VAL_15]] : vector<2x2xi32>
+// CHECK:  }
+func.func @test_lower_vector_arm_neon_k_unroll(%lhs: vector<2x16xi8>, %rhs: vector<2x16xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<2x16xi8> to vector<2x16xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<2x16xi4> to vector<2x16xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x16xi32>, vector<2x16xi32> into vector<2x2xi32>
+  return %res : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_k_unroll_vecmat(
+// CHECK-SAME:                                                          %[[VAL_0:.*]]: vector<1x24xi8>,
+// CHECK-SAME:                                                          %[[VAL_1:.*]]: vector<2x24xi4>,
+// CHECK-SAME:                                                          %[[VAL_2:.*]]: vector<1x2xi32>) -> vector<1x2xi32> {
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8>
+// CHECK:           %[[VAL_5:.*]] = arith.constant dense<0> : vector<1x2xi32>
+// CHECK:           %[[VAL_6:.*]] = arith.extsi %[[VAL_1]] : vector<2x24xi4> to vector<2x24xi8>
+// CHECK:           %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [1, 8], strides = [1, 1]} : vector<1x24xi8> to vector<1x8xi8>
+// CHECK:           %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x24xi8> to vector<2x8xi8>
+// CHECK:           %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:           %[[VAL_10:.*]] = vector.insert_strided_slice %[[VAL_2]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK:           %[[VAL_11:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:           %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:           %[[VAL_13:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:           %[[VAL_14:.*]] = arm_neon.intr.smmla %[[VAL_13]], %[[VAL_11]], %[[VAL_12]] : vector<16xi8> to vector<4xi32>
+// CHECK:           %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_15]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:           %[[VAL_17:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[VAL_5]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK:           %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [1, 8], strides = [1, 1]} : vector<1x24xi8> to vector<1x8xi8>
+// CHECK:           %[[VAL_19:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x24xi8> to vector<2x8xi8>
+// CHECK:           %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:           %[[VAL_21:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:           %[[VAL_22:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:           %[[VAL_23:.*]] = arm_neon.intr.smmla %[[VAL_14]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32>
+// CHECK:           %[[VAL_24:.*]] = vector.shape_cast %[[VAL_23]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:           %[[VAL_25:.*]] = vector.extract %[[VAL_24]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:           %[[VAL_26:.*]] = vector.insert_strided_slice %[[VAL_25]], %[[VAL_17]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK:           %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 16], sizes = [1, 8], strides = [1, 1]} : vector<1x24xi8> to vector<1x8xi8>
+// CHECK:           %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 16], sizes = [2, 8], strides = [1, 1]} : vector<2x24xi8> to vector<2x8xi8>
+// CHECK:           %[[VAL_29:.*]] = vector.insert_strided_slice %[[VAL_27]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:           %[[VAL_30:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:           %[[VAL_31:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:           %[[VAL_32:.*]] = arm_neon.intr.smmla %[[VAL_23]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK:           %[[VAL_33:.*]] = vector.shape_cast %[[VAL_32]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:           %[[VAL_34:.*]] = vector.extract %[[VAL_33]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:           %[[VAL_35:.*]] = vector.insert_strided_slice %[[VAL_34]], %[[VAL_26]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK:           return %[[VAL_35]] : vector<1x2xi32>
+// CHECK:         }
+func.func @test_lower_vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x24xi8>, %rhs: vector<2x24xi4>, %acc : vector<1x2xi32>) -> vector<1x2xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<1x24xi8> to vector<1x24xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<2x24xi4> to vector<2x24xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x24xi32>, vector<2x24xi32> into vector<1x2xi32>
+  return %res : vector<1x2xi32>
+}

>From e2f57e287c05ad82ca0d95f289a45abc6f2773ed Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Thu, 18 Apr 2024 20:00:31 +0000
Subject: [PATCH 2/2] review comments

---
 .../LowerContractionToSMMLAPattern.cpp        |  17 ++-
 .../Dialect/ArmNeon/lower-to-arm-neon.mlir    | 140 +++++++++---------
 2 files changed, 83 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index abe7f216533d7a..220798cf61b4b5 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -134,11 +134,10 @@ class LowerContractionToSMMLAPattern
       loopOrder.push_back(2);
     }
 
-    // Keep track of the previous accumulator when tiling over K
+    // Keep track of the previous accumulator when tiling over K.
     Value kAcc;
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
-      auto kTileIndex = offsets[offsets.size() - 1] / 8;
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
@@ -198,22 +197,24 @@ class LowerContractionToSMMLAPattern
           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
       auto collapsedOutputType =
           VectorType::get(outputExpandedType.getNumElements(), accElementType);
-      auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
-          tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
 
-      if (kTileIndex != 0) {
+      bool initialKAcc = offsets.back() == 0;
+      Value collapsedRes;
+      if (!initialKAcc) {
         collapsedRes = kAcc;
+      } else {
+        collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
+            tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
       }
 
       // Insert contract op
-      auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
+      kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
           op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
           collapsedRhs);
-      kAcc = smmlaOp;
 
       // Reshape output back to 2D
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
-          smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+          kAcc.getLoc(), tiledAcc.getType(), kAcc);
 
       // With vecmat, only one row of tiled ACC can be inserted into file result
       if (isVecmat) {
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index b70ca36c2d7f60..297be91e772838 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
 
-// CHECK-LABEL: test_lower_vector_arm_neon_mixed_types
+// CHECK-LABEL: vector_arm_neon_mixed_types
 // CHECK-SAME:    %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
 // CHECK-DAG: %[[D0:.*]] = arith.extsi %[[A1]] : vector<2x8xi4> to vector<2x8xi8>
 // CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
@@ -8,7 +8,7 @@
 // CHECK-DAG: %[[D3:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
 // CHECK-DAG: %[[D4:.*]] = arm_neon.intr.smmla %[[D3]], %[[D1]], %[[D2]] : vector<16xi8> to vector<4xi32>
 // CHECK-DAG: %[[D5:.*]] = vector.shape_cast %[[D4]] : vector<4xi32> to vector<2x2xi32>
-func.func @test_lower_vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+func.func @vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
   %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
   %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
@@ -17,14 +17,14 @@ func.func @test_lower_vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: ve
 
 // -----
 
-// CHECK-LABEL: test_lower_vector_arm_neon_same_types
+// CHECK-LABEL: vector_arm_neon_same_types
 // CHECK-SAME:    %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi8>, %[[A2:.*]]: vector<2x2xi32>
 // CHECK-DAG: %[[D0:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
 // CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A1]] : vector<2x8xi8> to vector<16xi8>
 // CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
 // CHECK-DAG: %[[D3:.*]] = arm_neon.intr.smmla %[[D2]], %[[D0]], %[[D1]] : vector<16xi8> to vector<4xi32>
 // CHECK-DAG: %[[D4:.*]] = vector.shape_cast %[[D3]] : vector<4xi32> to vector<2x2xi32>
-func.func @test_lower_vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
   %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
   %rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
@@ -33,17 +33,17 @@ func.func @test_lower_vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vec
 
 // -----
 
-// CHECK-LABEL: test_lower_vector_arm_neon_without_extsi
+// CHECK-LABEL: vector_arm_neon_without_extsi
 // CHECK-SAME:    %[[A0:.*]]: vector<2x8xi32>, %[[A1:.*]]: vector<2x8xi32>, %[[A2:.*]]: vector<2x2xi32>
 // CHECK-DAG: %[[D0:.*]] = vector.contract
-func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+func.func @vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
   return %res : vector<2x2xi32>
 }
 
 // -----
 
-// CHECK-LABEL: test_lower_vector_arm_neon_unroll
+// CHECK-LABEL: vector_arm_neon_unroll
 // CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
 // CHECK-DAG:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
 // CHECK-DAG:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
@@ -84,7 +84,7 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
 // CHECK-DAG:  %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
 // CHECK-DAG:  return %[[VAL_39]] : vector<4x4xi32>
 // CHECK-DAG:  }
-func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
+func.func @vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
   %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
   %rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
@@ -93,7 +93,7 @@ func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<
 
 // -----
 
-// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_mixed_unroll(
+// CHECK-LABEL:   func.func @vector_arm_neon_mixed_unroll(
 // CHECK-SAME:                                                       %[[VAL_0:.*]]: vector<4x8xi8>,
 // CHECK-SAME:                                                       %[[VAL_1:.*]]: vector<2x8xi4>,
 // CHECK-SAME:                                                       %[[VAL_2:.*]]: vector<4x2xi32>) -> vector<4x2xi32> {
@@ -117,7 +117,7 @@ func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<
 // CHECK-DAG:  %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_19]], %[[VAL_12]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
 // CHECK-DAG:  return %[[VAL_20]] : vector<4x2xi32>
 // CHECK-DAG:  }
-func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> {
+func.func @vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> {
   %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
   %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<2x8xi32> into vector<4x2xi32>
@@ -126,9 +126,9 @@ func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: v
 
 // -----
 
-// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(
+// CHECK-LABEL:   func.func @vector_arm_neon_unroll_incompatible_shape(
 // CHECK-DAG:  %[[result:.*]] = vector.contract
-func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
+func.func @vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
   %lhs_extsi = arith.extsi %lhs : vector<4x12xi8> to vector<4x12xi32>
   %rhs_extsi = arith.extsi %rhs : vector<4x12xi8> to vector<4x12xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
@@ -137,7 +137,7 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1
 
 // -----
 
-// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_vecmat_unroll(
+// CHECK-LABEL:   func.func @vector_arm_neon_vecmat_unroll(
 // CHECK-SAME:  %[[VAL_0:.*]]: vector<8xi8>,
 // CHECK-SAME:  %[[VAL_1:.*]]: vector<8x8xi8>,
 // CHECK-SAME:  %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
@@ -190,7 +190,7 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1
 // CHECK:  %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32>
 // CHECK:  return %[[VAL_49]] : vector<8xi32>
 // CHECK:  }
-func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+func.func @vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
   %lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32>
   %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32>
@@ -199,7 +199,7 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: ve
 
 // -----
 
-// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(
+// CHECK-LABEL:   func.func @vector_arm_neon_vecmat_unroll_leading_dim(
 // CHECK-SAME:  %[[VAL_0:.*]]: vector<1x8xi8>,
 // CHECK-SAME:  %[[VAL_1:.*]]: vector<8x8xi8>,
 // CHECK-SAME:  %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
@@ -252,7 +252,7 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: ve
 // CHECK:  %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [0, 6], strides = [1]} : vector<2xi32> into vector<1x8xi32>
 // CHECK:  return %[[VAL_49]] : vector<1x8xi32>
 // CHECK:  }
-func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> {
+func.func @vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> {
   %lhs_extsi= arith.extsi %lhs : vector<1x8xi8> to vector<1x8xi32>
   %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
@@ -261,9 +261,9 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8
 
 // -----
 
-// CHECK-LABEL: func.func @test_lower_vector_arm_neon_matvec
+// CHECK-LABEL: func.func @vector_arm_neon_matvec
 // CHECK-NOT: arm_neon.intr.smmla
-func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+func.func @vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
   %rhs_extsi= arith.extsi %rhs : vector<8xi8> to vector<8xi32>
   %lhs_extsi = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
@@ -273,7 +273,7 @@ func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<
 
 // -----
 
-// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_k_unroll(
+// CHECK-LABEL:   func.func @vector_arm_neon_k_unroll(
 // CHECK-SAME: %[[VAL_0:.*]]: vector<2x16xi8>,
 // CHECK-SAME: %[[VAL_1:.*]]: vector<2x16xi4>,
 // CHECK-SAME: %[[VAL_2:.*]]: vector<2x2xi32>) -> vector<2x2xi32> {
@@ -283,16 +283,16 @@ func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<
 // CHECK:  %[[VAL_6:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
 // CHECK:  %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
 // CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_2]] : vector<2x2xi32> to vector<4xi32>
-// CHECK:  %[[VAL_9:.*]] = arm_neon.intr.smmla %[[VAL_8]], %[[VAL_6]], %[[VAL_7]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[KACC_0:.*]] = arm_neon.intr.smmla %[[VAL_8]], %[[VAL_6]], %[[VAL_7]] : vector<16xi8> to vector<4xi32>
 // CHECK:  %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
 // CHECK:  %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_3]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
 // CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x8xi8> to vector<16xi8>
 // CHECK:  %[[VAL_13:.*]] = vector.shape_cast %[[VAL_11]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:  %[[VAL_14:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_12]], %[[VAL_13]] : vector<16xi8> to vector<4xi32>
-// CHECK:  %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[KACC_1:.*]] = arm_neon.intr.smmla %[[KACC_0]], %[[VAL_12]], %[[VAL_13]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.shape_cast %[[KACC_1]] : vector<4xi32> to vector<2x2xi32>
 // CHECK:  return %[[VAL_15]] : vector<2x2xi32>
 // CHECK:  }
-func.func @test_lower_vector_arm_neon_k_unroll(%lhs: vector<2x16xi8>, %rhs: vector<2x16xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+func.func @vector_arm_neon_k_unroll(%lhs: vector<2x16xi8>, %rhs: vector<2x16xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
   %lhs_extsi = arith.extsi %lhs : vector<2x16xi8> to vector<2x16xi32>
   %rhs_extsi = arith.extsi %rhs : vector<2x16xi4> to vector<2x16xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x16xi32>, vector<2x16xi32> into vector<2x2xi32>
@@ -301,48 +301,56 @@ func.func @test_lower_vector_arm_neon_k_unroll(%lhs: vector<2x16xi8>, %rhs: vect
 
 // -----
 
-// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_k_unroll_vecmat(
-// CHECK-SAME:                                                          %[[VAL_0:.*]]: vector<1x24xi8>,
-// CHECK-SAME:                                                          %[[VAL_1:.*]]: vector<2x24xi4>,
-// CHECK-SAME:                                                          %[[VAL_2:.*]]: vector<1x2xi32>) -> vector<1x2xi32> {
-// CHECK:           %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32>
-// CHECK:           %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8>
-// CHECK:           %[[VAL_5:.*]] = arith.constant dense<0> : vector<1x2xi32>
-// CHECK:           %[[VAL_6:.*]] = arith.extsi %[[VAL_1]] : vector<2x24xi4> to vector<2x24xi8>
-// CHECK:           %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [1, 8], strides = [1, 1]} : vector<1x24xi8> to vector<1x8xi8>
-// CHECK:           %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x24xi8> to vector<2x8xi8>
-// CHECK:           %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
-// CHECK:           %[[VAL_10:.*]] = vector.insert_strided_slice %[[VAL_2]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
-// CHECK:           %[[VAL_11:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:           %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:           %[[VAL_13:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x2xi32> to vector<4xi32>
-// CHECK:           %[[VAL_14:.*]] = arm_neon.intr.smmla %[[VAL_13]], %[[VAL_11]], %[[VAL_12]] : vector<16xi8> to vector<4xi32>
-// CHECK:           %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<4xi32> to vector<2x2xi32>
-// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_15]][0] : vector<2xi32> from vector<2x2xi32>
-// CHECK:           %[[VAL_17:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[VAL_5]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
-// CHECK:           %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [1, 8], strides = [1, 1]} : vector<1x24xi8> to vector<1x8xi8>
-// CHECK:           %[[VAL_19:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x24xi8> to vector<2x8xi8>
-// CHECK:           %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
-// CHECK:           %[[VAL_21:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:           %[[VAL_22:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:           %[[VAL_23:.*]] = arm_neon.intr.smmla %[[VAL_14]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32>
-// CHECK:           %[[VAL_24:.*]] = vector.shape_cast %[[VAL_23]] : vector<4xi32> to vector<2x2xi32>
-// CHECK:           %[[VAL_25:.*]] = vector.extract %[[VAL_24]][0] : vector<2xi32> from vector<2x2xi32>
-// CHECK:           %[[VAL_26:.*]] = vector.insert_strided_slice %[[VAL_25]], %[[VAL_17]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
-// CHECK:           %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 16], sizes = [1, 8], strides = [1, 1]} : vector<1x24xi8> to vector<1x8xi8>
-// CHECK:           %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 16], sizes = [2, 8], strides = [1, 1]} : vector<2x24xi8> to vector<2x8xi8>
-// CHECK:           %[[VAL_29:.*]] = vector.insert_strided_slice %[[VAL_27]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
-// CHECK:           %[[VAL_30:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:           %[[VAL_31:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:           %[[VAL_32:.*]] = arm_neon.intr.smmla %[[VAL_23]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
-// CHECK:           %[[VAL_33:.*]] = vector.shape_cast %[[VAL_32]] : vector<4xi32> to vector<2x2xi32>
-// CHECK:           %[[VAL_34:.*]] = vector.extract %[[VAL_33]][0] : vector<2xi32> from vector<2x2xi32>
-// CHECK:           %[[VAL_35:.*]] = vector.insert_strided_slice %[[VAL_34]], %[[VAL_26]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
-// CHECK:           return %[[VAL_35]] : vector<1x2xi32>
-// CHECK:         }
-func.func @test_lower_vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x24xi8>, %rhs: vector<2x24xi4>, %acc : vector<1x2xi32>) -> vector<1x2xi32> {
-  %lhs_extsi = arith.extsi %lhs : vector<1x24xi8> to vector<1x24xi32>
-  %rhs_extsi = arith.extsi %rhs : vector<2x24xi4> to vector<2x24xi32>
-  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x24xi32>, vector<2x24xi32> into vector<1x2xi32>
+// CHECK-LABEL:   func.func @vector_arm_neon_k_unroll_vecmat(
+// CHECK-SAME:                                               %[[VAL_0:.*]]: vector<1x32xi8>,
+// CHECK-SAME:                                               %[[VAL_1:.*]]: vector<2x32xi4>,
+// CHECK-SAME:                                               %[[VAL_2:.*]]: vector<1x2xi32>) -> vector<1x2xi32> {
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32>
+// CHECK:  %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = arith.constant dense<0> : vector<1x2xi32>
+// CHECK:  %[[VAL_6:.*]] = arith.extsi %[[VAL_1]] : vector<2x32xi4> to vector<2x32xi8>
+// CHECK:  %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [1, 8], strides = [1, 1]} : vector<1x32xi8> to vector<1x8xi8>
+// CHECK:  %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x32xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_10:.*]] = vector.insert_strided_slice %[[VAL_2]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK:  %[[VAL_11:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_13:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[KACC_0:.*]] = arm_neon.intr.smmla %[[VAL_13]], %[[VAL_11]], %[[VAL_12]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.shape_cast %[[KACC_0]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_16:.*]] = vector.extract %[[VAL_15]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_17:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[VAL_5]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK:  %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [1, 8], strides = [1, 1]} : vector<1x32xi8> to vector<1x8xi8>
+// CHECK:  %[[VAL_19:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x32xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_21:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_22:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[KACC_1:.*]] = arm_neon.intr.smmla %[[KACC_0]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_24:.*]] = vector.shape_cast %[[KACC_1]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.extract %[[VAL_24]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_26:.*]] = vector.insert_strided_slice %[[VAL_25]], %[[VAL_17]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK:  %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 16], sizes = [1, 8], strides = [1, 1]} : vector<1x32xi8> to vector<1x8xi8>
+// CHECK:  %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 16], sizes = [2, 8], strides = [1, 1]} : vector<2x32xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_29:.*]] = vector.insert_strided_slice %[[VAL_27]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_30:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_31:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[KACC_2:.*]] = arm_neon.intr.smmla %[[KACC_1]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_33:.*]] = vector.shape_cast %[[KACC_2]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_34:.*]] = vector.extract %[[VAL_33]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_35:.*]] = vector.insert_strided_slice %[[VAL_34]], %[[VAL_26]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK:  %[[VAL_36:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 24], sizes = [1, 8], strides = [1, 1]} : vector<1x32xi8> to vector<1x8xi8>
+// CHECK:  %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 24], sizes = [2, 8], strides = [1, 1]} : vector<2x32xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_38:.*]] = vector.insert_strided_slice %[[VAL_36]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK:  %[[VAL_39:.*]] = vector.shape_cast %[[VAL_38]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_40:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[KACC_3:.*]] = arm_neon.intr.smmla %[[KACC_2]], %[[VAL_39]], %[[VAL_40]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_42:.*]] = vector.shape_cast %[[KACC_3]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_43:.*]] = vector.extract %[[VAL_42]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_44:.*]] = vector.insert_strided_slice %[[VAL_43]], %[[VAL_35]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK:  return %[[VAL_44]] : vector<1x2xi32>
+func.func @vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x32xi8>, %rhs: vector<2x32xi4>, %acc : vector<1x2xi32>) -> vector<1x2xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<1x32xi8> to vector<1x32xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<2x32xi4> to vector<2x32xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x32xi32>, vector<2x32xi32> into vector<1x2xi32>
   return %res : vector<1x2xi32>
 }



More information about the Mlir-commits mailing list