[llvm] [SPIRV] Add bfloat support (PR #155645)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 28 13:50:18 PDT 2025


https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/155645

>From 9e4f483ce4d7c0f710270cc21007b73ce56b59e1 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 03:42:59 -0700
Subject: [PATCH 1/6] initial commit for adding the support of bfloat in
 machine inst

---
 llvm/include/llvm/CodeGen/MachineInstr.h       | 3 ++-
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp   | 9 +++++++--
 llvm/lib/CodeGen/MachineInstr.cpp              | 3 +++
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp     | 4 +++-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp  | 4 ++++
 llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 4 ++++
 6 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 10a9b1ff1411d..6f692ae32510b 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,8 +123,9 @@ class MachineInstr
     NoUSWrap = 1 << 20,      // Instruction supports geps
                              // no unsigned signed wrap.
     SameSign = 1 << 21,      // Both operands have the same sign.
-    InBounds = 1 << 22       // Pointer arithmetic remains inbounds.
+    InBounds = 1 << 22,       // Pointer arithmetic remains inbounds.
                              // Implies NoUSWrap.
+    BFloat16 = 1 << 23      // Instruction with bf16 type
   };
 
 private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..2a6b66984c8ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
-    return false;
+  // if (containsBF16Type(U))
+  //   return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
@@ -2813,6 +2813,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   if (isa<FPMathOperator>(CI))
     MIB->copyIRFlags(CI);
 
+  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+  if (containsBF16Type(U)) {
+    MIB->setFlag(MachineInstr::MIFlag::BFloat16);
+  }
+
   for (const auto &Arg : enumerate(CI.args())) {
     // If this is required to be an immediate, don't materialize it in a
     // register.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 79047f732808a..10ff667bcb522 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,6 +632,9 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
+  if (I.getType()->isBFloatTy())
+    MIFlags |= MachineInstr::MIFlag::BFloat16;
+
   return MIFlags;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..bd13a3bae92cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_KHR_float_controls2",
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
         {"SPV_INTEL_tensor_float32_conversion",
-         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+        {"SPV_KHR_bfloat16",
+         SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..5bba5cdce3753 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,6 +1267,10 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16)
       Reqs.addCapability(SPIRV::Capability::Float16);
+    if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+      Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+    }
     break;
   }
   case SPIRV::OpTypeVector: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
 defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
 defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
 defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
 defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time

>From 32e6c031b24083ef668da44d2dd5c9785d9ba12d Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 12:29:09 -0700
Subject: [PATCH 2/6] debugging

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  3 ++-
 llvm/lib/CodeGen/MachineInstr.cpp             |  4 +++-
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  2 ++
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 10 ++++++----
 4 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 2a6b66984c8ae..3b112805e082c 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2815,6 +2815,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
 
   // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
   if (containsBF16Type(U)) {
+    // assert(false && "bfloat detected at the IR Translator");
     MIB->setFlag(MachineInstr::MIFlag::BFloat16);
   }
 
@@ -2878,7 +2879,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
       MIB.addUse(TokenReg, RegState::Implicit);
     }
   }
-
+  
   return true;
 }
 
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 10ff667bcb522..1ec8c8a2d90c9 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,8 +632,10 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
-  if (I.getType()->isBFloatTy())
+  if (I.getType()->getScalarType()->isBFloatTy()) {
     MIFlags |= MachineInstr::MIFlag::BFloat16;
+    // assert(false && "bfloat detected at the MachineInstr");
+  }
 
   return MIFlags;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index f0b938d681dba..923c92c28d3f6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -169,6 +169,8 @@ def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness)
                   "$type = OpTypeInt $width $signedness">;
 def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
                   "$type = OpTypeFloat $width">;
+def OpTypeBFloat: Op<40, (outs TYPE:$type), (ins i32imm:$width),
+                  "$type = OpTypeFloat $width 0">;
 def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
                   "$type = OpTypeVector $compType $compCount">;
 def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 5bba5cdce3753..19c938c96a9e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,10 +1267,12 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16)
       Reqs.addCapability(SPIRV::Capability::Float16);
