[Mlir-commits] [mlir] 0a0960d - [mlir][spirv] Add bfloat16 support (#141458)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 13 07:14:49 PDT 2025


Author: Darren Wihandi
Date: 2025-06-13T10:14:45-04:00
New Revision: 0a0960dac69fc88a3c8bd5e2099f8d45b0292c78

URL: https://github.com/llvm/llvm-project/commit/0a0960dac69fc88a3c8bd5e2099f8d45b0292c78
DIFF: https://github.com/llvm/llvm-project/commit/0a0960dac69fc88a3c8bd5e2099f8d45b0292c78.diff

LOG: [mlir][spirv] Add bfloat16 support (#141458)

Adds bf16 support to SPIRV by using the `SPV_KHR_bfloat16` extension.
Only a few operations are supported, including loading from and storing
to memory, conversion to/from other types, cooperative matrix operations
(including coop matrix arithmetic ops) and dot product support.

This PR adds the type definition and implements the basic cast
operations. Arithmetic/coop matrix ops will be added in a separate PR.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
    mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
    mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
    mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
    mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
    mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
    mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
    mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
    mlir/test/Dialect/SPIRV/IR/types.mlir
    mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
    mlir/test/Target/SPIRV/cast-ops.mlir
    mlir/test/Target/SPIRV/logical-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index b143cf9a5f509..e413503bbd672 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate                  : I32EnumAttrCase<"SPV_KHR_subgroup
 def SPV_KHR_non_semantic_info                : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
 def SPV_KHR_terminate_invocation             : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
 def SPV_KHR_cooperative_matrix               : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
+def SPV_KHR_bfloat16                         : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;
 
 def SPV_EXT_demote_to_helper_invocation  : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
 def SPV_EXT_descriptor_indexing          : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
@@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
       SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
       SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
       SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
-      SPV_KHR_cooperative_matrix,
+      SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
       SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
       SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
       SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
@@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV                          : I32EnumAttrCase<"Shade
     Extension<[SPV_NV_stereo_view_rendering]>
   ];
 }
+def SPIRV_C_BFloat16TypeKHR                             : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
+  list<Availability> availability = [
+    Extension<[SPV_KHR_bfloat16]>
+  ];
+}
+def SPIRV_C_BFloat16DotProductKHR                       : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
+  list<Availability> availability = [
+    Extension<[SPV_KHR_bfloat16]>
+  ];
+}
+def SPIRV_C_BFloat16CooperativeMatrixKHR                : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
+  list<Availability> availability = [
+    Extension<[SPV_KHR_bfloat16]>
+  ];
+}
 
 def SPIRV_C_Bfloat16ConversionINTEL                         : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
   list<Availability> availability = [
@@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
       SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
       SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
-      SPIRV_C_CacheControlsINTEL
+      SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
+      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
       SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
     ]>;
 
+def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_BFloat16TypeKHR]>
+  ];
+}
+def SPIRV_FPEncodingAttr :
+    SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
+      SPIRV_FPE_BFloat16KHR
+    ]>;
+
 def SPIRV_FC_None         : I32BitEnumAttrCaseNone<"None">;
 def SPIRV_FC_Inline       : I32BitEnumAttrCaseBit<"Inline", 0>;
 def SPIRV_FC_DontInline   : I32BitEnumAttrCaseBit<"DontInline", 1>;
@@ -4161,10 +4190,12 @@ def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
 def SPIRV_Int16 : TypeAlias<I16, "Int16">;
 def SPIRV_Int32 : TypeAlias<I32, "Int32">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
+def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
+def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
 def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
