[llvm] [SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 (PR #155645)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 15 05:12:49 PDT 2025


https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/155645

>From 301c7e0df9d5ae33386a89035ffa71d33c0698d6 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 03:42:59 -0700
Subject: [PATCH 01/12] add support for the SPIR-V extension SPV_KHR_bfloat16

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  3 --
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp    |  3 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 40 ++++++++++++++++---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 13 ++++++
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 38 ++++++++++++++++--
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |  4 ++
 .../extensions/SPV_KHR_bfloat16/bfloat16.ll   | 22 ++++++++++
 .../bfloat16_cooperative_matrix.ll            | 20 ++++++++++
 .../SPV_KHR_bfloat16/bfloat16_dot.ll          | 21 ++++++++++
 9 files changed, 151 insertions(+), 13 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 768e3713f78e2..b1e68fb86d286 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,9 +2765,6 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
-    return false;
-
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..993de9e9f64ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_KHR_float_controls2",
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
         {"SPV_INTEL_tensor_float32_conversion",
-         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+        {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cfe24c84941a9..ce9ebb619f242 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1122,7 +1122,19 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
                                          ExplicitLayoutRequired, EmitIR);
   TypesInProcessing.erase(Ty);
-  VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
+
+  // Record the FPVariant of the floating-point registers in the
+  // VRegFPVariantMap.
+  MachineFunction *MF = &MIRBuilder.getMF();
+  Register TypeReg = getSPIRVTypeID(SpirvType);
+  if (Ty->isFloatingPointTy()) {
+    if (Ty->isBFloatTy()) {
+      VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
+    } else {
+      VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
+    }
+  }
+  VRegToTypeMap[MF][TypeReg] = SpirvType;
 
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
   // Is that a problem?
@@ -1679,11 +1691,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
   const MachineInstr *NewMI =
       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
-                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
-            .addDef(createTypeVReg(CurMF->getRegInfo()))
-            .addImm(BitWidth)
-            .addImm(0);
+        auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+                           MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+                       .addDef(createTypeVReg(CurMF->getRegInfo()))
+                       .addImm(BitWidth);
+
+        if (SPIRVOPcode != SPIRV::OpTypeFloat)
+          MIB.addImm(0);
+
+        return MIB;
       });
   add(Ty, false, NewMI);
   return finishCreatingSPIRVType(Ty, NewMI);
@@ -2088,3 +2104,15 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
   }
   return false;
 }
+
+SPIRVGlobalRegistry::FPVariant
+SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
+                                         const MachineFunction *MF) {
+  auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
+  if (t != VRegFPVariantMap.end()) {
+    auto tt = t->second.find(VReg);
+    if (tt != t->second.end())
+      return tt->second;
+  }
+  return FPVariant::NONE;
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7ef812828b7cc..1f8c30dc01f7f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,6 +29,10 @@ using SPIRVType = const MachineInstr;
 using StructOffsetDecorator = std::function<void(Register)>;
 
 class SPIRVGlobalRegistry : public SPIRVIRMapping {
+public:
+  enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
+
+private:
   // Registers holding values which have types associated with them.
   // Initialized upon VReg definition in IRTranslator.
   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -88,6 +92,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // map of aliasing decorations to aliasing metadata
   std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;
 
+  // Maps floating point Registers to their FPVariant (float type kind), given
+  // the MachineFunction.
+  DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
+      VRegFPVariantMap;
+
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
                              SPIRV::AccessQualifier::AccessQualifier AQ,
@@ -422,6 +431,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // structures referring this instruction.
   void invalidateMachineInstr(MachineInstr *MI);
 
+  // Return the FPVariant of to the given floating-point regiester.
+  FPVariant getFPVariantForVReg(Register VReg,
+                                const MachineFunction *MF = nullptr);
+
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index b7e371d190866..6cf00f078e8e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1261,12 +1261,35 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Int8);
     break;
   }