-    if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
-      Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
-      Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
-    }
+    break;
+  }
+  case SPIRV::OpTypeBFloat: {
+    assert(false && "bfloat detected at the file");
+    Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+    Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
     break;
   }
   case SPIRV::OpTypeVector: {

>From eccbcf845dd59f4a822a5fc9487d48b4e62d7a2c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 14:53:45 -0700
Subject: [PATCH 3/6] Debugging

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  | 19 ++++++++++++-------
 llvm/lib/CodeGen/GlobalISel/Legalizer.cpp     |  4 ++++
 llvm/lib/CodeGen/MachineInstr.cpp             | 14 ++++++++++++--
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  3 ++-
 .../SPIRV/SPIRVPreLegalizerCombiner.cpp       |  6 ++++++
 5 files changed, 36 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 3b112805e082c..4aece388749f8 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2195,6 +2195,8 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
   if (translateSimpleIntrinsic(CI, ID, MIRBuilder))
     return true;
 
+  LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
+  LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
   switch (ID) {
   default:
     break;
@@ -2581,9 +2583,10 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
     Value *FpValue = CI.getOperand(0);
     ConstantInt *TestMaskValue = cast<ConstantInt>(CI.getOperand(1));
 
+    uint32_t Flags = MachineInstr::copyFlagsFromInstruction(CI);
     MIRBuilder
         .buildInstr(TargetOpcode::G_IS_FPCLASS, {getOrCreateVReg(CI)},
-                    {getOrCreateVReg(*FpValue)})
+                    {getOrCreateVReg(*FpValue)}, Flags)
         .addImm(TestMaskValue->getZExtValue());
 
     return true;
@@ -2813,12 +2816,6 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   if (isa<FPMathOperator>(CI))
     MIB->copyIRFlags(CI);
 
-  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
-  if (containsBF16Type(U)) {
-    // assert(false && "bfloat detected at the IR Translator");
-    MIB->setFlag(MachineInstr::MIFlag::BFloat16);
-  }
-
   for (const auto &Arg : enumerate(CI.args())) {
     // If this is required to be an immediate, don't materialize it in a
     // register.
@@ -2879,6 +2876,13 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
       MIB.addUse(TokenReg, RegState::Implicit);
     }
   }
+
+  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+  if (containsBF16Type(U)) {
+    dbgs() << "Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
+    MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
+  }
+   MIB->copyIRFlags(CI);
   
   return true;
 }
@@ -4161,6 +4165,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
 
         // Translate any debug-info attached to the instruction.
         translateDbgInfo(Inst, *CurBuilder);
+        LLVM_DEBUG(dbgs() << "Inst at IRTranslator: " << Inst << "\n");
 
         if (translate(Inst))
           continue;
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index aef16b5f33af4..6013124558106 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -192,6 +192,10 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
     if (MBB->empty())
       continue;
     for (MachineInstr &MI : *MBB) {
+      dbgs() << "Legalizer sees: " << MI << "\n";
+      if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+        assert(false && "bfloat detected at the Legalizer");
+      }
       // Only legalize pre-isel generic instructions: others don't have types
       // and are assumed to be legal.
       if (!isPreISelGenericOpcode(MI.getOpcode()))
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 1ec8c8a2d90c9..bbeeed76631c6 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,9 +632,19 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
-  if (I.getType()->getScalarType()->isBFloatTy()) {
+  auto *Sel = dyn_cast<SelectInst>(&I);
+  bool BFloatOpnd = !Sel && I.getType()->getScalarType()->isBFloatTy();
+
+  if (isa<CallInst>(&I)) {
+    for (const Value *Op : I.operands()) {
+      Type *OpTy = Op->getType();
+      BFloatOpnd |= OpTy->getScalarType()->isBFloatTy();
+    }
+  }
+
+  if (BFloatOpnd) {
     MIFlags |= MachineInstr::MIFlag::BFloat16;
-    // assert(false && "bfloat detected at the MachineInstr");
+    LLVM_DEBUG(dbgs() << "bfloat detected at the MachineInstr" << "\n");
   }
 
   return MIFlags;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 19c938c96a9e8..ab26f0208d758 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,6 +1262,7 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpTypeFloat: {
+    assert(false && "bfloat detected at the spirv module analysis");
     unsigned BitWidth = MI.getOperand(1).getImm();
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
@@ -1270,7 +1271,7 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpTypeBFloat: {
-    assert(false && "bfloat detected at the file");
+    assert(false && "bfloat detected at the spirv module analysis");
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
     Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
     break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
index 83567519355a8..a964db0b02746 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
@@ -156,6 +156,12 @@ SPIRVPreLegalizerCombinerImpl::SPIRVPreLegalizerCombinerImpl(
 }
 
 bool SPIRVPreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
+  // Inst combiner is disabled for instructions
+  // with bfloat type.
+  // TODO: add bfloat-specific inst combiner rules
+  if (MI.getFlag(MachineInstr::BFloat16))
+    return false;
+
   return tryCombineAllImpl(MI);
 }
 

>From 2b080940fd7e57c6dc714aedca113b80e8871d54 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 05:55:29 -0700
Subject: [PATCH 4/6] continue debugging

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 19 +++++++++++--------
 llvm/lib/CodeGen/GlobalISel/Legalizer.cpp    |  4 ++--
 llvm/lib/CodeGen/MachineInstr.cpp            |  4 +++-
 3 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 4aece388749f8..a396839c5061b 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2195,8 +2195,8 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
   if (translateSimpleIntrinsic(CI, ID, MIRBuilder))
     return true;
 
-  LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
-  LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
+  LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
+  LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
   switch (ID) {
   default:
     break;
@@ -2878,11 +2878,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   }
 
   // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+  MIB->copyIRFlags(CI);
   if (containsBF16Type(U)) {
-    dbgs() << "Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
+    dbgs() << "[BFLOAT] Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
     MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
   }
-   MIB->copyIRFlags(CI);
   
   return true;
 }
@@ -4105,8 +4105,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
   if (CLI->fallBackToDAGISel(*MF)) {
     OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
                                F.getSubprogram(), &F.getEntryBlock());
-    R << "unable to lower function: "
-      << ore::NV("Prototype", F.getFunctionType());
+    R << "u "
+      << ore::NV("Prototype", F.getnable to lower function:FunctionType());
     reportTranslationError(*MF, *TPC, *ORE, R);
     return false;
   }
