[llvm] ffc5581 - [SPIRV] Add support for pointers to functions with aggregate args/returns as global variables / constant initialisers (#169595)

via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 7 12:17:30 PST 2025


Author: Alex Voicu
Date: 2025-12-07T20:17:25Z
New Revision: ffc55815ef3208a3802513c33ac66a122d2fb680

URL: https://github.com/llvm/llvm-project/commit/ffc55815ef3208a3802513c33ac66a122d2fb680
DIFF: https://github.com/llvm/llvm-project/commit/ffc55815ef3208a3802513c33ac66a122d2fb680.diff

LOG: [SPIRV] Add support for pointers to functions with aggregate args/returns as global variables / constant initialisers (#169595)

This patch does two things:

1. it extends the aggregate arg / ret replacement transform to work on
indirect calls / pointers to function. It is somewhat spread out as
retrieving the original function type is needed in a few places. In
general, we should rethink / rework the entire infrastructure around
aggregate arg/ret handling, using an opaque target specific type rather
than i32;
2. it enables global variables of pointer to function type, and, more
specifically, global variables of a aggregate type (arrays / structures)
with pointer to function elements.

This also exposes some issues in how we handle pointers to function and
lowering indirect function calls, primarily around not using the program
address space. These will be handled in a subsequent patch as they'll
require somewhat more intrusive surgery, possibly involving modifying
the data layout.

Added: 
    llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
    llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
    llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
    llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
    llvm/lib/Target/SPIRV/SPIRVUtils.cpp
    llvm/lib/Target/SPIRV/SPIRVUtils.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 206e13e8c0346..c514debb63b30 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -131,47 +131,6 @@ fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
   return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
 }
 
-// This code restores function args/retvalue types for composite cases
-// because the final types should still be aggregate whereas they're i32
-// during the translation to cope with aggregate flattening etc.
-static FunctionType *getOriginalFunctionType(const Function &F) {
-  auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
-  if (NamedMD == nullptr)
-    return F.getFunctionType();
-
-  Type *RetTy = F.getFunctionType()->getReturnType();
-  SmallVector<Type *, 4> ArgTypes;
-  for (auto &Arg : F.args())
-    ArgTypes.push_back(Arg.getType());
-
-  auto ThisFuncMDIt =
-      std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
-        return isa<MDString>(N->getOperand(0)) &&
-               cast<MDString>(N->getOperand(0))->getString() == F.getName();
-      });
-  if (ThisFuncMDIt != NamedMD->op_end()) {
-    auto *ThisFuncMD = *ThisFuncMDIt;
-    for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) {
-      MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(I));
-      assert(MD && "MDNode operand is expected");
-      ConstantInt *Const = getConstInt(MD, 0);
-      if (Const) {
-        auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
-        assert(CMeta && "ConstantAsMetadata operand is expected");
-        assert(Const->getSExtValue() >= -1);
-        // Currently -1 indicates return value, greater values mean
-        // argument numbers.
-        if (Const->getSExtValue() == -1)
-          RetTy = CMeta->getType();
-        else
-          ArgTypes[Const->getSExtValue()] = CMeta->getType();
-      }
-    }
-  }
-
-  return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
-}
-
 static SPIRV::AccessQualifier::AccessQualifier
 getArgAccessQual(const Function &F, unsigned ArgIdx) {
   if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
@@ -204,7 +163,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
       getArgAccessQual(F, ArgIdx);
 
-  Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
+  Type *OriginalArgType =
+      SPIRV::getOriginalFunctionType(F)->getParamType(ArgIdx);
 
   // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
   // be legally reassigned later).
@@ -429,7 +389,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   auto MRI = MIRBuilder.getMRI();
   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
   MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
-  FunctionType *FTy = getOriginalFunctionType(F);
+  FunctionType *FTy = SPIRV::getOriginalFunctionType(F);
   Type *FRetTy = FTy->getReturnType();
   if (isUntypedPointerTy(FRetTy)) {
     if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
@@ -514,10 +474,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
 // - add a topological sort of IndirectCalls to ensure the best types knowledge
 // - we may need to fix function formal parameter types if they are opaque
 //   pointers used as function pointers in these indirect calls
+// - defaulting to StorageClass::Function in the absence of the
+//   SPV_INTEL_function_pointers extension seems wrong, as that might not be
+//   able to hold a full width pointer to function, and it also does not model
+//   the semantics of a pointer to function in a generic fashion.
 void SPIRVCallLowering::produceIndirectPtrTypes(
     MachineIRBuilder &MIRBuilder) const {
   // Create indirect call data types if any
   MachineFunction &MF = MIRBuilder.getMF();
+  const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
   for (auto const &IC : IndirectCalls) {
     SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
         IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
@@ -535,8 +500,11 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
     SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
         FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
     // SPIR-V pointer to function type:
-    SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
-        SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
+    auto SC = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
+                  ? SPIRV::StorageClass::CodeSectionINTEL
+                  : SPIRV::StorageClass::Function;
+    SPIRVType *IndirectFuncPtrTy =
+        GR->getOrCreateSPIRVPointerType(SpirvFuncTy, MIRBuilder, SC);
     // Correct the Callee type
     GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
   }
@@ -564,12 +532,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
       return false;
-    if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
-      OrigRetTy = FTy->getReturnType();
-      if (isUntypedPointerTy(OrigRetTy)) {
-        if (auto *DerivedRetTy = GR->findReturnType(CF))
-          OrigRetTy = DerivedRetTy;
-      }
+
+    FunctionType *FTy = SPIRV::getOriginalFunctionType(*CF);
+    OrigRetTy = FTy->getReturnType();
+    if (isUntypedPointerTy(OrigRetTy)) {
+      if (auto *DerivedRetTy = GR->findReturnType(CF))
+        OrigRetTy = DerivedRetTy;
     }
   }
 
@@ -691,11 +659,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     if (CalleeReg.isValid()) {
       SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
       IndirectCall.Callee = CalleeReg;
-      IndirectCall.RetTy = OrigRetTy;
-      for (const auto &Arg : Info.OrigArgs) {
-        assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
-        IndirectCall.ArgTys.push_back(Arg.Ty);
-        IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
+      FunctionType *FTy = SPIRV::getOriginalFunctionType(*Info.CB);
+      IndirectCall.RetTy = OrigRetTy = FTy->getReturnType();
+      assert(FTy->getNumParams() == Info.OrigArgs.size() &&
+             "Function types mismatch");
+      for (unsigned I = 0; I != Info.OrigArgs.size(); ++I) {
+        assert(Info.OrigArgs[I].Regs.size() == 1 &&
+               "Call arg has multiple VRegs");
+        IndirectCall.ArgTys.push_back(FTy->getParamType(I));
+        IndirectCall.ArgRegs.push_back(Info.OrigArgs[I].Regs[0]);
       }
       IndirectCalls.push_back(IndirectCall);
     }

diff  --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index eea49bfdaf04b..f9e564b9dc1f6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -360,6 +360,17 @@ static void emitAssignName(Instruction *I, IRBuilder<> &B) {
   if (!I->hasName() || I->getType()->isAggregateType() ||
       expectIgnoredInIRTranslation(I))
     return;
+
+  if (isa<CallBase>(I)) {
+    // TODO: this is a temporary workaround meant to prevent inserting internal
+    //       noise into the generated binary; remove once we rework the entire
+    //       aggregate removal machinery.
+    StringRef Name = I->getName();
+    if (Name.starts_with("spv.mutated_callsite"))
+      return;
+    if (Name.starts_with("spv.named_mutated_callsite"))
+      I->setName(Name.substr(Name.rfind('.') + 1));
+  }
   reportFatalOnTokenType(I);
   setInsertPointAfterDef(B, I);
   LLVMContext &Ctx = I->getContext();
@@ -759,10 +770,15 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
     if (Type *ElemTy = getPointeeType(KnownTy))
       maybeAssignPtrType(Ty, I, ElemTy, UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
-    Ty = deduceElementTypeByValueDeep(
-        Ref->getValueType(),
-        Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
-        UnknownElemTypeI8);
+    if (auto *Fn = dyn_cast<Function>(Ref)) {
+      Ty = SPIRV::getOriginalFunctionType(*Fn);
+      GR->addDeducedElementType(I, Ty);
+    } else {
+      Ty = deduceElementTypeByValueDeep(
+          Ref->getValueType(),
+          Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
+          UnknownElemTypeI8);
+    }
   } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
     Type *RefTy = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
                                           UnknownElemTypeI8);
@@ -1063,10 +1079,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
   if (!Op || !isPointerTy(Op->getType()))
     return;
   Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
-  FunctionType *FTy = CI->getFunctionType();
+  FunctionType *FTy = SPIRV::getOriginalFunctionType(*CI);
   bool IsNewFTy = false, IsIncomplete = false;
   SmallVector<Type *, 4> ArgTys;
-  for (Value *Arg : CI->args()) {
+  for (auto &&[ParmIdx, Arg] : llvm::enumerate(CI->args())) {
     Type *ArgTy = Arg->getType();
     if (ArgTy->isPointerTy()) {
       if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
@@ -1077,6 +1093,8 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
       } else {
         IsIncomplete = true;
       }
+    } else {
+      ArgTy = FTy->getFunctionParamType(ParmIdx);
     }
     ArgTys.push_back(ArgTy);
   }

diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index eac6b4dc1de8a..8865c618a2594 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -214,6 +214,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
       if (Value *GlobalElem =
               Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
         ElementTy = findDeducedCompositeType(GlobalElem);
+      else if (const Function *Fn = dyn_cast<Function>(Global))
+        ElementTy = SPIRV::getOriginalFunctionType(*Fn);
     }
     return ElementTy ? ElementTy : Global->getValueType();
   }

diff  --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 1e77b79ec496d..acc726717743d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -257,9 +257,12 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       Register Def = MI.getOperand(0).getReg();
       Register Source = MI.getOperand(2).getReg();
       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
-      SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
-          ElemTy, MI,
-          addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
+      auto SC =
+          isa<FunctionType>(ElemTy)
+              ? SPIRV::StorageClass::CodeSectionINTEL
+              : addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST);
+      SPIRVType *AssignedPtrType =
+          GR->getOrCreateSPIRVPointerType(ElemTy, MI, SC);
 
       // If the ptrcast would be redundant, replace all uses with the source
       // register.

diff  --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
index be88f334d2171..fdd0af871e03e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
@@ -26,6 +26,8 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
@@ -41,6 +43,7 @@ class SPIRVPrepareFunctions : public ModulePass {
   const SPIRVTargetMachine &TM;
   bool substituteIntrinsicCalls(Function *F);
   Function *removeAggregateTypesFromSignature(Function *F);
+  bool removeAggregateTypesFromCalls(Function *F);
 
 public:
   static char ID;
@@ -469,6 +472,23 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
   return Changed;
 }
 
+static void
+addFunctionTypeMutation(NamedMDNode *NMD,
+                        SmallVector<std::pair<int, Type *>> ChangedTys,
+                        StringRef Name) {
+
+  LLVMContext &Ctx = NMD->getParent()->getContext();
+  Type *I32Ty = IntegerType::getInt32Ty(Ctx);
+
+  SmallVector<Metadata *> MDArgs;
+  MDArgs.push_back(MDString::get(Ctx, Name));
+  transform(ChangedTys, std::back_inserter(MDArgs), [=, &Ctx](auto &&CTy) {
+    return MDNode::get(
+        Ctx, {ConstantAsMetadata::get(ConstantInt::get(I32Ty, CTy.first, true)),
+              ValueAsMetadata::get(Constant::getNullValue(CTy.second))});
+  });
+  NMD->addOperand(MDNode::get(Ctx, MDArgs));
+}
 // Returns F if aggregate argument/return types are not present or cloned F
 // function with the types replaced by i32 types. The change in types is
 // noted in 'spv.cloned_funcs' metadata for later restoration.
@@ -503,7 +523,8 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
   FunctionType *NewFTy =
       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
   Function *NewF =
-      Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
+      Function::Create(NewFTy, F->getLinkage(), F->getAddressSpace(),
+                       F->getName(), F->getParent());
 
   ValueToValueMapTy VMap;
   auto NewFArgIt = NewF->arg_begin();
@@ -518,22 +539,18 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
                     Returns);
   NewF->takeName(F);
 
