[Mlir-commits] [mlir] 34c4852 - [mlir][spirv] Add MatrixTimesMatrix operation

Lei Zhang llvmlistbot at llvm.org
Tue Jul 7 18:38:39 PDT 2020


Author: HazemAbdelhafez
Date: 2020-07-07T21:32:50-04:00
New Revision: 34c485201507de2602b70b0df368dbc6c0b0fecb

URL: https://github.com/llvm/llvm-project/commit/34c485201507de2602b70b0df368dbc6c0b0fecb
DIFF: https://github.com/llvm/llvm-project/commit/34c485201507de2602b70b0df368dbc6c0b0fecb.diff

LOG: [mlir][spirv] Add MatrixTimesMatrix operation

Add MatrixTimesMatrix operation to SPIRV Dialect and add NoSideEffect trait
to Matrix ops.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D82671

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
    mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
    mlir/test/Dialect/SPIRV/matrix-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index fab1f63f7f30..f114c878569d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3166,6 +3166,7 @@ def SPV_OC_OpSMod                      : I32EnumAttrCase<"OpSMod", 139>;
 def SPV_OC_OpFRem                      : I32EnumAttrCase<"OpFRem", 140>;
 def SPV_OC_OpFMod                      : I32EnumAttrCase<"OpFMod", 141>;
 def SPV_OC_OpMatrixTimesScalar         : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
+def SPV_OC_OpMatrixTimesMatrix         : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
 def SPV_OC_OpLogicalEqual              : I32EnumAttrCase<"OpLogicalEqual", 164>;
 def SPV_OC_OpLogicalNotEqual           : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
 def SPV_OC_OpLogicalOr                 : I32EnumAttrCase<"OpLogicalOr", 166>;
@@ -3273,38 +3274,38 @@ def SPV_OpcodeAttr :
       SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
       SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
       SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
-      SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
-      SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
-      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
-      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
-      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
-      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
-      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
-      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
-      SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
-      SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
-      SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
-      SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
-      SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
-      SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
-      SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
-      SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
-      SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
-      SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
-      SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
-      SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
-      SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
-      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
-      SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
-      SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
-      SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
-      SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
-      SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
-      SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
-      SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
-      SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
-      SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
-      SPV_OC_OpCooperativeMatrixLengthNV
+      SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual,
+      SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd,
+      SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual,
+      SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual,
+      SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan,
+      SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual,
+      SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual,
+      SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan,
+      SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual,
+      SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual,
+      SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical,
+      SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr,
+      SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot,
+      SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract,
+      SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier,
+      SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak,
+      SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd,
+      SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin,
+      SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd,
+      SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge,
+      SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
+      SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
+      SPV_OC_OpUnreachable, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed,
+      SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot,
+      SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
+      SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
+      SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin,
+      SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax,
+      SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax,
+      SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV,
+      SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
+      SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
index 8545c7ad29e2..70479e743b0d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
@@ -12,10 +12,65 @@
 
 #ifndef SPIRV_MATRIX_OPS
 #define SPIRV_MATRIX_OPS
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 // -----
 
-def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
+def SPV_MatrixTimesMatrixOp : SPV_Op<"MatrixTimesMatrix", [NoSideEffect]> {
+  let summary = "Linear-algebraic multiply of LeftMatrix X RightMatrix.";
+
+  let description = [{
+    Result Type must be an OpTypeMatrix whose Column Type is a vector of
+    floating-point type.
+
+    LeftMatrix must be a matrix whose Column Type is the same as the Column
+    Type in Result Type.
+
+    RightMatrix must be a matrix with the same Component Type as the
+    Component Type in Result Type. Its number of columns must equal the
+    number of columns in Result Type. Its columns must have the same number
+    of components as the number of columns in LeftMatrix.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    matrix-times-matrix-op ::= ssa-id `=` `spv.MatrixTimesMatrix` ssa-use,
+    ssa-use `:` matrix-type `,` matrix-type `->` matrix-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %0 = spv.MatrixTimesMatrix %matrix_1, %matrix_2 :
+        !spv.matrix<4 x vector<3xf32>>, !spv.matrix<3 x vector<4xf32>> ->
+        !spv.matrix<4 x vector<4xf32>>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_Matrix]>
+  ];
+
+  let arguments = (ins
+    SPV_AnyMatrix:$leftmatrix,
+    SPV_AnyMatrix:$rightmatrix
+  );
+
+  let results = (outs
+    SPV_AnyMatrix:$result
+  );
+  let assemblyFormat = [{
+    operands attr-dict `:` type($leftmatrix) `,` type($rightmatrix) `->` type($result)
+  }];
+  let verifier = [{ return verifyMatrixTimesMatrix(*this); }];
+}
+
+// -----
+
+def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", [NoSideEffect]> {
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
@@ -79,7 +134,7 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
 
 // -----
 
