[Mlir-commits] [mlir] [mlir][spirv] Add `CooperativeMatrixMulAdd` op (PR #65617)

Jakub Kuderski llvmlistbot at llvm.org
Thu Sep 7 08:09:58 PDT 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/65617:

>From f8bc88836262532f383e6fbdead003ca64b03e15 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <kubak at google.com>
Date: Wed, 19 Jul 2023 10:48:14 -0400
Subject: [PATCH 1/2] [mlir][spirv] Add `CooperativeMatrixMulAdd` op

This is the last remaining op from the `SPV_KHR_cooperative_matrix` extension.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  21 +-
 .../SPIRV/IR/SPIRVCooperativeMatrixOps.td     | 108 ++++++++-
 .../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp |  57 +++++
 .../SPIRV/IR/cooperative-matrix-ops.mlir      | 214 ++++++++++++++++++
 4 files changed, 398 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 0ac660b8c1c3e1..cc4417077d459c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4071,6 +4071,23 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
       SPIRV_KHR_CML_RowMajor, SPIRV_KHR_CML_ColumnMajor
     ]>;
 
+// Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
+def SPIRV_KHR_CMO_None           : I32BitEnumAttrCaseNone<"None">;
+def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
+def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
+def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
+def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
+def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 16>;
+
+def SPIRV_KHR_CooperativeMatrixOperandsAttr :
+    SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",
+                      "valid SPIR-V Cooperative Matrix Operands (KHR)",
+                      "cooperative_matrix_operands_khr", [
+      SPIRV_KHR_CMO_None, SPIRV_KHR_CMO_MatrixA_Signed,
+      SPIRV_KHR_CMO_MatrixB_Signed, SPIRV_KHR_CMO_MatrixC_Signed,
+      SPIRV_KHR_CMO_Result_Signed, SPIRV_KHR_CMO_AccSat
+    ]>;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V attribute definitions
 //===----------------------------------------------------------------------===//
@@ -4447,6 +4464,7 @@ def SPIRV_OC_OpSUDotAccSat                : I32EnumAttrCase<"OpSUDotAccSat", 445
 def SPIRV_OC_OpTypeCooperativeMatrixKHR   : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
 def SPIRV_OC_OpCooperativeMatrixLoadKHR   : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
 def SPIRV_OC_OpCooperativeMatrixStoreKHR  : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
+def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
 def SPIRV_OC_OpTypeCooperativeMatrixNV    : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
 def SPIRV_OC_OpCooperativeMatrixLoadNV    : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
@@ -4548,7 +4566,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
       SPIRV_OC_OpSUDotAccSat,
       SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
-      SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+      SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
+      SPIRV_OC_OpCooperativeMatrixLengthKHR,
       SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
       SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
       SPIRV_OC_OpCooperativeMatrixLengthNV,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 6de744039483b9..7060aa80dc113e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -203,6 +203,112 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
   let results = (outs);
 }
 
