[Mlir-commits] [mlir] 3c278e5 - [mlir][spirv] Fix spirv.MatrixTimesScalar for cooperative matrix
Lei Zhang
llvmlistbot at llvm.org
Mon Dec 5 14:13:31 PST 2022
Author: Lei Zhang
Date: 2022-12-05T22:13:23Z
New Revision: 3c278e5e274b3aaa173eae71ccd861c8729b37c0
URL: https://github.com/llvm/llvm-project/commit/3c278e5e274b3aaa173eae71ccd861c8729b37c0
DIFF: https://github.com/llvm/llvm-project/commit/3c278e5e274b3aaa173eae71ccd861c8729b37c0.diff
LOG: [mlir][spirv] Fix spirv.MatrixTimesScalar for cooperative matrix
spirv.MatrixTimesScalar is allowed to use cooperative matrix.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D139279
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
mlir/test/Target/SPIRV/matrix.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6a89ac3bdfe5..8fa89a121bdd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4119,6 +4119,9 @@ class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
SPIRV_CoopMatrixOfType<[type]>]>;
+class SPIRV_MatrixOrCoopMatrixOf<Type type> :
+ AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixOfType<[type]>]>;
+
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index 93ba8940061b..b6b8c742dee4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -70,7 +70,8 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
// -----
-def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure]> {
+def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
+ "MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
let summary = "Scale a floating-point matrix.";
let description = [{
@@ -108,18 +109,16 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure]> {
];
let arguments = (ins
- SPIRV_AnyMatrix:$matrix,
+ SPIRV_MatrixOrCoopMatrixOf<SPIRV_Float>:$matrix,
SPIRV_Float:$scalar
);
let results = (outs
- SPIRV_AnyMatrix:$result
+ SPIRV_MatrixOrCoopMatrixOf<SPIRV_Float>:$result
);
- // TODO: we need just one matrix type given that the input and result are the
- // same and the scalar's type can be deduced from it.
let assemblyFormat = [{
- operands attr-dict `:` type($matrix) `,` type($scalar) `->` type($result)
+ operands attr-dict `:` type($matrix) `,` type($scalar)
}];
let availability = [
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c87ac89efe17..1a93882e8f5f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4128,35 +4128,20 @@ LogicalResult spirv::INTELJointMatrixMadOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::MatrixTimesScalarOp::verify() {
- // We already checked that result and matrix are both of matrix type in the
- // auto-generated verify method.
-
- auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
- auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
+ if (auto inputCoopmat =
+ getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) {
+ if (inputCoopmat.getElementType() != getScalar().getType())
+ return emitError("input matrix components' type and scaling value must "
+ "have the same type");
+ return success();
+ }
// Check that the scalar type is the same as the matrix element type.
+ auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
if (getScalar().getType() != inputMatrix.getElementType())
return emitError("input matrix components' type and scaling value must "
"have the same type");
- // Note that the next three checks could be done using the AllTypesMatch
- // trait in the Op definition file but it generates a vague error message.
-
- // Check that the input and result matrices have the same columns' count
- if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
- return emitError("input and result matrices must have the same "
- "number of columns");
-
- // Check that the input and result matrices' have the same rows count
- if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
- return emitError("input and result matrices' columns must have "
- "the same size");
-
- // Check that the input and result matrices' have the same component type
- if (inputMatrix.getElementType() != resultMatrix.getElementType())
- return emitError("input and result matrices' columns must have "
- "the same component type");
-
return success();
}
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 25891c512d38..8cdf2390d723 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -1,13 +1,20 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
- // CHECK-LABEL: @matrix_times_scalar
- spirv.func @matrix_times_scalar(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
- // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
+ // CHECK-LABEL: @matrix_times_scalar_1
+ spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
}
+ // CHECK-LABEL: @matrix_times_scalar_2
+ spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+ spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
+ }
+
// CHECK-LABEL: @matrix_transpose_1
spirv.func @matrix_transpose_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>) -> !spirv.matrix<2 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spirv.Transpose {{%.*}} : !spirv.matrix<3 x vector<2xf32>> -> !spirv.matrix<2 x vector<3xf32>>
@@ -39,54 +46,42 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
-func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) -> () {
+func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) {
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16 -> !spirv.matrix<3 x vector<3xf32>>
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16
+ return
}
// -----
-func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) -> () {
+func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) {
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64 -> !spirv.matrix<3 x vector<3xf32>>
-}
-
-// -----
-
-func.func @input_output_component_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
- // expected-error @+1 {{input and result matrices' columns must have the same component type}}
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf64>>
-}
-
-// -----
-
-func.func @input_output_size_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
- // expected-error @+1 {{input and result matrices must have the same number of columns}}
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<4 x vector<3xf32>>
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64
+ return
}
// -----
-func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () {
+func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
%result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<3 x vector<3xf32>>
- spirv.Return
+ return
}
// -----
-func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () {
+func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
%result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<2 x vector<4xf32>>
- spirv.Return
+ return
}
// -----
-func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () {
+func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
// expected-error @+1 {{input and output matrices must have the same component type}}
%result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<4 x vector<3xf16>>
- spirv.Return
+ return
}
// -----
@@ -94,6 +89,7 @@ func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>)
func.func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
// expected-error @+1 {{right and result matrices must have equal columns' count}}
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<3 x vector<2xf32>>
+ return
}
// -----
@@ -101,6 +97,7 @@ func.func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spirv.matrix<3 x v
func.func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
// expected-error @+1 {{left and result matrices must have equal rows' count}}
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<3xf32>>
+ return
}
// -----
@@ -108,6 +105,7 @@ func.func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spirv.matrix<3 x v
func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<2xf32>>){
// expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<2xf32>> -> !spirv.matrix<2 x vector<2xf32>>
+ return
}
// -----
@@ -115,6 +113,7 @@ func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x v
func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
// expected-error @+1 {{right and result matrices' component type must be the same}}
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf64>>
+ return
}
@@ -123,4 +122,5 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
// expected-error @+1 {{left and result matrices' component type must be the same}}
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
+ return
}
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 3c534d9e8f8d..0b71b3f24b19 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -10,17 +10,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @matrix_times_scalar_1
spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
- // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_times_scalar_2
spirv.func @matrix_times_scalar_2(%arg0 : !spirv.matrix<3 x vector<3xf16>>, %arg1 : f16) -> !spirv.matrix<3 x vector<3xf16>> "None" {
- // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf16>>, f16 -> !spirv.matrix<3 x vector<3xf16>>
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf16>>, f16 -> !spirv.matrix<3 x vector<3xf16>>
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf16>>, f16
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf16>>, f16
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf16>>
+ }
+ // CHECK-LABEL: @matrix_times_scalar_3
+ spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+ spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
}
// CHECK-LABEL: @matrix_transpose_1
More information about the Mlir-commits
mailing list