[Mlir-commits] [mlir] [mlir][spirv] Add bfloat16 support (PR #141458)

Darren Wihandi llvmlistbot at llvm.org
Thu Jun 5 14:20:02 PDT 2025


https://github.com/fairywreath updated https://github.com/llvm/llvm-project/pull/141458

>From 45349e623c10722a7576ff81d1dc889cd13889d1 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Mon, 26 May 2025 02:33:38 -0400
Subject: [PATCH 01/13] [mlir][spirv] Add bfloat16 support

---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    |  6 +-
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 40 +++++++++++--
 .../mlir/Dialect/SPIRV/IR/SPIRVCastOps.td     | 12 ++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  5 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      |  2 +-
 .../Target/SPIRV/Serialization/Serializer.cpp |  3 +
 .../FuncToSPIRV/types-to-spirv.mlir           | 12 ----
 .../test/Dialect/SPIRV/IR/arithmetic-ops.mlir |  9 ++-
 mlir/test/Dialect/SPIRV/IR/cast-ops.mlir      | 56 +++++++++++++++++++
 .../SPIRV/IR/khr-cooperative-matrix-ops.mlir  | 29 ++++++++++
 mlir/test/Dialect/SPIRV/IR/types.mlir         |  5 --
 11 files changed, 142 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 22d5afcd77381..daa1b2b328115 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
   }];
 
   let arguments = (ins
-    SPIRV_VectorOf<SPIRV_Float>:$vector1,
-    SPIRV_VectorOf<SPIRV_Float>:$vector2
+    SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector1,
+    SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector2
   );
 
   let results = (outs
-    SPIRV_Float:$result
+    SPIRV_FloatOrBFloat16:$result
   );
 
   let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8fd533db83d9a..5d4469954e5b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate                  : I32EnumAttrCase<"SPV_KHR_subgroup
 def SPV_KHR_non_semantic_info                : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
 def SPV_KHR_terminate_invocation             : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
 def SPV_KHR_cooperative_matrix               : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
+def SPV_KHR_bfloat16                         : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;
 
 def SPV_EXT_demote_to_helper_invocation  : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
 def SPV_EXT_descriptor_indexing          : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
@@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
       SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
       SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
       SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
-      SPV_KHR_cooperative_matrix,
+      SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
       SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
       SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
       SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
@@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV                          : I32EnumAttrCase<"Shade
     Extension<[SPV_NV_stereo_view_rendering]>
   ];
 }
+def SPIRV_C_BFloat16TypeKHR                             : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
+  list<Availability> availability = [
+    Extension<[SPV_KHR_bfloat16]>
+  ];
+}
+def SPIRV_C_BFloat16DotProductKHR                       : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
+  list<Availability> availability = [
+    Extension<[SPV_KHR_bfloat16]>
+  ];
+}
+def SPIRV_C_BFloat16CooperativeMatrixKHR                : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
+  list<Availability> availability = [
+    Extension<[SPV_KHR_bfloat16]>
+  ];
+}
 
 def SPIRV_C_Bfloat16ConversionINTEL                         : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
   list<Availability> availability = [
@@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
       SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
       SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
-      SPIRV_C_CacheControlsINTEL
+      SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
+      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
       SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
     ]>;
 
+def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_BFloat16TypeKHR]>
+  ];
+}
+def SPIRV_FPEncodingAttr :
+    SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
+      SPIRV_FPE_BFloat16KHR
+    ]>;
+
 def SPIRV_FC_None         : I32BitEnumAttrCaseNone<"None">;
 def SPIRV_FC_Inline       : I32BitEnumAttrCaseBit<"Inline", 0>;
 def SPIRV_FC_DontInline   : I32BitEnumAttrCaseBit<"DontInline", 1>;
@@ -4163,8 +4192,9 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
+def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>;
 def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