+  case SPIRV::OpDot: {
+    const MachineFunction *MF = MI.getMF();
+    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+    SPIRVGlobalRegistry::FPVariant FPV =
+        GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+    if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+      Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+    }
+    break;
+  }
   case SPIRV::OpTypeFloat: {
     unsigned BitWidth = MI.getOperand(1).getImm();
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
-    else if (BitWidth == 16)
+    else if (BitWidth == 16) {
+      SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+      const MachineFunction *MF = MI.getMF();
+      SPIRVGlobalRegistry::FPVariant FPV =
+          GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
+      if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+        if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
+          report_fatal_error("OpTypeFloat type with bfloat requires the "
+                             "following SPIR-V extension: SPV_KHR_bfloat16",
+                             false);
+        Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+        Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+      }
       Reqs.addCapability(SPIRV::Capability::Float16);
+    }
     break;
   }
   case SPIRV::OpTypeVector: {
@@ -1593,15 +1616,24 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::AsmINTEL);
     }
     break;
-  case SPIRV::OpTypeCooperativeMatrixKHR:
+  case SPIRV::OpTypeCooperativeMatrixKHR: {
     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
       report_fatal_error(
           "OpTypeCooperativeMatrixKHR type requires the "
           "following SPIR-V extension: SPV_KHR_cooperative_matrix",
           false);
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
-    Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+    const MachineFunction *MF = MI.getMF();
+    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+    SPIRVGlobalRegistry::FPVariant FPV =
+        GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+    if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+      Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
+    } else {
+      Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+    }
     break;
+  }
   case SPIRV::OpArithmeticFenceEXT:
     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
       report_fatal_error("OpArithmeticFenceEXT requires the "
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
 defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
 defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
 defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
 defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
new file mode 100644
index 0000000000000..bfc84691f6945
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -0,0 +1,22 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_kernel void @test() {
+entry:
+  %addr1 = alloca bfloat
+  %addr2 = alloca <2 x bfloat>
+  %data1 = load bfloat, ptr %addr1
+  %data2 = load <2 x bfloat>, ptr %addr2
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
new file mode 100644
index 0000000000000..5a6e6d88ca6a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -0,0 +1,20 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]]  %[[#]] %[[#]] %[[#]] %[[#]]
+; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
+
+define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
+entry:
+    %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+    %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+    %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 1.0)
+    store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
+    ret void
+}
+
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
new file mode 100644
index 0000000000000..7cfe29261f2cd
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -0,0 +1,21 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16DotProductKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+; CHECK: OpDot
+
+declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
+
+define spir_kernel void @test() {
+entry:
+  %addrA = alloca <2 x bfloat>
+  %addrB = alloca <2 x bfloat>
+  %dataA = load <2 x bfloat>, ptr %addrA
+  %dataB = load <2 x bfloat>, ptr %addrB
+  %call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
+  ret void
+}

