[llvm] [SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 (PR #155645)

via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 4 04:53:38 PDT 2025


https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/155645

>From 9e4f483ce4d7c0f710270cc21007b73ce56b59e1 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 03:42:59 -0700
Subject: [PATCH 01/20] initial commit for adding the support of bfloat in
 machine inst

---
 llvm/include/llvm/CodeGen/MachineInstr.h       | 3 ++-
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp   | 9 +++++++--
 llvm/lib/CodeGen/MachineInstr.cpp              | 3 +++
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp     | 4 +++-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp  | 4 ++++
 llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 4 ++++
 6 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 10a9b1ff1411d..6f692ae32510b 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,8 +123,9 @@ class MachineInstr
     NoUSWrap = 1 << 20,      // Instruction supports geps
                              // no unsigned signed wrap.
     SameSign = 1 << 21,      // Both operands have the same sign.
-    InBounds = 1 << 22       // Pointer arithmetic remains inbounds.
+    InBounds = 1 << 22,       // Pointer arithmetic remains inbounds.
                              // Implies NoUSWrap.
+    BFloat16 = 1 << 23      // Instruction with bf16 type
   };
 
 private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..2a6b66984c8ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
-    return false;
+  // if (containsBF16Type(U))
+  //   return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
@@ -2813,6 +2813,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   if (isa<FPMathOperator>(CI))
     MIB->copyIRFlags(CI);
 
+  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+  if (containsBF16Type(U)) {
+    MIB->setFlag(MachineInstr::MIFlag::BFloat16);
+  }
+
   for (const auto &Arg : enumerate(CI.args())) {
     // If this is required to be an immediate, don't materialize it in a
     // register.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 79047f732808a..10ff667bcb522 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,6 +632,9 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
+  if (I.getType()->isBFloatTy())
+    MIFlags |= MachineInstr::MIFlag::BFloat16;
+
   return MIFlags;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..bd13a3bae92cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_KHR_float_controls2",
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
         {"SPV_INTEL_tensor_float32_conversion",
-         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+        {"SPV_KHR_bfloat16",
+         SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..5bba5cdce3753 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,6 +1267,10 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16)
       Reqs.addCapability(SPIRV::Capability::Float16);
+    if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+      Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+    }
     break;
   }
   case SPIRV::OpTypeVector: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
 defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
 defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
 defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
 defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time

>From 32e6c031b24083ef668da44d2dd5c9785d9ba12d Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 12:29:09 -0700
Subject: [PATCH 02/20] debugging

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  3 ++-
 llvm/lib/CodeGen/MachineInstr.cpp             |  4 +++-
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  2 ++
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 10 ++++++----
 4 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 2a6b66984c8ae..3b112805e082c 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2815,6 +2815,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
 
   // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
   if (containsBF16Type(U)) {
+    // assert(false && "bfloat detected at the IR Translator");
     MIB->setFlag(MachineInstr::MIFlag::BFloat16);
   }
 
@@ -2878,7 +2879,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
       MIB.addUse(TokenReg, RegState::Implicit);
     }
   }
-
+  
   return true;
 }
 
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 10ff667bcb522..1ec8c8a2d90c9 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,8 +632,10 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
-  if (I.getType()->isBFloatTy())
+  if (I.getType()->getScalarType()->isBFloatTy()) {
     MIFlags |= MachineInstr::MIFlag::BFloat16;
+    // assert(false && "bfloat detected at the MachineInstr");
+  }
 
   return MIFlags;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index f0b938d681dba..923c92c28d3f6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -169,6 +169,8 @@ def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness)
                   "$type = OpTypeInt $width $signedness">;
 def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
                   "$type = OpTypeFloat $width">;
+def OpTypeBFloat: Op<40, (outs TYPE:$type), (ins i32imm:$width),
+                  "$type = OpTypeFloat $width 0">;
 def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
                   "$type = OpTypeVector $compType $compCount">;
 def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 5bba5cdce3753..19c938c96a9e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,10 +1267,12 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16)
       Reqs.addCapability(SPIRV::Capability::Float16);
-    if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
-      Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
-      Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
-    }
+    break;
+  }
+  case SPIRV::OpTypeBFloat: {
+    assert(false && "bfloat detected at the file");
+    Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+    Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
     break;
   }
   case SPIRV::OpTypeVector: {

>From eccbcf845dd59f4a822a5fc9487d48b4e62d7a2c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 14:53:45 -0700
Subject: [PATCH 03/20] Debugging

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  | 19 ++++++++++++-------
 llvm/lib/CodeGen/GlobalISel/Legalizer.cpp     |  4 ++++
 llvm/lib/CodeGen/MachineInstr.cpp             | 14 ++++++++++++--
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  3 ++-
 .../SPIRV/SPIRVPreLegalizerCombiner.cpp       |  6 ++++++
 5 files changed, 36 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 3b112805e082c..4aece388749f8 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2195,6 +2195,8 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
   if (translateSimpleIntrinsic(CI, ID, MIRBuilder))
     return true;
 
+  LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
+  LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
   switch (ID) {
   default:
     break;
@@ -2581,9 +2583,10 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
     Value *FpValue = CI.getOperand(0);
     ConstantInt *TestMaskValue = cast<ConstantInt>(CI.getOperand(1));
 
+    uint32_t Flags = MachineInstr::copyFlagsFromInstruction(CI);
     MIRBuilder
         .buildInstr(TargetOpcode::G_IS_FPCLASS, {getOrCreateVReg(CI)},
-                    {getOrCreateVReg(*FpValue)})
+                    {getOrCreateVReg(*FpValue)}, Flags)
         .addImm(TestMaskValue->getZExtValue());
 
     return true;
@@ -2813,12 +2816,6 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   if (isa<FPMathOperator>(CI))
     MIB->copyIRFlags(CI);
 
-  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
-  if (containsBF16Type(U)) {
-    // assert(false && "bfloat detected at the IR Translator");
-    MIB->setFlag(MachineInstr::MIFlag::BFloat16);
-  }
-
   for (const auto &Arg : enumerate(CI.args())) {
     // If this is required to be an immediate, don't materialize it in a
     // register.
@@ -2879,6 +2876,13 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
       MIB.addUse(TokenReg, RegState::Implicit);
     }
   }
+
+  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+  if (containsBF16Type(U)) {
+    dbgs() << "Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
+    MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
+  }
+   MIB->copyIRFlags(CI);
   
   return true;
 }
@@ -4161,6 +4165,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
 
         // Translate any debug-info attached to the instruction.
         translateDbgInfo(Inst, *CurBuilder);
+        LLVM_DEBUG(dbgs() << "Inst at IRTranslator: " << Inst << "\n");
 
         if (translate(Inst))
           continue;
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index aef16b5f33af4..6013124558106 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -192,6 +192,10 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
     if (MBB->empty())
       continue;
     for (MachineInstr &MI : *MBB) {
+      dbgs() << "Legalizer sees: " << MI << "\n";
+      if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+        assert(false && "bfloat detected at the Legalizer");
+      }
       // Only legalize pre-isel generic instructions: others don't have types
       // and are assumed to be legal.
       if (!isPreISelGenericOpcode(MI.getOpcode()))
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 1ec8c8a2d90c9..bbeeed76631c6 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,9 +632,19 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
-  if (I.getType()->getScalarType()->isBFloatTy()) {
+  auto *Sel = dyn_cast<SelectInst>(&I);
+  bool BFloatOpnd = !Sel && I.getType()->getScalarType()->isBFloatTy();
+
+  if (isa<CallInst>(&I)) {
+    for (const Value *Op : I.operands()) {
+      Type *OpTy = Op->getType();
+      BFloatOpnd |= OpTy->getScalarType()->isBFloatTy();
+    }
+  }
+
+  if (BFloatOpnd) {
     MIFlags |= MachineInstr::MIFlag::BFloat16;
-    // assert(false && "bfloat detected at the MachineInstr");
+    LLVM_DEBUG(dbgs() << "bfloat detected at the MachineInstr" << "\n");
   }
 
   return MIFlags;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 19c938c96a9e8..ab26f0208d758 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,6 +1262,7 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpTypeFloat: {
+    assert(false && "bfloat detected at the spirv module analysis");
     unsigned BitWidth = MI.getOperand(1).getImm();
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
@@ -1270,7 +1271,7 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpTypeBFloat: {
-    assert(false && "bfloat detected at the file");
+    assert(false && "bfloat detected at the spirv module analysis");
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
     Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
     break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
index 83567519355a8..a964db0b02746 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
@@ -156,6 +156,12 @@ SPIRVPreLegalizerCombinerImpl::SPIRVPreLegalizerCombinerImpl(
 }
 
 bool SPIRVPreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
+  // Inst combiner is disabled for instructions
+  // with bfloat type.
+  // TODO: add bfloat-specific inst combiner rules
+  if (MI.getFlag(MachineInstr::BFloat16))
+    return false;
+
   return tryCombineAllImpl(MI);
 }
 

>From 2b080940fd7e57c6dc714aedca113b80e8871d54 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 05:55:29 -0700
Subject: [PATCH 04/20] continue debugging

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 19 +++++++++++--------
 llvm/lib/CodeGen/GlobalISel/Legalizer.cpp    |  4 ++--
 llvm/lib/CodeGen/MachineInstr.cpp            |  4 +++-
 3 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 4aece388749f8..a396839c5061b 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2195,8 +2195,8 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
   if (translateSimpleIntrinsic(CI, ID, MIRBuilder))
     return true;
 
-  LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
-  LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
+  LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
+  LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
   switch (ID) {
   default:
     break;
@@ -2878,11 +2878,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   }
 
   // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+  MIB->copyIRFlags(CI);
   if (containsBF16Type(U)) {
-    dbgs() << "Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
+    dbgs() << "[BFLOAT] Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
     MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
   }
-   MIB->copyIRFlags(CI);
   
   return true;
 }
@@ -4105,8 +4105,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
   if (CLI->fallBackToDAGISel(*MF)) {
     OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
                                F.getSubprogram(), &F.getEntryBlock());
-    R << "unable to lower function: "
-      << ore::NV("Prototype", F.getFunctionType());
+    R << "u "
+      << ore::NV("Prototype", F.getnable to lower function:FunctionType());
     reportTranslationError(*MF, *TPC, *ORE, R);
     return false;
   }
@@ -4118,6 +4118,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
       continue; // Don't handle zero sized types.
     ArrayRef<Register> VRegs = getOrCreateVRegs(Arg);
     VRegArgs.push_back(VRegs);
