[llvm] [SPIR-V] Add saturation and float rounding mode decorations, a subset of arithmetic constrained floating-point intrinsics, and SPV_INTEL_float_controls2 extension (PR #119862)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 13 16:22:25 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR adds the following features:
* saturation and float rounding mode decorations,
* arithmetic constrained floating-point intrinsics (strict_fadd, strict_fsub, strict_fmul, strict_fdiv, strict_frem, strict_fma and strict_fldexp),
* and SPV_INTEL_float_controls2 extension,
* using recent improvements of emit-intrinsics step, this PR also simplifies pre- and post-legalizer steps and improves instruction selection.

---

Patch is 43.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119862.diff


18 Files Affected:

- (modified) llvm/docs/SPIRVUsage.rst (+2) 
- (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp (+4-5) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+18-5) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+2-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+97-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+7-7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+37-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+12-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+21-4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+9-22) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+3-33) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+11) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_float_controls2/exec_mode_float_control_empty.ll (+18) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_float_controls2/exec_mode_float_control_intel.ll (+74) 
- (modified) llvm/test/CodeGen/SPIRV/instructions/integer-casts.ll (+28) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll (+44) 


``````````diff
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 8f7ac71f8026b3..b7b3d21545168c 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -159,6 +159,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
      - Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
    * - ``SPV_INTEL_cache_controls``
      - Allows cache control information to be applied to memory access instructions.
+   * - ``SPV_INTEL_float_controls2``
+     - Adds execution modes and decorations to control floating-point computations.
    * - ``SPV_INTEL_function_pointers``
      - Allows translation of function pointers.
    * - ``SPV_INTEL_inline_assembly``
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
index 42567f695395ef..68cc6a3a7aac1b 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
@@ -65,11 +65,10 @@ static bool hasType(const MCInst &MI, const MCInstrInfo &MII) {
   // If we define an output, and have at least one other argument.
   if (MCDesc.getNumDefs() == 1 && MCDesc.getNumOperands() >= 2) {
     // Check if we define an ID, and take a type as operand 1.
-    auto &DefOpInfo = MCDesc.operands()[0];
-    auto &FirstArgOpInfo = MCDesc.operands()[1];
-    return DefOpInfo.RegClass >= 0 && FirstArgOpInfo.RegClass >= 0 &&
-           DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
-           FirstArgOpInfo.RegClass == SPIRV::TYPERegClassID;
+    return MCDesc.operands()[0].RegClass >= 0 &&
+           MCDesc.operands()[1].RegClass >= 0 &&
+           MCDesc.operands()[0].RegClass != SPIRV::TYPERegClassID &&
+           MCDesc.operands()[1].RegClass == SPIRV::TYPERegClassID;
   }
   return false;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index f4bfda4932b167..4bfa51e2cccdd8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -173,7 +173,8 @@ using namespace InstructionSet;
 
 namespace SPIRV {
 /// Parses the name part of the demangled builtin call.
-std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
+std::string lookupBuiltinNameHelper(StringRef DemangledCall,
+                                    std::string *Postfix) {
   const static std::string PassPrefix = "(anonymous namespace)::";
   std::string BuiltinName;
   // Itanium Demangler result may have "(anonymous namespace)::" prefix
@@ -231,10 +232,13 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
       "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
       "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
       "Convert|"
-      "UConvert|SConvert|FConvert|SatConvert).*)_R.*");
+      "UConvert|SConvert|FConvert|SatConvert).*)_R(.*)");
   std::smatch Match;
-  if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 2)
+  if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 3) {
     BuiltinName = Match[1].str();
+    if (Postfix)
+      *Postfix = Match[3].str();
+  }
 
   return BuiltinName;
 }
@@ -583,6 +587,15 @@ static Register buildScopeReg(Register CLScopeRegister,
   return buildConstantIntReg32(Scope, MIRBuilder, GR);
 }
 