>From 1ea205e308c12d936c781d3ff23bcc8616aae7cb Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 5 Sep 2025 09:28:38 -0700
Subject: [PATCH 02/12] nit change

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  3 +++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 16 +++++++++++-----
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index b1e68fb86d286..12b735e053bde 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,6 +2765,9 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
+    return false;
+
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ce9ebb619f242..4c931b4f45e69 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -2108,11 +2108,17 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
 SPIRVGlobalRegistry::FPVariant
 SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
                                          const MachineFunction *MF) {
-  auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
-  if (t != VRegFPVariantMap.end()) {
-    auto tt = t->second.find(VReg);
-    if (tt != t->second.end())
-      return tt->second;
+  const MachineFunction *Func = MF ? MF : CurMF;
+  DenseMap<const MachineFunction *,
+           DenseMap<Register, FPVariant>>::const_iterator FuncIt =
+      VRegFPVariantMap.find(Func);
+
+  if (FuncIt != VRegFPVariantMap.end()) {
+    const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
+    DenseMap<Register, FPVariant>::const_iterator VRegIt = VRegMap.find(VReg);
+
+    if (VRegIt != VRegMap.end())
+      return VRegIt->second;
   }
   return FPVariant::NONE;
 }

>From bdb40692524890dc4206a989467778e6251f1221 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 5 Sep 2025 10:29:03 -0700
Subject: [PATCH 03/12] revert prev change

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 16 +++++-----------
 1 file changed, 5 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 4c931b4f45e69..ce9ebb619f242 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -2108,17 +2108,11 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
 SPIRVGlobalRegistry::FPVariant
 SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
                                          const MachineFunction *MF) {
-  const MachineFunction *Func = MF ? MF : CurMF;
-  DenseMap<const MachineFunction *,
-           DenseMap<Register, FPVariant>>::const_iterator FuncIt =
-      VRegFPVariantMap.find(Func);
-
-  if (FuncIt != VRegFPVariantMap.end()) {
-    const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
-    DenseMap<Register, FPVariant>::const_iterator VRegIt = VRegMap.find(VReg);
-
-    if (VRegIt != VRegMap.end())
-      return VRegIt->second;
+  auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
+  if (t != VRegFPVariantMap.end()) {
+    auto tt = t->second.find(VReg);
+    if (tt != t->second.end())
+      return tt->second;
   }
   return FPVariant::NONE;
 }

>From 7877ca6711d4b75f753cad1ff8c4ace1b915a3dc Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 5 Sep 2025 10:46:21 -0700
Subject: [PATCH 04/12] Revert change

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 12b735e053bde..9fc89ac700b88 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
-    return false;
+  // if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
+  //   return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();

>From 2e54ae80d3e5f2076f1fbee2fd9b1d47d6080812 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 07:24:34 -0700
Subject: [PATCH 05/12] code clean up

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp         |  4 ++--
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp        | 12 +++++++-----
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp        |  2 +-
 .../SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll    |  2 +-
 .../SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll  |  2 +-
 .../extensions/SPV_KHR_bfloat16/bfloat16_dot.ll      |  2 +-
 6 files changed, 13 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 9fc89ac700b88..12b735e053bde 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  // if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
-  //   return false;
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
+    return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ce9ebb619f242..01c995bccb0ef 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -2108,11 +2108,13 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
 SPIRVGlobalRegistry::FPVariant
 SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
                                          const MachineFunction *MF) {
-  auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
-  if (t != VRegFPVariantMap.end()) {
-    auto tt = t->second.find(VReg);
-    if (tt != t->second.end())
-      return tt->second;
+  const MachineFunction *Func = MF ? MF : CurMF;
+  auto FuncIt = VRegFPVariantMap.find(Func);
+  if (FuncIt != VRegFPVariantMap.end()) {
+    const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
+    auto VRegIt = VRegMap.find(VReg);
+    if (VRegIt != VRegMap.end())
+      return VRegIt->second;
   }
   return FPVariant::NONE;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 6cf00f078e8e3..b4b04e08c8cd4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1276,6 +1276,7 @@ void addInstrRequirements(const MachineInstr &MI,
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16) {
+      Reqs.addCapability(SPIRV::Capability::Float16);
       SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
       const MachineFunction *MF = MI.getMF();
       SPIRVGlobalRegistry::FPVariant FPV =
@@ -1288,7 +1289,6 @@ void addInstrRequirements(const MachineInstr &MI,
         Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
         Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
       }
-      Reqs.addCapability(SPIRV::Capability::Float16);
     }
     break;
   }
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index bfc84691f6945..45123eb15d8d7 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -4,8 +4,8 @@
 
 ; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
 
-; CHECK-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpCapability BFloat16TypeKHR
 ; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
 ; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
 
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index 5a6e6d88ca6a0..d54b8325c6783 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -1,9 +1,9 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
 
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
 ; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]]  %[[#]] %[[#]] %[[#]] %[[#]]
 ; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index 7cfe29261f2cd..0943170ae6785 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -1,9 +1,9 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
 
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-DAG: OpCapability BFloat16DotProductKHR
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
 ; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
 ; CHECK: OpDot

>From 16e93a3bd803f4d3d878fe2e7f9d5560b0da8dd0 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 15:35:18 -0700
Subject: [PATCH 06/12] the SPIRV bfloat instruction should be Optypefloat 16
 0, we update the implementation to make the correction

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 52 +++++--------------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 17 ++----
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  2 +-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 20 +++----
 .../extensions/SPV_KHR_bfloat16/bfloat16.ll   |  2 +-
 .../bfloat16_cooperative_matrix.ll            |  2 +-
 .../SPV_KHR_bfloat16/bfloat16_dot.ll          |  2 +-
 7 files changed, 27 insertions(+), 70 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 01c995bccb0ef..c445b7a4a6e95 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,11 +195,15 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
-                                               MachineIRBuilder &MIRBuilder) {
+                                               MachineIRBuilder &MIRBuilder, bool isBfloatTy) {
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-    return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+    auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
         .addDef(createTypeVReg(MIRBuilder))
         .addImm(Width);
+    if(isBfloatTy){
+      MIB.addImm(0);
+    }
+    return MIB;
   });
 }
 
