[Mlir-commits] [mlir] 55d53d4 - [mlir][spirv] Add MatrixTimesScalar operation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 17 15:34:50 PDT 2020


Author: HazemAbdelhafez
Date: 2020-06-17T18:33:47-04:00
New Revision: 55d53d4f5448db87ab0bec903be94b696c8ed3e8

URL: https://github.com/llvm/llvm-project/commit/55d53d4f5448db87ab0bec903be94b696c8ed3e8
DIFF: https://github.com/llvm/llvm-project/commit/55d53d4f5448db87ab0bec903be94b696c8ed3e8.diff

LOG: [mlir][spirv] Add MatrixTimesScalar operation

Summary:
- Define the MatrixTimesScalar operation and add roundtrip tests.
- Added a new base class for matrix-specific operations to avoid invalid operands type mismatch check.
- Created a separate Matrix arithmetic operations td file to add more operations in the future.
- Augmented the automatically generated verify method to print more fine-grained error messages.
- Made minor Updates to the matrix type tests.

Reviewers: antiagainst, rriddle, mravishankar

Reviewed By: antiagainst

Subscribers: mehdi_amini, jpienaar, shauheen, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, bader, grosul1, frgossen, Kayjukh, jurahul, msifontes

Tags: #mlir

Differential Revision: https://reviews.llvm.org/D81677

Added: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
    mlir/test/Dialect/SPIRV/matrix-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/Serialization/matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index a95ed18ca4e7..e64abd70333d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -2994,10 +2994,12 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
 def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
 def SPV_IsCooperativeMatrixType :
   CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
+def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">;
 def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
 def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
 def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
 
+
 // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
 // for the definition of the following types and type categories.
 
@@ -3018,6 +3020,8 @@ def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
 def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
                                SPV_IsCooperativeMatrixType,
                                "any SPIR-V cooperative matrix type">;
+def SPV_AnyMatrix : DialectType<SPIRV_Dialect, SPV_IsMatrixType,
+                                "any SPIR-V matrix type">;
 def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
                                  "any SPIR-V runtime array type">;
 def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
@@ -3028,11 +3032,11 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
 def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
 def SPV_Composite :
     AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
-               SPV_AnyCooperativeMatrix]>;
+               SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>;
 def SPV_Type : AnyTypeOf<[
     SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
     SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
-    SPV_AnyCooperativeMatrix
+    SPV_AnyCooperativeMatrix, SPV_AnyMatrix
   ]>;
 
 def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
@@ -3160,6 +3164,7 @@ def SPV_OC_OpSRem                      : I32EnumAttrCase<"OpSRem", 138>;
 def SPV_OC_OpSMod                      : I32EnumAttrCase<"OpSMod", 139>;
 def SPV_OC_OpFRem                      : I32EnumAttrCase<"OpFRem", 140>;
 def SPV_OC_OpFMod                      : I32EnumAttrCase<"OpFMod", 141>;
+def SPV_OC_OpMatrixTimesScalar         : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
 def SPV_OC_OpLogicalEqual              : I32EnumAttrCase<"OpLogicalEqual", 164>;
 def SPV_OC_OpLogicalNotEqual           : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
 def SPV_OC_OpLogicalOr                 : I32EnumAttrCase<"OpLogicalOr", 166>;
@@ -3266,14 +3271,14 @@ def SPV_OpcodeAttr :
       SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
       SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
       SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
-      SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
-      SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
-      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
-      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
-      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
-      SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
-      SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
-      SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+      SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
+      SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
+      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
+      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
+      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
+      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
       SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
       SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
       SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
new file mode 100644
index 000000000000..c71e0cd54e99
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
@@ -0,0 +1,75 @@
+//===-- SPIRVMatrixOps.td - MLIR SPIR-V Matrix Ops ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains matrix operations for the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_MATRIX_OPS
+#define SPIRV_MATRIX_OPS
+
+// -----
+
+def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
+  let summary = "Scale a floating-point matrix.";
+
+  let description = [{
+    Result Type must be an OpTypeMatrix whose Column Type is a vector of
+    floating-point type.
+
+     The type of Matrix must be the same as Result Type. Each component in
+    each column in Matrix is multiplied by Scalar.
+
+    Scalar must have the same type as the Component Type in Result Type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    matrix-times-scalar-op ::= ssa-id `=` `spv.MatrixTimesScalar` ssa-use,
+    ssa-use `:` matrix-type `,` float-type `->` matrix-type
+
+    ```
+
+    #### Example:
+
+    ```mlir
+
+    %0 = spv.MatrixTimesScalar %matrix, %scalar :
+    !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
+
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_AnyMatrix:$matrix,
+    SPV_Float:$scalar
+  );
+
+  let results = (outs
+    SPV_AnyMatrix:$result
+  );
+
+  // TODO (Hazem): we need just one matrix type given that the input and result
+  // are the same and the scalar's type can be deduced from it.
+  let assemblyFormat = [{
+    operands attr-dict `:` type($matrix) `,` type($scalar) `->` type($result)
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_Matrix]>
+  ];
+
+  let verifier = [{ return verifyMatrixTimesScalar(*this); }];
+}
+
+// -----
+
+#endif // SPIRV_MATRIX_OPS
\ No newline at end of file

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 520ed14c9624..8b3a25037078 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td"
 include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
 include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
