[llvm] [SPIRV] Implement type deduction and reference to function declarations for indirect calls using SPV_INTEL_function_pointers (PR #111159)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 11 12:29:44 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/111159

>From 156c51ebf9d51fe3fdda9e9f7c8f2b564c495fb1 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 4 Oct 2024 06:54:34 -0700
Subject: [PATCH 1/6] fix indirect calls for function pointers

---
 llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp     | 10 +++
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   |  8 +++
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 53 +++++++++++++++-
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          | 14 +++++
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |  3 +
 .../fp-simple-hierarchy.ll                    | 63 +++++++++++++++++++
 6 files changed, 149 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index 55b41627802096..b078b22c7057ef 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -600,6 +600,16 @@ void SPIRVAsmPrinter::outputModuleSections() {
 }
 
 bool SPIRVAsmPrinter::doInitialization(Module &M) {
+  // Discard the internal service function
+  for (Function &F : M) {
+    if (!F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
+      continue;
+    getAnalysis<MachineModuleInfoWrapperPass>()
+        .getMMI()
+        .deleteMachineFunctionFor(F);
+    break;
+  }
+
   ModuleSectionsEmitted = false;
   // We need to call the parent's one explicitly.
   return AsmPrinter::doInitialization(M);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 27a9cb0ba9b8c0..59256e81951d73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -36,6 +36,10 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
                                     const Value *Val, ArrayRef<Register> VRegs,
                                     FunctionLoweringInfo &FLI,
                                     Register SwiftErrorVReg) const {
+  // Discard the internal service function
+  if (FLI.Fn && FLI.Fn->getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
+    return true;
+
   // Maybe run postponed production of types for function pointers
   if (IndirectCalls.size() > 0) {
     produceIndirectPtrTypes(MIRBuilder);
@@ -280,6 +284,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
                                              const Function &F,
                                              ArrayRef<ArrayRef<Register>> VRegs,
                                              FunctionLoweringInfo &FLI) const {
+  // Discard the internal service function
+  if (F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
+    return true;
+
   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
   GR->setCurrentFunc(MIRBuilder.getMF());
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 370df24bc7af9e..43e70ae2032dfe 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -147,6 +147,10 @@ class SPIRVEmitIntrinsics
   void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
                             CallInst *AssignCI);
 
+  bool runOnFunction(Function &F);
+  bool postprocessTypes();
+  bool processFunctionPointers(Module &M);
+
 public:
   static char ID;
   SPIRVEmitIntrinsics() : ModulePass(ID) {
@@ -173,8 +177,6 @@ class SPIRVEmitIntrinsics
   StringRef getPassName() const override { return "SPIRV emit intrinsics"; }
 
   bool runOnModule(Module &M) override;
-  bool runOnFunction(Function &F);
-  bool postprocessTypes();
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     ModulePass::getAnalysisUsage(AU);
@@ -1825,10 +1827,57 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   }
 
   Changed |= postprocessTypes();
+  Changed |= processFunctionPointers(M);
 
   return Changed;
 }
 
+bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
+  bool IsExt = false;
+  SmallVector<Function*> Worklist;
+  for (auto &F : M) {
+    if (!IsExt) {
+      if (!TM->getSubtarget<SPIRVSubtarget>(F).canUseExtension(
+              SPIRV::Extension::SPV_INTEL_function_pointers))
+        return false;
+      IsExt = true;
+    }
+    if (!F.isDeclaration() || F.isIntrinsic())
+      continue;
+    for (User *U : F.users()) {
+      CallInst *CI = dyn_cast<CallInst>(U);
+      if (!CI || CI->getCalledFunction() != &F) {
+        Worklist.push_back(&F);
+        break;
+      }
+    }
+  }
+  if (Worklist.empty())
+    return false;
+
+  std::string ServiceFunName = SPIRV_BACKEND_SERVICE_FUN_NAME;
+  if (!getVacantFunctionName(M, ServiceFunName))
+    report_fatal_error(
+        "cannot allocate a name for the internal service function");
+  LLVMContext &Ctx = M.getContext();
+  Function *SF =
+      Function::Create(FunctionType::get(Type::getVoidTy(Ctx), {}, false),
+                       GlobalValue::PrivateLinkage, ServiceFunName, M);
+  SF->addFnAttr(SPIRV_BACKEND_SERVICE_FUN_NAME, "");
+  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", SF);
+  IRBuilder<> IRB(BB);
+
+  for (Function *F : Worklist) {
+    SmallVector<Value *> Args;
+    for (const auto &Arg : F->args())
+      Args.push_back(PoisonValue::get(Arg.getType()));
+    IRB.CreateCall(F, Args);
+  }
+  IRB.CreateRetVoid();
+
+  return true;
+}
+
 ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
   return new SPIRVEmitIntrinsics(TM);
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index d204a8ac7975d8..dff33b16b9cfcf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -598,4 +598,18 @@ MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) {
   return MaybeDef;
 }
 
+bool getVacantFunctionName(Module &M, std::string &Name) {
+  // It's a bit of paranoia, but still we don't want to have even a chance that
+  // the loop will work for too long.
+  constexpr unsigned MaxIters = 1024;
+  for (unsigned I = 0; I < MaxIters; ++I) {
+    std::string OrdName = Name + Twine(I).str();
+    if (!M.getFunction(OrdName)) {
+      Name = OrdName;
+      return true;
+    }
+  }
+  return false;
+}
+
 } // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index f7e8a827c2767f..83e717e6ea58fd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -341,5 +341,8 @@ inline const Type *unifyPtrType(const Type *Ty) {
 
 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
 
+#define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun"
+bool getVacantFunctionName(Module &M, std::string &Name);
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll
new file mode 100644
index 00000000000000..5141259f63bdd7
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll
@@ -0,0 +1,63 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpFunction
+
+%classid = type { %arrayid }
+%arrayid = type { [1 x i64] }
+%struct.obj_storage_t = type { %storage }
+%storage = type { [8 x i8] }
+
+ at _ZTV12IncrementBy8 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy89incrementEPi to ptr addrspace(4))] }, align 8
+ at _ZTV13BaseIncrement = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN13BaseIncrement9incrementEPi to ptr addrspace(4))] }, align 8
+ at _ZTV12IncrementBy4 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy49incrementEPi to ptr addrspace(4))] }, align 8
+ at _ZTV12IncrementBy2 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy29incrementEPi to ptr addrspace(4))] }, align 8
+
+define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%classid) align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) {
+entry:
+  %0 = load i64, ptr %_arg_StorageAcc3, align 8
+  %add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %0
+  %arrayidx.ascast.i = addrspacecast ptr addrspace(1) %add.ptr.i to ptr addrspace(4)
+  %cmp.i = icmp ugt i32 %_arg_TestCase, 3
+  br i1 %cmp.i, label %entry.critedge, label %if.end.1
+
+entry.critedge: ; preds = %entry
+  %vtable.i.pre = load ptr addrspace(4), ptr addrspace(4) null, align 8
+  br label %exit
+
+if.end.1:                                         ; preds = %entry
+  switch i32 %_arg_TestCase, label %if.end.5 [
+    i32 0, label %if.end.2
+    i32 1, label %if.end.3
+    i32 2, label %if.end.4
+  ]
+
+if.end.5:                                 ; preds = %if.end.1
+  store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16), ptr addrspace(1) %add.ptr.i, align 8
+  br label %exit
+
+if.end.4:                                   ; preds = %if.end.1
+  store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16), ptr addrspace(1) %add.ptr.i, align 8
+  br label %exit
+
+if.end.3:                                     ; preds = %if.end.1
+  store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16), ptr addrspace(1) %add.ptr.i, align 8
+  br label %exit
+
+if.end.2:                                       ; preds = %if.end.1
+  store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16), ptr addrspace(1) %add.ptr.i, align 8
+  br label %exit
+
+exit: ; preds = %if.end.2, %if.end.3, %if.end.4, %if.end.5, %entry.critedge
+  %vtable.i = phi ptr addrspace(4) [ %vtable.i.pre, %entry.critedge ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16) to i64) to ptr addrspace(4)), %if.end.5 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16) to i64) to ptr addrspace(4)), %if.end.4 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16) to i64) to ptr addrspace(4)), %if.end.3 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16) to i64) to ptr addrspace(4)), %if.end.2 ]
+  %retval.0.i = phi ptr addrspace(4) [ null, %entry.critedge ], [ %arrayidx.ascast.i, %if.end.5 ], [ %arrayidx.ascast.i, %if.end.4 ], [ %arrayidx.ascast.i, %if.end.3 ], [ %arrayidx.ascast.i, %if.end.2 ]
+  %1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4)
+  %2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8
+  tail call spir_func addrspace(4) void %2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %1)
+  ret void
+}
+
+declare dso_local spir_func void @_ZN13BaseIncrement9incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef)
+declare dso_local spir_func void @_ZN12IncrementBy29incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef)
+declare dso_local spir_func void @_ZN12IncrementBy49incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef)
+declare dso_local spir_func void @_ZN12IncrementBy89incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef)

