[llvm] [SPIRV] Add support for pointers to functions with aggregate args/returns as global variables / constant initialisers (PR #169595)

Alex Voicu via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 26 07:38:18 PST 2025


https://github.com/AlexVlx updated https://github.com/llvm/llvm-project/pull/169595

>From 66977900b4efdc5d0ce646e3d5095d573b359b58 Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Wed, 26 Nov 2025 01:53:25 +0000
Subject: [PATCH 1/7] Add support for function pointers as globals / global
 aggregate elements. Start reworking indirect function calls.

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   |  83 +++++---------
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  25 +++-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |   2 +
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |   8 +-
 .../Target/SPIRV/SPIRVPrepareFunctions.cpp    | 106 ++++++++++++++---
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  64 +++++++++++
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |   5 +
 .../fun-with-aggregate-arg-in-const-init.ll   | 108 ++++++++++++++++++
 8 files changed, 324 insertions(+), 77 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index dd57b74d79a5e..6b5c602d4ac93 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -131,46 +131,6 @@ fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
   return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
 }
 
-// This code restores function args/retvalue types for composite cases
-// because the final types should still be aggregate whereas they're i32
-// during the translation to cope with aggregate flattening etc.
-static FunctionType *getOriginalFunctionType(const Function &F) {
-  auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
-  if (NamedMD == nullptr)
-    return F.getFunctionType();
-
-  Type *RetTy = F.getFunctionType()->getReturnType();
-  SmallVector<Type *, 4> ArgTypes;
-  for (auto &Arg : F.args())
-    ArgTypes.push_back(Arg.getType());
-
-  auto ThisFuncMDIt =
-      std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
-        return isa<MDString>(N->getOperand(0)) &&
-               cast<MDString>(N->getOperand(0))->getString() == F.getName();
-      });
-  if (ThisFuncMDIt != NamedMD->op_end()) {
-    auto *ThisFuncMD = *ThisFuncMDIt;
-    for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) {
-      MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(I));
-      assert(MD && "MDNode operand is expected");
-      ConstantInt *Const = getConstInt(MD, 0);
-      if (Const) {
-        auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
-        assert(CMeta && "ConstantAsMetadata operand is expected");
-        assert(Const->getSExtValue() >= -1);
-        // Currently -1 indicates return value, greater values mean
-        // argument numbers.
-        if (Const->getSExtValue() == -1)
-          RetTy = CMeta->getType();
-        else
-          ArgTypes[Const->getSExtValue()] = CMeta->getType();
-      }
-    }
-  }
-
-  return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
-}
 
 static SPIRV::AccessQualifier::AccessQualifier
 getArgAccessQual(const Function &F, unsigned ArgIdx) {
@@ -204,7 +164,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
       getArgAccessQual(F, ArgIdx);
 
-  Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
+  Type *OriginalArgType =
+      SPIRV::getOriginalFunctionType(F)->getParamType(ArgIdx);
 
   // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
   // be legally reassigned later).
@@ -421,7 +382,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   auto MRI = MIRBuilder.getMRI();
   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
   MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
-  FunctionType *FTy = getOriginalFunctionType(F);
+  FunctionType *FTy = SPIRV::getOriginalFunctionType(F);
   Type *FRetTy = FTy->getReturnType();
   if (isUntypedPointerTy(FRetTy)) {
     if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
@@ -506,10 +467,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
 // - add a topological sort of IndirectCalls to ensure the best types knowledge
 // - we may need to fix function formal parameter types if they are opaque
 //   pointers used as function pointers in these indirect calls
+// - defaulting to StorageClass::Function in the absence of the
+//   SPV_INTEL_function_pointers extension seems wrong, as that might not be
+//   able to hold a full width pointer to function, and it also does not model
+//   the semantics of a pointer to function in a generic fashion.
 void SPIRVCallLowering::produceIndirectPtrTypes(
     MachineIRBuilder &MIRBuilder) const {
   // Create indirect call data types if any
   MachineFunction &MF = MIRBuilder.getMF();
+  const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
   for (auto const &IC : IndirectCalls) {
     SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
         IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
@@ -527,8 +493,11 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
     SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
         FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
     // SPIR-V pointer to function type:
-    SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
-        SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
+    auto SC = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
+        ? SPIRV::StorageClass::CodeSectionINTEL
+        : SPIRV::StorageClass::Function;
+    SPIRVType *IndirectFuncPtrTy =
+        GR->getOrCreateSPIRVPointerType(SpirvFuncTy, MIRBuilder, SC);
     // Correct the Callee type
     GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
   }
@@ -556,12 +525,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
       return false;
-    if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
-      OrigRetTy = FTy->getReturnType();
-      if (isUntypedPointerTy(OrigRetTy)) {
-        if (auto *DerivedRetTy = GR->findReturnType(CF))
-          OrigRetTy = DerivedRetTy;
-      }
+
+    FunctionType *FTy = SPIRV::getOriginalFunctionType(*CF);
+    OrigRetTy = FTy->getReturnType();
+    if (isUntypedPointerTy(OrigRetTy)) {
+      if (auto *DerivedRetTy = GR->findReturnType(CF))
+        OrigRetTy = DerivedRetTy;
     }
   }
 
@@ -683,11 +652,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     if (CalleeReg.isValid()) {
       SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
       IndirectCall.Callee = CalleeReg;
-      IndirectCall.RetTy = OrigRetTy;
-      for (const auto &Arg : Info.OrigArgs) {
-        assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
-        IndirectCall.ArgTys.push_back(Arg.Ty);
-        IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
+      FunctionType *FTy = SPIRV::getOriginalFunctionType(*Info.CB);
+      IndirectCall.RetTy = OrigRetTy = FTy->getReturnType();
+      assert(FTy->getNumParams() == Info.OrigArgs.size() &&
+             "Function types mismatch");
+      for (unsigned I = 0; I != Info.OrigArgs.size(); ++I) {
+        assert(Info.OrigArgs[I].Regs.size() == 1 &&
+               "Call arg has multiple VRegs");
+        IndirectCall.ArgTys.push_back(FTy->getParamType(I));
+        IndirectCall.ArgRegs.push_back(Info.OrigArgs[I].Regs[0]);
       }
       IndirectCalls.push_back(IndirectCall);
     }
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 8e14fb03127fc..b37e6d8ce4ea3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -358,7 +358,11 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
 
 static void emitAssignName(Instruction *I, IRBuilder<> &B) {
   if (!I->hasName() || I->getType()->isAggregateType() ||
-      expectIgnoredInIRTranslation(I))
+      expectIgnoredInIRTranslation(I) ||
+      // TODO: this is a temporary workaround meant to prevent inserting
+      //       internal noise into the generated binary; remove once we rework
+      //       the entire aggregate removal machinery.
+      I->getName().starts_with("spv.mutated_callsite"))
     return;
   reportFatalOnTokenType(I);
   setInsertPointAfterDef(B, I);
@@ -759,10 +763,15 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
     if (Type *ElemTy = getPointeeType(KnownTy))
       maybeAssignPtrType(Ty, I, ElemTy, UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
-    Ty = deduceElementTypeByValueDeep(
-        Ref->getValueType(),
-        Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
-        UnknownElemTypeI8);
+    if (auto *Fn = dyn_cast<Function>(Ref)) {
+      Ty = SPIRV::getOriginalFunctionType(*Fn);
+      GR->addDeducedElementType(I, Ty);
+    } else {
+      Ty = deduceElementTypeByValueDeep(
+          Ref->getValueType(),
+          Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
+          UnknownElemTypeI8);
+    }
   } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
     Type *RefTy = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
                                           UnknownElemTypeI8);
@@ -1062,9 +1071,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
   if (!Op || !isPointerTy(Op->getType()))
     return;
   Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
-  FunctionType *FTy = CI->getFunctionType();
+  FunctionType *FTy = SPIRV::getOriginalFunctionType(*CI);
   bool IsNewFTy = false, IsIncomplete = false;
   SmallVector<Type *, 4> ArgTys;
+  unsigned ParmIdx = 0;
   for (Value *Arg : CI->args()) {
     Type *ArgTy = Arg->getType();
     if (ArgTy->isPointerTy()) {
@@ -1076,8 +1086,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
       } else {
         IsIncomplete = true;
       }
+    } else {
+      ArgTy = FTy->getFunctionParamType(ParmIdx);
     }
     ArgTys.push_back(ArgTy);
+    ++ParmIdx;
   }
   Type *RetTy = FTy->getReturnType();
   if (CI->getType()->isPointerTy()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 09c77f0cfd4f5..16f3260bf4ffc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -214,6 +214,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
       if (Value *GlobalElem =
               Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
         ElementTy = findDeducedCompositeType(GlobalElem);
+      else if (const Function *Fn = dyn_cast<Function>(Global))
+        ElementTy = SPIRV::getOriginalFunctionType(*Fn);
     }
     return ElementTy ? ElementTy : Global->getValueType();
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 0f4b3d59b904a..b273599596a35 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -257,9 +257,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       Register Def = MI.getOperand(0).getReg();
       Register Source = MI.getOperand(2).getReg();
       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
-      SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
-          ElemTy, MI,
-          addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
+      auto SC = isa<FunctionType>(ElemTy)
+          ? SPIRV::StorageClass::CodeSectionINTEL
+          : addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST);
+      SPIRVType *AssignedPtrType =
+          GR->getOrCreateSPIRVPointerType(ElemTy, MI, SC);
 
       // If the ptrcast would be redundant, replace all uses with the source
       // register.
diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
index be88f334d2171..8fd261cfff25b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
@@ -26,6 +26,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
@@ -41,6 +42,7 @@ class SPIRVPrepareFunctions : public ModulePass {
   const SPIRVTargetMachine &TM;
   bool substituteIntrinsicCalls(Function *F);
   Function *removeAggregateTypesFromSignature(Function *F);
+  bool removeAggregateTypesFromCalls(Function *F);
 
 public:
   static char ID;
@@ -469,6 +471,23 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
   return Changed;
 }
 
+static void addFunctionTypeMutation(
+    NamedMDNode *NMD,
+    SmallVector<std::pair<int, Type *>> ChangedTys, StringRef Name) {
+
+    LLVMContext &Ctx = NMD->getParent()->getContext();
+    Type *I32Ty = IntegerType::getInt32Ty(Ctx);
+
+    SmallVector<Metadata *> MDArgs;
+    MDArgs.push_back(MDString::get(Ctx, Name));
+    transform(ChangedTys, std::back_inserter(MDArgs), [=, &Ctx](auto &&CTy) {
+      return MDNode::get(
+          Ctx,
+          {ConstantAsMetadata::get(ConstantInt::get(I32Ty, CTy.first, true)),
+           ValueAsMetadata::get(Constant::getNullValue(CTy.second))});
+    });
+    NMD->addOperand(MDNode::get(Ctx, MDArgs));
+}
 // Returns F if aggregate argument/return types are not present or cloned F
 // function with the types replaced by i32 types. The change in types is
 // noted in 'spv.cloned_funcs' metadata for later restoration.
@@ -503,7 +522,8 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
   FunctionType *NewFTy =
       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
   Function *NewF =
-      Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
+      Function::Create(NewFTy, F->getLinkage(), F->getAddressSpace(),
+                       F->getName(), F->getParent());
 
   ValueToValueMapTy VMap;
   auto NewFArgIt = NewF->arg_begin();
@@ -518,22 +538,17 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
                     Returns);
   NewF->takeName(F);
 