-def SPV_TransposeOp : SPV_Op<"Transpose", []> {
+def SPV_TransposeOp : SPV_Op<"Transpose", [NoSideEffect]> {
   let summary = "Transpose a matrix.";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index d2dac563bfcf..91ad96fed7b0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -410,13 +410,23 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
                                                     Type columnType,
                                                     uint32_t columnCount);
 
-  /// Returns true if the matrix elements are vectors of float elements
+  /// Returns true if the matrix elements are vectors of float elements.
   static bool isValidColumnType(Type columnType);
 
-  Type getElementType() const;
+  Type getColumnType() const;
+
+  /// Returns the number of rows.
+  unsigned getNumRows() const;
+
+  /// Returns the number of columns.
+  unsigned getNumColumns() const;
 
+  /// Returns total number of elements (rows*columns).
   unsigned getNumElements() const;
 
+  /// Returns the elements' type (i.e, single element type).
+  Type getElementType() const;
+
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      Optional<spirv::StorageClass> storage = llvm::None);
   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index fbc644d38ae3..47440265239d 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -723,7 +723,7 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
 }
 
 static void print(MatrixType type, DialectAsmPrinter &os) {
-  os << "matrix<" << type.getNumElements() << " x " << type.getElementType();
+  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
   os << ">";
 }
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 32d13e6afd61..1f1b6e11c336 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2779,37 +2779,30 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
   // auto-generated verify method.
 
   auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
-  // Check that the scalar type is the same as the matrix components type.
-  if (auto inputMatrixColumns =
-          inputMatrix.getElementType().dyn_cast<VectorType>()) {
-    if (op.scalar().getType() != inputMatrixColumns.getElementType())
-      return op.emitError("input matrix components' type and scaling "
-                          "value must have the same type");
-
-    // Note that the next three checks could be done using the AllTypesMatch
-    // trait in the Op definition file but it generates a vague error message.
-
-    // Check that the input and result matrices have the same size
-    auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
-    if (inputMatrix.getNumElements() != resultMatrix.getNumElements())
-      return op.emitError("input and result matrices must have "
-                          "the same number of columns");
-
-    if (auto resultMatrixColumns =
-            resultMatrix.getElementType().dyn_cast<VectorType>()) {
-      // Check that the input and result matrices' columns have the same type
-      if (inputMatrixColumns.getElementType() !=
-          resultMatrixColumns.getElementType())
-        return op.emitError("input and result matrices' columns must "
-                            "have the same component type");
-
-      // Check that the input and result matrices' columns have the same size
-      if (inputMatrixColumns.getNumElements() !=
-          resultMatrixColumns.getNumElements())
-        return op.emitError("input and result matrices' columns must "
-                            "have the same size");
-    }
-  }
+  auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
+
+  // Check that the scalar type is the same as the matrix element type.
+  if (op.scalar().getType() != inputMatrix.getElementType())
+    return op.emitError("input matrix components' type and scaling value must "
+                        "have the same type");
+
+  // Note that the next three checks could be done using the AllTypesMatch
+  // trait in the Op definition file but it generates a vague error message.
+
+  // Check that the input and result matrices have the same columns' count
+  if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
+    return op.emitError("input and result matrices must have the same "
+                        "number of columns");
+
+  // Check that the input and result matrices' have the same rows count
+  if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
+    return op.emitError("input and result matrices' columns must have "
+                        "the same size");
+
+  // Check that the input and result matrices' have the same component type
+  if (inputMatrix.getElementType() != resultMatrix.getElementType())
+    return op.emitError("input and result matrices' columns must have "
+                        "the same component type");
 
   return success();
 }