>From fb8928d9fad3f3154a7da615bcfaeb88d2690e4d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 4 Oct 2024 11:21:08 -0700
Subject: [PATCH 2/6] do not emit anything if it's an internal service function

---
 llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp      | 18 +++++++-----------
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp    | 17 +++++++++++++++--
 .../fp-simple-hierarchy.ll                     |  2 +-
 3 files changed, 23 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index b078b22c7057ef..c0795146e9b923 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -149,6 +149,10 @@ void SPIRVAsmPrinter::outputOpFunctionEnd() {
 
 // Emit OpFunctionEnd at the end of MF and clear BBNumToRegMap.
 void SPIRVAsmPrinter::emitFunctionBodyEnd() {
+  // Do not emit anything if it's an internal service function.
+  if (MF->getFunction().getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
+    return;
+
   outputOpFunctionEnd();
   MAI->BBNumToRegMap.clear();
 }
@@ -162,7 +166,9 @@ void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
 }
 
 void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
-  assert(!MBB.empty() && "MBB is empty!");
+  // Do not emit anything if it's an internal service function.
+  if (MBB.empty())
+    return;
 
   // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so
   // OpLabel should be output after them.
@@ -600,16 +606,6 @@ void SPIRVAsmPrinter::outputModuleSections() {
 }
 
 bool SPIRVAsmPrinter::doInitialization(Module &M) {
-  // Discard the internal service function
-  for (Function &F : M) {
-    if (!F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
-      continue;
-    getAnalysis<MachineModuleInfoWrapperPass>()
-        .getMMI()
-        .deleteMachineFunctionFor(F);
-    break;
-  }
-
   ModuleSectionsEmitted = false;
   // We need to call the parent's one explicitly.
   return AsmPrinter::doInitialization(M);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 59256e81951d73..f10de1d2104125 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -36,8 +36,11 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
                                     const Value *Val, ArrayRef<Register> VRegs,
                                     FunctionLoweringInfo &FLI,
                                     Register SwiftErrorVReg) const {
-  // Discard the internal service function
-  if (FLI.Fn && FLI.Fn->getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
+  // Ignore if called from the internal service function
+  if (MIRBuilder.getMF()
+          .getFunction()
+          .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
+          .isValid())
     return true;
 
   // Maybe run postponed production of types for function pointers
@@ -497,6 +500,16 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
 
 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
                                   CallLoweringInfo &Info) const {
+  // Ignore if called from the internal service function
+  if (MIRBuilder.getMF()
+          .getFunction()
+          .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
+          .isValid()) {
+    // insert a no-op
+    MIRBuilder.buildTrap();
+    return true;
+  }
+
   // Currently call returns should have single vregs.
   // TODO: handle the case of multiple registers.
   if (Info.OrigRet.Regs.size() > 1)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll
index 5141259f63bdd7..0178e1192d7ea7 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll
@@ -1,4 +1,4 @@
-; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
 ; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK: OpFunction

>From 91b9adda55a9aa10c1a186fe38a93a6bde5236c2 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 4 Oct 2024 11:46:24 -0700
Subject: [PATCH 3/6] code format

---
 llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp     |  4 +-
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 92 +++++++++----------
 2 files changed, 49 insertions(+), 47 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index c0795146e9b923..1b85b72bc690ed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -150,7 +150,9 @@ void SPIRVAsmPrinter::outputOpFunctionEnd() {
 // Emit OpFunctionEnd at the end of MF and clear BBNumToRegMap.
 void SPIRVAsmPrinter::emitFunctionBodyEnd() {
   // Do not emit anything if it's an internal service function.
-  if (MF->getFunction().getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
+  if (MF->getFunction()
+          .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
+          .isValid())
     return;
 
   outputOpFunctionEnd();
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 43e70ae2032dfe..e9dfdde24ff3ba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1673,6 +1673,52 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
   }
 }
 
+bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
+  bool IsExt = false;
+  SmallVector<Function *> Worklist;
+  for (auto &F : M) {
+    if (!IsExt) {
+      if (!TM->getSubtarget<SPIRVSubtarget>(F).canUseExtension(
+              SPIRV::Extension::SPV_INTEL_function_pointers))
+        return false;
+      IsExt = true;
+    }
+    if (!F.isDeclaration() || F.isIntrinsic())
+      continue;
+    for (User *U : F.users()) {
+      CallInst *CI = dyn_cast<CallInst>(U);
+      if (!CI || CI->getCalledFunction() != &F) {
+        Worklist.push_back(&F);
+        break;
+      }
+    }
+  }
+  if (Worklist.empty())
+    return false;
+
+  std::string ServiceFunName = SPIRV_BACKEND_SERVICE_FUN_NAME;
+  if (!getVacantFunctionName(M, ServiceFunName))
+    report_fatal_error(
+        "cannot allocate a name for the internal service function");
+  LLVMContext &Ctx = M.getContext();
+  Function *SF =
+      Function::Create(FunctionType::get(Type::getVoidTy(Ctx), {}, false),
+                       GlobalValue::PrivateLinkage, ServiceFunName, M);
+  SF->addFnAttr(SPIRV_BACKEND_SERVICE_FUN_NAME, "");
+  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", SF);
+  IRBuilder<> IRB(BB);
+
+  for (Function *F : Worklist) {
+    SmallVector<Value *> Args;
+    for (const auto &Arg : F->args())
+      Args.push_back(PoisonValue::get(Arg.getType()));
+    IRB.CreateCall(F, Args);
+  }
+  IRB.CreateRetVoid();
+
+  return true;
+}
+
 bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   if (Func.isDeclaration())
     return false;
@@ -1832,52 +1878,6 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   return Changed;
 }
 
-bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
-  bool IsExt = false;
-  SmallVector<Function*> Worklist;
-  for (auto &F : M) {
-    if (!IsExt) {
-      if (!TM->getSubtarget<SPIRVSubtarget>(F).canUseExtension(
-              SPIRV::Extension::SPV_INTEL_function_pointers))
-        return false;
-      IsExt = true;
-    }
-    if (!F.isDeclaration() || F.isIntrinsic())
-      continue;
-    for (User *U : F.users()) {
-      CallInst *CI = dyn_cast<CallInst>(U);
-      if (!CI || CI->getCalledFunction() != &F) {
-        Worklist.push_back(&F);
-        break;
-      }
-    }
-  }
-  if (Worklist.empty())
-    return false;
-
-  std::string ServiceFunName = SPIRV_BACKEND_SERVICE_FUN_NAME;
-  if (!getVacantFunctionName(M, ServiceFunName))
-    report_fatal_error(
-        "cannot allocate a name for the internal service function");
-  LLVMContext &Ctx = M.getContext();
-  Function *SF =
-      Function::Create(FunctionType::get(Type::getVoidTy(Ctx), {}, false),
-                       GlobalValue::PrivateLinkage, ServiceFunName, M);
-  SF->addFnAttr(SPIRV_BACKEND_SERVICE_FUN_NAME, "");
-  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", SF);
-  IRBuilder<> IRB(BB);
-
-  for (Function *F : Worklist) {
-    SmallVector<Value *> Args;
-    for (const auto &Arg : F->args())
-      Args.push_back(PoisonValue::get(Arg.getType()));
-    IRB.CreateCall(F, Args);
-  }
-  IRB.CreateRetVoid();
-
-  return true;
-}
-
 ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
   return new SPIRVEmitIntrinsics(TM);
 }

>From 7f79653836985507cd724775aa95f092414ead0e Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 11 Oct 2024 02:51:30 -0700
Subject: [PATCH 4/6] improve type deduction for phi and call base

---
 llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp     |  16 +-
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   |  20 +--
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 141 +++++++++++-------
 3 files changed, 110 insertions(+), 67 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index 1b85b72bc690ed..8210e20ce5b10e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -78,6 +78,11 @@ class SPIRVAsmPrinter : public AsmPrinter {
   void outputExecutionMode(const Module &M);
   void outputAnnotations(const Module &M);
   void outputModuleSections();
+  bool isHidden() {
+    return MF->getFunction()
+        .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
+        .isValid();
+  }
 
   void emitInstruction(const MachineInstr *MI) override;
   void emitFunctionEntryLabel() override {}
@@ -131,7 +136,7 @@ void SPIRVAsmPrinter::emitFunctionHeader() {
   TII = ST->getInstrInfo();
   const Function &F = MF->getFunction();
 
-  if (isVerbose()) {
+  if (isVerbose() && !isHidden()) {
     OutStreamer->getCommentOS()
         << "-- Begin function "
         << GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n';
@@ -150,16 +155,17 @@ void SPIRVAsmPrinter::outputOpFunctionEnd() {
 // Emit OpFunctionEnd at the end of MF and clear BBNumToRegMap.
 void SPIRVAsmPrinter::emitFunctionBodyEnd() {
   // Do not emit anything if it's an internal service function.
-  if (MF->getFunction()
-          .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
-          .isValid())
+  if (isHidden())
     return;
-
   outputOpFunctionEnd();
   MAI->BBNumToRegMap.clear();
 }
 
 void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
+  // Do not emit anything if it's an internal service function.
+  if (isHidden())
+    return;
+
   MCInst LabelInst;
   LabelInst.setOpcode(SPIRV::OpLabel);
   LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index f10de1d2104125..f8ce02a13c0f67 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -500,16 +500,6 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
 
 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
                                   CallLoweringInfo &Info) const {
-  // Ignore if called from the internal service function
-  if (MIRBuilder.getMF()
-          .getFunction()
-          .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
-          .isValid()) {
-    // insert a no-op
-    MIRBuilder.buildTrap();
-    return true;
-  }
-
   // Currently call returns should have single vregs.
   // TODO: handle the case of multiple registers.
   if (Info.OrigRet.Regs.size() > 1)
@@ -597,6 +587,16 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
   }
 
+  // Ignore the call if it's called from the internal service function
+  if (MIRBuilder.getMF()
+          .getFunction()
+          .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
+          .isValid()) {
+    // insert a no-op
+    MIRBuilder.buildTrap();
+    return true;
+  }
+
   unsigned CallOp;
   if (Info.CB->isIndirectCall()) {
     if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e9dfdde24ff3ba..4ac06cc19f03dc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -386,7 +386,8 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
 // Traverse User instructions to deduce an element pointer type of the operand.
 Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
     Value *Op, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8) {
-  if (!Op || !isPointerTy(Op->getType()))
+  if (!Op || !isPointerTy(Op->getType()) || isa<ConstantPointerNull>(Op) ||
+      isa<UndefValue>(Op))
     return nullptr;
 
   if (auto ElemTy = getPointeeType(Op->getType()))
@@ -483,12 +484,25 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
     if (isPointerTy(Op->getType()))
       Ty = deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<PHINode>(I)) {
-    for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+    Type *BestTy = nullptr;
+    unsigned MaxN = 1;
+    DenseMap<Type *, unsigned> PhiTys;
+    for (int i = Ref->getNumIncomingValues() - 1; i >= 0; --i) {
       Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited,
                                         UnknownElemTypeI8);
-      if (Ty)
-        break;
+      if (!Ty)
+        continue;
+      auto It = PhiTys.try_emplace(Ty, 1);
+      if (!It.second) {
+        ++It.first->second;
+        if (It.first->second > MaxN) {
+          MaxN = It.first->second;
+          BestTy = Ty;
+        }
+      }
     }
+    if (BestTy)
+      Ty = BestTy;
   } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
     for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
       Ty = deduceElementTypeByUsersDeep(Op, Visited, UnknownElemTypeI8);
@@ -644,6 +658,62 @@ static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I,
   return nullptr;
 }
 
+// Try to deduce element type for a call base. Returns false if this is an
+// indirect function invocation, and true otherwise.
+static bool deduceOperandElementTypeCalledFunction(
+    SPIRVGlobalRegistry *GR, Instruction *I,
+    SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
+    SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
+  Function *CalledF = CI->getCalledFunction();
+  if (!CalledF)
+    return false;
+  std::string DemangledName =
+      getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+  if (DemangledName.length() > 0 &&
+      !StringRef(DemangledName).starts_with("llvm.")) {
+    auto [Grp, Opcode, ExtNo] =
+        SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+    if (Opcode == SPIRV::OpGroupAsyncCopy) {
+      for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2; ++i) {
+        Value *Op = CI->getArgOperand(i);
+        if (!isPointerTy(Op->getType()))
+          continue;
+        ++PtrCnt;
+        if (Type *ElemTy = GR->findDeducedElementType(Op))
+          KnownElemTy = ElemTy; // src will rewrite dest if both are defined
+        Ops.push_back(std::make_pair(Op, i));
+      }
+    } else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) {
+      if (CI->arg_size() < 2)
+        return true;
+      Value *Op = CI->getArgOperand(0);
+      if (!isPointerTy(Op->getType()))
+        return true;
+      switch (Opcode) {
+      case SPIRV::OpAtomicLoad:
+      case SPIRV::OpAtomicCompareExchangeWeak:
+      case SPIRV::OpAtomicCompareExchange:
+      case SPIRV::OpAtomicExchange:
+      case SPIRV::OpAtomicIAdd:
+      case SPIRV::OpAtomicISub:
+      case SPIRV::OpAtomicOr:
+      case SPIRV::OpAtomicXor:
+      case SPIRV::OpAtomicAnd:
+      case SPIRV::OpAtomicUMin:
+      case SPIRV::OpAtomicUMax:
+      case SPIRV::OpAtomicSMin:
+      case SPIRV::OpAtomicSMax: {
+        KnownElemTy = getAtomicElemTy(GR, I, Op);
+        if (!KnownElemTy)
+          return true;
+        Ops.push_back(std::make_pair(Op, 0));
+      } break;
+      }
+    }
+  }
+  return true;
+}
+
 // If the Instruction has Pointer operands with unresolved types, this function
 // tries to deduce them. If the Instruction has Pointer operands with known
 // types which differ from expected, this function tries to insert a bitcast to
@@ -749,53 +819,17 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
       KnownElemTy = ElemTy1;
       Ops.push_back(std::make_pair(Op0, 0));
     }