-  NamedMDNode *FuncMD =
-      F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
-  SmallVector<Metadata *, 2> MDArgs;
-  MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
-  for (auto &ChangedTyP : ChangedTypes)
-    MDArgs.push_back(MDNode::get(
-        B.getContext(),
-        {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
-         ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
-  MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
-  FuncMD->addOperand(ThisFuncMD);
+  addFunctionTypeMutation(
+      NewF->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"),
+      std::move(ChangedTypes), NewF->getName());
 
   for (auto *U : make_early_inc_range(F->users())) {
-    if (auto *CI = dyn_cast<CallInst>(U))
+    if (CallInst *CI;
+        (CI = dyn_cast<CallInst>(U)) && CI->getCalledFunction() == F)
       CI->mutateFunctionType(NewF->getFunctionType());
-    U->replaceUsesOfWith(F, NewF);
+    if (auto *C = dyn_cast<Constant>(U))
+      C->handleOperandChange(F, NewF);
+    else
+      U->replaceUsesOfWith(F, NewF);
   }
 
   // register the mutation
@@ -543,11 +560,78 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
   return NewF;
 }
 
+// Mutates indirect callsites iff if aggregate argument/return types are present
+// with the types replaced by i32 types. The change in types is noted in
+// 'spv.mutated_callsites' metadata for later restoration.
+bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
+  if (F->isDeclaration() || F->isIntrinsic())
+    return false;
+
+  SmallVector<std::pair<CallBase *, FunctionType *>> Calls;
+  for (auto &&I : instructions(F)) {
+    if (auto *CB = dyn_cast<CallBase>(&I)) {
+      if (!CB->getCalledOperand() || CB->getCalledFunction())
+        continue;
+      if (CB->getType()->isAggregateType() ||
+          any_of(CB->args(),
+                 [](auto &&Arg) { return Arg->getType()->isAggregateType(); }))
+        Calls.emplace_back(CB, nullptr);
+    }
+  }
+
+  if (Calls.empty())
+    return false;
+
+  IRBuilder<> B(F->getContext());
+
+  for (auto &&[CB, NewFnTy] : Calls) {
+    SmallVector<std::pair<int, Type *>> ChangedTypes;
+    SmallVector<Type *> NewArgTypes;
+
+    Type *RetTy = CB->getType();
+    if (RetTy->isAggregateType()) {
+      ChangedTypes.emplace_back(-1, RetTy);
+      RetTy = B.getInt32Ty();
+    }
+
+    for (auto &&Arg : CB->args()) {
+      if (Arg->getType()->isAggregateType()) {
+        NewArgTypes.push_back(B.getInt32Ty());
+        ChangedTypes.emplace_back(Arg.getOperandNo(), Arg->getType());
+      } else {
+        NewArgTypes.push_back(Arg->getType());
+      }
+    }
+    NewFnTy = FunctionType::get(RetTy, NewArgTypes,
+                                CB->getFunctionType()->isVarArg());
+
+    if (!CB->hasName())
+      CB->setName("spv.mutated_callsite." + F->getName());
+    else
+      CB->setName("spv.named_mutated_callsite." + F->getName() + "." +
+                  CB->getName());
+
+    addFunctionTypeMutation(
+        F->getParent()->getOrInsertNamedMetadata("spv.mutated_callsites"),
+        std::move(ChangedTypes), CB->getName());
+  }
+
+  for (auto &&[CB, NewFTy] : Calls) {
+    if (NewFTy->getReturnType() != CB->getType())
+      TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
+          CB, CB->getType());
+    CB->mutateFunctionType(NewFTy);
+  }
+
+  return true;
+}
+
 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
   bool Changed = false;
   for (Function &F : M) {
     Changed |= substituteIntrinsicCalls(&F);
     Changed |= sortBlocks(F);
+    Changed |= removeAggregateTypesFromCalls(&F);
   }
 
   std::vector<Function *> FuncsWorklist;

