[llvm] [SPIRV] Add bfloat support (PR #155645)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 2 14:48:19 PDT 2025
https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/155645
>From 9e4f483ce4d7c0f710270cc21007b73ce56b59e1 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 03:42:59 -0700
Subject: [PATCH 01/16] initial commit for adding the support of bfloat in
machine inst
---
llvm/include/llvm/CodeGen/MachineInstr.h | 3 ++-
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 9 +++++++--
llvm/lib/CodeGen/MachineInstr.cpp | 3 +++
llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 4 +++-
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 4 ++++
llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 4 ++++
6 files changed, 23 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 10a9b1ff1411d..6f692ae32510b 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,8 +123,9 @@ class MachineInstr
NoUSWrap = 1 << 20, // Instruction supports geps
// no unsigned signed wrap.
SameSign = 1 << 21, // Both operands have the same sign.
- InBounds = 1 << 22 // Pointer arithmetic remains inbounds.
+ InBounds = 1 << 22, // Pointer arithmetic remains inbounds.
// Implies NoUSWrap.
+ BFloat16 = 1 << 23 // Instruction with bf16 type
};
private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..2a6b66984c8ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
- return false;
+ // if (containsBF16Type(U))
+ // return false;
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
@@ -2813,6 +2813,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (isa<FPMathOperator>(CI))
MIB->copyIRFlags(CI);
+ // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+ if (containsBF16Type(U)) {
+ MIB->setFlag(MachineInstr::MIFlag::BFloat16);
+ }
+
for (const auto &Arg : enumerate(CI.args())) {
// If this is required to be an immediate, don't materialize it in a
// register.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 79047f732808a..10ff667bcb522 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,6 +632,9 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (I.getMetadata(LLVMContext::MD_unpredictable))
MIFlags |= MachineInstr::MIFlag::Unpredictable;
+ if (I.getType()->isBFloatTy())
+ MIFlags |= MachineInstr::MIFlag::BFloat16;
+
return MIFlags;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..bd13a3bae92cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
- SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+ SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+ {"SPV_KHR_bfloat16",
+ SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..5bba5cdce3753 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,6 +1267,10 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
+ if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+ }
break;
}
case SPIRV::OpTypeVector: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
>From 32e6c031b24083ef668da44d2dd5c9785d9ba12d Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 12:29:09 -0700
Subject: [PATCH 02/16] debugging
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 3 ++-
llvm/lib/CodeGen/MachineInstr.cpp | 4 +++-
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 2 ++
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 10 ++++++----
4 files changed, 13 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 2a6b66984c8ae..3b112805e082c 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2815,6 +2815,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
// If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
if (containsBF16Type(U)) {
+ // assert(false && "bfloat detected at the IR Translator");
MIB->setFlag(MachineInstr::MIFlag::BFloat16);
}
@@ -2878,7 +2879,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
MIB.addUse(TokenReg, RegState::Implicit);
}
}
-
+
return true;
}
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 10ff667bcb522..1ec8c8a2d90c9 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,8 +632,10 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (I.getMetadata(LLVMContext::MD_unpredictable))
MIFlags |= MachineInstr::MIFlag::Unpredictable;
- if (I.getType()->isBFloatTy())
+ if (I.getType()->getScalarType()->isBFloatTy()) {
MIFlags |= MachineInstr::MIFlag::BFloat16;
+ // assert(false && "bfloat detected at the MachineInstr");
+ }
return MIFlags;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index f0b938d681dba..923c92c28d3f6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -169,6 +169,8 @@ def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness)
"$type = OpTypeInt $width $signedness">;
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
"$type = OpTypeFloat $width">;
+def OpTypeBFloat: Op<40, (outs TYPE:$type), (ins i32imm:$width),
+ "$type = OpTypeFloat $width 0">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 5bba5cdce3753..19c938c96a9e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,10 +1267,12 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
- if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
- Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
- Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
- }
+ break;
+ }
+ case SPIRV::OpTypeBFloat: {
+ assert(false && "bfloat detected at the file");
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
break;
}
case SPIRV::OpTypeVector: {
>From eccbcf845dd59f4a822a5fc9487d48b4e62d7a2c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 14:53:45 -0700
Subject: [PATCH 03/16] Debugging
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 19 ++++++++++++-------
llvm/lib/CodeGen/GlobalISel/Legalizer.cpp | 4 ++++
llvm/lib/CodeGen/MachineInstr.cpp | 14 ++++++++++++--
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 3 ++-
.../SPIRV/SPIRVPreLegalizerCombiner.cpp | 6 ++++++
5 files changed, 36 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 3b112805e082c..4aece388749f8 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2195,6 +2195,8 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
if (translateSimpleIntrinsic(CI, ID, MIRBuilder))
return true;
+ LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
+ LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
switch (ID) {
default:
break;
@@ -2581,9 +2583,10 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
Value *FpValue = CI.getOperand(0);
ConstantInt *TestMaskValue = cast<ConstantInt>(CI.getOperand(1));
+ uint32_t Flags = MachineInstr::copyFlagsFromInstruction(CI);
MIRBuilder
.buildInstr(TargetOpcode::G_IS_FPCLASS, {getOrCreateVReg(CI)},
- {getOrCreateVReg(*FpValue)})
+ {getOrCreateVReg(*FpValue)}, Flags)
.addImm(TestMaskValue->getZExtValue());
return true;
@@ -2813,12 +2816,6 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (isa<FPMathOperator>(CI))
MIB->copyIRFlags(CI);
- // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
- if (containsBF16Type(U)) {
- // assert(false && "bfloat detected at the IR Translator");
- MIB->setFlag(MachineInstr::MIFlag::BFloat16);
- }
-
for (const auto &Arg : enumerate(CI.args())) {
// If this is required to be an immediate, don't materialize it in a
// register.
@@ -2879,6 +2876,13 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
MIB.addUse(TokenReg, RegState::Implicit);
}
}
+
+ // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+ if (containsBF16Type(U)) {
+ dbgs() << "Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
+ MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
+ }
+ MIB->copyIRFlags(CI);
return true;
}
@@ -4161,6 +4165,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
// Translate any debug-info attached to the instruction.
translateDbgInfo(Inst, *CurBuilder);
+ LLVM_DEBUG(dbgs() << "Inst at IRTranslator: " << Inst << "\n");
if (translate(Inst))
continue;
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index aef16b5f33af4..6013124558106 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -192,6 +192,10 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
if (MBB->empty())
continue;
for (MachineInstr &MI : *MBB) {
+ dbgs() << "Legalizer sees: " << MI << "\n";
+ if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+ assert(false && "bfloat detected at the Legalizer");
+ }
// Only legalize pre-isel generic instructions: others don't have types
// and are assumed to be legal.
if (!isPreISelGenericOpcode(MI.getOpcode()))
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 1ec8c8a2d90c9..bbeeed76631c6 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,9 +632,19 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (I.getMetadata(LLVMContext::MD_unpredictable))
MIFlags |= MachineInstr::MIFlag::Unpredictable;
- if (I.getType()->getScalarType()->isBFloatTy()) {
+ auto *Sel = dyn_cast<SelectInst>(&I);
+ bool BFloatOpnd = !Sel && I.getType()->getScalarType()->isBFloatTy();
+
+ if (isa<CallInst>(&I)) {
+ for (const Value *Op : I.operands()) {
+ Type *OpTy = Op->getType();
+ BFloatOpnd |= OpTy->getScalarType()->isBFloatTy();
+ }
+ }
+
+ if (BFloatOpnd) {
MIFlags |= MachineInstr::MIFlag::BFloat16;
- // assert(false && "bfloat detected at the MachineInstr");
+ LLVM_DEBUG(dbgs() << "bfloat detected at the MachineInstr" << "\n");
}
return MIFlags;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 19c938c96a9e8..ab26f0208d758 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,6 +1262,7 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpTypeFloat: {
+ assert(false && "bfloat detected at the spirv module analysis");
unsigned BitWidth = MI.getOperand(1).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
@@ -1270,7 +1271,7 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpTypeBFloat: {
- assert(false && "bfloat detected at the file");
+ assert(false && "bfloat detected at the spirv module analysis");
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
index 83567519355a8..a964db0b02746 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
@@ -156,6 +156,12 @@ SPIRVPreLegalizerCombinerImpl::SPIRVPreLegalizerCombinerImpl(
}
bool SPIRVPreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
+ // Inst combiner is disabled for instructions
+ // with bfloat type.
+ // TODO: add bfloat-specific inst combiner rules
+ if (MI.getFlag(MachineInstr::BFloat16))
+ return false;
+
return tryCombineAllImpl(MI);
}
>From 2b080940fd7e57c6dc714aedca113b80e8871d54 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 05:55:29 -0700
Subject: [PATCH 04/16] continue debugging
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 19 +++++++++++--------
llvm/lib/CodeGen/GlobalISel/Legalizer.cpp | 4 ++--
llvm/lib/CodeGen/MachineInstr.cpp | 4 +++-
3 files changed, 16 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 4aece388749f8..a396839c5061b 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2195,8 +2195,8 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
if (translateSimpleIntrinsic(CI, ID, MIRBuilder))
return true;
- LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
- LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
+ LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
+ LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
switch (ID) {
default:
break;
@@ -2878,11 +2878,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
}
// If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+ MIB->copyIRFlags(CI);
if (containsBF16Type(U)) {
- dbgs() << "Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
+ dbgs() << "[BFLOAT] Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
}
- MIB->copyIRFlags(CI);
return true;
}
@@ -4105,8 +4105,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
if (CLI->fallBackToDAGISel(*MF)) {
OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
F.getSubprogram(), &F.getEntryBlock());
- R << "unable to lower function: "
- << ore::NV("Prototype", F.getFunctionType());
+ R << "u "
+ << ore::NV("Prototype", F.getnable to lower function:FunctionType());
reportTranslationError(*MF, *TPC, *ORE, R);
return false;
}
@@ -4118,6 +4118,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
continue; // Don't handle zero sized types.
ArrayRef<Register> VRegs = getOrCreateVRegs(Arg);
VRegArgs.push_back(VRegs);
+ LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
if (Arg.hasSwiftErrorAttr()) {
assert(VRegs.size() == 1 && "Too many vregs for Swift error");
@@ -4165,10 +4166,12 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
// Translate any debug-info attached to the instruction.
translateDbgInfo(Inst, *CurBuilder);
- LLVM_DEBUG(dbgs() << "Inst at IRTranslator: " << Inst << "\n");
+ LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator: " << Inst << "\n");
- if (translate(Inst))
+ if (translate(Inst)) {
+ LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator after translate(Inst): " << Inst << "\n");
continue;
+ }
OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
Inst.getDebugLoc(), BB);
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index 6013124558106..fc1fede1a6d1c 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -192,9 +192,9 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
if (MBB->empty())
continue;
for (MachineInstr &MI : *MBB) {
- dbgs() << "Legalizer sees: " << MI << "\n";
+ LLVM_DEBUG(dbgs() << "[BFLOAT] Legalizer sees: " << MI << "\n");
if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
- assert(false && "bfloat detected at the Legalizer");
+ LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the Legalizer \n");
}
// Only legalize pre-isel generic instructions: others don't have types
// and are assumed to be legal.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index bbeeed76631c6..6b7efe671b572 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -61,6 +61,8 @@
#include <cstring>
#include <utility>
+#define DEBUG_TYPE "irtranslator"
+
using namespace llvm;
static cl::opt<bool>
@@ -644,7 +646,7 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (BFloatOpnd) {
MIFlags |= MachineInstr::MIFlag::BFloat16;
- LLVM_DEBUG(dbgs() << "bfloat detected at the MachineInstr" << "\n");
+ LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the MachineInstr \n");
}
return MIFlags;
>From 34ea97b632a5a71cec7311539c3a78b712fa542a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 08:57:30 -0700
Subject: [PATCH 05/16] continue debugging
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 11 +++++------
llvm/lib/CodeGen/GlobalISel/Legalizer.cpp | 3 ++-
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index a396839c5061b..1649b8b4a41b3 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2046,9 +2046,10 @@ bool IRTranslator::translateSimpleIntrinsic(const CallInst &CI,
MachineIRBuilder &MIRBuilder) {
unsigned Op = getSimpleIntrinsicOpcode(ID);
+ LLVM_DEBUG(dbgs() << "[BFLOAT] translateSimpleIntrinsic is called for ID:" << ID << '\n');
// Is this a simple intrinsic?
- if (Op == Intrinsic::not_intrinsic)
+ if (Op == Intrinsic::not_intrinsic){}
return false;
// Yes. Let's translate it.
@@ -2768,8 +2769,6 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- // if (containsBF16Type(U))
- // return false;
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
@@ -4105,8 +4104,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
if (CLI->fallBackToDAGISel(*MF)) {
OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
F.getSubprogram(), &F.getEntryBlock());
- R << "u "
- << ore::NV("Prototype", F.getnable to lower function:FunctionType());
+ R << "unable to lower function: "
+ << ore::NV("Prototype", F.getFunctionType());
reportTranslationError(*MF, *TPC, *ORE, R);
return false;
}
@@ -4118,7 +4117,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
continue; // Don't handle zero sized types.
ArrayRef<Register> VRegs = getOrCreateVRegs(Arg);
VRegArgs.push_back(VRegs);
- LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
+ // LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
if (Arg.hasSwiftErrorAttr()) {
assert(VRegs.size() == 1 && "Too many vregs for Swift error");
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index fc1fede1a6d1c..77289d138bc9c 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -195,7 +195,8 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
LLVM_DEBUG(dbgs() << "[BFLOAT] Legalizer sees: " << MI << "\n");
if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the Legalizer \n");
- }
+ }
+
// Only legalize pre-isel generic instructions: others don't have types
// and are assumed to be legal.
if (!isPreISelGenericOpcode(MI.getOpcode()))
>From eea9fd72db1e3fe7d6e83ecd9c547db75f6d144a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 13:45:57 -0700
Subject: [PATCH 06/16] use a new method of adding the floating point
information to optypefloat
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 36 +++++++++++++++----
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 2 +-
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 5 +--
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 7 ----
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 2 ++
5 files changed, 34 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cfe24c84941a9..cc524ee7e3cb1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,11 +195,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
- MachineIRBuilder &MIRBuilder) {
+ MachineIRBuilder &MIRBuilder, uint32_t Fptype) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
- .addImm(Width);
+ .addImm(Width)
+ .addImm(Fptype);
});
}
@@ -1041,8 +1042,11 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
- if (Ty->isFloatingPointTy())
- return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+ if (Ty->isFloatingPointTy()) {
+ if (Ty->isBFloatTy())
+ return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 1);
+ return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 0);
+ }
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
@@ -1677,14 +1681,33 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
return MI;
MachineBasicBlock &DepMBB = I.getMF()->front();
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
- const MachineInstr *NewMI =
- createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ const MachineInstr *NewMI = nullptr;
+ if(SPIRVOPcode == SPIRV::OpTypeFloat) {
+ if (Ty->isBFloatTy()) {
+ NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+ MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+ .addDef(createTypeVReg(CurMF->getRegInfo()))
+ .addImm(BitWidth)
+ .addImm(1);
+ });
+ } else {
+ NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
MIRBuilder.getDL(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(0);
});
+ }
+ } else {
+ NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+ MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+ .addDef(createTypeVReg(CurMF->getRegInfo()))
+ .addImm(BitWidth);
+ });
+ }
add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
}
@@ -1718,6 +1741,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
default:
llvm_unreachable("Bit width is of unexpected size.");
}
+
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7ef812828b7cc..a2a2e7b289c5d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -436,7 +436,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
- SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
+ SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, uint32_t Fptype);
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 923c92c28d3f6..a823df89eab6c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,10 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
"$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
- "$type = OpTypeFloat $width">;
-def OpTypeBFloat: Op<40, (outs TYPE:$type), (ins i32imm:$width),
- "$type = OpTypeFloat $width 0">;
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, i32imm:$fpencode), "$type = OpTypeFloat $width $fpencode">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index ab26f0208d758..8039cf0c432fa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,7 +1262,6 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpTypeFloat: {
- assert(false && "bfloat detected at the spirv module analysis");
unsigned BitWidth = MI.getOperand(1).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
@@ -1270,12 +1269,6 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Float16);
break;
}
- case SPIRV::OpTypeBFloat: {
- assert(false && "bfloat detected at the spirv module analysis");
- Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
- Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
- break;
- }
case SPIRV::OpTypeVector: {
unsigned NumComponents = MI.getOperand(2).getImm();
if (NumComponents == 8 || NumComponents == 16)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 1a08c6ac0dcaf..33cb8471d747e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -540,6 +540,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Def->getOpcode() != SPIRV::ASSIGN_TYPE)
insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
MF.getRegInfo());
+ LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_ptr_type to" << MI << '\n');
ToErase.push_back(&MI);
} else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
Register Reg = MI.getOperand(1).getReg();
@@ -550,6 +551,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
Def->getOpcode() != SPIRV::ASSIGN_TYPE)
insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
+ LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_type to" << MI << '\n');
ToErase.push_back(&MI);
} else if (MIOp == TargetOpcode::FAKE_USE && MI.getNumOperands() > 0) {
MachineInstr *MdMI = MI.getPrevNode();
>From cfaa1db3cf1d501bf200531c63e0cdce758187ee Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 28 Aug 2025 13:54:45 -0700
Subject: [PATCH 07/16] remove debug
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 9 +--------
llvm/lib/CodeGen/GlobalISel/Legalizer.cpp | 5 -----
llvm/lib/CodeGen/MachineInstr.cpp | 1 -
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 2 +-
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 2 --
5 files changed, 2 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 1649b8b4a41b3..e29239c7706c2 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2046,10 +2046,9 @@ bool IRTranslator::translateSimpleIntrinsic(const CallInst &CI,
MachineIRBuilder &MIRBuilder) {
unsigned Op = getSimpleIntrinsicOpcode(ID);
- LLVM_DEBUG(dbgs() << "[BFLOAT] translateSimpleIntrinsic is called for ID:" << ID << '\n');
// Is this a simple intrinsic?
- if (Op == Intrinsic::not_intrinsic){}
+ if (Op == Intrinsic::not_intrinsic)
return false;
// Yes. Let's translate it.
@@ -2196,8 +2195,6 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
if (translateSimpleIntrinsic(CI, ID, MIRBuilder))
return true;
- LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
- LLVM_DEBUG(dbgs() << "[BFLOAT] IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
switch (ID) {
default:
break;
@@ -2879,7 +2876,6 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
// If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
MIB->copyIRFlags(CI);
if (containsBF16Type(U)) {
- dbgs() << "[BFLOAT] Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
}
@@ -4117,7 +4113,6 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
continue; // Don't handle zero sized types.
ArrayRef<Register> VRegs = getOrCreateVRegs(Arg);
VRegArgs.push_back(VRegs);
- // LLVM_DEBUG(dbgs() << "[BFLOAT] Created vreg in IRTranslator: " << VRegs);
if (Arg.hasSwiftErrorAttr()) {
assert(VRegs.size() == 1 && "Too many vregs for Swift error");
@@ -4165,10 +4160,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
// Translate any debug-info attached to the instruction.
translateDbgInfo(Inst, *CurBuilder);
- LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator: " << Inst << "\n");
if (translate(Inst)) {
- LLVM_DEBUG(dbgs() << "[BFLOAT] Inst at IRTranslator after translate(Inst): " << Inst << "\n");
continue;
}
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index 77289d138bc9c..aef16b5f33af4 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -192,11 +192,6 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
if (MBB->empty())
continue;
for (MachineInstr &MI : *MBB) {
- LLVM_DEBUG(dbgs() << "[BFLOAT] Legalizer sees: " << MI << "\n");
- if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
- LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the Legalizer \n");
- }
-
// Only legalize pre-isel generic instructions: others don't have types
// and are assumed to be legal.
if (!isPreISelGenericOpcode(MI.getOpcode()))
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 6b7efe671b572..159b27dd46b53 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -646,7 +646,6 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (BFloatOpnd) {
MIFlags |= MachineInstr::MIFlag::BFloat16;
- LLVM_DEBUG(dbgs() << "[BFLOAT] bfloat detected at the MachineInstr \n");
}
return MIFlags;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cc524ee7e3cb1..0249444ea2328 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1741,7 +1741,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
default:
llvm_unreachable("Bit width is of unexpected size.");
}
-
+
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 33cb8471d747e..1a08c6ac0dcaf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -540,7 +540,6 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Def->getOpcode() != SPIRV::ASSIGN_TYPE)
insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
MF.getRegInfo());
- LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_ptr_type to" << MI << '\n');
ToErase.push_back(&MI);
} else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
Register Reg = MI.getOperand(1).getReg();
@@ -551,7 +550,6 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
Def->getOpcode() != SPIRV::ASSIGN_TYPE)
insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
- LLVM_DEBUG(dbgs() << "[BFLOAT] SPIRVPreLegalizer is lowering spv_assign_type to" << MI << '\n');
ToErase.push_back(&MI);
} else if (MIOp == TargetOpcode::FAKE_USE && MI.getNumOperands() > 0) {
MachineInstr *MdMI = MI.getPrevNode();
>From 75937a1a3b2c41ca75a33c3cdab9cc4c8a9c305a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 29 Aug 2025 09:04:45 -0700
Subject: [PATCH 08/16] add the testing
---
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 21 ++++++++++++++++--
.../extensions/SPV_KHR_bfloat16/bfloat16.ll | 20 +++++++++++++++++
.../SPV_KHR_bfloat16/bfloat16_dot.ll | 22 +++++++++++++++++++
3 files changed, 61 insertions(+), 2 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..239221d3b3abd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1261,12 +1261,29 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Int8);
break;
}
+ case SPIRV::OpDot: {
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+ LLVM_DEBUG(dbgs() << "[BFLOAT] The OpDot MI is " << MI << '\n');
+ LLVM_DEBUG(dbgs() << "[BFLOAT] TypeDef found is " << TypeDef << '\n');
+ if ((TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
+ (TypeDef->getOperand(1).getImm() == 16)) {
+ Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+ }
+ }
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
+ unsigned Fptype = MI.getOperand(2).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
- else if (BitWidth == 16)
- Reqs.addCapability(SPIRV::Capability::Float16);
+ else if (BitWidth == 16) {
+ if(Fptype == 1) {
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+ } else {
+ Reqs.addCapability(SPIRV::Capability::Float16);
+ }
+ }
break;
}
case SPIRV::OpTypeVector: {
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
new file mode 100644
index 0000000000000..8a917dc219fa0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -0,0 +1,20 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+; XFAIL: *
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR
+
+; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 1
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+
+define spir_kernel void @test() {
+entry:
+ %addr1 = alloca bfloat
+ %addr2 = alloca <2 x bfloat>
+ %data1 = load bfloat, ptr %addr1
+ %data2 = load <2 x bfloat>, ptr %addr2
+ ret void
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
new file mode 100644
index 0000000000000..337eda538c196
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -0,0 +1,22 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+; XFAIL: *
+
+; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
+; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR
+; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16 0
+; CHECK-SPIRV: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+; CHECK-SPIRV: OpDot
+
+declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
+
+define spir_kernel void @test() {
+entry:
+ %addrA = alloca <2 x bfloat>
+ %addrB = alloca <2 x bfloat>
+ %dataA = load <2 x bfloat>, ptr %addrA
+ %dataB = load <2 x bfloat>, ptr %addrB
+ %call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
+ ret void
+}
>From cd7bbfbd51f95939c8cfdb6fa911a2c19724ddca Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 29 Aug 2025 13:44:06 -0700
Subject: [PATCH 09/16] remove adding bfloat flag to MIR
---
llvm/include/llvm/CodeGen/MachineInstr.h | 3 +--
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 12 ++----------
llvm/lib/CodeGen/MachineInstr.cpp | 13 -------------
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 3 +--
llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp | 6 ------
.../SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll | 2 +-
6 files changed, 5 insertions(+), 34 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 6f692ae32510b..10a9b1ff1411d 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,9 +123,8 @@ class MachineInstr
NoUSWrap = 1 << 20, // Instruction supports geps
// no unsigned signed wrap.
SameSign = 1 << 21, // Both operands have the same sign.
- InBounds = 1 << 22, // Pointer arithmetic remains inbounds.
+ InBounds = 1 << 22 // Pointer arithmetic remains inbounds.
// Implies NoUSWrap.
- BFloat16 = 1 << 23 // Instruction with bf16 type
};
private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index e29239c7706c2..d86d3551b2d81 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2581,10 +2581,9 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
Value *FpValue = CI.getOperand(0);
ConstantInt *TestMaskValue = cast<ConstantInt>(CI.getOperand(1));
- uint32_t Flags = MachineInstr::copyFlagsFromInstruction(CI);
MIRBuilder
.buildInstr(TargetOpcode::G_IS_FPCLASS, {getOrCreateVReg(CI)},
- {getOrCreateVReg(*FpValue)}, Flags)
+ {getOrCreateVReg(*FpValue)})
.addImm(TestMaskValue->getZExtValue());
return true;
@@ -2873,12 +2872,6 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
}
}
- // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
- MIB->copyIRFlags(CI);
- if (containsBF16Type(U)) {
- MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
- }
-
return true;
}
@@ -4161,9 +4154,8 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
// Translate any debug-info attached to the instruction.
translateDbgInfo(Inst, *CurBuilder);
- if (translate(Inst)) {
+ if (translate(Inst))
continue;
- }
OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
Inst.getDebugLoc(), BB);
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 159b27dd46b53..29b30cf914208 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -61,8 +61,6 @@
#include <cstring>
#include <utility>
-#define DEBUG_TYPE "irtranslator"
-
using namespace llvm;
static cl::opt<bool>
@@ -637,17 +635,6 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
auto *Sel = dyn_cast<SelectInst>(&I);
bool BFloatOpnd = !Sel && I.getType()->getScalarType()->isBFloatTy();
- if (isa<CallInst>(&I)) {
- for (const Value *Op : I.operands()) {
- Type *OpTy = Op->getType();
- BFloatOpnd |= OpTy->getScalarType()->isBFloatTy();
- }
- }
-
- if (BFloatOpnd) {
- MIFlags |= MachineInstr::MIFlag::BFloat16;
- }
-
return MIFlags;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 239221d3b3abd..44024bde31067 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1264,12 +1264,11 @@ void addInstrRequirements(const MachineInstr &MI,
case SPIRV::OpDot: {
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
- LLVM_DEBUG(dbgs() << "[BFLOAT] The OpDot MI is " << MI << '\n');
- LLVM_DEBUG(dbgs() << "[BFLOAT] TypeDef found is " << TypeDef << '\n');
if ((TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
(TypeDef->getOperand(1).getImm() == 16)) {
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
}
+ break;
}
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
index a964db0b02746..83567519355a8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
@@ -156,12 +156,6 @@ SPIRVPreLegalizerCombinerImpl::SPIRVPreLegalizerCombinerImpl(
}
bool SPIRVPreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
- // Inst combiner is disabled for instructions
- // with bfloat type.
- // TODO: add bfloat-specific inst combiner rules
- if (MI.getFlag(MachineInstr::BFloat16))
- return false;
-
return tryCombineAllImpl(MI);
}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
index 8a917dc219fa0..95eb5a8348d67 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -17,4 +17,4 @@ entry:
%data1 = load bfloat, ptr %addr1
%data2 = load <2 x bfloat>, ptr %addr2
ret void
-}
\ No newline at end of file
+}
>From 7b5f9642662abb530481d32eae1ddfb4d5bdb8f1 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 29 Aug 2025 13:48:31 -0700
Subject: [PATCH 10/16] nit
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 4 +++-
llvm/lib/CodeGen/MachineInstr.cpp | 3 ---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 1 -
3 files changed, 3 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index d86d3551b2d81..541269ab6bfce 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,6 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
+ if (containsBF16Type(U))
+ return false;
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
@@ -4154,7 +4156,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
// Translate any debug-info attached to the instruction.
translateDbgInfo(Inst, *CurBuilder);
- if (translate(Inst))
+ if (translate(Inst))
continue;
OptimizationRemarkMissed R("gisel-irtranslator", "GISelFailure",
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 29b30cf914208..79047f732808a 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,9 +632,6 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (I.getMetadata(LLVMContext::MD_unpredictable))
MIFlags |= MachineInstr::MIFlag::Unpredictable;
- auto *Sel = dyn_cast<SelectInst>(&I);
- bool BFloatOpnd = !Sel && I.getType()->getScalarType()->isBFloatTy();
-
return MIFlags;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0249444ea2328..e021920f46982 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1741,7 +1741,6 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
default:
llvm_unreachable("Bit width is of unexpected size.");
}
-
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
}
>From 05246aaa6558f1561cfa4419f1e70f3028a66a9b Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sun, 31 Aug 2025 10:22:36 -0700
Subject: [PATCH 11/16] change an approach: add a map in the global registry
class for mapping vreg to ifbfloat
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 4 ++--
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 19 ++++++++++++++++++-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 6 ++++++
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 18 ++++++++++--------
4 files changed, 36 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..04b3d6f30e023 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
- return false;
+ // if (containsBF16Type(U))
+ // return false;
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e021920f46982..1a1624366bbf2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1125,9 +1125,15 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
TypesInProcessing.insert(Ty);
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
ExplicitLayoutRequired, EmitIR);
- TypesInProcessing.erase(Ty);
+
+ if(Ty->isFloatingPointTy() && Ty->isBFloatTy()){
+ VRegIsBFloatMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = true;
+ }
+
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
+ TypesInProcessing.erase(Ty);
+
// TODO: We could end up with two SPIR-V types pointing to the same llvm type.
// Is that a problem?
SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
@@ -1162,6 +1168,17 @@ SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
return nullptr;
}
+bool SPIRVGlobalRegistry::getIsBFloatForVReg(Register VReg,
+ const MachineFunction *MF) {
+ auto t = VRegToTypeMap.find(MF ? MF : CurMF);
+ if (t != VRegToTypeMap.end()) {
+ auto tt = t->second.find(VReg);
+ if (tt != t->second.end())
+ return tt->second;
+ }
+ return false;
+}
+
SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
MachineFunction *MF) {
if (!MF)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a2a2e7b289c5d..2629cd515cca6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -37,6 +37,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// type-declaring ones).
DenseMap<const MachineFunction *, DenseMap<Register, SPIRVType *>>
VRegToTypeMap;
+
+ DenseMap<const MachineFunction *, DenseMap<Register, bool>>
+ VRegIsBFloatMap;
DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
@@ -340,6 +343,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// functions.
SPIRVType *getSPIRVTypeForVReg(Register VReg,
const MachineFunction *MF = nullptr) const;
+
+ bool getIsBFloatForVReg(Register VReg,
+ const MachineFunction *MF = nullptr);
// Return the result type of the instruction defining the register.
SPIRVType *getResultType(Register VReg, MachineFunction *MF = nullptr);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 44024bde31067..38ffc51ad63be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,21 +1262,23 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpDot: {
- const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
- SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
- if ((TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
- (TypeDef->getOperand(1).getImm() == 16)) {
- Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
- }
+ const MachineFunction *MF = MI.getMF();
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(1).getReg(), MF);
+ if (isBfloat) {
+ Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+ }
break;
}
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
- unsigned Fptype = MI.getOperand(2).getImm();
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ const MachineFunction *MF = MI.getMF();
+ bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(0).getReg(), MF);
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16) {
- if(Fptype == 1) {
+ if(isBfloat) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
} else {
>From 7743453bd83978d620d6ea4f72890f98d0abbfd7 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sun, 31 Aug 2025 15:37:52 -0700
Subject: [PATCH 12/16] nit change
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 3 --
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 50 +++++++++----------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 24 ++++++---
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 3 +-
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 12 ++---
5 files changed, 50 insertions(+), 42 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 04b3d6f30e023..8a4b1afec3d26 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,9 +2765,6 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- // if (containsBF16Type(U))
- // return false;
-
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 1a1624366bbf2..633ef770ec202 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,12 +195,11 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
- MachineIRBuilder &MIRBuilder, uint32_t Fptype) {
+ MachineIRBuilder &MIRBuilder) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
- .addImm(Width)
- .addImm(Fptype);
+ .addImm(Width);
});
}
@@ -1042,11 +1041,8 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
- if (Ty->isFloatingPointTy()) {
- if (Ty->isBFloatTy())
- return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 1);
- return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 0);
- }
+ if (Ty->isFloatingPointTy())
+ return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
@@ -1125,14 +1121,18 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
TypesInProcessing.insert(Ty);
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
ExplicitLayoutRequired, EmitIR);
-
- if(Ty->isFloatingPointTy() && Ty->isBFloatTy()){
- VRegIsBFloatMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = true;
- }
+ TypesInProcessing.erase(Ty);
- VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
+ MachineFunction *MF = &MIRBuilder.getMF();
+ Register TypeReg = getSPIRVTypeID(SpirvType);
- TypesInProcessing.erase(Ty);
+ if (Ty->isFloatingPointTy()) {
+ if (Ty->isBFloatTy()) {
+ VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
+ } else {
+ VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
+ }
+ }
// TODO: We could end up with two SPIR-V types pointing to the same llvm type.
// Is that a problem?
@@ -1168,17 +1168,6 @@ SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
return nullptr;
}
-bool SPIRVGlobalRegistry::getIsBFloatForVReg(Register VReg,
- const MachineFunction *MF) {
- auto t = VRegToTypeMap.find(MF ? MF : CurMF);
- if (t != VRegToTypeMap.end()) {
- auto tt = t->second.find(VReg);
- if (tt != t->second.end())
- return tt->second;
- }
- return false;
-}
-
SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
MachineFunction *MF) {
if (!MF)
@@ -2128,3 +2117,14 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
}
return false;
}
+
+SPIRVGlobalRegistry::FPVariant SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
+ const MachineFunction *MF) {
+ auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
+ if (t != VRegFPVariantMap.end()) {
+ auto tt = t->second.find(VReg);
+ if (tt != t->second.end())
+ return tt->second;
+ }
+ return FPVariant::NONE;
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2629cd515cca6..8f9a5d3cfc68a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,6 +29,14 @@ using SPIRVType = const MachineInstr;
using StructOffsetDecorator = std::function<void(Register)>;
class SPIRVGlobalRegistry : public SPIRVIRMapping {
+ public:
+ enum class FPVariant {
+ NONE,
+ IEEE_FLOAT,
+ BRAIN_FLOAT
+ };
+
+ private:
// Registers holding values which have types associated with them.
// Initialized upon VReg definition in IRTranslator.
// Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -37,9 +45,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// type-declaring ones).
DenseMap<const MachineFunction *, DenseMap<Register, SPIRVType *>>
VRegToTypeMap;
-
- DenseMap<const MachineFunction *, DenseMap<Register, bool>>
- VRegIsBFloatMap;
DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
@@ -91,6 +96,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// map of aliasing decorations to aliasing metadata
std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;
+ // Maps floating point Registers to their FPVariant (float type kind), given the MachineFunction.
+ DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
+ VRegFPVariantMap;
+
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ,
@@ -343,9 +352,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// functions.
SPIRVType *getSPIRVTypeForVReg(Register VReg,
const MachineFunction *MF = nullptr) const;
-
- bool getIsBFloatForVReg(Register VReg,
- const MachineFunction *MF = nullptr);
// Return the result type of the instruction defining the register.
SPIRVType *getResultType(Register VReg, MachineFunction *MF = nullptr);
@@ -427,6 +433,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// Informs about removal of the machine instruction and invalidates data
// structures referring this instruction.
void invalidateMachineInstr(MachineInstr *MI);
+
+ // Return the FPVariant of to the given floating-point regiester.
+ FPVariant getFPVariantForVReg(Register VReg,
+ const MachineFunction *MF = nullptr);
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
@@ -442,7 +452,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
- SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, uint32_t Fptype);
+ SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a823df89eab6c..f0b938d681dba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,7 +167,8 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
"$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, i32imm:$fpencode), "$type = OpTypeFloat $width $fpencode">;
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
+ "$type = OpTypeFloat $width">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 38ffc51ad63be..24798ff6b8f0e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1264,21 +1264,21 @@ void addInstrRequirements(const MachineInstr &MI,
case SPIRV::OpDot: {
const MachineFunction *MF = MI.getMF();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
- bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(1).getReg(), MF);
- if (isBfloat) {
+ SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+ if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
}
break;
}
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
- SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
- const MachineFunction *MF = MI.getMF();
- bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(0).getReg(), MF);
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16) {
- if(isBfloat) {
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ const MachineFunction *MF = MI.getMF();
+ SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+ if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
} else {
>From d8b9deb4a5111bf112450b3807cab410a87dc006 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sun, 31 Aug 2025 15:47:37 -0700
Subject: [PATCH 13/16] nit change
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 25 +++----------------
1 file changed, 4 insertions(+), 21 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 633ef770ec202..61a8ba19b28e4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1134,6 +1134,8 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
}
}
+ VRegToTypeMap[MF][TypeReg] = SpirvType;
+
// TODO: We could end up with two SPIR-V types pointing to the same llvm type.
// Is that a problem?
SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
@@ -1687,33 +1689,14 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
return MI;
MachineBasicBlock &DepMBB = I.getMF()->front();
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
- const MachineInstr *NewMI = nullptr;
- if(SPIRVOPcode == SPIRV::OpTypeFloat) {
- if (Ty->isBFloatTy()) {
- NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
- MIRBuilder.getDL(), TII.get(SPIRVOPcode))
- .addDef(createTypeVReg(CurMF->getRegInfo()))
- .addImm(BitWidth)
- .addImm(1);
- });
- } else {
- NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ const MachineInstr *NewMI =
+ createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
MIRBuilder.getDL(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(0);
});
- }
- } else {
- NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
- MIRBuilder.getDL(), TII.get(SPIRVOPcode))
- .addDef(createTypeVReg(CurMF->getRegInfo()))
- .addImm(BitWidth);
- });
- }
add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
}
>From f7f16ef1c39914f5bbdde9024b7b2a355e98a82f Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sun, 31 Aug 2025 16:03:11 -0700
Subject: [PATCH 14/16] nit change
---
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 24798ff6b8f0e..653b67e464908 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1277,7 +1277,7 @@ void addInstrRequirements(const MachineInstr &MI,
else if (BitWidth == 16) {
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
const MachineFunction *MF = MI.getMF();
- SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+ SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
>From 07039fdbc7a0d7fce0a0e1a71baa908c1dc9c353 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 2 Sep 2025 05:19:27 -0700
Subject: [PATCH 15/16] add the test for cooperative matirx
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 3 +--
.../SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll | 13 +++++++++++++
.../extensions/SPV_KHR_bfloat16/bfloat16_dot.ll | 1 -
3 files changed, 14 insertions(+), 3 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 61a8ba19b28e4..cb458be51c0c6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1123,9 +1123,9 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
ExplicitLayoutRequired, EmitIR);
TypesInProcessing.erase(Ty);
+ // Record the FPVariant of the floating-point registers in the VRegFPVariantMap.
MachineFunction *MF = &MIRBuilder.getMF();
Register TypeReg = getSPIRVTypeID(SpirvType);
-
if (Ty->isFloatingPointTy()) {
if (Ty->isBFloatTy()) {
VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
@@ -1133,7 +1133,6 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
}
}
-
VRegToTypeMap[MF][TypeReg] = SpirvType;
// TODO: We could end up with two SPIR-V types pointing to the same llvm type.
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
new file mode 100644
index 0000000000000..e9eec4fdac126
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -0,0 +1,13 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+define spir_kernel void @matr() {
+entry:
+ %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+ %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+ %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 0.0)
+ store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
+ ret void
+}
+
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
index 337eda538c196..bd7e3c71e0cc9 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -1,6 +1,5 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
-; XFAIL: *
; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR
>From 82c56b9c8d518c2d8de0d82e785a4f2ce3441206 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 2 Sep 2025 14:48:01 -0700
Subject: [PATCH 16/16] test is failing for cooperative matrix
---
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 1 +
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 16 ++++++++--------
.../bfloat16_cooperative_matrix.ll | 5 -----
3 files changed, 9 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 98c7709acf938..82977319b4779 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -569,6 +569,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
bool HasDefs = I.getNumDefs() > 0;
Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0);
SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr;
+
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
I.getOpcode() == TargetOpcode::G_IMPLICIT_DEF);
if (spvSelect(ResVReg, ResType, I)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 653b67e464908..360dac9364cdc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1611,15 +1611,15 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
}
break;
- case SPIRV::OpTypeCooperativeMatrixKHR:
- if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
- report_fatal_error(
- "OpTypeCooperativeMatrixKHR type requires the "
- "following SPIR-V extension: SPV_KHR_cooperative_matrix",
- false);
- Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
- Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+ case SPIRV::OpTypeCooperativeMatrixKHR: {
+ const MachineFunction *MF = MI.getMF();
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
+ if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
+ Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
+ }
break;
+ }
case SPIRV::OpArithmeticFenceEXT:
if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
report_fatal_error("OpArithmeticFenceEXT requires the "
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
index e9eec4fdac126..fa938fbf648a9 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -4,10 +4,5 @@
define spir_kernel void @matr() {
entry:
%addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
- %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
- %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 0.0)
- store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
ret void
}
-
-declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
More information about the llvm-commits
mailing list