[llvm] [SPIRV] Add bfloat support (PR #155645)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 27 08:55:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: None (YixingZhang007)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/155645.diff


6 Files Affected:

- (modified) llvm/include/llvm/CodeGen/MachineInstr.h (+2-1) 
- (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+7-2) 
- (modified) llvm/lib/CodeGen/MachineInstr.cpp (+3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+3-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+4) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 10a9b1ff1411d..6f692ae32510b 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,8 +123,9 @@ class MachineInstr
     NoUSWrap = 1 << 20,      // Instruction supports geps
                              // no unsigned signed wrap.
     SameSign = 1 << 21,      // Both operands have the same sign.
-    InBounds = 1 << 22       // Pointer arithmetic remains inbounds.
+    InBounds = 1 << 22,       // Pointer arithmetic remains inbounds.
                              // Implies NoUSWrap.
+    BFloat16 = 1 << 23      // Instruction with bf16 type
   };
 
 private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..2a6b66984c8ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
-    return false;
+  // if (containsBF16Type(U))
+  //   return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
@@ -2813,6 +2813,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   if (isa<FPMathOperator>(CI))
     MIB->copyIRFlags(CI);
 
+  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+  if (containsBF16Type(U)) {
+    MIB->setFlag(MachineInstr::MIFlag::BFloat16);
+  }
+
   for (const auto &Arg : enumerate(CI.args())) {
     // If this is required to be an immediate, don't materialize it in a
     // register.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 79047f732808a..10ff667bcb522 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,6 +632,9 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
+  if (I.getType()->isBFloatTy())
+    MIFlags |= MachineInstr::MIFlag::BFloat16;
+
   return MIFlags;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..bd13a3bae92cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_KHR_float_controls2",
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
         {"SPV_INTEL_tensor_float32_conversion",
-         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+        {"SPV_KHR_bfloat16",
+         SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..5bba5cdce3753 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,6 +1267,10 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16)
       Reqs.addCapability(SPIRV::Capability::Float16);
+    if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+      Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+    }
     break;
   }
   case SPIRV::OpTypeVector: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
 defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
 defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
 defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
 defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time

``````````

</details>


https://github.com/llvm/llvm-project/pull/155645


More information about the llvm-commits mailing list