-                                       [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
+                                       [SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
 // Component type check is done in the type parser for the following SPIR-V
 // dialect-specific types so we use "Any" here.
 def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
@@ -4187,14 +4218,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
 def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
                                 "any SPIR-V sampled image type">;
 
-def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>;
+def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
                SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
 def SPIRV_Type : AnyTypeOf<[
-    SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
+    SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
     SPIRV_AnyImage

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index b05ee0251df5b..a5c8aa8fb450c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
 
 // -----
 
-def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_AnyFloat, []> {
   let summary = [{
     Convert value numerically from floating point to signed integer, with
     round toward 0.0.
@@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float
 
 // -----
 
-def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_AnyFloat, []> {
   let summary = [{
     Convert value numerically from floating point to unsigned integer, with
     round toward 0.0.
@@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
 // -----
 
 def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
-                                   SPIRV_Float,
+                                   SPIRV_AnyFloat,
                                    SPIRV_Integer,
                                    [SignedOp]> {
   let summary = [{
@@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
 // -----
 
 def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
-                                   SPIRV_Float,
+                                   SPIRV_AnyFloat,
                                    SPIRV_Integer,
                                    [UnsignedOp]> {
   let summary = [{
@@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
 // -----
 
 def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
-                                SPIRV_Float,
-                                SPIRV_Float,
+                                SPIRV_AnyFloat,
+                                SPIRV_AnyFloat,
                                 [UsableInSpecConstantOp]> {
   let summary = [{
     Convert value numerically from one floating-point width to another

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 0cf5f0823be63..a21acef1c4b43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
 
   // Check other allowed types
   if (auto t = llvm::dyn_cast<FloatType>(type)) {
-    if (type.isBF16()) {
-      parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
-      return Type();
-    }
+    // TODO: All float types are allowed for now, but this should be fixed.
   } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
     if (!ScalarType::isValid(t)) {
       parser.emitError(typeLoc,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 1aff43c301334..93e0c9b33c546 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -526,7 +526,7 @@ bool ScalarType::classof(Type type) {
 }
 
 bool ScalarType::isValid(FloatType type) {
-  return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
+  return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
 }
 
 bool ScalarType::isValid(IntegerType type) {
@@ -535,6 +535,11 @@ bool ScalarType::isValid(IntegerType type) {
 
 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                std::optional<StorageClass> storage) {
+  if (isa<BFloat16Type>(*this)) {
+    static const Extension ext = Extension::SPV_KHR_bfloat16;
+    extensions.push_back(ext);
+  }
+
   // 8- or 16-bit integer/floating-point numbers will require extra extensions
   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
   // SPV_KHR_8bit_storage for more details.
@@ -640,7 +645,16 @@ void ScalarType::getCapabilities(
   } else {
     assert(llvm::isa<FloatType>(*this));
     switch (bitwidth) {
-      WIDTH_CASE(Float, 16);
+    case 16: {
+      if (isa<BFloat16Type>(*this)) {
+        static const Capability cap = Capability::BFloat16TypeKHR;
+        capabilities.push_back(cap);
+      } else {
+        static const Capability cap = Capability::Float16;
+        capabilities.push_back(cap);
+      }
+      break;
+    }
       WIDTH_CASE(Float, 64);
     case 32:
       break;

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index c43d584d7b913..b9d9a9015eb61 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -867,11 +867,15 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
   } break;
   case spirv::Opcode::OpTypeFloat: {
-    if (operands.size() != 2)
-      return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
+    if (operands.size() != 2 && operands.size() != 3)
+      return emitError(unknownLoc,
+                       "OpTypeFloat expects either 2 operands (type, bitwidth) "
+                       "or 3 operands (type, bitwidth, encoding), but got ")
+             << operands.size();
+    uint32_t bitWidth = operands[1];
 
     Type floatTy;
-    switch (operands[1]) {
+    switch (bitWidth) {
     case 16:
       floatTy = opBuilder.getF16Type();
       break;
@@ -883,8 +887,20 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
       break;
     default:
       return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
-             << operands[1];
+             << bitWidth;
+    }
+
+    if (operands.size() == 3) {
+      if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
+        return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
+               << operands[2];
+      if (bitWidth != 16)
+        return emitError(unknownLoc,
+                         "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
+               << bitWidth << " (expected 16)";
+      floatTy = opBuilder.getBF16Type();
     }
+
     typeMap[operands[0]] = floatTy;
   } break;
   case spirv::Opcode::OpTypeVector: {
@@ -1399,6 +1415,9 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
     } else if (floatType.isF16()) {
       APInt data(16, operands[2]);
       value = APFloat(APFloat::IEEEhalf(), data);
+    } else if (floatType.isBF16()) {
+      APInt data(16, operands[2]);
+      value = APFloat(APFloat::BFloat(), data);
     }
 
     auto attr = opBuilder.getFloatAttr(floatType, value);

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 647535809554c..d258bfd852961 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType(
   if (auto floatType = dyn_cast<FloatType>(type)) {
     typeEnum = spirv::Opcode::OpTypeFloat;
     operands.push_back(floatType.getWidth());
+    if (floatType.isBF16()) {
+      operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
+    }
     return success();
   }
 
@@ -1022,21 +1025,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
 
   auto resultID = getNextID();
   APFloat value = floatAttr.getValue();
+  const llvm::fltSemantics *semantics = &value.getSemantics();
 
   auto opcode =
       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
 
-  if (&value.getSemantics() == &APFloat::IEEEsingle()) {
+  if (semantics == &APFloat::IEEEsingle()) {
     uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
-  } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
+  } else if (semantics == &APFloat::IEEEdouble()) {
     struct DoubleWord {
       uint32_t word1;
       uint32_t word2;
     } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
     encodeInstructionInto(typesGlobalValues, opcode,
                           {typeID, resultID, words.word1, words.word2});
-  } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
+  } else if (semantics == &APFloat::IEEEhalf() ||
+             semantics == &APFloat::BFloat()) {
     uint32_t word =
         static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 82d750755ffe2..1737f4a906bf8 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -173,6 +173,12 @@ func.func @float16(%arg0: f16) { return }
 // NOEMU-SAME: f64
 func.func @float64(%arg0: f64) { return }
 
+// CHECK-LABEL: spirv.func @bfloat16
+// CHECK-SAME: f32
+// NOEMU-LABEL: func.func @bfloat16
+// NOEMU-SAME: bf16
+func.func @bfloat16(%arg0: bf16) { return }
+
 // f80 is not supported by SPIR-V.
 // CHECK-LABEL: func.func @float80
 // CHECK-SAME: f80
@@ -206,18 +212,6 @@ func.func @float64(%arg0: f64) { return }
 
 // -----
 
-// Check that bf16 is not supported.
-module attributes {
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-// CHECK-NOT: spirv.func @bf16_type
-func.func @bf16_type(%arg0: bf16) { return }
-
-} // end module
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // Complex types
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 2d0c86e08de5a..d58c27598f2b8 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -12,6 +12,14 @@ func.func @fadd_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fadd_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FAdd %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FDiv
 //===----------------------------------------------------------------------===//
@@ -24,6 +32,14 @@ func.func @fdiv_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fdiv_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FDiv %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FMod
 //===----------------------------------------------------------------------===//
@@ -36,6 +52,14 @@ func.func @fmod_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fmod_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FMod %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FMul
 //===----------------------------------------------------------------------===//
@@ -70,6 +94,14 @@ func.func @fmul_bf16(%arg: bf16) -> bf16 {
 
 // -----
 
+func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FMul %arg, %arg : vector<4xbf16>
+  return %0 : vector<4xbf16>
+}
+
+// -----
+
 func.func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
   // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
   %0 = spirv.FMul %arg, %arg : tensor<4xf32>
@@ -90,6 +122,14 @@ func.func @fnegate_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fnegate_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FNegate %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FRem
 //===----------------------------------------------------------------------===//
@@ -102,6 +142,14 @@ func.func @frem_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @frem_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FRem %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FSub
 //===----------------------------------------------------------------------===//
@@ -114,6 +162,14 @@ func.func @fsub_scalar(%arg: f32) -> f32 {
 
 // -----
 
+func.func @fsub_bf16_scalar(%arg: bf16) -> bf16 {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.FSub %arg, %arg : bf16
+  return %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IAdd
 //===----------------------------------------------------------------------===//
@@ -489,3 +545,11 @@ func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3
   %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<3xf32>
   return %0 : vector<3xf32>
 }
+
+// -----
+
+func.func @vector_bf16_times_scalar_bf16(%vector: vector<4xbf16>, %scalar: bf16) -> vector<4xbf16> {
+  // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}}
+  %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xbf16>, bf16) -> vector<4xbf16>
+  return %0 : vector<4xbf16>
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
index cc0abd3a42dcb..661497d5fff38 100644
--- a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
@@ -272,3 +272,11 @@ func.func @atomic_fadd(%ptr : !spirv.ptr<f32, StorageBuffer>, %value : f32) -> f
   %0 = spirv.EXT.AtomicFAdd <Device> <Acquire|Release> %ptr, %value : !spirv.ptr<f32, StorageBuffer>
   return %0 : f32
 }
+
+// -----
+
+func.func @atomic_bf16_fadd(%ptr : !spirv.ptr<bf16, StorageBuffer>, %value : bf16) -> bf16 {
+  // expected-error @+1 {{op operand #1 must be 16/32/64-bit float, but got 'bf16'}}
+  %0 = spirv.EXT.AtomicFAdd <Device> <None> %ptr, %value : !spirv.ptr<bf16, StorageBuffer>
+  return %0 : bf16
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 34d0109e6bb44..4480a1f3720f2 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -110,6 +110,14 @@ func.func @convert_f_to_s_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
 
 // -----
 
+func.func @convert_bf16_to_s32_scalar(%arg0 : bf16) -> i32 {
+  // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
+  %0 = spirv.ConvertFToS %arg0 : bf16 to i32
+  spirv.ReturnValue %0 : i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.ConvertFToU
 //===----------------------------------------------------------------------===//
@@ -146,6 +154,14 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou
 
 // -----
 
+func.func @convert_bf16_to_u32_scalar(%arg0 : bf16) -> i32 {
+  // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
+  %0 = spirv.ConvertFToU %arg0 : bf16 to i32
+  spirv.ReturnValue %0 : i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.ConvertSToF
 //===----------------------------------------------------------------------===//
@@ -174,6 +190,14 @@ func.func @convert_s_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
 
 // -----
 
+func.func @convert_s32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+  // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to bf16
+  %0 = spirv.ConvertSToF %arg0 : i32 to bf16
+  spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.ConvertUToF
 //===----------------------------------------------------------------------===//
@@ -202,6 +226,14 @@ func.func @convert_u_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
 
 // -----
 
+func.func @convert_u32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+  // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to bf16
+  %0 = spirv.ConvertUToF %arg0 : i32 to bf16
+  spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.FConvert
 //===----------------------------------------------------------------------===//
@@ -238,6 +270,30 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {
 
 // -----
 
+func.func @f_convert_bf16_to_f32_scalar(%arg0 : bf16) -> f32 {
+  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
+  %0 = spirv.FConvert %arg0 : bf16 to f32
+  spirv.ReturnValue %0 : f32
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_vector(%arg0 : vector<3xf32>) -> vector<3xbf16> {
+  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : vector<3xf32> to vector<3xbf16>
+  %0 = spirv.FConvert %arg0 : vector<3xf32> to vector<3xbf16>
+  spirv.ReturnValue %0 : vector<3xbf16>
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) -> !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> {
+  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+  %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+  spirv.ReturnValue %0 : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.SConvert
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 3fc8dfb2767d1..e71b545de11df 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve
   return %0: vector<3xf32>
 }
 
+// CHECK-LABEL: func @composite_construct_bf16_vector
+func.func @composite_construct_bf16_vector(%arg0: bf16, %arg1: bf16, %arg2 : bf16) -> vector<3xbf16> {
+  // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (bf16, bf16, bf16) -> vector<3xbf16>
+  %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (bf16, bf16, bf16) -> vector<3xbf16>
+  return %0: vector<3xbf16>
+}
+
 // CHECK-LABEL: func @composite_construct_struct
 func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
   // CHECK: spirv.CompositeConstruct

diff  --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 2b75767feaf92..642346cc40b0d 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -50,6 +50,14 @@ func.func @exp(%arg0 : i32) -> () {
 
 // -----
 
+func.func @exp_bf16(%arg0 : bf16) -> () {
+  // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
+  %2 = spirv.GL.Exp %arg0 : bf16
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GL.{F|S|U}{Max|Min}
 //===----------------------------------------------------------------------===//
@@ -92,6 +100,15 @@ func.func @iminmax(%arg0: i32, %arg1: i32) {
 
 // -----
 
+func.func @fmaxminbf16vec(%arg0 : vector<3xbf16>, %arg1 : vector<3xbf16>) {
+  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %1 = spirv.GL.FMax %arg0, %arg1 : vector<3xbf16>
+  %2 = spirv.GL.FMin %arg0, %arg1 : vector<3xbf16>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GL.InverseSqrt
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index 5c24f0e6a7d33..d6c34645f5746 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -201,6 +201,14 @@ func.func @select_op_float(%arg0: i1) -> () {
   return
 }
 
+func.func @select_op_bfloat16(%arg0: i1) -> () {
+  %0 = spirv.Constant 2.0 : bf16
+  %1 = spirv.Constant 3.0 : bf16
+  // CHECK: spirv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, bf16
+  %2 = spirv.Select %arg0, %0, %1 : i1, bf16
+  return
+}
+
 func.func @select_op_ptr(%arg0: i1) -> () {
   %0 = spirv.Variable : !spirv.ptr<f32, Function>
   %1 = spirv.Variable : !spirv.ptr<f32, Function>

diff  --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 5f56de6ad1fa9..7ab94f17360d5 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -184,6 +184,14 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto
 
 // -----
 
+func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+  %0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16
+  return %0: bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformFMax
 //===----------------------------------------------------------------------===//
@@ -197,6 +205,14 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 {
 
 // -----
 
+func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+  %0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16
+  return %0: bf16
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformFMin
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index b63a08d96e6af..c23894c62826b 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -15,6 +15,9 @@ func.func private @vector_array_type(!spirv.array< 32 x vector<4xf32> >) -> ()
 // CHECK: func private @array_type_stride(!spirv.array<4 x !spirv.array<4 x f32, stride=4>, stride=128>)
 func.func private @array_type_stride(!spirv.array< 4 x !spirv.array<4 x f32, stride=4>, stride = 128>) -> ()
 
+// CHECK: func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16>>)
+func.func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16> >) -> ()
+
 // -----
 
 // expected-error @+1 {{expected '<'}}
@@ -57,11 +60,6 @@ func.func private @tensor_type(!spirv.array<4xtensor<4xf32>>) -> ()
 
 // -----
 
-// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}}
-func.func private @bf16_type(!spirv.array<4xbf16>) -> ()
-
-// -----
-
 // expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}}
 func.func private @i256_type(!spirv.array<4xi256>) -> ()
 

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index ff5ac7cea8fc6..2b237665ffc4a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -217,3 +217,17 @@ spirv.module Logical GLSL450 attributes {
   spirv.GlobalVariable @data : !spirv.ptr<!spirv.struct<(i8 [0], f16 [2], i64 [4])>, Uniform>
   spirv.GlobalVariable @img  : !spirv.ptr<!spirv.image<f32, Buffer, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Rg32f>, UniformConstant>
 }
+
+// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,
+    #spirv.resource_limits<>
+  >
+} {
+  spirv.func @load_bf16(%ptr : !spirv.ptr<bf16, StorageBuffer>) -> bf16 "None" {
+    %val = spirv.Load "StorageBuffer" %ptr : bf16
+    spirv.ReturnValue %val : bf16
+  }
+}

diff  --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index ede0bf30511ef..04a468b39b645 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -25,6 +25,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.ConvertFToS %arg0 : f64 to i32
     spirv.ReturnValue %0 : i32
   }
+  spirv.func @convert_bf16_to_s32(%arg0 : bf16) -> i32 "None" {
+    // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
+    %0 = spirv.ConvertFToS %arg0 : bf16 to i32
+    spirv.ReturnValue %0 : i32
+  }
   spirv.func @convert_f_to_u(%arg0 : f32) -> i32 "None" {
     // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : f32 to i32
     %0 = spirv.ConvertFToU %arg0 : f32 to i32
@@ -35,6 +40,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.ConvertFToU %arg0 : f64 to i32
     spirv.ReturnValue %0 : i32
   }
+  spirv.func @convert_bf16_to_u32(%arg0 : bf16) -> i32 "None" {
+    // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
+    %0 = spirv.ConvertFToU %arg0 : bf16 to i32
+    spirv.ReturnValue %0 : i32
+  }
   spirv.func @convert_s_to_f(%arg0 : i32) -> f32 "None" {
     // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to f32
     %0 = spirv.ConvertSToF %arg0 : i32 to f32
@@ -45,6 +55,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.ConvertSToF %arg0 : i64 to f32
     spirv.ReturnValue %0 : f32
   }
+  spirv.func @convert_s64_to_bf16(%arg0 : i64) -> bf16 "None" {
+    // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i64 to bf16
+    %0 = spirv.ConvertSToF %arg0 : i64 to bf16
+    spirv.ReturnValue %0 : bf16
+  }
   spirv.func @convert_u_to_f(%arg0 : i32) -> f32 "None" {
     // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to f32
     %0 = spirv.ConvertUToF %arg0 : i32 to f32
@@ -55,11 +70,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.ConvertUToF %arg0 : i64 to f32
     spirv.ReturnValue %0 : f32
   }
-  spirv.func @f_convert(%arg0 : f32) -> f64 "None" {
+  spirv.func @convert_u64_to_bf16(%arg0 : i64) -> bf16 "None" {
+    // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i64 to bf16
+    %0 = spirv.ConvertUToF %arg0 : i64 to bf16
+    spirv.ReturnValue %0 : bf16
+  }
+  spirv.func @convert_f32_to_f64(%arg0 : f32) -> f64 "None" {
     // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : f32 to f64
     %0 = spirv.FConvert %arg0 : f32 to f64
     spirv.ReturnValue %0 : f64
   }
+  spirv.func @convert_f32_to_bf16(%arg0 : f32) -> bf16 "None" {
+    // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : f32 to bf16
+    %0 = spirv.FConvert %arg0 : f32 to bf16
+    spirv.ReturnValue %0 : bf16
+  }
+  spirv.func @convert_bf16_to_f32(%arg0 : bf16) -> f32 "None" {
+    // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
+    %0 = spirv.FConvert %arg0 : bf16 to f32
+    spirv.ReturnValue %0 : f32
+  }
   spirv.func @s_convert(%arg0 : i32) -> i64 "None" {
     // CHECK: {{%.*}} = spirv.SConvert {{%.*}} : i32 to i64
     %0 = spirv.SConvert %arg0 : i32 to i64

diff  --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir
index 16846ac84e38c..b2008719b021c 100644
--- a/mlir/test/Target/SPIRV/logical-ops.mlir
+++ b/mlir/test/Target/SPIRV/logical-ops.mlir
@@ -108,3 +108,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.Return
   }
 }
+
+// -----
+
+// Test select works with bf16 scalar and vectors.
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  spirv.SpecConstant @condition_scalar = true
+  spirv.func @select_bf16() -> () "None" {
+    %0 = spirv.Constant 4.0 : bf16
+    %1 = spirv.Constant 5.0 : bf16
+    %2 = spirv.mlir.referenceof @condition_scalar : i1
+    // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, bf16
+    %3 = spirv.Select %2, %0, %1 : i1, bf16
+    %4 = spirv.Constant dense<[2.0, 3.0, 4.0, 5.0]> : vector<4xbf16>
+    %5 = spirv.Constant dense<[6.0, 7.0, 8.0, 9.0]> : vector<4xbf16>
+    // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, vector<4xbf16>
+    %6 = spirv.Select %2, %4, %5 : i1, vector<4xbf16>
+    %7 = spirv.Constant dense<[true, true, true, true]> : vector<4xi1>
+    // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : vector<4xi1>, vector<4xbf16>
+    %8 = spirv.Select %7, %4, %5 : vector<4xi1>, vector<4xbf16>
+    spirv.Return
+  }
+}


        


More information about the Mlir-commits mailing list