diff  --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 7fdb0fafa3719..d4dd897647cfc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -28,6 +28,69 @@
 #include <vector>
 
 namespace llvm {
+namespace SPIRV {
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with aggregate flattening etc.
+// TODO: should these just return nullptr when there's no metadata?
+static FunctionType *extractFunctionTypeFromMetadata(NamedMDNode *NMD,
+                                                     FunctionType *FTy,
+                                                     StringRef Name) {
+  if (!NMD)
+    return FTy;
+
+  constexpr auto getConstInt = [](MDNode *MD, unsigned OpId) -> ConstantInt * {
+    if (MD->getNumOperands() <= OpId)
+      return nullptr;
+    if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(OpId)))
+      return dyn_cast<ConstantInt>(CMeta->getValue());
+    return nullptr;
+  };
+
+  auto It = find_if(NMD->operands(), [Name](MDNode *N) {
+    if (auto *MDS = dyn_cast_or_null<MDString>(N->getOperand(0)))
+      return MDS->getString() == Name;
+    return false;
+  });
+
+  if (It == NMD->op_end())
+    return FTy;
+
+  Type *RetTy = FTy->getReturnType();
+  SmallVector<Type *, 4> PTys(FTy->params());
+
+  for (unsigned I = 1; I != (*It)->getNumOperands(); ++I) {
+    MDNode *MD = dyn_cast<MDNode>((*It)->getOperand(I));
+    assert(MD && "MDNode operand is expected");
+
+    if (auto *Const = getConstInt(MD, 0)) {
+      auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
+      assert(CMeta && "ConstantAsMetadata operand is expected");
+      assert(Const->getSExtValue() >= -1);
+      // Currently -1 indicates return value, greater values mean
+      // argument numbers.
+      if (Const->getSExtValue() == -1)
+        RetTy = CMeta->getType();
+      else
+        PTys[Const->getSExtValue()] = CMeta->getType();
+    }
+  }
+
+  return FunctionType::get(RetTy, PTys, FTy->isVarArg());
+}
+
+FunctionType *getOriginalFunctionType(const Function &F) {
+  return extractFunctionTypeFromMetadata(
+      F.getParent()->getNamedMetadata("spv.cloned_funcs"), F.getFunctionType(),
+      F.getName());
+}
+
+FunctionType *getOriginalFunctionType(const CallBase &CB) {
+  return extractFunctionTypeFromMetadata(
+      CB.getModule()->getNamedMetadata("spv.mutated_callsites"),
+      CB.getFunctionType(), CB.getName());
+}
+} // Namespace SPIRV
 
 // The following functions are used to add these string literals as a series of
 // 32-bit integer operands with the correct format, and unpack them if necessary