@@ -2902,24 +2895,56 @@ static LogicalResult verifyTranspose(spirv::TransposeOp op) {
   auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
 
   // Verify that the input and output matrices have correct shapes.
-  if (auto inputMatrixColumns =
-          inputMatrix.getElementType().dyn_cast<VectorType>()) {
-    if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements())
-      return op.emitError("input matrix rows count must be equal to "
-                          "output matrix columns count");
-    if (auto resultMatrixColumns =
-            resultMatrix.getElementType().dyn_cast<VectorType>()) {
-      if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements())
-        return op.emitError("input matrix columns count must be equal "
-                            "to output matrix rows count");
-
-      // Verify that the input and output matrices have the same component type
-      if (inputMatrixColumns.getElementType() !=
-          resultMatrixColumns.getElementType())
-        return op.emitError("input and output matrices must have the "
-                            "same component type");
-    }
-  }
+  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
+    return op.emitError("input matrix rows count must be equal to "
+                        "output matrix columns count");
+
+  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
+    return op.emitError("input matrix columns count must be equal to "
+                        "output matrix rows count");
+
+  // Verify that the input and output matrices have the same component type
+  if (inputMatrix.getElementType() != resultMatrix.getElementType())
+    return op.emitError("input and output matrices must have the same "
+                        "component type");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.MatrixTimesMatrix
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
+  auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
+  auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
+  auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
+
+  // left matrix columns' count and right matrix rows' count must be equal
+  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
+    return op.emitError("left matrix columns' count must be equal to "
+                        "the right matrix rows' count");
+
+  // right and result matrices columns' count must be the same
+  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
+    return op.emitError(
+        "right and result matrices must have equal columns' count");
+
+  // right and result matrices component type must be the same
+  if (rightMatrix.getElementType() != resultMatrix.getElementType())
+    return op.emitError("right and result matrices' component type must"
+                        " be the same");
+
+  // left and result matrices component type must be the same
+  if (leftMatrix.getElementType() != resultMatrix.getElementType())
+    return op.emitError("left and result matrices' component type"
+                        " must be the same");
+
+  // left and result matrices rows count must be the same
+  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
+    return op.emitError("left and result matrices must have equal rows'"
+                        " count");
+
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index b0396bfc1163..07ffe4046483 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -182,7 +182,7 @@ Type CompositeType::getElementType(unsigned index) const {
   case spirv::TypeKind::CooperativeMatrix:
     return cast<CooperativeMatrixNVType>().getElementType();
   case spirv::TypeKind::Matrix:
-    return cast<MatrixType>().getElementType();
+    return cast<MatrixType>().getColumnType();
   case spirv::TypeKind::RuntimeArray:
     return cast<RuntimeArrayType>().getElementType();
   case spirv::TypeKind::Struct:
@@ -202,7 +202,7 @@ unsigned CompositeType::getNumElements() const {
     llvm_unreachable(
         "invalid to query number of elements of spirv::CooperativeMatrix type");
   case spirv::TypeKind::Matrix:
-    return cast<MatrixType>().getNumElements();
+    return cast<MatrixType>().getNumColumns();
   case spirv::TypeKind::RuntimeArray:
     llvm_unreachable(
         "invalid to query number of elements of spirv::RuntimeArray type");
@@ -1086,13 +1086,25 @@ bool MatrixType::isValidColumnType(Type columnType) {
   return false;
 }
 
-Type MatrixType::getElementType() const { return getImpl()->columnType; }
+Type MatrixType::getColumnType() const { return getImpl()->columnType; }
 
-unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; }
+Type MatrixType::getElementType() const {
+  return getImpl()->columnType.cast<VectorType>().getElementType();
+}
+
+unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
+
+unsigned MatrixType::getNumRows() const {
+  return getImpl()->columnType.cast<VectorType>().getShape()[0];
+}
+
+unsigned MatrixType::getNumElements() const {
+  return (getImpl()->columnCount) * getNumRows();
+}
 
 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                Optional<StorageClass> storage) {
-  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+  getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
 }
 
 void MatrixType::getCapabilities(
@@ -1104,5 +1116,5 @@ void MatrixType::getCapabilities(
     capabilities.push_back(ref);
   }
   // Add any capabilities associated with the underlying vectors (i.e., columns)
-  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+  getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 8f6e02de27e7..b015c2391a06 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1127,12 +1127,12 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
 
   if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
     uint32_t elementTypeID = 0;
-    if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) {
+    if (failed(processType(loc, matrixType.getColumnType(), elementTypeID))) {
       return failure();
     }
     typeEnum = spirv::Opcode::OpTypeMatrix;
     operands.push_back(elementTypeID);
-    operands.push_back(matrixType.getNumElements());
+    operands.push_back(matrixType.getNumColumns());
     return success();
   }
 

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
index 6db8a7666768..8900ff6137db 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
@@ -29,6 +29,20 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
     spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
   }
+
+  // CHECK-LABEL: @matrix_times_matrix_1
+  spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{
+    // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+    %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+    spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
+  }
+
+  // CHECK-LABEL: @matrix_times_matrix_2
+  spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{
+    // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
+    %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
+    spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>>
+  }
 }
 
 // -----

diff  --git a/mlir/test/Dialect/SPIRV/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/matrix-ops.mlir
index 09bdf3983005..4ec78344bc38 100644
--- a/mlir/test/Dialect/SPIRV/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/matrix-ops.mlir
@@ -21,6 +21,20 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
     spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
   }
+
+  // CHECK-LABEL: @matrix_times_matrix_1
+  spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{
+    // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+    %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+    spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
+  }
+
+  // CHECK-LABEL: @matrix_times_matrix_2
+  spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{
+    // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
+    %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
+    spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>>
+  }
 }
 
 // -----
@@ -74,3 +88,39 @@ func @transpose_op_type_mismatch(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
    %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>>
    spv.Return
 }
+
+// -----
+
+func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){
+   // expected-error @+1 {{right and result matrices must have equal columns' count}}
+   %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<3 x vector<2xf32>>
+}
+
+// -----
+
+func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){
+   // expected-error @+1 {{left and result matrices must have equal rows' count}}
+   %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<3xf32>>
+}
+
+// -----
+
+func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<2xf32>>){
+   // expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
+   %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<2xf32>> -> !spv.matrix<2 x vector<2xf32>>
+}
+
+// -----
+
+func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : !spv.matrix<3x vector<3xf32>>){
+   // expected-error @+1 {{right and result matrices' component type must be the same}}
+   %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf64>>
+}
+
+
+// -----
+
+func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spv.matrix<3 x vector<3xf64>>, %arg1 : !spv.matrix<3x vector<3xf32>>){
+   // expected-error @+1 {{left and result matrices' component type must be the same}}
+   %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf64>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+}


        


More information about the Mlir-commits mailing list