+static void setRegClassIfNull(Register Reg, MachineRegisterInfo *MRI,
+                              SPIRVGlobalRegistry *GR) {
+  if (MRI->getRegClassOrNull(Reg))
+    return;
+  SPIRVType *SpvType = GR->getSPIRVTypeForVReg(Reg);
+  MRI->setRegClass(Reg,
+                   SpvType ? GR->getRegClass(SpvType) : &SPIRV::iIDRegClass);
+}
+
 static Register buildMemSemanticsReg(Register SemanticsRegister,
                                      Register PtrRegister, unsigned &Semantics,
                                      MachineIRBuilder &MIRBuilder,
@@ -1160,7 +1173,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
         MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
     for (unsigned i = 1; i < Call->Arguments.size(); i++) {
       MIB.addUse(Call->Arguments[i]);
-      MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
+      setRegClassIfNull(Call->Arguments[i], MRI, GR);
     }
     insertAssignInstr(VecReg, nullptr, VecType, GR, MIRBuilder,
                       MIRBuilder.getMF().getRegInfo());
@@ -1176,7 +1189,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
     MIB.addImm(GroupBuiltin->GroupOperation);
   if (Call->Arguments.size() > 0) {
     MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
-    MRI->setRegClass(Call->Arguments[0], &SPIRV::iIDRegClass);
+    setRegClassIfNull(Call->Arguments[0], MRI, GR);
     if (VecReg.isValid())
       MIB.addUse(VecReg);
     else
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 42b452db8b9fb4..0182d9652d18c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -20,7 +20,8 @@
 namespace llvm {
 namespace SPIRV {
 /// Parses the name part of the demangled builtin call.
-std::string lookupBuiltinNameHelper(StringRef DemangledCall);
+std::string lookupBuiltinNameHelper(StringRef DemangledCall,
+                                    std::string *Postfix = nullptr);
 /// Lowers a builtin function call using the provided \p DemangledCall skeleton
 /// and external instruction \p Set.
 ///
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index fb05c1fdbd1e3b..45b39c51164795 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -36,6 +36,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_INTEL_arbitrary_precision_integers},
         {"SPV_INTEL_cache_controls",
          SPIRV::Extension::Extension::SPV_INTEL_cache_controls},
+        {"SPV_INTEL_float_controls2",
+         SPIRV::Extension::Extension::SPV_INTEL_float_controls2},
         {"SPV_INTEL_global_variable_fpga_decorations",
          SPIRV::Extension::Extension::
              SPV_INTEL_global_variable_fpga_decorations},
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 2b623136e602e5..433956f44917fb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -216,6 +216,8 @@ class SPIRVEmitIntrinsics
   bool processFunctionPointers(Module &M);
   void parseFunDeclarations(Module &M);
 
+  void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);
+
 public:
   static char ID;
   SPIRVEmitIntrinsics() : ModulePass(ID) {
@@ -1291,6 +1293,37 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
   }
 }
 
+static void createDecorationIntrinsic(Instruction *I, MDNode *Node,
+                                      IRBuilder<> &B) {
+  LLVMContext &Ctx = I->getContext();
+  setInsertPointAfterDef(B, I);
+  B.CreateIntrinsic(Intrinsic::spv_assign_decoration, {I->getType()},
+                    {I, MetadataAsValue::get(Ctx, MDNode::get(Ctx, {Node}))});
+}
+
+static void createRoundingModeDecoration(Instruction *I,
+                                         unsigned RoundingModeDeco,
+                                         IRBuilder<> &B) {
+  LLVMContext &Ctx = I->getContext();
+  Type *Int32Ty = Type::getInt32Ty(Ctx);
+  MDNode *RoundingModeNode = MDNode::get(
+      Ctx,
+      {ConstantAsMetadata::get(
+           ConstantInt::get(Int32Ty, SPIRV::Decoration::FPRoundingMode)),
+       ConstantAsMetadata::get(ConstantInt::get(Int32Ty, RoundingModeDeco))});
+  createDecorationIntrinsic(I, RoundingModeNode, B);
+}
+
+static void createSaturatedConversionDecoration(Instruction *I,
+                                                IRBuilder<> &B) {
+  LLVMContext &Ctx = I->getContext();
+  Type *Int32Ty = Type::getInt32Ty(Ctx);
+  MDNode *SaturatedConversionNode =
+      MDNode::get(Ctx, {ConstantAsMetadata::get(ConstantInt::get(
+                           Int32Ty, SPIRV::Decoration::SaturatedConversion))});
+  createDecorationIntrinsic(I, SaturatedConversionNode, B);
+}
+
 Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
   if (!Call.isInlineAsm())
     return &Call;
@@ -1312,6 +1345,40 @@ Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
   return &Call;
 }
 