+    LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
 
     if (Arg.hasSwiftErrorAttr()) {
       assert(VRegs.size() == 1 && "Too many vregs for Swift error");
@@ -4165,10 +4166,12 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
 
         // Translate any debug-info attached to the instruction.
         translateDbgInfo(Inst, *CurBuilder);
-        LLVM_DEBUG(dbgs() << "Inst at IRTranslator: " << Inst << "\n");
+        LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator: " << Inst << "\n");
 
-        if (translate(Inst))
+        if (translate(Inst)) {
+          LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator after translate(Inst): " << Inst << "\n");
           continue;
+        }
 
         OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
                                    Inst.getDebugLoc(), BB);
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index 6013124558106..fc1fede1a6d1c 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -192,9 +192,9 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
     if (MBB->empty())
       continue;
     for (MachineInstr &MI : *MBB) {
-      dbgs() << "Legalizer sees: " << MI << "\n";
+      LLVM_DEBUG(dbgs() << "[BFLOAT]  Legalizer sees: " << MI << "\n");
       if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
-        assert(false && "bfloat detected at the Legalizer");
+        LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the Legalizer \n");
       }
       // Only legalize pre-isel generic instructions: others don't have types
       // and are assumed to be legal.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index bbeeed76631c6..6b7efe671b572 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -61,6 +61,8 @@
 #include <cstring>
 #include <utility>
 
+#define DEBUG_TYPE "irtranslator"
+
 using namespace llvm;
 
 static cl::opt<bool>
@@ -644,7 +646,7 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
 
   if (BFloatOpnd) {
     MIFlags |= MachineInstr::MIFlag::BFloat16;
-    LLVM_DEBUG(dbgs() << "bfloat detected at the MachineInstr" << "\n");
+    LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the MachineInstr \n");
   }
 
   return MIFlags;

>From 34ea97b632a5a71cec7311539c3a78b712fa542a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 08:57:30 -0700
Subject: [PATCH 05/20] continue debugging

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 11 +++++------
 llvm/lib/CodeGen/GlobalISel/Legalizer.cpp    |  3 ++-
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index a396839c5061b..1649b8b4a41b3 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2046,9 +2046,10 @@ bool IRTranslator::translateSimpleIntrinsic(const CallInst &CI,
                                             MachineIRBuilder &MIRBuilder) {
 
   unsigned Op = getSimpleIntrinsicOpcode(ID);
+  LLVM_DEBUG(dbgs() << "[BFLOAT] translateSimpleIntrinsic is called for ID:" << ID << '\n');
 
   // Is this a simple intrinsic?
-  if (Op == Intrinsic::not_intrinsic)
+  if (Op == Intrinsic::not_intrinsic){}
     return false;
 
   // Yes. Let's translate it.
@@ -2768,8 +2769,6 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  // if (containsBF16Type(U))
-  //   return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
@@ -4105,8 +4104,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
   if (CLI->fallBackToDAGISel(*MF)) {
     OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
                                F.getSubprogram(), &F.getEntryBlock());
-    R << "u "
-      << ore::NV("Prototype", F.getnable to lower function:FunctionType());
+    R << "unable to lower function: "
+      << ore::NV("Prototype", F.getFunctionType());
     reportTranslationError(*MF, *TPC, *ORE, R);
     return false;
   }
@@ -4118,7 +4117,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
       continue; // Don't handle zero sized types.
     ArrayRef<Register> VRegs = getOrCreateVRegs(Arg);
     VRegArgs.push_back(VRegs);
-    LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
+    // LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
 
     if (Arg.hasSwiftErrorAttr()) {
       assert(VRegs.size() == 1 && "Too many vregs for Swift error");
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index fc1fede1a6d1c..77289d138bc9c 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -195,7 +195,8 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
       LLVM_DEBUG(dbgs() << "[BFLOAT]  Legalizer sees: " << MI << "\n");
       if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
         LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the Legalizer \n");
-      }
+      } 
+
       // Only legalize pre-isel generic instructions: others don't have types
       // and are assumed to be legal.
       if (!isPreISelGenericOpcode(MI.getOpcode()))

>From eea9fd72db1e3fe7d6e83ecd9c547db75f6d144a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 13:45:57 -0700
Subject: [PATCH 06/20] use a new method of adding the floating point
 information to optypefloat

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 36 +++++++++++++++----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  2 +-
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  5 +--
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  7 ----
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |  2 ++
 5 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cfe24c84941a9..cc524ee7e3cb1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,11 +195,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
-                                               MachineIRBuilder &MIRBuilder) {
+                                               MachineIRBuilder &MIRBuilder, uint32_t Fptype) {
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
     return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
         .addDef(createTypeVReg(MIRBuilder))
-        .addImm(Width);
+        .addImm(Width)
+        .addImm(Fptype);
   });
 }
 
@@ -1041,8 +1042,11 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     return Width == 1 ? getOpTypeBool(MIRBuilder)
                       : getOpTypeInt(Width, MIRBuilder, false);
   }
-  if (Ty->isFloatingPointTy())
-    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+  if (Ty->isFloatingPointTy()) {
+    if (Ty->isBFloatTy())
+      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 1);
+    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 0);
+  }
   if (Ty->isVoidTy())
     return getOpTypeVoid(MIRBuilder);
   if (Ty->isVectorTy()) {
@@ -1677,14 +1681,33 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
     return MI;
   MachineBasicBlock &DepMBB = I.getMF()->front();
   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
-  const MachineInstr *NewMI =
-      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+  const MachineInstr *NewMI = nullptr;
+  if(SPIRVOPcode == SPIRV::OpTypeFloat) {
+    if (Ty->isBFloatTy()) {
+      NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+            .addDef(createTypeVReg(CurMF->getRegInfo()))
+            .addImm(BitWidth)
+            .addImm(1);
+      });
+    } else {
+      NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
         return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
                        MIRBuilder.getDL(), TII.get(SPIRVOPcode))
             .addDef(createTypeVReg(CurMF->getRegInfo()))
             .addImm(BitWidth)
             .addImm(0);
       });
+    }
+  } else {
+    NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+            .addDef(createTypeVReg(CurMF->getRegInfo()))
+            .addImm(BitWidth);
+      });
+  }
   add(Ty, false, NewMI);
   return finishCreatingSPIRVType(Ty, NewMI);
 }
@@ -1718,6 +1741,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
   default:
     llvm_unreachable("Bit width is of unexpected size.");
   }
+  
   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7ef812828b7cc..a2a2e7b289c5d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -436,7 +436,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
                           bool IsSigned = false);
 
-  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, uint32_t Fptype);
 
   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 923c92c28d3f6..a823df89eab6c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,10 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
 def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
 def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
                   "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
-                  "$type = OpTypeFloat $width">;
-def OpTypeBFloat: Op<40, (outs TYPE:$type), (ins i32imm:$width),
-                  "$type = OpTypeFloat $width 0">;
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, i32imm:$fpencode), "$type = OpTypeFloat $width $fpencode">;
 def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
                   "$type = OpTypeVector $compType $compCount">;
 def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index ab26f0208d758..8039cf0c432fa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,7 +1262,6 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpTypeFloat: {
-    assert(false && "bfloat detected at the spirv module analysis");
     unsigned BitWidth = MI.getOperand(1).getImm();
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
@@ -1270,12 +1269,6 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float16);
     break;
   }
-  case SPIRV::OpTypeBFloat: {
-    assert(false && "bfloat detected at the spirv module analysis");
-    Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
-    Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
-    break;
-  }
   case SPIRV::OpTypeVector: {
     unsigned NumComponents = MI.getOperand(2).getImm();
     if (NumComponents == 8 || NumComponents == 16)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 1a08c6ac0dcaf..33cb8471d747e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -540,6 +540,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
                             MF.getRegInfo());
+        LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_ptr_type to" << MI << '\n');
         ToErase.push_back(&MI);
       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
         Register Reg = MI.getOperand(1).getReg();
@@ -550,6 +551,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
+        LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_type to" << MI << '\n');
         ToErase.push_back(&MI);
       } else if (MIOp == TargetOpcode::FAKE_USE && MI.getNumOperands() > 0) {
         MachineInstr *MdMI = MI.getPrevNode();

>From cfaa1db3cf1d501bf200531c63e0cdce758187ee Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 13:54:45 -0700
Subject: [PATCH 07/20] remove debug

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  | 9 +--------
 llvm/lib/CodeGen/GlobalISel/Legalizer.cpp     | 5 -----
 llvm/lib/CodeGen/MachineInstr.cpp             | 1 -
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 2 +-
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   | 2 --
 5 files changed, 2 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 1649b8b4a41b3..e29239c7706c2 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2046,10 +2046,9 @@ bool IRTranslator::translateSimpleIntrinsic(const CallInst &CI,
                                             MachineIRBuilder &MIRBuilder) {
 
   unsigned Op = getSimpleIntrinsicOpcode(ID);
-  LLVM_DEBUG(dbgs() << "[BFLOAT] translateSimpleIntrinsic is called for ID:" << ID << '\n');
 
   // Is this a simple intrinsic?
-  if (Op == Intrinsic::not_intrinsic){}
+  if (Op == Intrinsic::not_intrinsic)
     return false;
 
   // Yes. Let's translate it.
@@ -2196,8 +2195,6 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
   if (translateSimpleIntrinsic(CI, ID, MIRBuilder))
     return true;
 
-  LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
-  LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
   switch (ID) {
   default:
     break;
@@ -2879,7 +2876,6 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
   MIB->copyIRFlags(CI);
   if (containsBF16Type(U)) {
-    dbgs() << "[BFLOAT] Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
     MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
   }
   
@@ -4117,7 +4113,6 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
       continue; // Don't handle zero sized types.
     ArrayRef<Register> VRegs = getOrCreateVRegs(Arg);
     VRegArgs.push_back(VRegs);
-    // LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
 
     if (Arg.hasSwiftErrorAttr()) {
       assert(VRegs.size() == 1 && "Too many vregs for Swift error");
@@ -4165,10 +4160,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
 
         // Translate any debug-info attached to the instruction.
         translateDbgInfo(Inst, *CurBuilder);
-        LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator: " << Inst << "\n");
 
         if (translate(Inst)) {
-          LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator after translate(Inst): " << Inst << "\n");
           continue;
         }
 
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index 77289d138bc9c..aef16b5f33af4 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -192,11 +192,6 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
     if (MBB->empty())
       continue;
     for (MachineInstr &MI : *MBB) {
-      LLVM_DEBUG(dbgs() << "[BFLOAT]  Legalizer sees: " << MI << "\n");
-      if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
-        LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the Legalizer \n");
-      } 
-
       // Only legalize pre-isel generic instructions: others don't have types
       // and are assumed to be legal.
       if (!isPreISelGenericOpcode(MI.getOpcode()))
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 6b7efe671b572..159b27dd46b53 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -646,7 +646,6 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
 
   if (BFloatOpnd) {
     MIFlags |= MachineInstr::MIFlag::BFloat16;
-    LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the MachineInstr \n");
   }
 
   return MIFlags;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cc524ee7e3cb1..0249444ea2328 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1741,7 +1741,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
   default:
     llvm_unreachable("Bit width is of unexpected size.");
   }
-  
+
   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 33cb8471d747e..1a08c6ac0dcaf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -540,7 +540,6 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
                             MF.getRegInfo());
-        LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_ptr_type to" << MI << '\n');
         ToErase.push_back(&MI);
       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
         Register Reg = MI.getOperand(1).getReg();
@@ -551,7 +550,6 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
-        LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_type to" << MI << '\n');
         ToErase.push_back(&MI);
       } else if (MIOp == TargetOpcode::FAKE_USE && MI.getNumOperands() > 0) {
         MachineInstr *MdMI = MI.getPrevNode();

>From 75937a1a3b2c41ca75a33c3cdab9cc4c8a9c305a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 29 Aug 2025 09:04:45 -0700
Subject: [PATCH 08/20] add the testing

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 21 ++++++++++++++++--
 .../extensions/SPV_KHR_bfloat16/bfloat16.ll   | 20 +++++++++++++++++
 .../SPV_KHR_bfloat16/bfloat16_dot.ll          | 22 +++++++++++++++++++
 3 files changed, 61 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..239221d3b3abd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1261,12 +1261,29 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Int8);
     break;
   }
