[llvm] [SPIRV] Add support for pointers to functions with aggregate args/returns as global variables / constant initialisers (PR #169595)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 25 18:17:16 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Alex Voicu (AlexVlx)

<details>
<summary>Changes</summary>

This patch does two things:

1. it extends the aggregate arg / ret replacement transform to work on indirect calls / pointers to function. It is somewhat spread out as retrieving the original function type is needed in a few places. In general, we should rethink / rework the entire infrastructure around aggregate arg/ret handling, using an opaque target specific type rather than i32;
2. it enables global variables of pointer to function type, and, more specifically, global variables of a aggregate type (arrays / structures) with pointer to function elements.

This also exposes some issues in how we handle pointers to function and lowering indirect function calls, primarily around not using the program address space. These will be handled in a subsequent patch as they'll require somewhat more intrusive surgery, possibly involving modifying the data layout.

---

Patch is 27.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169595.diff


8 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+28-55) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+19-6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+5-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (+93-13) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+64) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+5) 
- (added) llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll (+108) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index dd57b74d79a5e..6b5c602d4ac93 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -131,46 +131,6 @@ fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
   return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
 }
 
-// This code restores function args/retvalue types for composite cases
-// because the final types should still be aggregate whereas they're i32
-// during the translation to cope with aggregate flattening etc.
-static FunctionType *getOriginalFunctionType(const Function &F) {
-  auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
-  if (NamedMD == nullptr)
-    return F.getFunctionType();
-
-  Type *RetTy = F.getFunctionType()->getReturnType();
-  SmallVector<Type *, 4> ArgTypes;
-  for (auto &Arg : F.args())
-    ArgTypes.push_back(Arg.getType());
-
-  auto ThisFuncMDIt =
-      std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
-        return isa<MDString>(N->getOperand(0)) &&
-               cast<MDString>(N->getOperand(0))->getString() == F.getName();
-      });
-  if (ThisFuncMDIt != NamedMD->op_end()) {
-    auto *ThisFuncMD = *ThisFuncMDIt;
-    for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) {
-      MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(I));
-      assert(MD && "MDNode operand is expected");
-      ConstantInt *Const = getConstInt(MD, 0);
-      if (Const) {
-        auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
-        assert(CMeta && "ConstantAsMetadata operand is expected");
-        assert(Const->getSExtValue() >= -1);
-        // Currently -1 indicates return value, greater values mean
-        // argument numbers.
-        if (Const->getSExtValue() == -1)
-          RetTy = CMeta->getType();
-        else
-          ArgTypes[Const->getSExtValue()] = CMeta->getType();
-      }
-    }
-  }
-
-  return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
-}
 
 static SPIRV::AccessQualifier::AccessQualifier
 getArgAccessQual(const Function &F, unsigned ArgIdx) {
@@ -204,7 +164,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
       getArgAccessQual(F, ArgIdx);
 
-  Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
+  Type *OriginalArgType =
+      SPIRV::getOriginalFunctionType(F)->getParamType(ArgIdx);
 
   // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
   // be legally reassigned later).
@@ -421,7 +382,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   auto MRI = MIRBuilder.getMRI();
   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
   MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
-  FunctionType *FTy = getOriginalFunctionType(F);
+  FunctionType *FTy = SPIRV::getOriginalFunctionType(F);
   Type *FRetTy = FTy->getReturnType();
   if (isUntypedPointerTy(FRetTy)) {
     if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
@@ -506,10 +467,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
 // - add a topological sort of IndirectCalls to ensure the best types knowledge
 // - we may need to fix function formal parameter types if they are opaque
 //   pointers used as function pointers in these indirect calls
+// - defaulting to StorageClass::Function in the absence of the
+//   SPV_INTEL_function_pointers extension seems wrong, as that might not be
+//   able to hold a full width pointer to function, and it also does not model
+//   the semantics of a pointer to function in a generic fashion.
 void SPIRVCallLowering::produceIndirectPtrTypes(
     MachineIRBuilder &MIRBuilder) const {
   // Create indirect call data types if any
   MachineFunction &MF = MIRBuilder.getMF();
+  const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
   for (auto const &IC : IndirectCalls) {
     SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
         IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
@@ -527,8 +493,11 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
     SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
         FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
     // SPIR-V pointer to function type:
-    SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
-        SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
+    auto SC = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
+        ? SPIRV::StorageClass::CodeSectionINTEL
+        : SPIRV::StorageClass::Function;
+    SPIRVType *IndirectFuncPtrTy =
+        GR->getOrCreateSPIRVPointerType(SpirvFuncTy, MIRBuilder, SC);
     // Correct the Callee type
     GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
   }
@@ -556,12 +525,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
       return false;
-    if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
-      OrigRetTy = FTy->getReturnType();
-      if (isUntypedPointerTy(OrigRetTy)) {
-        if (auto *DerivedRetTy = GR->findReturnType(CF))
-          OrigRetTy = DerivedRetTy;
-      }
+
+    FunctionType *FTy = SPIRV::getOriginalFunctionType(*CF);
+    OrigRetTy = FTy->getReturnType();
+    if (isUntypedPointerTy(OrigRetTy)) {
+      if (auto *DerivedRetTy = GR->findReturnType(CF))
+        OrigRetTy = DerivedRetTy;
     }
   }
 
@@ -683,11 +652,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     if (CalleeReg.isValid()) {
       SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
       IndirectCall.Callee = CalleeReg;
-      IndirectCall.RetTy = OrigRetTy;
-      for (const auto &Arg : Info.OrigArgs) {
-        assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
-        IndirectCall.ArgTys.push_back(Arg.Ty);
-        IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
+      FunctionType *FTy = SPIRV::getOriginalFunctionType(*Info.CB);
+      IndirectCall.RetTy = OrigRetTy = FTy->getReturnType();
+      assert(FTy->getNumParams() == Info.OrigArgs.size() &&
+             "Function types mismatch");
+      for (unsigned I = 0; I != Info.OrigArgs.size(); ++I) {
+        assert(Info.OrigArgs[I].Regs.size() == 1 &&
+               "Call arg has multiple VRegs");
+        IndirectCall.ArgTys.push_back(FTy->getParamType(I));
+        IndirectCall.ArgRegs.push_back(Info.OrigArgs[I].Regs[0]);
       }
       IndirectCalls.push_back(IndirectCall);
     }
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 8e14fb03127fc..b37e6d8ce4ea3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -358,7 +358,11 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
 
 static void emitAssignName(Instruction *I, IRBuilder<> &B) {
   if (!I->hasName() || I->getType()->isAggregateType() ||
-      expectIgnoredInIRTranslation(I))
+      expectIgnoredInIRTranslation(I) ||
+      // TODO: this is a temporary workaround meant to prevent inserting
+      //       internal noise into the generated binary; remove once we rework
+      //       the entire aggregate removal machinery.
+      I->getName().starts_with("spv.mutated_callsite"))
     return;
   reportFatalOnTokenType(I);
   setInsertPointAfterDef(B, I);
@@ -759,10 +763,15 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
     if (Type *ElemTy = getPointeeType(KnownTy))
       maybeAssignPtrType(Ty, I, ElemTy, UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
-    Ty = deduceElementTypeByValueDeep(
-        Ref->getValueType(),
-        Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
-        UnknownElemTypeI8);
+    if (auto *Fn = dyn_cast<Function>(Ref)) {
+      Ty = SPIRV::getOriginalFunctionType(*Fn);
+      GR->addDeducedElementType(I, Ty);
+    } else {
+      Ty = deduceElementTypeByValueDeep(
+          Ref->getValueType(),
+          Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
+          UnknownElemTypeI8);
+    }
   } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
     Type *RefTy = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
                                           UnknownElemTypeI8);
@@ -1062,9 +1071,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
   if (!Op || !isPointerTy(Op->getType()))
     return;
   Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
-  FunctionType *FTy = CI->getFunctionType();
+  FunctionType *FTy = SPIRV::getOriginalFunctionType(*CI);
   bool IsNewFTy = false, IsIncomplete = false;
   SmallVector<Type *, 4> ArgTys;
+  unsigned ParmIdx = 0;
   for (Value *Arg : CI->args()) {
     Type *ArgTy = Arg->getType();
     if (ArgTy->isPointerTy()) {
@@ -1076,8 +1086,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
       } else {
         IsIncomplete = true;
       }
+    } else {
+      ArgTy = FTy->getFunctionParamType(ParmIdx);
     }
     ArgTys.push_back(ArgTy);
+    ++ParmIdx;
   }
   Type *RetTy = FTy->getReturnType();
   if (CI->getType()->isPointerTy()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 09c77f0cfd4f5..16f3260bf4ffc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -214,6 +214,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
       if (Value *GlobalElem =
               Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
         ElementTy = findDeducedCompositeType(GlobalElem);
+      else if (const Function *Fn = dyn_cast<Function>(Global))
+        ElementTy = SPIRV::getOriginalFunctionType(*Fn);
     }
     return ElementTy ? ElementTy : Global->getValueType();
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 0f4b3d59b904a..b273599596a35 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -257,9 +257,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       Register Def = MI.getOperand(0).getReg();
       Register Source = MI.getOperand(2).getReg();
       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
-      SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
-          ElemTy, MI,
-          addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
+      auto SC = isa<FunctionType>(ElemTy)
+          ? SPIRV::StorageClass::CodeSectionINTEL
+          : addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST);
+      SPIRVType *AssignedPtrType =
+          GR->getOrCreateSPIRVPointerType(ElemTy, MI, SC);
 
       // If the ptrcast would be redundant, replace all uses with the source
       // register.
diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
index be88f334d2171..8fd261cfff25b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
@@ -26,6 +26,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
@@ -41,6 +42,7 @@ class SPIRVPrepareFunctions : public ModulePass {
   const SPIRVTargetMachine &TM;
   bool substituteIntrinsicCalls(Function *F);
   Function *removeAggregateTypesFromSignature(Function *F);
+  bool removeAggregateTypesFromCalls(Function *F);
 
 public:
   static char ID;
@@ -469,6 +471,23 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
   return Changed;
 }
 
+static void addFunctionTypeMutation(
+    NamedMDNode *NMD,
+    SmallVector<std::pair<int, Type *>> ChangedTys, StringRef Name) {
+
+    LLVMContext &Ctx = NMD->getParent()->getContext();
+    Type *I32Ty = IntegerType::getInt32Ty(Ctx);
+
+    SmallVector<Metadata *> MDArgs;
+    MDArgs.push_back(MDString::get(Ctx, Name));
+    transform(ChangedTys, std::back_inserter(MDArgs), [=, &Ctx](auto &&CTy) {
+      return MDNode::get(
+          Ctx,
+          {ConstantAsMetadata::get(ConstantInt::get(I32Ty, CTy.first, true)),
+           ValueAsMetadata::get(Constant::getNullValue(CTy.second))});
+    });
+    NMD->addOperand(MDNode::get(Ctx, MDArgs));
+}
 // Returns F if aggregate argument/return types are not present or cloned F
 // function with the types replaced by i32 types. The change in types is
 // noted in 'spv.cloned_funcs' metadata for later restoration.
@@ -503,7 +522,8 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
   FunctionType *NewFTy =
       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
   Function *NewF =
-      Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
+      Function::Create(NewFTy, F->getLinkage(), F->getAddressSpace(),
+                       F->getName(), F->getParent());
 
   ValueToValueMapTy VMap;
   auto NewFArgIt = NewF->arg_begin();
@@ -518,22 +538,17 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
                     Returns);
   NewF->takeName(F);
 