+// -----
+
+def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMulAdd",
+  [Pure, AllTypesMatch<["c", "result"]>]> {
+  let summary = "Returns the result of `(A x B) + C` of matrices A, B, and C";
+
+  let description = [{
+    Linear-algebraic matrix multiply of A by B and then component-wise add C.
+    The order of the operations is implementation-dependent. The internal
+    precision of floating-point operations is defined by the client API. Integer
+    operations used in the multiplication of A by B are performed at the
+    precision of the Result Type and the resulting value will equal the
+    low-order N bits of the correct result R, where N is the result width and R
+    is computed with enough precision to avoid overflow and underflow if the
+    SaturatingAccumulation Cooperative Matrix Operand is not present. If the
+    SaturatingAccumulation Cooperative Matrix Operand is present and overflow or
+    underflow occurs as part of calculating that intermediate result, the result
+    of the instruction is undefined. Integer additions of the elements of that
+    intermediate result with those of C are performed at the precision of Result
+    Type, are exact, and are saturating if the SaturatingAccumulation
+    Cooperative Matrix Operand is present, with the signedness of the saturation
+    being that of the components of Result Type. If the SaturatingAccumulation
+    Cooperative Matrix Operand is not present then the resulting value will
+    equal the low-order N bits of the correct result R, where N is the result
+    width and R is computed with enough precision to avoid overflow and
+    underflow.
+
+    Result Type must be a cooperative matrix type with M rows and N columns
+    whose Use must be MatrixAccumulatorKHR.
+
+    A is a cooperative matrix with M rows and K columns whose Use must be
+    MatrixAKHR.
+
+    B is a cooperative matrix with K rows and N columns whose Use must be
+    MatrixBKHR.
+
+    C is a cooperative matrix with M rows and N columns whose Use must be
+    MatrixAccumulatorKHR.
+
+    The values of M, N, and K must be consistent across the result and operands.
+    This is referred to as an MxNxK matrix multiply.
+
+    A, B, C, and Result Type must have the same scope, and this defines the
+    scope of the operation. A, B, C, and Result Type need not necessarily have
+    the same component type, this is defined by the client API.
+
+    If the Component Type of any matrix operand is an integer type, then its
+    components are treated as signed if the Matrix{A,B,C,Result}SignedComponents
+    Cooperative Matrix Operand is present and are treated as unsigned otherwise.
+
+    Cooperative Matrix Operands is an optional Cooperative Matrix Operand
+    literal. If not present, it is the same as specifying the Cooperative Matrix
+    Operand None.
+
+    For a given dynamic instance of this instruction, all invocations in a given
+    scope instance must be active or all must be inactive (where the scope is
+    the scope of the operation).
+
+    ``` {.ebnf}
+    cooperative-matrixmuladd-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixMulAdd`
+                              ssa-use `,` ssa-use `,` ssa-use
+                              (`<` matrix-operands `>`)? `:`
+                              a-cooperative-matrix-type `,`
+                              b-cooperative-matrix-type `->`
+                                result-cooperative-matrix-type
+    ```
+
+    #### Example:
+
+    ```
+    %0 = spirv.KHR.CooperativeMatrixMulAdd %matA, %matB, %matC :
+      !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
+      !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB> ->
+        !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>
+
+    %1 = spirv.KHR.CooperativeMatrixMulAdd %matA, %matB, %matC, <ASigned | AccSat> :
+      !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+      !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+        !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+    ```
+  }];
+
+  let assemblyFormat = [{
+    $a `,` $b `,` $c ( `,` $matrix_operands^ )? attr-dict `:`
+      type($a) `,` type($b) `->` type($c)
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_6>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_KHR_cooperative_matrix]>,
+    Capability<[SPIRV_C_CooperativeMatrixKHR]>
+  ];
+
+  let arguments = (ins
+    SPIRV_AnyCooperativeMatrix:$a,
+    SPIRV_AnyCooperativeMatrix:$b,
+    SPIRV_AnyCooperativeMatrix:$c,
+    OptionalAttr<SPIRV_KHR_CooperativeMatrixOperandsAttr>:$matrix_operands
+  );
+
+  let results = (outs
+    SPIRV_AnyCooperativeMatrix:$result
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // SPV_NV_cooperative_matrix extension ops.
 //===----------------------------------------------------------------------===//
@@ -380,7 +486,7 @@ def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAd
   }];
 
   let assemblyFormat = [{
-    operands attr-dict`:` type($a) `,` type($b) `->` type($c)
+    operands attr-dict `:` type($a) `,` type($b) `->` type($c)
   }];
 
   let availability = [
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index bdd87677866501..cee39d7791aaf8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -11,7 +11,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "llvm/ADT/STLExtras.h"
+#include <cstdint>
 
 using namespace mlir::spirv::AttrNames;
 
@@ -151,6 +154,60 @@ LogicalResult KHRCooperativeMatrixStoreOp::verify() {
                                         getObject().getType());
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixMulAdd
+//===----------------------------------------------------------------------===//
+
+LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
+  auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType());
+  auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType());
+  auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType());
+
+  // Check element types. ODS enforces that `type(c) == type(result)`, so no
+  // need to check it here.
+  if (!llvm::all_equal({typeA.getElementType(), typeB.getElementType()}))
+    return emitOpError("matrix A and matrix B element type mismatch");
+
+  // Check the 'use' part of the type against the operands and the result.
+  if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
+    return emitOpError("operand #0 must be of use 'MatrixA'");
+  if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
+    return emitOpError("operand #1 must be of use 'MatrixB'");
+  if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
+    return emitOpError("operand #2 must be of use 'MatrixAcc'");
+
+  // Check the 'scope' part of the type.
+  if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()}))
+    return emitOpError("matrix scope mismatch");
+
+  // Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'.
+  if (typeA.getRows() != typeC.getRows())
+    return emitOpError("matrix size mismatch on dimension 'M'");
+  if (typeB.getColumns() != typeC.getColumns())
+    return emitOpError("matrix size mismatch on dimension 'N'");
+  if (typeA.getColumns() != typeB.getRows())
+    return emitOpError("matrix size mismatch on dimension 'K'");
+
+  // The spec does not restrict the element types:
+  //  > A, B, C, and Result Type need not necessarily have the same component
+  //  > type, this is defined by the client API.
+
+  // Check that if Cooperative Matrix Operands are provided, the element type
+  // is integer.
+  if (getMatrixOperands()) {
+    Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
+                           typeC.getElementType()};
+    if (!llvm::all_of(elementTypes,
+                      [](Type ty) { return isa<IntegerType>(ty); })) {
+      return emitOpError("Matrix Operands require all matrix element types to "
+                         "be Integer Types");
+    }
+  }
+
+  // Any further requirements need to be checked against VCE.
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.NV.CooperativeMatrixLength
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index ce9a61a6277a04..4b77f1bbf5b622 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -146,6 +146,220 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, Stor
 
 // -----
 
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+                                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  %p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned> :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned | AccSat> :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
+                                          %b : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB>,
+                                          %c : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>) "None" {
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                             %b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
+                                             %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_workgroup(%a : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
+                                                %b : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB>,
+                                                %c : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>) "None" {
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
+        !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB> ->
+          !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #0 must be of use 'MatrixA'}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
+  // expected-error @+1 {{expected ','}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
+  // expected-error @+1 {{expected SSA operand}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, <ASigned> :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{expected '<'}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, %c :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA>,
+                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #1 must be of use 'MatrixB'}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>) "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #2 must be of use 'MatrixAcc'}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+                                      %c : !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'M'}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                      %b : !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB>,
+                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'N'}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                      %b : !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB>,
+                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'K'}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+                                      %c : !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix scope mismatch}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd_i8(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                         %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+                                         %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix A and matrix B element type mismatch}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
+                                                      %b : !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB>,
+                                                      %c : !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op Matrix Operands require all matrix element types to be Integer Types}}
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
+        !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // NV.CooperativeMatrix
 //===----------------------------------------------------------------------===//

>From 8c4ae505ec6db16c9d1c9d798a2c561ee815201d Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 7 Sep 2023 11:09:29 -0400
Subject: [PATCH 2/2] Allow matrix A and matrix B element types to differ

---
 .../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp |  2 --
 .../SPIRV/IR/cooperative-matrix-ops.mlir      | 23 ++++++++-----------
 2 files changed, 10 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index cee39d7791aaf8..bc1d30f5551830 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -165,8 +165,6 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
 
   // Check element types. ODS enforces that `type(c) == type(result)`, so no
   // need to check it here.
-  if (!llvm::all_equal({typeA.getElementType(), typeB.getElementType()}))
-    return emitOpError("matrix A and matrix B element type mismatch");
 
   // Check the 'use' part of the type against the operands and the result.
   if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 4b77f1bbf5b622..aa6e072b03c5d3 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -194,6 +194,16 @@ spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Sub
   spirv.Return
 }
 
+spirv.func @cooperative_matrix_muladd_i8_i16_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                                 %b : !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB>,
+                                                 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
 spirv.func @cooperative_matrix_muladd_workgroup(%a : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
                                                 %b : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB>,
                                                 %c : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>) "None" {
@@ -334,19 +344,6 @@ spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup,
 
 // -----
 
-spirv.func @cooperative_matrix_muladd_i8(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
-                                         %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
-                                         %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
-  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix A and matrix B element type mismatch}}
-  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
-        !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
-        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
-          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
-  spirv.Return
-}
-
-// -----
-
 spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
                                                       %b : !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB>,
                                                       %c : !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>) "None" {



More information about the Mlir-commits mailing list