+// Use a tip about rounding mode to create a decoration.
+void SPIRVEmitIntrinsics::useRoundingMode(ConstrainedFPIntrinsic *FPI,
+                                          IRBuilder<> &B) {
+  std::optional<RoundingMode> RM = FPI->getRoundingMode();
+  if (!RM.has_value())
+    return;
+  unsigned RoundingModeDeco = std::numeric_limits<unsigned>::max();
+  switch (RM.value()) {
+  default:
+    // ignore unknown rounding modes
+    break;
+  case RoundingMode::NearestTiesToEven:
+    RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTE;
+    break;
+  case RoundingMode::TowardNegative:
+    RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTN;
+    break;
+  case RoundingMode::TowardPositive:
+    RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTP;
+    break;
+  case RoundingMode::TowardZero:
+    RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
+    break;
+  case RoundingMode::Dynamic:
+  case RoundingMode::NearestTiesToAway:
+    // TODO: check if supported
+    break;
+  }
+  if (RoundingModeDeco == std::numeric_limits<unsigned>::max())
+    return;
+  // Convert the tip about rounding mode into a decoration record.
+  createRoundingModeDecoration(FPI, RoundingModeDeco, B);
+}
+
 Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
   BasicBlock *ParentBB = I.getParent();
   IRBuilder<> B(ParentBB);
@@ -1809,6 +1876,18 @@ bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
   return true;
 }
 
+static unsigned roundingModeMDToDecorationConst(StringRef S) {
+  if (S == "rte")
+    return SPIRV::FPRoundingMode::FPRoundingMode::RTE;
+  if (S == "rtz")
+    return SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
+  if (S == "rtp")
+    return SPIRV::FPRoundingMode::FPRoundingMode::RTP;
+  if (S == "rtn")
+    return SPIRV::FPRoundingMode::FPRoundingMode::RTN;
+  return std::numeric_limits<unsigned>::max();
+}
+
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
                                                 IRBuilder<> &B) {
   // TODO: extend the list of functions with known result types
@@ -1826,8 +1905,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
       Function *CalledF = CI->getCalledFunction();
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+      std::string Postfix;
       if (DemangledName.length() > 0)
-        DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
+        DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName, &Postfix);
       auto ResIt = ResTypeWellKnown.find(DemangledName);
       if (ResIt != ResTypeWellKnown.end()) {
         IsKnown = true;
@@ -1839,6 +1919,19 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
           break;
         }
       }
+      // check if a floating rounding mode info is present
+      StringRef S = Postfix;
+      SmallVector<StringRef, 8> Parts;
+      S.split(Parts, "_", -1, false);
+      if (Parts.size() > 1) {
+        // Convert the info about rounding mode into a decoration record.
+        unsigned RoundingModeDeco = roundingModeMDToDecorationConst(Parts[1]);
+        if (RoundingModeDeco != std::numeric_limits<unsigned>::max())
+          createRoundingModeDecoration(CI, RoundingModeDeco, B);
+        // Check if the SaturatedConversion info is present.
+        if (Parts[1] == "sat")
+          createSaturatedConversionDecoration(CI, B);
+      }
     }
   }
 
