[llvm] [SPIR-V] Implement SPV_KHR_float_controls2 (PR #146941)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 9 05:52:35 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Marcos Maronas (maarquitos14)

<details>
<summary>Changes</summary>

Implementation of [SPV_KHR_float_controls2](https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_float_controls2.html) extension, and corresponding tests.

Some of the tests make use of `!spirv.ExecutionMode` LLVM named metadata. This is because some SPIR-V instructions don't have a direct equivalent in LLVM IR, so the SPIR-V Target uses different LLVM named metadata to convey the necessary information. Below, you will find an example from one of the newly added tests:
```
!spirv.ExecutionMode = !{!19, !20, !21, !22, !23, !24, !25, !26, !27}
!19 = !{ptr @<!-- -->k_float_controls_float, i32 6028, float poison, i32 131079}
!20 = !{ptr @<!-- -->k_float_controls_all, i32 6028, float poison, i32 131079}
!21 = !{ptr @<!-- -->k_float_controls_float, i32 31}
!22 = !{ptr @<!-- -->k_float_controls_all, i32 31}
!23 = !{ptr @<!-- -->k_float_controls_float, i32 4461, i32 32}
!24 = !{ptr @<!-- -->k_float_controls_all, i32 4461, i32 16}
!25 = !{ptr @<!-- -->k_float_controls_all, i32 4461, i32 32}
!26 = !{ptr @<!-- -->k_float_controls_all, i32 4461, i32 64}
!27 = !{ptr @<!-- -->k_float_controls_all, i32 4461, i32 128}
```
`!spirv.ExecutionMode` contains a list of metadata nodes, and each of them specifies the required operands for expressing a particular `OpExecutionMode` instruction in SPIR-V. For example, `!19 = !{ptr @<!-- -->k_float_controls_float, i32 6028, float poison, i32 131079}` will be lowered to `OpExecutionMode [[k_float_controls_float_ID]] FPFastMathDefault [[float_type_ID]] 131079`.

---

Patch is 66.26 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146941.diff


19 Files Affected:

- (modified) llvm/docs/SPIRVUsage.rst (+29-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp (+174-8) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+25-4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+3-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp (+20-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.h (+2-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+3-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+299-20) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h (+34) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+2-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+57-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/capability-FloatControl2.ll (+1-1) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_float_controls2/decoration.ll (+92) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_float_controls2/exec_mode.ll (+88) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_float_controls2/exec_mode2.ll (+50) 


``````````diff
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 1f563fbfb725a..3f6fa241da8ac 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -218,7 +218,7 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
    * - ``SPV_INTEL_int4``
      - Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
    * - ``SPV_KHR_float_controls2``
-     - Adds ability to specify the floating-point environment in shaders. It can be used on whole modules and individual instructions.
+     - Adds execution modes and decorations to control floating-point computations in both kernels and shaders. It can be used on whole modules and individual instructions.
 
 To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
 
@@ -585,3 +585,31 @@ Group and Subgroup Operations
 For workgroup and subgroup operations, LLVM uses function calls to represent SPIR-V's
 group-based instructions. These builtins facilitate group synchronization, data sharing,
 and collective operations essential for efficient parallel computation.
+
+SPIR-V Instructions Mapped to LLVM Metadata
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Some SPIR-V instructions don't have a direct equivalent in the LLVM IR language. To
+address this, the SPIR-V Target uses different specific LLVM named metadata to convey
+the necessary information. The SPIR-V specification allows multiple module-scope
+instructions, where as LLVM named metadata must be unique. Therefore, the encoding of
+such instructions has the following format:
+
+.. code-block:: llvm
+
+  !spirv.<OpCodeName> = !{!<InstructionMetadata1>, !<InstructionMetadata2>, ..}
+  !<InstructionMetadata1> = !{<Operand1>, <Operand2>, ..}
+  !<InstructionMetadata2> = !{<Operand1>, <Operand2>, ..}
+
+Below, you will find the mappings between SPIR-V instruction and their corresponding
+LLVM IR representations.
+
++--------------------+---------------------------------------------------------+
+| SPIR-V instruction | LLVM IR                                                 |
++====================+=========================================================+
+| OpExecutionMode    | .. code-block:: llvm                                    |
+|                    |                                                         |
+|                    |    !spirv.ExecutionMode = !{!0}                         |
+|                    |    !0 = !{void @worker, i32 30, i32 262149}         |
+|                    |    ; Set execution mode with id 30 (VecTypeHint) and    |
+|                    |    ; literal `262149` operand.                          |
++--------------------+---------------------------------------------------------+
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index 1ebfde2a603b9..24e4e390f98f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -80,6 +80,7 @@ class SPIRVAsmPrinter : public AsmPrinter {
   void outputExecutionMode(const Module &M);
   void outputAnnotations(const Module &M);
   void outputModuleSections();
+  void outputFPFastMathDefaultInfo();
   bool isHidden() {
     return MF->getFunction()
         .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
@@ -497,11 +498,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
   NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
   if (Node) {
     for (unsigned i = 0; i < Node->getNumOperands(); i++) {
+      // If SPV_KHR_float_controls2 is enabled and we find any of
+      // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
+      // modes, skip it, it'll be done somewhere else.
+      if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
+        const auto EM =
+            cast<ConstantInt>(
+                cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1))
+                    ->getValue())
+                ->getZExtValue();
+        if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
+            EM == SPIRV::ExecutionMode::ContractionOff ||
+            EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
+          continue;
+      }
+
       MCInst Inst;
       Inst.setOpcode(SPIRV::OpExecutionMode);
       addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
       outputMCInst(Inst);
     }
+    outputFPFastMathDefaultInfo();
   }
   for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
     const Function &F = *FI;
@@ -551,12 +568,85 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
     }
     if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
         !M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
-      MCInst Inst;
-      Inst.setOpcode(SPIRV::OpExecutionMode);
-      Inst.addOperand(MCOperand::createReg(FReg));
-      unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
-      Inst.addOperand(MCOperand::createImm(EM));
-      outputMCInst(Inst);
+      if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
+        // When SPV_KHR_float_controls2 is enabled, ContractionOff is
+        // deprecated. We need to use FPFastMathDefault with the appropriate
+        // flags instead. Since FPFastMathDefault takes a target type, we need
+        // to emit it for each floating-point type to match the effect of
+        // ContractionOff. As of now, there are 4 FP types: fp16, fp32, fp64 and
+        // fp128.
+        constexpr size_t NumFPTypes = 4;
+        for (size_t i = 0; i < NumFPTypes; ++i) {
+          MCInst Inst;
+          Inst.setOpcode(SPIRV::OpExecutionMode);
+          Inst.addOperand(MCOperand::createReg(FReg));
+          unsigned EM =
+              static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
+          Inst.addOperand(MCOperand::createImm(EM));
+
+          Type *TargetType = nullptr;
+          switch (i) {
+          case 0:
+            TargetType = Type::getHalfTy(M.getContext());
+            break;
+          case 1:
+            TargetType = Type::getFloatTy(M.getContext());
+            break;
+          case 2:
+            TargetType = Type::getDoubleTy(M.getContext());
+            break;
+          case 3:
+            TargetType = Type::getFP128Ty(M.getContext());
+            break;
+          }
+          assert(TargetType && "Invalid target type for FPFastMathDefault");
+
+          // Find the SPIRV type matching the target type. We'll go over all the
+          // TypeConstVars instructions in the SPIRV module and find the one
+          // that matches the target type. We know the target type is a
+          // floating-point type, so we can skip anything different than
+          // OpTypeFloat. Then, we need to check the bitwidth.
+          bool SPIRVTypeFound = false;
+          for (const MachineInstr *MI :
+               MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
+            // Skip if the instruction is not OpTypeFloat.
+            if (MI->getOpcode() != SPIRV::OpTypeFloat)
+              continue;
+
+            // Skip if TargetTy bitwidth doesn't match MI->getOperand(1), which
+            // is the SPIRV type bit width.
+            if (TargetType->getScalarSizeInBits() != MI->getOperand(1).getImm())
+              continue;
+
+            SPIRVTypeFound = true;
+            const MachineFunction *MF = MI->getMF();
+            MCRegister TypeReg =
+                MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
+            Inst.addOperand(MCOperand::createReg(TypeReg));
+          }
+
+          if (!SPIRVTypeFound) {
+            // The module does not contain this FP type, so we don't need to
+            // emit FPFastMathDefault for it.
+            continue;
+          }
+          // We only end up here because there is no "spirv.ExecutionMode"
+          // metadata, so that means no FPFastMathDefault. Therefore, we only
+          // need to make sure AllowContract is set to 0, as the rest of flags.
+          // We still need to emit the OpExecutionMode instruction, otherwise
+          // it's up to the client API to define the flags.
+          Inst.addOperand(MCOperand::createImm(SPIRV::FPFastMathMode::None));
+          outputMCInst(Inst);
+        }
+      } else {
+        MCInst Inst;
+        Inst.setOpcode(SPIRV::OpExecutionMode);
+        Inst.addOperand(MCOperand::createReg(FReg));
+        unsigned EM =
+            static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
+        Inst.addOperand(MCOperand::createImm(EM));
+        outputMCInst(Inst);
+      }
     }
   }
 }