-  NamedMDNode *FuncMD =
-      F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
-  SmallVector<Metadata *, 2> MDArgs;
-  MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
-  for (auto &ChangedTyP : ChangedTypes)
-    MDArgs.push_back(MDNode::get(
-        B.getContext(),
-        {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
-         ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
-  MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
-  FuncMD->addOperand(ThisFuncMD);
+  addFunctionTypeMutation(
+      NewF->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"),
+      std::move(ChangedTypes), NewF->getName());
 
   for (auto *U : make_early_inc_range(F->users())) {
     if (auto *CI = dyn_cast<CallInst>(U))
       CI->mutateFunctionType(NewF->getFunctionType());
-    U->replaceUsesOfWith(F, NewF);
+    if (auto *C = dyn_cast<Constant>(U))
+      C->handleOperandChange(F, NewF);
+    else
+      U->replaceUsesOfWith(F, NewF);
   }
 
   // register the mutation
@@ -543,11 +558,76 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
   return NewF;
 }
 
+// Mutates indirect callsites iff if aggregate argument/return types are present
+// with the types replaced by i32 types. The change in types is noted in
+// 'spv.mutated_callsites' metadata for later restoration.
+bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
+  if (F->isDeclaration() || F->isIntrinsic())
+    return false;
+
+  SmallVector<std::pair<CallBase *, FunctionType *>> Calls;
+  for (auto &&BB : *F) {
+    for (auto &&I : BB) {
+      if (auto *CB = dyn_cast<CallBase>(&I)) {
+        if (!CB->getCalledOperand() || CB->getCalledFunction())
+          continue;
+        if (CB->getType()->isAggregateType() ||
+            any_of(CB->args(),
+                  [](auto &&Arg) { return Arg->getType()->isAggregateType(); }))
+          Calls.emplace_back(CB, nullptr);
+      }
+    }
+  }
+
+  if (Calls.empty())
+    return false;
+
+  IRBuilder<> B(F->getContext());
+
+  for (auto &&[CB, NewFnTy] : Calls) {
+    SmallVector<std::pair<int, Type *>> ChangedTypes;
+    SmallVector<Type *> NewArgTypes;
+
+    if (CB->getType()->isAggregateType())
+      ChangedTypes.emplace_back(-1, CB->getType());
+
+    Type *RetTy = ChangedTypes.empty() ? CB->getType() : B.getInt32Ty();
+    for (auto &&Arg : CB->args()) {
+      if (Arg->getType()->isAggregateType()) {
+        NewArgTypes.push_back(B.getInt32Ty());
+        ChangedTypes.emplace_back(Arg.getOperandNo(), Arg->getType());
+      } else {
+        NewArgTypes.push_back(Arg->getType());
+      }
+    }
+    NewFnTy = FunctionType::get(RetTy, NewArgTypes,
+                                CB->getFunctionType()->isVarArg());
+
+    if (!CB->hasName())
+      CB->setName("spv.mutated_callsite");
+
+    addFunctionTypeMutation(
+      F->getParent()->getOrInsertNamedMetadata("spv.mutated_callsites"),
+      std::move(ChangedTypes),
+      CB->getName());
+  }
+
+  for (auto &&[CB, NewFTy] : Calls) {
+    if (NewFTy->getReturnType() != CB->getType())
+      TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
+          CB, CB->getType());
+    CB->mutateFunctionType(NewFTy);
+  }
+
+  return true;
+}
+
 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
   bool Changed = false;
   for (Function &F : M) {
     Changed |= substituteIntrinsicCalls(&F);
     Changed |= sortBlocks(F);
+    Changed |= removeAggregateTypesFromCalls(&F);
   }
 
   std::vector<Function *> FuncsWorklist;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 8f2fc01da476f..5b6a36de6c526 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -28,6 +28,70 @@
 #include <vector>
 
 namespace llvm {
+namespace SPIRV {
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with aggregate flattening etc.
+// TODO: should these just return nullptr when there's no metadata?
+static FunctionType *extractFunctionTypeFromMetadata(NamedMDNode *NMD,
+                                                     FunctionType *FTy,
+                                                     StringRef Name) {
+  if (!NMD)
+    return FTy;
+
+  constexpr auto getConstInt = [](MDNode *MD, unsigned OpId) -> ConstantInt * {
+    if (MD->getNumOperands() <= OpId)
+      return nullptr;
+    if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(OpId)))
+      return dyn_cast<ConstantInt>(CMeta->getValue());
+    return nullptr;
+  };
+
+  auto It = find_if(NMD->operands(), [Name](MDNode *N) {
+    if (auto *MDS = dyn_cast_or_null<MDString>(N->getOperand(0)))
+      return MDS->getString() == Name;
+    return false;
+  });
+
+  if (It == NMD->op_end())
+    return FTy;
+
+  Type *RetTy = FTy->getReturnType();
+  SmallVector<Type *, 4> PTys(FTy->params());
+
+  for (unsigned I = 1; I != (*It)->getNumOperands(); ++I) {
+    MDNode *MD = dyn_cast<MDNode>((*It)->getOperand(I));
+    assert(MD && "MDNode operand is expected");
+
+    if (auto *Const = getConstInt(MD, 0)) {
+      auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
+      assert(CMeta && "ConstantAsMetadata operand is expected");
+      assert(Const->getSExtValue() >= -1);
+      // Currently -1 indicates return value, greater values mean
+      // argument numbers.
+      if (Const->getSExtValue() == -1)
+        RetTy = CMeta->getType();
+      else
+        PTys[Const->getSExtValue()] = CMeta->getType();
+    }
+  }
+
+  return FunctionType::get(RetTy, PTys, FTy->isVarArg());
+}
+
+FunctionType *getOriginalFunctionType(const Function &F) {
+  return extractFunctionTypeFromMetadata(
+      F.getParent()->getNamedMetadata("spv.cloned_funcs"),
+      F.getFunctionType(), F.getName());
+}
+
+FunctionType *getOriginalFunctionType(const CallBase &CB) {
+  return extractFunctionTypeFromMetadata(
+      CB.getParent()
+        ->getParent()->getParent()->getNamedMetadata("spv.mutated_callsites"),
+      CB.getFunctionType(), CB.getName());
+}
+} // Namespace SPIRV
 
 // The following functions are used to add these string literals as a series of
 // 32-bit integer operands with the correct format, and unpack them if necessary
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 99d9d403ea70c..3da77f1d6c1ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -159,6 +159,11 @@ struct FPFastMathDefaultInfoVector
   }
 };
 
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with a...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/169595


More information about the llvm-commits mailing list