@@ -4118,6 +4118,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
       continue; // Don't handle zero sized types.
     ArrayRef<Register> VRegs = getOrCreateVRegs(Arg);
     VRegArgs.push_back(VRegs);
+    LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
 
     if (Arg.hasSwiftErrorAttr()) {
       assert(VRegs.size() == 1 && "Too many vregs for Swift error");
@@ -4165,10 +4166,12 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
 
         // Translate any debug-info attached to the instruction.
         translateDbgInfo(Inst, *CurBuilder);
-        LLVM_DEBUG(dbgs() << "Inst at IRTranslator: " << Inst << "\n");
+        LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator: " << Inst << "\n");
 
-        if (translate(Inst))
+        if (translate(Inst)) {
+          LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator after translate(Inst): " << Inst << "\n");
           continue;
+        }
 
         OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
                                    Inst.getDebugLoc(), BB);
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index 6013124558106..fc1fede1a6d1c 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -192,9 +192,9 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
     if (MBB->empty())
       continue;
     for (MachineInstr &MI : *MBB) {
-      dbgs() << "Legalizer sees: " << MI << "\n";
+      LLVM_DEBUG(dbgs() << "[BFLOAT]  Legalizer sees: " << MI << "\n");
       if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
-        assert(false && "bfloat detected at the Legalizer");
+        LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the Legalizer \n");
       }
       // Only legalize pre-isel generic instructions: others don't have types
       // and are assumed to be legal.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index bbeeed76631c6..6b7efe671b572 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -61,6 +61,8 @@
 #include <cstring>
 #include <utility>
 
+#define DEBUG_TYPE "irtranslator"
+
 using namespace llvm;
 
 static cl::opt<bool>
@@ -644,7 +646,7 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
 
   if (BFloatOpnd) {
     MIFlags |= MachineInstr::MIFlag::BFloat16;
-    LLVM_DEBUG(dbgs() << "bfloat detected at the MachineInstr" << "\n");
+    LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the MachineInstr \n");
   }
 
   return MIFlags;

>From 34ea97b632a5a71cec7311539c3a78b712fa542a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 08:57:30 -0700
Subject: [PATCH 5/6] continue debugging

---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 11 +++++------
 llvm/lib/CodeGen/GlobalISel/Legalizer.cpp    |  3 ++-
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index a396839c5061b..1649b8b4a41b3 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2046,9 +2046,10 @@ bool IRTranslator::translateSimpleIntrinsic(const CallInst &CI,
                                             MachineIRBuilder &MIRBuilder) {
 
   unsigned Op = getSimpleIntrinsicOpcode(ID);
+  LLVM_DEBUG(dbgs() << "[BFLOAT] translateSimpleIntrinsic is called for ID:" << ID << '\n');
 
   // Is this a simple intrinsic?
-  if (Op == Intrinsic::not_intrinsic)
+  if (Op == Intrinsic::not_intrinsic){}
     return false;
 
   // Yes. Let's translate it.
@@ -2768,8 +2769,6 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  // if (containsBF16Type(U))
-  //   return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
@@ -4105,8 +4104,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
   if (CLI->fallBackToDAGISel(*MF)) {
     OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
                                F.getSubprogram(), &F.getEntryBlock());
-    R << "u "
-      << ore::NV("Prototype", F.getnable to lower function:FunctionType());
+    R << "unable to lower function: "
+      << ore::NV("Prototype", F.getFunctionType());
     reportTranslationError(*MF, *TPC, *ORE, R);
     return false;
   }
@@ -4118,7 +4117,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
       continue; // Don't handle zero sized types.
     ArrayRef<Register> VRegs = getOrCreateVRegs(Arg);
     VRegArgs.push_back(VRegs);
