[Mlir-commits] [mlir] [mlir][spirv] Add definition for VectorTimesMatrixOp (PR #124571)

Igor Wodiany llvmlistbot at llvm.org
Tue Jan 28 10:16:59 PST 2025


https://github.com/IgWod-IMG updated https://github.com/llvm/llvm-project/pull/124571

>From f2dda8e67adecdf5b6138403dac7ef633147403e Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Mon, 27 Jan 2025 15:16:44 +0000
Subject: [PATCH] [mlir][spirv] Add definition for VectorTimesMatrixOp

Adding op as defined in section 3.52.13. (Arithmetic Instructions)
of the SPIR-V specification.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 16 ++++-
 .../mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td   | 65 +++++++++++++++++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 30 +++++++--
 mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir    | 42 +++++++++++-
 mlir/test/Target/SPIRV/matrix.mlir            |  7 ++
 5 files changed, 144 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index c84677d26a8b69..ff738fc2555734 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4234,8 +4234,13 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
     "::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
     "Cooperative Matrix">;
 
+class SPIRV_MatrixOfType<list<Type> allowedTypes> :
+  ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsMatrixType,
+    "::llvm::cast<::mlir::spirv::MatrixType>($_self).getElementType()",
+    "Matrix">;
+
 class SPIRV_VectorOf<Type type> :
-    VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
+    VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
 
 class SPIRV_ScalarOrVectorOf<Type type> :
     AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
@@ -4248,6 +4253,9 @@ class SPIRV_MatrixOrCoopMatrixOf<Type type> :
     AnyTypeOf<[SPIRV_AnyMatrix,
                SPIRV_CoopMatrixOfType<[type]>]>;
 
+class SPIRV_MatrixOf<Type type> :
+    SPIRV_MatrixOfType<[type]>;
+
 def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
 def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
 
@@ -4387,7 +4395,8 @@ def SPIRV_OC_OpFRem                         : I32EnumAttrCase<"OpFRem", 140>;
 def SPIRV_OC_OpFMod                         : I32EnumAttrCase<"OpFMod", 141>;
 def SPIRV_OC_OpVectorTimesScalar            : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
 def SPIRV_OC_OpMatrixTimesScalar            : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
-def SPIRV_OC_OpMatrixTimesVector           : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
+def SPIRV_OC_OpVectorTimesMatrix            : I32EnumAttrCase<"OpVectorTimesMatrix", 144>;
+def SPIRV_OC_OpMatrixTimesVector            : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
 def SPIRV_OC_OpMatrixTimesMatrix            : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
 def SPIRV_OC_OpDot                          : I32EnumAttrCase<"OpDot", 148>;
 def SPIRV_OC_OpIAddCarry                    : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4559,7 +4568,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
       SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
       SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
-      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
+      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+      SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector,
       SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
       SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
       SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index 5bd99386e00858..f2796861cdf561 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -63,8 +63,7 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
 
 // -----
 
-def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
-    "MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
+def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
@@ -114,8 +113,11 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
 
 // -----
 
-def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
-  let summary = "Linear-algebraic multiply of matrix X vector.";
+def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [
+    Pure,
+    AllElementTypesMatch<["vector", "result"]>
+  ]> {
+  let summary = "Linear-algebraic Matrix X Vector.";
 
   let description = [{
     Result Type must be a vector of floating-point type.
@@ -140,12 +142,12 @@ def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
   ];
 
   let arguments = (ins
-    SPIRV_AnyMatrix:$matrix,
-    SPIRV_AnyVector:$vector
+    SPIRV_MatrixOf<SPIRV_Float>:$matrix,
+    SPIRV_VectorOf<SPIRV_Float>:$vector
   );
 
   let results = (outs
-    SPIRV_AnyVector:$result
+    SPIRV_VectorOf<SPIRV_Float>:$result
   );
 
   let assemblyFormat = [{
@@ -198,4 +200,53 @@ def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
 
 // -----
 
+def SPIRV_VectorTimesMatrixOp : SPIRV_Op<"VectorTimesMatrix", [
+    Pure,
+    AllElementTypesMatch<["vector", "result"]>
+  ]> {
+  let summary = "Linear-algebraic Vector X Matrix.";
+
+  let description = [{
+    Result Type must be a vector of floating-point type.
+
+    Vector must be a vector with the same Component Type as the Component
+    Type in Result Type. Its number of components must equal the number of
+    components in each column in Matrix.
+
+    Matrix must be a matrix with the same Component Type as the Component
+    Type in Result Type. Its number of columns must equal the number of
+    components in Result Type.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %result = spirv.VectorTimesMatrix %vector, %matrix : vector<4xf32>, !spirv.matrix<4 x vector<4xf32>> -> vector<4xf32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_Matrix]>
+  ];
+
+  let arguments = (ins
+    SPIRV_VectorOf<SPIRV_Float>:$vector,
+    SPIRV_MatrixOf<SPIRV_Float>:$matrix
+  );
+
+  let results = (outs
+    SPIRV_VectorOf<SPIRV_Float>:$result
+  );
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type($vector) `,` type($matrix) `->` type($result)
+  }];
+}
+
+// -----
+
 #endif // MLIR_DIALECT_SPIRV_IR_MATRIX_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 040bf6a34cea78..f0f03e989cb475 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1717,10 +1717,32 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
            << resultType.getNumElements() << ") must match the matrix rows ("
            << matrixType.getNumRows() << ")";
 
-  auto matrixElementType = matrixType.getElementType();
-  if (matrixElementType != vectorType.getElementType() ||
-      matrixElementType != resultType.getElementType())
-    return emitOpError("matrix, vector, and result element types must match");
+  if (matrixType.getElementType() != resultType.getElementType())
+    return emitOpError("matrix and result element types must match");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.VectorTimesMatrix
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::VectorTimesMatrixOp::verify() {
+  auto vectorType = llvm::cast<VectorType>(getVector().getType());
+  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
+  auto resultType = llvm::cast<VectorType>(getType());
+
+  if (matrixType.getNumRows() != vectorType.getNumElements())
+    return emitOpError("number of components in vector must equal the number "
+                       "of components in each column in matrix");
+
+  if (resultType.getNumElements() != matrixType.getNumColumns())
+    return emitOpError("number of columns in matrix must equal the number of "
+                       "components in result");
+
+  if (matrixType.getElementType() != resultType.getElementType())
+    return emitOpError("matrix must be a matrix with the same component type "
+                       "as the component type in result");
 
   return success();
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 37e7514d664ef0..ba95322dbf38bb 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %result : vector<4xf32>
   }
 
+  // CHECK-LABEL: @vector_times_matrix_1
+  spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
+    // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+    %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+    spirv.ReturnValue %result : vector<4xf32>
+  }
+
   // CHECK-LABEL: @matrix_times_matrix_1
   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
     // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -123,7 +130,6 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
    return
 }
 
-
 // -----
 
 func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
@@ -135,7 +141,7 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3
 // -----
 
 func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) {
-  // expected-error @+1 {{matrix, vector, and result element types must match}}
+  // expected-error @+1 {{op failed to verify that all of {vector, result} have same element type}}
   %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32>
   return
 }
@@ -155,3 +161,35 @@ func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3
   %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
   return
 }
+
+// -----
+
+func.func @vector_times_matrix_vector_matrix_mismatch(%arg0: vector<4xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+  // expected-error @+1 {{number of components in vector must equal the number of components in each column in matrix}}
+  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<4xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @vector_times_matrix_result_matrix_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+  // expected-error @+1 {{number of columns in matrix must equal the number of components in result}}
+  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @vector_times_matrix_vector_type_mismatch(%arg0: vector<3xf16>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+  // expected-error @+1 {{op failed to verify that all of {vector, result} have same element type}}
+  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf16>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+  return
+}
+
+// -----
+
+func.func @vector_times_matrix_matrix_type_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf16>>) {
+  // expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}}
+  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf16>> -> vector<4xf32>
+  return
+}
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 0ec1dc27e4e932..452f8fc16f2588 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -42,6 +42,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
     spirv.ReturnValue %result : vector<4xf32>
   }
+
+  // CHECK-LABEL: @vector_times_matrix_1
+  spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
+    // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+    %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+    spirv.ReturnValue %result : vector<4xf32>
+  }
   
   // CHECK-LABEL: @matrix_times_matrix_1
   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{



More information about the Mlir-commits mailing list