[llvm] [SPIR-V] Add saturation and float rounding mode decorations, a subset of arithmetic constrained floating-point intrinsics, and SPV_INTEL_float_controls2 extension (PR #119862)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 13 16:22:25 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
This PR adds the following features:
* saturation and float rounding mode decorations,
* arithmetic constrained floating-point intrinsics (strict_fadd, strict_fsub, strict_fmul, strict_fdiv, strict_frem, strict_fma and strict_fldexp),
* and SPV_INTEL_float_controls2 extension,
* using recent improvements of emit-intrinsics step, this PR also simplifies pre- and post-legalizer steps and improves instruction selection.
---
Patch is 43.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119862.diff
18 Files Affected:
- (modified) llvm/docs/SPIRVUsage.rst (+2)
- (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp (+4-5)
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+18-5)
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+2-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2)
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+97-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+7-7)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+6)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+37-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+12-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+21-4)
- (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+9-22)
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+3-33)
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+11)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_float_controls2/exec_mode_float_control_empty.ll (+18)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_float_controls2/exec_mode_float_control_intel.ll (+74)
- (modified) llvm/test/CodeGen/SPIRV/instructions/integer-casts.ll (+28)
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll (+44)
``````````diff
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 8f7ac71f8026b3..b7b3d21545168c 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -159,6 +159,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
* - ``SPV_INTEL_cache_controls``
- Allows cache control information to be applied to memory access instructions.
+ * - ``SPV_INTEL_float_controls2``
+ - Adds execution modes and decorations to control floating-point computations.
* - ``SPV_INTEL_function_pointers``
- Allows translation of function pointers.
* - ``SPV_INTEL_inline_assembly``
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
index 42567f695395ef..68cc6a3a7aac1b 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
@@ -65,11 +65,10 @@ static bool hasType(const MCInst &MI, const MCInstrInfo &MII) {
// If we define an output, and have at least one other argument.
if (MCDesc.getNumDefs() == 1 && MCDesc.getNumOperands() >= 2) {
// Check if we define an ID, and take a type as operand 1.
- auto &DefOpInfo = MCDesc.operands()[0];
- auto &FirstArgOpInfo = MCDesc.operands()[1];
- return DefOpInfo.RegClass >= 0 && FirstArgOpInfo.RegClass >= 0 &&
- DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
- FirstArgOpInfo.RegClass == SPIRV::TYPERegClassID;
+ return MCDesc.operands()[0].RegClass >= 0 &&
+ MCDesc.operands()[1].RegClass >= 0 &&
+ MCDesc.operands()[0].RegClass != SPIRV::TYPERegClassID &&
+ MCDesc.operands()[1].RegClass == SPIRV::TYPERegClassID;
}
return false;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index f4bfda4932b167..4bfa51e2cccdd8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -173,7 +173,8 @@ using namespace InstructionSet;
namespace SPIRV {
/// Parses the name part of the demangled builtin call.
-std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
+std::string lookupBuiltinNameHelper(StringRef DemangledCall,
+ std::string *Postfix) {
const static std::string PassPrefix = "(anonymous namespace)::";
std::string BuiltinName;
// Itanium Demangler result may have "(anonymous namespace)::" prefix
@@ -231,10 +232,13 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
"ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
"SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
"Convert|"
- "UConvert|SConvert|FConvert|SatConvert).*)_R.*");
+ "UConvert|SConvert|FConvert|SatConvert).*)_R(.*)");
std::smatch Match;
- if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 2)
+ if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 3) {
BuiltinName = Match[1].str();
+ if (Postfix)
+ *Postfix = Match[3].str();
+ }
return BuiltinName;
}
@@ -583,6 +587,15 @@ static Register buildScopeReg(Register CLScopeRegister,
return buildConstantIntReg32(Scope, MIRBuilder, GR);
}
+static void setRegClassIfNull(Register Reg, MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry *GR) {
+ if (MRI->getRegClassOrNull(Reg))
+ return;
+ SPIRVType *SpvType = GR->getSPIRVTypeForVReg(Reg);
+ MRI->setRegClass(Reg,
+ SpvType ? GR->getRegClass(SpvType) : &SPIRV::iIDRegClass);
+}
+
static Register buildMemSemanticsReg(Register SemanticsRegister,
Register PtrRegister, unsigned &Semantics,
MachineIRBuilder &MIRBuilder,
@@ -1160,7 +1173,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
for (unsigned i = 1; i < Call->Arguments.size(); i++) {
MIB.addUse(Call->Arguments[i]);
- MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
+ setRegClassIfNull(Call->Arguments[i], MRI, GR);
}
insertAssignInstr(VecReg, nullptr, VecType, GR, MIRBuilder,
MIRBuilder.getMF().getRegInfo());
@@ -1176,7 +1189,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
MIB.addImm(GroupBuiltin->GroupOperation);
if (Call->Arguments.size() > 0) {
MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
- MRI->setRegClass(Call->Arguments[0], &SPIRV::iIDRegClass);
+ setRegClassIfNull(Call->Arguments[0], MRI, GR);
if (VecReg.isValid())
MIB.addUse(VecReg);
else
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 42b452db8b9fb4..0182d9652d18c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -20,7 +20,8 @@
namespace llvm {
namespace SPIRV {
/// Parses the name part of the demangled builtin call.
-std::string lookupBuiltinNameHelper(StringRef DemangledCall);
+std::string lookupBuiltinNameHelper(StringRef DemangledCall,
+ std::string *Postfix = nullptr);
/// Lowers a builtin function call using the provided \p DemangledCall skeleton
/// and external instruction \p Set.
///
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index fb05c1fdbd1e3b..45b39c51164795 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -36,6 +36,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_INTEL_arbitrary_precision_integers},
{"SPV_INTEL_cache_controls",
SPIRV::Extension::Extension::SPV_INTEL_cache_controls},
+ {"SPV_INTEL_float_controls2",
+ SPIRV::Extension::Extension::SPV_INTEL_float_controls2},
{"SPV_INTEL_global_variable_fpga_decorations",
SPIRV::Extension::Extension::
SPV_INTEL_global_variable_fpga_decorations},
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 2b623136e602e5..433956f44917fb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -216,6 +216,8 @@ class SPIRVEmitIntrinsics
bool processFunctionPointers(Module &M);
void parseFunDeclarations(Module &M);
+ void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);
+
public:
static char ID;
SPIRVEmitIntrinsics() : ModulePass(ID) {
@@ -1291,6 +1293,37 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
}
}
+static void createDecorationIntrinsic(Instruction *I, MDNode *Node,
+ IRBuilder<> &B) {
+ LLVMContext &Ctx = I->getContext();
+ setInsertPointAfterDef(B, I);
+ B.CreateIntrinsic(Intrinsic::spv_assign_decoration, {I->getType()},
+ {I, MetadataAsValue::get(Ctx, MDNode::get(Ctx, {Node}))});
+}
+
+static void createRoundingModeDecoration(Instruction *I,
+ unsigned RoundingModeDeco,
+ IRBuilder<> &B) {
+ LLVMContext &Ctx = I->getContext();
+ Type *Int32Ty = Type::getInt32Ty(Ctx);
+ MDNode *RoundingModeNode = MDNode::get(
+ Ctx,
+ {ConstantAsMetadata::get(
+ ConstantInt::get(Int32Ty, SPIRV::Decoration::FPRoundingMode)),
+ ConstantAsMetadata::get(ConstantInt::get(Int32Ty, RoundingModeDeco))});
+ createDecorationIntrinsic(I, RoundingModeNode, B);
+}
+
+static void createSaturatedConversionDecoration(Instruction *I,
+ IRBuilder<> &B) {
+ LLVMContext &Ctx = I->getContext();
+ Type *Int32Ty = Type::getInt32Ty(Ctx);
+ MDNode *SaturatedConversionNode =
+ MDNode::get(Ctx, {ConstantAsMetadata::get(ConstantInt::get(
+ Int32Ty, SPIRV::Decoration::SaturatedConversion))});
+ createDecorationIntrinsic(I, SaturatedConversionNode, B);
+}
+
Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
if (!Call.isInlineAsm())
return &Call;
@@ -1312,6 +1345,40 @@ Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
return &Call;
}
+// Use a tip about rounding mode to create a decoration.
+void SPIRVEmitIntrinsics::useRoundingMode(ConstrainedFPIntrinsic *FPI,
+ IRBuilder<> &B) {
+ std::optional<RoundingMode> RM = FPI->getRoundingMode();
+ if (!RM.has_value())
+ return;
+ unsigned RoundingModeDeco = std::numeric_limits<unsigned>::max();
+ switch (RM.value()) {
+ default:
+ // ignore unknown rounding modes
+ break;
+ case RoundingMode::NearestTiesToEven:
+ RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTE;
+ break;
+ case RoundingMode::TowardNegative:
+ RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTN;
+ break;
+ case RoundingMode::TowardPositive:
+ RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTP;
+ break;
+ case RoundingMode::TowardZero:
+ RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
+ break;
+ case RoundingMode::Dynamic:
+ case RoundingMode::NearestTiesToAway:
+ // TODO: check if supported
+ break;
+ }
+ if (RoundingModeDeco == std::numeric_limits<unsigned>::max())
+ return;
+ // Convert the tip about rounding mode into a decoration record.
+ createRoundingModeDecoration(FPI, RoundingModeDeco, B);
+}
+
Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
BasicBlock *ParentBB = I.getParent();
IRBuilder<> B(ParentBB);
@@ -1809,6 +1876,18 @@ bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
return true;
}
+static unsigned roundingModeMDToDecorationConst(StringRef S) {
+ if (S == "rte")
+ return SPIRV::FPRoundingMode::FPRoundingMode::RTE;
+ if (S == "rtz")
+ return SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
+ if (S == "rtp")
+ return SPIRV::FPRoundingMode::FPRoundingMode::RTP;
+ if (S == "rtn")
+ return SPIRV::FPRoundingMode::FPRoundingMode::RTN;
+ return std::numeric_limits<unsigned>::max();
+}
+
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
IRBuilder<> &B) {
// TODO: extend the list of functions with known result types
@@ -1826,8 +1905,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
Function *CalledF = CI->getCalledFunction();
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+ std::string Postfix;
if (DemangledName.length() > 0)
- DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
+ DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName, &Postfix);
auto ResIt = ResTypeWellKnown.find(DemangledName);
if (ResIt != ResTypeWellKnown.end()) {
IsKnown = true;
@@ -1839,6 +1919,19 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
break;
}
}
+ // check if a floating rounding mode info is present
+ StringRef S = Postfix;
+ SmallVector<StringRef, 8> Parts;
+ S.split(Parts, "_", -1, false);
+ if (Parts.size() > 1) {
+ // Convert the info about rounding mode into a decoration record.
+ unsigned RoundingModeDeco = roundingModeMDToDecorationConst(Parts[1]);
+ if (RoundingModeDeco != std::numeric_limits<unsigned>::max())
+ createRoundingModeDecoration(CI, RoundingModeDeco, B);
+ // Check if the SaturatedConversion info is present.
+ if (Parts[1] == "sat")
+ createSaturatedConversionDecoration(CI, B);
+ }
}
}
@@ -2264,6 +2357,9 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
// already, and force it to be i8 if not
if (Postpone && !GR->findAssignPtrTypeInstr(I))
insertAssignPtrTypeIntrs(I, B, true);
+
+ if (auto *FPI = dyn_cast<ConstrainedFPIntrinsic>(I))
+ useRoundingMode(FPI, B);
}
// Pass backward: use instructions results to specify/update/cast operands
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 5f72a41ddb8647..3e913646d57c80 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -126,14 +126,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
Width = adjustOpTypeIntWidth(Width);
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
- if (ST.canUseExtension(
- SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
- MIRBuilder.buildInstr(SPIRV::OpExtension)
- .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
- MIRBuilder.buildInstr(SPIRV::OpCapability)
- .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
- }
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ if (ST.canUseExtension(
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+ MIRBuilder.buildInstr(SPIRV::OpExtension)
+ .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
+ MIRBuilder.buildInstr(SPIRV::OpCapability)
+ .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
+ }
return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index d95803fea56a58..1bc35c6e57a4f6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -491,16 +491,20 @@ def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>;
def OpFNegateV: UnOpTyped<"OpFNegate", 127, vfID, fneg>;
defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>;
defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>;
+defm OpStrictFAdd: BinOpTypedGen<"OpFAdd", 129, strict_fadd, 1, 1>;
defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>;
defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>;
+defm OpStrictFSub: BinOpTypedGen<"OpFSub", 131, strict_fsub, 1, 1>;
defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>;
defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>;
+defm OpStrictFMul: BinOpTypedGen<"OpFMul", 133, strict_fmul, 1, 1>;
defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>;
defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>;
defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>;
+defm OpStrictFDiv: BinOpTypedGen<"OpFDiv", 136, strict_fdiv, 1, 1>;
defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>;
defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;
@@ -508,6 +512,8 @@ defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;
def OpSMod: BinOp<"OpSMod", 139>;
defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>;
+defm OpStrictFRem: BinOpTypedGen<"OpFRem", 140, strict_frem, 1, 1>;
+
def OpFMod: BinOp<"OpFMod", 141>;
def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b64030508cfc11..856caf2074fba4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -61,6 +61,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
/// We need to keep track of the number we give to anonymous global values to
/// generate the same name every time when this is needed.
mutable DenseMap<const GlobalValue *, unsigned> UnnamedGlobalIDs;
+ SmallPtrSet<MachineInstr *, 8> DeadMIs;
public:
SPIRVInstructionSelector(const SPIRVTargetMachine &TM,
@@ -382,6 +383,24 @@ static bool isImm(const MachineOperand &MO, MachineRegisterInfo *MRI);
// Defined in SPIRVLegalizerInfo.cpp.
extern bool isTypeFoldingSupported(unsigned Opcode);
+bool isDead(const MachineInstr &MI, const MachineRegisterInfo &MRI) {
+ for (const auto &MO : MI.all_defs()) {
+ Register Reg = MO.getReg();
+ if (Reg.isPhysical() || !MRI.use_nodbg_empty(Reg))
+ return false;
+ }
+ if (MI.getOpcode() == TargetOpcode::LOCAL_ESCAPE || MI.isFakeUse() ||
+ MI.isLifetimeMarker())
+ return false;
+ if (MI.isPHI())
+ return true;
+ if (MI.mayStore() || MI.isCall() ||
+ (MI.mayLoad() && MI.hasOrderedMemoryRef()) || MI.isPosition() ||
+ MI.isDebugInstr() || MI.isTerminator() || MI.isJumpTableDebugInfo())
+ return false;
+ return true;
+}
+
bool SPIRVInstructionSelector::select(MachineInstr &I) {
resetVRegsType(*I.getParent()->getParent());
@@ -404,8 +423,11 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
}
});
assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
- if (Res)
+ if (Res) {
+ if (!isTriviallyDead(*Def, *MRI) && isDead(*Def, *MRI))
+ DeadMIs.insert(Def);
return Res;
+ }
}
MRI->setRegClass(SrcReg, MRI->getRegClass(DstReg));
MRI->replaceRegWith(SrcReg, DstReg);
@@ -418,6 +440,15 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
return constrainSelectedInstRegOperands(I, TII, TRI, RBI);
}
+ if (DeadMIs.contains(&I)) {
+ // if the instruction has been already made dead by folding it away
+ // erase it
+ LLVM_DEBUG(dbgs() << "Instruction is folded and dead.\n");
+ salvageDebugInfo(*MRI, I);
+ I.eraseFromParent();
+ return true;
+ }
+
if (I.getNumOperands() != I.getNumExplicitOperands()) {
LLVM_DEBUG(errs() << "Generic instr has unexpected implicit operands\n");
return false;
@@ -557,9 +588,13 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_UCMP:
return selectSUCmp(ResVReg, ResType, I, false);
+ case TargetOpcode::G_STRICT_FMA:
case TargetOpcode::G_FMA:
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
+ case TargetOpcode::G_STRICT_FLDEXP:
+ return selectExtInst(ResVReg, ResType, I, CL::ldexp);
+
case TargetOpcode::G_FPOW:
return selectExtInst(ResVReg, ResType, I, CL::pow, GL::Pow);
case TargetOpcode::G_FPOWI:
@@ -618,6 +653,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_FTANH:
return selectExtInst(ResVReg, ResType, I, CL::tanh, GL::Tanh);
+ case TargetOpcode::G_STRICT_FSQRT:
case TargetOpcode::G_FSQRT:
return selectExtInst(ResVReg, ResType, I, CL::sqrt, GL::Sqrt);
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 7230e0e6b9fca1..b22027cd2cb931 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -24,19 +24,25 @@ using namespace llvm;
using namespace llvm::LegalizeActions;
using namespace llvm::LegalityPredicates;
+// clang-format off
static const std::set<unsigned> TypeFoldingSupportingOpcs = {
TargetOpcode::G_ADD,
TargetOpcode::G_FADD,
+ TargetOpcode::G_STRICT_FADD,
TargetOpcode::G_SUB,
TargetOpcode::G_FSUB,
+ TargetOpcode::G_STRICT_FSUB,
TargetOpcode::G_MUL,
TargetOpcode::G_FMUL,
+ TargetOpcode::G_STRICT_FMUL,
TargetOpcode::G_SDIV,
TargetOpcode::G_UDIV,
TargetOpcode::G_FDIV,
+ TargetOpcode::G_STRICT_FDIV,
TargetOpcode::G_SREM,
TargetOpcode::G_UREM,
TargetOpcode::G_FREM,
+ TargetOpcode::G_STRICT_FREM,
TargetOpcode::G_FNEG,
TargetOpcode::G_CONSTANT,
TargetOpcode::G_FCONSTANT,
@@ -49,6 +55,7 @@ static const std::set<unsigned> TypeFoldingSupportingOpcs = {
TargetOpcode::G_SELECT,
TargetOpcode::G_EXTRACT_VECTOR_ELT,
};
+// clang-format on
bool isTypeFoldingSupported(unsigned Opcode) {
return TypeFoldingSupportingOpcs.count(Opcode) > 0;
@@ -219,7 +226,11 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
.legalFor(allIntScalarsAndVectors)
.legalIf(extendedScalarsAndVectors);
- getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/119862
More information about the llvm-commits
mailing list