@@ -1042,7 +1046,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
                       : getOpTypeInt(Width, MIRBuilder, false);
   }
   if (Ty->isFloatingPointTy())
-    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, Ty->isBFloatTy());
   if (Ty->isVoidTy())
     return getOpTypeVoid(MIRBuilder);
   if (Ty->isVectorTy()) {
@@ -1122,19 +1126,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
                                          ExplicitLayoutRequired, EmitIR);
   TypesInProcessing.erase(Ty);
-
-  // Record the FPVariant of the floating-point registers in the
-  // VRegFPVariantMap.
-  MachineFunction *MF = &MIRBuilder.getMF();
-  Register TypeReg = getSPIRVTypeID(SpirvType);
-  if (Ty->isFloatingPointTy()) {
-    if (Ty->isBFloatTy()) {
-      VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
-    } else {
-      VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
-    }
-  }
-  VRegToTypeMap[MF][TypeReg] = SpirvType;
+  VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
 
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
   // Is that a problem?
@@ -1691,15 +1683,11 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
   const MachineInstr *NewMI =
       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-        auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
-                           MIRBuilder.getDL(), TII.get(SPIRVOPcode))
-                       .addDef(createTypeVReg(CurMF->getRegInfo()))
-                       .addImm(BitWidth);
-
-        if (SPIRVOPcode != SPIRV::OpTypeFloat)
-          MIB.addImm(0);
-
-        return MIB;
+        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+            .addDef(createTypeVReg(CurMF->getRegInfo()))
+            .addImm(BitWidth)
+            .addImm(0);
       });
   add(Ty, false, NewMI);
   return finishCreatingSPIRVType(Ty, NewMI);
@@ -2104,17 +2092,3 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
   }
   return false;
 }
-
-SPIRVGlobalRegistry::FPVariant
-SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
-                                         const MachineFunction *MF) {
-  const MachineFunction *Func = MF ? MF : CurMF;
-  auto FuncIt = VRegFPVariantMap.find(Func);
-  if (FuncIt != VRegFPVariantMap.end()) {
-    const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
-    auto VRegIt = VRegMap.find(VReg);
-    if (VRegIt != VRegMap.end())
-      return VRegIt->second;
-  }
-  return FPVariant::NONE;
-}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 1f8c30dc01f7f..2c28484599656 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,10 +29,6 @@ using SPIRVType = const MachineInstr;
 using StructOffsetDecorator = std::function<void(Register)>;
 
 class SPIRVGlobalRegistry : public SPIRVIRMapping {
-public:
-  enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
-
-private:
   // Registers holding values which have types associated with them.
   // Initialized upon VReg definition in IRTranslator.
   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -92,11 +88,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // map of aliasing decorations to aliasing metadata
   std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;
 
-  // Maps floating point Registers to their FPVariant (float type kind), given
-  // the MachineFunction.
-  DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
-      VRegFPVariantMap;
-
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
                              SPIRV::AccessQualifier::AccessQualifier AQ,
@@ -431,10 +422,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // structures referring this instruction.
   void invalidateMachineInstr(MachineInstr *MI);
 
-  // Return the FPVariant of to the given floating-point regiester.
-  FPVariant getFPVariantForVReg(Register VReg,
-                                const MachineFunction *MF = nullptr);
-
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 
@@ -449,7 +436,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
                           bool IsSigned = false);
 
-  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, bool isBfloatTy);
+
+  SPIRVType *getOpTypeBFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
 
   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 8d10cd0ffb3dd..496dcba17c10d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,7 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
 def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
 def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
                   "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
                   "$type = OpTypeFloat $width">;
 def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
                   "$type = OpTypeVector $compType $compCount">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index b4b04e08c8cd4..a64c65b4a0dd1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,11 +1262,10 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpDot: {
-    const MachineFunction *MF = MI.getMF();
     SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
-    SPIRVGlobalRegistry::FPVariant FPV =
-        GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
-    if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+    const MachineFunction *MF = MI.getMF();
+    SPIRVType *RegType = GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF);
+    if (RegType->getNumOperands() == 3) {
       Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
     }
     break;
