[llvm] [SPIRV] Add bfloat support (PR #155645)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 27 14:53:55 PDT 2025
https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/155645
>From 9e4f483ce4d7c0f710270cc21007b73ce56b59e1 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 03:42:59 -0700
Subject: [PATCH 1/3] initial commit for adding the support of bfloat in
machine inst
---
llvm/include/llvm/CodeGen/MachineInstr.h | 3 ++-
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 9 +++++++--
llvm/lib/CodeGen/MachineInstr.cpp | 3 +++
llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 4 +++-
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 4 ++++
llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 4 ++++
6 files changed, 23 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 10a9b1ff1411d..6f692ae32510b 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,8 +123,9 @@ class MachineInstr
NoUSWrap = 1 << 20, // Instruction supports geps
// no unsigned signed wrap.
SameSign = 1 << 21, // Both operands have the same sign.
- InBounds = 1 << 22 // Pointer arithmetic remains inbounds.
+ InBounds = 1 << 22, // Pointer arithmetic remains inbounds.
// Implies NoUSWrap.
+ BFloat16 = 1 << 23 // Instruction with bf16 type
};
private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..2a6b66984c8ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
- return false;
+ // if (containsBF16Type(U))
+ // return false;
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
@@ -2813,6 +2813,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (isa<FPMathOperator>(CI))
MIB->copyIRFlags(CI);
+ // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+ if (containsBF16Type(U)) {
+ MIB->setFlag(MachineInstr::MIFlag::BFloat16);
+ }
+
for (const auto &Arg : enumerate(CI.args())) {
// If this is required to be an immediate, don't materialize it in a
// register.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 79047f732808a..10ff667bcb522 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,6 +632,9 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (I.getMetadata(LLVMContext::MD_unpredictable))
MIFlags |= MachineInstr::MIFlag::Unpredictable;
+ if (I.getType()->isBFloatTy())
+ MIFlags |= MachineInstr::MIFlag::BFloat16;
+
return MIFlags;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..bd13a3bae92cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
- SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+ SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+ {"SPV_KHR_bfloat16",
+ SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..5bba5cdce3753 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,6 +1267,10 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
+ if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+ }
break;
}
case SPIRV::OpTypeVector: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
>From 32e6c031b24083ef668da44d2dd5c9785d9ba12d Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 12:29:09 -0700
Subject: [PATCH 2/3] debugging
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 3 ++-
llvm/lib/CodeGen/MachineInstr.cpp | 4 +++-
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 2 ++
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 10 ++++++----
4 files changed, 13 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 2a6b66984c8ae..3b112805e082c 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2815,6 +2815,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
// If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
if (containsBF16Type(U)) {
+ // assert(false && "bfloat detected at the IR Translator");
MIB->setFlag(MachineInstr::MIFlag::BFloat16);
}
@@ -2878,7 +2879,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
MIB.addUse(TokenReg, RegState::Implicit);
}
}
-
+
return true;
}
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 10ff667bcb522..1ec8c8a2d90c9 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,8 +632,10 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (I.getMetadata(LLVMContext::MD_unpredictable))
MIFlags |= MachineInstr::MIFlag::Unpredictable;
- if (I.getType()->isBFloatTy())
+ if (I.getType()->getScalarType()->isBFloatTy()) {
MIFlags |= MachineInstr::MIFlag::BFloat16;
+ // assert(false && "bfloat detected at the MachineInstr");
+ }
return MIFlags;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index f0b938d681dba..923c92c28d3f6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -169,6 +169,8 @@ def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness)
"$type = OpTypeInt $width $signedness">;
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
"$type = OpTypeFloat $width">;
+def OpTypeBFloat: Op<40, (outs TYPE:$type), (ins i32imm:$width),
+ "$type = OpTypeFloat $width 0">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 5bba5cdce3753..19c938c96a9e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,10 +1267,12 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
- if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
- Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
- Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
- }
+ break;
+ }
+ case SPIRV::OpTypeBFloat: {
+ assert(false && "bfloat detected at the file");
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
break;
}
case SPIRV::OpTypeVector: {
>From eccbcf845dd59f4a822a5fc9487d48b4e62d7a2c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 14:53:45 -0700
Subject: [PATCH 3/3] Debugging
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 19 ++++++++++++-------
llvm/lib/CodeGen/GlobalISel/Legalizer.cpp | 4 ++++
llvm/lib/CodeGen/MachineInstr.cpp | 14 ++++++++++++--
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 3 ++-
.../SPIRV/SPIRVPreLegalizerCombiner.cpp | 6 ++++++
5 files changed, 36 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 3b112805e082c..4aece388749f8 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2195,6 +2195,8 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
if (translateSimpleIntrinsic(CI, ID, MIRBuilder))
return true;
+ LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for CI: " << CI << '\n');
+ LLVM_DEBUG(dbgs() << "IRTranslator translateKnownIntrinsic for ID: " << ID << '\n');
switch (ID) {
default:
break;
@@ -2581,9 +2583,10 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
Value *FpValue = CI.getOperand(0);
ConstantInt *TestMaskValue = cast<ConstantInt>(CI.getOperand(1));
+ uint32_t Flags = MachineInstr::copyFlagsFromInstruction(CI);
MIRBuilder
.buildInstr(TargetOpcode::G_IS_FPCLASS, {getOrCreateVReg(CI)},
- {getOrCreateVReg(*FpValue)})
+ {getOrCreateVReg(*FpValue)}, Flags)
.addImm(TestMaskValue->getZExtValue());
return true;
@@ -2813,12 +2816,6 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (isa<FPMathOperator>(CI))
MIB->copyIRFlags(CI);
- // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
- if (containsBF16Type(U)) {
- // assert(false && "bfloat detected at the IR Translator");
- MIB->setFlag(MachineInstr::MIFlag::BFloat16);
- }
-
for (const auto &Arg : enumerate(CI.args())) {
// If this is required to be an immediate, don't materialize it in a
// register.
@@ -2879,6 +2876,13 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
MIB.addUse(TokenReg, RegState::Implicit);
}
}
+
+ // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+ if (containsBF16Type(U)) {
+ dbgs() << "Flagged at IRTranslator: " << *MIB.getInstr() << "\n";
+ MIB.getInstr()->setFlag(MachineInstr::MIFlag::BFloat16);
+ }
+ MIB->copyIRFlags(CI);
return true;
}
@@ -4161,6 +4165,7 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) {
// Translate any debug-info attached to the instruction.
translateDbgInfo(Inst, *CurBuilder);
+ LLVM_DEBUG(dbgs() << "Inst at IRTranslator: " << Inst << "\n");
if (translate(Inst))
continue;
diff --git a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
index aef16b5f33af4..6013124558106 100644
--- a/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp
@@ -192,6 +192,10 @@ Legalizer::legalizeMachineFunction(MachineFunction &MF, const LegalizerInfo &LI,
if (MBB->empty())
continue;
for (MachineInstr &MI : *MBB) {
+ dbgs() << "Legalizer sees: " << MI << "\n";
+ if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+ assert(false && "bfloat detected at the Legalizer");
+ }
// Only legalize pre-isel generic instructions: others don't have types
// and are assumed to be legal.
if (!isPreISelGenericOpcode(MI.getOpcode()))
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 1ec8c8a2d90c9..bbeeed76631c6 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,9 +632,19 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (I.getMetadata(LLVMContext::MD_unpredictable))
MIFlags |= MachineInstr::MIFlag::Unpredictable;
- if (I.getType()->getScalarType()->isBFloatTy()) {
+ auto *Sel = dyn_cast<SelectInst>(&I);
+ bool BFloatOpnd = !Sel && I.getType()->getScalarType()->isBFloatTy();
+
+ if (isa<CallInst>(&I)) {
+ for (const Value *Op : I.operands()) {
+ Type *OpTy = Op->getType();
+ BFloatOpnd |= OpTy->getScalarType()->isBFloatTy();
+ }
+ }
+
+ if (BFloatOpnd) {
MIFlags |= MachineInstr::MIFlag::BFloat16;
- // assert(false && "bfloat detected at the MachineInstr");
+ LLVM_DEBUG(dbgs() << "bfloat detected at the MachineInstr" << "\n");
}
return MIFlags;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 19c938c96a9e8..ab26f0208d758 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,6 +1262,7 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpTypeFloat: {
+ assert(false && "bfloat detected at the spirv module analysis");
unsigned BitWidth = MI.getOperand(1).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
@@ -1270,7 +1271,7 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpTypeBFloat: {
- assert(false && "bfloat detected at the file");
+ assert(false && "bfloat detected at the spirv module analysis");
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
index 83567519355a8..a964db0b02746 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
@@ -156,6 +156,12 @@ SPIRVPreLegalizerCombinerImpl::SPIRVPreLegalizerCombinerImpl(
}
bool SPIRVPreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
+ // Inst combiner is disabled for instructions
+ // with bfloat type.
+ // TODO: add bfloat-specific inst combiner rules
+ if (MI.getFlag(MachineInstr::BFloat16))
+ return false;
+
return tryCombineAllImpl(MI);
}
More information about the llvm-commits
mailing list