+  case SPIRV::OpDot: {
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+    LLVM_DEBUG(dbgs() << "[BFLOAT] The OpDot MI is " << MI << '\n');
+    LLVM_DEBUG(dbgs() << "[BFLOAT] TypeDef found is " << TypeDef << '\n');
+    if ((TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
+        (TypeDef->getOperand(1).getImm() == 16)) {
+          Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+        }
+  }
   case SPIRV::OpTypeFloat: {
     unsigned BitWidth = MI.getOperand(1).getImm();
+    unsigned Fptype = MI.getOperand(2).getImm();
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
-    else if (BitWidth == 16)
-      Reqs.addCapability(SPIRV::Capability::Float16);
+    else if (BitWidth == 16) {
+      if(Fptype == 1) {
+        Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+        Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+      } else {
+        Reqs.addCapability(SPIRV::Capability::Float16);
+      }
+    }
     break;
   }
   case SPIRV::OpTypeVector: {
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
new file mode 100644
index 0000000000000..8a917dc219fa0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -0,0 +1,20 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+; XFAIL: *
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR
+
+; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 1
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+
+define spir_kernel void @test() {
+entry:
+  %addr1 = alloca bfloat
+  %addr2 = alloca <2 x bfloat>
+  %data1 = load bfloat, ptr %addr1
+  %data2 = load <2 x bfloat>, ptr %addr2
+  ret void
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
new file mode 100644
index 0000000000000..337eda538c196
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -0,0 +1,22 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+; XFAIL: *
+
+; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
+; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR
+; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16 0
+; CHECK-SPIRV: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+; CHECK-SPIRV: OpDot
+
+declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
+
+define spir_kernel void @test() {
+entry:
+  %addrA = alloca <2 x bfloat>
+  %addrB = alloca <2 x bfloat>
+  %dataA = load <2 x bfloat>, ptr %addrA
+  %dataB = load <2 x bfloat>, ptr %addrB
+  %call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
+  ret void
+}

>From cd7bbfbd51f95939c8cfdb6fa911a2c19724ddca Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 29 Aug 2025 13:44:06 -0700
Subject: [PATCH 09/20] remove adding bfloat flag to MIR

---
 llvm/include/llvm/CodeGen/MachineInstr.h            |  3 +--
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp        | 12 ++----------
 llvm/lib/CodeGen/MachineInstr.cpp                   | 13 -------------
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp       |  3 +--
 llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp |  6 ------
 .../SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll   |  2 +-
 6 files changed, 5 insertions(+), 34 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 6f692ae32510b..10a9b1ff1411d 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,9 +123,8 @@ class MachineInstr
     NoUSWrap = 1 << 20,      // Instruction supports geps
                              // no unsigned signed wrap.
     SameSign = 1 << 21,      // Both operands have the same sign.
-    InBounds = 1 << 22,       // Pointer arithmetic remains inbounds.
+    InBounds = 1 << 22       // Pointer arithmetic remains inbounds.
                              // Implies NoUSWrap.
-    BFloat16 = 1 << 23      // Instruction with bf16 type
   };
 
 private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index e29239c7706c2..d86d3551b2d81 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2581,10 +2581,9 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
     Value *FpValue = CI.getOperand(0);
     ConstantInt *TestMaskValue = cast<ConstantInt>(CI.getOperand(1));
 
-    uint32_t Flags = MachineInstr::copyFlagsFromInstruction(CI);
     MIRBuilder
         .buildInstr(TargetOpcode::G_IS_FPCLASS, {getOrCreateVReg(CI)},
-                    {getOrCreateVReg(*FpValue)}, Flags)
+                    {getOrCreateVReg(*FpValue)})
         .addImm(TestMaskValue->getZExtValue());
 
     return true;
@@ -2873,12 +2872,6 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
     }
   }
 
-  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
-  MIB->copyIRFlags(CI);
-  if (containsBF16Type(U)) {
-    MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
-  }
-  
   return true;
 }
 
@@ -4161,9 +4154,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
         // Translate any debug-info attached to the instruction.
         translateDbgInfo(Inst, *CurBuilder);
 
-        if (translate(Inst)) {
+        if (translate(Inst)) 
           continue;
-        }
 
         OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
                                    Inst.getDebugLoc(), BB);
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 159b27dd46b53..29b30cf914208 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -61,8 +61,6 @@
 #include <cstring>
 #include <utility>
 
-#define DEBUG_TYPE "irtranslator"
-
 using namespace llvm;
 
 static cl::opt<bool>
@@ -637,17 +635,6 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   auto *Sel = dyn_cast<SelectInst>(&I);
   bool BFloatOpnd = !Sel && I.getType()->getScalarType()->isBFloatTy();
 
-  if (isa<CallInst>(&I)) {
-    for (const Value *Op : I.operands()) {
-      Type *OpTy = Op->getType();
-      BFloatOpnd |= OpTy->getScalarType()->isBFloatTy();
-    }
-  }
-
-  if (BFloatOpnd) {
-    MIFlags |= MachineInstr::MIFlag::BFloat16;
-  }
-
   return MIFlags;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 239221d3b3abd..44024bde31067 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1264,12 +1264,11 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpDot: {
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
-    LLVM_DEBUG(dbgs() << "[BFLOAT] The OpDot MI is " << MI << '\n');
-    LLVM_DEBUG(dbgs() << "[BFLOAT] TypeDef found is " << TypeDef << '\n');
     if ((TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
         (TypeDef->getOperand(1).getImm() == 16)) {
           Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
         }
+    break;
   }
   case SPIRV::OpTypeFloat: {
     unsigned BitWidth = MI.getOperand(1).getImm();
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
index a964db0b02746..83567519355a8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
@@ -156,12 +156,6 @@ SPIRVPreLegalizerCombinerImpl::SPIRVPreLegalizerCombinerImpl(
 }
 
 bool SPIRVPreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
-  // Inst combiner is disabled for instructions
-  // with bfloat type.
-  // TODO: add bfloat-specific inst combiner rules
-  if (MI.getFlag(MachineInstr::BFloat16))
-    return false;
-
   return tryCombineAllImpl(MI);
 }
 
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index 8a917dc219fa0..95eb5a8348d67 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -17,4 +17,4 @@ entry:
   %data1 = load bfloat, ptr %addr1
   %data2 = load <2 x bfloat>, ptr %addr2
   ret void
-}
\ No newline at end of file
+}

>From 7b5f9642662abb530481d32eae1ddfb4d5bdb8f1 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 29 Aug 2025 13:48:31 -0700
Subject: [PATCH 10/20] nit

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  | 4 +++-
 llvm/lib/CodeGen/MachineInstr.cpp             | 3 ---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 1 -
 3 files changed, 3 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index d86d3551b2d81..541269ab6bfce 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,6 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
+  if (containsBF16Type(U))
+    return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
@@ -4154,7 +4156,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
         // Translate any debug-info attached to the instruction.
         translateDbgInfo(Inst, *CurBuilder);
 
-        if (translate(Inst)) 
+        if (translate(Inst))
           continue;
 
         OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 29b30cf914208..79047f732808a 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,9 +632,6 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
-  auto *Sel = dyn_cast<SelectInst>(&I);
-  bool BFloatOpnd = !Sel && I.getType()->getScalarType()->isBFloatTy();
-
   return MIFlags;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0249444ea2328..e021920f46982 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1741,7 +1741,6 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
   default:
     llvm_unreachable("Bit width is of unexpected size.");
   }
-
   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
 }
 

>From 05246aaa6558f1561cfa4419f1e70f3028a66a9b Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sun, 31 Aug 2025 10:22:36 -0700
Subject: [PATCH 11/20] change an approach: add a map in the global registry
 class for mapping vreg to ifbfloat

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  4 ++--
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 19 ++++++++++++++++++-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  6 ++++++
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 18 ++++++++++--------
 4 files changed, 36 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..04b3d6f30e023 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
-    return false;
+  // if (containsBF16Type(U))
+  //   return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e021920f46982..1a1624366bbf2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1125,9 +1125,15 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   TypesInProcessing.insert(Ty);
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
                                          ExplicitLayoutRequired, EmitIR);
-  TypesInProcessing.erase(Ty);
+                                         
+  if(Ty->isFloatingPointTy() && Ty->isBFloatTy()){
+    VRegIsBFloatMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = true;
+  }
+
   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
 
+  TypesInProcessing.erase(Ty);
+
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
   // Is that a problem?
   SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
@@ -1162,6 +1168,17 @@ SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
   return nullptr;
 }
 
+bool SPIRVGlobalRegistry::getIsBFloatForVReg(Register VReg,
+                                         const MachineFunction *MF) {
+  auto t = VRegToTypeMap.find(MF ? MF : CurMF);
+  if (t != VRegToTypeMap.end()) {
+    auto tt = t->second.find(VReg);
+    if (tt != t->second.end())
+      return tt->second;
+  }
+  return false;
+}
+
 SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
                                               MachineFunction *MF) {
   if (!MF)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a2a2e7b289c5d..2629cd515cca6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -37,6 +37,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // type-declaring ones).
   DenseMap<const MachineFunction *, DenseMap<Register, SPIRVType *>>
       VRegToTypeMap;
+  
+  DenseMap<const MachineFunction *, DenseMap<Register, bool>>
+      VRegIsBFloatMap;
 
   DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
 
@@ -340,6 +343,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // functions.
   SPIRVType *getSPIRVTypeForVReg(Register VReg,
                                  const MachineFunction *MF = nullptr) const;
+  
+  bool getIsBFloatForVReg(Register VReg,
+                                 const MachineFunction *MF = nullptr);
 
   // Return the result type of the instruction defining the register.
   SPIRVType *getResultType(Register VReg, MachineFunction *MF = nullptr);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 44024bde31067..38ffc51ad63be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,21 +1262,23 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpDot: {
-    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
-    SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
-    if ((TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
-        (TypeDef->getOperand(1).getImm() == 16)) {
-          Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
-        }
+    const MachineFunction *MF = MI.getMF();
+    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+    bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(1).getReg(), MF);
+    if (isBfloat) {
+      Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+    }
     break;
   }
   case SPIRV::OpTypeFloat: {
     unsigned BitWidth = MI.getOperand(1).getImm();
-    unsigned Fptype = MI.getOperand(2).getImm();
+    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+    const MachineFunction *MF = MI.getMF();
+    bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(0).getReg(), MF);
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16) {
-      if(Fptype == 1) {
+      if(isBfloat) {
         Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
         Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
       } else {

>From 7743453bd83978d620d6ea4f72890f98d0abbfd7 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sun, 31 Aug 2025 15:37:52 -0700
Subject: [PATCH 12/20] nit change

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  3 --
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 50 +++++++++----------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 24 ++++++---
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  3 +-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 12 ++---
 5 files changed, 50 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 04b3d6f30e023..8a4b1afec3d26 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,9 +2765,6 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  // if (containsBF16Type(U))
-  //   return false;
-
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 1a1624366bbf2..633ef770ec202 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,12 +195,11 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
-                                               MachineIRBuilder &MIRBuilder, uint32_t Fptype) {
+                                               MachineIRBuilder &MIRBuilder) {
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
     return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
         .addDef(createTypeVReg(MIRBuilder))
-        .addImm(Width)
-        .addImm(Fptype);
+        .addImm(Width);
   });
 }
 