-  } else if (auto *CI = dyn_cast<CallInst>(I)) {
-    if (Function *CalledF = CI->getCalledFunction()) {
-      std::string DemangledName =
-          getOclOrSpirvBuiltinDemangledName(CalledF->getName());
-      if (DemangledName.length() > 0 &&
-          !StringRef(DemangledName).starts_with("llvm.")) {
-        auto [Grp, Opcode, ExtNo] =
-            SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
-        if (Opcode == SPIRV::OpGroupAsyncCopy) {
-          for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
-               ++i) {
-            Value *Op = CI->getArgOperand(i);
-            if (!isPointerTy(Op->getType()))
-              continue;
-            ++PtrCnt;
-            if (Type *ElemTy = GR->findDeducedElementType(Op))
-              KnownElemTy = ElemTy; // src will rewrite dest if both are defined
-            Ops.push_back(std::make_pair(Op, i));
-          }
-        } else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) {
-          if (CI->arg_size() < 2)
-            return;
-          Value *Op = CI->getArgOperand(0);
-          if (!isPointerTy(Op->getType()))
-            return;
-          switch (Opcode) {
-          case SPIRV::OpAtomicLoad:
-          case SPIRV::OpAtomicCompareExchangeWeak:
-          case SPIRV::OpAtomicCompareExchange:
-          case SPIRV::OpAtomicExchange:
-          case SPIRV::OpAtomicIAdd:
-          case SPIRV::OpAtomicISub:
-          case SPIRV::OpAtomicOr:
-          case SPIRV::OpAtomicXor:
-          case SPIRV::OpAtomicAnd:
-          case SPIRV::OpAtomicUMin:
-          case SPIRV::OpAtomicUMax:
-          case SPIRV::OpAtomicSMin:
-          case SPIRV::OpAtomicSMax: {
-            KnownElemTy = getAtomicElemTy(GR, I, Op);
-            if (!KnownElemTy)
-              return;
-            Ops.push_back(std::make_pair(Op, 0));
-          } break;
-          }
-        }
-      }
+  } else if (CallInst *CI = dyn_cast<CallInst>(I)) {
+    if (!CI->isIndirectCall()) {
+      deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
+                                             KnownElemTy);
+    } else if (TM->getSubtarget<SPIRVSubtarget>(*F).canUseExtension(
+                   SPIRV::Extension::SPV_INTEL_function_pointers)) {
+      Value *Op = CI->getCalledOperand();
+      if (!Op || !isPointerTy(Op->getType()))
+        return;
+      Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
+      KnownElemTy = CI->getFunctionType();
     }
   }
 
