[Mlir-commits] [mlir] e5a28a3 - [mlir][spirv] Add MatrixTimesVector Op (#122302)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 17 12:47:38 PST 2025


Author: mishaobu
Date: 2025-01-17T15:47:34-05:00
New Revision: e5a28a3b4d09a3ab128439a0f4eb2659e0b1978b

URL: https://github.com/llvm/llvm-project/commit/e5a28a3b4d09a3ab128439a0f4eb2659e0b1978b
DIFF: https://github.com/llvm/llvm-project/commit/e5a28a3b4d09a3ab128439a0f4eb2659e0b1978b.diff

LOG: [mlir][spirv] Add MatrixTimesVector Op (#122302)

(From SPIRV reference here :
https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpMatrixTimesVector)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
    mlir/test/Target/SPIRV/matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index a4c01c0bc3418d..469a9a0ef01dd2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4171,6 +4171,7 @@ def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
 def SPIRV_IsCooperativeMatrixType :
   CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
 def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
+def SPIRV_IsVectorType : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">;
 def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
 def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
@@ -4202,6 +4203,8 @@ def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
                                   "any SPIR-V cooperative matrix type">;
 def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
                                 "any SPIR-V image type">;
+def SPIRV_AnyVector : DialectType<SPIRV_Dialect, SPIRV_IsVectorType,
+                                "any SPIR-V vector type">;
 def SPIRV_AnyMatrix : DialectType<SPIRV_Dialect, SPIRV_IsMatrixType,
                                 "any SPIR-V matrix type">;
 def SPIRV_AnyRTArray : DialectType<SPIRV_Dialect, SPIRV_IsRTArrayType,
@@ -4384,6 +4387,7 @@ def SPIRV_OC_OpFRem                         : I32EnumAttrCase<"OpFRem", 140>;
 def SPIRV_OC_OpFMod                         : I32EnumAttrCase<"OpFMod", 141>;
 def SPIRV_OC_OpVectorTimesScalar            : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
 def SPIRV_OC_OpMatrixTimesScalar            : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
+def SPIRV_OC_OpMatrixTimesVector           : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
 def SPIRV_OC_OpMatrixTimesMatrix            : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
 def SPIRV_OC_OpDot                          : I32EnumAttrCase<"OpDot", 148>;
 def SPIRV_OC_OpIAddCarry                    : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4553,7 +4557,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
       SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
       SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
-      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
       SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
       SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
       SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index a6f0f41429bcbc..5bd99386e00858 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -114,6 +114,47 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
 
 // -----
 
+def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
+  let summary = "Linear-algebraic multiply of matrix X vector.";
+
+  let description = [{
+    Result Type must be a vector of floating-point type.
+
+    Matrix must be an OpTypeMatrix whose Column Type is Result Type.
+
+    Vector must be a vector with the same Component Type as the Component Type in Result Type. Its number of components must equal the number of columns in Matrix.
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.MatrixTimesVector %matrix, %vector : 
+        !spirv.matrix<3 x vector<2xf32>>, vector<3xf32> -> vector<2xf32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_Matrix]>
+  ];
+
+  let arguments = (ins
+    SPIRV_AnyMatrix:$matrix,
+    SPIRV_AnyVector:$vector
+  );
+
+  let results = (outs
+    SPIRV_AnyVector:$result
+  );
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type($matrix) `,` type($vector) `->` type($result)
+  }];
+}
+
+// -----
+
 def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
   let summary = "Transpose a matrix.";
 

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..040bf6a34cea78 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1698,6 +1698,33 @@ LogicalResult spirv::TransposeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.MatrixTimesVector
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::MatrixTimesVectorOp::verify() {
+  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
+  auto vectorType = llvm::cast<VectorType>(getVector().getType());
+  auto resultType = llvm::cast<VectorType>(getType());
+
+  if (matrixType.getNumColumns() != vectorType.getNumElements())
+    return emitOpError("matrix columns (")
+           << matrixType.getNumColumns() << ") must match vector operand size ("
+           << vectorType.getNumElements() << ")";
+
+  if (resultType.getNumElements() != matrixType.getNumRows())
+    return emitOpError("result size (")
+           << resultType.getNumElements() << ") must match the matrix rows ("
+           << matrixType.getNumRows() << ")";
+
+  auto matrixElementType = matrixType.getElementType();
+  if (matrixElementType != vectorType.getElementType() ||
+      matrixElementType != resultType.getElementType())
+    return emitOpError("matrix, vector, and result element types must match");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.MatrixTimesMatrix
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 372fcc6e514b97..37e7514d664ef0 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -29,6 +29,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
   }
 
+  // CHECK-LABEL: @matrix_times_vector_1
+  spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    spirv.ReturnValue %result : vector<4xf32>
+  }
+
   // CHECK-LABEL: @matrix_times_matrix_1
   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
     // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -124,3 +131,27 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3
    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
    return
 }
+
+// -----
+
+func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) {
+  // expected-error @+1 {{matrix, vector, and result element types must match}}
+  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) {
+  // expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}}
+  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32>
+  return
+}
+
+// -----
+
+func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<3xf32>) {
+  // expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}}
+  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
+  return
+}

diff  --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 2a391df4bff396..0ec1dc27e4e932 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %result : !spirv.matrix<2 x vector<3xf32>>
   }
 
+  // CHECK-LABEL: @matrix_times_vector_1
+  spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+    spirv.ReturnValue %result : vector<4xf32>
+  }
+  
   // CHECK-LABEL: @matrix_times_matrix_1
   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
     // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>


        


More information about the Mlir-commits mailing list