@@ -603,6 +693,80 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
   }
 }
 
+void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
+  for (const auto &[Func, FPFastMathDefaultInfoVec] :
+       MAI->FPFastMathDefaultInfoMap) {
+    for (const auto &FPFastMathDefaultInfo : FPFastMathDefaultInfoVec) {
+      MCInst Inst;
+      Inst.setOpcode(SPIRV::OpExecutionMode);
+      MCRegister FuncReg = MAI->getFuncReg(Func);
+      assert(FuncReg.isValid());
+      Inst.addOperand(MCOperand::createReg(FuncReg));
+      Inst.addOperand(
+          MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault));
+
+      // Find the SPIRV type matching the target type. We'll go over all the
+      // TypeConstVars instructions in the SPIRV module and find the one that
+      // matches the target type. We know the target type is a floating-point
+      // type, so we can skip anything different than OpTypeFloat. Then, we
+      // need to check the bitwidth.
+      const Type *TargetTy = FPFastMathDefaultInfo.Ty;
+      assert(TargetTy && "Expected target type");
+      bool SPIRVTypeFound = false;
+      for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
+        // Skip if the instruction is not OpTypeFloat.
+        if (MI->getOpcode() != SPIRV::OpTypeFloat)
+          continue;
+
+        // Skip if TargetTy bitwidth doesn't match MI->getOperand(1), which is
+        // the SPIRV type bit width.
+        if (TargetTy->getScalarSizeInBits() != MI->getOperand(1).getImm())
+          continue;
+
+        SPIRVTypeFound = true;
+        const MachineFunction *MF = MI->getMF();
+        MCRegister TypeReg =
+            MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
+        Inst.addOperand(MCOperand::createReg(TypeReg));
+      }
+      if (!SPIRVTypeFound) {
+        // The module does not contain this FP type, so we don't need to emit
+        // FPFastMathDefault for it.
+        continue;
+      }
+
+      unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
+      if (FPFastMathDefaultInfo.ContractionOff &&
+          (Flags & SPIRV::FPFastMathMode::AllowContract) &&
+          FPFastMathDefaultInfo.FPFastMathDefault)
+        report_fatal_error(
+            "Conflicting FPFastMathFlags: ContractionOff and AllowContract");
+
+      if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
+          !(Flags &
+            (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
+             SPIRV::FPFastMathMode::NSZ))) {
+        if (FPFastMathDefaultInfo.FPFastMathDefault)
+          report_fatal_error("Conflicting FPFastMathFlags: "
+                             "SignedZeroInfNanPreserve but at least one of "
+                             "NotNaN/NotInf/NSZ is disabled.");
+
+        Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
+                 SPIRV::FPFastMathMode::NSZ;
+      }
+
+      // Don't emit if none of the execution modes was used.
+      if (Flags == SPIRV::FPFastMathMode::None &&
+          !FPFastMathDefaultInfo.ContractionOff &&
+          !FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
+          !FPFastMathDefaultInfo.FPFastMathDefault)
+        continue;
+      Inst.addOperand(MCOperand::createImm(Flags));
+      outputMCInst(Inst);
+    }
+  }
+}
+
 void SPIRVAsmPrinter::outputModuleSections() {
   const Module *M = MMI->getModule();
   // Get the global subtarget to output module-level info.
@@ -611,7 +775,8 @@ void SPIRVAsmPrinter::outputModuleSections() {
   MAI = &SPIRVModuleAnalysis::MAI;
   assert(ST && TII && MAI && M && "Module analysis is required");
   // Output instructions according to the Logical Layout of a Module:
-  // 1,2. All OpCapability instructions, then optional OpExtension instructions.
+  // 1,2. All OpCapability instructions, then optional OpExtension
+  // instructions.
   outputGlobalRequirements();
   // 3. Optional OpExtInstImport instructions.
   outputOpExtInstImports(*M);
@@ -619,7 +784,8 @@ void SPIRVAsmPrinter::outputModuleSections() {
   outputOpMemoryModel();
   // 5. All entry point declarations, using OpEntryPoint.
   outputEntryPoints();
-  // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
+  // 6. Execution-mode declarations, using OpExecutionMode or
+  // OpExecutionModeId.
   outputExecutionMode(*M);
   // 7a. Debug: all OpString, OpSourceExtension, OpSource, and
   // OpSourceContinued, without forward references.
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 6ec7544767c52..280a0197513c0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -697,7 +697,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
                                  MachineIRBuilder &MIRBuilder,
                                  SPIRVGlobalRegistry *GR) {
   if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
+    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
+                              Register(0));
 
   Register ScopeRegister =
       buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -1125,11 +1126,24 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) {
 
 static bool generateExtInst(const SPIRV::IncomingCall *Call,
                             MachineIRBuilder &MIRBuilder,
-                            SPIRVGlobalRegistry *GR) {
+                            SPIRVGlobalRegistry *GR, const CallBase &CB) {
   // Lookup the extended instruction number in the TableGen records.
   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
   uint32_t Number =
       SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
+  // fmin_common and fmax_common are now deprecated, and we should use fmin and
+  // fmax with NotInf and NotNaN flags instead. Keep original number to add
+  // later the NoNans and NoInfs flags.
+  uint32_t OrigNumber = Number;
+  const SPIRVSubtarget &ST =
+      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+  if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2) &&
+      (Number == SPIRV::OpenCLExtInst::fmin_common ||
+       Number == SPIRV::OpenCLExtInst::fmax_common)) {
+    Number = (Number == SPIRV::OpenCLExtInst::fmin_common)
+                 ? SPIRV::OpenCLExtInst::fmin
+                 : SPIRV::OpenCLExtInst::fmax;
+  }
 
   // Build extended instruction.
   auto MIB =
@@ -1141,6 +1155,13 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,
 
   for (auto Argument : Call->Arguments)
     MIB.addUse(Argument);
+  MIB.getInstr()->copyIRFlags(CB);
+  if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common ||
+      OrigNumber == SPIRV::OpenCLExtInst::fmax_common) {
+    // Add NoNans and NoInfs flags to fmin/fmax instruction.
+    MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoNans);
+    MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoInfs);
+  }
   return true;
 }
 
