[llvm] f91e0bf - [SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 (#155645)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 22 05:53:02 PDT 2025


Author: YixingZhang007
Date: 2025-09-22T14:52:57+02:00
New Revision: f91e0bf16098decbee75233f67109751f2a2e79b

URL: https://github.com/llvm/llvm-project/commit/f91e0bf16098decbee75233f67109751f2a2e79b
DIFF: https://github.com/llvm/llvm-project/commit/f91e0bf16098decbee75233f67109751f2a2e79b.diff

LOG: [SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 (#155645)

This PR introduces the support for the SPIR-V extension
`SPV_KHR_bfloat16`. This extension extends the `OpTypeFloat` instruction
to enable the use of bfloat16 types with cooperative matrices and dot
products.

TODO:
Per the `SPV_KHR_bfloat16` extension, there are a limited number of
instructions that can use the bfloat16 type. For example, arithmetic
instructions like `FAdd` or `FMul` can't operate on `bfloat16` values.
Therefore, a future patch should be added to either emit an error or
fall back to FP32 for arithmetic in cases where bfloat16 must not be
used.

Reference Specification:

https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_bfloat16.asciidoc

Added: 
    llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
    llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
    llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll

Modified: 
    llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
    llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
    llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
    llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
    llvm/test/CodeGen/SPIRV/basic_float_types.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 768e3713f78e2..12b735e053bde 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,7 +2765,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   const CallInst &CI = cast<CallInst>(U);

diff  --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..993de9e9f64ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_KHR_float_controls2",
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
         {"SPV_INTEL_tensor_float32_conversion",
-         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+        {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,

diff  --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index b7e371d190866..a95f393b75605 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1222,6 +1222,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
   }
 }
 
+static bool isBFloat16Type(const SPIRVType *TypeDef) {
+  return TypeDef && TypeDef->getNumOperands() == 3 &&
+         TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+         TypeDef->getOperand(1).getImm() == 16 &&
+         TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
+}
+
 void addInstrRequirements(const MachineInstr &MI,
                           SPIRV::RequirementHandler &Reqs,
                           const SPIRVSubtarget &ST) {
@@ -1261,12 +1268,29 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Int8);
     break;
   }
+  case SPIRV::OpDot: {
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+    if (isBFloat16Type(TypeDef))
+      Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+    break;
+  }
   case SPIRV::OpTypeFloat: {
     unsigned BitWidth = MI.getOperand(1).getImm();
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
-    else if (BitWidth == 16)
-      Reqs.addCapability(SPIRV::Capability::Float16);
+    else if (BitWidth == 16) {
+      if (isBFloat16Type(&MI)) {
+        if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
+          report_fatal_error("OpTypeFloat type with bfloat requires the "
+                             "following SPIR-V extension: SPV_KHR_bfloat16",
+                             false);
+        Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+        Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+      } else {
+        Reqs.addCapability(SPIRV::Capability::Float16);
+      }
+    }
     break;
   }
   case SPIRV::OpTypeVector: {
@@ -1286,8 +1310,9 @@ void addInstrRequirements(const MachineInstr &MI,
     assert(MI.getOperand(2).isReg());
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
-    if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
-        TypeDef->getOperand(1).getImm() == 16)
+    if ((TypeDef->getNumOperands() == 2) &&
+        (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
+        (TypeDef->getOperand(1).getImm() == 16))
       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
     break;
   }
@@ -1593,7 +1618,7 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::AsmINTEL);
     }
     break;
-  case SPIRV::OpTypeCooperativeMatrixKHR:
+  case SPIRV::OpTypeCooperativeMatrixKHR: {
     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
       report_fatal_error(
           "OpTypeCooperativeMatrixKHR type requires the "
@@ -1601,7 +1626,12 @@ void addInstrRequirements(const MachineInstr &MI,
           false);
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+    if (isBFloat16Type(TypeDef))
+      Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
     break;
+  }
   case SPIRV::OpArithmeticFenceEXT:
     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
       report_fatal_error("OpArithmeticFenceEXT requires the "

diff  --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index ed933f872d136..501bcb94af2ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -383,6 +383,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
 defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
 defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -595,6 +596,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
 defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
 defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
@@ -2021,4 +2025,4 @@ multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
              reqExtensions, [], []>;
 }
 
-defm BFloat16KHR : FPEncodingOperand<0, []>;
+defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>;

diff  --git a/llvm/test/CodeGen/SPIRV/basic_float_types.ll b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
index 486f6358ce5de..a0ba97e1d1f14 100644
--- a/llvm/test/CodeGen/SPIRV/basic_float_types.ll
+++ b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
@@ -1,12 +1,13 @@
-; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
-; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
-; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
 
 define void @main() {
 entry:
 
 ; CHECK-DAG: OpCapability Float16
 ; CHECK-DAG: OpCapability Float64
+; CHECK-DAG: OpCapability BFloat16TypeKHR
 
 ; CHECK-DAG:     %[[#half:]] = OpTypeFloat 16{{$}}
 ; CHECK-DAG:   %[[#bfloat:]] = OpTypeFloat 16 0{{$}}

diff  --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
new file mode 100644
index 0000000000000..22668e71fb257
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -0,0 +1,22 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_kernel void @test() {
+entry:
+  %addr1 = alloca bfloat
+  %addr2 = alloca <2 x bfloat>
+  %data1 = load bfloat, ptr %addr1
+  %data2 = load <2 x bfloat>, ptr %addr2
+  ret void
+}

diff  --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
new file mode 100644
index 0000000000000..d47b5d7440d18
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -0,0 +1,22 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability CooperativeMatrixKHR
+; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_KHR_cooperative_matrix"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
+; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]]  %[[#]] %[[#]] %[[#]] %[[#]]
+; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
+
+define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
+entry:
+    %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+    %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+    %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 1.0)
+    store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
+    ret void
+}
+
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)

diff  --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
new file mode 100644
index 0000000000000..4c248fea5c7f1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -0,0 +1,21 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16DotProductKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+; CHECK: OpDot
+
+declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
+
+define spir_kernel void @test() {
+entry:
+  %addrA = alloca <2 x bfloat>
+  %addrB = alloca <2 x bfloat>
+  %dataA = load <2 x bfloat>, ptr %addrA
+  %dataB = load <2 x bfloat>, ptr %addrB
+  %call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
+  ret void
+}


        


More information about the llvm-commits mailing list