[Mlir-commits] [mlir] [mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns (PR #86005)
Kojo Acquah
llvmlistbot at llvm.org
Wed Mar 20 14:09:25 PDT 2024
https://github.com/KoolJBlack created https://github.com/llvm/llvm-project/pull/86005
None
>From a6b07cc143a542288dd4b4bf83bc551c334dc76b Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Wed, 20 Mar 2024 20:56:08 +0000
Subject: [PATCH] implement vecmat unroll for i8mm
---
.../LowerContractionToSMMLAPattern.cpp | 66 ++++++++++++-------
.../Dialect/ArmNeon/lower-to-arm-neon.mlir | 60 +++++++++++++++++
2 files changed, 102 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 1f48d27aa27b17..a37bebc325c09d 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -49,32 +49,28 @@ class LowerContractionToSMMLAPattern
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- // Check index maps that represent M N K in contract.
- auto indexingMaps = op.getIndexingMapsArray();
- if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
- return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
- affineMap.getNumResults() != 2;
- })) {
- return failure();
- }
- // Check iterator types for contract.
- auto iteratorTypes = op.getIteratorTypesArray();
- if (iteratorTypes.size() != 3 ||
- iteratorTypes[0] != vector::IteratorType::parallel ||
- iteratorTypes[1] != vector::IteratorType::parallel ||
- iteratorTypes[2] != vector::IteratorType::reduction) {
- return failure();
- }
- // Infer tile sizes from operands; Note: RHS is not transposed.
+ // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+ // Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
- auto dimM = lhsType.getDimSize(0);
+ auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
- auto dimK = lhsType.getDimSize(1);
-
+ auto dimK = rhsType.getDimSize(1);
+ bool isVecmat = dimM == 1 ? true : false;
+ if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+ rhsType.getDimSize(rhsType.getRank() - 1)) {
+ return failure(); // dimK mismatch
+ }
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
// tiling.
- if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+ if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+ return failure();
+ }
+
+ // Check iterator types for contract.
+ auto iteratorTypes = op.getIteratorTypesArray();
+ if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+ vector::IteratorType::reduction) {
return failure();
}
@@ -122,9 +118,12 @@ class LowerContractionToSMMLAPattern
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
SmallVector<int64_t> smmlaShape{2, 2, 8};
SmallVector<int64_t> loopOrder{0, 1, 2};
+ if (isVecmat) {
+ smmlaShape = {2, 8};
+ loopOrder = {0, 1};
+ }
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
-
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +149,30 @@ class LowerContractionToSMMLAPattern
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+ // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+ // rows along dimM. Broadcast both to the full width
+ if (isVecmat) {
+ auto lhsBroadcastType = VectorType::get(
+ {2, 8}, tiledLhs.getType().cast<ShapedType>().getElementType());
+ tiledLhs = rewriter.create<vector::BroadcastOp>(loc, lhsBroadcastType,
+ tiledLhs);
+ auto accBroadcastType = VectorType::get(
+ {2, 2}, tiledAcc.getType().cast<ShapedType>().getElementType());
+ tiledAcc = rewriter.create<vector::BroadcastOp>(loc, accBroadcastType,
+ tiledAcc);
+ }
+
// Collapse tiled operands to 1D vectors required by smmla intrinsic
auto collapsedInputType = VectorType::get(
tiledLhs.getType().cast<ShapedType>().getNumElements(),
tiledLhs.getType().cast<ShapedType>().getElementType());
- auto collapsedOutputType = VectorType::get(
- {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+ auto collapsedOutputType = VectorType::get(
+ tiledAcc.getType().cast<ShapedType>().getNumElements(),
+ tiledAcc.getType().cast<ShapedType>().getElementType());
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
@@ -172,6 +185,11 @@ class LowerContractionToSMMLAPattern
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+ // With vecmat, only one row of tiled ACC can be inserted inot file result
+ if (isVecmat) {
+ tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+ }
+
// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index e2be87453bf6f2..0615d155bf4f6a 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -134,3 +134,63 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
return %res : vector<4x4xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_5]] : vector<2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_11:.*]] = arm_neon.intr.smmla %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_12]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_14:.*]] = vector.insert_strided_slice %[[VAL_13]], %[[VAL_3]] {offsets = [0], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_16]] : vector<2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_18]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_22:.*]] = arm_neon.intr.smmla %[[VAL_21]], %[[VAL_19]], %[[VAL_20]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_23]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_25:.*]] = vector.insert_strided_slice %[[VAL_24]], %[[VAL_14]] {offsets = [2], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: %[[VAL_26:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [4], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_28:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_29:.*]] = vector.broadcast %[[VAL_27]] : vector<2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_31:.*]] = vector.shape_cast %[[VAL_26]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_33:.*]] = arm_neon.intr.smmla %[[VAL_32]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_33]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_36:.*]] = vector.insert_strided_slice %[[VAL_35]], %[[VAL_25]] {offsets = [4], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_38:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [6], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_39:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_40:.*]] = vector.broadcast %[[VAL_38]] : vector<2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_42:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[VAL_40]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_44:.*]] = arm_neon.intr.smmla %[[VAL_43]], %[[VAL_41]], %[[VAL_42]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_45:.*]] = vector.shape_cast %[[VAL_44]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_46:.*]] = vector.extract %[[VAL_45]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_47:.*]] = vector.insert_strided_slice %[[VAL_46]], %[[VAL_36]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: return %[[VAL_47]] : vector<8xi32>
+// CHECK: }
+func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+ %lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32>
+ return %res : vector<8xi32>
+}
More information about the Mlir-commits
mailing list