[Mlir-commits] [mlir] [mlir][ArmNeon] Update `LowerContractionToSMMLAPattern` to support proper unrolling for k dimension (PR #88591)
Kojo Acquah
llvmlistbot at llvm.org
Mon Apr 15 16:46:14 PDT 2024
https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/88591
>From eed2afdba18328bd4c7d24d863c4ac4118e37b6b Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Fri, 12 Apr 2024 23:02:41 +0000
Subject: [PATCH] implement proper unrolling for k dim
---
.../LowerContractionToSMMLAPattern.cpp | 11 ++-
.../Dialect/ArmNeon/lower-to-arm-neon.mlir | 77 +++++++++++++++++++
2 files changed, 87 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 3ae894692089b3..abe7f216533d7a 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -133,8 +133,12 @@ class LowerContractionToSMMLAPattern
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
loopOrder.push_back(2);
}
+
+ // Keep track of the previous accumulator when tiling over K
+ Value kAcc;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+ auto kTileIndex = offsets[offsets.size() - 1] / 8;
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
@@ -197,16 +201,21 @@ class LowerContractionToSMMLAPattern
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+ if (kTileIndex != 0) {
+ collapsedRes = kAcc;
+ }
+
// Insert contract op
auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
collapsedRhs);
+ kAcc = smmlaOp;
// Reshape output back to 2D
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
- // With vecmat, only one row of tiled ACC can be inserted inot file result
+ // With vecmat, only one row of tiled ACC can be inserted into file result
if (isVecmat) {
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
}
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index c276a5b0c2a14b..b70ca36c2d7f60 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -269,3 +269,80 @@ func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<
%res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
return %res : vector<8xi32>
}
+
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_k_unroll(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2x16xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2x16xi4>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<2x2xi32>) -> vector<2x2xi32> {
+// CHECK: %[[VAL_3:.*]] = arith.extsi %[[VAL_1]] : vector<2x16xi4> to vector<2x16xi8>
+// CHECK: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_3]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_6:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_2]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_9:.*]] = arm_neon.intr.smmla %[[VAL_8]], %[[VAL_6]], %[[VAL_7]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_3]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_13:.*]] = vector.shape_cast %[[VAL_11]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_14:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_12]], %[[VAL_13]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: return %[[VAL_15]] : vector<2x2xi32>
+// CHECK: }
+func.func @test_lower_vector_arm_neon_k_unroll(%lhs: vector<2x16xi8>, %rhs: vector<2x16xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+ %lhs_extsi = arith.extsi %lhs : vector<2x16xi8> to vector<2x16xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<2x16xi4> to vector<2x16xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x16xi32>, vector<2x16xi32> into vector<2x2xi32>
+ return %res : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_k_unroll_vecmat(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x24xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2x24xi4>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<1x2xi32>) -> vector<1x2xi32> {
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32>
+// CHECK: %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8>
+// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : vector<1x2xi32>
+// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_1]] : vector<2x24xi4> to vector<2x24xi8>
+// CHECK: %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [1, 8], strides = [1, 1]} : vector<1x24xi8> to vector<1x8xi8>
+// CHECK: %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x24xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_10:.*]] = vector.insert_strided_slice %[[VAL_2]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_13:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_14:.*]] = arm_neon.intr.smmla %[[VAL_13]], %[[VAL_11]], %[[VAL_12]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_15]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_17:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[VAL_5]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK: %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [1, 8], strides = [1, 1]} : vector<1x24xi8> to vector<1x8xi8>
+// CHECK: %[[VAL_19:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x24xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_22:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_23:.*]] = arm_neon.intr.smmla %[[VAL_14]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_24:.*]] = vector.shape_cast %[[VAL_23]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_25:.*]] = vector.extract %[[VAL_24]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_26:.*]] = vector.insert_strided_slice %[[VAL_25]], %[[VAL_17]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK: %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 16], sizes = [1, 8], strides = [1, 1]} : vector<1x24xi8> to vector<1x8xi8>
+// CHECK: %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 16], sizes = [2, 8], strides = [1, 1]} : vector<2x24xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_29:.*]] = vector.insert_strided_slice %[[VAL_27]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_31:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_32:.*]] = arm_neon.intr.smmla %[[VAL_23]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_33:.*]] = vector.shape_cast %[[VAL_32]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_34:.*]] = vector.extract %[[VAL_33]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_35:.*]] = vector.insert_strided_slice %[[VAL_34]], %[[VAL_26]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32>
+// CHECK: return %[[VAL_35]] : vector<1x2xi32>
+// CHECK: }
+func.func @test_lower_vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x24xi8>, %rhs: vector<2x24xi4>, %acc : vector<1x2xi32>) -> vector<1x2xi32> {
+ %lhs_extsi = arith.extsi %lhs : vector<1x24xi8> to vector<1x24xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<2x24xi4> to vector<2x24xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x24xi32>, vector<2x24xi32> into vector<1x2xi32>
+ return %res : vector<1x2xi32>
+}
More information about the Mlir-commits
mailing list