[llvm] [RFC][SPIR-V] Add llvm.arbitrary.fp.convert intrinsic (PR #164252)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 20 06:25:23 PDT 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff origin/main HEAD --extensions h,cpp -- llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp llvm/lib/IR/Verifier.cpp llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp llvm/lib/Target/SPIRV/SPIRVUtils.cpp llvm/lib/Target/SPIRV/SPIRVUtils.h --diff_from_common_commit
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 58b801916..46f70531c 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -79,8 +79,8 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/EHPersonalities.h"
-#include "llvm/IR/Function.h"
#include "llvm/IR/FPEnv.h"
+#include "llvm/IR/Function.h"
#include "llvm/IR/GCStrategy.h"
#include "llvm/IR/GlobalAlias.h"
#include "llvm/IR/GlobalValue.h"
@@ -5867,20 +5867,18 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
auto *RoundingMAV = dyn_cast<MetadataAsValue>(Call.getArgOperand(3));
Check(RoundingMAV, "missing rounding mode metadata operand", Call);
auto *RoundingStr = dyn_cast<MDString>(RoundingMAV->getMetadata());
- Check(RoundingStr, "rounding mode metadata operand must be a string",
- Call);
+ Check(RoundingStr, "rounding mode metadata operand must be a string", Call);
StringRef RoundingInterp = RoundingStr->getString();
- // Check that interpretation strings are not empty. The actual interpretation
- // values are target-specific and not validated here.
+ // Check that interpretation strings are not empty. The actual
+ // interpretation values are target-specific and not validated here.
Check(!ResultInterp.empty(),
"result interpretation metadata string must not be empty", Call);
Check(!InputInterp.empty(),
"input interpretation metadata string must not be empty", Call);
if (RoundingInterp != "none") {
- std::optional<RoundingMode> RM =
- convertStrToRoundingMode(RoundingInterp);
+ std::optional<RoundingMode> RM = convertStrToRoundingMode(RoundingInterp);
Check(RM && *RM != RoundingMode::Dynamic,
"unsupported rounding mode argument", Call);
}
@@ -5890,8 +5888,8 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
Check(SaturationOp, "saturation operand must be a constant integer", Call);
if (SaturationOp) {
uint64_t SatVal = SaturationOp->getZExtValue();
- Check(SatVal == 0 || SatVal == 1,
- "saturation operand must be 0 or 1", Call);
+ Check(SatVal == 0 || SatVal == 1, "saturation operand must be 0 or 1",
+ Call);
}
break;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 3d13e375c..650f97db4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -224,7 +224,8 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFloatWithEncoding(
const MachineFunction *TypeMF = Existing->getParent()->getParent();
if (TypeMF == &MIRBuilder.getMF())
return Existing;
- // Type is from a different function, need to create a new one for current function
+ // Type is from a different function, need to create a new one for current
+ // function
}
SPIRVType *SpvType = getOpTypeFloat(Width, MIRBuilder, FPEncode);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 47353fee1..6e888f070 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -40,8 +40,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
- DenseMap<std::pair<unsigned, unsigned>, SPIRVType *>
- FloatTypesWithEncoding;
+ DenseMap<std::pair<unsigned, unsigned>, SPIRVType *> FloatTypesWithEncoding;
// map a Function to its definition (as a machine instruction operand)
DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
@@ -416,9 +415,10 @@ public:
// Return the number of bits SPIR-V pointers and size_t variables require.
unsigned getPointerSize() const { return PointerSize; }
- SPIRVType *getOrCreateOpTypeFloatWithEncoding(
- uint32_t Width, MachineIRBuilder &MIRBuilder,
- SPIRV::FPEncoding::FPEncoding FPEncode);
+ SPIRVType *
+ getOrCreateOpTypeFloatWithEncoding(uint32_t Width,
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::FPEncoding::FPEncoding FPEncode);
// Returns true if two types are defined and are compatible in a sense of
// OpBitcast instruction
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3963d126d..ec135bac4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -23,7 +23,6 @@
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
-#include "llvm/IR/FPEnv.h"
#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
@@ -31,6 +30,7 @@
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/FPEnv.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
@@ -2108,7 +2108,8 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
static std::optional<SPIRV::FPEncoding::FPEncoding>
getFloat8EncodingFromString(StringRef Interpretation) {
- return StringSwitch<std::optional<SPIRV::FPEncoding::FPEncoding>>(Interpretation)
+ return StringSwitch<std::optional<SPIRV::FPEncoding::FPEncoding>>(
+ Interpretation)
.Case("spv.E4M3EXT", SPIRV::FPEncoding::Float8E4M3EXT)
.Case("spv.E5M2EXT", SPIRV::FPEncoding::Float8E5M2EXT)
.Default(std::nullopt);
@@ -2171,7 +2172,8 @@ struct ArbitraryConvertParams {
if (!ValueOp.isReg())
return std::nullopt;
- auto GetStringFromMD = [&](unsigned OperandIdx) -> std::optional<StringRef> {
+ auto GetStringFromMD =
+ [&](unsigned OperandIdx) -> std::optional<StringRef> {
const MachineOperand &Op = I.getOperand(OperandIdx);
if (!Op.isMetadata())
return std::nullopt;
@@ -2235,9 +2237,9 @@ struct ArbitraryConvertParams {
// Helper function to create Float8 type (scalar or vector)
static SPIRVType *createFloat8Type(unsigned ComponentCount,
- SPIRV::FPEncoding::FPEncoding Encoding,
- MachineIRBuilder &MIRBuilder,
- SPIRVGlobalRegistry &GR) {
+ SPIRV::FPEncoding::FPEncoding Encoding,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry &GR) {
SPIRVType *Float8ScalarType =
GR.getOrCreateOpTypeFloatWithEncoding(8, MIRBuilder, Encoding);
if (ComponentCount > 1)
@@ -2250,9 +2252,8 @@ static SPIRVType *createFloat8Type(unsigned ComponentCount,
static std::optional<Register>
buildBitcastIfNeeded(Register SrcReg, SPIRVType *SrcType, SPIRVType *TargetType,
MachineInstr &I, const TargetInstrInfo &TII,
- const TargetRegisterInfo &TRI,
- const RegisterBankInfo &RBI, MachineRegisterInfo *MRI,
- SPIRVGlobalRegistry &GR) {
+ const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI,
+ MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR) {
if (SrcType == TargetType)
return SrcReg;
@@ -2291,8 +2292,8 @@ bool SPIRVInstructionSelector::selectArbitraryFPConvert(
return false;
}
- auto GetComponentInfo = [&](const SPIRVType *Type)
- -> std::pair<const SPIRVType *, unsigned> {
+ auto GetComponentInfo =
+ [&](const SPIRVType *Type) -> std::pair<const SPIRVType *, unsigned> {
if (!Type)
return {nullptr, 0};
return {GR.getScalarOrVectorComponentType(Type),
@@ -2325,8 +2326,8 @@ bool SPIRVInstructionSelector::selectArbitraryFPConvert(
if (SrcComponentCount != ComponentCount)
return false;
- SPIRVType *Float8Type =
- createFloat8Type(ComponentCount, *Params.getSrcFP8Encoding(), MIRBuilder, GR);
+ SPIRVType *Float8Type = createFloat8Type(
+ ComponentCount, *Params.getSrcFP8Encoding(), MIRBuilder, GR);
std::optional<Register> Float8Reg = buildBitcastIfNeeded(
SrcReg, SrcType, Float8Type, I, TII, TRI, RBI, MRI, GR);
@@ -2335,11 +2336,11 @@ bool SPIRVInstructionSelector::selectArbitraryFPConvert(
if (!RBI.constrainGenericRegister(ResVReg, SPIRV::iIDRegClass, *MRI))
return false;
- auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpFConvert))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(*Float8Reg);
+ auto MIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpFConvert))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(*Float8Reg);
return MIB.constrainAllUses(TII, TRI, RBI);
}
@@ -2368,18 +2369,17 @@ bool SPIRVInstructionSelector::selectArbitraryFPConvert(
if (GR.getScalarOrVectorComponentCount(ResType) != ComponentCount)
return false;
- SPIRVType *Float8Type =
- createFloat8Type(ComponentCount, *Params.getDstFP8Encoding(), MIRBuilder, GR);
+ SPIRVType *Float8Type = createFloat8Type(
+ ComponentCount, *Params.getDstFP8Encoding(), MIRBuilder, GR);
- Register ConvertedReg =
- MRI->createVirtualRegister(&SPIRV::iIDRegClass);
+ Register ConvertedReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
GR.assignSPIRVTypeToVReg(Float8Type, ConvertedReg, *I.getMF());
- auto ConvertMIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpFConvert))
- .addDef(ConvertedReg)
- .addUse(GR.getSPIRVTypeID(Float8Type))
- .addUse(SrcReg);
+ auto ConvertMIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpFConvert))
+ .addDef(ConvertedReg)
+ .addUse(GR.getSPIRVTypeID(Float8Type))
+ .addUse(SrcReg);
if (!ConvertMIB.constrainAllUses(TII, TRI, RBI))
return false;
@@ -2387,8 +2387,7 @@ bool SPIRVInstructionSelector::selectArbitraryFPConvert(
auto MaybeRM = toSPIRVRoundingMode(*RM);
if (!MaybeRM)
return false;
- buildOpDecorate(ConvertedReg, I, TII,
- SPIRV::Decoration::FPRoundingMode,
+ buildOpDecorate(ConvertedReg, I, TII, SPIRV::Decoration::FPRoundingMode,
{static_cast<uint32_t>(*MaybeRM)});
} else if (!RoundingNone) {
return false;
@@ -2396,24 +2395,25 @@ bool SPIRVInstructionSelector::selectArbitraryFPConvert(
// Add saturation decoration if requested
if (Params.UseSaturation) {
- buildOpDecorate(ConvertedReg, I, TII,
- SPIRV::Decoration::SaturatedToLargestFloat8NormalConversionEXT,
- {});
+ buildOpDecorate(
+ ConvertedReg, I, TII,
+ SPIRV::Decoration::SaturatedToLargestFloat8NormalConversionEXT, {});
}
if (!RBI.constrainGenericRegister(ResVReg, SPIRV::iIDRegClass, *MRI))
return false;
- auto BitcastMIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpBitcast))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(ConvertedReg);
+ auto BitcastMIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpBitcast))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(ConvertedReg);
return BitcastMIB.constrainAllUses(TII, TRI, RBI);
}
// Conversion path 3: FP8 -> Int (e.g., spv.E4M3EXT -> signed/unsigned)
if ((Params.DstType == InterpretationType::Signed ||
- Params.DstType == InterpretationType::Unsigned) && Params.isSrcFP8()) {
+ Params.DstType == InterpretationType::Unsigned) &&
+ Params.isSrcFP8()) {
if (RM)
return false;
@@ -2446,8 +2446,8 @@ bool SPIRVInstructionSelector::selectArbitraryFPConvert(
if (SrcComponentCount != ComponentCount)
return false;
- SPIRVType *Float8Type =
- createFloat8Type(ComponentCount, *Params.getSrcFP8Encoding(), MIRBuilder, GR);
+ SPIRVType *Float8Type = createFloat8Type(
+ ComponentCount, *Params.getSrcFP8Encoding(), MIRBuilder, GR);
std::optional<Register> Float8Reg = buildBitcastIfNeeded(
SrcReg, SrcType, Float8Type, I, TII, TRI, RBI, MRI, GR);
@@ -2469,7 +2469,8 @@ bool SPIRVInstructionSelector::selectArbitraryFPConvert(
// Conversion path 4: Int -> FP8 (e.g., signed/unsigned -> spv.E5M2EXT)
if ((Params.SrcType == InterpretationType::Signed ||
- Params.SrcType == InterpretationType::Unsigned) && Params.isDstFP8()) {
+ Params.SrcType == InterpretationType::Unsigned) &&
+ Params.isDstFP8()) {
if (RM)
return false;
@@ -2501,38 +2502,37 @@ bool SPIRVInstructionSelector::selectArbitraryFPConvert(
if (ResComponentCount != ComponentCount)
return false;
- SPIRVType *Float8Type =
- createFloat8Type(ComponentCount, *Params.getDstFP8Encoding(), MIRBuilder, GR);
+ SPIRVType *Float8Type = createFloat8Type(
+ ComponentCount, *Params.getDstFP8Encoding(), MIRBuilder, GR);
- Register ConvertedReg =
- MRI->createVirtualRegister(&SPIRV::iIDRegClass);
+ Register ConvertedReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
GR.assignSPIRVTypeToVReg(Float8Type, ConvertedReg, *I.getMF());
unsigned Opcode = Params.SrcType == InterpretationType::Signed
? SPIRV::OpConvertSToF
: SPIRV::OpConvertUToF;
- auto ConvertMIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(Opcode))
- .addDef(ConvertedReg)
- .addUse(GR.getSPIRVTypeID(Float8Type))
- .addUse(SrcReg);
+ auto ConvertMIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
+ .addDef(ConvertedReg)
+ .addUse(GR.getSPIRVTypeID(Float8Type))
+ .addUse(SrcReg);
if (!ConvertMIB.constrainAllUses(TII, TRI, RBI))
return false;
// Add saturation decoration if requested
if (Params.UseSaturation) {
- buildOpDecorate(ConvertedReg, I, TII,
- SPIRV::Decoration::SaturatedToLargestFloat8NormalConversionEXT,
- {});
+ buildOpDecorate(
+ ConvertedReg, I, TII,
+ SPIRV::Decoration::SaturatedToLargestFloat8NormalConversionEXT, {});
}
if (!RBI.constrainGenericRegister(ResVReg, SPIRV::iIDRegClass, *MRI))
return false;
- auto BitcastMIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpBitcast))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(ConvertedReg);
+ auto BitcastMIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpBitcast))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(ConvertedReg);
return BitcastMIB.constrainAllUses(TII, TRI, RBI);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 31f15f6bc..ef5b94ea0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -373,8 +373,8 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
if (ValueIdx + 3 >= MI->getNumOperands())
break;
- auto GetStringFromMD = [&](unsigned OperandIdx)
- -> std::optional<StringRef> {
+ auto GetStringFromMD =
+ [&](unsigned OperandIdx) -> std::optional<StringRef> {
const MachineOperand &OpMO = MI->getOperand(OperandIdx);
if (!OpMO.isMetadata())
return std::nullopt;
``````````
</details>
https://github.com/llvm/llvm-project/pull/164252
More information about the llvm-commits
mailing list