[Mlir-commits] [mlir] 3c278e5 - [mlir][spirv] Fix spirv.MatrixTimesScalar for cooperative matrix

Lei Zhang llvmlistbot at llvm.org
Mon Dec 5 14:13:31 PST 2022


Author: Lei Zhang
Date: 2022-12-05T22:13:23Z
New Revision: 3c278e5e274b3aaa173eae71ccd861c8729b37c0

URL: https://github.com/llvm/llvm-project/commit/3c278e5e274b3aaa173eae71ccd861c8729b37c0
DIFF: https://github.com/llvm/llvm-project/commit/3c278e5e274b3aaa173eae71ccd861c8729b37c0.diff

LOG: [mlir][spirv] Fix spirv.MatrixTimesScalar for cooperative matrix

spirv.MatrixTimesScalar is allowed to use cooperative matrix.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D139279

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
    mlir/test/Target/SPIRV/matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6a89ac3bdfe5..8fa89a121bdd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4119,6 +4119,9 @@ class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
     AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
                SPIRV_CoopMatrixOfType<[type]>]>;
 
+class SPIRV_MatrixOrCoopMatrixOf<Type type> :
+    AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixOfType<[type]>]>;
+
 def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
 def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index 93ba8940061b..b6b8c742dee4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -70,7 +70,8 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
 
 // -----
 
-def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure]> {
+def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
+    "MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
@@ -108,18 +109,16 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure]> {
   ];
 
   let arguments = (ins
-    SPIRV_AnyMatrix:$matrix,
+    SPIRV_MatrixOrCoopMatrixOf<SPIRV_Float>:$matrix,
     SPIRV_Float:$scalar
   );
 
   let results = (outs
-    SPIRV_AnyMatrix:$result
+    SPIRV_MatrixOrCoopMatrixOf<SPIRV_Float>:$result
   );
 
-  // TODO: we need just one matrix type given that the input and result are the
-  // same and the scalar's type can be deduced from it.
   let assemblyFormat = [{
-    operands attr-dict `:` type($matrix) `,` type($scalar) `->` type($result)
+    operands attr-dict `:` type($matrix) `,` type($scalar)
   }];
 
   let availability = [

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c87ac89efe17..1a93882e8f5f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4128,35 +4128,20 @@ LogicalResult spirv::INTELJointMatrixMadOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::MatrixTimesScalarOp::verify() {
-  // We already checked that result and matrix are both of matrix type in the
-  // auto-generated verify method.
-
-  auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
-  auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
+  if (auto inputCoopmat =
+          getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) {
+    if (inputCoopmat.getElementType() != getScalar().getType())
+      return emitError("input matrix components' type and scaling value must "
+                       "have the same type");
+    return success();
+  }
 
   // Check that the scalar type is the same as the matrix element type.
+  auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
   if (getScalar().getType() != inputMatrix.getElementType())
     return emitError("input matrix components' type and scaling value must "
                      "have the same type");
 
-  // Note that the next three checks could be done using the AllTypesMatch
-  // trait in the Op definition file but it generates a vague error message.
-
-  // Check that the input and result matrices have the same columns' count
-  if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
-    return emitError("input and result matrices must have the same "
-                     "number of columns");
-
-  // Check that the input and result matrices' have the same rows count
-  if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
-    return emitError("input and result matrices' columns must have "
-                     "the same size");
-
-  // Check that the input and result matrices' have the same component type
-  if (inputMatrix.getElementType() != resultMatrix.getElementType())
-    return emitError("input and result matrices' columns must have "
-                     "the same component type");
-
   return success();
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 25891c512d38..8cdf2390d723 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -1,13 +1,20 @@
 // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
-  // CHECK-LABEL: @matrix_times_scalar
-  spirv.func @matrix_times_scalar(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
-    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
-    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
+  // CHECK-LABEL: @matrix_times_scalar_1
+  spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32
+    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32
     spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
   }
 
+  // CHECK-LABEL: @matrix_times_scalar_2
+  spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+    spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
+  }
+
   // CHECK-LABEL: @matrix_transpose_1
   spirv.func @matrix_transpose_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>) -> !spirv.matrix<2 x vector<3xf32>> "None" {
     // CHECK: {{%.*}} = spirv.Transpose {{%.*}} : !spirv.matrix<3 x vector<2xf32>> -> !spirv.matrix<2 x vector<3xf32>>
@@ -39,54 +46,42 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 
 // -----
 
-func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) -> () {
+func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) {
   // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
-  %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16 -> !spirv.matrix<3 x vector<3xf32>>
+  %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16
+  return
 }
 
 // -----
 
-func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) -> () {
+func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) {
   // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
-  %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64 -> !spirv.matrix<3 x vector<3xf32>>
-}
-
-// -----
-
-func.func @input_output_component_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
-   // expected-error @+1 {{input and result matrices' columns must have the same component type}}
-   %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf64>>
-}
-
-// -----
-
-func.func @input_output_size_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
-   // expected-error @+1 {{input and result matrices must have the same number of columns}}
-   %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<4 x vector<3xf32>>
+  %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64
+  return
 }
 
 // -----
 
-func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () {
+func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
    // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
    %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<3 x vector<3xf32>>
-   spirv.Return
+   return
 }
 
 // -----
 
-func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () {
+func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
    // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
    %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<2 x vector<4xf32>>
-   spirv.Return
+   return
 }
 
 // -----
 
-func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () {
+func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
    // expected-error @+1 {{input and output matrices must have the same component type}}
    %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<4 x vector<3xf16>>
-   spirv.Return
+   return
 }
 
 // -----
@@ -94,6 +89,7 @@ func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>)
 func.func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
    // expected-error @+1 {{right and result matrices must have equal columns' count}}
    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<3 x vector<2xf32>>
+   return
 }
 
 // -----
@@ -101,6 +97,7 @@ func.func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spirv.matrix<3 x v
 func.func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
    // expected-error @+1 {{left and result matrices must have equal rows' count}}
    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<3xf32>>
+   return
 }
 
 // -----
@@ -108,6 +105,7 @@ func.func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spirv.matrix<3 x v
 func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<2xf32>>){
    // expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<2xf32>> -> !spirv.matrix<2 x vector<2xf32>>
+   return
 }
 
 // -----
@@ -115,6 +113,7 @@ func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x v
 func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
    // expected-error @+1 {{right and result matrices' component type must be the same}}
    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf64>>
+   return
 }
 
 
@@ -123,4 +122,5 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
 func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
    // expected-error @+1 {{left and result matrices' component type must be the same}}
    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
+   return
 }

diff  --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 3c534d9e8f8d..0b71b3f24b19 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -10,17 +10,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 
   // CHECK-LABEL: @matrix_times_scalar_1
   spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
-    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
-    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
+    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32
+    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32
     spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
   }
 
   // CHECK-LABEL: @matrix_times_scalar_2
   spirv.func @matrix_times_scalar_2(%arg0 : !spirv.matrix<3 x vector<3xf16>>, %arg1 : f16) -> !spirv.matrix<3 x vector<3xf16>> "None" {
-    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf16>>, f16 -> !spirv.matrix<3 x vector<3xf16>>
-    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf16>>, f16 -> !spirv.matrix<3 x vector<3xf16>>
+    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf16>>, f16
+    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf16>>, f16
     spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf16>>
+  }
 
+  // CHECK-LABEL: @matrix_times_scalar_3
+  spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+    spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
   }
 
   // CHECK-LABEL: @matrix_transpose_1


        


More information about the Mlir-commits mailing list