[llvm] [SPIR-V] Implement SPV_KHR_float_controls2 (PR #146941)
Nathan Gauër via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 7 07:19:08 PDT 2025
================
@@ -551,12 +569,85 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
}
if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
- MCInst Inst;
- Inst.setOpcode(SPIRV::OpExecutionMode);
- Inst.addOperand(MCOperand::createReg(FReg));
- unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
- Inst.addOperand(MCOperand::createImm(EM));
- outputMCInst(Inst);
+ if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
+ // When SPV_KHR_float_controls2 is enabled, ContractionOff is
+ // deprecated. We need to use FPFastMathDefault with the appropriate
+ // flags instead. Since FPFastMathDefault takes a target type, we need
+ // to emit it for each floating-point type to match the effect of
+ // ContractionOff. As of now, there are 4 FP types: fp16, fp32, fp64 and
+ // fp128.
+ constexpr size_t NumFPTypes = 4;
+ for (size_t i = 0; i < NumFPTypes; ++i) {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionMode);
+ Inst.addOperand(MCOperand::createReg(FReg));
+ unsigned EM =
+ static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
+ Inst.addOperand(MCOperand::createImm(EM));
+
+ Type *TargetType = nullptr;
+ switch (i) {
+ case 0:
+ TargetType = Type::getHalfTy(M.getContext());
+ break;
+ case 1:
+ TargetType = Type::getFloatTy(M.getContext());
+ break;
+ case 2:
+ TargetType = Type::getDoubleTy(M.getContext());
+ break;
+ case 3:
+ TargetType = Type::getFP128Ty(M.getContext());
+ break;
+ }
+ assert(TargetType && "Invalid target type for FPFastMathDefault");
+
+ // Find the SPIRV type matching the target type. We'll go over all the
+ // TypeConstVars instructions in the SPIRV module and find the one
+ // that matches the target type. We know the target type is a
+ // floating-point type, so we can skip anything different than
+ // OpTypeFloat. Then, we need to check the bitwidth.
+ bool SPIRVTypeFound = false;
+ for (const MachineInstr *MI :
+ MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
+ // Skip if the instruction is not OpTypeFloat.
+ if (MI->getOpcode() != SPIRV::OpTypeFloat)
+ continue;
+
+ // Skip if TargetTy bitwidth doesn't match MI->getOperand(1), which
+ // is the SPIRV type bit width.
+ if (TargetType->getScalarSizeInBits() != MI->getOperand(1).getImm())
+ continue;
+
+ SPIRVTypeFound = true;
+ const MachineFunction *MF = MI->getMF();
+ MCRegister TypeReg =
+ MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
+ Inst.addOperand(MCOperand::createReg(TypeReg));
+ }
+
+ if (!SPIRVTypeFound) {
+ // The module does not contain this FP type, so we don't need to
+ // emit FPFastMathDefault for it.
+ continue;
+ }
+ // We only end up here because there is no "spirv.ExecutionMode"
+ // metadata, so that means no FPFastMathDefault. Therefore, we only
+ // need to make sure AllowContract is set to 0, as the rest of flags.
+ // We still need to emit the OpExecutionMode instruction, otherwise
+ // it's up to the client API to define the flags.
+ Inst.addOperand(MCOperand::createImm(SPIRV::FPFastMathMode::None));
+ outputMCInst(Inst);
+ }
----------------
Keenuts wrote:
Unless I'm mistaken:
- We always have N possible values for `TargetType->getScalarSizeInBits()`, and for now only `[16, 32, 64, 128]`.
- We want to emit one `OpExecutionMode FPFastMathDefault %float_type` for each `%float_type` we find which has one of those sizes.
But this code iterates 4 times on all instructions to find a float-type with the correct size (given it should be unrolled, but the instruction loop is still there)
Shall the code be this instead?
Also looks like the loop is misleading: it never adds more than 1 TypeReg operand since it's not allowed to emit 2 times the same SPIR-V type for non-aggregates.
```suggestion
constexpr std::array<unsigned, 4> KnownFloatSizes { 16, 32, 64, 128 };
for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
// Skip if the instruction is not OpTypeFloat.
if (MI->getOpcode() != SPIRV::OpTypeFloat)
continue;
// std::count/find not constexpr in cpp17. Loop should be unrolled.
bool IsOfInterest = false;
for (unsigned Size : KnownFloatSizes)
if (Size == MI->getOperand(1).getImm())
IsOfInterest = true;
if (!IsOfInterest)
continue;
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(FReg));
unsigned EM =
static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
Inst.addOperand(MCOperand::createImm(EM));
const MachineFunction *MF = MI->getMF();
MCRegister TypeReg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
Inst.addOperand(MCOperand::createReg(TypeReg));
Inst.addOperand(MCOperand::createImm(SPIRV::FPFastMathMode::None));
outputMCInst(Inst);
}
```
https://github.com/llvm/llvm-project/pull/146941
More information about the llvm-commits
mailing list