[llvm] Add support for SPIR-V extension: SPV_INTEL_function_pointers (PR #80759)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 6 04:49:11 PST 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/80759
>From 9ed0da2d9f209eb80706aa5ba0c5bfcb9b83d03d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 5 Feb 2024 14:55:27 -0800
Subject: [PATCH 1/5] add initial support for SPIR-V extension:
SPV_INTEL_function_pointers
---
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 208 +++++++++++++++---
llvm/lib/Target/SPIRV/SPIRVCallLowering.h | 28 +++
.../Target/SPIRV/SPIRVDuplicatesTracker.cpp | 9 +-
.../lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 28 +++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 19 ++
llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp | 1 +
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 10 +
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 27 +++
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 42 ++++
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h | 2 +
llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp | 5 +-
.../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 6 +
.../SPV_INTEL_function_pointers/fp_const.ll | 34 +++
.../fp_two_calls.ll | 34 +++
14 files changed, 425 insertions(+), 28 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 97b25147ffb34..9fe517a996e91 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -34,6 +34,10 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
const Value *Val, ArrayRef<Register> VRegs,
FunctionLoweringInfo &FLI,
Register SwiftErrorVReg) const {
+ // Maybe run postponed production of OpFunction/OpFunctionParameter's
+ if (FormalArgs.F != nullptr)
+ FormalArgs.produceFunArgsInstructions(MIRBuilder, GR, IndirectCalls);
+
// Currently all return types should use a single register.
// TODO: handle the case of multiple registers.
if (VRegs.size() > 1)
@@ -217,6 +221,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// Assign types and names to all args, and store their types for later.
FunctionType *FTy = getOriginalFunctionType(F);
SmallVector<SPIRVType *, 4> ArgTypeVRegs;
+ bool HasOpaquePtrArg = false;
if (VRegs.size() > 0) {
unsigned i = 0;
for (const auto &Arg : F.args()) {
@@ -231,6 +236,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
if (Arg.hasName())
buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
if (Arg.getType()->isPointerTy()) {
+ HasOpaquePtrArg = true;
auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
if (DerefBytes != 0)
buildOpDecorate(VRegs[i][0], MIRBuilder,
@@ -292,33 +298,61 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
}
}
- // Generate a SPIR-V type for the function.
+ // If there is support of indirect calls and there are opaque pointer formal
+ // arguments, there is a chance to specify opaque ptr types later (after the
+ // function's body is processed) by information about the indirect call. To
+ // support this case we may postpone generation of some SPIR-V types, and
+ // OpFunction and OpFunctionParameter's. Otherwise we generate all SPIR-V
+ // types related to the function along with instructions.
+ const auto *ST =
+ static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+ bool hasFunctionPointers =
+ ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
+ bool PostponeOpFunction = HasOpaquePtrArg && hasFunctionPointers;
+
auto MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
- SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
- FTy, RetTy, ArgTypeVRegs, MIRBuilder);
-
- // Build the OpTypeFunction declaring it.
+ SPIRVType *FuncTy = PostponeOpFunction
+ ? nullptr
+ : GR->getOrCreateOpTypeFunctionWithArgs(
+ FTy, RetTy, ArgTypeVRegs, MIRBuilder);
uint32_t FuncControl = getFunctionControl(F);
- MIRBuilder.buildInstr(SPIRV::OpFunction)
- .addDef(FuncVReg)
- .addUse(GR->getSPIRVTypeID(RetTy))
- .addImm(FuncControl)
- .addUse(GR->getSPIRVTypeID(FuncTy));
+ if (PostponeOpFunction) {
+ FormalArgs.F = &F;
+ FormalArgs.KeepMBB = &(MIRBuilder.getMBB());
+ FormalArgs.KeepInsertPt = MIRBuilder.getInsertPt();
+ FormalArgs.FuncVReg = FuncVReg;
+ FormalArgs.RetTy = RetTy;
+ FormalArgs.FuncControl = FuncControl;
+ FormalArgs.OrigFTy = FTy;
+ FormalArgs.ArgTypeVRegs = ArgTypeVRegs;
+ } else {
+ const MachineInstrBuilder &MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
+ .addDef(FuncVReg)
+ .addUse(GR->getSPIRVTypeID(RetTy))
+ .addImm(FuncControl)
+ .addUse(GR->getSPIRVTypeID(FuncTy));
+ const MachineOperand *DefOpFunction = &MB.getInstr()->getOperand(0);
+ GR->recordFunctionDefinition(&F, DefOpFunction);
+ }
// Add OpFunctionParameters.
int i = 0;
for (const auto &Arg : F.args()) {
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
- MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
- .addDef(VRegs[i][0])
- .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
+ if (PostponeOpFunction) {
+ FormalArgs.ArgVRegs.push_back(VRegs[i][0]);
+ } else {
+ MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
+ .addDef(VRegs[i][0])
+ .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
+ }
if (F.isDeclaration())
GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
i++;
@@ -343,9 +377,106 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
{static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
}
+ // Handle function pointers decoration
+ if (hasFunctionPointers) {
+ if (F.hasFnAttribute("referenced-indirectly")) {
+ assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
+ "Unexpected 'referenced-indirectly' attribute of the kernel "
+ "function");
+ buildOpDecorate(FuncVReg, MIRBuilder,
+ SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
+ }
+ }
+
return true;
}
+// Use collect during function's body analysis information about the indirect
+// call to specify opaque ptr types of parent function's parameters
+void SPIRVCallLowering::SPIRVFunFormalArgs::produceFunArgsInstructions(
+ MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR,
+ SmallVector<SPIRVCallLowering::SPIRVIndirectCall> &IndirectCalls) {
+ // Store current insertion point
+ MachineBasicBlock &NextKeepMBB = MIRBuilder.getMBB();
+ MachineBasicBlock::iterator NextKeepInsertPt = MIRBuilder.getInsertPt();
+ // Set a new insertion point
+ MIRBuilder.setInsertPt(*KeepMBB, KeepInsertPt);
+
+ bool IsTypeUpd = false;
+ if (IndirectCalls.size() > 0) {
+ // TODO: add a topological sort of IndirectCalls
+ // Create indirect call data types if any
+ MachineFunction &MF = MIRBuilder.getMF();
+ for (auto const &IC : IndirectCalls) {
+ SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
+ SmallVector<SPIRVType *, 4> SpirvArgTypes;
+ for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
+ SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
+ SpirvArgTypes.push_back(SPIRVTy);
+ if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
+ GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
+ }
+ // SPIR-V function type:
+ FunctionType *FTy =
+ FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
+ SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
+ FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
+ // SPIR-V pointer to function type:
+ SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
+ SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
+ // Correct the Calee type
+ GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
+ }
+
+ // Check if our knowledge about a type of the function parameter is updated
+ // as a result of indirect calls analysis
+ for (size_t i = 0; i < ArgVRegs.size(); ++i) {
+ SPIRVType *ArgTy = GR->getSPIRVTypeForVReg(ArgVRegs[i]);
+ if (ArgTy && ArgTypeVRegs[i] != ArgTy) {
+ ArgTypeVRegs[i] = ArgTy;
+ IsTypeUpd = true;
+ }
+ }
+ }
+
+ // If we have update about function parameter types, create a new function
+ // type instead of the stored
+ // TODO: (maybe) allocated in getOriginalFunctionType(F) this->OrigFTy may be
+ // overwritten and is not used (tracked?) anywhere
+ FunctionType *UpdateFTy = OrigFTy;
+ if (IsTypeUpd) {
+ SmallVector<Type *, 4> ArgTys;
+ for (size_t i = 0; i < ArgTypeVRegs.size(); ++i) {
+ const Type *Ty = GR->getTypeForSPIRVType(ArgTypeVRegs[i]);
+ ArgTys.push_back(const_cast<Type *>(Ty));
+ }
+ // Argument types were specified, we must update function type
+ UpdateFTy = FunctionType::get(F->getReturnType(), ArgTys,
+ F->getFunctionType()->isVarArg());
+ }
+ // Create SPIR-V function type
+ SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
+ UpdateFTy, RetTy, ArgTypeVRegs, MIRBuilder);
+
+ // Emit OpFunction
+ const MachineInstrBuilder &MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
+ .addDef(FuncVReg)
+ .addUse(GR->getSPIRVTypeID(RetTy))
+ .addImm(FuncControl)
+ .addUse(GR->getSPIRVTypeID(FuncTy));
+ GR->recordFunctionDefinition(F, &MB.getInstr()->getOperand(0));
+
+ // Emit OpFunctionParameter's
+ for (size_t i = 0; i < ArgVRegs.size(); ++i) {
+ MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
+ .addDef(ArgVRegs[i])
+ .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
+ }
+
+ // Restore insertion point
+ MIRBuilder.setInsertPt(NextKeepMBB, NextKeepInsertPt);
+}
+
bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
CallLoweringInfo &Info) const {
// Currently call returns should have single vregs.
@@ -356,45 +487,44 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
GR->setCurrentFunc(MF);
FunctionType *FTy = nullptr;
const Function *CF = nullptr;
+ std::string DemangledName;
+ const Type *OrigRetTy = Info.OrigRet.Ty;
// Emit a regular OpFunctionCall. If it's an externally declared function,
// be sure to emit its type and function declaration here. It will be hoisted
// globally later.
if (Info.Callee.isGlobal()) {
+ std::string FuncName = Info.Callee.getGlobal()->getName().str();
+ DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
- FTy = getOriginalFunctionType(*CF);
+ if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
+ OrigRetTy = FTy->getReturnType();
}
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register ResVReg =
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
- std::string FuncName = Info.Callee.getGlobal()->getName().str();
- std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
// TODO: check that it's OCL builtin, then apply OpenCL_std.
if (!DemangledName.empty() && CF && CF->isDeclaration() &&
ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
- const Type *OrigRetTy = Info.OrigRet.Ty;
- if (FTy)
- OrigRetTy = FTy->getReturnType();
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
ArgVRegs.push_back(Arg.Regs[0]);
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
- GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
+ GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
}
if (auto Res = SPIRV::lowerBuiltin(
DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
ResVReg, OrigRetTy, ArgVRegs, GR))
return *Res;
}
- if (CF && CF->isDeclaration() &&
- !GR->find(CF, &MIRBuilder.getMF()).isValid()) {
+ if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
// Emit the type info and forward function declaration to the first MBB
// to ensure VReg definition dependencies are valid across all MBBs.
MachineIRBuilder FirstBlockBuilder;
@@ -416,14 +546,40 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
}
+ unsigned CallOp;
+ if (Info.CB->isIndirectCall()) {
+ if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
+ report_fatal_error("An indirect call is encountered but SPIR-V without "
+ "extensions does not support it",
+ false);
+ // Set instruction operation according to SPV_INTEL_function_pointers
+ CallOp = SPIRV::OpFunctionPointerCallINTEL;
+ // Collect information about the indirect call to support possible
+ // specification of opaque ptr types of parent function's parameters
+ Register CalleeReg = Info.Callee.getReg();
+ if (CalleeReg.isValid()) {
+ SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
+ IndirectCall.Callee = CalleeReg;
+ IndirectCall.RetTy = OrigRetTy;
+ for (const auto &Arg : Info.OrigArgs) {
+ assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
+ IndirectCall.ArgTys.push_back(Arg.Ty);
+ IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
+ }
+ IndirectCalls.push_back(IndirectCall);
+ }
+ } else {
+ // Emit a regular OpFunctionCall
+ CallOp = SPIRV::OpFunctionCall;
+ }
+
// Make sure there's a valid return reg, even for functions returning void.
if (!ResVReg.isValid())
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
- SPIRVType *RetType =
- GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
+ SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
- // Emit the OpFunctionCall and its args.
- auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
+ // Emit the call instruction and its args.
+ auto MIB = MIRBuilder.buildInstr(CallOp)
.addDef(ResVReg)
.addUse(GR->getSPIRVTypeID(RetType))
.add(Info.Callee);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h
index c2d6ad82d507d..680db7ca7b1be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h
@@ -26,6 +26,34 @@ class SPIRVCallLowering : public CallLowering {
// Used to create and assign function, argument, and return type information.
SPIRVGlobalRegistry *GR;
+ // Used to postpone producing of OpFunction and OpFunctionParameter
+ // and use indirect calls to specify argument types
+ struct SPIRVIndirectCall {
+ const Type *RetTy = nullptr;
+ SmallVector<Type *> ArgTys;
+ SmallVector<Register> ArgRegs;
+ Register Callee;
+ };
+ struct SPIRVFunFormalArgs {
+ const Function *F = nullptr;
+ // the insertion point
+ MachineBasicBlock *KeepMBB = nullptr;
+ MachineBasicBlock::iterator KeepInsertPt;
+ // OpFunction and OpFunctionParameter operands
+ Register FuncVReg;
+ SPIRVType *RetTy = nullptr;
+ uint32_t FuncControl;
+ FunctionType *OrigFTy = nullptr;
+ SmallVector<SPIRVType *, 4> ArgTypeVRegs;
+ SmallVector<Register> ArgVRegs;
+
+ void produceFunArgsInstructions(
+ MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR,
+ SmallVector<SPIRVCallLowering::SPIRVIndirectCall> &IndirectCalls);
+ };
+ mutable SPIRVFunFormalArgs FormalArgs;
+ mutable SmallVector<SPIRVIndirectCall> IndirectCalls;
+
public:
SPIRVCallLowering(const SPIRVTargetLowering &TLI, SPIRVGlobalRegistry *GR);
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
index cbe1a53fd7568..d82fb2df4539a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
@@ -54,7 +54,14 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
MachineOperand &Op = MI->getOperand(i);
if (!Op.isReg())
continue;
- MachineOperand *RegOp = &MRI.getVRegDef(Op.getReg())->getOperand(0);
+ MachineInstr *VRegDef = MRI.getVRegDef(Op.getReg());
+ // References to a function via function pointers generate virtual
+ // registers without a definition. We are able to resolve this
+ // reference using Globar Register info into an OpFunction instruction
+ // but do not expect to find it in Reg2Entry.
+ if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL && i == 2)
+ continue;
+ MachineOperand *RegOp = &VRegDef->getOperand(0);
assert((MI->getOpcode() == SPIRV::OpVariable && i == 3) ||
Reg2Entry.count(RegOp));
if (Reg2Entry.count(RegOp))
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 96cc621791e97..65bc651116962 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -264,6 +264,11 @@ class SPIRVGeneralDuplicatesTracker {
SPIRVDuplicatesTracker<Argument> AT;
SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
+ // map a Function to its definition (as a machine instruction operand)
+ DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
+ // map function pointer (as a machine instruction operand) to the used Function
+ DenseMap<const MachineOperand *, const Function *> InstrToFunction;
+
// NOTE: using MOs instead of regs to get rid of MF dependency to be able
// to use flat data structure.
// NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
@@ -280,6 +285,29 @@ class SPIRVGeneralDuplicatesTracker {
void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
MachineModuleInfo *MMI);
+ // Map a machine operand that represents a use of a function via function
+ // pointer to a machine operand that represents the function definition.
+ // Return either the register or invalid value, because we have no context for
+ // a good diagnostic message in case of unexpectedly missing references.
+ const MachineOperand* getFunctionDefinitionByUse(const MachineOperand *Use) {
+ auto ResF = InstrToFunction.find(Use);
+ if (ResF == InstrToFunction.end())
+ return nullptr;
+ auto ResReg = FunctionToInstr.find(ResF->second);
+ return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
+ }
+ // map function pointer (as a machine instruction operand) to the used
+ // Function
+ void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
+ InstrToFunction[MO] = F;
+ }
+ // map a Function to its definition (as a machine instruction)
+ void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
+ FunctionToInstr[F] = MO;
+ }
+ // Return true if any OpConstantFunctionPointerINTEL were generated
+ bool hasConstFunPtr() { return !InstrToFunction.empty(); }
+
void add(const Type *Ty, const MachineFunction *MF, Register R) {
TT.add(Ty, MF, R);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index f3280928c25df..3ccc224ad8a4c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -101,6 +101,25 @@ class SPIRVGlobalRegistry {
DT.buildDepsGraph(Graph, MMI);
}
+ // map function pointer (as a machine instruction operand) to the used
+ // Function
+ void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
+ DT.recordFunctionPointer(MO, F);
+ }
+ // map a Function to its definition (as a machine instruction)
+ void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
+ DT.recordFunctionDefinition(F, MO);
+ }
+ // Map a machine operand that represents a use of a function via function
+ // pointer to a machine operand that represents the function definition.
+ // Return either the register or invalid value, because we have no context for
+ // a good diagnostic message in case of unexpectedly missing references.
+ const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
+ return DT.getFunctionDefinitionByUse(Use);
+ }
+ // Return true if any OpConstantFunctionPointerINTEL were generated
+ bool hasConstFunPtr() { return DT.hasConstFunPtr(); }
+
// Get or create a SPIR-V type corresponding the given LLVM IR type,
// and map it to the given VReg by creating an ASSIGN_TYPE instruction.
SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index 42317453a2370..e3f76419f1313 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -40,6 +40,7 @@ bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const {
case SPIRV::OpSpecConstantComposite:
case SPIRV::OpSpecConstantOp:
case SPIRV::OpUndef:
+ case SPIRV::OpConstantFunctionPointerINTEL:
return true;
default:
return false;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index da033ba32624c..3683fe9ec1648 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -761,3 +761,13 @@ def OpGroupNonUniformBitwiseXor: OpGroupNUGroup<"BitwiseXor", 361>;
def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>;
def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>;
def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>;
+
+// 3.49.7, Constant-Creation Instructions
+
+// - SPV_INTEL_function_pointers
+def OpConstantFunctionPointerINTEL: Op<5600, (outs ID:$res), (ins TYPE:$ty, ID:$fun), "$res = OpConstantFunctionPointerINTEL $ty $fun">;
+
+// 3.49.9. Function Instructions
+
+// - SPV_INTEL_function_pointers
+def OpFunctionPointerCallINTEL: Op<5601, (outs ID:$res), (ins TYPE:$ty, ID:$funPtr, variable_ops), "$res = OpFunctionPointerCallINTEL $ty $funPtr">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 8c1dfc5e626db..f7cf3e1936e33 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1534,6 +1534,12 @@ bool SPIRVInstructionSelector::selectGlobalValue(
GlobalIdent = GV->getGlobalIdentifier();
}
+ // Behaviour of functions as operands depends on availability of the
+ // corresponding extension (SPV_INTEL_function_pointers):
+ // - If there is an extension to operate with functions as operands:
+ // We create a proper constant operand and evaluate a correct type for a
+ // function pointer.
+ // - Without the required extension:
// We have functions as operands in tests with blocks of instruction e.g. in
// transcoding/global_block.ll. These operands are not used and should be
// substituted by zero constants. Their type is expected to be always
@@ -1545,6 +1551,27 @@ bool SPIRVInstructionSelector::selectGlobalValue(
if (!NewReg.isValid()) {
Register NewReg = ResVReg;
GR.add(ConstVal, GR.CurMF, NewReg);
+ const Function *GVFun =
+ STI.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
+ ? dyn_cast<Function>(GV)
+ : nullptr;
+ if (GVFun) {
+ // References to a function via function pointers generate virtual
+ // registers without a definition. We will resolve it later, during
+ // module analysis stage.
+ MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+ Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
+ const MachineInstrBuilder &MB =
+ BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpConstantFunctionPointerINTEL))
+ .addDef(NewReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(FuncVReg);
+ // mapping the function pointer to the used Function
+ GR.recordFunctionPointer(&MB.getInstr()->getOperand(2), GVFun);
+ return MB.constrainAllUses(TII, TRI, RBI);
+ }
return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(NewReg)
.addUse(GR.getSPIRVTypeID(ResType))
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 370da046984f9..6f648c8aa3e88 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -291,6 +291,32 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
}
}
+// References to a function via function pointers generate virtual
+// registers without a definition. We are able to resolve this
+// reference using Globar Register info into an OpFunction instruction
+// and replace dummy operands by the corresponding global register references.
+void SPIRVModuleAnalysis::collectFuncPtrs() {
+ for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars])
+ if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL)
+ collectFuncPtrs(MI);
+}
+
+void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) {
+ const MachineOperand *FunUse = &MI->getOperand(2);
+ if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) {
+ const MachineInstr *FunDefMI = FunDef->getParent();
+ assert(FunDefMI->getOpcode() == SPIRV::OpFunction &&
+ "Constant function pointer must refer to function definition");
+ Register FunDefReg = FunDef->getReg();
+ Register GlobalFunDefReg =
+ MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg);
+ assert(GlobalFunDefReg.isValid() &&
+ "Function definition must refer to a global register");
+ Register FunPtrReg = FunUse->getReg();
+ MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg);
+ }
+}
+
using InstrSignature = SmallVector<size_t>;
using InstrTraces = std::set<InstrSignature>;
@@ -915,6 +941,18 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
}
break;
+ case SPIRV::OpConstantFunctionPointerINTEL:
+ if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
+ Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
+ }
+ break;
+ case SPIRV::OpFunctionPointerCallINTEL:
+ if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
+ Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
+ }
+ break;
default:
break;
}
@@ -1073,6 +1111,10 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) {
// Number rest of registers from N+1 onwards.
numberRegistersGlobally(M);
+ // Update references to OpFunction instructions to use Global Registers
+ if (GR->hasConstFunPtr())
+ collectFuncPtrs();
+
// Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
processOtherInstrs(M);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h
index d0b8027edd420..b05526b06e7da 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h
@@ -224,6 +224,8 @@ struct SPIRVModuleAnalysis : public ModulePass {
void collectFuncNames(MachineInstr &MI, const Function *F);
void processOtherInstrs(const Module &M);
void numberRegistersGlobally(const Module &M);
+ void collectFuncPtrs();
+ void collectFuncPtrs(MachineInstr *MI);
const SPIRVSubtarget *ST;
SPIRVGlobalRegistry *GR;
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index cf6dfb127cdeb..5974e37201fad 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -48,7 +48,10 @@ cl::list<SPIRV::Extension::Extension> Extensions(
clEnumValN(SPIRV::Extension::SPV_KHR_bit_instructions,
"SPV_KHR_bit_instructions",
"This enables bit instructions to be used by SPIR-V modules "
- "without requiring the Shader capability")));
+ "without requiring the Shader capability"),
+ clEnumValN(SPIRV::Extension::SPV_INTEL_function_pointers,
+ "SPV_INTEL_function_pointers",
+ "Allows translation of function pointers")));
// Compare version numbers, but allow 0 to mean unspecified.
static bool isAtLeastVer(uint32_t Target, uint32_t VerToCompareTo) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index ac92ee4a0756a..674e432203eed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -295,6 +295,7 @@ defm SPV_INTEL_usm_storage_classes : ExtensionOperand<100>;
defm SPV_INTEL_fpga_latency_control : ExtensionOperand<101>;
defm SPV_INTEL_fpga_argument_interfaces : ExtensionOperand<102>;
defm SPV_INTEL_optnone : ExtensionOperand<103>;
+defm SPV_INTEL_function_pointers : ExtensionOperand<104>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -452,6 +453,8 @@ defm ArbitraryPrecisionIntegersINTEL : CapabilityOperand<5844, 0, 0, [SPV_INTEL_
defm OptNoneINTEL : CapabilityOperand<6094, 0, 0, [SPV_INTEL_optnone], []>;
defm BitInstructions : CapabilityOperand<6025, 0, 0, [SPV_KHR_bit_instructions], []>;
defm ExpectAssumeKHR : CapabilityOperand<5629, 0, 0, [SPV_KHR_expect_assume], []>;
+defm FunctionPointersINTEL : CapabilityOperand<5603, 0, 0, [SPV_INTEL_function_pointers], []>;
+defm IndirectReferencesINTEL : CapabilityOperand<5604, 0, 0, [SPV_INTEL_function_pointers], []>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
@@ -688,6 +691,7 @@ defm HitAttributeNV : StorageClassOperand<5339, [RayTracingNV]>;
defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>;
defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>;
defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>;
+defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Dim enum values and at the same time
@@ -1179,6 +1183,8 @@ defm CountBuffer : DecorationOperand<5634, 0, 0, [], []>;
defm UserSemantic : DecorationOperand<5635, 0, 0, [], []>;
defm RestrictPointerEXT : DecorationOperand<5355, 0, 0, [], [PhysicalStorageBufferAddressesEXT]>;
defm AliasedPointerEXT : DecorationOperand<5356, 0, 0, [], [PhysicalStorageBufferAddressesEXT]>;
+defm ReferencedIndirectlyINTEL : DecorationOperand<5602, 0, 0, [], [IndirectReferencesINTEL]>;
+defm ArgumentAttributeINTEL : DecorationOperand<6409, 0, 0, [], [FunctionPointersINTEL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define BuiltIn enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll
new file mode 100644
index 0000000000000..0bd1b5d776a94
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll
@@ -0,0 +1,34 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability Int8
+; CHECK-DAG: OpCapability FunctionPointersINTEL
+; CHECK-DAG: OpCapability Int64
+; CHECK: OpExtension "SPV_INTEL_function_pointers"
+; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
+; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
+; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyVoid]] %[[TyInt64]]
+; CHECK-DAG: %[[ConstInt64:.*]] = OpConstant %[[TyInt64]] 42
+; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]]
+; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFunFp]] %[[DefFunFp:.*]]
+; CHECK: %[[FunPtr1:.*]] = OpBitcast %[[#]] %[[ConstFunFp]]
+; CHECK: %[[FunPtr2:.*]] = OpLoad %[[#]] %[[FunPtr1]]
+; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FunPtr2]] %[[ConstInt64]]
+; CHECK: OpReturn
+; CHECK: OpFunctionEnd
+; CHECK: %[[DefFunFp]] = OpFunction %[[TyVoid]] None %[[TyFunFp]]
+
+target triple = "spir64-unknown-unknown"
+
+define spir_kernel void @test() {
+entry:
+ %0 = load ptr, ptr @foo
+ %1 = call i64 %0(i64 42)
+ ret void
+}
+
+define void @foo(i64 %a) {
+entry:
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
new file mode 100644
index 0000000000000..33f176b7325d5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
@@ -0,0 +1,34 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability Int8
+; CHECK-DAG: OpCapability FunctionPointersINTEL
+; CHECK-DAG: OpCapability Int64
+; CHECK: OpExtension "SPV_INTEL_function_pointers"
+; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
+; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-DAG: %[[TyFloat32:.*]] = OpTypeFloat 32
+; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
+; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]]
+; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyFunBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrInt8]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]]
+; CHECK-DAG: %[[TyPtrFunBar:.*]] = OpTypePointer Function %[[TyFunBar]]
+; CHECK: %[[TyFunTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrFunFp]] %[[TyPtrInt8]] %[[TyPtrFunBar]]
+; CHECK: %[[FunTest:.*]] = OpFunction %[[TyVoid]] None %[[TyFunTest]]
+; CHECK: %[[ArgFp:.*]] = OpFunctionParameter %[[TyPtrFunFp]]
+; CHECK: %[[ArgData:.*]] = OpFunctionParameter %[[TyPtrInt8]]
+; CHECK: %[[ArgBar:.*]] = OpFunctionParameter %[[TyPtrFunBar]]
+; CHECK: OpFunctionPointerCallINTEL %[[TyFloat32]] %[[ArgFp]] %[[ArgBar]]
+; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[ArgBar]] %[[ArgFp]] %[[ArgData]]
+; CHECK: OpReturn
+; CHECK: OpFunctionEnd
+
+target triple = "spir64-unknown-unknown"
+
+define spir_kernel void @test(ptr %fp, ptr %data, ptr %bar) {
+entry:
+ %0 = call spir_func float %fp(ptr %bar)
+ %1 = call spir_func i64 %bar(ptr %fp, ptr %data)
+ ret void
+}
>From 3fb82d2a1f6de390f8b379301fbb91b166934ad0 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 5 Feb 2024 15:20:11 -0800
Subject: [PATCH 2/5] apply clang-format
---
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 4 ++--
llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 5 +++--
2 files changed, 5 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 9fe517a996e91..a31180f847f2b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -381,8 +381,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
if (hasFunctionPointers) {
if (F.hasFnAttribute("referenced-indirectly")) {
assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
- "Unexpected 'referenced-indirectly' attribute of the kernel "
- "function");
+ "Unexpected 'referenced-indirectly' attribute of the kernel "
+ "function");
buildOpDecorate(FuncVReg, MIRBuilder,
SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 65bc651116962..dabd73367cd6a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -266,7 +266,8 @@ class SPIRVGeneralDuplicatesTracker {
// map a Function to its definition (as a machine instruction operand)
DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
- // map function pointer (as a machine instruction operand) to the used Function
+ // map function pointer (as a machine instruction operand) to the used
+ // Function
DenseMap<const MachineOperand *, const Function *> InstrToFunction;
// NOTE: using MOs instead of regs to get rid of MF dependency to be able
@@ -289,7 +290,7 @@ class SPIRVGeneralDuplicatesTracker {
// pointer to a machine operand that represents the function definition.
// Return either the register or invalid value, because we have no context for
// a good diagnostic message in case of unexpectedly missing references.
- const MachineOperand* getFunctionDefinitionByUse(const MachineOperand *Use) {
+ const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
auto ResF = InstrToFunction.find(Use);
if (ResF == InstrToFunction.end())
return nullptr;
>From 43e587791406b83585d761fc81d6e6a7aee794a8 Mon Sep 17 00:00:00 2001
From: paperchalice <liujunchang97 at outlook.com>
Date: Tue, 6 Feb 2024 17:56:56 +0800
Subject: [PATCH 3/5] [CodeGen] Port DeadMachineInstructionElim to new pass
manager (#80582)
A simple enough op pass so we can test standard instrumentations in
future.
---
.../llvm/CodeGen/DeadMachineInstructionElim.h | 25 +++++++
llvm/include/llvm/Passes/CodeGenPassBuilder.h | 1 +
.../llvm/Passes/MachinePassRegistry.def | 2 +-
.../CodeGen/DeadMachineInstructionElim.cpp | 66 ++++++++++++-------
llvm/lib/Passes/PassBuilder.cpp | 1 +
llvm/test/CodeGen/AArch64/elim-dead-mi.mir | 1 +
6 files changed, 71 insertions(+), 25 deletions(-)
create mode 100644 llvm/include/llvm/CodeGen/DeadMachineInstructionElim.h
diff --git a/llvm/include/llvm/CodeGen/DeadMachineInstructionElim.h b/llvm/include/llvm/CodeGen/DeadMachineInstructionElim.h
new file mode 100644
index 0000000000000..b9fe7cfccf9a3
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/DeadMachineInstructionElim.h
@@ -0,0 +1,25 @@
+//===- llvm/CodeGen/DeadMachineInstructionElim.h ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_DEADMACHINEINSTRUCTIONELIM_H
+#define LLVM_CODEGEN_DEADMACHINEINSTRUCTIONELIM_H
+
+#include "llvm/CodeGen/MachinePassManager.h"
+
+namespace llvm {
+
+class DeadMachineInstructionElimPass
+ : public MachinePassInfoMixin<DeadMachineInstructionElimPass> {
+public:
+ PreservedAnalyses run(MachineFunction &MF,
+ MachineFunctionAnalysisManager &MFAM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_CODEGEN_DEADMACHINEINSTRUCTIONELIM_H
diff --git a/llvm/include/llvm/Passes/CodeGenPassBuilder.h b/llvm/include/llvm/Passes/CodeGenPassBuilder.h
index 40cc0c046531a..fa6dbd4a49730 100644
--- a/llvm/include/llvm/Passes/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/Passes/CodeGenPassBuilder.h
@@ -26,6 +26,7 @@
#include "llvm/CodeGen/AssignmentTrackingAnalysis.h"
#include "llvm/CodeGen/CallBrPrepare.h"
#include "llvm/CodeGen/CodeGenPrepare.h"
+#include "llvm/CodeGen/DeadMachineInstructionElim.h"
#include "llvm/CodeGen/DwarfEHPrepare.h"
#include "llvm/CodeGen/ExpandMemCmp.h"
#include "llvm/CodeGen/ExpandReductions.h"
diff --git a/llvm/include/llvm/Passes/MachinePassRegistry.def b/llvm/include/llvm/Passes/MachinePassRegistry.def
index 5c3d2659fdfb7..d8972080beeb0 100644
--- a/llvm/include/llvm/Passes/MachinePassRegistry.def
+++ b/llvm/include/llvm/Passes/MachinePassRegistry.def
@@ -123,6 +123,7 @@ MACHINE_FUNCTION_ANALYSIS("pass-instrumentation", PassInstrumentationAnalysis(PI
#ifndef MACHINE_FUNCTION_PASS
#define MACHINE_FUNCTION_PASS(NAME, CREATE_PASS)
#endif
+MACHINE_FUNCTION_PASS("dead-mi-elimination", DeadMachineInstructionElimPass())
// MACHINE_FUNCTION_PASS("free-machine-function", FreeMachineFunctionPass())
MACHINE_FUNCTION_PASS("no-op-machine-function", NoOpMachineFunctionPass())
MACHINE_FUNCTION_PASS("print", PrintMIRPass())
@@ -160,7 +161,6 @@ DUMMY_MACHINE_FUNCTION_PASS("break-false-deps", BreakFalseDepsPass)
DUMMY_MACHINE_FUNCTION_PASS("cfguard-longjmp", CFGuardLongjmpPass)
DUMMY_MACHINE_FUNCTION_PASS("cfi-fixup", CFIFixupPass)
DUMMY_MACHINE_FUNCTION_PASS("cfi-instr-inserter", CFIInstrInserterPass)
-DUMMY_MACHINE_FUNCTION_PASS("dead-mi-elimination", DeadMachineInstructionElimPass)
DUMMY_MACHINE_FUNCTION_PASS("detect-dead-lanes", DetectDeadLanesPass)
DUMMY_MACHINE_FUNCTION_PASS("dot-machine-cfg", MachineCFGPrinter)
DUMMY_MACHINE_FUNCTION_PASS("early-ifcvt", EarlyIfConverterPass)
diff --git a/llvm/lib/CodeGen/DeadMachineInstructionElim.cpp b/llvm/lib/CodeGen/DeadMachineInstructionElim.cpp
index 6a7de3b241fee..facc01452d2f1 100644
--- a/llvm/lib/CodeGen/DeadMachineInstructionElim.cpp
+++ b/llvm/lib/CodeGen/DeadMachineInstructionElim.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/CodeGen/DeadMachineInstructionElim.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/LiveRegUnits.h"
@@ -28,37 +29,57 @@ using namespace llvm;
STATISTIC(NumDeletes, "Number of dead instructions deleted");
namespace {
- class DeadMachineInstructionElim : public MachineFunctionPass {
- bool runOnMachineFunction(MachineFunction &MF) override;
+class DeadMachineInstructionElimImpl {
+ const MachineRegisterInfo *MRI = nullptr;
+ const TargetInstrInfo *TII = nullptr;
+ LiveRegUnits LivePhysRegs;
- const MachineRegisterInfo *MRI = nullptr;
- const TargetInstrInfo *TII = nullptr;
- LiveRegUnits LivePhysRegs;
+public:
+ bool runImpl(MachineFunction &MF);
- public:
- static char ID; // Pass identification, replacement for typeid
- DeadMachineInstructionElim() : MachineFunctionPass(ID) {
- initializeDeadMachineInstructionElimPass(*PassRegistry::getPassRegistry());
- }
+private:
+ bool isDead(const MachineInstr *MI) const;
+ bool eliminateDeadMI(MachineFunction &MF);
+};
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- MachineFunctionPass::getAnalysisUsage(AU);
- }
+class DeadMachineInstructionElim : public MachineFunctionPass {
+public:
+ static char ID; // Pass identification, replacement for typeid
- private:
- bool isDead(const MachineInstr *MI) const;
+ DeadMachineInstructionElim() : MachineFunctionPass(ID) {
+ initializeDeadMachineInstructionElimPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnMachineFunction(MachineFunction &MF) override {
+ if (skipFunction(MF.getFunction()))
+ return false;
+ return DeadMachineInstructionElimImpl().runImpl(MF);
+ }
- bool eliminateDeadMI(MachineFunction &MF);
- };
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ MachineFunctionPass::getAnalysisUsage(AU);
+ }
+};
+} // namespace
+
+PreservedAnalyses
+DeadMachineInstructionElimPass::run(MachineFunction &MF,
+ MachineFunctionAnalysisManager &) {
+ if (!DeadMachineInstructionElimImpl().runImpl(MF))
+ return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
+
char DeadMachineInstructionElim::ID = 0;
char &llvm::DeadMachineInstructionElimID = DeadMachineInstructionElim::ID;
INITIALIZE_PASS(DeadMachineInstructionElim, DEBUG_TYPE,
"Remove dead machine instructions", false, false)
-bool DeadMachineInstructionElim::isDead(const MachineInstr *MI) const {
+bool DeadMachineInstructionElimImpl::isDead(const MachineInstr *MI) const {
// Technically speaking inline asm without side effects and no defs can still
// be deleted. But there is so much bad inline asm code out there, we should
// let them be.
@@ -102,10 +123,7 @@ bool DeadMachineInstructionElim::isDead(const MachineInstr *MI) const {
return true;
}
-bool DeadMachineInstructionElim::runOnMachineFunction(MachineFunction &MF) {
- if (skipFunction(MF.getFunction()))
- return false;
-
+bool DeadMachineInstructionElimImpl::runImpl(MachineFunction &MF) {
MRI = &MF.getRegInfo();
const TargetSubtargetInfo &ST = MF.getSubtarget();
@@ -118,7 +136,7 @@ bool DeadMachineInstructionElim::runOnMachineFunction(MachineFunction &MF) {
return AnyChanges;
}
-bool DeadMachineInstructionElim::eliminateDeadMI(MachineFunction &MF) {
+bool DeadMachineInstructionElimImpl::eliminateDeadMI(MachineFunction &MF) {
bool AnyChanges = false;
// Loop over all instructions in all blocks, from bottom to top, so that it's
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 89947711d4bfe..7c306c4a21daf 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -76,6 +76,7 @@
#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
#include "llvm/CodeGen/CallBrPrepare.h"
#include "llvm/CodeGen/CodeGenPrepare.h"
+#include "llvm/CodeGen/DeadMachineInstructionElim.h"
#include "llvm/CodeGen/DwarfEHPrepare.h"
#include "llvm/CodeGen/ExpandLargeDivRem.h"
#include "llvm/CodeGen/ExpandLargeFpConvert.h"
diff --git a/llvm/test/CodeGen/AArch64/elim-dead-mi.mir b/llvm/test/CodeGen/AArch64/elim-dead-mi.mir
index 0542b46f2e393..9612f3269f162 100644
--- a/llvm/test/CodeGen/AArch64/elim-dead-mi.mir
+++ b/llvm/test/CodeGen/AArch64/elim-dead-mi.mir
@@ -1,5 +1,6 @@
# RUN: llc -mtriple=aarch64 -o - %s \
# RUN: -run-pass dead-mi-elimination | FileCheck %s
+# RUN: llc -mtriple=aarch64 -o - %s -p dead-mi-elimination | FileCheck %s
--- |
@c = internal unnamed_addr global [3 x i8] zeroinitializer, align 4
@d = common dso_local local_unnamed_addr global i32 0, align 4
>From d87b481d050c1d53efc3e842376ca206bcd86d70 Mon Sep 17 00:00:00 2001
From: AtariDreams <83477269+AtariDreams at users.noreply.github.com>
Date: Tue, 6 Feb 2024 05:00:35 -0500
Subject: [PATCH 4/5] [Transforms] Expand optimizeTan to fold more inverse trig
pairs (#77799)
optimizeTan has been renamed to optimizeTrigInversionPairs as a result.
Sadly, this is not mathematically true that all inverse pairs fold to x.
For example, asin(sin(x)) does not fold to x if x is over 2pi.
---
.../llvm/Transforms/Utils/SimplifyLibCalls.h | 2 +-
.../lib/Transforms/Utils/SimplifyLibCalls.cpp | 59 ++++++--
.../Transforms/InstCombine/tan-nofastmath.ll | 17 ---
llvm/test/Transforms/InstCombine/tan.ll | 23 ---
llvm/test/Transforms/InstCombine/trig.ll | 140 ++++++++++++++++++
5 files changed, 185 insertions(+), 56 deletions(-)
delete mode 100644 llvm/test/Transforms/InstCombine/tan-nofastmath.ll
delete mode 100644 llvm/test/Transforms/InstCombine/tan.ll
create mode 100644 llvm/test/Transforms/InstCombine/trig.ll
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index 1aad0b2988451..1b6b525b19cae 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -203,7 +203,7 @@ class LibCallSimplifier {
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
- Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+ Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
// Wrapper for all floating point library call optimizations
Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index f79549f79389a..26a34aa99e1b8 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2681,13 +2681,16 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
return copyFlags(*CI, FabsCall);
}
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI,
+ IRBuilderBase &B) {
Module *M = CI->getModule();
Function *Callee = CI->getCalledFunction();
Value *Ret = nullptr;
StringRef Name = Callee->getName();
- if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+ if (UnsafeFPShrink &&
+ (Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh" ||
+ Name == "asinh") &&
+ hasFloatVersion(M, Name))
Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
Value *Op1 = CI->getArgOperand(0);
@@ -2700,16 +2703,34 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
return Ret;
// tan(atan(x)) -> x
- // tanf(atanf(x)) -> x
- // tanl(atanl(x)) -> x
+ // atanh(tanh(x)) -> x
+ // sinh(asinh(x)) -> x
+ // asinh(sinh(x)) -> x
+ // cosh(acosh(x)) -> x
LibFunc Func;
Function *F = OpC->getCalledFunction();
if (F && TLI->getLibFunc(F->getName(), Func) &&
- isLibFuncEmittable(M, TLI, Func) &&
- ((Func == LibFunc_atan && Callee->getName() == "tan") ||
- (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
- (Func == LibFunc_atanl && Callee->getName() == "tanl")))
- Ret = OpC->getArgOperand(0);
+ isLibFuncEmittable(M, TLI, Func)) {
+ LibFunc inverseFunc = llvm::StringSwitch<LibFunc>(Callee->getName())
+ .Case("tan", LibFunc_atan)
+ .Case("atanh", LibFunc_tanh)
+ .Case("sinh", LibFunc_asinh)
+ .Case("cosh", LibFunc_acosh)
+ .Case("tanf", LibFunc_atanf)
+ .Case("atanhf", LibFunc_tanhf)
+ .Case("sinhf", LibFunc_asinhf)
+ .Case("coshf", LibFunc_acoshf)
+ .Case("tanl", LibFunc_atanl)
+ .Case("atanhl", LibFunc_tanhl)
+ .Case("sinhl", LibFunc_asinhl)
+ .Case("coshl", LibFunc_acoshl)
+ .Case("asinh", LibFunc_sinh)
+ .Case("asinhf", LibFunc_sinhf)
+ .Case("asinhl", LibFunc_sinhl)
+ .Default(NumLibFuncs); // Used as error value
+ if (Func == inverseFunc)
+ Ret = OpC->getArgOperand(0);
+ }
return Ret;
}
@@ -3702,7 +3723,19 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_tan:
case LibFunc_tanf:
case LibFunc_tanl:
- return optimizeTan(CI, Builder);
+ case LibFunc_sinh:
+ case LibFunc_sinhf:
+ case LibFunc_sinhl:
+ case LibFunc_asinh:
+ case LibFunc_asinhf:
+ case LibFunc_asinhl:
+ case LibFunc_cosh:
+ case LibFunc_coshf:
+ case LibFunc_coshl:
+ case LibFunc_atanh:
+ case LibFunc_atanhf:
+ case LibFunc_atanhl:
+ return optimizeTrigInversionPairs(CI, Builder);
case LibFunc_ceil:
return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
case LibFunc_floor:
@@ -3720,17 +3753,13 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_acos:
case LibFunc_acosh:
case LibFunc_asin:
- case LibFunc_asinh:
case LibFunc_atan:
- case LibFunc_atanh:
case LibFunc_cbrt:
- case LibFunc_cosh:
case LibFunc_exp:
case LibFunc_exp10:
case LibFunc_expm1:
case LibFunc_cos:
case LibFunc_sin:
- case LibFunc_sinh:
case LibFunc_tanh:
if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
diff --git a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll b/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
deleted file mode 100644
index 514ff4e40d618..0000000000000
--- a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll
+++ /dev/null
@@ -1,17 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
-entry:
- %call = call float @atanf(float %x)
- %call1 = call float @tanf(float %call)
- ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK: %call = call float @atanf(float %x)
-; CHECK-NEXT: %call1 = call float @tanf(float %call)
-; CHECK-NEXT: ret float %call1
-; CHECK-NEXT: }
-
-declare float @tanf(float)
-declare float @atanf(float)
diff --git a/llvm/test/Transforms/InstCombine/tan.ll b/llvm/test/Transforms/InstCombine/tan.ll
deleted file mode 100644
index 49f6e00e6d9ba..0000000000000
--- a/llvm/test/Transforms/InstCombine/tan.ll
+++ /dev/null
@@ -1,23 +0,0 @@
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
-define float @mytan(float %x) {
- %call = call fast float @atanf(float %x)
- %call1 = call fast float @tanf(float %call)
- ret float %call1
-}
-
-; CHECK-LABEL: define float @mytan(
-; CHECK: ret float %x
-
-define float @test2(ptr %fptr) {
- %call1 = call fast float %fptr()
- %tan = call fast float @tanf(float %call1)
- ret float %tan
-}
-
-; CHECK-LABEL: @test2
-; CHECK: tanf
-
-declare float @tanf(float)
-declare float @atanf(float)
-
diff --git a/llvm/test/Transforms/InstCombine/trig.ll b/llvm/test/Transforms/InstCombine/trig.ll
new file mode 100644
index 0000000000000..5dda1524396d4
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/trig.ll
@@ -0,0 +1,140 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define float @tanAtanInverseFast(float %x) {
+; CHECK-LABEL: define float @tanAtanInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @atanf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @atanf(float %x)
+ %call1 = call fast float @tanf(float %call)
+ ret float %call1
+}
+
+define float @atanhTanhInverseFast(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @tanhf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @tanhf(float %x)
+ %call1 = call fast float @atanhf(float %call)
+ ret float %call1
+}
+
+define float @sinhAsinhInverseFast(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @asinhf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @asinhf(float %x)
+ %call1 = call fast float @sinhf(float %call)
+ ret float %call1
+}
+
+define float @asinhSinhInverseFast(float %x) {
+; CHECK-LABEL: define float @asinhSinhInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @sinhf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @sinhf(float %x)
+ %call1 = call fast float @asinhf(float %call)
+ ret float %call1
+}
+
+define float @coshAcoshInverseFast(float %x) {
+; CHECK-LABEL: define float @coshAcoshInverseFast(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call fast float @acoshf(float [[X]])
+; CHECK-NEXT: ret float [[X]]
+;
+ %call = call fast float @acoshf(float %x)
+ %call1 = call fast float @coshf(float %call)
+ ret float %call1
+}
+
+define float @indirectTanCall(ptr %fptr) {
+; CHECK-LABEL: define float @indirectTanCall(
+; CHECK-SAME: ptr [[FPTR:%.*]]) {
+; CHECK-NEXT: [[CALL1:%.*]] = call fast float [[FPTR]]()
+; CHECK-NEXT: [[TAN:%.*]] = call fast float @tanf(float [[CALL1]])
+; CHECK-NEXT: ret float [[TAN]]
+;
+ %call1 = call fast float %fptr()
+ %tan = call fast float @tanf(float %call1)
+ ret float %tan
+}
+
+; No fast-math.
+
+define float @tanAtanInverse(float %x) {
+; CHECK-LABEL: define float @tanAtanInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @atanf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @tanf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @atanf(float %x)
+ %call1 = call float @tanf(float %call)
+ ret float %call1
+}
+
+define float @atanhTanhInverse(float %x) {
+; CHECK-LABEL: define float @atanhTanhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @tanhf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @atanhf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @tanhf(float %x)
+ %call1 = call float @atanhf(float %call)
+ ret float %call1
+}
+
+define float @sinhAsinhInverse(float %x) {
+; CHECK-LABEL: define float @sinhAsinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @asinhf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @sinhf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @asinhf(float %x)
+ %call1 = call float @sinhf(float %call)
+ ret float %call1
+}
+
+define float @asinhSinhInverse(float %x) {
+; CHECK-LABEL: define float @asinhSinhInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @sinhf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @asinhf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @sinhf(float %x)
+ %call1 = call float @asinhf(float %call)
+ ret float %call1
+}
+
+define float @coshAcoshInverse(float %x) {
+; CHECK-LABEL: define float @coshAcoshInverse(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @acoshf(float [[X]])
+; CHECK-NEXT: [[CALL1:%.*]] = call float @coshf(float [[CALL]])
+; CHECK-NEXT: ret float [[CALL1]]
+;
+ %call = call float @acoshf(float %x)
+ %call1 = call float @coshf(float %call)
+ ret float %call1
+}
+
+declare float @asinhf(float)
+declare float @sinhf(float)
+declare float @acoshf(float)
+declare float @coshf(float)
+declare float @tanhf(float)
+declare float @atanhf(float)
+declare float @tanf(float)
+declare float @atanf(float)
>From 62afb3060830843b9ed95f933ef7ea9e62aa232b Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 6 Feb 2024 04:48:55 -0800
Subject: [PATCH 5/5] simplify Op->Fun->FunDefOp register
---
.../lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 29 ------------------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 30 ++++++++++++-------
2 files changed, 20 insertions(+), 39 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index dabd73367cd6a..96cc621791e97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -264,12 +264,6 @@ class SPIRVGeneralDuplicatesTracker {
SPIRVDuplicatesTracker<Argument> AT;
SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
- // map a Function to its definition (as a machine instruction operand)
- DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
- // map function pointer (as a machine instruction operand) to the used
- // Function
- DenseMap<const MachineOperand *, const Function *> InstrToFunction;
-
// NOTE: using MOs instead of regs to get rid of MF dependency to be able
// to use flat data structure.
// NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
@@ -286,29 +280,6 @@ class SPIRVGeneralDuplicatesTracker {
void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
MachineModuleInfo *MMI);
- // Map a machine operand that represents a use of a function via function
- // pointer to a machine operand that represents the function definition.
- // Return either the register or invalid value, because we have no context for
- // a good diagnostic message in case of unexpectedly missing references.
- const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
- auto ResF = InstrToFunction.find(Use);
- if (ResF == InstrToFunction.end())
- return nullptr;
- auto ResReg = FunctionToInstr.find(ResF->second);
- return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
- }
- // map function pointer (as a machine instruction operand) to the used
- // Function
- void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
- InstrToFunction[MO] = F;
- }
- // map a Function to its definition (as a machine instruction)
- void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
- FunctionToInstr[F] = MO;
- }
- // Return true if any OpConstantFunctionPointerINTEL were generated
- bool hasConstFunPtr() { return !InstrToFunction.empty(); }
-
void add(const Type *Ty, const MachineFunction *MF, Register R) {
TT.add(Ty, MF, R);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 3ccc224ad8a4c..792a00786f0aa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -38,6 +38,12 @@ class SPIRVGlobalRegistry {
DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
+ // map a Function to its definition (as a machine instruction operand)
+ DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
+ // map function pointer (as a machine instruction operand) to the used
+ // Function
+ DenseMap<const MachineOperand *, const Function *> InstrToFunction;
+
// Look for an equivalent of the newType in the map. Return the equivalent
// if it's found, otherwise insert newType to the map and return the type.
const MachineInstr *checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
@@ -101,24 +107,28 @@ class SPIRVGlobalRegistry {
DT.buildDepsGraph(Graph, MMI);
}
+ // Map a machine operand that represents a use of a function via function
+ // pointer to a machine operand that represents the function definition.
+ // Return either the register or invalid value, because we have no context for
+ // a good diagnostic message in case of unexpectedly missing references.
+ const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
+ auto ResF = InstrToFunction.find(Use);
+ if (ResF == InstrToFunction.end())
+ return nullptr;
+ auto ResReg = FunctionToInstr.find(ResF->second);
+ return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
+ }
// map function pointer (as a machine instruction operand) to the used
// Function
void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
- DT.recordFunctionPointer(MO, F);
+ InstrToFunction[MO] = F;
}
// map a Function to its definition (as a machine instruction)
void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
- DT.recordFunctionDefinition(F, MO);
- }
- // Map a machine operand that represents a use of a function via function
- // pointer to a machine operand that represents the function definition.
- // Return either the register or invalid value, because we have no context for
- // a good diagnostic message in case of unexpectedly missing references.
- const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
- return DT.getFunctionDefinitionByUse(Use);
+ FunctionToInstr[F] = MO;
}
// Return true if any OpConstantFunctionPointerINTEL were generated
- bool hasConstFunPtr() { return DT.hasConstFunPtr(); }
+ bool hasConstFunPtr() { return !InstrToFunction.empty(); }
// Get or create a SPIR-V type corresponding the given LLVM IR type,
// and map it to the given VReg by creating an ASSIGN_TYPE instruction.
More information about the llvm-commits
mailing list