[llvm] [SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 (PR #155645)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 8 22:04:56 PDT 2025
https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/155645
>From d391c7c72d93f6561c0cbb4d1e091a5465c21d71 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 03:42:59 -0700
Subject: [PATCH 1/9] add support for the SPIR-V extension SPV_KHR_bfloat16
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 3 --
llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 3 +-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 40 ++++++++++++++++---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 13 ++++++
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 38 ++++++++++++++++--
.../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 4 ++
.../extensions/SPV_KHR_bfloat16/bfloat16.ll | 22 ++++++++++
.../bfloat16_cooperative_matrix.ll | 20 ++++++++++
.../SPV_KHR_bfloat16/bfloat16_dot.ll | 21 ++++++++++
9 files changed, 151 insertions(+), 13 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..8a4b1afec3d26 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,9 +2765,6 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
- return false;
-
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..993de9e9f64ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
- SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+ SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+ {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cfe24c84941a9..ce9ebb619f242 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1122,7 +1122,19 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
ExplicitLayoutRequired, EmitIR);
TypesInProcessing.erase(Ty);
- VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
+
+ // Record the FPVariant of the floating-point registers in the
+ // VRegFPVariantMap.
+ MachineFunction *MF = &MIRBuilder.getMF();
+ Register TypeReg = getSPIRVTypeID(SpirvType);
+ if (Ty->isFloatingPointTy()) {
+ if (Ty->isBFloatTy()) {
+ VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
+ } else {
+ VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
+ }
+ }
+ VRegToTypeMap[MF][TypeReg] = SpirvType;
// TODO: We could end up with two SPIR-V types pointing to the same llvm type.
// Is that a problem?
@@ -1679,11 +1691,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
const MachineInstr *NewMI =
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
- MIRBuilder.getDL(), TII.get(SPIRVOPcode))
- .addDef(createTypeVReg(CurMF->getRegInfo()))
- .addImm(BitWidth)
- .addImm(0);
+ auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+ MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+ .addDef(createTypeVReg(CurMF->getRegInfo()))
+ .addImm(BitWidth);
+
+ if (SPIRVOPcode != SPIRV::OpTypeFloat)
+ MIB.addImm(0);
+
+ return MIB;
});
add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
@@ -2088,3 +2104,15 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
}
return false;
}
+
+SPIRVGlobalRegistry::FPVariant
+SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
+ const MachineFunction *MF) {
+ auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
+ if (t != VRegFPVariantMap.end()) {
+ auto tt = t->second.find(VReg);
+ if (tt != t->second.end())
+ return tt->second;
+ }
+ return FPVariant::NONE;
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7ef812828b7cc..1f8c30dc01f7f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,6 +29,10 @@ using SPIRVType = const MachineInstr;
using StructOffsetDecorator = std::function<void(Register)>;
class SPIRVGlobalRegistry : public SPIRVIRMapping {
+public:
+ enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
+
+private:
// Registers holding values which have types associated with them.
// Initialized upon VReg definition in IRTranslator.
// Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -88,6 +92,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// map of aliasing decorations to aliasing metadata
std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;
+ // Maps floating point Registers to their FPVariant (float type kind), given
+ // the MachineFunction.
+ DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
+ VRegFPVariantMap;
+
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ,
@@ -422,6 +431,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// structures referring this instruction.
void invalidateMachineInstr(MachineInstr *MI);
+ // Return the FPVariant of to the given floating-point regiester.
+ FPVariant getFPVariantForVReg(Register VReg,
+ const MachineFunction *MF = nullptr);
+
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..b8041725c9050 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1261,12 +1261,35 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Int8);
break;
}
+ case SPIRV::OpDot: {
+ const MachineFunction *MF = MI.getMF();
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ SPIRVGlobalRegistry::FPVariant FPV =
+ GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+ if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+ Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+ }
+ break;
+ }
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
- else if (BitWidth == 16)
+ else if (BitWidth == 16) {
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ const MachineFunction *MF = MI.getMF();
+ SPIRVGlobalRegistry::FPVariant FPV =
+ GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
+ if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
+ report_fatal_error("OpTypeFloat type with bfloat requires the "
+ "following SPIR-V extension: SPV_KHR_bfloat16",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+ }
Reqs.addCapability(SPIRV::Capability::Float16);
+ }
break;
}
case SPIRV::OpTypeVector: {
@@ -1593,15 +1616,24 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
}
break;
- case SPIRV::OpTypeCooperativeMatrixKHR:
+ case SPIRV::OpTypeCooperativeMatrixKHR: {
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
report_fatal_error(
"OpTypeCooperativeMatrixKHR type requires the "
"following SPIR-V extension: SPV_KHR_cooperative_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
- Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+ const MachineFunction *MF = MI.getMF();
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ SPIRVGlobalRegistry::FPVariant FPV =
+ GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+ if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+ Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
+ } else {
+ Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+ }
break;
+ }
case SPIRV::OpArithmeticFenceEXT:
if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
report_fatal_error("OpArithmeticFenceEXT requires the "
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
new file mode 100644
index 0000000000000..bfc84691f6945
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -0,0 +1,22 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_kernel void @test() {
+entry:
+ %addr1 = alloca bfloat
+ %addr2 = alloca <2 x bfloat>
+ %data1 = load bfloat, ptr %addr1
+ %data2 = load <2 x bfloat>, ptr %addr2
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
new file mode 100644
index 0000000000000..5a6e6d88ca6a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -0,0 +1,20 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]] %[[#]] %[[#]] %[[#]] %[[#]]
+; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
+
+define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
+entry:
+ %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+ %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+ %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 1.0)
+ store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
+ ret void
+}
+
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
new file mode 100644
index 0000000000000..7cfe29261f2cd
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -0,0 +1,21 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16DotProductKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+; CHECK: OpDot
+
+declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
+
+define spir_kernel void @test() {
+entry:
+ %addrA = alloca <2 x bfloat>
+ %addrB = alloca <2 x bfloat>
+ %dataA = load <2 x bfloat>, ptr %addrA
+ %dataB = load <2 x bfloat>, ptr %addrB
+ %call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
+ ret void
+}
>From 601d68bf0633f318f45ec878babf971460339a71 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 5 Sep 2025 09:28:38 -0700
Subject: [PATCH 2/9] nit change
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 3 +++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 16 +++++++++++-----
2 files changed, 14 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 8a4b1afec3d26..c6813df435154 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,6 +2765,9 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
+ return false;
+
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ce9ebb619f242..4c931b4f45e69 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -2108,11 +2108,17 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
SPIRVGlobalRegistry::FPVariant
SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
const MachineFunction *MF) {
- auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
- if (t != VRegFPVariantMap.end()) {
- auto tt = t->second.find(VReg);
- if (tt != t->second.end())
- return tt->second;
+ const MachineFunction *Func = MF ? MF : CurMF;
+ DenseMap<const MachineFunction *,
+ DenseMap<Register, FPVariant>>::const_iterator FuncIt =
+ VRegFPVariantMap.find(Func);
+
+ if (FuncIt != VRegFPVariantMap.end()) {
+ const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
+ DenseMap<Register, FPVariant>::const_iterator VRegIt = VRegMap.find(VReg);
+
+ if (VRegIt != VRegMap.end())
+ return VRegIt->second;
}
return FPVariant::NONE;
}
>From e0b7026d30fc7ac96b01742c1487a0f43af50a61 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 5 Sep 2025 10:29:03 -0700
Subject: [PATCH 3/9] revert prev change
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 16 +++++-----------
1 file changed, 5 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 4c931b4f45e69..ce9ebb619f242 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -2108,17 +2108,11 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
SPIRVGlobalRegistry::FPVariant
SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
const MachineFunction *MF) {
- const MachineFunction *Func = MF ? MF : CurMF;
- DenseMap<const MachineFunction *,
- DenseMap<Register, FPVariant>>::const_iterator FuncIt =
- VRegFPVariantMap.find(Func);
-
- if (FuncIt != VRegFPVariantMap.end()) {
- const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
- DenseMap<Register, FPVariant>::const_iterator VRegIt = VRegMap.find(VReg);
-
- if (VRegIt != VRegMap.end())
- return VRegIt->second;
+ auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
+ if (t != VRegFPVariantMap.end()) {
+ auto tt = t->second.find(VReg);
+ if (tt != t->second.end())
+ return tt->second;
}
return FPVariant::NONE;
}
>From 0d818de31d8e2ad39e320a7945e94a68896e9c7c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 5 Sep 2025 10:46:21 -0700
Subject: [PATCH 4/9] Revert change
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index c6813df435154..679f77f6be411 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
- return false;
+ // if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
+ // return false;
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
>From 91812e015a627d786726b0737f60e6498acaddca Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 07:24:34 -0700
Subject: [PATCH 5/9] code clean up
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 4 ++--
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 12 +++++++-----
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 2 +-
.../SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll | 2 +-
.../SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll | 2 +-
.../extensions/SPV_KHR_bfloat16/bfloat16_dot.ll | 2 +-
6 files changed, 13 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 679f77f6be411..c6813df435154 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- // if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
- // return false;
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
+ return false;
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ce9ebb619f242..01c995bccb0ef 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -2108,11 +2108,13 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
SPIRVGlobalRegistry::FPVariant
SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
const MachineFunction *MF) {
- auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
- if (t != VRegFPVariantMap.end()) {
- auto tt = t->second.find(VReg);
- if (tt != t->second.end())
- return tt->second;
+ const MachineFunction *Func = MF ? MF : CurMF;
+ auto FuncIt = VRegFPVariantMap.find(Func);
+ if (FuncIt != VRegFPVariantMap.end()) {
+ const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
+ auto VRegIt = VRegMap.find(VReg);
+ if (VRegIt != VRegMap.end())
+ return VRegIt->second;
}
return FPVariant::NONE;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index b8041725c9050..144c9659d1ce3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1276,6 +1276,7 @@ void addInstrRequirements(const MachineInstr &MI,
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16) {
+ Reqs.addCapability(SPIRV::Capability::Float16);
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
const MachineFunction *MF = MI.getMF();
SPIRVGlobalRegistry::FPVariant FPV =
@@ -1288,7 +1289,6 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
}
- Reqs.addCapability(SPIRV::Capability::Float16);
}
break;
}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index bfc84691f6945..45123eb15d8d7 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -4,8 +4,8 @@
; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
-; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index 5a6e6d88ca6a0..d54b8325c6783 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -1,9 +1,9 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]] %[[#]] %[[#]] %[[#]] %[[#]]
; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index 7cfe29261f2cd..0943170ae6785 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -1,9 +1,9 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpCapability BFloat16DotProductKHR
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
; CHECK: OpDot
>From 9f8c4ba5e571eddd6992a223257d967b39e2dab3 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 15:35:18 -0700
Subject: [PATCH 6/9] the SPIRV bfloat instruction should be Optypefloat 16 0,
we update the implementation to make the correction
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 52 +++++--------------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 17 ++----
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 2 +-
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 20 +++----
.../extensions/SPV_KHR_bfloat16/bfloat16.ll | 2 +-
.../bfloat16_cooperative_matrix.ll | 2 +-
.../SPV_KHR_bfloat16/bfloat16_dot.ll | 2 +-
7 files changed, 27 insertions(+), 70 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 01c995bccb0ef..c445b7a4a6e95 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,11 +195,15 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
- MachineIRBuilder &MIRBuilder) {
+ MachineIRBuilder &MIRBuilder, bool isBfloatTy) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+ auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width);
+ if(isBfloatTy){
+ MIB.addImm(0);
+ }
+ return MIB;
});
}
@@ -1042,7 +1046,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
: getOpTypeInt(Width, MIRBuilder, false);
}
if (Ty->isFloatingPointTy())
- return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+ return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, Ty->isBFloatTy());
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
@@ -1122,19 +1126,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
ExplicitLayoutRequired, EmitIR);
TypesInProcessing.erase(Ty);
-
- // Record the FPVariant of the floating-point registers in the
- // VRegFPVariantMap.
- MachineFunction *MF = &MIRBuilder.getMF();
- Register TypeReg = getSPIRVTypeID(SpirvType);
- if (Ty->isFloatingPointTy()) {
- if (Ty->isBFloatTy()) {
- VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
- } else {
- VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
- }
- }
- VRegToTypeMap[MF][TypeReg] = SpirvType;
+ VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
// TODO: We could end up with two SPIR-V types pointing to the same llvm type.
// Is that a problem?
@@ -1691,15 +1683,11 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
const MachineInstr *NewMI =
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
- MIRBuilder.getDL(), TII.get(SPIRVOPcode))
- .addDef(createTypeVReg(CurMF->getRegInfo()))
- .addImm(BitWidth);
-
- if (SPIRVOPcode != SPIRV::OpTypeFloat)
- MIB.addImm(0);
-
- return MIB;
+ return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+ MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+ .addDef(createTypeVReg(CurMF->getRegInfo()))
+ .addImm(BitWidth)
+ .addImm(0);
});
add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
@@ -2104,17 +2092,3 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
}
return false;
}
-
-SPIRVGlobalRegistry::FPVariant
-SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
- const MachineFunction *MF) {
- const MachineFunction *Func = MF ? MF : CurMF;
- auto FuncIt = VRegFPVariantMap.find(Func);
- if (FuncIt != VRegFPVariantMap.end()) {
- const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
- auto VRegIt = VRegMap.find(VReg);
- if (VRegIt != VRegMap.end())
- return VRegIt->second;
- }
- return FPVariant::NONE;
-}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 1f8c30dc01f7f..2c28484599656 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,10 +29,6 @@ using SPIRVType = const MachineInstr;
using StructOffsetDecorator = std::function<void(Register)>;
class SPIRVGlobalRegistry : public SPIRVIRMapping {
-public:
- enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
-
-private:
// Registers holding values which have types associated with them.
// Initialized upon VReg definition in IRTranslator.
// Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -92,11 +88,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// map of aliasing decorations to aliasing metadata
std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;
- // Maps floating point Registers to their FPVariant (float type kind), given
- // the MachineFunction.
- DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
- VRegFPVariantMap;
-
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ,
@@ -431,10 +422,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// structures referring this instruction.
void invalidateMachineInstr(MachineInstr *MI);
- // Return the FPVariant of to the given floating-point regiester.
- FPVariant getFPVariantForVReg(Register VReg,
- const MachineFunction *MF = nullptr);
-
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
@@ -449,7 +436,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
- SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
+ SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, bool isBfloatTy);
+
+ SPIRVType *getOpTypeBFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 8d10cd0ffb3dd..496dcba17c10d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,7 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
"$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
"$type = OpTypeFloat $width">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 144c9659d1ce3..009de0d951198 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,11 +1262,10 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpDot: {
- const MachineFunction *MF = MI.getMF();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
- SPIRVGlobalRegistry::FPVariant FPV =
- GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
- if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+ const MachineFunction *MF = MI.getMF();
+ SPIRVType *RegType = GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF);
+ if (RegType->getNumOperands() == 3) {
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
}
break;
@@ -1277,11 +1276,7 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16) {
Reqs.addCapability(SPIRV::Capability::Float16);
- SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
- const MachineFunction *MF = MI.getMF();
- SPIRVGlobalRegistry::FPVariant FPV =
- GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
- if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+ if (MI.getNumOperands() == 3) {
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
report_fatal_error("OpTypeFloat type with bfloat requires the "
"following SPIR-V extension: SPV_KHR_bfloat16",
@@ -1623,11 +1618,10 @@ void addInstrRequirements(const MachineInstr &MI,
"following SPIR-V extension: SPV_KHR_cooperative_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
- const MachineFunction *MF = MI.getMF();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
- SPIRVGlobalRegistry::FPVariant FPV =
- GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
- if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+ const MachineFunction *MF = MI.getMF();
+ SPIRVType *RegType = GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF);
+ if (RegType->getNumOperands() == 3) {
Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
} else {
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index 45123eb15d8d7..48514b60ad9b9 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -6,7 +6,7 @@
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-DAG: OpCapability BFloat16TypeKHR
-; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index d54b8325c6783..9f52a459616da 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -4,7 +4,7 @@
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
-; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]] %[[#]] %[[#]] %[[#]] %[[#]]
; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index 0943170ae6785..51d212788df72 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -4,7 +4,7 @@
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpCapability BFloat16DotProductKHR
-; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
; CHECK: OpDot
>From 55ed2d643d0629083384df58136a6235830877ce Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 21:27:45 -0700
Subject: [PATCH 7/9] Add the enum class FPEncoding
---
.../Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h | 5 ++++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 27 ++++++++++-------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 +--
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 29 +++++++++----------
.../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 28 +++++++++++++++++-
5 files changed, 65 insertions(+), 28 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
index c2c08f8831307..d76180ce97e9e 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
@@ -232,6 +232,11 @@ namespace SpecConstantOpOperands {
#include "SPIRVGenTables.inc"
} // namespace SpecConstantOpOperands
+namespace FPEncoding {
+#define GET_FPEncoding_DECL
+#include "SPIRVGenTables.inc"
+} // namespace FPEncoding
+
struct ExtendedBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index c445b7a4a6e95..8d4e766545ba3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,16 +195,18 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
- MachineIRBuilder &MIRBuilder, bool isBfloatTy) {
- return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+ MachineIRBuilder &MIRBuilder) {
+ return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width);
- if(isBfloatTy){
- MIB.addImm(0);
- }
- return MIB;
- });
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
+ MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode) {
+ return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+ .addDef(createTypeVReg(MIRBuilder))
+ .addImm(Width)
+ .addImm(FPEncode);
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
@@ -1045,8 +1047,13 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
- if (Ty->isFloatingPointTy())
- return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, Ty->isBFloatTy());
+ if (Ty->isFloatingPointTy()) {
+ if(Ty->isBFloatTy()) {
+ return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, SPIRV::FPEncoding::BFloat16KHR);
+ } else {
+ return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+ }
+ }
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2c28484599656..8104cb8177457 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -436,9 +436,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
- SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, bool isBfloatTy);
+ SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
- SPIRVType *getOpTypeBFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
+ SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode);
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 009de0d951198..6732850ab2fa9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,10 +1262,9 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpDot: {
- SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
- const MachineFunction *MF = MI.getMF();
- SPIRVType *RegType = GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF);
- if (RegType->getNumOperands() == 3) {
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+ if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
}
break;
@@ -1275,14 +1274,15 @@ void addInstrRequirements(const MachineInstr &MI,
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16) {
- Reqs.addCapability(SPIRV::Capability::Float16);
- if (MI.getNumOperands() == 3) {
+ if ((MI.getNumOperands() == 3) && (MI.getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
report_fatal_error("OpTypeFloat type with bfloat requires the "
"following SPIR-V extension: SPV_KHR_bfloat16",
false);
- Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ } else {
+ Reqs.addCapability(SPIRV::Capability::Float16);
}
}
break;
@@ -1304,8 +1304,9 @@ void addInstrRequirements(const MachineInstr &MI,
assert(MI.getOperand(2).isReg());
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
- if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
- TypeDef->getOperand(1).getImm() == 16)
+ if ((TypeDef->getNumOperands() == 2) &&
+ (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
+ (TypeDef->getOperand(1).getImm() == 16))
Reqs.addCapability(SPIRV::Capability::Float16Buffer);
break;
}
@@ -1617,14 +1618,12 @@ void addInstrRequirements(const MachineInstr &MI,
"OpTypeCooperativeMatrixKHR type requires the "
"following SPIR-V extension: SPV_KHR_cooperative_matrix",
false);
+ Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
- SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
- const MachineFunction *MF = MI.getMF();
- SPIRVType *RegType = GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF);
- if (RegType->getNumOperands() == 3) {
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+ if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
- } else {
- Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
}
break;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 9d630356e8ffb..00b636d315d02 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -210,6 +210,7 @@ def CooperativeMatrixLayoutOperand : OperandCategory;
def CooperativeMatrixOperandsOperand : OperandCategory;
def SpecConstantOpOperandsOperand : OperandCategory;
def MatrixMultiplyAccumulateOperandsOperand : OperandCategory;
+def FPEncodingOperand : OperandCategory;
//===----------------------------------------------------------------------===//
// Definition of the Environments
@@ -382,7 +383,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
-defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -2000,3 +2001,28 @@ defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x400,
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+
+//===----------------------------------------------------------------------===//
+// Multiclass used to define FPEncoding enum values and at the
+// same time SymbolicOperand entries extensions.
+//===----------------------------------------------------------------------===//
+def FPEncoding : GenericEnum, Operand<i32> {
+ let FilterClass = "FPEncoding";
+ let NameField = "Name";
+ let ValueField = "Value";
+ let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
+}
+
+class FPEncoding<string name, bits<32> value> {
+ string Name = name;
+ bits<32> Value = value;
+}
+
+multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
+ def NAME : FPEncoding<NAME, value>;
+ defm : SymbolicOperandWithRequirements<
+ FPEncodingOperand, value, NAME, 0, 0,
+ reqExtensions, [], []>;
+}
+
+defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>;
>From 295865393b71def7d89b580aacb89008fd55d5e7 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 21:56:15 -0700
Subject: [PATCH 8/9] nit change and update tests
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 8 ++++++--
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 4 ++--
llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 2 +-
.../CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll | 2 +-
.../SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll | 4 +++-
.../SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll | 2 +-
6 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 8d4e766545ba3..76b276fede679 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -196,17 +196,21 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
MachineIRBuilder &MIRBuilder) {
- return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+ return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width);
+ });
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode) {
- return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+ return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
.addImm(FPEncode);
+ });
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 6732850ab2fa9..91bfe91bf059f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1279,8 +1279,8 @@ void addInstrRequirements(const MachineInstr &MI,
report_fatal_error("OpTypeFloat type with bfloat requires the "
"following SPIR-V extension: SPV_KHR_bfloat16",
false);
- Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
} else {
Reqs.addCapability(SPIRV::Capability::Float16);
}
@@ -1618,8 +1618,8 @@ void addInstrRequirements(const MachineInstr &MI,
"OpTypeCooperativeMatrixKHR type requires the "
"following SPIR-V extension: SPV_KHR_cooperative_matrix",
false);
- Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
+ Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 00b636d315d02..501bcb94af2ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -2004,7 +2004,7 @@ defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x200
//===----------------------------------------------------------------------===//
// Multiclass used to define FPEncoding enum values and at the
-// same time SymbolicOperand entries extensions.
+// same time SymbolicOperand entries with extensions.
//===----------------------------------------------------------------------===//
def FPEncoding : GenericEnum, Operand<i32> {
let FilterClass = "FPEncoding";
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index 48514b60ad9b9..22668e71fb257 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -4,8 +4,8 @@
; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index 9f52a459616da..d47b5d7440d18 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -1,9 +1,11 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability CooperativeMatrixKHR
; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_KHR_cooperative_matrix"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]] %[[#]] %[[#]] %[[#]] %[[#]]
; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index 51d212788df72..4c248fea5c7f1 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -1,9 +1,9 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
-; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpCapability BFloat16DotProductKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
; CHECK: OpDot
>From 8e9389ddcccbbce15e484cbdf0a89f27a3c07256 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 22:04:42 -0700
Subject: [PATCH 9/9] fix code format issue
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 11 +++++++----
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 3 ++-
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 9 ++++++---
3 files changed, 15 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 76b276fede679..115766ce886c7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -203,8 +203,10 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
});
}
-SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
- MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode) {
+SPIRVType *
+SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::FPEncoding::FPEncoding FPEncode) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
@@ -1052,8 +1054,9 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
: getOpTypeInt(Width, MIRBuilder, false);
}
if (Ty->isFloatingPointTy()) {
- if(Ty->isBFloatTy()) {
- return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, SPIRV::FPEncoding::BFloat16KHR);
+ if (Ty->isBFloatTy()) {
+ return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder,
+ SPIRV::FPEncoding::BFloat16KHR);
} else {
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 8104cb8177457..a648defa0a888 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -438,7 +438,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
- SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, SPIRV::FPEncoding::FPEncoding FPEncode);
+ SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
+ SPIRV::FPEncoding::FPEncoding FPEncode);
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 91bfe91bf059f..01f33ce1e21d9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1264,7 +1264,8 @@ void addInstrRequirements(const MachineInstr &MI,
case SPIRV::OpDot: {
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
- if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
+ if ((TypeDef->getNumOperands() == 3) &&
+ (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
}
break;
@@ -1274,7 +1275,8 @@ void addInstrRequirements(const MachineInstr &MI,
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16) {
- if ((MI.getNumOperands() == 3) && (MI.getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
+ if ((MI.getNumOperands() == 3) &&
+ (MI.getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
report_fatal_error("OpTypeFloat type with bfloat requires the "
"following SPIR-V extension: SPV_KHR_bfloat16",
@@ -1622,7 +1624,8 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
- if ((TypeDef->getNumOperands() == 3) && (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
+ if ((TypeDef->getNumOperands() == 3) &&
+ (TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR)) {
Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
}
break;
More information about the llvm-commits
mailing list