-                                       [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
+                                       [SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>;
 // Component type check is done in the type parser for the following SPIR-V
 // dialect-specific types so we use "Any" here.
 def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
@@ -4194,9 +4224,9 @@ def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
                SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
 def SPIRV_Type : AnyTypeOf<[
-    SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
+    SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-    SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
+    SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index b05ee0251df5b..29571cf138ebf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
 
 // -----
 
-def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
   let summary = [{
     Convert value numerically from floating point to signed integer, with
     round toward 0.0.
@@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float
 
 // -----
 
-def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
   let summary = [{
     Convert value numerically from floating point to unsigned integer, with
     round toward 0.0.
@@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
 // -----
 
 def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
-                                   SPIRV_Float,
+                                   SPIRV_FloatOrBFloat16,
                                    SPIRV_Integer,
                                    [SignedOp]> {
   let summary = [{
@@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
 // -----
 
 def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
-                                   SPIRV_Float,
+                                   SPIRV_FloatOrBFloat16,
                                    SPIRV_Integer,
                                    [UnsignedOp]> {
   let summary = [{
@@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
 // -----
 
 def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
-                                SPIRV_Float,
-                                SPIRV_Float,
+                                SPIRV_FloatOrBFloat16,
+                                SPIRV_FloatOrBFloat16,
                                 [UsableInSpecConstantOp]> {
   let summary = [{
     Convert value numerically from one floating-point width to another
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 0cf5f0823be63..a21acef1c4b43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
 
   // Check other allowed types
   if (auto t = llvm::dyn_cast<FloatType>(type)) {
-    if (type.isBF16()) {
-      parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
-      return Type();
-    }
+    // TODO: All float types are allowed for now, but this should be fixed.
   } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
     if (!ScalarType::isValid(t)) {
       parser.emitError(typeLoc,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 337df3a5a65f0..5da3164ad4d14 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -505,7 +505,7 @@ bool ScalarType::classof(Type type) {
 }
 
 bool ScalarType::isValid(FloatType type) {
-  return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
+  return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
 }
 
 bool ScalarType::isValid(IntegerType type) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 15e06616f4492..b43f22db55a2e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType(
   if (auto floatType = dyn_cast<FloatType>(type)) {
     typeEnum = spirv::Opcode::OpTypeFloat;
     operands.push_back(floatType.getWidth());
+    if (floatType.isBF16()) {
+      operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
+    }
     return success();
   }
 
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 82d750755ffe2..2e34c9ff54012 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -206,18 +206,6 @@ func.func @float64(%arg0: f64) { return }
 
 // -----
 
-// Check that bf16 is not supported.
-module attributes {
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-// CHECK-NOT: spirv.func @bf16_type
-func.func @bf16_type(%arg0: bf16) { return }
-
-} // end module
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // Complex types
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 2d0c86e08de5a..301a5bab9ab1a 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -265,6 +265,13 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
 
 // -----
 
+func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+  %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+  return %0 : bf16
+}
+
+// -----
+
 // expected-note @+1 {{prior use here}}
 func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
   // expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -283,7 +290,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
 // -----
 
 func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
-  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or bfloat16 type values of length 2/3/4/8/16}}
   %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
   return %0 : i32
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 34d0109e6bb44..4480a1f3720f2 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -110,6 +110,14 @@ func.func @convert_f_to_s_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
 
 // -----
 
+func.func @convert_bf16_to_s32_scalar(%arg0 : bf16) -> i32 {
+  // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
+  %0 = spirv.ConvertFToS %arg0 : bf16 to i32
+  spirv.ReturnValue %0 : i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.ConvertFToU
 //===----------------------------------------------------------------------===//
@@ -146,6 +154,14 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou
 
 // -----
 
+func.func @convert_bf16_to_u32_scalar(%arg0 : bf16) -> i32 {
+  // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
+  %0 = spirv.ConvertFToU %arg0 : bf16 to i32
+  spirv.ReturnValue %0 : i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.ConvertSToF
 //===----------------------------------------------------------------------===//
@@ -174,6 +190,14 @@ func.func @convert_s_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
 
 // -----
 
+func.func @convert_s32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+  // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to bf16
+  %0 = spirv.ConvertSToF %arg0 : i32 to bf16
+  spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.ConvertUToF
 //===----------------------------------------------------------------------===//
@@ -202,6 +226,14 @@ func.func @convert_u_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
 
 // -----
 
+func.func @convert_u32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+  // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to bf16
+  %0 = spirv.ConvertUToF %arg0 : i32 to bf16
+  spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FConvert
 //===----------------------------------------------------------------------===//
@@ -238,6 +270,30 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {
 
 // -----
 
+func.func @f_convert_bf16_to_f32_scalar(%arg0 : bf16) -> f32 {
+  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
+  %0 = spirv.FConvert %arg0 : bf16 to f32
+  spirv.ReturnValue %0 : f32
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_vector(%arg0 : vector<3xf32>) -> vector<3xbf16> {
+  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : vector<3xf32> to vector<3xbf16>
+  %0 = spirv.FConvert %arg0 : vector<3xf32> to vector<3xbf16>
+  spirv.ReturnValue %0 : vector<3xbf16>
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) -> !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> {
+  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+  %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+  spirv.ReturnValue %0 : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.SConvert
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index d3e1dbc229ef9..8929e63639c97 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -31,6 +31,15 @@ spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuf
   spirv.Return
 }
 
+// CHECK-LABEL: @cooperative_matrix_load_bf16
+spirv.func @cooperative_matrix_load_bf16(%ptr : !spirv.ptr<bf16, StorageBuffer>, %stride : i32) "None" {
+  // CHECK:      {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>
+  // CHECK-SAME:   : !spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+    !spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
+  spirv.Return
+}
+
 // CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
 spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
@@ -225,6 +234,26 @@ spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgro
   spirv.Return
 }
 
+spirv.func @cooperative_matrix_muladd_bf16_bf16(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+                                             %b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
+                                             %c : !spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>) "None" {
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_bf16_f32(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+                                             %b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
+                                             %c : !spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>) "None" {
+  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+        !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+        !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
+          !spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
 spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
                                              %b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
                                              %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index b63a08d96e6af..a81fe72a8362e 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -57,11 +57,6 @@ func.func private @tensor_type(!spirv.array<4xtensor<4xf32>>) -> ()
 
 // -----
 
-// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}}
-func.func private @bf16_type(!spirv.array<4xbf16>) -> ()
-
-// -----
-
 // expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}}
 func.func private @i256_type(!spirv.array<4xi256>) -> ()
 

>From d9815d4addcfa9ed4c5efeb6c572d0b6c1576863 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Mon, 26 May 2025 11:35:10 -0600
Subject: [PATCH 02/13] Properly implement arithmetic coop matrix ops with bf16

---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    | 96 ++++++++++++++-----
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 12 ++-
 .../SPIRV/IR/khr-cooperative-matrix-ops.mlir  | 68 +++++++++++--
 3 files changed, 138 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index daa1b2b328115..850e0d165f4cf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -24,8 +24,8 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
       SPIRV_BinaryOp<mnemonic, type, type,
                    !listconcat(traits,
                                [Pure, SameOperandsAndResultType])> {
-  // In addition to normal types arithmetic instructions can support cooperative
-  // matrix.
+  // TODO: Arithmetic operations that use this definition do not support cooperative matrices,
+  // these need to be fixed.
   let arguments = (ins
     SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
     SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
@@ -37,20 +37,43 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
   let assemblyFormat = "operands attr-dict `:` type($result)";
 }
 
-class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
+class SPIRV_ArithmeticWithCoopMatrixBinaryOp<string mnemonic,
+                                             Type scalarVectorType,
+                                             Type coopMatrixType,
+                                             list<Trait> traits = []> :
+      // Operands type same as result type.
+      SPIRV_BinaryOp<mnemonic, coopMatrixType, coopMatrixType,
+                   !listconcat(traits,
+                                [Pure, SameOperandsAndResultType])> {
+  // In addition to normal types these arithmetic instructions can support
+  // cooperative matrix.
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$operand1,
+    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$operand2
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$result
+  );
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
+class SPIRV_ArithmeticUnaryOp<string mnemonic,
+                              Type scalarVectorType,
+                              Type coopMatrixType,
                               list<Trait> traits = []> :
       // Operand type same as result type.
-      SPIRV_UnaryOp<mnemonic, type, type,
+      SPIRV_UnaryOp<mnemonic, coopMatrixType, coopMatrixType,
                    !listconcat(traits,
                                [Pure, SameOperandsAndResultType])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
-    SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand
+    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$operand1
   );
 
   let results = (outs
-    SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$result
+    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$result
   );
   let assemblyFormat = "operands attr-dict `:` type($result)";
 }
@@ -82,7 +105,10 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
 
 // -----
 
-def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
+def SPIRV_FAddOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FAdd",
+                                                      SPIRV_Float,
+                                                      SPIRV_FloatOrBFloat16,
+                                                      [Commutative]> {
   let summary = "Floating-point addition of Operand 1 and Operand 2.";
 
   let description = [{
@@ -104,7 +130,10 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
 
 // -----
 
-def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
+def SPIRV_FDivOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FDiv",
+                                                      SPIRV_Float,
+                                                      SPIRV_FloatOrBFloat16,
+                                                      []> {
   let summary = "Floating-point division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -154,7 +183,10 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
+def SPIRV_FMulOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FMul",
+                                                      SPIRV_Float,
+                                                      SPIRV_FloatOrBFloat16,
+                                                      [Commutative]> {
   let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -176,7 +208,10 @@ def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]>
 
 // -----
 
-def SPIRV_FNegateOp : SPIRV_ArithmeticUnaryOp<"FNegate", SPIRV_Float, []> {
+def SPIRV_FNegateOp : SPIRV_ArithmeticUnaryOp<"FNegate",
+                                          SPIRV_Float,
+                                          SPIRV_FloatOrBFloat16,
+                                          []> {
   let summary = [{
     Inverts the sign bit of Operand. (Note, however, that OpFNegate is still
     considered a floating-point instruction, and so is subject to the
@@ -229,7 +264,10 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
+def SPIRV_FSubOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FSub",
+                                                      SPIRV_Float,
+                                                      SPIRV_FloatOrBFloat16,
+                                                      []> {
   let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
 
   let description = [{
@@ -251,9 +289,10 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
-                                        SPIRV_Integer,
-                                        [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IAddOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"IAdd",
+                                                      SPIRV_Integer,
+                                                      SPIRV_Integer,
+                                                      [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer addition of Operand 1 and Operand 2.";
 
   let description = [{
@@ -322,9 +361,10 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
 
 // -----
 
-def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
-                                        SPIRV_Integer,
-                                        [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IMulOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"IMul",
+                                                      SPIRV_Integer,
+                                                      SPIRV_Integer,
+                                                      [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -354,9 +394,10 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
 
 // -----
 
-def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
-                                        SPIRV_Integer,
-                                        [UsableInSpecConstantOp]> {
+def SPIRV_ISubOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"ISub",
+                                                      SPIRV_Integer,
+                                                      SPIRV_Integer,
+                                                      [UsableInSpecConstantOp]> {
   let summary = "Integer subtraction of Operand 2 from Operand 1.";
 
   let description = [{
@@ -460,9 +501,10 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
 
 // -----
 
-def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
-                                        SPIRV_Integer,
-                                        [UsableInSpecConstantOp]> {
+def SPIRV_SDivOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"SDiv",
+                                                      SPIRV_Integer,
+                                                      SPIRV_Integer,
+                                                      [UsableInSpecConstantOp]> {
   let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -560,6 +602,7 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
 // -----
 
 def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
+                                          SPIRV_Integer,
                                           SPIRV_Integer,
                                           [UsableInSpecConstantOp]> {
   let summary = "Signed-integer subtract of Operand from zero.";
@@ -622,9 +665,10 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
 
 // -----
 
-def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
-                                        SPIRV_Integer,
-                                        [UnsignedOp, UsableInSpecConstantOp]> {
+def SPIRV_UDivOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"UDiv",
+                                                      SPIRV_Integer,
+                                                      SPIRV_Integer,
+                                                      [UnsignedOp, UsableInSpecConstantOp]> {
   let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 5d4469954e5b7..ccc6bd76ca1d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4245,16 +4245,24 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
 class SPIRV_VectorOf<Type type> :
     VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
 
+class SPIRV_CoopMatrixOf<Type type> :
+    SPIRV_CoopMatrixOfType<[type]>;
+
 class SPIRV_ScalarOrVectorOf<Type type> :
     AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
 
 class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
     AnyTypeOf<[type, SPIRV_VectorOf<type>,
-               SPIRV_CoopMatrixOfType<[type]>]>;
+               SPIRV_CoopMatrixOf<type>]>;
+
+class SPIRV_ScalarOrVectorOfOrCoopMatrixOf<Type scalarVectorType,
+                                           Type coopMatrixType> :
+    AnyTypeOf<[scalarVectorType, SPIRV_VectorOf<scalarVectorType>,
+               SPIRV_CoopMatrixOf<coopMatrixType>]>;
 
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
     AnyTypeOf<[SPIRV_AnyMatrix,
-               SPIRV_CoopMatrixOfType<[type]>]>;
+               SPIRV_CoopMatrixOf<type>]>;
 
 class SPIRV_MatrixOf<Type type> :
     SPIRV_MatrixOfType<[type]>;
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 8929e63639c97..4d161a3193505 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -437,6 +437,9 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1
 !matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
 !matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
 
+!matA_bf16 = !spirv.coopmatrix<2x2xbf16, Subgroup, MatrixA>
+!matB_bf16 = !spirv.coopmatrix<2x2xbf16, Subgroup, MatrixB>
+
 // These tests are kept in the same order as the list of compatible ops in the
 // SPV_KHR_cooperative_matrix extension spec.
 
@@ -449,8 +452,8 @@ spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fnegate
-spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fnegate_f32
+spirv.func @fnegate_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
   %p = spirv.FNegate %a : !matA_f32
@@ -458,6 +461,15 @@ spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
+// CHECK-LABEL: @fnegate_bf16
+spirv.func @fnegate_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
+  // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FNegate %a : !matA_bf16
+  %q = spirv.FNegate %b : !matB_bf16
+  spirv.Return
+}
+
 // CHECK-LABEL: @iadd
 spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
   // CHECK:      spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
@@ -467,8 +479,8 @@ spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fadd
-spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fadd_f32
+spirv.func @fadd_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
   %p = spirv.FAdd %a, %a : !matA_f32
@@ -476,6 +488,15 @@ spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
+// CHECK-LABEL: @fadd_bf16
+spirv.func @fadd_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
+  // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FAdd %a, %a : !matA_bf16
+  %q = spirv.FAdd %b, %b : !matB_bf16
+  spirv.Return
+}
+
 // CHECK-LABEL: @isub
 spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
   // CHECK:      spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
@@ -485,8 +506,8 @@ spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fsub
-spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fsub_f32
+spirv.func @fsub_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
   %p = spirv.FSub %a, %a : !matA_f32
@@ -494,8 +515,17 @@ spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fmul
-spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fsub_bf16
+spirv.func @fsub_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
+  // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FSub %a, %a : !matA_bf16
+  %q = spirv.FSub %b, %b : !matB_bf16
+  spirv.Return
+}
+
+// CHECK-LABEL: @fmul_f32
+spirv.func @fmul_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
   %p = spirv.FMul %a, %a : !matA_f32
@@ -503,6 +533,15 @@ spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
+// CHECK-LABEL: @fmul_bf16
+spirv.func @fmul_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
+  // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FMul %a, %a : !matA_bf16
+  %q = spirv.FMul %b, %b : !matB_bf16
+  spirv.Return
+}
+
 // CHECK-LABEL: @imul
 spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
   // CHECK:      spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
@@ -512,8 +551,8 @@ spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fdiv
-spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fdiv_f32
+spirv.func @fdiv_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
   %p = spirv.FDiv %a, %a : !matA_f32
@@ -521,6 +560,15 @@ spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
+// CHECK-LABEL: @fdiv_bf16
+spirv.func @fdiv_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
+  // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FDiv %a, %a : !matA_bf16
+  %q = spirv.FDiv %b, %b : !matB_bf16
+  spirv.Return
+}
+
 // CHECK-LABEL: @sdiv
 spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
   // CHECK:      spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix

>From 80565f0004adc673c1b75b7dc496439a561c5ec4 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Wed, 28 May 2025 23:03:14 -0400
Subject: [PATCH 03/13] Address review comments

---
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 ++++-------
 1 file changed, 4 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ccc6bd76ca1d0..bbb0d65b70853 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4226,7 +4226,7 @@ def SPIRV_Composite :
 def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-    SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
+    SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4245,24 +4245,21 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
 class SPIRV_VectorOf<Type type> :
     VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
 
-class SPIRV_CoopMatrixOf<Type type> :
-    SPIRV_CoopMatrixOfType<[type]>;
-
 class SPIRV_ScalarOrVectorOf<Type type> :
     AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
 
 class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
     AnyTypeOf<[type, SPIRV_VectorOf<type>,
-               SPIRV_CoopMatrixOf<type>]>;
+               SPIRV_CoopMatrixOfType<[type]>]>;
 
 class SPIRV_ScalarOrVectorOfOrCoopMatrixOf<Type scalarVectorType,
                                            Type coopMatrixType> :
     AnyTypeOf<[scalarVectorType, SPIRV_VectorOf<scalarVectorType>,
-               SPIRV_CoopMatrixOf<coopMatrixType>]>;
+               SPIRV_CoopMatrixOfType<[coopMatrixType]>]>;
 
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
     AnyTypeOf<[SPIRV_AnyMatrix,
-               SPIRV_CoopMatrixOf<type>]>;
+               SPIRV_CoopMatrixOfType<[type]>]>;
 
 class SPIRV_MatrixOf<Type type> :
     SPIRV_MatrixOfType<[type]>;

>From 5d2984b202e6a778c2c0d2601d0009154d1f8cb2 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Wed, 28 May 2025 23:07:56 -0400
Subject: [PATCH 04/13] Address review and scoped down changes

---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    | 102 +++++-------------
 .../test/Dialect/SPIRV/IR/arithmetic-ops.mlir |   9 +-
 .../SPIRV/IR/khr-cooperative-matrix-ops.mlir  |  97 ++---------------
 3 files changed, 40 insertions(+), 168 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 850e0d165f4cf..22d5afcd77381 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -24,8 +24,8 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
       SPIRV_BinaryOp<mnemonic, type, type,
                    !listconcat(traits,
                                [Pure, SameOperandsAndResultType])> {
-  // TODO: Arithmetic operations that use this definition do not support cooperative matrices,
-  // these need to be fixed.
+  // In addition to normal types arithmetic instructions can support cooperative
+  // matrix.
   let arguments = (ins
     SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
     SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
@@ -37,43 +37,20 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
   let assemblyFormat = "operands attr-dict `:` type($result)";
 }
 
-class SPIRV_ArithmeticWithCoopMatrixBinaryOp<string mnemonic,
-                                             Type scalarVectorType,
-                                             Type coopMatrixType,
-                                             list<Trait> traits = []> :
-      // Operands type same as result type.
-      SPIRV_BinaryOp<mnemonic, coopMatrixType, coopMatrixType,
-                   !listconcat(traits,
-                                [Pure, SameOperandsAndResultType])> {
-  // In addition to normal types these arithmetic instructions can support
-  // cooperative matrix.
-  let arguments = (ins
-    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$operand1,
-    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$operand2
-  );
-
-  let results = (outs
-    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$result
-  );
-  let assemblyFormat = "operands attr-dict `:` type($result)";
-}
-
-class SPIRV_ArithmeticUnaryOp<string mnemonic,
-                              Type scalarVectorType,
-                              Type coopMatrixType,
+class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
                               list<Trait> traits = []> :
       // Operand type same as result type.
-      SPIRV_UnaryOp<mnemonic, coopMatrixType, coopMatrixType,
+      SPIRV_UnaryOp<mnemonic, type, type,
                    !listconcat(traits,
                                [Pure, SameOperandsAndResultType])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
-    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$operand1
+    SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand
   );
 
   let results = (outs
-    SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$result
+    SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$result
   );
   let assemblyFormat = "operands attr-dict `:` type($result)";
 }
@@ -105,10 +82,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
 
 // -----
 
-def SPIRV_FAddOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FAdd",
-                                                      SPIRV_Float,
-                                                      SPIRV_FloatOrBFloat16,
-                                                      [Commutative]> {
+def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
   let summary = "Floating-point addition of Operand 1 and Operand 2.";
 
   let description = [{
@@ -130,10 +104,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FAdd",
 
 // -----
 
-def SPIRV_FDivOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FDiv",
-                                                      SPIRV_Float,
-                                                      SPIRV_FloatOrBFloat16,
-                                                      []> {
+def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
   let summary = "Floating-point division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -183,10 +154,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_FMulOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FMul",
-                                                      SPIRV_Float,
-                                                      SPIRV_FloatOrBFloat16,
-                                                      [Commutative]> {
+def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
   let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -208,10 +176,7 @@ def SPIRV_FMulOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FMul",
 
 // -----
 
-def SPIRV_FNegateOp : SPIRV_ArithmeticUnaryOp<"FNegate",
-                                          SPIRV_Float,
-                                          SPIRV_FloatOrBFloat16,
-                                          []> {
+def SPIRV_FNegateOp : SPIRV_ArithmeticUnaryOp<"FNegate", SPIRV_Float, []> {
   let summary = [{
     Inverts the sign bit of Operand. (Note, however, that OpFNegate is still
     considered a floating-point instruction, and so is subject to the
@@ -264,10 +229,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_FSubOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FSub",
-                                                      SPIRV_Float,
-                                                      SPIRV_FloatOrBFloat16,
-                                                      []> {
+def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
   let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
 
   let description = [{
@@ -289,10 +251,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FSub",
 
 // -----
 
-def SPIRV_IAddOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"IAdd",
-                                                      SPIRV_Integer,
-                                                      SPIRV_Integer,
-                                                      [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
+                                        SPIRV_Integer,
+                                        [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer addition of Operand 1 and Operand 2.";
 
   let description = [{
@@ -361,10 +322,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
 
 // -----
 
-def SPIRV_IMulOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"IMul",
-                                                      SPIRV_Integer,
-                                                      SPIRV_Integer,
-                                                      [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
+                                        SPIRV_Integer,
+                                        [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -394,10 +354,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"IMul",
 
 // -----
 
-def SPIRV_ISubOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"ISub",
-                                                      SPIRV_Integer,
-                                                      SPIRV_Integer,
-                                                      [UsableInSpecConstantOp]> {
+def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
+                                        SPIRV_Integer,
+                                        [UsableInSpecConstantOp]> {
   let summary = "Integer subtraction of Operand 2 from Operand 1.";
 
   let description = [{
@@ -486,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
   }];
 
   let arguments = (ins
-    SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector1,
-    SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector2
+    SPIRV_VectorOf<SPIRV_Float>:$vector1,
+    SPIRV_VectorOf<SPIRV_Float>:$vector2
   );
 
   let results = (outs
-    SPIRV_FloatOrBFloat16:$result
+    SPIRV_Float:$result
   );
 
   let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
@@ -501,10 +460,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
 
 // -----
 
-def SPIRV_SDivOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"SDiv",
-                                                      SPIRV_Integer,
-                                                      SPIRV_Integer,
-                                                      [UsableInSpecConstantOp]> {
+def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
+                                        SPIRV_Integer,
+                                        [UsableInSpecConstantOp]> {
   let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -602,7 +560,6 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
 // -----
 
 def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
-                                          SPIRV_Integer,
                                           SPIRV_Integer,
                                           [UsableInSpecConstantOp]> {
   let summary = "Signed-integer subtract of Operand from zero.";
@@ -665,10 +622,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
 
 // -----
 
-def SPIRV_UDivOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"UDiv",
-                                                      SPIRV_Integer,
-                                                      SPIRV_Integer,
-                                                      [UnsignedOp, UsableInSpecConstantOp]> {
+def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
+                                        SPIRV_Integer,
+                                        [UnsignedOp, UsableInSpecConstantOp]> {
   let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 301a5bab9ab1a..2d0c86e08de5a 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -265,13 +265,6 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
 
 // -----
 
-func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
-  %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
-  return %0 : bf16
-}
-
-// -----
-
 // expected-note @+1 {{prior use here}}
 func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
   // expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -290,7 +283,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
 // -----
 
 func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
-  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or bfloat16 type values of length 2/3/4/8/16}}
+  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
   %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
   return %0 : i32
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 4d161a3193505..d3e1dbc229ef9 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -31,15 +31,6 @@ spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuf
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_load_bf16
-spirv.func @cooperative_matrix_load_bf16(%ptr : !spirv.ptr<bf16, StorageBuffer>, %stride : i32) "None" {
-  // CHECK:      {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>
-  // CHECK-SAME:   : !spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
-  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
-    !spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
-  spirv.Return
-}
-
 // CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
 spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
@@ -234,26 +225,6 @@ spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgro
   spirv.Return
 }
 
-spirv.func @cooperative_matrix_muladd_bf16_bf16(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
-                                             %b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
-                                             %c : !spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>) "None" {
-  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
-        !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
-        !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
-          !spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>
-  spirv.Return
-}
-
-spirv.func @cooperative_matrix_muladd_bf16_f32(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
-                                             %b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
-                                             %c : !spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>) "None" {
-  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
-        !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
-        !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
-          !spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>
-  spirv.Return
-}
-
 spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
                                              %b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
                                              %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
@@ -437,9 +408,6 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1
 !matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
 !matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
 
-!matA_bf16 = !spirv.coopmatrix<2x2xbf16, Subgroup, MatrixA>
-!matB_bf16 = !spirv.coopmatrix<2x2xbf16, Subgroup, MatrixB>
-
 // These tests are kept in the same order as the list of compatible ops in the
 // SPV_KHR_cooperative_matrix extension spec.
 
@@ -452,8 +420,8 @@ spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fnegate_f32
-spirv.func @fnegate_f32(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fnegate
+spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
   %p = spirv.FNegate %a : !matA_f32
@@ -461,15 +429,6 @@ spirv.func @fnegate_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fnegate_bf16
-spirv.func @fnegate_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
-  // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
-  // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
-  %p = spirv.FNegate %a : !matA_bf16
-  %q = spirv.FNegate %b : !matB_bf16
-  spirv.Return
-}
-
 // CHECK-LABEL: @iadd
 spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
   // CHECK:      spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
@@ -479,8 +438,8 @@ spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fadd_f32
-spirv.func @fadd_f32(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fadd
+spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
   %p = spirv.FAdd %a, %a : !matA_f32
@@ -488,15 +447,6 @@ spirv.func @fadd_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fadd_bf16
-spirv.func @fadd_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
-  // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
-  // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
-  %p = spirv.FAdd %a, %a : !matA_bf16
-  %q = spirv.FAdd %b, %b : !matB_bf16
-  spirv.Return
-}
-
 // CHECK-LABEL: @isub
 spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
   // CHECK:      spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
@@ -506,8 +456,8 @@ spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fsub_f32
-spirv.func @fsub_f32(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fsub
+spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
   %p = spirv.FSub %a, %a : !matA_f32
@@ -515,17 +465,8 @@ spirv.func @fsub_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fsub_bf16
-spirv.func @fsub_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
-  // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
-  // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
-  %p = spirv.FSub %a, %a : !matA_bf16
-  %q = spirv.FSub %b, %b : !matB_bf16
-  spirv.Return
-}
-
-// CHECK-LABEL: @fmul_f32
-spirv.func @fmul_f32(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fmul
+spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
   %p = spirv.FMul %a, %a : !matA_f32
@@ -533,15 +474,6 @@ spirv.func @fmul_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fmul_bf16
-spirv.func @fmul_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
-  // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
-  // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
-  %p = spirv.FMul %a, %a : !matA_bf16
-  %q = spirv.FMul %b, %b : !matB_bf16
-  spirv.Return
-}
-
 // CHECK-LABEL: @imul
 spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
   // CHECK:      spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
@@ -551,8 +483,8 @@ spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fdiv_f32
-spirv.func @fdiv_f32(%a: !matA_f32, %b: !matB_f32) "None" {
+// CHECK-LABEL: @fdiv
+spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
   // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
   // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
   %p = spirv.FDiv %a, %a : !matA_f32
@@ -560,15 +492,6 @@ spirv.func @fdiv_f32(%a: !matA_f32, %b: !matB_f32) "None" {
   spirv.Return
 }
 
-// CHECK-LABEL: @fdiv_bf16
-spirv.func @fdiv_bf16(%a: !matA_bf16, %b: !matB_bf16) "None" {
-  // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
-  // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
-  %p = spirv.FDiv %a, %a : !matA_bf16
-  %q = spirv.FDiv %b, %b : !matB_bf16
-  spirv.Return
-}
-
 // CHECK-LABEL: @sdiv
 spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
   // CHECK:      spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix

>From 3a4b9364ff2e65d1476fe50d428839066b8dda8c Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 29 May 2025 00:24:17 -0400
Subject: [PATCH 05/13] Add roundtrip tests and implement deserialization

---
 .../SPIRV/Deserialization/Deserializer.cpp    | 21 ++++++++++--
 .../FuncToSPIRV/types-to-spirv.mlir           |  6 ++++
 mlir/test/Target/SPIRV/cast-ops.mlir          | 32 ++++++++++++++++++-
 3 files changed, 55 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 7afd6e9b25b77..b71595ced93bd 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -866,11 +866,12 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
   } break;
   case spirv::Opcode::OpTypeFloat: {
-    if (operands.size() != 2)
+    if (operands.size() < 2)
       return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
+    uint32_t bitWidth = operands[1];
 
     Type floatTy;
-    switch (operands[1]) {
+    switch (bitWidth) {
     case 16:
       floatTy = opBuilder.getF16Type();
       break;
@@ -882,8 +883,22 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
       break;
     default:
       return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
-             << operands[1];
+             << bitWidth;
     }
+
+    if (operands.size() == 3) {
+      if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR) {
+        return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
+               << operands[2];
+      }
+      if (bitWidth != 16) {
+        return emitError(unknownLoc,
+                         "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
+               << bitWidth << " (expected 16)";
+      }
+      floatTy = opBuilder.getBF16Type();
+    }
+
     typeMap[operands[0]] = floatTy;
   } break;
   case spirv::Opcode::OpTypeVector: {
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 2e34c9ff54012..1737f4a906bf8 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -173,6 +173,12 @@ func.func @float16(%arg0: f16) { return }
 // NOEMU-SAME: f64
 func.func @float64(%arg0: f64) { return }
 
+// CHECK-LABEL: spirv.func @bfloat16
+// CHECK-SAME: f32
+// NOEMU-LABEL: func.func @bfloat16
+// NOEMU-SAME: bf16
+func.func @bfloat16(%arg0: bf16) { return }
+
 // f80 is not supported by SPIR-V.
 // CHECK-LABEL: func.func @float80
 // CHECK-SAME: f80
diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index ede0bf30511ef..04a468b39b645 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -25,6 +25,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.ConvertFToS %arg0 : f64 to i32
     spirv.ReturnValue %0 : i32
   }
+  spirv.func @convert_bf16_to_s32(%arg0 : bf16) -> i32 "None" {
+    // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
+    %0 = spirv.ConvertFToS %arg0 : bf16 to i32
+    spirv.ReturnValue %0 : i32
+  }
   spirv.func @convert_f_to_u(%arg0 : f32) -> i32 "None" {
     // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : f32 to i32
     %0 = spirv.ConvertFToU %arg0 : f32 to i32
@@ -35,6 +40,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.ConvertFToU %arg0 : f64 to i32
     spirv.ReturnValue %0 : i32
   }
+  spirv.func @convert_bf16_to_u32(%arg0 : bf16) -> i32 "None" {
+    // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
+    %0 = spirv.ConvertFToU %arg0 : bf16 to i32
+    spirv.ReturnValue %0 : i32
+  }
   spirv.func @convert_s_to_f(%arg0 : i32) -> f32 "None" {
     // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to f32
     %0 = spirv.ConvertSToF %arg0 : i32 to f32
@@ -45,6 +55,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.ConvertSToF %arg0 : i64 to f32
     spirv.ReturnValue %0 : f32
   }
+  spirv.func @convert_s64_to_bf16(%arg0 : i64) -> bf16 "None" {
+    // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i64 to bf16
+    %0 = spirv.ConvertSToF %arg0 : i64 to bf16
+    spirv.ReturnValue %0 : bf16
+  }
   spirv.func @convert_u_to_f(%arg0 : i32) -> f32 "None" {
     // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to f32
     %0 = spirv.ConvertUToF %arg0 : i32 to f32
@@ -55,11 +70,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.ConvertUToF %arg0 : i64 to f32
     spirv.ReturnValue %0 : f32
   }
-  spirv.func @f_convert(%arg0 : f32) -> f64 "None" {
+  spirv.func @convert_u64_to_bf16(%arg0 : i64) -> bf16 "None" {
+    // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i64 to bf16
+    %0 = spirv.ConvertUToF %arg0 : i64 to bf16
+    spirv.ReturnValue %0 : bf16
+  }
+  spirv.func @convert_f32_to_f64(%arg0 : f32) -> f64 "None" {
     // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : f32 to f64
     %0 = spirv.FConvert %arg0 : f32 to f64
     spirv.ReturnValue %0 : f64
   }
+  spirv.func @convert_f32_to_bf16(%arg0 : f32) -> bf16 "None" {
+    // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : f32 to bf16
+    %0 = spirv.FConvert %arg0 : f32 to bf16
+    spirv.ReturnValue %0 : bf16
+  }
+  spirv.func @convert_bf16_to_f32(%arg0 : bf16) -> f32 "None" {
+    // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
+    %0 = spirv.FConvert %arg0 : bf16 to f32
+    spirv.ReturnValue %0 : f32
+  }
   spirv.func @s_convert(%arg0 : i32) -> i64 "None" {
     // CHECK: {{%.*}} = spirv.SConvert {{%.*}} : i32 to i64
     %0 = spirv.SConvert %arg0 : i32 to i64

>From 40ae9196155cd409d74dbfbf7329cca115f8e16f Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 29 May 2025 00:32:47 -0400
Subject: [PATCH 06/13] Remove dead code and add simple array test

---
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | 5 -----
 mlir/test/Dialect/SPIRV/IR/types.mlir           | 3 +++
 2 files changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index bbb0d65b70853..354f398f73eb3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4252,11 +4252,6 @@ class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
     AnyTypeOf<[type, SPIRV_VectorOf<type>,
                SPIRV_CoopMatrixOfType<[type]>]>;
 
-class SPIRV_ScalarOrVectorOfOrCoopMatrixOf<Type scalarVectorType,
-                                           Type coopMatrixType> :
-    AnyTypeOf<[scalarVectorType, SPIRV_VectorOf<scalarVectorType>,
-               SPIRV_CoopMatrixOfType<[coopMatrixType]>]>;
-
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
     AnyTypeOf<[SPIRV_AnyMatrix,
                SPIRV_CoopMatrixOfType<[type]>]>;
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index a81fe72a8362e..c23894c62826b 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -15,6 +15,9 @@ func.func private @vector_array_type(!spirv.array< 32 x vector<4xf32> >) -> ()
 // CHECK: func private @array_type_stride(!spirv.array<4 x !spirv.array<4 x f32, stride=4>, stride=128>)
 func.func private @array_type_stride(!spirv.array< 4 x !spirv.array<4 x f32, stride=4>, stride = 128>) -> ()
 
+// CHECK: func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16>>)
+func.func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16> >) -> ()
+
 // -----
 
 // expected-error @+1 {{expected '<'}}

>From 47ba69613281c5736cd8230877024e31e805f15e Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 29 May 2025 00:41:57 -0400
Subject: [PATCH 07/13] Remove non-standard comma at end of array

---
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 354f398f73eb3..a14a5d6d5fd81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -1537,7 +1537,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
       SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
       SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
-      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
+      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;

>From 49e8e86b8463fc5b0e5c66e30fdc79e244f6bbfc Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Fri, 30 May 2025 00:35:09 -0400
Subject: [PATCH 08/13] Address review comments

---
 .../lib/Target/SPIRV/Deserialization/Deserializer.cpp | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b71595ced93bd..3bb38585b5bf4 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -866,8 +866,9 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
   } break;
   case spirv::Opcode::OpTypeFloat: {
-    if (operands.size() < 2)
-      return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
+    if (operands.size() != 2 && operands.size() != 3)
+      return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter "
+                                   "and optional floating point encoding");
     uint32_t bitWidth = operands[1];
 
     Type floatTy;
@@ -887,15 +888,13 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     }
 
     if (operands.size() == 3) {
-      if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR) {
+      if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
         return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
                << operands[2];
-      }
-      if (bitWidth != 16) {
+      if (bitWidth != 16)
         return emitError(unknownLoc,
                          "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
                << bitWidth << " (expected 16)";
-      }
       floatTy = opBuilder.getBF16Type();
     }
 

>From 4248d7c54c89cb13408d8cc32eeedb5740d6339c Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 5 Jun 2025 00:59:56 -0400
Subject: [PATCH 09/13] Add Intel downstream changes and address review
 comments

---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  9 ++++----
 .../mlir/Dialect/SPIRV/IR/SPIRVCastOps.td     | 12 +++++-----
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 22 +++++++++++++++++--
 .../SPIRV/Deserialization/Deserializer.cpp    |  3 +++
 .../Target/SPIRV/Serialization/Serializer.cpp |  8 ++++---
 5 files changed, 39 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index a14a5d6d5fd81..19ef961d6c0e7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4190,11 +4190,12 @@ def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
 def SPIRV_Int16 : TypeAlias<I16, "Int16">;
 def SPIRV_Int32 : TypeAlias<I32, "Int32">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
+def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
-def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>;
+def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
 def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
-                                       [SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>;
+                                       [SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
 // Component type check is done in the type parser for the following SPIR-V
 // dialect-specific types so we use "Any" here.
 def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
@@ -4217,14 +4218,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
 def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
                                 "any SPIR-V sampled image type">;
 
-def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>;
+def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
                SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
 def SPIRV_Type : AnyTypeOf<[
-    SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector,
+    SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
   ]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index 29571cf138ebf..a5c8aa8fb450c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
 
 // -----
 
-def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
+def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_AnyFloat, []> {
   let summary = [{
     Convert value numerically from floating point to signed integer, with
     round toward 0.0.
@@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float
 
 // -----
 
-def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
+def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_AnyFloat, []> {
   let summary = [{
     Convert value numerically from floating point to unsigned integer, with
     round toward 0.0.
@@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
 // -----
 
 def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
-                                   SPIRV_FloatOrBFloat16,
+                                   SPIRV_AnyFloat,
                                    SPIRV_Integer,
                                    [SignedOp]> {
   let summary = [{
@@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
 // -----
 
 def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
-                                   SPIRV_FloatOrBFloat16,
+                                   SPIRV_AnyFloat,
                                    SPIRV_Integer,
                                    [UnsignedOp]> {
   let summary = [{
@@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
 // -----
 
 def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
-                                SPIRV_FloatOrBFloat16,
-                                SPIRV_FloatOrBFloat16,
+                                SPIRV_AnyFloat,
+                                SPIRV_AnyFloat,
                                 [UsableInSpecConstantOp]> {
   let summary = [{
     Convert value numerically from one floating-point width to another
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 5da3164ad4d14..f5bcb5318e08d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -514,6 +514,12 @@ bool ScalarType::isValid(IntegerType type) {
 
 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                std::optional<StorageClass> storage) {
+  if (llvm::isa<BFloat16Type>(*this)) {
+    static const Extension exts[] = {Extension::SPV_KHR_bfloat16};
+    ArrayRef<Extension> ref(exts, std::size(exts));
+    extensions.push_back(ref);
+  }
+
   // 8- or 16-bit integer/floating-point numbers will require extra extensions
   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
   // SPV_KHR_8bit_storage for more details.
@@ -532,7 +538,7 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
     [[fallthrough]];
   case StorageClass::Input:
   case StorageClass::Output:
-    if (getIntOrFloatBitWidth() == 16) {
+    if (getIntOrFloatBitWidth() == 16 && !llvm::isa<BFloat16Type>(*this)) {
       static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
       ArrayRef<Extension> ref(exts, std::size(exts));
       extensions.push_back(ref);
@@ -619,7 +625,19 @@ void ScalarType::getCapabilities(
   } else {
     assert(llvm::isa<FloatType>(*this));
     switch (bitwidth) {
-      WIDTH_CASE(Float, 16);
+    case 16: {
+      if (llvm::isa<BFloat16Type>(*this)) {
+        static const Capability caps[] = {Capability::BFloat16TypeKHR};
+        ArrayRef<Capability> ref(caps, std::size(caps));
+        capabilities.push_back(ref);
+
+      } else {
+        static const Capability caps[] = {Capability::Float16};
+        ArrayRef<Capability> ref(caps, std::size(caps));
+        capabilities.push_back(ref);
+      }
+      break;
+    }
       WIDTH_CASE(Float, 64);
     case 32:
       break;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 900eb0fd3178b..7af1a3f276235 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1413,6 +1413,9 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
     } else if (floatType.isF16()) {
       APInt data(16, operands[2]);
       value = APFloat(APFloat::IEEEhalf(), data);
+    } else if (floatType.isBF16()) {
+      APInt data(16, operands[2]);
+      value = APFloat(APFloat::BFloat(), data);
     }
 
     auto attr = opBuilder.getFloatAttr(floatType, value);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b43f22db55a2e..5de498cb454a7 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -999,21 +999,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
 
   auto resultID = getNextID();
   APFloat value = floatAttr.getValue();
+  const llvm::fltSemantics *semantics = &value.getSemantics();
 
   auto opcode =
       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
 
-  if (&value.getSemantics() == &APFloat::IEEEsingle()) {
+  if (semantics == &APFloat::IEEEsingle()) {
     uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
-  } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
+  } else if (semantics == &APFloat::IEEEdouble()) {
     struct DoubleWord {
       uint32_t word1;
       uint32_t word2;
     } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
     encodeInstructionInto(typesGlobalValues, opcode,
                           {typeID, resultID, words.word1, words.word2});
-  } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
+  } else if (semantics == &APFloat::IEEEhalf() ||
+             semantics == &APFloat::BFloat()) {
     uint32_t word =
         static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});

>From f2ba086a26d1ae754ecc5421093a1d2685fd52c4 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 5 Jun 2025 01:15:05 -0400
Subject: [PATCH 10/13] Add tests for errors for arithmetic ops

---
 .../test/Dialect/SPIRV/IR/arithmetic-ops.mlir | 64 +++++++++++++++++++
 1 file changed, 64 insertions(+)

diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 2d0c86e08de5a..d58c27598f2b8 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -12,6 +12,14 @@ func.func @fadd_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fadd_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FAdd %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FDiv
 //===----------------------------------------------------------------------===//
@@ -24,6 +32,14 @@ func.func @fdiv_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fdiv_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FDiv %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FMod
 //===----------------------------------------------------------------------===//
@@ -36,6 +52,14 @@ func.func @fmod_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fmod_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FMod %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FMul
 //===----------------------------------------------------------------------===//
@@ -70,6 +94,14 @@ func.func @fmul_bf16(%arg: bf16) -> bf16 {
 
 // -----
 
+func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FMul %arg, %arg : vector<4xbf16>
+  return %0 : vector<4xbf16>
+}
+
+// -----
+
 func.func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
   // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
   %0 = spirv.FMul %arg, %arg : tensor<4xf32>
@@ -90,6 +122,14 @@ func.func @fnegate_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fnegate_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FNegate %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FRem
 //===----------------------------------------------------------------------===//
@@ -102,6 +142,14 @@ func.func @frem_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @frem_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FRem %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FSub
 //===----------------------------------------------------------------------===//
@@ -114,6 +162,14 @@ func.func @fsub_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fsub_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FSub %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IAdd
 //===----------------------------------------------------------------------===//
@@ -489,3 +545,11 @@ func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3
   %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<3xf32>
   return %0 : vector<3xf32>
 }
+
+// -----
+
+func.func @vector_bf16_times_scalar_bf16(%vector: vector<4xbf16>, %scalar: bf16) -> vector<4xbf16> {
+  // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}}
+  %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xbf16>, bf16) -> vector<4xbf16>
+  return %0 : vector<4xbf16>
+}

>From c0fee39c2a5c3f1d9205286a0872dca348287a71 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 5 Jun 2025 01:43:08 -0400
Subject: [PATCH 11/13] Add more error checking tests and select/constant tests

---
 mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir    |  8 +++++++
 mlir/test/Dialect/SPIRV/IR/composite-ops.mlir |  7 ++++++
 mlir/test/Dialect/SPIRV/IR/gl-ops.mlir        | 17 ++++++++++++++
 mlir/test/Dialect/SPIRV/IR/logical-ops.mlir   |  8 +++++++
 .../Dialect/SPIRV/IR/non-uniform-ops.mlir     | 16 +++++++++++++
 mlir/test/Target/SPIRV/logical-ops.mlir       | 23 +++++++++++++++++++
 6 files changed, 79 insertions(+)

diff --git a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
index cc0abd3a42dcb..661497d5fff38 100644
--- a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
@@ -272,3 +272,11 @@ func.func @atomic_fadd(%ptr : !spirv.ptr<f32, StorageBuffer>, %value : f32) -> f
   %0 = spirv.EXT.AtomicFAdd <Device> <Acquire|Release> %ptr, %value : !spirv.ptr<f32, StorageBuffer>
   return %0 : f32
 }
+
+// -----
+
+func.func @atomic_bf16_fadd(%ptr : !spirv.ptr<bf16, StorageBuffer>, %value : bf16) -> bf16 {
+  // expected-error @+1 {{op operand #1 must be 16/32/64-bit float, but got 'bf16'}}
+  %0 = spirv.EXT.AtomicFAdd <Device> <None> %ptr, %value : !spirv.ptr<bf16, StorageBuffer>
+  return %0 : bf16
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 3fc8dfb2767d1..e71b545de11df 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve
   return %0: vector<3xf32>
 }
 
+// CHECK-LABEL: func @composite_construct_bf16_vector
+func.func @composite_construct_bf16_vector(%arg0: bf16, %arg1: bf16, %arg2 : bf16) -> vector<3xbf16> {
+  // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (bf16, bf16, bf16) -> vector<3xbf16>
+  %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (bf16, bf16, bf16) -> vector<3xbf16>
+  return %0: vector<3xbf16>
+}
+
 // CHECK-LABEL: func @composite_construct_struct
 func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
   // CHECK: spirv.CompositeConstruct
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 0be047932c1f3..4c141a285cd30 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -50,6 +50,14 @@ func.func @exp(%arg0 : i32) -> () {
 
 // -----
 
+func.func @exp_bf16(%arg0 : bf16) -> () {
+  // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
+  %2 = spirv.GL.Exp %arg0 : bf16
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GL.{F|S|U}{Max|Min}
 //===----------------------------------------------------------------------===//
@@ -92,6 +100,15 @@ func.func @iminmax(%arg0: i32, %arg1: i32) {
 
 // -----
 
+func.func @fmaxminbf16vec(%arg0 : vector<3xbf16>, %arg1 : vector<3xbf16>) {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %1 = spirv.GL.FMax %arg0, %arg1 : vector<3xbf16>
+  %2 = spirv.GL.FMin %arg0, %arg1 : vector<3xbf16>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GL.InverseSqrt
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index 5c24f0e6a7d33..d6c34645f5746 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -201,6 +201,14 @@ func.func @select_op_float(%arg0: i1) -> () {
   return
 }
 
+func.func @select_op_bfloat16(%arg0: i1) -> () {
+  %0 = spirv.Constant 2.0 : bf16
+  %1 = spirv.Constant 3.0 : bf16
+  // CHECK: spirv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, bf16
+  %2 = spirv.Select %arg0, %0, %1 : i1, bf16
+  return
+}
+
 func.func @select_op_ptr(%arg0: i1) -> () {
   %0 = spirv.Variable : !spirv.ptr<f32, Function>
   %1 = spirv.Variable : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 5f56de6ad1fa9..7ab94f17360d5 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -184,6 +184,14 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto
 
 // -----
 
+func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+  %0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16
+  return %0: bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformFMax
 //===----------------------------------------------------------------------===//
@@ -197,6 +205,14 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 {
 
 // -----
 
+func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+  %0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16
+  return %0: bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformFMin
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir
index 16846ac84e38c..b2008719b021c 100644
--- a/mlir/test/Target/SPIRV/logical-ops.mlir
+++ b/mlir/test/Target/SPIRV/logical-ops.mlir
@@ -108,3 +108,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.Return
   }
 }
+
+// -----
+
+// Test select works with bf16 scalar and vectors.
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  spirv.SpecConstant @condition_scalar = true
+  spirv.func @select_bf16() -> () "None" {
+    %0 = spirv.Constant 4.0 : bf16
+    %1 = spirv.Constant 5.0 : bf16
+    %2 = spirv.mlir.referenceof @condition_scalar : i1
+    // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, bf16
+    %3 = spirv.Select %2, %0, %1 : i1, bf16
+    %4 = spirv.Constant dense<[2.0, 3.0, 4.0, 5.0]> : vector<4xbf16>
+    %5 = spirv.Constant dense<[6.0, 7.0, 8.0, 9.0]> : vector<4xbf16>
+    // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, vector<4xbf16>
+    %6 = spirv.Select %2, %4, %5 : i1, vector<4xbf16>
+    %7 = spirv.Constant dense<[true, true, true, true]> : vector<4xi1>
+    // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : vector<4xi1>, vector<4xbf16>
+    %8 = spirv.Select %7, %4, %5 : vector<4xi1>, vector<4xbf16>
+    spirv.Return
+  }
+}

>From e01695be7cf4056441c403eeeeabd0cccc32db12 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 5 Jun 2025 10:32:28 -0600
Subject: [PATCH 12/13] Address review comments

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 22 ++++++++-----------
 .../SPIRV/Deserialization/Deserializer.cpp    |  5 +++--
 2 files changed, 12 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index f5bcb5318e08d..c7ec00218c164 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -514,10 +514,9 @@ bool ScalarType::isValid(IntegerType type) {
 
 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                std::optional<StorageClass> storage) {
-  if (llvm::isa<BFloat16Type>(*this)) {
-    static const Extension exts[] = {Extension::SPV_KHR_bfloat16};
-    ArrayRef<Extension> ref(exts, std::size(exts));
-    extensions.push_back(ref);
+  if (isa<BFloat16Type>(*this)) {
+    static const Extension ext = Extension::SPV_KHR_bfloat16;
+    extensions.push_back(ext);
   }
 
   // 8- or 16-bit integer/floating-point numbers will require extra extensions
@@ -538,7 +537,7 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
     [[fallthrough]];
   case StorageClass::Input:
   case StorageClass::Output:
-    if (getIntOrFloatBitWidth() == 16 && !llvm::isa<BFloat16Type>(*this)) {
+    if (getIntOrFloatBitWidth() == 16 && !isa<BFloat16Type>(*this)) {
       static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
       ArrayRef<Extension> ref(exts, std::size(exts));
       extensions.push_back(ref);
@@ -626,15 +625,12 @@ void ScalarType::getCapabilities(
     assert(llvm::isa<FloatType>(*this));
     switch (bitwidth) {
     case 16: {
-      if (llvm::isa<BFloat16Type>(*this)) {
-        static const Capability caps[] = {Capability::BFloat16TypeKHR};
-        ArrayRef<Capability> ref(caps, std::size(caps));
-        capabilities.push_back(ref);
-
+      if (isa<BFloat16Type>(*this)) {
+        static const Capability cap = Capability::BFloat16TypeKHR;
+        capabilities.push_back(cap);
       } else {
-        static const Capability caps[] = {Capability::Float16};
-        ArrayRef<Capability> ref(caps, std::size(caps));
-        capabilities.push_back(ref);
+        static const Capability cap = Capability::Float16;
+        capabilities.push_back(cap);
       }
       break;
     }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 7af1a3f276235..3a9310ba52adb 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -868,8 +868,9 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
   } break;
   case spirv::Opcode::OpTypeFloat: {
     if (operands.size() != 2 && operands.size() != 3)
-      return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter "
-                                   "and optional floating point encoding");
+      return emitError(unknownLoc,
+                       "OpTypeFloat expects either 2 operands (type, bitwidth) "
+                       "or 3 operands (type, bitwidth, encoding)");
     uint32_t bitWidth = operands[1];
 
     Type floatTy;

>From ef930f46b4ed13f6b8e84c17bcce4938a73fd830 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 5 Jun 2025 15:19:38 -0600
Subject: [PATCH 13/13] Fix VCE and add test

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp           |  2 +-
 .../Dialect/SPIRV/Transforms/vce-deduction.mlir    | 14 ++++++++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index c7ec00218c164..1e71f4277f660 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -537,7 +537,7 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
     [[fallthrough]];
   case StorageClass::Input:
   case StorageClass::Output:
-    if (getIntOrFloatBitWidth() == 16 && !isa<BFloat16Type>(*this)) {
+    if (getIntOrFloatBitWidth() == 16) {
       static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
       ArrayRef<Extension> ref(exts, std::size(exts));
       extensions.push_back(ref);
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index ff5ac7cea8fc6..2b237665ffc4a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -217,3 +217,17 @@ spirv.module Logical GLSL450 attributes {
   spirv.GlobalVariable @data : !spirv.ptr<!spirv.struct<(i8 [0], f16 [2], i64 [4])>, Uniform>
   spirv.GlobalVariable @img  : !spirv.ptr<!spirv.image<f32, Buffer, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Rg32f>, UniformConstant>
 }
+
+// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,
+    #spirv.resource_limits<>
+  >
+} {
+  spirv.func @load_bf16(%ptr : !spirv.ptr<bf16, StorageBuffer>) -> bf16 "None" {
+    %val = spirv.Load "StorageBuffer" %ptr : bf16
+    spirv.ReturnValue %val : bf16
+  }
+}



More information about the Mlir-commits mailing list