+include "mlir/Dialect/SPIRV/SPIRVMatrixOps.td"
 include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
 include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index aee22c4422e6..efe685873bd7 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2760,6 +2760,49 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.MatrixTimesScalar
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
+  // We already checked that result and matrix are both of matrix type in the
+  // auto-generated verify method.
+
+  auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
+  // Check that the scalar type is the same as the matrix components type.
+  if (auto inputMatrixColumns =
+          inputMatrix.getElementType().dyn_cast<VectorType>()) {
+    if (op.scalar().getType() != inputMatrixColumns.getElementType())
+      return op.emitError("input matrix components' type and scaling "
+                          "value must have the same type");
+
+    // Note that the next three checks could be done using the AllTypesMatch
+    // trait in the Op definition file but it generates a vague error message.
+
+    // Check that the input and result matrices have the same size
+    auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
+    if (inputMatrix.getNumElements() != resultMatrix.getNumElements())
+      return op.emitError("input and result matrices must have "
+                          "the same number of columns");
+
+    if (auto resultMatrixColumns =
+            resultMatrix.getElementType().dyn_cast<VectorType>()) {
+      // Check that the input and result matrices' columns have the same type
+      if (inputMatrixColumns.getElementType() !=
+          resultMatrixColumns.getElementType())
+        return op.emitError("input and result matrices' columns must "
+                            "have the same component type");
+
+      // Check that the input and result matrices' columns have the same size
+      if (inputMatrixColumns.getNumElements() !=
+          resultMatrixColumns.getNumElements())
+        return op.emitError("input and result matrices' columns must "
+                            "have the same size");
+    }
+  }
+  return success();
+}
+
 namespace mlir {
 namespace spirv {
 

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
index b27702bf50d8..8dc90cb504e8 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
@@ -1,10 +1,25 @@
 // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
 
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
-  spv.func @matrix_type(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>, %arg1 : i32) "None" {
-    // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
-    %2 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
-    spv.Return
+  // CHECK-LABEL: @matrix_access_chain
+  spv.func @matrix_access_chain(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>, %arg1 : i32) -> !spv.ptr<vector<3xf32>, Function> "None" {
+    // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>
+    %0 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>
+    spv.ReturnValue %0 : !spv.ptr<vector<3xf32>, Function>
+  }
+
+  // CHECK-LABEL: @matrix_times_scalar_1
+  spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
+    // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
+    %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
+    spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
+  }
+
+  // CHECK-LABEL: @matrix_times_scalar_2
+  spv.func @matrix_times_scalar_2(%arg0 : !spv.matrix<3 x vector<3xf16>>, %arg1 : f16) -> !spv.matrix<3 x vector<3xf16>> "None" {
+    // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>>
+    %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>>
+    spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>>
   }
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/matrix-ops.mlir
new file mode 100644
index 000000000000..8079b4ba88f6
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/matrix-ops.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+  // CHECK-LABEL: @matrix_times_scalar
+  spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
+    // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
+    %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
+    spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
+  }
+}
+
+// -----
+
+func @input_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f16) -> () {
+  // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+  %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f16 -> !spv.matrix<3 x vector<3xf32>>
+}
+
+// -----
+
+func @input_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f64) -> () {
+  // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+  %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f64 -> !spv.matrix<3 x vector<3xf32>>
+}
+
+// -----
+
+func @input_output_component_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
+   // expected-error @+1 {{input and result matrices' columns must have the same component type}}
+   %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf64>>
+}
+
+// -----
+
+func @input_output_size_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
+   // expected-error @+1 {{input and result matrices must have the same number of columns}}
+   %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>>
+}
+
+
+


        


More information about the Mlir-commits mailing list