[llvm] [SPIR-V] Implement SPV_KHR_float_controls2 (PR #146941)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 7 07:19:08 PDT 2025


================
@@ -551,12 +569,85 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
     }
     if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
         !M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
-      MCInst Inst;
-      Inst.setOpcode(SPIRV::OpExecutionMode);
-      Inst.addOperand(MCOperand::createReg(FReg));
-      unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
-      Inst.addOperand(MCOperand::createImm(EM));
-      outputMCInst(Inst);
+      if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
+        // When SPV_KHR_float_controls2 is enabled, ContractionOff is
+        // deprecated. We need to use FPFastMathDefault with the appropriate
+        // flags instead. Since FPFastMathDefault takes a target type, we need
+        // to emit it for each floating-point type to match the effect of
+        // ContractionOff. As of now, there are 4 FP types: fp16, fp32, fp64 and
+        // fp128.
+        constexpr size_t NumFPTypes = 4;
+        for (size_t i = 0; i < NumFPTypes; ++i) {
+          MCInst Inst;
+          Inst.setOpcode(SPIRV::OpExecutionMode);
+          Inst.addOperand(MCOperand::createReg(FReg));
+          unsigned EM =
+              static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
+          Inst.addOperand(MCOperand::createImm(EM));
+
+          Type *TargetType = nullptr;
+          switch (i) {
+          case 0:
+            TargetType = Type::getHalfTy(M.getContext());
+            break;
+          case 1:
+            TargetType = Type::getFloatTy(M.getContext());
+            break;
+          case 2:
+            TargetType = Type::getDoubleTy(M.getContext());
+            break;
+          case 3:
+            TargetType = Type::getFP128Ty(M.getContext());
+            break;
+          }
+          assert(TargetType && "Invalid target type for FPFastMathDefault");
+
+          // Find the SPIRV type matching the target type. We'll go over all the
+          // TypeConstVars instructions in the SPIRV module and find the one
+          // that matches the target type. We know the target type is a
+          // floating-point type, so we can skip anything different than
+          // OpTypeFloat. Then, we need to check the bitwidth.
+          bool SPIRVTypeFound = false;
+          for (const MachineInstr *MI :
+               MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
+            // Skip if the instruction is not OpTypeFloat.
+            if (MI->getOpcode() != SPIRV::OpTypeFloat)
+              continue;
+
+            // Skip if TargetTy bitwidth doesn't match MI->getOperand(1), which
+            // is the SPIRV type bit width.
+            if (TargetType->getScalarSizeInBits() != MI->getOperand(1).getImm())
+              continue;
+
+            SPIRVTypeFound = true;
+            const MachineFunction *MF = MI->getMF();
+            MCRegister TypeReg =
+                MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
+            Inst.addOperand(MCOperand::createReg(TypeReg));
+          }
+
+          if (!SPIRVTypeFound) {
+            // The module does not contain this FP type, so we don't need to
+            // emit FPFastMathDefault for it.
+            continue;
+          }
+          // We only end up here because there is no "spirv.ExecutionMode"
+          // metadata, so that means no FPFastMathDefault. Therefore, we only
+          // need to make sure AllowContract is set to 0, as the rest of flags.
+          // We still need to emit the OpExecutionMode instruction, otherwise
+          // it's up to the client API to define the flags.
+          Inst.addOperand(MCOperand::createImm(SPIRV::FPFastMathMode::None));
+          outputMCInst(Inst);
+        }
----------------
Keenuts wrote:

Unless I'm mistaken:
 - We always have N possible values for `TargetType->getScalarSizeInBits()`, and for now only `[16, 32, 64, 128]`.
 - We want to emit one `OpExecutionMode FPFastMathDefault %float_type` for each `%float_type` we find which has one of those sizes.

But this code iterates 4 times on all instructions to find a float-type with the correct size (given it should be unrolled, but the instruction loop is still there)
Shall the code be this instead?

Also looks like the loop is misleading: it never adds more than 1 TypeReg operand since it's not allowed to emit 2 times the same SPIR-V type for non-aggregates.

```suggestion
        constexpr std::array<unsigned, 4> KnownFloatSizes { 16, 32, 64, 128 };
        for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
          // Skip if the instruction is not OpTypeFloat.
          if (MI->getOpcode() != SPIRV::OpTypeFloat)
            continue;

          // std::count/find not constexpr in cpp17. Loop should be unrolled.
          bool IsOfInterest = false;
          for (unsigned Size : KnownFloatSizes)
            if (Size == MI->getOperand(1).getImm())
              IsOfInterest = true;
          if (!IsOfInterest)
            continue;

          MCInst Inst;
          Inst.setOpcode(SPIRV::OpExecutionMode);
          Inst.addOperand(MCOperand::createReg(FReg));
          unsigned EM =
              static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
          Inst.addOperand(MCOperand::createImm(EM));

          const MachineFunction *MF = MI->getMF();
          MCRegister TypeReg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
          Inst.addOperand(MCOperand::createReg(TypeReg));
          Inst.addOperand(MCOperand::createImm(SPIRV::FPFastMathMode::None));
          outputMCInst(Inst);
        }
```



https://github.com/llvm/llvm-project/pull/146941


More information about the llvm-commits mailing list