@@ -1042,11 +1041,8 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     return Width == 1 ? getOpTypeBool(MIRBuilder)
                       : getOpTypeInt(Width, MIRBuilder, false);
   }
-  if (Ty->isFloatingPointTy()) {
-    if (Ty->isBFloatTy())
-      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 1);
-    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 0);
-  }
+  if (Ty->isFloatingPointTy())
+    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
   if (Ty->isVoidTy())
     return getOpTypeVoid(MIRBuilder);
   if (Ty->isVectorTy()) {
@@ -1125,14 +1121,18 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   TypesInProcessing.insert(Ty);
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
                                          ExplicitLayoutRequired, EmitIR);
-                                         
-  if(Ty->isFloatingPointTy() && Ty->isBFloatTy()){
-    VRegIsBFloatMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = true;
-  }
+  TypesInProcessing.erase(Ty);
 
-  VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
+  MachineFunction *MF = &MIRBuilder.getMF();
+  Register TypeReg = getSPIRVTypeID(SpirvType);
 
-  TypesInProcessing.erase(Ty);
+  if (Ty->isFloatingPointTy()) {
+    if (Ty->isBFloatTy()) {
+        VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
+    } else {
+        VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
+    }
+  }
 
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
   // Is that a problem?
@@ -1168,17 +1168,6 @@ SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
   return nullptr;
 }
 
-bool SPIRVGlobalRegistry::getIsBFloatForVReg(Register VReg,
-                                         const MachineFunction *MF) {
-  auto t = VRegToTypeMap.find(MF ? MF : CurMF);
-  if (t != VRegToTypeMap.end()) {
-    auto tt = t->second.find(VReg);
-    if (tt != t->second.end())
-      return tt->second;
-  }
-  return false;
-}
-
 SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
                                               MachineFunction *MF) {
   if (!MF)
@@ -2128,3 +2117,14 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
   }
   return false;
 }
+
+SPIRVGlobalRegistry::FPVariant SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
+                                         const MachineFunction *MF) {
+  auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
+  if (t != VRegFPVariantMap.end()) {
+    auto tt = t->second.find(VReg);
+    if (tt != t->second.end())
+      return tt->second;
+  }
+  return FPVariant::NONE;
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2629cd515cca6..8f9a5d3cfc68a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,6 +29,14 @@ using SPIRVType = const MachineInstr;
 using StructOffsetDecorator = std::function<void(Register)>;
 
 class SPIRVGlobalRegistry : public SPIRVIRMapping {
+  public:
+  enum class FPVariant {
+    NONE,
+    IEEE_FLOAT,
+    BRAIN_FLOAT
+  };
+
+  private:
   // Registers holding values which have types associated with them.
   // Initialized upon VReg definition in IRTranslator.
   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -37,9 +45,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // type-declaring ones).
   DenseMap<const MachineFunction *, DenseMap<Register, SPIRVType *>>
       VRegToTypeMap;
-  
-  DenseMap<const MachineFunction *, DenseMap<Register, bool>>
-      VRegIsBFloatMap;
 
   DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
 
@@ -91,6 +96,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // map of aliasing decorations to aliasing metadata
   std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;
 
+  // Maps floating point Registers to their FPVariant (float type kind), given the MachineFunction.
+  DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
+    VRegFPVariantMap;
+
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
                              SPIRV::AccessQualifier::AccessQualifier AQ,
@@ -343,9 +352,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // functions.
   SPIRVType *getSPIRVTypeForVReg(Register VReg,
                                  const MachineFunction *MF = nullptr) const;
-  
-  bool getIsBFloatForVReg(Register VReg,
-                                 const MachineFunction *MF = nullptr);
 
   // Return the result type of the instruction defining the register.
   SPIRVType *getResultType(Register VReg, MachineFunction *MF = nullptr);
@@ -427,6 +433,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // Informs about removal of the machine instruction and invalidates data
   // structures referring this instruction.
   void invalidateMachineInstr(MachineInstr *MI);
+  
+  // Return the FPVariant of to the given floating-point regiester. 
+  FPVariant getFPVariantForVReg(Register VReg,
+                                 const MachineFunction *MF = nullptr);
 
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
@@ -442,7 +452,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
                           bool IsSigned = false);
 
-  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, uint32_t Fptype);
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
 
   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a823df89eab6c..f0b938d681dba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,7 +167,8 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
 def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
 def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
                   "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, i32imm:$fpencode), "$type = OpTypeFloat $width $fpencode">;
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
+                  "$type = OpTypeFloat $width">;
 def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
                   "$type = OpTypeVector $compType $compCount">;
 def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 38ffc51ad63be..24798ff6b8f0e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1264,21 +1264,21 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpDot: {
     const MachineFunction *MF = MI.getMF();
     SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
-    bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(1).getReg(), MF);
-    if (isBfloat) {
+    SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+    if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
       Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
     }
     break;
   }
   case SPIRV::OpTypeFloat: {
     unsigned BitWidth = MI.getOperand(1).getImm();
-    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
-    const MachineFunction *MF = MI.getMF();
-    bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(0).getReg(), MF);
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16) {
-      if(isBfloat) {
+      SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+      const MachineFunction *MF = MI.getMF();
+      SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+      if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
         Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
         Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
       } else {

>From d8b9deb4a5111bf112450b3807cab410a87dc006 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sun, 31 Aug 2025 15:47:37 -0700
Subject: [PATCH 13/20] nit change

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 25 +++----------------
 1 file changed, 4 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 633ef770ec202..61a8ba19b28e4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1134,6 +1134,8 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
     }
   }
 
+  VRegToTypeMap[MF][TypeReg] = SpirvType;
+
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
   // Is that a problem?
   SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
@@ -1687,33 +1689,14 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
     return MI;
   MachineBasicBlock &DepMBB = I.getMF()->front();
   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
-  const MachineInstr *NewMI = nullptr;
-  if(SPIRVOPcode == SPIRV::OpTypeFloat) {
-    if (Ty->isBFloatTy()) {
-      NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
-                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
-            .addDef(createTypeVReg(CurMF->getRegInfo()))
-            .addImm(BitWidth)
-            .addImm(1);
-      });
-    } else {
-      NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+  const MachineInstr *NewMI =
+      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
         return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
                        MIRBuilder.getDL(), TII.get(SPIRVOPcode))
             .addDef(createTypeVReg(CurMF->getRegInfo()))
             .addImm(BitWidth)
             .addImm(0);
       });
-    }
-  } else {
-    NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
-                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
-            .addDef(createTypeVReg(CurMF->getRegInfo()))
-            .addImm(BitWidth);
-      });
-  }
   add(Ty, false, NewMI);
   return finishCreatingSPIRVType(Ty, NewMI);
 }

>From f7f16ef1c39914f5bbdde9024b7b2a355e98a82f Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sun, 31 Aug 2025 16:03:11 -0700
Subject: [PATCH 14/20] nit change

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 24798ff6b8f0e..653b67e464908 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1277,7 +1277,7 @@ void addInstrRequirements(const MachineInstr &MI,
     else if (BitWidth == 16) {
       SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
       const MachineFunction *MF = MI.getMF();
-      SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+      SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
       if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
         Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
         Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);

>From 07039fdbc7a0d7fce0a0e1a71baa908c1dc9c353 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 2 Sep 2025 05:19:27 -0700
Subject: [PATCH 15/20] add the test for cooperative matirx

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp       |  3 +--
 .../SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll | 13 +++++++++++++
 .../extensions/SPV_KHR_bfloat16/bfloat16_dot.ll     |  1 -
 3 files changed, 14 insertions(+), 3 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 61a8ba19b28e4..cb458be51c0c6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1123,9 +1123,9 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
                                          ExplicitLayoutRequired, EmitIR);
   TypesInProcessing.erase(Ty);
 
+  // Record the FPVariant of the floating-point registers in the VRegFPVariantMap.
   MachineFunction *MF = &MIRBuilder.getMF();
   Register TypeReg = getSPIRVTypeID(SpirvType);
-
   if (Ty->isFloatingPointTy()) {
     if (Ty->isBFloatTy()) {
         VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
@@ -1133,7 +1133,6 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
         VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
     }
   }
-
   VRegToTypeMap[MF][TypeReg] = SpirvType;
 
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
new file mode 100644
index 0000000000000..e9eec4fdac126
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -0,0 +1,13 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+define spir_kernel void @matr() {
+entry:
+    %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+    %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+    %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 0.0)
+    store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
+    ret void
+}
+
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index 337eda538c196..bd7e3c71e0cc9 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -1,6 +1,5 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
-; XFAIL: *
 
 ; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR

>From 82c56b9c8d518c2d8de0d82e785a4f2ce3441206 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 2 Sep 2025 14:48:01 -0700
Subject: [PATCH 16/20] test is failing for cooperative matrix

---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp    |  1 +
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp    | 16 ++++++++--------
 .../bfloat16_cooperative_matrix.ll               |  5 -----
 3 files changed, 9 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 98c7709acf938..82977319b4779 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -569,6 +569,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
   bool HasDefs = I.getNumDefs() > 0;
   Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0);
   SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr;
+
   assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
          I.getOpcode() == TargetOpcode::G_IMPLICIT_DEF);
   if (spvSelect(ResVReg, ResType, I)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 653b67e464908..360dac9364cdc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1611,15 +1611,15 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::AsmINTEL);
     }
     break;
-  case SPIRV::OpTypeCooperativeMatrixKHR:
-    if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
-      report_fatal_error(
-          "OpTypeCooperativeMatrixKHR type requires the "
-          "following SPIR-V extension: SPV_KHR_cooperative_matrix",
-          false);
-    Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
-    Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+  case SPIRV::OpTypeCooperativeMatrixKHR: {
+    const MachineFunction *MF = MI.getMF();
+    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+    SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+    if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+      Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
+    }
     break;
