[Mlir-commits] [mlir] [mlir][spirv] Add MatrixTimesVector Op (PR #122302)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 17 09:40:09 PST 2025
https://github.com/mishaobu updated https://github.com/llvm/llvm-project/pull/122302
>From 7d8dc65b6834c474c13bd496d877a712e9cf2297 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:16:42 +0100
Subject: [PATCH 1/7] target tests
---
mlir/test/Target/SPIRV/matrix.mlir | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 2a391df4bff396..fde379af4b54ee 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -36,6 +36,20 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %result : !spirv.matrix<2 x vector<3xf32>>
}
+ // CHECK-LABEL: @matrix_times_vector_1
+ spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+ %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+ spirv.ReturnValue %result : vector<4xf32>
+ }
+
+ // CHECK-LABEL: @matrix_times_vector_2
+ spirv.func @matrix_times_vector_2(%arg0: vector<3xf32>, %arg1: !spirv.matrix<3 x vector<4xf32>>) -> vector<4xf32> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+ %result = spirv.MatrixTimesVector %arg1, %arg0 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+ spirv.ReturnValue %result : vector<4xf32>
+ }
+
// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
>From 8bd529f0dabd28eaeb4e082d6df19ca49355b055 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:16:53 +0100
Subject: [PATCH 2/7] dialect tests
---
mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir | 31 ++++++++++++++++++++++
1 file changed, 31 insertions(+)
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 372fcc6e514b97..04220a702da5b3 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -29,6 +29,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
}
+ // CHECK-LABEL: @matrix_times_vector_1
+ spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+ %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
+ spirv.ReturnValue %result : vector<4xf32>
+ }
+
// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -124,3 +131,27 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
return
}
+
+// -----
+
+func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) {
+ // expected-error @+1 {{matrix, vector, and result element types must match}}
+ %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) {
+ // expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}}
+ %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32>
+ return
+}
+
+// -----
+
+func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<3xf32>) {
+ // expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}}
+ %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
+ return
+}
\ No newline at end of file
>From 8813b824a6c8b1fd402fe41611c0378b492bfa79 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:17:38 +0100
Subject: [PATCH 3/7] add MatrixTimesVector op to tablegen
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 6 ++-
.../mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td | 41 +++++++++++++++++++
2 files changed, 46 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index a4c01c0bc3418d..469a9a0ef01dd2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4171,6 +4171,7 @@ def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
def SPIRV_IsCooperativeMatrixType :
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
+def SPIRV_IsVectorType : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">;
def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
@@ -4202,6 +4203,8 @@ def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
"any SPIR-V cooperative matrix type">;
def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
"any SPIR-V image type">;
+def SPIRV_AnyVector : DialectType<SPIRV_Dialect, SPIRV_IsVectorType,
+ "any SPIR-V vector type">;
def SPIRV_AnyMatrix : DialectType<SPIRV_Dialect, SPIRV_IsMatrixType,
"any SPIR-V matrix type">;
def SPIRV_AnyRTArray : DialectType<SPIRV_Dialect, SPIRV_IsRTArrayType,
@@ -4384,6 +4387,7 @@ def SPIRV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
+def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4553,7 +4557,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
- SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+ SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index a6f0f41429bcbc..5bd99386e00858 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -114,6 +114,47 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
// -----
+def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
+ let summary = "Linear-algebraic multiply of matrix X vector.";
+
+ let description = [{
+ Result Type must be a vector of floating-point type.
+
+ Matrix must be an OpTypeMatrix whose Column Type is Result Type.
+
+ Vector must be a vector with the same Component Type as the Component Type in Result Type. Its number of components must equal the number of columns in Matrix.
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.MatrixTimesVector %matrix, %vector :
+ !spirv.matrix<3 x vector<2xf32>>, vector<3xf32> -> vector<2xf32>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[SPIRV_C_Matrix]>
+ ];
+
+ let arguments = (ins
+ SPIRV_AnyMatrix:$matrix,
+ SPIRV_AnyVector:$vector
+ );
+
+ let results = (outs
+ SPIRV_AnyVector:$result
+ );
+
+ let assemblyFormat = [{
+ operands attr-dict `:` type($matrix) `,` type($vector) `->` type($result)
+ }];
+}
+
+// -----
+
def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
let summary = "Transpose a matrix.";
>From 646004106b9bc6273b91eb8483d1f017353c102b Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:18:33 +0100
Subject: [PATCH 4/7] implement MatrixTimesVector verify
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 29 ++++++++++++++++++++++++++
1 file changed, 29 insertions(+)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..7eebd0c989fde3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1698,6 +1698,35 @@ LogicalResult spirv::TransposeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.MatrixTimesVector
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::MatrixTimesVectorOp::verify() {
+ auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
+ auto vectorType = llvm::cast<VectorType>(getVector().getType());
+ auto resultType = llvm::cast<VectorType>(getType());
+
+ if (matrixType.getNumColumns() != vectorType.getNumElements())
+ return emitOpError("matrix columns (")
+ << matrixType.getNumColumns()
+ << ") must match vector operand size ("
+ << vectorType.getNumElements() << ")";
+
+ if (resultType.getNumElements() != matrixType.getNumRows())
+ return emitOpError("result size (")
+ << resultType.getNumElements()
+ << ") must match the matrix rows ("
+ << matrixType.getNumRows() << ")";
+
+ auto matrixElementType = matrixType.getElementType();
+ if (matrixElementType != vectorType.getElementType() ||
+ matrixElementType != resultType.getElementType())
+ return emitOpError("matrix, vector, and result element types must match");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.MatrixTimesMatrix
//===----------------------------------------------------------------------===//
>From 9d0d52595af4efc36394be7f0222c65899e19283 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:22:10 +0100
Subject: [PATCH 5/7] apply clang format
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7eebd0c989fde3..040bf6a34cea78 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1709,14 +1709,12 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
if (matrixType.getNumColumns() != vectorType.getNumElements())
return emitOpError("matrix columns (")
- << matrixType.getNumColumns()
- << ") must match vector operand size ("
+ << matrixType.getNumColumns() << ") must match vector operand size ("
<< vectorType.getNumElements() << ")";
if (resultType.getNumElements() != matrixType.getNumRows())
return emitOpError("result size (")
- << resultType.getNumElements()
- << ") must match the matrix rows ("
+ << resultType.getNumElements() << ") must match the matrix rows ("
<< matrixType.getNumRows() << ")";
auto matrixElementType = matrixType.getElementType();
>From c5a8c94eceb21a04d69aa21db3a7feeed14f5b39 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 18:47:50 +0100
Subject: [PATCH 6/7] newline
---
mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 04220a702da5b3..37e7514d664ef0 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -154,4 +154,4 @@ func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3
// expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}}
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
return
-}
\ No newline at end of file
+}
>From f3fc230c5956c7c6acffb85869ceb54cb6c1d1a0 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 18:35:39 +0100
Subject: [PATCH 7/7] Remove redundant test
---
mlir/test/Target/SPIRV/matrix.mlir | 7 -------
1 file changed, 7 deletions(-)
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index fde379af4b54ee..0ec1dc27e4e932 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -42,13 +42,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
spirv.ReturnValue %result : vector<4xf32>
}
-
- // CHECK-LABEL: @matrix_times_vector_2
- spirv.func @matrix_times_vector_2(%arg0: vector<3xf32>, %arg1: !spirv.matrix<3 x vector<4xf32>>) -> vector<4xf32> "None" {
- // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
- %result = spirv.MatrixTimesVector %arg1, %arg0 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
- spirv.ReturnValue %result : vector<4xf32>
- }
// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
More information about the Mlir-commits
mailing list