[Mlir-commits] [mlir] 04bf1a4 - Update `LowerContractionToSMMLAPattern` to ingnore matvec (#88288)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 10 10:18:51 PDT 2024


Author: Kojo Acquah
Date: 2024-04-10T13:18:47-04:00
New Revision: 04bf1a4090c535e3a1033ab9a8ef92068166461f

URL: https://github.com/llvm/llvm-project/commit/04bf1a4090c535e3a1033ab9a8ef92068166461f
DIFF: https://github.com/llvm/llvm-project/commit/04bf1a4090c535e3a1033ab9a8ef92068166461f.diff

LOG: Update `LowerContractionToSMMLAPattern` to ingnore matvec (#88288)

Patterns in `LowerContractionToSMMLAPattern` are designed to handle
vector-to-matrix multiplication but not matrix-to-vector. This leads to
the following error when processing `rhs` with rank < 2:

```
iree-compile: /usr/local/google/home/kooljblack/code/iree-build/llvm-project/tools/mlir/include/mlir/IR/BuiltinTypeInterfaces.h.inc:268: int64_t mlir::detail::ShapedTypeTrait<mlir::VectorType>::getDimSize(unsigned int) const [ConcreteType = mlir::VectorType]: Assertion `idx < getRank() && "invalid index for shaped type"' failed.
```

Updates to explicitly check the rhs rank and fail cases that cannot
process.

Added: 
    

Modified: 
    mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
    mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 13740225749e46..3ae894692089b3 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -54,6 +54,9 @@ class LowerContractionToSMMLAPattern
     // Note: RHS is not transposed.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
+    // Avoid 0-D vectors and 1-D rhs:
+    if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
+      return failure();
     auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = rhsType.getDimSize(1);

diff  --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index 46c4026d13b660..c276a5b0c2a14b 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -258,3 +258,14 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
   return %res : vector<1x8xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_matvec
+// CHECK-NOT: arm_neon.intr.smmla
+func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+  %rhs_extsi= arith.extsi %rhs : vector<8xi8> to vector<8xi32>
+  %lhs_extsi = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32>
+  return %res : vector<8xi32>
+}


        


More information about the Mlir-commits mailing list