@@ -1277,11 +1276,7 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16) {
       Reqs.addCapability(SPIRV::Capability::Float16);
-      SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
-      const MachineFunction *MF = MI.getMF();
-      SPIRVGlobalRegistry::FPVariant FPV =
-          GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
-      if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+      if (MI.getNumOperands() == 3) {
         if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
           report_fatal_error("OpTypeFloat type with bfloat requires the "
                              "following SPIR-V extension: SPV_KHR_bfloat16",
@@ -1623,11 +1618,10 @@ void addInstrRequirements(const MachineInstr &MI,
           "following SPIR-V extension: SPV_KHR_cooperative_matrix",
           false);
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
-    const MachineFunction *MF = MI.getMF();
     SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
-    SPIRVGlobalRegistry::FPVariant FPV =
-        GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
-    if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+    const MachineFunction *MF = MI.getMF();
+    SPIRVType *RegType = GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF);
+    if (RegType->getNumOperands() == 3) {
       Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
     } else {
       Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index 45123eb15d8d7..48514b60ad9b9 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -6,7 +6,7 @@
 
 ; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
-; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
 ; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
 
 target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index d54b8325c6783..9f52a459616da 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -4,7 +4,7 @@
 ; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
-; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
 ; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]]  %[[#]] %[[#]] %[[#]] %[[#]]
 ; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
 
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index 0943170ae6785..51d212788df72 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -4,7 +4,7 @@
 ; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-DAG: OpCapability BFloat16DotProductKHR
-; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
 ; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
 ; CHECK: OpDot
 

>From 91859188bdb54d0f2312e84e37ee8253dc0175f0 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 21:27:45 -0700
Subject: [PATCH 07/12] Add the enum class FPEncoding

---
 .../Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h |  5 ++++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 27 ++++++++++-------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  4 +--
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 29 +++++++++----------
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 28 +++++++++++++++++-
 5 files changed, 65 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
index c2c08f8831307..d76180ce97e9e 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
@@ -232,6 +232,11 @@ namespace SpecConstantOpOperands {
 #include "SPIRVGenTables.inc"
 } // namespace SpecConstantOpOperands
 
+namespace FPEncoding {
+#define GET_FPEncoding_DECL
+#include "SPIRVGenTables.inc"
+} // namespace FPEncoding
+
 struct ExtendedBuiltin {
   StringRef Name;
   InstructionSet::InstructionSet Set;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index c445b7a4a6e95..8d4e766545ba3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,16 +195,18 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
-                                               MachineIRBuilder &MIRBuilder, bool isBfloatTy) {
-  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-    auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+                                               MachineIRBuilder &MIRBuilder) {
+  return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
         .addDef(createTypeVReg(MIRBuilder))
         .addImm(Width);
-    if(isBfloatTy){
-      MIB.addImm(0);
-    }
-    return MIB;
-  });
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
+                                               MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode) {
+  return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+        .addDef(createTypeVReg(MIRBuilder))
+        .addImm(Width)
+        .addImm(FPEncode);
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
@@ -1045,8 +1047,13 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     return Width == 1 ? getOpTypeBool(MIRBuilder)
                       : getOpTypeInt(Width, MIRBuilder, false);
   }
-  if (Ty->isFloatingPointTy())
-    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, Ty->isBFloatTy());
+  if (Ty->isFloatingPointTy()) {
+    if(Ty->isBFloatTy()) {
+      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, SPIRV::FPEncoding::BFloat16KHR);
+    } else {
+      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+    }
+  }
   if (Ty->isVoidTy())
     return getOpTypeVoid(MIRBuilder);
   if (Ty->isVectorTy()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2c28484599656..8104cb8177457 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -436,9 +436,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
                           bool IsSigned = false);
 
-  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, bool isBfloatTy);
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
 
-  SPIRVType *getOpTypeBFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode);
 
   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index a64c65b4a0dd1..de1ee17e4b559 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,10 +1262,9 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpDot: {
-    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
-    const MachineFunction *MF = MI.getMF();
-    SPIRVType *RegType = GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF);
-    if (RegType->getNumOperands() == 3) {
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+    if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
       Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
     }
     break;
