[Mlir-commits] [mlir] [mlir][spirv] Add MatrixTimesVector Op (PR #122302)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 17 12:31:34 PST 2025


https://github.com/mishaobu updated https://github.com/llvm/llvm-project/pull/122302

>From 7d8dc65b6834c474c13bd496d877a712e9cf2297 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:16:42 +0100
Subject: [PATCH 1/7] target tests

---
 mlir/test/Target/SPIRV/matrix.mlir | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 2a391df4bff396..fde379af4b54ee 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -36,6 +36,20 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %result : !spirv.matrix<2 x vector<3xf32>>
   }
 
+  // CHECK-LABEL: @matrix_times_vector_1
+  spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    spirv.ReturnValue %result : vector<4xf32>
+  }
+
+  // CHECK-LABEL: @matrix_times_vector_2
+  spirv.func @matrix_times_vector_2(%arg0: vector<3xf32>, %arg1: !spirv.matrix<3 x vector<4xf32>>) -> vector<4xf32> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    %result = spirv.MatrixTimesVector %arg1, %arg0 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    spirv.ReturnValue %result : vector<4xf32>
+  }
+  
   // CHECK-LABEL: @matrix_times_matrix_1
   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
     // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>

>From 8bd529f0dabd28eaeb4e082d6df19ca49355b055 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:16:53 +0100
Subject: [PATCH 2/7] dialect tests

---
 mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir | 31 ++++++++++++++++++++++
 1 file changed, 31 insertions(+)

diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 372fcc6e514b97..04220a702da5b3 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -29,6 +29,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
   }
 
+  // CHECK-LABEL: @matrix_times_vector_1
+  spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    spirv.ReturnValue %result : vector<4xf32>
+  }
+
   // CHECK-LABEL: @matrix_times_matrix_1
   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
     // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -124,3 +131,27 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3
    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
    return
 }
+
+// -----
+
+func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) {
+  // expected-error @+1 {{matrix, vector, and result element types must match}}
+  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) {
+  // expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}}
+  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32>
+  return
+}
+
+// -----
+
+func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<3xf32>) {
+  // expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}}
+  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
+  return
+}
\ No newline at end of file

>From 8813b824a6c8b1fd402fe41611c0378b492bfa79 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:17:38 +0100
Subject: [PATCH 3/7] add MatrixTimesVector op to tablegen

---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  6 ++-
 .../mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td   | 41 +++++++++++++++++++
 2 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index a4c01c0bc3418d..469a9a0ef01dd2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4171,6 +4171,7 @@ def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
 def SPIRV_IsCooperativeMatrixType :
   CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
 def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
+def SPIRV_IsVectorType : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">;
 def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
 def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
@@ -4202,6 +4203,8 @@ def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
                                   "any SPIR-V cooperative matrix type">;
 def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
                                 "any SPIR-V image type">;
+def SPIRV_AnyVector : DialectType<SPIRV_Dialect, SPIRV_IsVectorType,
+                                "any SPIR-V vector type">;
 def SPIRV_AnyMatrix : DialectType<SPIRV_Dialect, SPIRV_IsMatrixType,
                                 "any SPIR-V matrix type">;
 def SPIRV_AnyRTArray : DialectType<SPIRV_Dialect, SPIRV_IsRTArrayType,
@@ -4384,6 +4387,7 @@ def SPIRV_OC_OpFRem                         : I32EnumAttrCase<"OpFRem", 140>;
 def SPIRV_OC_OpFMod                         : I32EnumAttrCase<"OpFMod", 141>;
 def SPIRV_OC_OpVectorTimesScalar            : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
 def SPIRV_OC_OpMatrixTimesScalar            : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
+def SPIRV_OC_OpMatrixTimesVector           : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
 def SPIRV_OC_OpMatrixTimesMatrix            : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
 def SPIRV_OC_OpDot                          : I32EnumAttrCase<"OpDot", 148>;
 def SPIRV_OC_OpIAddCarry                    : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4553,7 +4557,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
       SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
       SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
-      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
       SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
       SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
       SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index a6f0f41429bcbc..5bd99386e00858 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -114,6 +114,47 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
 
 // -----
 