@@ -846,7 +880,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
                                       B.getInt32(getPointerAddressSpace(OpTy))};
       CallInst *PtrCastI =
           B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
-      I->setOperand(OpIt.second, PtrCastI);
+      if (OpIt.second == std::numeric_limits<unsigned>::max())
+        dyn_cast<CallInst>(I)->setCalledOperand(PtrCastI);
+      else
+        I->setOperand(OpIt.second, PtrCastI);
       buildAssignPtr(B, KnownElemTy, PtrCastI);
     }
   }

>From bd9bceaa8fdb9dfe0c8c21376f984333aec6dc36 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 11 Oct 2024 09:11:07 -0700
Subject: [PATCH 5/6] update type inference for function pointers and update
 test cases

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 107 ++++++++++++++----
 .../fp-simple-hierarchy.ll                    |  49 ++++++--
 .../SPV_INTEL_function_pointers/fp_const.ll   |  37 +++---
 .../CodeGen/SPIRV/instructions/select-phi.ll  |   8 +-
 4 files changed, 149 insertions(+), 52 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 4ac06cc19f03dc..8b7e9c48de6c75 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -69,6 +69,7 @@ class SPIRVEmitIntrinsics
   SPIRVGlobalRegistry *GR = nullptr;
   Function *F = nullptr;
   bool TrackConstants = true;