@@ -1275,14 +1274,15 @@ void addInstrRequirements(const MachineInstr &MI,
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16) {
-      Reqs.addCapability(SPIRV::Capability::Float16);
-      if (MI.getNumOperands() == 3) {
+      if ((MI.getNumOperands() == 3) && (MI.getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
         if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
           report_fatal_error("OpTypeFloat type with bfloat requires the "
                              "following SPIR-V extension: SPV_KHR_bfloat16",
                              false);
-        Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
         Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+        Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+      } else {
+        Reqs.addCapability(SPIRV::Capability::Float16);
       }
     }
     break;
@@ -1304,8 +1304,9 @@ void addInstrRequirements(const MachineInstr &MI,
     assert(MI.getOperand(2).isReg());
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
-    if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
-        TypeDef->getOperand(1).getImm() == 16)
+    if ((TypeDef->getNumOperands() == 2) &&
+        (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
+        (TypeDef->getOperand(1).getImm() == 16))
       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
     break;
   }
@@ -1617,14 +1618,12 @@ void addInstrRequirements(const MachineInstr &MI,
           "OpTypeCooperativeMatrixKHR type requires the "
           "following SPIR-V extension: SPV_KHR_cooperative_matrix",
           false);
+    Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
-    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
-    const MachineFunction *MF = MI.getMF();
-    SPIRVType *RegType = GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF);
-    if (RegType->getNumOperands() == 3) {
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+    if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
       Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
-    } else {
-      Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
     }
     break;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 9d630356e8ffb..00b636d315d02 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -210,6 +210,7 @@ def CooperativeMatrixLayoutOperand : OperandCategory;
 def CooperativeMatrixOperandsOperand : OperandCategory;
 def SpecConstantOpOperandsOperand : OperandCategory;
 def MatrixMultiplyAccumulateOperandsOperand : OperandCategory;
+def FPEncodingOperand : OperandCategory;
 
 //===----------------------------------------------------------------------===//
 // Definition of the Environments
@@ -382,7 +383,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
 defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
 defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
-defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -2000,3 +2001,28 @@ defm MatrixAPackedFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x400,
 defm MatrixBPackedFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
 defm MatrixAPackedBFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
 defm MatrixBPackedBFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+
+//===----------------------------------------------------------------------===//
+// Multiclass used to define FPEncoding enum values and at the
+// same time SymbolicOperand entries extensions.
+//===----------------------------------------------------------------------===//
+def FPEncoding : GenericEnum, Operand<i32> {
+  let FilterClass = "FPEncoding";
+  let NameField = "Name";
+  let ValueField = "Value";
+  let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
+}
+
+class FPEncoding<string name, bits<32> value> {
+  string Name = name;
+  bits<32> Value = value;
+}
+
+multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
+  def NAME : FPEncoding<NAME, value>;
+  defm : SymbolicOperandWithRequirements<
+             FPEncodingOperand, value, NAME, 0, 0,
+             reqExtensions, [], []>;
+}
+
+defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>;

>From 874b618fc088f6cfafbba3a55a2abed124f9ccf0 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 21:56:15 -0700
Subject: [PATCH 08/12] nit change and update tests

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp             | 8 ++++++--
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp             | 4 ++--
 llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td            | 2 +-
 .../CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll | 2 +-
 .../SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll       | 4 +++-
 .../SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll     | 2 +-
 6 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 8d4e766545ba3..76b276fede679 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -196,17 +196,21 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
                                                MachineIRBuilder &MIRBuilder) {
-  return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
         .addDef(createTypeVReg(MIRBuilder))
         .addImm(Width);
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
                                                MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode) {
-  return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
         .addDef(createTypeVReg(MIRBuilder))
         .addImm(Width)
         .addImm(FPEncode);
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index de1ee17e4b559..49ac72e71f863 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1279,8 +1279,8 @@ void addInstrRequirements(const MachineInstr &MI,
           report_fatal_error("OpTypeFloat type with bfloat requires the "
                              "following SPIR-V extension: SPV_KHR_bfloat16",
                              false);
-        Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
         Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+        Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
       } else {
         Reqs.addCapability(SPIRV::Capability::Float16);
       }
@@ -1618,8 +1618,8 @@ void addInstrRequirements(const MachineInstr &MI,
           "OpTypeCooperativeMatrixKHR type requires the "
           "following SPIR-V extension: SPV_KHR_cooperative_matrix",
           false);
-    Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
+    Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
     if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 00b636d315d02..501bcb94af2ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -2004,7 +2004,7 @@ defm MatrixBPackedBFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x200
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define FPEncoding enum values and at the
-// same time SymbolicOperand entries extensions.
+// same time SymbolicOperand entries with extensions.
 //===----------------------------------------------------------------------===//
 def FPEncoding : GenericEnum, Operand<i32> {
   let FilterClass = "FPEncoding";
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index 48514b60ad9b9..22668e71fb257 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -4,8 +4,8 @@
 
 ; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
 
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
 ; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
 
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index 9f52a459616da..d47b5d7440d18 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -1,9 +1,11 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
 
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability CooperativeMatrixKHR
 ; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_KHR_cooperative_matrix"
 ; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
 ; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]]  %[[#]] %[[#]] %[[#]] %[[#]]
 ; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index 51d212788df72..4c248fea5c7f1 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -1,9 +1,9 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
 
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-DAG: OpCapability BFloat16DotProductKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
 ; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
 ; CHECK: OpDot

>From 19593a7ea8de92ee50f61f48ec301ed6dfcd1897 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 22:04:42 -0700
Subject: [PATCH 09/12] fix code format issue

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 11 +++++++----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  3 ++-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  9 ++++++---
 3 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 76b276fede679..115766ce886c7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -203,8 +203,10 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
   });
 }
 
-SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
-                                               MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode) {
+SPIRVType *
+SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
+                                    MachineIRBuilder &MIRBuilder,
+                                    SPIRV::FPEncoding::FPEncoding FPEncode) {
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
     return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
         .addDef(createTypeVReg(MIRBuilder))
@@ -1052,8 +1054,9 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
                       : getOpTypeInt(Width, MIRBuilder, false);
   }
   if (Ty->isFloatingPointTy()) {
-    if(Ty->isBFloatTy()) {
-      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, SPIRV::FPEncoding::BFloat16KHR);
+    if (Ty->isBFloatTy()) {
+      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder,
+                            SPIRV::FPEncoding::BFloat16KHR);
     } else {
       return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
     }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 8104cb8177457..a648defa0a888 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -438,7 +438,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
 
   SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
 
-  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode);
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
+                            SPIRV::FPEncoding::FPEncoding FPEncode);
 
   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 49ac72e71f863..e0eb27d716d72 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1264,7 +1264,8 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpDot: {
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
-    if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
+    if ((TypeDef->getNumOperands() == 3) &&
+        (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
       Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
     }
     break;
@@ -1274,7 +1275,8 @@ void addInstrRequirements(const MachineInstr &MI,
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16) {
-      if ((MI.getNumOperands() == 3) && (MI.getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
+      if ((MI.getNumOperands() == 3) &&
+          (MI.getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
         if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
           report_fatal_error("OpTypeFloat type with bfloat requires the "
                              "following SPIR-V extension: SPV_KHR_bfloat16",
@@ -1622,7 +1624,8 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
-    if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
+    if ((TypeDef->getNumOperands() == 3) &&
+        (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
       Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
     }
     break;

>From 91c4ab8a5b168d6cde6ae919a4e2071a4b21faeb Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 23:01:46 -0700
Subject: [PATCH 10/12] fix the CI check issue

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index e0eb27d716d72..2bc1e2b8cef6a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1222,6 +1222,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
   }
 }
 
+static bool isBFloat16Type(const SPIRVType *TypeDef) {
+  return TypeDef &&
+         TypeDef->getNumOperands() == 3 &&
+         TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+         TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
+}
+
 void addInstrRequirements(const MachineInstr &MI,
                           SPIRV::RequirementHandler &Reqs,
                           const SPIRVSubtarget &ST) {
@@ -1264,10 +1271,8 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpDot: {
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
-    if ((TypeDef->getNumOperands() == 3) &&
-        (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
+    if (isBFloat16Type(TypeDef))
       Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
-    }
     break;
   }
   case SPIRV::OpTypeFloat: {
@@ -1624,10 +1629,8 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
-    if ((TypeDef->getNumOperands() == 3) &&
-        (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
+    if (isBFloat16Type(TypeDef))
       Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
-    }
     break;
   }
   case SPIRV::OpArithmeticFenceEXT:

>From a37c02f8d65f0402a2f759e25ba64890a2b9bd77 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 9 Sep 2025 06:44:26 -0700
Subject: [PATCH 11/12] fix the clang format issue

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 2bc1e2b8cef6a..8721adddc5aff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1223,8 +1223,7 @@ static void AddDotProductRequirements(const MachineInstr &MI,
 }
 
 static bool isBFloat16Type(const SPIRVType *TypeDef) {
-  return TypeDef &&
-         TypeDef->getNumOperands() == 3 &&
+  return TypeDef && TypeDef->getNumOperands() == 3 &&
          TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
          TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
 }

>From 389096ad50884c22ffcda52167f34b404e95e4be Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 15 Sep 2025 04:57:50 -0700
Subject: [PATCH 12/12] add test for bfloat

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  4 ++--
 llvm/test/CodeGen/SPIRV/basic_float_types.ll  | 10 ++++++++++
 2 files changed, 12 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8721adddc5aff..a47cf295c6fff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1225,6 +1225,7 @@ static void AddDotProductRequirements(const MachineInstr &MI,
 static bool isBFloat16Type(const SPIRVType *TypeDef) {
   return TypeDef && TypeDef->getNumOperands() == 3 &&
          TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+         TypeDef->getOperand(1).getImm() == 16 &&
          TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
 }
 
@@ -1279,8 +1280,7 @@ void addInstrRequirements(const MachineInstr &MI,
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16) {
-      if ((MI.getNumOperands() == 3) &&
-          (MI.getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
+      if(isBFloat16Type(&MI)) {
         if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
           report_fatal_error("OpTypeFloat type with bfloat requires the "
                              "following SPIR-V extension: SPV_KHR_bfloat16",
diff --git a/llvm/test/CodeGen/SPIRV/basic_float_types.ll b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
index dfee1ace2205d..009355416c475 100644
--- a/llvm/test/CodeGen/SPIRV/basic_float_types.ll
+++ b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
@@ -7,8 +7,10 @@ entry:
 
 ; CHECK-DAG: OpCapability Float16
 ; CHECK-DAG: OpCapability Float64
+; CHECK-DAG: OpCapability BFloat16TypeKHR
 
 ; CHECK-DAG:     %[[#half:]] = OpTypeFloat 16
+; CHECK-DAG:   %[[#bfloat:]] = OpTypeFloat 16 0
 ; CHECK-DAG:    %[[#float:]] = OpTypeFloat 32
 ; CHECK-DAG:   %[[#double:]] = OpTypeFloat 64
 
@@ -25,11 +27,13 @@ entry:
 ; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4
 
 ; CHECK-DAG:     %[[#ptr_Function_half:]] = OpTypePointer Function %[[#half]]
+; CHECK-DAG:    %[[#ptr_Function_bfloat:]] = OpTypePointer Function %[[#bfloat]]
 ; CHECK-DAG:    %[[#ptr_Function_float:]] = OpTypePointer Function %[[#float]]
 ; CHECK-DAG:   %[[#ptr_Function_double:]] = OpTypePointer Function %[[#double]]
 ; CHECK-DAG:   %[[#ptr_Function_v2half:]] = OpTypePointer Function %[[#v2half]]
 ; CHECK-DAG:   %[[#ptr_Function_v3half:]] = OpTypePointer Function %[[#v3half]]
 ; CHECK-DAG:   %[[#ptr_Function_v4half:]] = OpTypePointer Function %[[#v4half]]
+; CHECK-DAG:  %[[#ptr_Function_v2bfloat:]] = OpTypePointer Function %[[#v2bfloat]]
 ; CHECK-DAG:  %[[#ptr_Function_v2float:]] = OpTypePointer Function %[[#v2float]]
 ; CHECK-DAG:  %[[#ptr_Function_v3float:]] = OpTypePointer Function %[[#v3float]]
 ; CHECK-DAG:  %[[#ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
@@ -40,6 +44,9 @@ entry:
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_half]] Function
   %half_Val = alloca half, align 2
 
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_bfloat]] Function
+  %bfloat_Val = alloca bfloat, align 2
+
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_float]] Function
   %float_Val = alloca float, align 4
 
@@ -55,6 +62,9 @@ entry:
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4half]] Function
   %half4_Val = alloca <4 x half>, align 8
 
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2bfloat]] Function
+  %bfloat2_Val = alloca <2 x bfloat>, align 4
+
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2float]] Function
   %float2_Val = alloca <2 x float>, align 8
 



More information about the llvm-commits mailing list