[Mlir-commits] [mlir] [mlir][spirv] Add bfloat16 support (PR #141458)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 25 23:39:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Darren Wihandi (fairywreath)
<details>
<summary>Changes</summary>
Adds bf16 support to SPIRV by using the `SPV_KHR_bfloat16` extension. Only a few operations are supported, including loading from and storing to memory, conversion to/from other types, cooperative matrix and dot product support.
Remaining TODO:
- Allow arithmetic ops of cooperative matrices with bf16 as the element type.
---
Full diff: https://github.com/llvm/llvm-project/pull/141458.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+3-3)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+35-5)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td (+6-6)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+1-4)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+1-1)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+3)
- (modified) mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir (-12)
- (modified) mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir (+8-1)
- (modified) mlir/test/Dialect/SPIRV/IR/cast-ops.mlir (+56)
- (modified) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+29)
- (modified) mlir/test/Dialect/SPIRV/IR/types.mlir (-5)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 22d5afcd77381..daa1b2b328115 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];
let arguments = (ins
- SPIRV_VectorOf<SPIRV_Float>:$vector1,
- SPIRV_VectorOf<SPIRV_Float>:$vector2
+ SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector1,
+ SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector2
);
let results = (outs
- SPIRV_Float:$result
+ SPIRV_FloatOrBFloat16:$result
);
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8fd533db83d9a..5d4469954e5b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
+def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;
def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
@@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
- SPV_KHR_cooperative_matrix,
+ SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
@@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade
Extension<[SPV_NV_stereo_view_rendering]>
];
}
+def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
+ list<Availability> availability = [
+ Extension<[SPV_KHR_bfloat16]>
+ ];
+}
+def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
+ list<Availability> availability = [
+ Extension<[SPV_KHR_bfloat16]>
+ ];
+}
+def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
+ list<Availability> availability = [
+ Extension<[SPV_KHR_bfloat16]>
+ ];
+}
def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
list<Availability> availability = [
@@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
- SPIRV_C_CacheControlsINTEL
+ SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
+ SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
]>;
+def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_BFloat16TypeKHR]>
+ ];
+}
+def SPIRV_FPEncodingAttr :
+ SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
+ SPIRV_FPE_BFloat16KHR
+ ]>;
+
def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>;
def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>;
@@ -4163,8 +4192,9 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
+def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
- [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
+ [SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
@@ -4194,9 +4224,9 @@ def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
- SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
+ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index b05ee0251df5b..29571cf138ebf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
// -----
-def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to signed integer, with
round toward 0.0.
@@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float
// -----
-def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to unsigned integer, with
round toward 0.0.
@@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
// -----
def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
- SPIRV_Float,
+ SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[SignedOp]> {
let summary = [{
@@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
// -----
def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
- SPIRV_Float,
+ SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[UnsignedOp]> {
let summary = [{
@@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
// -----
def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
- SPIRV_Float,
- SPIRV_Float,
+ SPIRV_FloatOrBFloat16,
+ SPIRV_FloatOrBFloat16,
[UsableInSpecConstantOp]> {
let summary = [{
Convert value numerically from one floating-point width to another
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 0cf5f0823be63..a21acef1c4b43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
// Check other allowed types
if (auto t = llvm::dyn_cast<FloatType>(type)) {
- if (type.isBF16()) {
- parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
- return Type();
- }
+ // TODO: All float types are allowed for now, but this should be fixed.
} else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 337df3a5a65f0..5da3164ad4d14 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -505,7 +505,7 @@ bool ScalarType::classof(Type type) {
}
bool ScalarType::isValid(FloatType type) {
- return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
+ return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
}
bool ScalarType::isValid(IntegerType type) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 15e06616f4492..b43f22db55a2e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType(
if (auto floatType = dyn_cast<FloatType>(type)) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
+ if (floatType.isBF16()) {
+ operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
+ }
return success();
}
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 82d750755ffe2..2e34c9ff54012 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -206,18 +206,6 @@ func.func @float64(%arg0: f64) { return }
// -----
-// Check that bf16 is not supported.
-module attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-// CHECK-NOT: spirv.func @bf16_type
-func.func @bf16_type(%arg0: bf16) { return }
-
-} // end module
-
-// -----
-
//===----------------------------------------------------------------------===//
// Complex types
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 2d0c86e08de5a..301a5bab9ab1a 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -265,6 +265,13 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
// -----
+func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+ return %0 : bf16
+}
+
+// -----
+
// expected-note @+1 {{prior use here}}
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -283,7 +290,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
- // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or bfloat16 type values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 34d0109e6bb44..4480a1f3720f2 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -110,6 +110,14 @@ func.func @convert_f_to_s_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
// -----
+func.func @convert_bf16_to_s32_scalar(%arg0 : bf16) -> i32 {
+ // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
+ %0 = spirv.ConvertFToS %arg0 : bf16 to i32
+ spirv.ReturnValue %0 : i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertFToU
//===----------------------------------------------------------------------===//
@@ -146,6 +154,14 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou
// -----
+func.func @convert_bf16_to_u32_scalar(%arg0 : bf16) -> i32 {
+ // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
+ %0 = spirv.ConvertFToU %arg0 : bf16 to i32
+ spirv.ReturnValue %0 : i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertSToF
//===----------------------------------------------------------------------===//
@@ -174,6 +190,14 @@ func.func @convert_s_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
// -----
+func.func @convert_s32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+ // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to bf16
+ %0 = spirv.ConvertSToF %arg0 : i32 to bf16
+ spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertUToF
//===----------------------------------------------------------------------===//
@@ -202,6 +226,14 @@ func.func @convert_u_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
// -----
+func.func @convert_u32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+ // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to bf16
+ %0 = spirv.ConvertUToF %arg0 : i32 to bf16
+ spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.FConvert
//===----------------------------------------------------------------------===//
@@ -238,6 +270,30 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {
// -----
+func.func @f_convert_bf16_to_f32_scalar(%arg0 : bf16) -> f32 {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
+ %0 = spirv.FConvert %arg0 : bf16 to f32
+ spirv.ReturnValue %0 : f32
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_vector(%arg0 : vector<3xf32>) -> vector<3xbf16> {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : vector<3xf32> to vector<3xbf16>
+ %0 = spirv.FConvert %arg0 : vector<3xf32> to vector<3xbf16>
+ spirv.ReturnValue %0 : vector<3xbf16>
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) -> !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+ %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+ spirv.ReturnValue %0 : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.SConvert
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index d3e1dbc229ef9..8929e63639c97 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -31,6 +31,15 @@ spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuf
spirv.Return
}
+// CHECK-LABEL: @cooperative_matrix_load_bf16
+spirv.func @cooperative_matrix_load_bf16(%ptr : !spirv.ptr<bf16, StorageBuffer>, %stride : i32) "None" {
+ // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>
+ // CHECK-SAME: : !spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+ !spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
+ spirv.Return
+}
+
// CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
@@ -225,6 +234,26 @@ spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgro
spirv.Return
}
+spirv.func @cooperative_matrix_muladd_bf16_bf16(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>) "None" {
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_bf16_f32(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>) "None" {
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index b63a08d96e6af..a81fe72a8362e 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -57,11 +57,6 @@ func.func private @tensor_type(!spirv.array<4xtensor<4xf32>>) -> ()
// -----
-// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}}
-func.func private @bf16_type(!spirv.array<4xbf16>) -> ()
-
-// -----
-
// expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}}
func.func private @i256_type(!spirv.array<4xi256>) -> ()
``````````
</details>
https://github.com/llvm/llvm-project/pull/141458
More information about the Mlir-commits
mailing list