@@ -2264,6 +2357,9 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     // already, and force it to be i8 if not
     if (Postpone && !GR->findAssignPtrTypeInstr(I))
       insertAssignPtrTypeIntrs(I, B, true);
+
+    if (auto *FPI = dyn_cast<ConstrainedFPIntrinsic>(I))
+      useRoundingMode(FPI, B);
   }
 
   // Pass backward: use instructions results to specify/update/cast operands
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 5f72a41ddb8647..3e913646d57c80 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -126,14 +126,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
   Width = adjustOpTypeIntWidth(Width);
   const SPIRVSubtarget &ST =
       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
-  if (ST.canUseExtension(
-          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
-    MIRBuilder.buildInstr(SPIRV::OpExtension)
-        .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
-    MIRBuilder.buildInstr(SPIRV::OpCapability)
-        .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
-  }
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    if (ST.canUseExtension(
+            SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+      MIRBuilder.buildInstr(SPIRV::OpExtension)
+          .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
+      MIRBuilder.buildInstr(SPIRV::OpCapability)
+          .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
+    }
     return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
         .addDef(createTypeVReg(MIRBuilder))
         .addImm(Width)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index d95803fea56a58..1bc35c6e57a4f6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -491,16 +491,20 @@ def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>;
 def OpFNegateV: UnOpTyped<"OpFNegate", 127, vfID, fneg>;
 defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>;
 defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>;
+defm OpStrictFAdd: BinOpTypedGen<"OpFAdd", 129, strict_fadd, 1, 1>;
 
 defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>;
 defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>;
+defm OpStrictFSub: BinOpTypedGen<"OpFSub", 131, strict_fsub, 1, 1>;
 
 defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>;
 defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>;
+defm OpStrictFMul: BinOpTypedGen<"OpFMul", 133, strict_fmul, 1, 1>;
 
 defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>;
 defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>;
 defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>;
+defm OpStrictFDiv: BinOpTypedGen<"OpFDiv", 136, strict_fdiv, 1, 1>;
 
 defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>;
 defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;
@@ -508,6 +512,8 @@ defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;
 def OpSMod: BinOp<"OpSMod", 139>;
 
 defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>;
+defm OpStrictFRem: BinOpTypedGen<"OpFRem", 140, strict_frem, 1, 1>;
+
 def OpFMod: BinOp<"OpFMod", 141>;
 
 def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b64030508cfc11..856caf2074fba4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -61,6 +61,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
   /// We need to keep track of the number we give to anonymous global values to
   /// generate the same name every time when this is needed.
   mutable DenseMap<const GlobalValue *, unsigned> UnnamedGlobalIDs;
+  SmallPtrSet<MachineInstr *, 8> DeadMIs;
 
 public:
   SPIRVInstructionSelector(const SPIRVTargetMachine &TM,
@@ -382,6 +383,24 @@ static bool isImm(const MachineOperand &MO, MachineRegisterInfo *MRI);
 // Defined in SPIRVLegalizerInfo.cpp.
 extern bool isTypeFoldingSupported(unsigned Opcode);
 
+bool isDead(const MachineInstr &MI, const MachineRegisterInfo &MRI) {
+  for (const auto &MO : MI.all_defs()) {
+    Register Reg = MO.getReg();
+    if (Reg.isPhysical() || !MRI.use_nodbg_empty(Reg))
+      return false;
+  }
+  if (MI.getOpcode() == TargetOpcode::LOCAL_ESCAPE || MI.isFakeUse() ||
+      MI.isLifetimeMarker())
+    return false;
+  if (MI.isPHI())
+    return true;
+  if (MI.mayStore() || MI.isCall() ||
+      (MI.mayLoad() && MI.hasOrderedMemoryRef()) || MI.isPosition() ||
+      MI.isDebugInstr() || MI.isTerminator() || MI.isJumpTableDebugInfo())
+    return false;
+  return true;
+}
+
 bool SPIRVInstructionSelector::select(MachineInstr &I) {
   resetVRegsType(*I.getParent()->getParent());
 
@@ -404,8 +423,11 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
           }
         });
         assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
-        if (Res)
+        if (Res) {
+          if (!isTriviallyDead(*Def, *MRI) && isDead(*Def, *MRI))
+            DeadMIs.insert(Def);
           return Res;
+        }
       }
       MRI->setRegClass(SrcReg, MRI->getRegClass(DstReg));
       MRI->replaceRegWith(SrcReg, DstReg);