+  }
   case SPIRV::OpArithmeticFenceEXT:
     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
       report_fatal_error("OpArithmeticFenceEXT requires the "
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index e9eec4fdac126..fa938fbf648a9 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -4,10 +4,5 @@
 define spir_kernel void @matr() {
 entry:
     %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
-    %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
-    %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 0.0)
-    store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
     ret void
 }
-
-declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)

>From da902e0fa0c7eb98484649f057e3ede70db68747 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 3 Sep 2025 07:53:35 -0700
Subject: [PATCH 17/20] the cooperative matrix immplementation is now passing

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp  | 18 +++++++++++-------
 .../Target/SPIRV/SPIRVInstructionSelector.cpp  |  1 -
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp  |  8 ++++++++
 .../bfloat16_cooperative_matrix.ll             |  5 +++++
 4 files changed, 24 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cb458be51c0c6..946da8fcfff24 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1689,13 +1689,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
   MachineBasicBlock &DepMBB = I.getMF()->front();
   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
   const MachineInstr *NewMI =
-      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
-                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
-            .addDef(createTypeVReg(CurMF->getRegInfo()))
-            .addImm(BitWidth)
-            .addImm(0);
-      });
+    createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+      auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+                         MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+                     .addDef(createTypeVReg(CurMF->getRegInfo()))
+                     .addImm(BitWidth);
+
+      if (SPIRVOPcode != SPIRV::OpTypeFloat)
+        MIB.addImm(0);
+
+      return MIB;
+    });
   add(Ty, false, NewMI);
   return finishCreatingSPIRVType(Ty, NewMI);
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 82977319b4779..98c7709acf938 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -569,7 +569,6 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
   bool HasDefs = I.getNumDefs() > 0;
   Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0);
   SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr;
-
   assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
          I.getOpcode() == TargetOpcode::G_IMPLICIT_DEF);
   if (spvSelect(ResVReg, ResType, I)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 360dac9364cdc..07987c67f243b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1612,11 +1612,19 @@ void addInstrRequirements(const MachineInstr &MI,
     }
     break;
   case SPIRV::OpTypeCooperativeMatrixKHR: {
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
+      report_fatal_error(
+          "OpTypeCooperativeMatrixKHR type requires the "
+          "following SPIR-V extension: SPV_KHR_cooperative_matrix",
+          false);
+    Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
     const MachineFunction *MF = MI.getMF();
     SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
     SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
     if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
       Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
+    } else {
+      Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
     }
     break;
   }
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index fa938fbf648a9..a6edb9099d29d 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -4,5 +4,10 @@
 define spir_kernel void @matr() {
 entry:
     %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+    %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+    %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 1.0)
+    store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
     ret void
 }
+
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
\ No newline at end of file

>From a6e336489b293aef7f94661b14317eff483340bf Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 3 Sep 2025 13:10:31 -0700
Subject: [PATCH 18/20] update the tests

---
 .../SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll |  6 +-----
 .../bfloat16_cooperative_matrix.ll                | 15 +++++++++++----
 .../extensions/SPV_KHR_bfloat16/bfloat16_dot.ll   |  2 +-
 3 files changed, 13 insertions(+), 10 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index 95eb5a8348d67..b019ce60a366f 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -1,13 +1,9 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