-  NamedMDNode *FuncMD =
-      F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
-  SmallVector<Metadata *, 2> MDArgs;
-  MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
-  for (auto &ChangedTyP : ChangedTypes)
-    MDArgs.push_back(MDNode::get(
-        B.getContext(),
-        {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
-         ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
-  MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
-  FuncMD->addOperand(ThisFuncMD);
+  addFunctionTypeMutation(
+      NewF->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"),
+      std::move(ChangedTypes), NewF->getName());
 
   for (auto *U : make_early_inc_range(F->users())) {
     if (auto *CI = dyn_cast<CallInst>(U))
       CI->mutateFunctionType(NewF->getFunctionType());
-    U->replaceUsesOfWith(F, NewF);
+    if (auto *C = dyn_cast<Constant>(U))
+      C->handleOperandChange(F, NewF);
+    else
+      U->replaceUsesOfWith(F, NewF);
   }
 
   // register the mutation
@@ -543,11 +558,76 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
   return NewF;
 }
 
+// Mutates indirect callsites iff if aggregate argument/return types are present
+// with the types replaced by i32 types. The change in types is noted in
+// 'spv.mutated_callsites' metadata for later restoration.
+bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
+  if (F->isDeclaration() || F->isIntrinsic())
+    return false;
+
+  SmallVector<std::pair<CallBase *, FunctionType *>> Calls;
+  for (auto &&BB : *F) {
+    for (auto &&I : BB) {
+      if (auto *CB = dyn_cast<CallBase>(&I)) {
+        if (!CB->getCalledOperand() || CB->getCalledFunction())
+          continue;
+        if (CB->getType()->isAggregateType() ||
+            any_of(CB->args(),
+                  [](auto &&Arg) { return Arg->getType()->isAggregateType(); }))
+          Calls.emplace_back(CB, nullptr);
+      }
+    }
+  }
+
+  if (Calls.empty())
+    return false;
+
+  IRBuilder<> B(F->getContext());
+
+  for (auto &&[CB, NewFnTy] : Calls) {
+    SmallVector<std::pair<int, Type *>> ChangedTypes;
+    SmallVector<Type *> NewArgTypes;
+
+    if (CB->getType()->isAggregateType())
+      ChangedTypes.emplace_back(-1, CB->getType());
+
+    Type *RetTy = ChangedTypes.empty() ? CB->getType() : B.getInt32Ty();
+    for (auto &&Arg : CB->args()) {
+      if (Arg->getType()->isAggregateType()) {
+        NewArgTypes.push_back(B.getInt32Ty());
+        ChangedTypes.emplace_back(Arg.getOperandNo(), Arg->getType());
+      } else {
+        NewArgTypes.push_back(Arg->getType());
+      }
+    }
+    NewFnTy = FunctionType::get(RetTy, NewArgTypes,
+                                CB->getFunctionType()->isVarArg());
+
+    if (!CB->hasName())
+      CB->setName("spv.mutated_callsite");
+
+    addFunctionTypeMutation(
+      F->getParent()->getOrInsertNamedMetadata("spv.mutated_callsites"),
+      std::move(ChangedTypes),
+      CB->getName());
+  }
+
+  for (auto &&[CB, NewFTy] : Calls) {
+    if (NewFTy->getReturnType() != CB->getType())
+      TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
+          CB, CB->getType());
+    CB->mutateFunctionType(NewFTy);
+  }
+
+  return true;
+}
+
 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
   bool Changed = false;
   for (Function &F : M) {
     Changed |= substituteIntrinsicCalls(&F);
     Changed |= sortBlocks(F);
+    Changed |= removeAggregateTypesFromCalls(&F);
   }
 
   std::vector<Function *> FuncsWorklist;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 8f2fc01da476f..5b6a36de6c526 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -28,6 +28,70 @@
 #include <vector>
 
 namespace llvm {
+namespace SPIRV {
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with aggregate flattening etc.
+// TODO: should these just return nullptr when there's no metadata?
+static FunctionType *extractFunctionTypeFromMetadata(NamedMDNode *NMD,
+                                                     FunctionType *FTy,
+                                                     StringRef Name) {
+  if (!NMD)
+    return FTy;
+
+  constexpr auto getConstInt = [](MDNode *MD, unsigned OpId) -> ConstantInt * {
+    if (MD->getNumOperands() <= OpId)
+      return nullptr;
+    if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(OpId)))
+      return dyn_cast<ConstantInt>(CMeta->getValue());
+    return nullptr;
+  };
+
+  auto It = find_if(NMD->operands(), [Name](MDNode *N) {
+    if (auto *MDS = dyn_cast_or_null<MDString>(N->getOperand(0)))
+      return MDS->getString() == Name;
+    return false;
+  });
+
+  if (It == NMD->op_end())
+    return FTy;
+
+  Type *RetTy = FTy->getReturnType();
+  SmallVector<Type *, 4> PTys(FTy->params());
+
+  for (unsigned I = 1; I != (*It)->getNumOperands(); ++I) {
+    MDNode *MD = dyn_cast<MDNode>((*It)->getOperand(I));
+    assert(MD && "MDNode operand is expected");
+
+    if (auto *Const = getConstInt(MD, 0)) {
+      auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
+      assert(CMeta && "ConstantAsMetadata operand is expected");
+      assert(Const->getSExtValue() >= -1);
+      // Currently -1 indicates return value, greater values mean
+      // argument numbers.
+      if (Const->getSExtValue() == -1)
+        RetTy = CMeta->getType();
+      else
+        PTys[Const->getSExtValue()] = CMeta->getType();
+    }
+  }
+
+  return FunctionType::get(RetTy, PTys, FTy->isVarArg());
+}
+
+FunctionType *getOriginalFunctionType(const Function &F) {
+  return extractFunctionTypeFromMetadata(
+      F.getParent()->getNamedMetadata("spv.cloned_funcs"),
+      F.getFunctionType(), F.getName());
+}
+
+FunctionType *getOriginalFunctionType(const CallBase &CB) {
+  return extractFunctionTypeFromMetadata(
+      CB.getParent()
+        ->getParent()->getParent()->getNamedMetadata("spv.mutated_callsites"),
+      CB.getFunctionType(), CB.getName());
+}
+} // Namespace SPIRV
 
 // The following functions are used to add these string literals as a series of
 // 32-bit integer operands with the correct format, and unpack them if necessary
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 99d9d403ea70c..3da77f1d6c1ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -159,6 +159,11 @@ struct FPFastMathDefaultInfoVector
   }
 };
 
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with aggregate flattening etc.
+FunctionType *getOriginalFunctionType(const Function &F);
+FunctionType *getOriginalFunctionType(const CallBase &CB);
 } // namespace SPIRV
 
 // Add the given string as a series of integer operand, inserting null