+  bool HaveFunPtrs = false;
   DenseMap<Instruction *, Constant *> AggrConsts;
   DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
@@ -714,6 +715,37 @@ static bool deduceOperandElementTypeCalledFunction(
   return true;
 }
 
+// Try to deduce element type for a function pointer.
+static void deduceOperandElementTypeFunctionPointer(
+    SPIRVGlobalRegistry *GR, Instruction *I, CallInst *CI,
+    SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
+  Value *Op = CI->getCalledOperand();
+  if (!Op || !isPointerTy(Op->getType()))
+    return;
+  Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
+  FunctionType *FTy = CI->getFunctionType();
+  bool IsNewFTy = false;
+  SmallVector<Type *, 4> ArgTys;
+  for (Value *Arg : CI->args()) {
+    Type *ArgTy = Arg->getType();
+    if (ArgTy->isPointerTy())
+      if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
+        IsNewFTy = true;
+        ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
+      }
+    ArgTys.push_back(ArgTy);
+  }
+  Type *RetTy = FTy->getReturnType();
+  if (I->getType()->isPointerTy())
+    if (Type *ElemTy = GR->findDeducedElementType(I)) {
+      IsNewFTy = true;
+      RetTy =
+          TypedPointerType::get(ElemTy, getPointerAddressSpace(I->getType()));
+    }
+  KnownElemTy =
+      IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
+}
+
 // If the Instruction has Pointer operands with unresolved types, this function
 // tries to deduce them. If the Instruction has Pointer operands with known
 // types which differ from expected, this function tries to insert a bitcast to