-; XFAIL: *
-; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR
-
-; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate
 
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
-; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 1
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
 ; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
 
 define spir_kernel void @test() {
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index a6edb9099d29d..6889063458f4e 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -1,7 +1,14 @@
-; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
 
-define spir_kernel void @matr() {
+; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
+; CHECK-SPIRV-DAG: OpCapability BFloat16CooperativeMatrixKHR
+; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK-SPIRV: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]]  %[[#]] %[[#]] %[[#]] %[[#]]
+; CHECK-SPIRV: OpCompositeConstruct %[[#MatTy]] %[[#]]
+
+define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
 entry:
     %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
     %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
@@ -10,4 +17,4 @@ entry:
     ret void
 }
 
-declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
\ No newline at end of file
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index bd7e3c71e0cc9..02390daf5533d 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -4,7 +4,7 @@
 ; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR
 ; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16"
-; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16 0
+; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16
 ; CHECK-SPIRV: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
 ; CHECK-SPIRV: OpDot
 

>From c508dccafae5ca81fbde1c2d1cd0d4909deba0bd Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 3 Sep 2025 15:27:21 -0700
Subject: [PATCH 19/20] update the error in the tests

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp        |  5 +++++
 .../SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll    |  6 ++++++
 .../SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll  | 12 ++++++------
 .../extensions/SPV_KHR_bfloat16/bfloat16_dot.ll      | 12 ++++++------
 4 files changed, 23 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 07987c67f243b..1cc76b5e30036 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1279,6 +1279,11 @@ void addInstrRequirements(const MachineInstr &MI,
       const MachineFunction *MF = MI.getMF();
       SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
       if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+        if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
+      report_fatal_error(
+          "OpTypeFloat type with bfloat requires the "
+          "following SPIR-V extension: SPV_KHR_bfloat16",
+          false);
         Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
         Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
       } else {
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index b019ce60a366f..bfc84691f6945 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -1,11 +1,17 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
 
+; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
+
 ; CHECK-DAG: OpCapability BFloat16TypeKHR
 ; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
 ; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
 ; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
 
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
 define spir_kernel void @test() {
 entry:
   %addr1 = alloca bfloat
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index 6889063458f4e..5a6e6d88ca6a0 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -1,12 +1,12 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
 
-; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
-; CHECK-SPIRV-DAG: OpCapability BFloat16CooperativeMatrixKHR
-; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16"
-; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16
-; CHECK-SPIRV: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]]  %[[#]] %[[#]] %[[#]] %[[#]]
-; CHECK-SPIRV: OpCompositeConstruct %[[#MatTy]] %[[#]]
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]]  %[[#]] %[[#]] %[[#]] %[[#]]
+; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
 
 define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
 entry:
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index 02390daf5533d..7cfe29261f2cd 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -1,12 +1,12 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
 
-; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
-; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR
-; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16"
-; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16
-; CHECK-SPIRV: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
-; CHECK-SPIRV: OpDot
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16DotProductKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+; CHECK: OpDot
 
 declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
 

>From 72e060daf0ef19729ed99e92455ec8c247a33030 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 4 Sep 2025 04:53:16 -0700
Subject: [PATCH 20/20] solve the public test failure

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  | 412 +++++++++---------
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp    |   3 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  28 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  23 +-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  19 +-
 5 files changed, 240 insertions(+), 245 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 8a4b1afec3d26..7aeff7f2396ab 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -101,14 +101,14 @@ static cl::opt<bool>
 char IRTranslator::ID = 0;
 
 INITIALIZE_PASS_BEGIN(IRTranslator, DEBUG_TYPE, "IRTranslator LLVM IR -> MI",
-                false, false)
+                      false, false)
 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
 INITIALIZE_PASS_DEPENDENCY(GISelCSEAnalysisWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(StackProtector)
 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
 INITIALIZE_PASS_END(IRTranslator, DEBUG_TYPE, "IRTranslator LLVM IR -> MI",
-                false, false)
+                    false, false)
 
 static void reportTranslationError(MachineFunction &MF,
                                    const TargetPassConfig &TPC,
@@ -168,7 +168,6 @@ class DILocationVerifier : public GISelChangeObserver {
 } // namespace
 #endif // ifndef NDEBUG
 
-
 void IRTranslator::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addRequired<StackProtector>();
   AU.addRequired<TargetPassConfig>();
@@ -360,8 +359,8 @@ bool IRTranslator::translateCompare(const User &U,
   if (CmpInst::isIntPredicate(Pred))
     MIRBuilder.buildICmp(Pred, Res, Op0, Op1, Flags);
   else if (Pred == CmpInst::FCMP_FALSE)
-    MIRBuilder.buildCopy(
-        Res, getOrCreateVReg(*Constant::getNullValue(U.getType())));
+    MIRBuilder.buildCopy(Res,
+                         getOrCreateVReg(*Constant::getNullValue(U.getType())));
   else if (Pred == CmpInst::FCMP_TRUE)
     MIRBuilder.buildCopy(
         Res, getOrCreateVReg(*Constant::getAllOnesValue(U.getType())));
@@ -951,8 +950,8 @@ void IRTranslator::emitSwitchCase(SwitchCG::CaseBlock &CB,
     assert(CB.PredInfo.Pred == CmpInst::ICMP_SLE &&
            "Can only handle SLE ranges");
 
-    const APInt& Low = cast<ConstantInt>(CB.CmpLHS)->getValue();
-    const APInt& High = cast<ConstantInt>(CB.CmpRHS)->getValue();
+    const APInt &Low = cast<ConstantInt>(CB.CmpLHS)->getValue();
+    const APInt &High = cast<ConstantInt>(CB.CmpRHS)->getValue();
 
     Register CmpOpReg = getOrCreateVReg(*CB.CmpMHS);
     if (cast<ConstantInt>(CB.CmpLHS)->isMinValue(true)) {
@@ -987,16 +986,12 @@ void IRTranslator::emitSwitchCase(SwitchCG::CaseBlock &CB,
   MIB.setDebugLoc(OldDbgLoc);
 }
 
-bool IRTranslator::lowerJumpTableWorkItem(SwitchCG::SwitchWorkListItem W,
-                                          MachineBasicBlock *SwitchMBB,
-                                          MachineBasicBlock *CurMBB,
-                                          MachineBasicBlock *DefaultMBB,
-                                          MachineIRBuilder &MIB,
-                                          MachineFunction::iterator BBI,
-                                          BranchProbability UnhandledProbs,
-                                          SwitchCG::CaseClusterIt I,
-                                          MachineBasicBlock *Fallthrough,
-                                          bool FallthroughUnreachable) {
+bool IRTranslator::lowerJumpTableWorkItem(
+    SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB,
+    MachineBasicBlock *CurMBB, MachineBasicBlock *DefaultMBB,
+    MachineIRBuilder &MIB, MachineFunction::iterator BBI,
+    BranchProbability UnhandledProbs, SwitchCG::CaseClusterIt I,
+    MachineBasicBlock *Fallthrough, bool FallthroughUnreachable) {
   using namespace SwitchCG;
   MachineFunction *CurMF = SwitchMBB->getParent();
   // FIXME: Optimize away range check based on pivot comparisons.
@@ -1058,14 +1053,11 @@ bool IRTranslator::lowerJumpTableWorkItem(SwitchCG::SwitchWorkListItem W,
   }
   return true;
 }
-bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
-                                            Value *Cond,
-                                            MachineBasicBlock *Fallthrough,
-                                            bool FallthroughUnreachable,
-                                            BranchProbability UnhandledProbs,
-                                            MachineBasicBlock *CurMBB,
-                                            MachineIRBuilder &MIB,
-                                            MachineBasicBlock *SwitchMBB) {
+bool IRTranslator::lowerSwitchRangeWorkItem(
+    SwitchCG::CaseClusterIt I, Value *Cond, MachineBasicBlock *Fallthrough,
+    bool FallthroughUnreachable, BranchProbability UnhandledProbs,
+    MachineBasicBlock *CurMBB, MachineIRBuilder &MIB,
+    MachineBasicBlock *SwitchMBB) {
   using namespace SwitchCG;
   const Value *RHS, *LHS, *MHS;
   CmpInst::Predicate Pred;
@@ -1700,8 +1692,7 @@ bool IRTranslator::translateGetElementPtr(const User &U,
   }
 
   if (Offset != 0) {
-    auto OffsetMIB =
-        MIRBuilder.buildConstant(OffsetTy, Offset);
+    auto OffsetMIB = MIRBuilder.buildConstant(OffsetTy, Offset);
 
     MIRBuilder.buildPtrAdd(getOrCreateVReg(U), BaseReg, OffsetMIB.getReg(0),
                            PtrAddFlagsWithConst(Offset));
@@ -1898,145 +1889,145 @@ bool IRTranslator::translateFixedPointIntrinsic(unsigned Op, const CallInst &CI,
   Register Src0 = getOrCreateVReg(*CI.getOperand(0));
   Register Src1 = getOrCreateVReg(*CI.getOperand(1));
   uint64_t Scale = cast<ConstantInt>(CI.getOperand(2))->getZExtValue();
-  MIRBuilder.buildInstr(Op, {Dst}, { Src0, Src1, Scale });
+  MIRBuilder.buildInstr(Op, {Dst}, {Src0, Src1, Scale});
   return true;
 }
 
 unsigned IRTranslator::getSimpleIntrinsicOpcode(Intrinsic::ID ID) {
   switch (ID) {
-    default:
-      break;
-    case Intrinsic::acos:
-      return TargetOpcode::G_FACOS;
-    case Intrinsic::asin:
-      return TargetOpcode::G_FASIN;
-    case Intrinsic::atan:
-      return TargetOpcode::G_FATAN;
-    case Intrinsic::atan2:
-      return TargetOpcode::G_FATAN2;
-    case Intrinsic::bswap:
-      return TargetOpcode::G_BSWAP;
-    case Intrinsic::bitreverse:
-      return TargetOpcode::G_BITREVERSE;
-    case Intrinsic::fshl:
-      return TargetOpcode::G_FSHL;
-    case Intrinsic::fshr:
-      return TargetOpcode::G_FSHR;
-    case Intrinsic::ceil:
-      return TargetOpcode::G_FCEIL;
-    case Intrinsic::cos:
-      return TargetOpcode::G_FCOS;
-    case Intrinsic::cosh:
-      return TargetOpcode::G_FCOSH;
-    case Intrinsic::ctpop:
-      return TargetOpcode::G_CTPOP;
-    case Intrinsic::exp:
-      return TargetOpcode::G_FEXP;
-    case Intrinsic::exp2:
-      return TargetOpcode::G_FEXP2;
-    case Intrinsic::exp10:
-      return TargetOpcode::G_FEXP10;
-    case Intrinsic::fabs:
-      return TargetOpcode::G_FABS;
-    case Intrinsic::copysign:
-      return TargetOpcode::G_FCOPYSIGN;
-    case Intrinsic::minnum:
-      return TargetOpcode::G_FMINNUM;
-    case Intrinsic::maxnum:
-      return TargetOpcode::G_FMAXNUM;
-    case Intrinsic::minimum:
-      return TargetOpcode::G_FMINIMUM;
-    case Intrinsic::maximum:
-      return TargetOpcode::G_FMAXIMUM;
-    case Intrinsic::minimumnum:
-      return TargetOpcode::G_FMINIMUMNUM;
-    case Intrinsic::maximumnum:
-      return TargetOpcode::G_FMAXIMUMNUM;
-    case Intrinsic::canonicalize:
-      return TargetOpcode::G_FCANONICALIZE;
-    case Intrinsic::floor:
-      return TargetOpcode::G_FFLOOR;
-    case Intrinsic::fma:
-      return TargetOpcode::G_FMA;
-    case Intrinsic::log:
-      return TargetOpcode::G_FLOG;
-    case Intrinsic::log2:
-      return TargetOpcode::G_FLOG2;
-    case Intrinsic::log10:
-      return TargetOpcode::G_FLOG10;
-    case Intrinsic::ldexp:
-      return TargetOpcode::G_FLDEXP;
-    case Intrinsic::nearbyint:
-      return TargetOpcode::G_FNEARBYINT;
-    case Intrinsic::pow:
-      return TargetOpcode::G_FPOW;
-    case Intrinsic::powi:
-      return TargetOpcode::G_FPOWI;
-    case Intrinsic::rint:
-      return TargetOpcode::G_FRINT;
-    case Intrinsic::round:
-      return TargetOpcode::G_INTRINSIC_ROUND;
-    case Intrinsic::roundeven:
-      return TargetOpcode::G_INTRINSIC_ROUNDEVEN;
-    case Intrinsic::sin:
-      return TargetOpcode::G_FSIN;
-    case Intrinsic::sinh:
-      return TargetOpcode::G_FSINH;
-    case Intrinsic::sqrt:
-      return TargetOpcode::G_FSQRT;
-    case Intrinsic::tan:
-      return TargetOpcode::G_FTAN;
-    case Intrinsic::tanh:
-      return TargetOpcode::G_FTANH;
-    case Intrinsic::trunc:
-      return TargetOpcode::G_INTRINSIC_TRUNC;
-    case Intrinsic::readcyclecounter:
-      return TargetOpcode::G_READCYCLECOUNTER;
-    case Intrinsic::readsteadycounter:
-      return TargetOpcode::G_READSTEADYCOUNTER;
-    case Intrinsic::ptrmask:
-      return TargetOpcode::G_PTRMASK;
-    case Intrinsic::lrint:
-      return TargetOpcode::G_INTRINSIC_LRINT;
-    case Intrinsic::llrint:
-      return TargetOpcode::G_INTRINSIC_LLRINT;
-    // FADD/FMUL require checking the FMF, so are handled elsewhere.
-    case Intrinsic::vector_reduce_fmin:
-      return TargetOpcode::G_VECREDUCE_FMIN;
-    case Intrinsic::vector_reduce_fmax:
-      return TargetOpcode::G_VECREDUCE_FMAX;
-    case Intrinsic::vector_reduce_fminimum:
-      return TargetOpcode::G_VECREDUCE_FMINIMUM;
-    case Intrinsic::vector_reduce_fmaximum:
-      return TargetOpcode::G_VECREDUCE_FMAXIMUM;
-    case Intrinsic::vector_reduce_add:
-      return TargetOpcode::G_VECREDUCE_ADD;
-    case Intrinsic::vector_reduce_mul:
-      return TargetOpcode::G_VECREDUCE_MUL;
-    case Intrinsic::vector_reduce_and:
-      return TargetOpcode::G_VECREDUCE_AND;
-    case Intrinsic::vector_reduce_or:
-      return TargetOpcode::G_VECREDUCE_OR;
-    case Intrinsic::vector_reduce_xor:
-      return TargetOpcode::G_VECREDUCE_XOR;
-    case Intrinsic::vector_reduce_smax:
-      return TargetOpcode::G_VECREDUCE_SMAX;
-    case Intrinsic::vector_reduce_smin:
-      return TargetOpcode::G_VECREDUCE_SMIN;
-    case Intrinsic::vector_reduce_umax:
-      return TargetOpcode::G_VECREDUCE_UMAX;
-    case Intrinsic::vector_reduce_umin:
-      return TargetOpcode::G_VECREDUCE_UMIN;
-    case Intrinsic::experimental_vector_compress:
-      return TargetOpcode::G_VECTOR_COMPRESS;
-    case Intrinsic::lround:
-      return TargetOpcode::G_LROUND;
-    case Intrinsic::llround:
-      return TargetOpcode::G_LLROUND;
-    case Intrinsic::get_fpenv:
-      return TargetOpcode::G_GET_FPENV;
-    case Intrinsic::get_fpmode:
-      return TargetOpcode::G_GET_FPMODE;
+  default:
+    break;
+  case Intrinsic::acos:
+    return TargetOpcode::G_FACOS;
+  case Intrinsic::asin:
+    return TargetOpcode::G_FASIN;
+  case Intrinsic::atan:
+    return TargetOpcode::G_FATAN;
+  case Intrinsic::atan2:
+    return TargetOpcode::G_FATAN2;
+  case Intrinsic::bswap:
+    return TargetOpcode::G_BSWAP;
+  case Intrinsic::bitreverse:
+    return TargetOpcode::G_BITREVERSE;
+  case Intrinsic::fshl:
+    return TargetOpcode::G_FSHL;
+  case Intrinsic::fshr:
+    return TargetOpcode::G_FSHR;
+  case Intrinsic::ceil:
+    return TargetOpcode::G_FCEIL;
+  case Intrinsic::cos:
+    return TargetOpcode::G_FCOS;
+  case Intrinsic::cosh:
+    return TargetOpcode::G_FCOSH;
+  case Intrinsic::ctpop:
+    return TargetOpcode::G_CTPOP;
+  case Intrinsic::exp:
+    return TargetOpcode::G_FEXP;
+  case Intrinsic::exp2:
+    return TargetOpcode::G_FEXP2;
+  case Intrinsic::exp10:
+    return TargetOpcode::G_FEXP10;
+  case Intrinsic::fabs:
+    return TargetOpcode::G_FABS;
+  case Intrinsic::copysign:
+    return TargetOpcode::G_FCOPYSIGN;
+  case Intrinsic::minnum:
+    return TargetOpcode::G_FMINNUM;
+  case Intrinsic::maxnum:
+    return TargetOpcode::G_FMAXNUM;
+  case Intrinsic::minimum:
+    return TargetOpcode::G_FMINIMUM;
+  case Intrinsic::maximum:
+    return TargetOpcode::G_FMAXIMUM;
+  case Intrinsic::minimumnum:
+    return TargetOpcode::G_FMINIMUMNUM;
+  case Intrinsic::maximumnum:
+    return TargetOpcode::G_FMAXIMUMNUM;
+  case Intrinsic::canonicalize:
+    return TargetOpcode::G_FCANONICALIZE;
+  case Intrinsic::floor:
+    return TargetOpcode::G_FFLOOR;
+  case Intrinsic::fma:
+    return TargetOpcode::G_FMA;
+  case Intrinsic::log:
+    return TargetOpcode::G_FLOG;
+  case Intrinsic::log2:
+    return TargetOpcode::G_FLOG2;
+  case Intrinsic::log10:
+    return TargetOpcode::G_FLOG10;
+  case Intrinsic::ldexp:
+    return TargetOpcode::G_FLDEXP;
+  case Intrinsic::nearbyint:
+    return TargetOpcode::G_FNEARBYINT;
+  case Intrinsic::pow:
+    return TargetOpcode::G_FPOW;
+  case Intrinsic::powi:
+    return TargetOpcode::G_FPOWI;
+  case Intrinsic::rint:
+    return TargetOpcode::G_FRINT;
+  case Intrinsic::round:
+    return TargetOpcode::G_INTRINSIC_ROUND;
+  case Intrinsic::roundeven:
+    return TargetOpcode::G_INTRINSIC_ROUNDEVEN;
+  case Intrinsic::sin:
+    return TargetOpcode::G_FSIN;
+  case Intrinsic::sinh:
+    return TargetOpcode::G_FSINH;
+  case Intrinsic::sqrt:
+    return TargetOpcode::G_FSQRT;
+  case Intrinsic::tan:
+    return TargetOpcode::G_FTAN;
+  case Intrinsic::tanh:
+    return TargetOpcode::G_FTANH;
+  case Intrinsic::trunc:
+    return TargetOpcode::G_INTRINSIC_TRUNC;
+  case Intrinsic::readcyclecounter:
+    return TargetOpcode::G_READCYCLECOUNTER;
+  case Intrinsic::readsteadycounter:
+    return TargetOpcode::G_READSTEADYCOUNTER;
+  case Intrinsic::ptrmask:
+    return TargetOpcode::G_PTRMASK;
+  case Intrinsic::lrint:
+    return TargetOpcode::G_INTRINSIC_LRINT;
+  case Intrinsic::llrint:
+    return TargetOpcode::G_INTRINSIC_LLRINT;
+  // FADD/FMUL require checking the FMF, so are handled elsewhere.
+  case Intrinsic::vector_reduce_fmin:
+    return TargetOpcode::G_VECREDUCE_FMIN;
+  case Intrinsic::vector_reduce_fmax:
+    return TargetOpcode::G_VECREDUCE_FMAX;
+  case Intrinsic::vector_reduce_fminimum:
+    return TargetOpcode::G_VECREDUCE_FMINIMUM;
+  case Intrinsic::vector_reduce_fmaximum:
+    return TargetOpcode::G_VECREDUCE_FMAXIMUM;
+  case Intrinsic::vector_reduce_add:
+    return TargetOpcode::G_VECREDUCE_ADD;
+  case Intrinsic::vector_reduce_mul:
+    return TargetOpcode::G_VECREDUCE_MUL;
+  case Intrinsic::vector_reduce_and:
+    return TargetOpcode::G_VECREDUCE_AND;
+  case Intrinsic::vector_reduce_or:
+    return TargetOpcode::G_VECREDUCE_OR;
+  case Intrinsic::vector_reduce_xor:
+    return TargetOpcode::G_VECREDUCE_XOR;
+  case Intrinsic::vector_reduce_smax:
+    return TargetOpcode::G_VECREDUCE_SMAX;
+  case Intrinsic::vector_reduce_smin:
+    return TargetOpcode::G_VECREDUCE_SMIN;
+  case Intrinsic::vector_reduce_umax:
+    return TargetOpcode::G_VECREDUCE_UMAX;
+  case Intrinsic::vector_reduce_umin:
+    return TargetOpcode::G_VECREDUCE_UMIN;
+  case Intrinsic::experimental_vector_compress:
+    return TargetOpcode::G_VECTOR_COMPRESS;
+  case Intrinsic::lround:
+    return TargetOpcode::G_LROUND;
+  case Intrinsic::llround:
+    return TargetOpcode::G_LLROUND;
+  case Intrinsic::get_fpenv:
+    return TargetOpcode::G_GET_FPENV;
+  case Intrinsic::get_fpmode:
+    return TargetOpcode::G_GET_FPMODE;
   }
   return Intrinsic::not_intrinsic;
 }
@@ -2086,7 +2077,7 @@ static unsigned getConstrainedOpcode(Intrinsic::ID ID) {
 }
 
 bool IRTranslator::translateConstrainedFPIntrinsic(
-  const ConstrainedFPIntrinsic &FPI, MachineIRBuilder &MIRBuilder) {
+    const ConstrainedFPIntrinsic &FPI, MachineIRBuilder &MIRBuilder) {
   fp::ExceptionBehavior EB = *FPI.getExceptionBehavior();
 
   unsigned Opcode = getConstrainedOpcode(FPI.getIntrinsicID());
@@ -2226,17 +2217,18 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
   case Intrinsic::dbg_declare: {
     const DbgDeclareInst &DI = cast<DbgDeclareInst>(CI);
     assert(DI.getVariable() && "Missing variable");
-    translateDbgDeclareRecord(DI.getAddress(), DI.hasArgList(), DI.getVariable(),
-                       DI.getExpression(), DI.getDebugLoc(), MIRBuilder);
+    translateDbgDeclareRecord(DI.getAddress(), DI.hasArgList(),
+                              DI.getVariable(), DI.getExpression(),
+                              DI.getDebugLoc(), MIRBuilder);
     return true;
   }
   case Intrinsic::dbg_label: {
     const DbgLabelInst &DI = cast<DbgLabelInst>(CI);
     assert(DI.getLabel() && "Missing label");
 
-    assert(DI.getLabel()->isValidLocationForIntrinsic(
-               MIRBuilder.getDebugLoc()) &&
-           "Expected inlined-at fields to agree");
+    assert(
+        DI.getLabel()->isValidLocationForIntrinsic(MIRBuilder.getDebugLoc()) &&
+        "Expected inlined-at fields to agree");
 
     MIRBuilder.buildDbgLabel(DI.getLabel());
     return true;
@@ -2267,7 +2259,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
     // This form of DBG_VALUE is target-independent.
     const DbgValueInst &DI = cast<DbgValueInst>(CI);
     translateDbgValueRecord(DI.getValue(), DI.hasArgList(), DI.getVariable(),
-                       DI.getExpression(), DI.getDebugLoc(), MIRBuilder);
+                            DI.getExpression(), DI.getDebugLoc(), MIRBuilder);
     return true;
   }
   case Intrinsic::uadd_with_overflow:
@@ -2306,21 +2298,29 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
     // TODO: Preserve "int min is poison" arg in GMIR?
     return translateUnaryOp(TargetOpcode::G_ABS, CI, MIRBuilder);
   case Intrinsic::smul_fix:
-    return translateFixedPointIntrinsic(TargetOpcode::G_SMULFIX, CI, MIRBuilder);
+    return translateFixedPointIntrinsic(TargetOpcode::G_SMULFIX, CI,
+                                        MIRBuilder);
   case Intrinsic::umul_fix:
-    return translateFixedPointIntrinsic(TargetOpcode::G_UMULFIX, CI, MIRBuilder);
+    return translateFixedPointIntrinsic(TargetOpcode::G_UMULFIX, CI,
+                                        MIRBuilder);
   case Intrinsic::smul_fix_sat:
-    return translateFixedPointIntrinsic(TargetOpcode::G_SMULFIXSAT, CI, MIRBuilder);
+    return translateFixedPointIntrinsic(TargetOpcode::G_SMULFIXSAT, CI,
+                                        MIRBuilder);
   case Intrinsic::umul_fix_sat:
-    return translateFixedPointIntrinsic(TargetOpcode::G_UMULFIXSAT, CI, MIRBuilder);
+    return translateFixedPointIntrinsic(TargetOpcode::G_UMULFIXSAT, CI,
+                                        MIRBuilder);
   case Intrinsic::sdiv_fix:
-    return translateFixedPointIntrinsic(TargetOpcode::G_SDIVFIX, CI, MIRBuilder);
+    return translateFixedPointIntrinsic(TargetOpcode::G_SDIVFIX, CI,
+                                        MIRBuilder);
   case Intrinsic::udiv_fix:
-    return translateFixedPointIntrinsic(TargetOpcode::G_UDIVFIX, CI, MIRBuilder);
+    return translateFixedPointIntrinsic(TargetOpcode::G_UDIVFIX, CI,
+                                        MIRBuilder);
   case Intrinsic::sdiv_fix_sat:
-    return translateFixedPointIntrinsic(TargetOpcode::G_SDIVFIXSAT, CI, MIRBuilder);
+    return translateFixedPointIntrinsic(TargetOpcode::G_SDIVFIXSAT, CI,
+                                        MIRBuilder);
   case Intrinsic::udiv_fix_sat:
-    return translateFixedPointIntrinsic(TargetOpcode::G_UDIVFIXSAT, CI, MIRBuilder);
+    return translateFixedPointIntrinsic(TargetOpcode::G_UDIVFIXSAT, CI,
+                                        MIRBuilder);
   case Intrinsic::fmuladd: {
     const TargetMachine &TM = MF->getTarget();
     Register Dst = getOrCreateVReg(CI);
@@ -2435,11 +2435,11 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
   case Intrinsic::ctlz: {
     ConstantInt *Cst = cast<ConstantInt>(CI.getArgOperand(1));
     bool isTrailing = ID == Intrinsic::cttz;
-    unsigned Opcode = isTrailing
-                          ? Cst->isZero() ? TargetOpcode::G_CTTZ
-                                          : TargetOpcode::G_CTTZ_ZERO_UNDEF
-                          : Cst->isZero() ? TargetOpcode::G_CTLZ
-                                          : TargetOpcode::G_CTLZ_ZERO_UNDEF;
+    unsigned Opcode = isTrailing      ? Cst->isZero()
+                                            ? TargetOpcode::G_CTTZ
+                                            : TargetOpcode::G_CTTZ_ZERO_UNDEF
+                      : Cst->isZero() ? TargetOpcode::G_CTLZ
+                                      : TargetOpcode::G_CTLZ_ZERO_UNDEF;
     MIRBuilder.buildInstr(Opcode, {getOrCreateVReg(CI)},
                           {getOrCreateVReg(*CI.getArgOperand(0))});
     return true;
@@ -2478,8 +2478,8 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
   case Intrinsic::write_register: {
     Value *Arg = CI.getArgOperand(0);
     MIRBuilder.buildInstr(TargetOpcode::G_WRITE_REGISTER)
-      .addMetadata(cast<MDNode>(cast<MetadataAsValue>(Arg)->getMetadata()))
-      .addUse(getOrCreateVReg(*CI.getArgOperand(1)));
+        .addMetadata(cast<MDNode>(cast<MetadataAsValue>(Arg)->getMetadata()))
+        .addUse(getOrCreateVReg(*CI.getArgOperand(1)));
     return true;
   }
   case Intrinsic::localescape: {
@@ -2659,7 +2659,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
     return translateVectorDeinterleave2Intrinsic(CI, MIRBuilder);
   }
 
-#define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC)  \
+#define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC)                         \
   case Intrinsic::INTRINSIC:
 #include "llvm/IR/ConstrainedOps.def"
     return translateConstrainedFPIntrinsic(cast<ConstrainedFPIntrinsic>(CI),
@@ -2875,8 +2875,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
 }
 
 bool IRTranslator::findUnwindDestinations(
-    const BasicBlock *EHPadBB,
-    BranchProbability Prob,
+    const BasicBlock *EHPadBB, BranchProbability Prob,
     SmallVectorImpl<std::pair<MachineBasicBlock *, BranchProbability>>
         &UnwindDests) {
   EHPersonality Personality = classifyEHPersonality(
@@ -2959,8 +2958,8 @@ bool IRTranslator::translateInvoke(const User &U,
   // FIXME: support Windows dllimport function calls and calls through
   // weak symbols.
   if (Fn && (Fn->hasDLLImportStorageClass() ||
-            (MF->getTarget().getTargetTriple().isOSWindows() &&
-             Fn->hasExternalWeakLinkage())))
+             (MF->getTarget().getTargetTriple().isOSWindows() &&
+              Fn->hasExternalWeakLinkage())))
     return false;
 
   bool LowerInlineAsm = I.isInlineAsm();
@@ -3047,8 +3046,7 @@ bool IRTranslator::translateLandingPad(const User &U,
 
   // Add a label to mark the beginning of the landing pad.  Deletion of the
   // landing pad can thus be detected via the MachineModuleInfo.
-  MIRBuilder.buildInstr(TargetOpcode::EH_LABEL)
-    .addSym(MF->addLandingPad(&MBB));
+  MIRBuilder.buildInstr(TargetOpcode::EH_LABEL).addSym(MF->addLandingPad(&MBB));
 
   // If the unwinder does not preserve all registers, ensure that the
   // function marks the clobbered registers as used.
@@ -3476,8 +3474,7 @@ bool IRTranslator::translateAtomicRMW(const User &U,
   return true;
 }
 
-bool IRTranslator::translateFence(const User &U,
-                                  MachineIRBuilder &MIRBuilder) {
+bool IRTranslator::translateFence(const User &U, MachineIRBuilder &MIRBuilder) {
   const FenceInst &Fence = cast<FenceInst>(U);
   MIRBuilder.buildFence(static_cast<unsigned>(Fence.getOrdering()),
                         Fence.getSyncScopeID());
@@ -3535,10 +3532,10 @@ void IRTranslator::finishPendingPhis() {
 }
 
 void IRTranslator::translateDbgValueRecord(Value *V, bool HasArgList,
-                                     const DILocalVariable *Variable,
-                                     const DIExpression *Expression,
-                                     const DebugLoc &DL,
-                                     MachineIRBuilder &MIRBuilder) {
+                                           const DILocalVariable *Variable,
+                                           const DIExpression *Expression,
+                                           const DebugLoc &DL,
+                                           MachineIRBuilder &MIRBuilder) {
   assert(Variable->isValidLocationForIntrinsic(DL) &&
          "Expected inlined-at fields to agree");
   // Act as if we're handling a debug intrinsic.
@@ -3581,10 +3578,10 @@ void IRTranslator::translateDbgValueRecord(Value *V, bool HasArgList,
 }
 
 void IRTranslator::translateDbgDeclareRecord(Value *Address, bool HasArgList,
-                                     const DILocalVariable *Variable,
-                                     const DIExpression *Expression,
-                                     const DebugLoc &DL,
-                                     MachineIRBuilder &MIRBuilder) {
+                                             const DILocalVariable *Variable,
+                                             const DIExpression *Expression,
+                                             const DebugLoc &DL,
+                                             MachineIRBuilder &MIRBuilder) {
   if (!Address || isa<UndefValue>(Address)) {
     LLVM_DEBUG(dbgs() << "Dropping debug info for " << *Variable << "\n");
     return;
@@ -3596,13 +3593,12 @@ void IRTranslator::translateDbgDeclareRecord(Value *Address, bool HasArgList,
   if (AI && AI->isStaticAlloca()) {
     // Static allocas are tracked at the MF level, no need for DBG_VALUE
     // instructions (in fact, they get ignored if they *do* exist).
-    MF->setVariableDbgInfo(Variable, Expression,
-                           getOrCreateFrameIndex(*AI), DL);
+    MF->setVariableDbgInfo(Variable, Expression, getOrCreateFrameIndex(*AI),
+                           DL);
     return;
   }
 
-  if (translateIfEntryValueArgument(true, Address, Variable,
-                                    Expression, DL,
+  if (translateIfEntryValueArgument(true, Address, Variable, Expression, DL,
                                     MIRBuilder))
     return;
 
@@ -3614,7 +3610,7 @@ void IRTranslator::translateDbgDeclareRecord(Value *Address, bool HasArgList,
 }
 
 void IRTranslator::translateDbgInfo(const Instruction &Inst,
-                                      MachineIRBuilder &MIRBuilder) {
+                                    MachineIRBuilder &MIRBuilder) {
   for (DbgRecord &DR : Inst.getDbgRecordRange()) {
     if (DbgLabelRecord *DLR = dyn_cast<DbgLabelRecord>(&DR)) {
       MIRBuilder.setDebugLoc(DLR->getDebugLoc());
@@ -3705,7 +3701,7 @@ bool IRTranslator::translate(const Constant &C, Register Reg) {
     }
     EntryBuilder->buildBuildVector(Reg, Ops);
   } else if (auto CE = dyn_cast<ConstantExpr>(&C)) {
-    switch(CE->getOpcode()) {
+    switch (CE->getOpcode()) {
 #define HANDLE_INST(NUM, OPCODE, CLASS)                                        \
   case Instruction::OPCODE:                                                    \
     return translate##OPCODE(*CE, *EntryBuilder.get());
@@ -4071,7 +4067,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
 
   // Create all blocks, in IR order, to preserve the layout.
   FuncInfo.MBBMap.resize(F.getMaxBlockNumber());
-  for (const BasicBlock &BB: F) {
+  for (const BasicBlock &BB : F) {
     auto *&MBB = FuncInfo.MBBMap[BB.getNumber()];
 
     MBB = MF->CreateMachineBasicBlock(&BB);
@@ -4100,7 +4096,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
 
   // Lower the actual args into this basic block.
   SmallVector<ArrayRef<Register>, 8> VRegArgs;
-  for (const Argument &Arg: F.args()) {
+  for (const Argument &Arg : F.args()) {
     if (DL->getTypeStoreSize(Arg.getType()).isZero())
       continue; // Don't handle zero sized types.
     ArrayRef<Register> VRegs = getOrCreateVRegs(Arg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index bd13a3bae92cd..993de9e9f64ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -148,8 +148,7 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
         {"SPV_INTEL_tensor_float32_conversion",
          SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
-        {"SPV_KHR_bfloat16",
-         SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
+        {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 946da8fcfff24..ce9ebb619f242 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1123,14 +1123,15 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
                                          ExplicitLayoutRequired, EmitIR);
   TypesInProcessing.erase(Ty);
 
-  // Record the FPVariant of the floating-point registers in the VRegFPVariantMap.
+  // Record the FPVariant of the floating-point registers in the
+  // VRegFPVariantMap.
   MachineFunction *MF = &MIRBuilder.getMF();
   Register TypeReg = getSPIRVTypeID(SpirvType);
   if (Ty->isFloatingPointTy()) {
     if (Ty->isBFloatTy()) {
-        VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
+      VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
     } else {
-        VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
+      VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
     }
   }
   VRegToTypeMap[MF][TypeReg] = SpirvType;
@@ -1689,17 +1690,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
   MachineBasicBlock &DepMBB = I.getMF()->front();
   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
   const MachineInstr *NewMI =
-    createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-      auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
-                         MIRBuilder.getDL(), TII.get(SPIRVOPcode))
-                     .addDef(createTypeVReg(CurMF->getRegInfo()))
-                     .addImm(BitWidth);
+      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+                           MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+                       .addDef(createTypeVReg(CurMF->getRegInfo()))
+                       .addImm(BitWidth);
 
-      if (SPIRVOPcode != SPIRV::OpTypeFloat)
-        MIB.addImm(0);
+        if (SPIRVOPcode != SPIRV::OpTypeFloat)
+          MIB.addImm(0);
 
-      return MIB;
-    });
+        return MIB;
+      });
   add(Ty, false, NewMI);
   return finishCreatingSPIRVType(Ty, NewMI);
 }
@@ -2104,7 +2105,8 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
   return false;
 }
 
-SPIRVGlobalRegistry::FPVariant SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
+SPIRVGlobalRegistry::FPVariant
+SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
                                          const MachineFunction *MF) {
   auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
   if (t != VRegFPVariantMap.end()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 8f9a5d3cfc68a..1f8c30dc01f7f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,14 +29,10 @@ using SPIRVType = const MachineInstr;
 using StructOffsetDecorator = std::function<void(Register)>;
 
 class SPIRVGlobalRegistry : public SPIRVIRMapping {
-  public:
-  enum class FPVariant {
-    NONE,
-    IEEE_FLOAT,
-    BRAIN_FLOAT
-  };
-
-  private:
+public:
+  enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
+
+private:
   // Registers holding values which have types associated with them.
   // Initialized upon VReg definition in IRTranslator.
   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -96,9 +92,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // map of aliasing decorations to aliasing metadata
   std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;
 
-  // Maps floating point Registers to their FPVariant (float type kind), given the MachineFunction.
+  // Maps floating point Registers to their FPVariant (float type kind), given
+  // the MachineFunction.
   DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
-    VRegFPVariantMap;
+      VRegFPVariantMap;
 
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
@@ -433,10 +430,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // Informs about removal of the machine instruction and invalidates data
   // structures referring this instruction.
   void invalidateMachineInstr(MachineInstr *MI);
-  
-  // Return the FPVariant of to the given floating-point regiester. 
+
+  // Return the FPVariant of to the given floating-point regiester.
   FPVariant getFPVariantForVReg(Register VReg,
-                                 const MachineFunction *MF = nullptr);
+                                const MachineFunction *MF = nullptr);
 
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 1cc76b5e30036..b8041725c9050 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1264,7 +1264,8 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpDot: {
     const MachineFunction *MF = MI.getMF();
     SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
-    SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+    SPIRVGlobalRegistry::FPVariant FPV =
+        GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
     if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
       Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
     }
@@ -1277,18 +1278,17 @@ void addInstrRequirements(const MachineInstr &MI,
     else if (BitWidth == 16) {
       SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
       const MachineFunction *MF = MI.getMF();
-      SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
+      SPIRVGlobalRegistry::FPVariant FPV =
+          GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
       if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
         if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
-      report_fatal_error(
-          "OpTypeFloat type with bfloat requires the "
-          "following SPIR-V extension: SPV_KHR_bfloat16",
-          false);
+          report_fatal_error("OpTypeFloat type with bfloat requires the "
+                             "following SPIR-V extension: SPV_KHR_bfloat16",
+                             false);
         Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
         Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
-      } else {
-        Reqs.addCapability(SPIRV::Capability::Float16);
       }
+      Reqs.addCapability(SPIRV::Capability::Float16);
     }
     break;
   }
@@ -1625,7 +1625,8 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
     const MachineFunction *MF = MI.getMF();
     SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
-    SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+    SPIRVGlobalRegistry::FPVariant FPV =
+        GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
     if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
       Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
     } else {



More information about the llvm-commits mailing list