@@ -2844,7 +2865,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  MachineIRBuilder &MIRBuilder,
                                  const Register OrigRet, const Type *OrigRetTy,
                                  const SmallVectorImpl<Register> &Args,
-                                 SPIRVGlobalRegistry *GR) {
+                                 SPIRVGlobalRegistry *GR, const CallBase &CB) {
   LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
 
   // Lookup the builtin in the TableGen records.
@@ -2867,7 +2888,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
   // Match the builtin with implementation based on the grouping.
   switch (Call->Builtin->Group) {
   case SPIRV::Extended:
-    return generateExtInst(Call.get(), MIRBuilder, GR);
+    return generateExtInst(Call.get(), MIRBuilder, GR, CB);
   case SPIRV::Relational:
     return generateRelationalInst(Call.get(), MIRBuilder, GR);
   case SPIRV::Group:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 1a8641a8328dd..f6a5234cd3c73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -39,7 +39,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  MachineIRBuilder &MIRBuilder,
                                  const Register OrigRet, const Type *OrigRetTy,
                                  const SmallVectorImpl<Register> &Args,
-                                 SPIRVGlobalRegistry *GR);
+                                 SPIRVGlobalRegistry *GR, const CallBase &CB);
 
 /// Helper function for finding a builtin function attributes
 /// by a demangled function name. Defined in SPIRVBuiltins.cpp.
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index a412887e51adb..1a7c02c676465 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -641,9 +641,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
                                    GR->getPointerSize()));
       }
     }
