[Mlir-commits] [mlir] [mlir][vector] Disable `vector.matrix_multiply` for scalable vectors (PR #102573)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Aug 9 02:24:44 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/102573

>From c0334a95251ca6ef700647a85e5de96c1c2cd12e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 9 Aug 2024 06:58:58 +0100
Subject: [PATCH 1/2] [nlir][vector] Disable `vector.matrix_multiply` for
 scalable vectors

Disables `vector.matrix_multiply` for scalable vectors. As per the docs:

>  This is the counterpart of llvm.matrix.multiply in MLIR

I'm not aware of any use of matrix-multiply intrinsics in the context of
scalable vectors, hence disabling.
---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 11 +++++++----
 mlir/include/mlir/IR/CommonTypeConstraints.td    |  8 ++++++++
 mlir/test/Dialect/Vector/invalid.mlir            | 13 +++++++++++++
 3 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bc97a5ae7d2f70..b8559efda13e99 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2683,13 +2683,13 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
                     TCresVTEtIsSameAsOpBase<0, 1>>]>,
       Arguments<(
         // TODO: tighten vector element types that make sense.
-        ins VectorOfRankAndType<[1],
+        ins FixedVectorOfRankAndType<[1],
               [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
-            VectorOfRankAndType<[1],
+            FixedVectorOfRankAndType<[1],
               [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
             I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
       Results<(
-        outs VectorOfRankAndType<[1],
+        outs FixedVectorOfRankAndType<[1],
                [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
 {
   let summary = "Vector matrix multiplication op that operates on flattened 1-D"
@@ -2707,7 +2707,10 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
     <rhs_columns> and multiplies them. The result matrix is returned embedded in
     the result vector.
 
-    Also see:
+    Note, the semantics of the corresponding LLVM intrinsic,
+    `@llvm.matrix.multiply.*`, are not clear in the context of scalable
+    vectors. Hence, this Op is only available for fixed-width vectors. Also
+    see:
 
     http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa2420..2eec2c6073bbf2 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -494,6 +494,14 @@ class VectorOfRankAndType<list<int> allowedRanks,
   VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
   "::mlir::VectorType">;
 
+// Fixed-width vector where the rank is from the given `allowedRanks` list and
+// the type is from the given `allowedTypes` list
+class FixedVectorOfRankAndType<list<int> allowedRanks,
+                          list<Type> allowedTypes> : AllOfType<
+  [FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
+  FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+  "::mlir::VectorType">;
+
 // Whether the number of elements of a vector is from the given
 // `allowedLengths` list
 class IsVectorOfLengthPred<list<int> allowedLengths> :
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ba1efe8b3c2d38..6e077a2fb4cee4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1862,3 +1862,16 @@ func.func @invalid_step_2d() {
   vector.step : vector<2x4xf32>
   return
 }
+
+// -----
+
+func.func @matrix_matmul_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) {
+  // expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}}
+  %c = vector.matrix_multiply %a, %b {
+    lhs_rows = 2: i32,
+    lhs_columns = 2: i32 ,
+    rhs_columns = 2: i32 }
+  : (vector<[4]xf64>, vector<4xf64>) -> vector<4xf64>
+
+  return
+}

>From e6d8031e6d378f14be015d2d7dd33e8343427cf7 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 9 Aug 2024 10:24:19 +0100
Subject: [PATCH 2/2] fixup! [nlir][vector] Disable `vector.matrix_multiply`
 for scalable vectors

Include suggestions from Cullen
---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 7 +++----
 mlir/include/mlir/IR/CommonTypeConstraints.td    | 2 +-
 mlir/test/Dialect/Vector/invalid.mlir            | 2 +-
 3 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b8559efda13e99..a2a317109e29d8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2707,10 +2707,9 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
     <rhs_columns> and multiplies them. The result matrix is returned embedded in
     the result vector.
 
-    Note, the semantics of the corresponding LLVM intrinsic,
-    `@llvm.matrix.multiply.*`, are not clear in the context of scalable
-    vectors. Hence, this Op is only available for fixed-width vectors. Also
-    see:
+    Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
+    support scalable vectors. Hence, this Op is only available for fixed-width
+    vectors. Also see:
 
     http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 2eec2c6073bbf2..2493f212a356a4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -600,7 +600,7 @@ class VectorOfLengthAndType<list<int> allowedLengths,
 // Any fixed-length vector where the number of elements is from the given
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class FixedVectorOfLengthAndType<list<int> allowedLengths,
-                                    list<Type> allowedTypes> : AllOfType<
+                                 list<Type> allowedTypes> : AllOfType<
   [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
   FixedVectorOf<allowedTypes>.summary #
   FixedVectorOfLength<allowedLengths>.summary,
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 6e077a2fb4cee4..c95b8bd5ed6147 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1865,7 +1865,7 @@ func.func @invalid_step_2d() {
 
 // -----
 
-func.func @matrix_matmul_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) {
+func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) {
   // expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}}
   %c = vector.matrix_multiply %a, %b {
     lhs_rows = 2: i32,



More information about the Mlir-commits mailing list