@@ -820,17 +852,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
       Ops.push_back(std::make_pair(Op0, 0));
     }
   } else if (CallInst *CI = dyn_cast<CallInst>(I)) {
-    if (!CI->isIndirectCall()) {
+    if (!CI->isIndirectCall())
       deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
                                              KnownElemTy);
-    } else if (TM->getSubtarget<SPIRVSubtarget>(*F).canUseExtension(
-                   SPIRV::Extension::SPV_INTEL_function_pointers)) {
-      Value *Op = CI->getCalledOperand();
-      if (!Op || !isPointerTy(Op->getType()))
-        return;
-      Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
-      KnownElemTy = CI->getFunctionType();
-    }
+    else if (HaveFunPtrs)
+      deduceOperandElementTypeFunctionPointer(GR, I, CI, Ops, KnownElemTy);
   }
 
   // There is no enough info to deduce types or all is valid.
@@ -1710,23 +1736,53 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
   }
 }
 
+static FunctionType *getFunctionPointerElemType(Function *F,
+                                                SPIRVGlobalRegistry *GR) {
+  FunctionType *FTy = F->getFunctionType();
+  bool IsNewFTy = false;
+  SmallVector<Type *, 4> ArgTys;
+  for (Argument &Arg : F->args()) {
+    Type *ArgTy = Arg.getType();
+    if (ArgTy->isPointerTy())
+      if (Type *ElemTy = GR->findDeducedElementType(&Arg)) {
+        IsNewFTy = true;
+        ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
+      }
+    ArgTys.push_back(ArgTy);
+  }
+  return IsNewFTy
+             ? FunctionType::get(FTy->getReturnType(), ArgTys, FTy->isVarArg())
+             : FTy;
+}
+
 bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
-  bool IsExt = false;
   SmallVector<Function *> Worklist;
   for (auto &F : M) {
-    if (!IsExt) {
-      if (!TM->getSubtarget<SPIRVSubtarget>(F).canUseExtension(
-              SPIRV::Extension::SPV_INTEL_function_pointers))
-        return false;
-      IsExt = true;
-    }
-    if (!F.isDeclaration() || F.isIntrinsic())
+    if (F.isIntrinsic())
       continue;
-    for (User *U : F.users()) {
-      CallInst *CI = dyn_cast<CallInst>(U);
-      if (!CI || CI->getCalledFunction() != &F) {
-        Worklist.push_back(&F);
-        break;
+    if (F.isDeclaration()) {
+      for (User *U : F.users()) {
+        CallInst *CI = dyn_cast<CallInst>(U);
+        if (!CI || CI->getCalledFunction() != &F) {
+          Worklist.push_back(&F);
+          break;
+        }
+      }
+    } else {
+      if (F.user_empty())
+        continue;
+      Type *FPElemTy = GR->findDeducedElementType(&F);
+      if (!FPElemTy)
+        FPElemTy = getFunctionPointerElemType(&F, GR);
+      for (User *U : F.users()) {
+        IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
+        if (!II || II->arg_size() != 3 || II->getOperand(0) != &F)
+          continue;
+        if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type ||
+            II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
+          updateAssignType(II, &F, PoisonValue::get(FPElemTy));
+          break;
+        }
       }
     }
   }
@@ -1765,6 +1821,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
                               : SPIRV::InstructionSet::GLSL_std_450;
 
+  if (!F)
+    HaveFunPtrs =
+        ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
+
   F = &Func;
   IRBuilder<> B(Func.getContext());
   AggrConsts.clear();
@@ -1910,7 +1970,8 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   }
 
   Changed |= postprocessTypes();
-  Changed |= processFunctionPointers(M);
+  if (HaveFunPtrs)
+    Changed |= processFunctionPointers(M);
 
   return Changed;
 }
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll
index 0178e1192d7ea7..d5a8fb3e7baafa 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll
@@ -1,22 +1,47 @@
-; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
 ; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
-; CHECK: OpFunction
+; CHECK-DAG: OpName %[[I9:.*]] "_ZN13BaseIncrement9incrementEPi"
+; CHECK-DAG: OpName %[[I29:.*]] "_ZN12IncrementBy29incrementEPi"
+; CHECK-DAG: OpName %[[I49:.*]] "_ZN12IncrementBy49incrementEPi"
+; CHECK-DAG: OpName %[[I89:.*]] "_ZN12IncrementBy89incrementEPi"
 
