[llvm] [SPIRV] Add support for pointers to functions with aggregate args/returns as global variables / constant initialisers (PR #169595)
Alex Voicu via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 26 07:29:34 PST 2025
================
@@ -28,6 +28,70 @@
#include <vector>
namespace llvm {
+namespace SPIRV {
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with aggregate flattening etc.
+// TODO: should these just return nullptr when there's no metadata?
+static FunctionType *extractFunctionTypeFromMetadata(NamedMDNode *NMD,
+ FunctionType *FTy,
+ StringRef Name) {
+ if (!NMD)
+ return FTy;
+
+ constexpr auto getConstInt = [](MDNode *MD, unsigned OpId) -> ConstantInt * {
+ if (MD->getNumOperands() <= OpId)
+ return nullptr;
+ if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(OpId)))
+ return dyn_cast<ConstantInt>(CMeta->getValue());
+ return nullptr;
+ };
+
+ auto It = find_if(NMD->operands(), [Name](MDNode *N) {
+ if (auto *MDS = dyn_cast_or_null<MDString>(N->getOperand(0)))
+ return MDS->getString() == Name;
+ return false;
+ });
+
+ if (It == NMD->op_end())
+ return FTy;
+
+ Type *RetTy = FTy->getReturnType();
+ SmallVector<Type *, 4> PTys(FTy->params());
+
+ for (unsigned I = 1; I != (*It)->getNumOperands(); ++I) {
+ MDNode *MD = dyn_cast<MDNode>((*It)->getOperand(I));
+ assert(MD && "MDNode operand is expected");
+
+ if (auto *Const = getConstInt(MD, 0)) {
+ auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
+ assert(CMeta && "ConstantAsMetadata operand is expected");
+ assert(Const->getSExtValue() >= -1);
+ // Currently -1 indicates return value, greater values mean
+ // argument numbers.
+ if (Const->getSExtValue() == -1)
+ RetTy = CMeta->getType();
+ else
+ PTys[Const->getSExtValue()] = CMeta->getType();
+ }
+ }
+
+ return FunctionType::get(RetTy, PTys, FTy->isVarArg());
+}
+
+FunctionType *getOriginalFunctionType(const Function &F) {
+ return extractFunctionTypeFromMetadata(
+ F.getParent()->getNamedMetadata("spv.cloned_funcs"),
+ F.getFunctionType(), F.getName());
+}
+
+FunctionType *getOriginalFunctionType(const CallBase &CB) {
+ return extractFunctionTypeFromMetadata(
+ CB.getParent()
+ ->getParent()->getParent()->getNamedMetadata("spv.mutated_callsites"),
----------------
AlexVlx wrote:
Done.
https://github.com/llvm/llvm-project/pull/169595
More information about the llvm-commits
mailing list