-    LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
+    // LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
 
     if (Arg.hasSwiftErrorAttr()) {
       assert(VRegs.size() == 1 && "Too many vregs for Swift error");
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index fc1fede1a6d1c..77289d138bc9c 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -195,7 +195,8 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
       LLVM_DEBUG(dbgs() << "[BFLOAT]  Legalizer sees: " << MI << "\n");
       if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
         LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the Legalizer \n");
-      }
+      } 
+
       // Only legalize pre-isel generic instructions: others don't have types
       // and are assumed to be legal.
       if (!isPreISelGenericOpcode(MI.getOpcode()))

>From eea9fd72db1e3fe7d6e83ecd9c547db75f6d144a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 13:45:57 -0700
Subject: [PATCH 6/6] use a new method of adding the floating point information
 to optypefloat

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 36 +++++++++++++++----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  2 +-
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  5 +--
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  7 ----
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |  2 ++
 5 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cfe24c84941a9..cc524ee7e3cb1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,11 +195,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
-                                               MachineIRBuilder &MIRBuilder) {
+                                               MachineIRBuilder &MIRBuilder, uint32_t Fptype) {
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
     return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
         .addDef(createTypeVReg(MIRBuilder))
-        .addImm(Width);
+        .addImm(Width)
+        .addImm(Fptype);
   });
 }
 
@@ -1041,8 +1042,11 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     return Width == 1 ? getOpTypeBool(MIRBuilder)
                       : getOpTypeInt(Width, MIRBuilder, false);
   }
-  if (Ty->isFloatingPointTy())
-    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+  if (Ty->isFloatingPointTy()) {
+    if (Ty->isBFloatTy())
+      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 1);
+    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 0);
+  }
   if (Ty->isVoidTy())
     return getOpTypeVoid(MIRBuilder);
   if (Ty->isVectorTy()) {
@@ -1677,14 +1681,33 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
     return MI;
   MachineBasicBlock &DepMBB = I.getMF()->front();
   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
-  const MachineInstr *NewMI =
-      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+  const MachineInstr *NewMI = nullptr;
+  if(SPIRVOPcode == SPIRV::OpTypeFloat) {
+    if (Ty->isBFloatTy()) {
+      NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+            .addDef(createTypeVReg(CurMF->getRegInfo()))
+            .addImm(BitWidth)
+            .addImm(1);
+      });
+    } else {
+      NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
         return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
                        MIRBuilder.getDL(), TII.get(SPIRVOPcode))
             .addDef(createTypeVReg(CurMF->getRegInfo()))
             .addImm(BitWidth)
             .addImm(0);
       });
+    }
+  } else {
+    NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+            .addDef(createTypeVReg(CurMF->getRegInfo()))
+            .addImm(BitWidth);
+      });
+  }
   add(Ty, false, NewMI);
   return finishCreatingSPIRVType(Ty, NewMI);
 }
@@ -1718,6 +1741,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
   default:
     llvm_unreachable("Bit width is of unexpected size.");
   }
+  
   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7ef812828b7cc..a2a2e7b289c5d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -436,7 +436,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
                           bool IsSigned = false);
 
-  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, uint32_t Fptype);
 
   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 923c92c28d3f6..a823df89eab6c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,10 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
 def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
 def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
                   "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
-                  "$type = OpTypeFloat $width">;
-def OpTypeBFloat: Op<40, (outs TYPE:$type), (ins i32imm:$width),
-                  "$type = OpTypeFloat $width 0">;
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, i32imm:$fpencode), "$type = OpTypeFloat $width $fpencode">;
 def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
                   "$type = OpTypeVector $compType $compCount">;
 def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index ab26f0208d758..8039cf0c432fa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,7 +1262,6 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpTypeFloat: {
-    assert(false && "bfloat detected at the spirv module analysis");
     unsigned BitWidth = MI.getOperand(1).getImm();
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
@@ -1270,12 +1269,6 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float16);
     break;
   }
-  case SPIRV::OpTypeBFloat: {
-    assert(false && "bfloat detected at the spirv module analysis");
-    Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
-    Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
-    break;
-  }
   case SPIRV::OpTypeVector: {
     unsigned NumComponents = MI.getOperand(2).getImm();
     if (NumComponents == 8 || NumComponents == 16)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 1a08c6ac0dcaf..33cb8471d747e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -540,6 +540,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
                             MF.getRegInfo());
+        LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_ptr_type to" << MI << '\n');
         ToErase.push_back(&MI);
       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
         Register Reg = MI.getOperand(1).getReg();
@@ -550,6 +551,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
+        LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_type to" << MI << '\n');
         ToErase.push_back(&MI);
       } else if (MIOp == TargetOpcode::FAKE_USE && MI.getNumOperands() > 0) {
         MachineInstr *MdMI = MI.getPrevNode();



More information about the llvm-commits mailing list