@@ -418,6 +440,15 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
     return constrainSelectedInstRegOperands(I, TII, TRI, RBI);
   }
 
+  if (DeadMIs.contains(&I)) {
+    // if the instruction has been already made dead by folding it away
+    // erase it
+    LLVM_DEBUG(dbgs() << "Instruction is folded and dead.\n");
+    salvageDebugInfo(*MRI, I);
+    I.eraseFromParent();
+    return true;
+  }
+
   if (I.getNumOperands() != I.getNumExplicitOperands()) {
     LLVM_DEBUG(errs() << "Generic instr has unexpected implicit operands\n");
     return false;
@@ -557,9 +588,13 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_UCMP:
     return selectSUCmp(ResVReg, ResType, I, false);
 
+  case TargetOpcode::G_STRICT_FMA:
   case TargetOpcode::G_FMA:
     return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
 
+  case TargetOpcode::G_STRICT_FLDEXP:
+    return selectExtInst(ResVReg, ResType, I, CL::ldexp);
+
   case TargetOpcode::G_FPOW:
     return selectExtInst(ResVReg, ResType, I, CL::pow, GL::Pow);
   case TargetOpcode::G_FPOWI:
@@ -618,6 +653,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_FTANH:
     return selectExtInst(ResVReg, ResType, I, CL::tanh, GL::Tanh);
 
+  case TargetOpcode::G_STRICT_FSQRT:
   case TargetOpcode::G_FSQRT:
     return selectExtInst(ResVReg, ResType, I, CL::sqrt, GL::Sqrt);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 7230e0e6b9fca1..b22027cd2cb931 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -24,19 +24,25 @@ using namespace llvm;
 using namespace llvm::LegalizeActions;
 using namespace llvm::LegalityPredicates;
 
+// clang-format off
 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
     TargetOpcode::G_ADD,
     TargetOpcode::G_FADD,
+    TargetOpcode::G_STRICT_FADD,
     TargetOpcode::G_SUB,
     TargetOpcode::G_FSUB,
+    TargetOpcode::G_STRICT_FSUB,
     TargetOpcode::G_MUL,
     TargetOpcode::G_FMUL,
+    TargetOpcode::G_STRICT_FMUL,
     TargetOpcode::G_SDIV,
     TargetOpcode::G_UDIV,
     TargetOpcode::G_FDIV,
+    TargetOpcode::G_STRICT_FDIV,
     TargetOpcode::G_SREM,
     TargetOpcode::G_UREM,
     TargetOpcode::G_FREM,
+    TargetOpcode::G_STRICT_FREM,
     TargetOpcode::G_FNEG,
     TargetOpcode::G_CONSTANT,
     TargetOpcode::G_FCONSTANT,
@@ -49,6 +55,7 @@ static const std::set<unsigned> TypeFoldingSupportingOpcs = {
     TargetOpcode::G_SELECT,
     TargetOpcode::G_EXTRACT_VECTOR_ELT,
 };
+// clang-format on
 
 bool isTypeFoldingSupported(unsigned Opcode) {
   return TypeFoldingSupportingOpcs.count(Opcode) > 0;
@@ -219,7 +226,11 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .legalFor(allIntScalarsAndVectors)
       .legalIf(extendedScalarsAndVectors);
 
-  getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/119862


More information about the llvm-commits mailing list