diff --git a/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
new file mode 100644
index 0000000000000..97a0cdf56dc5c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
@@ -0,0 +1,108 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --extra_scrub --tool spirv-val --include-generated-funcs --version 6
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpCapability Kernel
+; CHECK-DAG: OpCapability FunctionPointersINTEL
+; CHECK-DAG: OpExtension "SPV_INTEL_function_pointers"
+; CHECK-DAG: OpName %[[#fArray:]] "array"
+; CHECK-DAG: OpName %[[#fStruct:]] "struct"
+
+; CHECK-DAG: %[[#Int8Ty:]] = OpTypeInt 8 0
+; CHECK: %[[#GlobalInt8PtrTy:]] = OpTypePointer CrossWorkgroup %[[#Int8Ty]]
+; CHECK: %[[#VoidTy:]] = OpTypeVoid
+; CHECK: %[[#TestFnTy:]] = OpTypeFunction %[[#VoidTy]] %[[#GlobalInt8PtrTy]]
+; CHECK: %[[#F16Ty:]] = OpTypeFloat 16
+; CHECK: %[[#t_halfTy:]] = OpTypeStruct %[[#F16Ty]]
+; CHECK: %[[#FnTy:]] = OpTypeFunction %[[#t_halfTy]] %[[#GlobalInt8PtrTy]] %[[#t_halfTy]]
+; CHECK: %[[#IntelFnPtrTy:]] = OpTypePointer CodeSectionINTEL %[[#FnTy]]
+; CHECK: %[[#Int8PtrTy:]] = OpTypePointer Function %[[#Int8Ty]]
+; CHECK: %[[#Int32Ty:]] = OpTypeInt 32 0
+; CHECK: %[[#I32Const3:]] = OpConstant %[[#Int32Ty]] 3
+; CHECK: %[[#FnArrTy:]] = OpTypeArray %[[#Int8PtrTy]] %[[#I32Const3]]
+; CHECK: %[[#GlobalFnArrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#FnArrTy]]
+; CHECK: %[[#GlobalFnPtrTy:]] = OpTypePointer CrossWorkgroup %[[#FnTy]]
+; CHECK: %[[#FnPtrTy:]] = OpTypePointer Function %[[#FnTy]]
+; CHECK: %[[#StructWithPfnTy:]] = OpTypeStruct %[[#FnPtrTy]] %[[#FnPtrTy]] %[[#FnPtrTy]]
+; CHECK: %[[#ArrayOfPfnTy:]] = OpTypeArray %[[#FnPtrTy]] %[[#I32Const3]]
+; CHECK: %[[#Int64Ty:]] = OpTypeInt 64 0
+; CHECK: %[[#GlobalStructWithPfnPtrTy:]] = OpTypePointer CrossWorkgroup %[[#StructWithPfnTy]]
+; CHECK: %[[#GlobalArrOfPfnPtrTy:]] = OpTypePointer CrossWorkgroup %[[#ArrayOfPfnTy]]
+; CHECK: %[[#I64Const2:]] = OpConstant %[[#Int64Ty]] 2
+; CHECK: %[[#I64Const1:]] = OpConstant %[[#Int64Ty]] 1
+; CHECK: %[[#I64Const0:]] = OpConstantNull %[[#Int64Ty]]
+; CHECK: %[[#f0Pfn:]] = OpConstantFunctionPointerINTEL %[[#IntelFnPtrTy]] %28
+; CHECK: %[[#f1Pfn:]] = OpConstantFunctionPointerINTEL %[[#IntelFnPtrTy]] %32
+; CHECK: %[[#f2Pfn:]] = OpConstantFunctionPointerINTEL %[[#IntelFnPtrTy]] %36
+; CHECK: %[[#f0Cast:]] = OpSpecConstantOp %[[#FnPtrTy]] Bitcast %[[#f0Pfn]]
+; CHECK: %[[#f1Cast:]] = OpSpecConstantOp %[[#FnPtrTy]] Bitcast %[[#f1Pfn]]
+; CHECK: %[[#f2Cast:]] = OpSpecConstantOp %[[#FnPtrTy]] Bitcast %[[#f2Pfn]]
+; CHECK: %[[#fnptrTy:]] = OpConstantComposite %[[#ArrayOfPfnTy]] %[[#f0Cast]] %[[#f1Cast]] %[[#f2Cast]]
+; CHECK: %[[#fnptr:]] = OpVariable %[[#GlobalArrOfPfnPtrTy]] CrossWorkgroup %[[#fnptrTy]]
+; CHECK: %[[#fnstructTy:]] = OpConstantComposite %[[#StructWithPfnTy]] %[[#f0Cast]] %[[#f1Cast]] %[[#f2Cast]]
+; CHECK: %[[#fnstruct:]] = OpVariable %[[#GlobalStructWithPfnPtrTy:]] CrossWorkgroup %[[#fnstructTy]]
+; CHECK-DAG: %[[#GlobalInt8PtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#Int8PtrTy]]
+; CHECK: %[[#StructWithPtrTy:]] = OpTypeStruct %[[#Int8PtrTy]] %[[#Int8PtrTy]] %[[#Int8PtrTy]]
+; CHECK: %[[#GlobalStructWithPtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#StructWithPtrTy]]
+; CHECK: %[[#I32Const2:]] = OpConstant %[[#Int32Ty]] 2
+; CHECK: %[[#I32Const1:]] = OpConstant %[[#Int32Ty]] 1
+; CHECK: %[[#I32Const0:]] = OpConstantNull %[[#Int32Ty]]
+; CHECK: %[[#GlobalFnPtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#FnPtrTy]]
+%t_half = type { half }
+%struct.anon = type { ptr, ptr, ptr }
+
+declare spir_func %t_half @f0(ptr addrspace(1) %a, %t_half %b)
+declare spir_func %t_half @f1(ptr addrspace(1) %a, %t_half %b)
+declare spir_func %t_half @f2(ptr addrspace(1) %a, %t_half %b)
+
+ at fnptr = addrspace(1) constant [3 x ptr] [ptr @f0, ptr @f1, ptr @f2]
+ at fnstruct = addrspace(1) constant %struct.anon { ptr @f0, ptr @f1, ptr @f2 }, align 8
+
+; CHECK-DAG: %[[#fArray]] = OpFunction %[[#VoidTy]] None %[[#TestFnTy]]
+;	CHECK-DAG: %[[#fnptrCast:]] = OpBitcast %[[#GlobalFnArrPtrTy]] %[[#fnptr]]
+; CHECK: %[[#f0GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalFnArrPtrTy]] %[[#fnptrCast]] %[[#I64Const0]]
+; CHECK: %[[#f0GEPCast:]] = OpBitcast %[[#GlobalFnPtrTy]] %[[#f0GEP]]
+; CHECK: %[[#f1GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalFnArrPtrTy]] %[[#fnptrCast]] %[[#I64Const1]]
+; CHECK: %[[#f1GEPCast:]] = OpBitcast %[[#GlobalFnPtrTy]] %[[#f1GEP]]
+; CHECK: %[[#f2GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalFnArrPtrTy]] %[[#fnptrCast]] %[[#I64Const2]]
+; CHECK: %[[#f2GEPCast:]] = OpBitcast %[[#GlobalFnPtrTy]] %[[#f2GEP]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f0GEPCast]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f1GEPCast]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f2GEPCast]]
+define spir_func void @array(ptr addrspace(1) %p) {
+entry:
+  %f = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 0
+  %g = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 1
+  %h = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 2
+  %0 = call spir_func addrspace(1) %t_half %f(ptr addrspace(1) %p, %t_half undef)
+  %1 = call spir_func addrspace(1) %t_half %g(ptr addrspace(1) %p, %t_half %0)
+  %2 = call spir_func addrspace(1) %t_half %h(ptr addrspace(1) %p, %t_half %1)
+
+  ret void
+}
+
+; CHECK-DAG: %[[#fStruct]] = OpFunction %[[#VoidTy]] None %[[#TestFnTy]]
+; CHECK-DAG: %[[#fnStructCast0:]] = OpBitcast %[[#GlobalInt8PtrPtrTy]] %[[#fnstruct]]
+; CHECK: %[[#fnStructCast1:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#fnStructCast0]]
+; CHECK: %[[#f0Load:]] = OpLoad %[[#FnPtrTy]] %[[#fnStructCast1]]
+; CHECK: %[[#fnStructCast2:]] = OpBitcast %[[#GlobalStructWithPtrPtrTy]] %[[#fnstruct]]
+; CHECK: %[[#f1GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalInt8PtrPtrTy]] %[[#fnStructCast2]] %[[#I32Const0]] %[[#I32Const1]]
+; CHECK: %[[#f1GEPCast:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#f1GEP]]
+; CHECK: %[[#f1Load:]] = OpLoad %[[#FnPtrTy]] %[[#f1GEPCast]]
+; CHECK: %[[f2GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalInt8PtrPtrTy]] %[[#fnStructCast2]] %[[#I32Const0]] %[[#I32Const2]]
+; CHECK: %[[#f2GEPCast:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#f2GEP]]
+; CHECK: %[[#f2Load:]] = OpLoad %[[#FnPtrTy]] %[[#f2GEPCast]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f0Load]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f1Load]]
+; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f2Load]]
+define spir_func void @struct(ptr addrspace(1) %p) {
+entry:
+  %f = load ptr, ptr addrspace(1) @fnstruct
+  %g = load ptr, ptr addrspace(1) getelementptr inbounds (%struct.anon, ptr addrspace(1) @fnstruct, i32 0, i32 1)
+  %h = load ptr, ptr addrspace(1) getelementptr inbounds (%struct.anon, ptr addrspace(1) @fnstruct, i32 0, i32 2)
+  %0 = call spir_func noundef %t_half %f(ptr addrspace(1) %p, %t_half undef)
+  %1 = call spir_func noundef %t_half %g(ptr addrspace(1) %p, %t_half %0)
+  %2 = call spir_func noundef %t_half %h(ptr addrspace(1) %p, %t_half %1)
+
+  ret void
+}

>From c1e9fc1916f561994e66e97c449f9e55f63e14fe Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Wed, 26 Nov 2025 14:59:17 +0000
Subject: [PATCH 2/7] Fix formatting.

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   |  5 +--
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |  7 ++--
 .../Target/SPIRV/SPIRVPrepareFunctions.cpp    | 42 +++++++++----------
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  8 ++--
 4 files changed, 31 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 6b5c602d4ac93..b61b9e5d4fc7f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -131,7 +131,6 @@ fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
   return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
 }
 
-
 static SPIRV::AccessQualifier::AccessQualifier
 getArgAccessQual(const Function &F, unsigned ArgIdx) {
   if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
@@ -494,8 +493,8 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
         FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
     // SPIR-V pointer to function type:
     auto SC = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
-        ? SPIRV::StorageClass::CodeSectionINTEL
-        : SPIRV::StorageClass::Function;
+                  ? SPIRV::StorageClass::CodeSectionINTEL
+                  : SPIRV::StorageClass::Function;
     SPIRVType *IndirectFuncPtrTy =
         GR->getOrCreateSPIRVPointerType(SpirvFuncTy, MIRBuilder, SC);
     // Correct the Callee type
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index b273599596a35..f3c886aafc131 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -257,9 +257,10 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       Register Def = MI.getOperand(0).getReg();
       Register Source = MI.getOperand(2).getReg();
       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
-      auto SC = isa<FunctionType>(ElemTy)
-          ? SPIRV::StorageClass::CodeSectionINTEL
-          : addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST);
+      auto SC =
+          isa<FunctionType>(ElemTy)
+              ? SPIRV::StorageClass::CodeSectionINTEL
+              : addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST);
       SPIRVType *AssignedPtrType =
           GR->getOrCreateSPIRVPointerType(ElemTy, MI, SC);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
index 8fd261cfff25b..405269337e911 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
@@ -471,22 +471,22 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
   return Changed;
 }
 
-static void addFunctionTypeMutation(
-    NamedMDNode *NMD,
-    SmallVector<std::pair<int, Type *>> ChangedTys, StringRef Name) {
-
-    LLVMContext &Ctx = NMD->getParent()->getContext();
-    Type *I32Ty = IntegerType::getInt32Ty(Ctx);
-
-    SmallVector<Metadata *> MDArgs;
-    MDArgs.push_back(MDString::get(Ctx, Name));
-    transform(ChangedTys, std::back_inserter(MDArgs), [=, &Ctx](auto &&CTy) {
-      return MDNode::get(
-          Ctx,
-          {ConstantAsMetadata::get(ConstantInt::get(I32Ty, CTy.first, true)),
-           ValueAsMetadata::get(Constant::getNullValue(CTy.second))});
-    });
-    NMD->addOperand(MDNode::get(Ctx, MDArgs));
+static void
+addFunctionTypeMutation(NamedMDNode *NMD,
+                        SmallVector<std::pair<int, Type *>> ChangedTys,
+                        StringRef Name) {
+
+  LLVMContext &Ctx = NMD->getParent()->getContext();
+  Type *I32Ty = IntegerType::getInt32Ty(Ctx);
+
+  SmallVector<Metadata *> MDArgs;
+  MDArgs.push_back(MDString::get(Ctx, Name));
+  transform(ChangedTys, std::back_inserter(MDArgs), [=, &Ctx](auto &&CTy) {
+    return MDNode::get(
+        Ctx, {ConstantAsMetadata::get(ConstantInt::get(I32Ty, CTy.first, true)),
+              ValueAsMetadata::get(Constant::getNullValue(CTy.second))});
+  });
+  NMD->addOperand(MDNode::get(Ctx, MDArgs));
 }
 // Returns F if aggregate argument/return types are not present or cloned F
 // function with the types replaced by i32 types. The change in types is
@@ -572,8 +572,9 @@ bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
         if (!CB->getCalledOperand() || CB->getCalledFunction())
           continue;
         if (CB->getType()->isAggregateType() ||
-            any_of(CB->args(),
-                  [](auto &&Arg) { return Arg->getType()->isAggregateType(); }))
+            any_of(CB->args(), [](auto &&Arg) {
+              return Arg->getType()->isAggregateType();
+            }))
           Calls.emplace_back(CB, nullptr);
       }
     }
@@ -607,9 +608,8 @@ bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
       CB->setName("spv.mutated_callsite");
 
     addFunctionTypeMutation(
-      F->getParent()->getOrInsertNamedMetadata("spv.mutated_callsites"),
-      std::move(ChangedTypes),
-      CB->getName());
+        F->getParent()->getOrInsertNamedMetadata("spv.mutated_callsites"),
+        std::move(ChangedTypes), CB->getName());
   }
 
   for (auto &&[CB, NewFTy] : Calls) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 5b6a36de6c526..c23ef9cc6923e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -81,14 +81,14 @@ static FunctionType *extractFunctionTypeFromMetadata(NamedMDNode *NMD,
 
 FunctionType *getOriginalFunctionType(const Function &F) {
   return extractFunctionTypeFromMetadata(
-      F.getParent()->getNamedMetadata("spv.cloned_funcs"),
-      F.getFunctionType(), F.getName());
+      F.getParent()->getNamedMetadata("spv.cloned_funcs"), F.getFunctionType(),
+      F.getName());
 }
 
 FunctionType *getOriginalFunctionType(const CallBase &CB) {
   return extractFunctionTypeFromMetadata(
-      CB.getParent()
-        ->getParent()->getParent()->getNamedMetadata("spv.mutated_callsites"),
+      CB.getParent()->getParent()->getParent()->getNamedMetadata(
+          "spv.mutated_callsites"),
       CB.getFunctionType(), CB.getName());
 }
 } // Namespace SPIRV

>From 211f7f124a4dd2d87945c69394f73e47b3d09c91 Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Wed, 26 Nov 2025 15:04:01 +0000
Subject: [PATCH 3/7] Fix test.

---
 .../SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll   | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
index 97a0cdf56dc5c..e71e34bea41bd 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
@@ -1,6 +1,5 @@
-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --extra_scrub --tool spirv-val --include-generated-funcs --version 6
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - -filetype=obj | spirv-val %}
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK: OpCapability Kernel
 ; CHECK-DAG: OpCapability FunctionPointersINTEL
@@ -89,7 +88,7 @@ entry:
 ; CHECK: %[[#f1GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalInt8PtrPtrTy]] %[[#fnStructCast2]] %[[#I32Const0]] %[[#I32Const1]]
 ; CHECK: %[[#f1GEPCast:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#f1GEP]]
 ; CHECK: %[[#f1Load:]] = OpLoad %[[#FnPtrTy]] %[[#f1GEPCast]]
-; CHECK: %[[f2GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalInt8PtrPtrTy]] %[[#fnStructCast2]] %[[#I32Const0]] %[[#I32Const2]]
+; CHECK: %[[#f2GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalInt8PtrPtrTy]] %[[#fnStructCast2]] %[[#I32Const0]] %[[#I32Const2]]
 ; CHECK: %[[#f2GEPCast:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#f2GEP]]
 ; CHECK: %[[#f2Load:]] = OpLoad %[[#FnPtrTy]] %[[#f2GEPCast]]
 ; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f0Load]]

>From ac3bfbcef6d3b64759b28a6dc189c7ceb2dedfa7 Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Wed, 26 Nov 2025 15:05:15 +0000
Subject: [PATCH 4/7] Use `poison` instead of `under`.

---
 .../SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll    | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
index e71e34bea41bd..ec3fd41f7de9e 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll
@@ -73,7 +73,7 @@ entry:
   %f = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 0
   %g = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 1
   %h = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 2
-  %0 = call spir_func addrspace(1) %t_half %f(ptr addrspace(1) %p, %t_half undef)
+  %0 = call spir_func addrspace(1) %t_half %f(ptr addrspace(1) %p, %t_half poison)
   %1 = call spir_func addrspace(1) %t_half %g(ptr addrspace(1) %p, %t_half %0)
   %2 = call spir_func addrspace(1) %t_half %h(ptr addrspace(1) %p, %t_half %1)
 
@@ -99,7 +99,7 @@ entry:
   %f = load ptr, ptr addrspace(1) @fnstruct
   %g = load ptr, ptr addrspace(1) getelementptr inbounds (%struct.anon, ptr addrspace(1) @fnstruct, i32 0, i32 1)
   %h = load ptr, ptr addrspace(1) getelementptr inbounds (%struct.anon, ptr addrspace(1) @fnstruct, i32 0, i32 2)
-  %0 = call spir_func noundef %t_half %f(ptr addrspace(1) %p, %t_half undef)
+  %0 = call spir_func noundef %t_half %f(ptr addrspace(1) %p, %t_half poison)
   %1 = call spir_func noundef %t_half %g(ptr addrspace(1) %p, %t_half %0)
   %2 = call spir_func noundef %t_half %h(ptr addrspace(1) %p, %t_half %1)
 

>From 2682f906aba36f018a8e063820f1d102b2eac12b Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Wed, 26 Nov 2025 15:19:57 +0000
Subject: [PATCH 5/7] Adopt review suggestions.

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  4 +---
 .../Target/SPIRV/SPIRVPrepareFunctions.cpp    | 21 +++++++++----------
 2 files changed, 11 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index b37e6d8ce4ea3..7469d7c585077 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1074,8 +1074,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
   FunctionType *FTy = SPIRV::getOriginalFunctionType(*CI);
   bool IsNewFTy = false, IsIncomplete = false;
   SmallVector<Type *, 4> ArgTys;
-  unsigned ParmIdx = 0;
-  for (Value *Arg : CI->args()) {
+  for (auto &&[ParmIdx, Arg] : llvm::enumerate(CI->args())) {
     Type *ArgTy = Arg->getType();
     if (ArgTy->isPointerTy()) {
       if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
@@ -1090,7 +1089,6 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
       ArgTy = FTy->getFunctionParamType(ParmIdx);
     }
     ArgTys.push_back(ArgTy);
-    ++ParmIdx;
   }
   Type *RetTy = FTy->getReturnType();
   if (CI->getType()->isPointerTy()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
index 405269337e911..309995c58167a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
@@ -26,6 +26,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
@@ -566,17 +567,15 @@ bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
     return false;
 
   SmallVector<std::pair<CallBase *, FunctionType *>> Calls;
-  for (auto &&BB : *F) {
-    for (auto &&I : BB) {
-      if (auto *CB = dyn_cast<CallBase>(&I)) {
-        if (!CB->getCalledOperand() || CB->getCalledFunction())
-          continue;
-        if (CB->getType()->isAggregateType() ||
-            any_of(CB->args(), [](auto &&Arg) {
-              return Arg->getType()->isAggregateType();
-            }))
-          Calls.emplace_back(CB, nullptr);
-      }
+  for (auto &&I : instructions(F)) {
+    if (auto *CB = dyn_cast<CallBase>(&I)) {
+      if (!CB->getCalledOperand() || CB->getCalledFunction())
+        continue;
+      if (CB->getType()->isAggregateType() ||
+          any_of(CB->args(), [](auto &&Arg) {
+            return Arg->getType()->isAggregateType();
+          }))
+        Calls.emplace_back(CB, nullptr);
     }
   }
 

>From 7897636ca21b718e18f31a020eb8f5ee792f83ea Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Wed, 26 Nov 2025 15:35:06 +0000
Subject: [PATCH 6/7] Adopt review suggestion.

---
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c23ef9cc6923e..5757e2a382b85 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -87,8 +87,7 @@ FunctionType *getOriginalFunctionType(const Function &F) {
 
 FunctionType *getOriginalFunctionType(const CallBase &CB) {
   return extractFunctionTypeFromMetadata(
-      CB.getParent()->getParent()->getParent()->getNamedMetadata(
-          "spv.mutated_callsites"),
+      CB.getModule()->getNamedMetadata("spv.mutated_callsites"),
       CB.getFunctionType(), CB.getName());
 }
 } // Namespace SPIRV

>From 197e0f51b0ddf67e2739cfae8575a4b599dffbcf Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Wed, 26 Nov 2025 15:38:02 +0000
Subject: [PATCH 7/7] Adopt review suggestion.

---
 llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
index 309995c58167a..97c1f76f46131 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
@@ -588,10 +588,12 @@ bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
     SmallVector<std::pair<int, Type *>> ChangedTypes;
     SmallVector<Type *> NewArgTypes;
 
-    if (CB->getType()->isAggregateType())
-      ChangedTypes.emplace_back(-1, CB->getType());
+    Type* RetTy = CB->getType();
+    if (RetTy->isAggregateType()) {
+      ChangedTypes.emplace_back(-1, RetTy);
+      RetTy = B.getInt32Ty();
+    }
 
-    Type *RetTy = ChangedTypes.empty() ? CB->getType() : B.getInt32Ty();
     for (auto &&Arg : CB->args()) {
       if (Arg->getType()->isAggregateType()) {
         NewArgTypes.push_back(B.getInt32Ty());



More information about the llvm-commits mailing list