-%classid = type { %arrayid }
-%arrayid = type { [1 x i64] }
-%struct.obj_storage_t = type { %storage }
-%storage = type { [8 x i8] }
+; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-DAG: %[[TyArr:.*]] = OpTypeArray
+; CHECK-DAG: %[[TyStruct1:.*]] = OpTypeStruct %[[TyArr]]
+; CHECK-DAG: %[[TyStruct2:.*]] = OpTypeStruct %[[TyStruct1]]
+; CHECK-DAG: %[[TyPtrStruct2:.*]] = OpTypePointer Generic %[[TyStruct2]]
+; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrStruct2]] %[[#]]
+; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Generic %[[TyFun]]
+; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Generic %[[TyPtrFun]]
+
+; CHECK: %[[I9]] = OpFunction
+; CHECK: %[[I29]] = OpFunction
+; CHECK: %[[I49]] = OpFunction
+; CHECK: %[[I89]] = OpFunction
+
+; CHECK: %[[Arg1:.*]] = OpPhi %[[TyPtrStruct2]]
+; CHECK: %[[VTbl:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[#]]
+; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[VTbl]]
+; CHECK: %[[#]] = OpFunctionPointerCallINTEL %[[TyVoid]] %[[FP]] %[[Arg1]] %[[#]]
+
+%"cls::id" = type { %"cls::detail::array" }
+%"cls::detail::array" = type { [1 x i64] }
+%struct.obj_storage_t = type { %"struct.aligned_storage<BaseIncrement, IncrementBy2, IncrementBy4, IncrementBy8>::type" }
+%"struct.aligned_storage<BaseIncrement, IncrementBy2, IncrementBy4, IncrementBy8>::type" = type { [8 x i8] }
 
 @_ZTV12IncrementBy8 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy89incrementEPi to ptr addrspace(4))] }, align 8
 @_ZTV13BaseIncrement = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN13BaseIncrement9incrementEPi to ptr addrspace(4))] }, align 8
 @_ZTV12IncrementBy4 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy49incrementEPi to ptr addrspace(4))] }, align 8
 @_ZTV12IncrementBy2 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy29incrementEPi to ptr addrspace(4))] }, align 8
+ at __spirv_BuiltInWorkgroupId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
+ at __spirv_BuiltInGlobalLinearId = external dso_local local_unnamed_addr addrspace(1) constant i64, align 8
+ at __spirv_BuiltInWorkgroupSize = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
 
-define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%classid) align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) {
+define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%"cls::id") align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) {
 entry:
-  %0 = load i64, ptr %_arg_StorageAcc3, align 8
-  %add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %0
+  %r0 = load i64, ptr %_arg_StorageAcc3, align 8
+  %add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %r0
   %arrayidx.ascast.i = addrspacecast ptr addrspace(1) %add.ptr.i to ptr addrspace(4)
   %cmp.i = icmp ugt i32 %_arg_TestCase, 3
   br i1 %cmp.i, label %entry.critedge, label %if.end.1
@@ -51,9 +76,9 @@ if.end.2:                                       ; preds = %if.end.1
 exit: ; preds = %if.end.2, %if.end.3, %if.end.4, %if.end.5, %entry.critedge
   %vtable.i = phi ptr addrspace(4) [ %vtable.i.pre, %entry.critedge ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16) to i64) to ptr addrspace(4)), %if.end.5 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16) to i64) to ptr addrspace(4)), %if.end.4 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16) to i64) to ptr addrspace(4)), %if.end.3 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16) to i64) to ptr addrspace(4)), %if.end.2 ]
   %retval.0.i = phi ptr addrspace(4) [ null, %entry.critedge ], [ %arrayidx.ascast.i, %if.end.5 ], [ %arrayidx.ascast.i, %if.end.4 ], [ %arrayidx.ascast.i, %if.end.3 ], [ %arrayidx.ascast.i, %if.end.2 ]
-  %1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4)
-  %2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8
-  tail call spir_func addrspace(4) void %2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %1)
+  %r1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4)
+  %r2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8
+  tail call spir_func addrspace(4) void %r2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %r1)
   ret void
 }
 
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll
index 5f073e95cb68f2..b4faba9a4eb8e3 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll
@@ -5,30 +5,39 @@
 ; CHECK-DAG: OpCapability FunctionPointersINTEL
 ; CHECK-DAG: OpCapability Int64
 ; CHECK: OpExtension "SPV_INTEL_function_pointers"
-; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
+
 ; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
 ; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
-; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyVoid]] %[[TyInt64]]
-; CHECK-DAG: %[[ConstInt64:.*]] = OpConstant %[[TyInt64]] 42
-; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]]
-; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFunFp]] %[[DefFunFp:.*]]
-; CHECK: %[[FunPtr1:.*]] = OpBitcast %[[#]] %[[ConstFunFp]]
-; CHECK: %[[FunPtr2:.*]] = OpLoad %[[#]] %[[FunPtr1]]
-; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FunPtr2]] %[[ConstInt64]]
-; CHECK: OpReturn
+; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyInt64]] %[[TyInt64]]
+; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
+; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Function %[[TyFun]]
+; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFun]] %[[DefFunFp:.*]]
+; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Function %[[TyPtrFun]]
+; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]]
+; CHECK-DAG: %[[TyPtrPtrInt8:.*]] = OpTypePointer Function %[[TyPtrInt8]]
+; CHECK: OpFunction
+; CHECK: %[[Var:.*]] = OpVariable %[[TyPtrPtrInt8]] Function
+; CHECK: %[[SAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]]
+; CHECK: OpStore %[[SAddr]] %[[ConstFunFp]]
+; CHECK: %[[LAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]]
+; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[LAddr]]
+; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FP]] %[[#]]
 ; CHECK: OpFunctionEnd
-; CHECK: %[[DefFunFp]] = OpFunction %[[TyVoid]] None %[[TyFunFp]]
+ 
+; CHECK: %[[DefFunFp]] = OpFunction %[[TyInt64]] None %[[TyFun]]
 
 target triple = "spir64-unknown-unknown"
 
 define spir_kernel void @test() {
 entry:
-  %0 = load ptr, ptr @foo
-  %1 = call i64 %0(i64 42)
+  %fp = alloca ptr
+  store ptr @foo, ptr %fp
+  %tocall = load ptr, ptr %fp
+  %res = call i64 %tocall(i64 42)
   ret void
 }
 
-define void @foo(i64 %a) {
+define i64 @foo(i64 %a) {
 entry:
-  ret void
+  ret i64 %a
 }
diff --git a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
index 3828fe89e60aec..16be7cd3b8db62 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
@@ -1,3 +1,6 @@
+; This test case checks how phi-nodes with different operand types select
+; a result type. Majority of operands makes it i8* in this case.
+
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
 ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
 
@@ -15,14 +18,13 @@
 
 ; CHECK: %[[Branch1:.*]] = OpLabel
 ; CHECK: %[[Res1:.*]] = OpVariable %[[StructPtr]] Function
+; CHECK: %[[Res1Casted:.*]] = OpBitcast %[[CharPtr]] %[[Res1]]
 ; CHECK: OpBranchConditional %[[#]] %[[#]] %[[Branch2:.*]]
 ; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[#]] %[[#]]
-; CHECK: %[[Res2Casted:.*]] = OpBitcast %[[StructPtr]] %[[Res2]]
 ; CHECK: OpBranchConditional %[[#]] %[[#]] %[[BranchSelect:.*]]
 ; CHECK: %[[SelectRes:.*]] = OpSelect %[[CharPtr]] %[[#]] %[[#]] %[[#]]
-; CHECK: %[[SelectResCasted:.*]] = OpBitcast %[[StructPtr]] %[[SelectRes]]
 ; CHECK: OpLabel
-; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2Casted]] %[[Branch2]] %[[SelectResCasted]] %[[BranchSelect]]
+; CHECK: OpPhi %[[CharPtr]] %[[Res1Casted]] %[[Branch1]] %[[Res2]] %[[Branch2]] %[[SelectRes]] %[[BranchSelect]]
 
 %struct = type { %array }
 %array = type { [1 x i64] }

>From db9f3670e2759e18f6070930a237521db2b340c0 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 11 Oct 2024 12:07:26 -0700
Subject: [PATCH 6/6] update test case

---
 .../fp_two_calls.ll                           | 31 ++++++++++---------
 1 file changed, 17 insertions(+), 14 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
index c5a2918f92c29e..eb7b1dffaee501 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
@@ -5,27 +5,30 @@
 ; CHECK-DAG: OpCapability FunctionPointersINTEL
 ; CHECK-DAG: OpCapability Int64
 ; CHECK: OpExtension "SPV_INTEL_function_pointers"
-; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
+
+; CHECK-DAG: OpName %[[fp:.*]] "fp"
+; CHECK-DAG: OpName %[[data:.*]] "data"
+; CHECK-DAG: OpName %[[bar:.*]] "bar"
+; CHECK-DAG: OpName %[[test:.*]] "test"
 ; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
 ; CHECK-DAG: %[[TyFloat32:.*]] = OpTypeFloat 32
+; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
 ; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]]
-; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]]
-; CHECK-DAG: %[[TyFunBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrInt8]] %[[TyPtrInt8]]
-; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]]
-; CHECK-DAG: %[[TyPtrFunBar:.*]] = OpTypePointer Function %[[TyFunBar]]
-; CHECK-DAG: %[[TyFunTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrInt8]] %[[TyPtrInt8]] %[[TyPtrInt8]]
-; CHECK: %[[FunTest:.*]] = OpFunction %[[TyVoid]] None %[[TyFunTest]]
-; CHECK: %[[ArgFp:.*]] = OpFunctionParameter %[[TyPtrInt8]]
-; CHECK: %[[ArgData:.*]] = OpFunctionParameter %[[TyPtrInt8]]
-; CHECK: %[[ArgBar:.*]] = OpFunctionParameter %[[TyPtrInt8]]
-; CHECK: OpFunctionPointerCallINTEL %[[TyFloat32]] %[[ArgFp]] %[[ArgBar]]
-; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[ArgBar]] %[[ArgFp]] %[[ArgData]]
+; CHECK-DAG: %[[TyFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyPtrFp:.*]] = OpTypePointer Function %[[TyFp]]
+; CHECK-DAG: %[[TyBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrFp]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyPtrBar:.*]] = OpTypePointer Function %[[TyBar]]
+; CHECK-DAG: %[[TyTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrFp]] %[[TyPtrInt8]] %[[TyPtrBar]]
+; CHECK: %[[test]] = OpFunction %[[TyVoid]] None %[[TyTest]]
+; CHECK: %[[fp]] = OpFunctionParameter %[[TyPtrFp]]
+; CHECK: %[[data]] = OpFunctionParameter %[[TyPtrInt8]]
+; CHECK: %[[bar]] = OpFunctionParameter %[[TyPtrBar]]
+; CHECK: OpFunctionPointerCallINTEL %[[TyFloat32]] %[[fp]] %[[bar]]
+; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[bar]] %[[fp]] %[[data]]
 ; CHECK: OpReturn
 ; CHECK: OpFunctionEnd
 
-target triple = "spir64-unknown-unknown"
-
 define spir_kernel void @test(ptr %fp, ptr %data, ptr %bar) {
 entry:
   %0 = call spir_func float %fp(ptr %bar)



More information about the llvm-commits mailing list