[Mlir-commits] [mlir] [mlir][spirv] Add definition for VectorTimesMatrixOp (PR #124571)
Igor Wodiany
llvmlistbot at llvm.org
Mon Jan 27 07:44:33 PST 2025
https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/124571
Adding op as defined in section 3.52.13. (Arithmetic Instructions) of the SPIR-V specification.
>From a01c5487dd7039deb24e7827e5bf3f96e1f1dd14 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Mon, 27 Jan 2025 15:16:44 +0000
Subject: [PATCH] [mlir][spirv] Add definition for VectorTimesMatrixOp
Adding op as defined in section 3.52.13. (Arithmetic Instructions)
of the SPIR-V specification.
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 6 ++-
.../mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td | 51 +++++++++++++++++--
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 28 ++++++++++
mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir | 40 ++++++++++++++-
mlir/test/Target/SPIRV/matrix.mlir | 7 +++
5 files changed, 126 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index c84677d26a8b69..2f50f9b6111822 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4387,7 +4387,8 @@ def SPIRV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
-def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
+def SPIRV_OC_OpVectorTimesMatrix : I32EnumAttrCase<"OpVectorTimesMatrix", 144>;
+def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4559,7 +4560,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
- SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
+ SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+ SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector,
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index 5bd99386e00858..78b5fa2c228dc2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -63,8 +63,7 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
// -----
-def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
- "MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
+def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
let summary = "Scale a floating-point matrix.";
let description = [{
@@ -115,7 +114,7 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
// -----
def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
- let summary = "Linear-algebraic multiply of matrix X vector.";
+ let summary = "Linear-algebraic Matrix X Vector.";
let description = [{
Result Type must be a vector of floating-point type.
@@ -198,4 +197,50 @@ def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
// -----
+def SPIRV_VectorTimesMatrixOp : SPIRV_Op<"VectorTimesMatrix", [Pure]> {
+ let summary = "Linear-algebraic Vector X Matrix.";
+
+ let description = [{
+ Result Type must be a vector of floating-point type.
+
+ Vector must be a vector with the same Component Type as the Component
+ Type in Result Type. Its number of components must equal the number of
+ components in each column in Matrix.
+
+ Matrix must be a matrix with the same Component Type as the Component
+ Type in Result Type. Its number of columns must equal the number of
+ components in Result Type.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %result = spirv.VectorTimesMatrix %vector, %matrix : vector<4xf32>, !spirv.matrix<4 x vector<4xf32>> -> vector<4xf32>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[SPIRV_C_Matrix]>
+ ];
+
+ let arguments = (ins
+ SPIRV_AnyVector:$vector,
+ SPIRV_AnyMatrix:$matrix
+ );
+
+ let results = (outs
+ SPIRV_VectorOf<SPIRV_Float>:$result
+ );
+
+ let assemblyFormat = [{
+ operands attr-dict `:` type($vector) `,` type($matrix) `->` type($result)
+ }];
+}
+
+// -----
+
#endif // MLIR_DIALECT_SPIRV_IR_MATRIX_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 040bf6a34cea78..2273e8073503d4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1725,6 +1725,34 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.VectorTimesMatrix
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::VectorTimesMatrixOp::verify() {
+ auto vectorType = llvm::cast<VectorType>(getVector().getType());
+ auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
+ auto resultType = llvm::cast<VectorType>(getType());
+
+ if (matrixType.getNumRows() != vectorType.getNumElements())
+ return emitOpError("number of components in vector must equal the number "
+ "of components in each column in matrix");
+
+ if (resultType.getNumElements() != matrixType.getNumColumns())
+ return emitOpError("number of columns in matrix must equal the number of "
+ "components in result");
+
+ if (resultType.getElementType() != vectorType.getElementType())
+ return emitOpError("vector must be a vector with the same component type "
+ "as the component type in result");
+
+ if (matrixType.getElementType() != resultType.getElementType())
+ return emitOpError("matrix must be a matrix with the same component type "
+ "as the component type in result");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.MatrixTimesMatrix
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 37e7514d664ef0..79379b45805ac4 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %result : vector<4xf32>
}
+ // CHECK-LABEL: @vector_times_matrix_1
+ spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
+ // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ spirv.ReturnValue %result : vector<4xf32>
+ }
+
// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -123,7 +130,6 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
return
}
-
// -----
func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
@@ -155,3 +161,35 @@ func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
return
}
+
+// -----
+
+func.func @vector_times_matrix_vector_matrix_mismatch(%arg0: vector<4xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+ // expected-error @+1 {{number of components in vector must equal the number of components in each column in matrix}}
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<4xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @vector_times_matrix_result_matrix_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+ // expected-error @+1 {{number of columns in matrix must equal the number of components in result}}
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @vector_times_matrix_vector_type_mismatch(%arg0: vector<3xi32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+ // expected-error @+1 {{vector must be a vector with the same component type as the component type in result}}
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xi32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ return
+}
+
+// -----
+
+func.func @vector_times_matrix_matrix_type_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf16>>) {
+ // expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}}
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf16>> -> vector<4xf32>
+ return
+}
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 0ec1dc27e4e932..452f8fc16f2588 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -42,6 +42,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
spirv.ReturnValue %result : vector<4xf32>
}
+
+ // CHECK-LABEL: @vector_times_matrix_1
+ spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
+ // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ spirv.ReturnValue %result : vector<4xf32>
+ }
// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
More information about the Mlir-commits
mailing list