diff  --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 45e211a1e5d2a..6cda16e3b4d54 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -159,6 +159,11 @@ struct FPFastMathDefaultInfoVector
   }
 };
 
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with aggregate flattening etc.
+FunctionType *getOriginalFunctionType(const Function &F);
+FunctionType *getOriginalFunctionType(const CallBase &CB);
 } // namespace SPIRV
 
 // Add the given string as a series of integer operand, inserting null

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
new file mode 100644
index 0000000000000..ec3fd41f7de9e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
@@ -0,0 +1,107 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpCapability Kernel
+; CHECK-DAG: OpCapability FunctionPointersINTEL
+; CHECK-DAG: OpExtension "SPV_INTEL_function_pointers"
+; CHECK-DAG: OpName %[[#fArray:]] "array"
+; CHECK-DAG: OpName %[[#fStruct:]] "struct"
+
+; CHECK-DAG: %[[#Int8Ty:]] = OpTypeInt 8 0
+; CHECK: %[[#GlobalInt8PtrTy:]] = OpTypePointer CrossWorkgroup %[[#Int8Ty]]
+; CHECK: %[[#VoidTy:]] = OpTypeVoid
+; CHECK: %[[#TestFnTy:]] = OpTypeFunction %[[#VoidTy]] %[[#GlobalInt8PtrTy]]
+; CHECK: %[[#F16Ty:]] = OpTypeFloat 16
+; CHECK: %[[#t_halfTy:]] = OpTypeStruct %[[#F16Ty]]
+; CHECK: %[[#FnTy:]] = OpTypeFunction %[[#t_halfTy]] %[[#GlobalInt8PtrTy]] %[[#t_halfTy]]
+; CHECK: %[[#IntelFnPtrTy:]] = OpTypePointer CodeSectionINTEL %[[#FnTy]]
+; CHECK: %[[#Int8PtrTy:]] = OpTypePointer Function %[[#Int8Ty]]
+; CHECK: %[[#Int32Ty:]] = OpTypeInt 32 0
+; CHECK: %[[#I32Const3:]] = OpConstant %[[#Int32Ty]] 3
+; CHECK: %[[#FnArrTy:]] = OpTypeArray %[[#Int8PtrTy]] %[[#I32Const3]]
+; CHECK: %[[#GlobalFnArrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#FnArrTy]]
+; CHECK: %[[#GlobalFnPtrTy:]] = OpTypePointer CrossWorkgroup %[[#FnTy]]
+; CHECK: %[[#FnPtrTy:]] = OpTypePointer Function %[[#FnTy]]
+; CHECK: %[[#StructWithPfnTy:]] = OpTypeStruct %[[#FnPtrTy]] %[[#FnPtrTy]] %[[#FnPtrTy]]
+; CHECK: %[[#ArrayOfPfnTy:]] = OpTypeArray %[[#FnPtrTy]] %[[#I32Const3]]
+; CHECK: %[[#Int64Ty:]] = OpTypeInt 64 0
+; CHECK: %[[#GlobalStructWithPfnPtrTy:]] = OpTypePointer CrossWorkgroup %[[#StructWithPfnTy]]
+; CHECK: %[[#GlobalArrOfPfnPtrTy:]] = OpTypePointer CrossWorkgroup %[[#ArrayOfPfnTy]]
+; CHECK: %[[#I64Const2:]] = OpConstant %[[#Int64Ty]] 2
+; CHECK: %[[#I64Const1:]] = OpConstant %[[#Int64Ty]] 1
+; CHECK: %[[#I64Const0:]] = OpConstantNull %[[#Int64Ty]]
+; CHECK: %[[#f0Pfn:]] = OpConstantFunctionPointerINTEL %[[#IntelFnPtrTy]] %28
+; CHECK: %[[#f1Pfn:]] = OpConstantFunctionPointerINTEL %[[#IntelFnPtrTy]] %32
+; CHECK: %[[#f2Pfn:]] = OpConstantFunctionPointerINTEL %[[#IntelFnPtrTy]] %36
+; CHECK: %[[#f0Cast:]] = OpSpecConstantOp %[[#FnPtrTy]] Bitcast %[[#f0Pfn]]
+; CHECK: %[[#f1Cast:]] = OpSpecConstantOp %[[#FnPtrTy]] Bitcast %[[#f1Pfn]]
+; CHECK: %[[#f2Cast:]] = OpSpecConstantOp %[[#FnPtrTy]] Bitcast %[[#f2Pfn]]
+; CHECK: %[[#fnptrTy:]] = OpConstantComposite %[[#ArrayOfPfnTy]] %[[#f0Cast]] %[[#f1Cast]] %[[#f2Cast]]
+; CHECK: %[[#fnptr:]] = OpVariable %[[#GlobalArrOfPfnPtrTy]] CrossWorkgroup %[[#fnptrTy]]
+; CHECK: %[[#fnstructTy:]] = OpConstantComposite %[[#StructWithPfnTy]] %[[#f0Cast]] %[[#f1Cast]] %[[#f2Cast]]
+; CHECK: %[[#fnstruct:]] = OpVariable %[[#GlobalStructWithPfnPtrTy:]] CrossWorkgroup %[[#fnstructTy]]
+; CHECK-DAG: %[[#GlobalInt8PtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#Int8PtrTy]]
+; CHECK: %[[#StructWithPtrTy:]] = OpTypeStruct %[[#Int8PtrTy]] %[[#Int8PtrTy]] %[[#Int8PtrTy]]
+; CHECK: %[[#GlobalStructWithPtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#StructWithPtrTy]]
+; CHECK: %[[#I32Const2:]] = OpConstant %[[#Int32Ty]] 2
+; CHECK: %[[#I32Const1:]] = OpConstant %[[#Int32Ty]] 1
+; CHECK: %[[#I32Const0:]] = OpConstantNull %[[#Int32Ty]]
+; CHECK: %[[#GlobalFnPtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#FnPtrTy]]
+%t_half = type { half }
+%struct.anon = type { ptr, ptr, ptr }
+
+declare spir_func %t_half @f0(ptr addrspace(1) %a, %t_half %b)
+declare spir_func %t_half @f1(ptr addrspace(1) %a, %t_half %b)
+declare spir_func %t_half @f2(ptr addrspace(1) %a, %t_half %b)
+
+ at fnptr = addrspace(1) constant [3 x ptr] [ptr @f0, ptr @f1, ptr @f2]
+ at fnstruct = addrspace(1) constant %struct.anon { ptr @f0, ptr @f1, ptr @f2 }, align 8
+
+; CHECK-DAG: %[[#fArray]] = OpFunction %[[#VoidTy]] None %[[#TestFnTy]]
+;	CHECK-DAG: %[[#fnptrCast:]] = OpBitcast %[[#GlobalFnArrPtrTy]] %[[#fnptr]]
+; CHECK: %[[#f0GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalFnArrPtrTy]] %[[#fnptrCast]] %[[#I64Const0]]
+; CHECK: %[[#f0GEPCast:]] = OpBitcast %[[#GlobalFnPtrTy]] %[[#f0GEP]]
+; CHECK: %[[#f1GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalFnArrPtrTy]] %[[#fnptrCast]] %[[#I64Const1]]
+; CHECK: %[[#f1GEPCast:]] = OpBitcast %[[#GlobalFnPtrTy]] %[[#f1GEP]]
+; CHECK: %[[#f2GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalFnArrPtrTy]] %[[#fnptrCast]] %[[#I64Const2]]
+; CHECK: %[[#f2GEPCast:]] = OpBitcast %[[#GlobalFnPtrTy]] %[[#f2GEP]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f0GEPCast]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f1GEPCast]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f2GEPCast]]
+define spir_func void @array(ptr addrspace(1) %p) {
+entry:
+  %f = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 0
+  %g = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 1
+  %h = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 2
+  %0 = call spir_func addrspace(1) %t_half %f(ptr addrspace(1) %p, %t_half poison)
+  %1 = call spir_func addrspace(1) %t_half %g(ptr addrspace(1) %p, %t_half %0)
+  %2 = call spir_func addrspace(1) %t_half %h(ptr addrspace(1) %p, %t_half %1)
+
+  ret void
+}
+
+; CHECK-DAG: %[[#fStruct]] = OpFunction %[[#VoidTy]] None %[[#TestFnTy]]
+; CHECK-DAG: %[[#fnStructCast0:]] = OpBitcast %[[#GlobalInt8PtrPtrTy]] %[[#fnstruct]]
+; CHECK: %[[#fnStructCast1:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#fnStructCast0]]
+; CHECK: %[[#f0Load:]] = OpLoad %[[#FnPtrTy]] %[[#fnStructCast1]]
+; CHECK: %[[#fnStructCast2:]] = OpBitcast %[[#GlobalStructWithPtrPtrTy]] %[[#fnstruct]]
+; CHECK: %[[#f1GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalInt8PtrPtrTy]] %[[#fnStructCast2]] %[[#I32Const0]] %[[#I32Const1]]
+; CHECK: %[[#f1GEPCast:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#f1GEP]]
+; CHECK: %[[#f1Load:]] = OpLoad %[[#FnPtrTy]] %[[#f1GEPCast]]
+; CHECK: %[[#f2GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalInt8PtrPtrTy]] %[[#fnStructCast2]] %[[#I32Const0]] %[[#I32Const2]]
+; CHECK: %[[#f2GEPCast:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#f2GEP]]
+; CHECK: %[[#f2Load:]] = OpLoad %[[#FnPtrTy]] %[[#f2GEPCast]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f0Load]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f1Load]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f2Load]]
+define spir_func void @struct(ptr addrspace(1) %p) {
+entry:
+  %f = load ptr, ptr addrspace(1) @fnstruct
+  %g = load ptr, ptr addrspace(1) getelementptr inbounds (%struct.anon, ptr addrspace(1) @fnstruct, i32 0, i32 1)
+  %h = load ptr, ptr addrspace(1) getelementptr inbounds (%struct.anon, ptr addrspace(1) @fnstruct, i32 0, i32 2)
+  %0 = call spir_func noundef %t_half %f(ptr addrspace(1) %p, %t_half poison)
+  %1 = call spir_func noundef %t_half %g(ptr addrspace(1) %p, %t_half %0)
+  %2 = call spir_func noundef %t_half %h(ptr addrspace(1) %p, %t_half %1)
+
+  ret void
+}


        


More information about the llvm-commits mailing list