-    if (auto Res =
-            SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(),
-                                MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR))
+    if (auto Res = SPIRV::lowerBuiltin(
+            DemangledName, ST->getPreferredInstructionSet(), MIRBuilder,
+            ResVReg, OrigRetTy, ArgVRegs, GR, *Info.CB))
       return *Res;
   }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 83fccdc2bdba3..bc275b09674be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -794,7 +794,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
   // arguments.
   MDNode *GVarMD = nullptr;
   if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)
-    buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
+    buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD, ST);
 
   return Reg;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index f658b67a4c2a5..357aab2f580c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -130,7 +130,8 @@ bool SPIRVInstrInfo::isHeaderInstr(const MachineInstr &MI) const {
   }
 }
 
-bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const {
+bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI,
+                                         bool KHRFloatControls2) const {
   switch (MI.getOpcode()) {
   case SPIRV::OpFAddS:
   case SPIRV::OpFSubS:
@@ -144,6 +145,24 @@ bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const {
   case SPIRV::OpFRemV:
   case SPIRV::OpFMod:
     return true;
+  case SPIRV::OpFNegateV:
+  case SPIRV::OpFNegate:
+  case SPIRV::OpOrdered:
+  case SPIRV::OpUnordered:
+  case SPIRV::OpFOrdEqual:
+  case SPIRV::OpFOrdNotEqual:
+  case SPIRV::OpFOrdLessThan:
+  case SPIRV::OpFOrdLessThanEqual:
+  case SPIRV::OpFOrdGreaterThan:
+  case SPIRV::OpFOrdGreaterThanEqual:
+  case SPIRV::OpFUnordEqual:
+  case SPIRV::OpFUnordNotEqual:
+  case SPIRV::OpFUnordLessThan:
+  case SPIRV::OpFUnordLessThanEqual:
+  case SPIRV::OpFUnordGreaterThan:
+  case SPIRV::OpFUnordGreaterThanEqual:
+  case SPIRV::OpExtInst:
+    return KHRFloatControls2 ? true : false;
   default:
     return false;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
in...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/146941


More information about the llvm-commits mailing list