[Mlir-commits] [mlir] fe84369 - [mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern (#84848)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 19 10:09:38 PDT 2024
Author: Kojo Acquah
Date: 2024-03-19T13:09:33-04:00
New Revision: fe84369cc6759194e006f3f624a064bce13c84d4
URL: https://github.com/llvm/llvm-project/commit/fe84369cc6759194e006f3f624a064bce13c84d4
DIFF: https://github.com/llvm/llvm-project/commit/fe84369cc6759194e006f3f624a064bce13c84d4.diff
LOG: [mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern (#84848)
This patch updates `LowerContractionToSMMLAPattern` to unroll larger vector contracts into multiple smmla instructions.
Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2. `vector.extract_strided_slice`/`vector.insert_strided_slice` divides the contract into tiles to be processed in a row.
Added:
Modified:
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 47c84708f3c38b..1f48d27aa27b17 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -16,7 +16,9 @@
#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -36,8 +38,10 @@ static Type matchContainerType(Type element, Type container) {
return element;
}
-/// Lowering from a single vector::contractOp directly to the arm neon smmla
-/// intrinsic. The shapes of the contract and intrinsic must match.
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
+/// any vector.contract into multiple smmla instructions with unrolling so long
+/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
+/// smmla instruction is emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
@@ -45,10 +49,6 @@ class LowerContractionToSMMLAPattern
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Value lhs = op.getLhs();
- Value rhs = op.getRhs();
- Value res = op.getAcc();
-
// Check index maps that represent M N K in contract.
auto indexingMaps = op.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
@@ -57,7 +57,6 @@ class LowerContractionToSMMLAPattern
})) {
return failure();
}
-
// Check iterator types for contract.
auto iteratorTypes = op.getIteratorTypesArray();
if (iteratorTypes.size() != 3 ||
@@ -66,22 +65,24 @@ class LowerContractionToSMMLAPattern
iteratorTypes[2] != vector::IteratorType::reduction) {
return failure();
}
-
- // Check the tile size by mapping the dimensions of the contract.
+ // Infer tile sizes from operands; Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
auto dimM = lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = lhsType.getDimSize(1);
- if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+
+ // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
+ // tiling.
+ if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
return failure();
}
// Check two extsi inputs Rhs Lhs for contract.
arith::ExtSIOp origLhsExtOp =
- dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
+ dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
arith::ExtSIOp origRhsExtOp =
- dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp());
+ dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
if (!origLhsExtOp || !origRhsExtOp) {
return failure();
}
@@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
return failure();
}
- // Collapse to 1D vectors required by smmla intrinsic
- auto collapsedInputType = VectorType::get(
- {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
- auto collapsedOutputType =
- VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
- auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
- extsiLhs.getLoc(), collapsedInputType, extsiLhs);
- auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
- extsiRhs.getLoc(), collapsedInputType, extsiRhs);
- auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
- res.getLoc(), collapsedOutputType, res);
-
- // Replace the contract with a neon op
- auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
- op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
- collapsedRhs);
-
- // Reshape output back to 2D
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
- smmlaOp);
+ // Initial accumulator for the final result. This is the un-tiled result if
+ // tiling is done.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+
+ SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
+ SmallVector<int64_t> smmlaShape{2, 2, 8};
+ SmallVector<int64_t> loopOrder{0, 1, 2};
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+
+ // Helper to compute the new shape of each operand and extract the slice.
+ auto extractOperand = [&](Value operand, AffineMap permutationMap,
+ ArrayRef<int64_t> operandOffsets) {
+ SmallVector<int64_t> operandShape =
+ applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+ SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
+ return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand, operandOffsets, operandShape, operandStrides);
+ };
+
+ // Extract tiled lhs, rhs, and acc
+ AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
+ SmallVector<int64_t> lhsOffsets =
+ applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+ Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
+ AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+ SmallVector<int64_t> rhsOffsets =
+ applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+ Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
+ AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+ SmallVector<int64_t> accOffsets =
+ applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
+ Value tiledAcc =
+ extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+
+ // Collapse tiled operands to 1D vectors required by smmla intrinsic
+ auto collapsedInputType = VectorType::get(
+ tiledLhs.getType().cast<ShapedType>().getNumElements(),
+ tiledLhs.getType().cast<ShapedType>().getElementType());
+ auto collapsedOutputType = VectorType::get(
+ {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
+ auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledLhs.getLoc(), collapsedInputType, tiledLhs);
+ auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+ auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+
+ // Insert contract op
+ auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
+ op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+ collapsedRhs);
+
+ // Reshape output back to 2D
+ Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
+ smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+
+ // Insert the tiled result back into the non tiled result of the
+ // contract op.
+ SmallVector<int64_t> strides(
+ tiledRes.getType().cast<ShapedType>().getRank(), 1);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tiledRes, result, accOffsets, strides);
+ }
+
+ rewriter.replaceOp(op, result);
return success();
}
};
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index cba7b00ba77a82..e2be87453bf6f2 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -40,3 +40,97 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
return %res : vector<2x2xi32>
}
+
+// -----
+
+// CHECK-LABEL: test_lower_vector_arm_neon_unroll
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
+// CHECK-DAG: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK-DAG: %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK-DAG: %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK-DAG: return %[[VAL_39]] : vector<4x4xi32>
+// CHECK-DAG: }
+func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
+ %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_mixed_unroll(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2x8xi4>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<4x2xi32>) -> vector<4x2xi32> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x2xi32>
+// CHECK-DAG: %[[VAL_4:.*]] = arith.extsi %[[VAL_1]] : vector<2x8xi4> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
+// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[VAL_18:.*]] = arm_neon.intr.smmla %[[VAL_17]], %[[VAL_15]], %[[VAL_16]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_18]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_19]], %[[VAL_12]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
+// CHECK-DAG: return %[[VAL_20]] : vector<4x2xi32>
+// CHECK-DAG: }
+func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> {
+ %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<2x8xi32> into vector<4x2xi32>
+ return %res : vector<4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(
+// CHECK-DAG: %[[result:.*]] = vector.contract
+func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
+ %lhs_extsi = arith.extsi %lhs : vector<4x12xi8> to vector<4x12xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<4x12xi8> to vector<4x12xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
More information about the Mlir-commits
mailing list