[Mlir-commits] [mlir] Update `LowerContractionToSMMLAPattern` to ingnore matvec (PR #88288)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 10 09:11:40 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-neon
Author: Kojo Acquah (KoolJBlack)
<details>
<summary>Changes</summary>
Patterns in `LowerContractionToSMMLAPattern` are designed to handle vector-to-matrix multiplication but not matrix-to-vector. This leads to the following error when processing `rhs` with rank < 2:
```
iree-compile: /usr/local/google/home/kooljblack/code/iree-build/llvm-project/tools/mlir/include/mlir/IR/BuiltinTypeInterfaces.h.inc:268: int64_t mlir::detail::ShapedTypeTrait<mlir::VectorType>::getDimSize(unsigned int) const [ConcreteType = mlir::VectorType]: Assertion `idx < getRank() && "invalid index for shaped type"' failed.
```
Updates to explicitly check the rhs rank and fail cases that cannot process.
---
Full diff: https://github.com/llvm/llvm-project/pull/88288.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+2)
- (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 13740225749e46..efdaeeda4fec5d 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -54,6 +54,8 @@ class LowerContractionToSMMLAPattern
// Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
+ if (rhsType.getRank() < 2)
+ return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index 46c4026d13b660..c276a5b0c2a14b 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -258,3 +258,14 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
return %res : vector<1x8xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_matvec
+// CHECK-NOT: arm_neon.intr.smmla
+func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+ %rhs_extsi= arith.extsi %rhs : vector<8xi8> to vector<8xi32>
+ %lhs_extsi = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
+ return %res : vector<8xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/88288
More information about the Mlir-commits
mailing list