[Mlir-commits] [mlir] [mlir][spirv] Add `CooperativeMatrixMulAdd` op (PR #65617)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Sep 7 08:09:58 PDT 2023
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/65617:
>From f8bc88836262532f383e6fbdead003ca64b03e15 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <kubak at google.com>
Date: Wed, 19 Jul 2023 10:48:14 -0400
Subject: [PATCH 1/2] [mlir][spirv] Add `CooperativeMatrixMulAdd` op
This is the last remaining op from the `SPV_KHR_cooperative_matrix` extension.
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 21 +-
.../SPIRV/IR/SPIRVCooperativeMatrixOps.td | 108 ++++++++-
.../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 57 +++++
.../SPIRV/IR/cooperative-matrix-ops.mlir | 214 ++++++++++++++++++
4 files changed, 398 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 0ac660b8c1c3e1..cc4417077d459c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4071,6 +4071,23 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
SPIRV_KHR_CML_RowMajor, SPIRV_KHR_CML_ColumnMajor
]>;
+// Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
+def SPIRV_KHR_CMO_None : I32BitEnumAttrCaseNone<"None">;
+def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
+def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
+def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
+def SPIRV_KHR_CMO_Result_Signed : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
+def SPIRV_KHR_CMO_AccSat : I32BitEnumAttrCaseBit<"AccSat", 16>;
+
+def SPIRV_KHR_CooperativeMatrixOperandsAttr :
+ SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",
+ "valid SPIR-V Cooperative Matrix Operands (KHR)",
+ "cooperative_matrix_operands_khr", [
+ SPIRV_KHR_CMO_None, SPIRV_KHR_CMO_MatrixA_Signed,
+ SPIRV_KHR_CMO_MatrixB_Signed, SPIRV_KHR_CMO_MatrixC_Signed,
+ SPIRV_KHR_CMO_Result_Signed, SPIRV_KHR_CMO_AccSat
+ ]>;
+
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
@@ -4447,6 +4464,7 @@ def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 445
def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
+def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
@@ -4548,7 +4566,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat,
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
- SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+ SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
+ SPIRV_OC_OpCooperativeMatrixLengthKHR,
SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
SPIRV_OC_OpCooperativeMatrixLengthNV,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 6de744039483b9..7060aa80dc113e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -203,6 +203,112 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
let results = (outs);
}
+// -----
+
+def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMulAdd",
+ [Pure, AllTypesMatch<["c", "result"]>]> {
+ let summary = "Returns the result of `(A x B) + C` of matrices A, B, and C";
+
+ let description = [{
+ Linear-algebraic matrix multiply of A by B and then component-wise add C.
+ The order of the operations is implementation-dependent. The internal
+ precision of floating-point operations is defined by the client API. Integer
+ operations used in the multiplication of A by B are performed at the
+ precision of the Result Type and the resulting value will equal the
+ low-order N bits of the correct result R, where N is the result width and R
+ is computed with enough precision to avoid overflow and underflow if the
+ SaturatingAccumulation Cooperative Matrix Operand is not present. If the
+ SaturatingAccumulation Cooperative Matrix Operand is present and overflow or
+ underflow occurs as part of calculating that intermediate result, the result
+ of the instruction is undefined. Integer additions of the elements of that
+ intermediate result with those of C are performed at the precision of Result
+ Type, are exact, and are saturating if the SaturatingAccumulation
+ Cooperative Matrix Operand is present, with the signedness of the saturation
+ being that of the components of Result Type. If the SaturatingAccumulation
+ Cooperative Matrix Operand is not present then the resulting value will
+ equal the low-order N bits of the correct result R, where N is the result
+ width and R is computed with enough precision to avoid overflow and
+ underflow.
+
+ Result Type must be a cooperative matrix type with M rows and N columns
+ whose Use must be MatrixAccumulatorKHR.
+
+ A is a cooperative matrix with M rows and K columns whose Use must be
+ MatrixAKHR.
+
+ B is a cooperative matrix with K rows and N columns whose Use must be
+ MatrixBKHR.
+
+ C is a cooperative matrix with M rows and N columns whose Use must be
+ MatrixAccumulatorKHR.
+
+ The values of M, N, and K must be consistent across the result and operands.
+ This is referred to as an MxNxK matrix multiply.
+
+ A, B, C, and Result Type must have the same scope, and this defines the
+ scope of the operation. A, B, C, and Result Type need not necessarily have
+ the same component type, this is defined by the client API.
+
+ If the Component Type of any matrix operand is an integer type, then its
+ components are treated as signed if the Matrix{A,B,C,Result}SignedComponents
+ Cooperative Matrix Operand is present and are treated as unsigned otherwise.
+
+ Cooperative Matrix Operands is an optional Cooperative Matrix Operand
+ literal. If not present, it is the same as specifying the Cooperative Matrix
+ Operand None.
+
+ For a given dynamic instance of this instruction, all invocations in a given
+ scope instance must be active or all must be inactive (where the scope is
+ the scope of the operation).
+
+ ``` {.ebnf}
+ cooperative-matrixmuladd-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixMulAdd`
+ ssa-use `,` ssa-use `,` ssa-use
+ (`<` matrix-operands `>`)? `:`
+ a-cooperative-matrix-type `,`
+ b-cooperative-matrix-type `->`
+ result-cooperative-matrix-type
+ ```
+
+ #### Example:
+
+ ```
+ %0 = spirv.KHR.CooperativeMatrixMulAdd %matA, %matB, %matC :
+ !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>
+
+ %1 = spirv.KHR.CooperativeMatrixMulAdd %matA, %matB, %matC, <ASigned | AccSat> :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $a `,` $b `,` $c ( `,` $matrix_operands^ )? attr-dict `:`
+ type($a) `,` type($b) `->` type($c)
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_6>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_KHR_cooperative_matrix]>,
+ Capability<[SPIRV_C_CooperativeMatrixKHR]>
+ ];
+
+ let arguments = (ins
+ SPIRV_AnyCooperativeMatrix:$a,
+ SPIRV_AnyCooperativeMatrix:$b,
+ SPIRV_AnyCooperativeMatrix:$c,
+ OptionalAttr<SPIRV_KHR_CooperativeMatrixOperandsAttr>:$matrix_operands
+ );
+
+ let results = (outs
+ SPIRV_AnyCooperativeMatrix:$result
+ );
+}
+
//===----------------------------------------------------------------------===//
// SPV_NV_cooperative_matrix extension ops.
//===----------------------------------------------------------------------===//
@@ -380,7 +486,7 @@ def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAd
}];
let assemblyFormat = [{
- operands attr-dict`:` type($a) `,` type($b) `->` type($c)
+ operands attr-dict `:` type($a) `,` type($b) `->` type($c)
}];
let availability = [
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index bdd87677866501..cee39d7791aaf8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -11,7 +11,10 @@
//===----------------------------------------------------------------------===//
#include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "llvm/ADT/STLExtras.h"
+#include <cstdint>
using namespace mlir::spirv::AttrNames;
@@ -151,6 +154,60 @@ LogicalResult KHRCooperativeMatrixStoreOp::verify() {
getObject().getType());
}
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixMulAdd
+//===----------------------------------------------------------------------===//
+
+LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
+ auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType());
+ auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType());
+ auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType());
+
+ // Check element types. ODS enforces that `type(c) == type(result)`, so no
+ // need to check it here.
+ if (!llvm::all_equal({typeA.getElementType(), typeB.getElementType()}))
+ return emitOpError("matrix A and matrix B element type mismatch");
+
+ // Check the 'use' part of the type against the operands and the result.
+ if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
+ return emitOpError("operand #0 must be of use 'MatrixA'");
+ if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
+ return emitOpError("operand #1 must be of use 'MatrixB'");
+ if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
+ return emitOpError("operand #2 must be of use 'MatrixAcc'");
+
+ // Check the 'scope' part of the type.
+ if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()}))
+ return emitOpError("matrix scope mismatch");
+
+ // Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'.
+ if (typeA.getRows() != typeC.getRows())
+ return emitOpError("matrix size mismatch on dimension 'M'");
+ if (typeB.getColumns() != typeC.getColumns())
+ return emitOpError("matrix size mismatch on dimension 'N'");
+ if (typeA.getColumns() != typeB.getRows())
+ return emitOpError("matrix size mismatch on dimension 'K'");
+
+ // The spec does not restrict the element types:
+ // > A, B, C, and Result Type need not necessarily have the same component
+ // > type, this is defined by the client API.
+
+ // Check that if Cooperative Matrix Operands are provided, the element type
+ // is integer.
+ if (getMatrixOperands()) {
+ Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
+ typeC.getElementType()};
+ if (!llvm::all_of(elementTypes,
+ [](Type ty) { return isa<IntegerType>(ty); })) {
+ return emitOpError("Matrix Operands require all matrix element types to "
+ "be Integer Types");
+ }
+ }
+
+ // Any further requirements need to be checked against VCE.
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixLength
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index ce9a61a6277a04..4b77f1bbf5b622 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -146,6 +146,220 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, Stor
// -----
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ %p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned> :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned | AccSat> :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>) "None" {
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_workgroup(%a : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
+ %b : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB>,
+ %c : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>) "None" {
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
+ !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB> ->
+ !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #0 must be of use 'MatrixA'}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
+ // expected-error @+1 {{expected ','}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
+ // expected-error @+1 {{expected SSA operand}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, <ASigned> :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{expected '<'}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, %c :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #1 must be of use 'MatrixB'}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>) "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #2 must be of use 'MatrixAcc'}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'M'}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'N'}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'K'}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix scope mismatch}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd_i8(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix A and matrix B element type mismatch}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op Matrix Operands require all matrix element types to be Integer Types}}
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
+ !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// NV.CooperativeMatrix
//===----------------------------------------------------------------------===//
>From 8c4ae505ec6db16c9d1c9d798a2c561ee815201d Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 7 Sep 2023 11:09:29 -0400
Subject: [PATCH 2/2] Allow matrix A and matrix B element types to differ
---
.../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 2 --
.../SPIRV/IR/cooperative-matrix-ops.mlir | 23 ++++++++-----------
2 files changed, 10 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index cee39d7791aaf8..bc1d30f5551830 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -165,8 +165,6 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
// Check element types. ODS enforces that `type(c) == type(result)`, so no
// need to check it here.
- if (!llvm::all_equal({typeA.getElementType(), typeB.getElementType()}))
- return emitOpError("matrix A and matrix B element type mismatch");
// Check the 'use' part of the type against the operands and the result.
if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 4b77f1bbf5b622..aa6e072b03c5d3 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -194,6 +194,16 @@ spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Sub
spirv.Return
}
+spirv.func @cooperative_matrix_muladd_i8_i16_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
spirv.func @cooperative_matrix_muladd_workgroup(%a : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
%b : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB>,
%c : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>) "None" {
@@ -334,19 +344,6 @@ spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup,
// -----
-spirv.func @cooperative_matrix_muladd_i8(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
- %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
- %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
- // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix A and matrix B element type mismatch}}
- %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
- !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
- !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
- !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
- spirv.Return
-}
-
-// -----
-
spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>) "None" {
More information about the Mlir-commits
mailing list