[Mlir-commits] [mlir] 04bf1a4 - Update `LowerContractionToSMMLAPattern` to ingnore matvec (#88288)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 10 10:18:51 PDT 2024
Author: Kojo Acquah
Date: 2024-04-10T13:18:47-04:00
New Revision: 04bf1a4090c535e3a1033ab9a8ef92068166461f
URL: https://github.com/llvm/llvm-project/commit/04bf1a4090c535e3a1033ab9a8ef92068166461f
DIFF: https://github.com/llvm/llvm-project/commit/04bf1a4090c535e3a1033ab9a8ef92068166461f.diff
LOG: Update `LowerContractionToSMMLAPattern` to ingnore matvec (#88288)
Patterns in `LowerContractionToSMMLAPattern` are designed to handle
vector-to-matrix multiplication but not matrix-to-vector. This leads to
the following error when processing `rhs` with rank < 2:
```
iree-compile: /usr/local/google/home/kooljblack/code/iree-build/llvm-project/tools/mlir/include/mlir/IR/BuiltinTypeInterfaces.h.inc:268: int64_t mlir::detail::ShapedTypeTrait<mlir::VectorType>::getDimSize(unsigned int) const [ConcreteType = mlir::VectorType]: Assertion `idx < getRank() && "invalid index for shaped type"' failed.
```
Updates to explicitly check the rhs rank and fail cases that cannot
process.
Added:
Modified:
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 13740225749e46..3ae894692089b3 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -54,6 +54,9 @@ class LowerContractionToSMMLAPattern
// Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
+ // Avoid 0-D vectors and 1-D rhs:
+ if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
+ return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index 46c4026d13b660..c276a5b0c2a14b 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -258,3 +258,14 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
return %res : vector<1x8xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_matvec
+// CHECK-NOT: arm_neon.intr.smmla
+func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+ %rhs_extsi= arith.extsi %rhs : vector<8xi8> to vector<8xi32>
+ %lhs_extsi = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
+ return %res : vector<8xi32>
+}
More information about the Mlir-commits
mailing list