[Mlir-commits] [mlir] [mlir][ArmNeon] Update `LowerContractionToSMMLAPattern` to support proper unrolling for k dimension (PR #88591)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Apr 16 08:15:39 PDT 2024


================
@@ -269,3 +269,80 @@ func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
   return %res : vector<8xi32>
 }
+
+
+// -----
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_k_unroll(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2x16xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2x16xi4>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<2x2xi32>) -> vector<2x2xi32> {
+// CHECK:  %[[VAL_3:.*]] = arith.extsi %[[VAL_1]] : vector<2x16xi4> to vector<2x16xi8>
+// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_3]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_6:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_2]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_9:.*]] = arm_neon.intr.smmla %[[VAL_8]], %[[VAL_6]], %[[VAL_7]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_3]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_13:.*]] = vector.shape_cast %[[VAL_11]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_14:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_12]], %[[VAL_13]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  return %[[VAL_15]] : vector<2x2xi32>
+// CHECK:  }
----------------
banach-space wrote:

[nit] Not needed

https://github.com/llvm/llvm-project/pull/88591


More information about the Mlir-commits mailing list