+def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
+  let summary = "Linear-algebraic multiply of matrix X vector.";
+
+  let description = [{
+    Result Type must be a vector of floating-point type.
+
+    Matrix must be an OpTypeMatrix whose Column Type is Result Type.
+
+    Vector must be a vector with the same Component Type as the Component Type in Result Type. Its number of components must equal the number of columns in Matrix.
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.MatrixTimesVector %matrix, %vector : 
+        !spirv.matrix<3 x vector<2xf32>>, vector<3xf32> -> vector<2xf32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_Matrix]>
+  ];
+
+  let arguments = (ins
+    SPIRV_AnyMatrix:$matrix,
+    SPIRV_AnyVector:$vector
+  );
+
+  let results = (outs
+    SPIRV_AnyVector:$result
+  );
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type($matrix) `,` type($vector) `->` type($result)
+  }];
+}
+
+// -----
+
 def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
   let summary = "Transpose a matrix.";
 

>From 646004106b9bc6273b91eb8483d1f017353c102b Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:18:33 +0100
Subject: [PATCH 4/7] implement MatrixTimesVector verify

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 29 ++++++++++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..7eebd0c989fde3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1698,6 +1698,35 @@ LogicalResult spirv::TransposeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.MatrixTimesVector
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::MatrixTimesVectorOp::verify() {
+  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
+  auto vectorType = llvm::cast<VectorType>(getVector().getType());
+  auto resultType = llvm::cast<VectorType>(getType());
+
+  if (matrixType.getNumColumns() != vectorType.getNumElements())
+    return emitOpError("matrix columns (")
+           << matrixType.getNumColumns()
+           << ") must match vector operand size ("
+           << vectorType.getNumElements() << ")";
+
+  if (resultType.getNumElements() != matrixType.getNumRows())
+    return emitOpError("result size (")
+           << resultType.getNumElements()
+           << ") must match the matrix rows ("
+           << matrixType.getNumRows() << ")";
+
+  auto matrixElementType = matrixType.getElementType();
+  if (matrixElementType != vectorType.getElementType() ||
+      matrixElementType != resultType.getElementType())
+    return emitOpError("matrix, vector, and result element types must match");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.MatrixTimesMatrix
 //===----------------------------------------------------------------------===//

>From 9d0d52595af4efc36394be7f0222c65899e19283 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:22:10 +0100
Subject: [PATCH 5/7] apply clang format

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7eebd0c989fde3..040bf6a34cea78 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1709,14 +1709,12 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
 
   if (matrixType.getNumColumns() != vectorType.getNumElements())
     return emitOpError("matrix columns (")
-           << matrixType.getNumColumns()
-           << ") must match vector operand size ("
+           << matrixType.getNumColumns() << ") must match vector operand size ("
            << vectorType.getNumElements() << ")";
 
   if (resultType.getNumElements() != matrixType.getNumRows())
     return emitOpError("result size (")
-           << resultType.getNumElements()
-           << ") must match the matrix rows ("
+           << resultType.getNumElements() << ") must match the matrix rows ("
            << matrixType.getNumRows() << ")";
 
   auto matrixElementType = matrixType.getElementType();

>From c5a8c94eceb21a04d69aa21db3a7feeed14f5b39 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 18:47:50 +0100
Subject: [PATCH 6/7] newline

---
 mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 04220a702da5b3..37e7514d664ef0 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -154,4 +154,4 @@ func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3
   // expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}}
   %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
   return
-}
\ No newline at end of file
+}

>From f3fc230c5956c7c6acffb85869ceb54cb6c1d1a0 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 18:35:39 +0100
Subject: [PATCH 7/7] Remove redundant test

---
 mlir/test/Target/SPIRV/matrix.mlir | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index fde379af4b54ee..0ec1dc27e4e932 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -42,13 +42,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
     spirv.ReturnValue %result : vector<4xf32>
   }
-
-  // CHECK-LABEL: @matrix_times_vector_2
-  spirv.func @matrix_times_vector_2(%arg0: vector<3xf32>, %arg1: !spirv.matrix<3 x vector<4xf32>>) -> vector<4xf32> "None" {
-    // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
-    %result = spirv.MatrixTimesVector %arg1, %arg0 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
-    spirv.ReturnValue %result : vector<4xf32>
-  }
   
   // CHECK-LABEL: @matrix_times_matrix_1
   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{



More information about the Mlir-commits mailing list