[llvm] Add support for SPIR-V extension: SPV_INTEL_function_pointers (PR #80759)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 05:54:27 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/80759

>From d90e91238283c4acb03c5b8d52f7fb2380df000a Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 5 Feb 2024 14:55:27 -0800
Subject: [PATCH 1/4] add initial support for SPIR-V extension:
 SPV_INTEL_function_pointers

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 208 +++++++++++++++---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.h     |  28 +++
 .../Target/SPIRV/SPIRVDuplicatesTracker.cpp   |   9 +-
 .../lib/Target/SPIRV/SPIRVDuplicatesTracker.h |  28 +++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  19 ++
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp      |   1 +
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  10 +
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  27 +++
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  42 ++++
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h   |   2 +
 llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp      |   5 +-
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |   6 +
 .../SPV_INTEL_function_pointers/fp_const.ll   |  34 +++
 .../fp_two_calls.ll                           |  34 +++
 14 files changed, 425 insertions(+), 28 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 97b25147ffb34..9fe517a996e91 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -34,6 +34,10 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
                                     const Value *Val, ArrayRef<Register> VRegs,
                                     FunctionLoweringInfo &FLI,
                                     Register SwiftErrorVReg) const {
+  // Maybe run postponed production of OpFunction/OpFunctionParameter's
+  if (FormalArgs.F != nullptr)
+    FormalArgs.produceFunArgsInstructions(MIRBuilder, GR, IndirectCalls);
+
   // Currently all return types should use a single register.
   // TODO: handle the case of multiple registers.
   if (VRegs.size() > 1)
@@ -217,6 +221,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   // Assign types and names to all args, and store their types for later.
   FunctionType *FTy = getOriginalFunctionType(F);
   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
+  bool HasOpaquePtrArg = false;
   if (VRegs.size() > 0) {
     unsigned i = 0;
     for (const auto &Arg : F.args()) {
@@ -231,6 +236,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
       if (Arg.hasName())
         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
       if (Arg.getType()->isPointerTy()) {
+        HasOpaquePtrArg = true;
         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
         if (DerefBytes != 0)
           buildOpDecorate(VRegs[i][0], MIRBuilder,
@@ -292,33 +298,61 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     }
   }
 
-  // Generate a SPIR-V type for the function.
+  // If there is support of indirect calls and there are opaque pointer formal
+  // arguments, there is a chance to specify opaque ptr types later (after the
+  // function's body is processed) by information about the indirect call. To
+  // support this case we may postpone generation of some SPIR-V types, and
+  // OpFunction and OpFunctionParameter's. Otherwise we generate all SPIR-V
+  // types related to the function along with instructions.
+  const auto *ST =
+      static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+  bool hasFunctionPointers =
+      ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
+  bool PostponeOpFunction = HasOpaquePtrArg && hasFunctionPointers;
+
   auto MRI = MIRBuilder.getMRI();
   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
   MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
   if (F.isDeclaration())
     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
   SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
-  SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
-      FTy, RetTy, ArgTypeVRegs, MIRBuilder);
-
-  // Build the OpTypeFunction declaring it.
+  SPIRVType *FuncTy = PostponeOpFunction
+                          ? nullptr
+                          : GR->getOrCreateOpTypeFunctionWithArgs(
+                                FTy, RetTy, ArgTypeVRegs, MIRBuilder);
   uint32_t FuncControl = getFunctionControl(F);
 
-  MIRBuilder.buildInstr(SPIRV::OpFunction)
-      .addDef(FuncVReg)
-      .addUse(GR->getSPIRVTypeID(RetTy))
-      .addImm(FuncControl)
-      .addUse(GR->getSPIRVTypeID(FuncTy));
+  if (PostponeOpFunction) {
+    FormalArgs.F = &F;
+    FormalArgs.KeepMBB = &(MIRBuilder.getMBB());
+    FormalArgs.KeepInsertPt = MIRBuilder.getInsertPt();
+    FormalArgs.FuncVReg = FuncVReg;
+    FormalArgs.RetTy = RetTy;
+    FormalArgs.FuncControl = FuncControl;
+    FormalArgs.OrigFTy = FTy;
+    FormalArgs.ArgTypeVRegs = ArgTypeVRegs;
+  } else {
+    const MachineInstrBuilder &MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
+                                        .addDef(FuncVReg)
+                                        .addUse(GR->getSPIRVTypeID(RetTy))
+                                        .addImm(FuncControl)
+                                        .addUse(GR->getSPIRVTypeID(FuncTy));
+    const MachineOperand *DefOpFunction = &MB.getInstr()->getOperand(0);
+    GR->recordFunctionDefinition(&F, DefOpFunction);
+  }
 
   // Add OpFunctionParameters.
   int i = 0;
   for (const auto &Arg : F.args()) {
     assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
     MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
-    MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
-        .addDef(VRegs[i][0])
-        .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
+    if (PostponeOpFunction) {
+      FormalArgs.ArgVRegs.push_back(VRegs[i][0]);
+    } else {
+      MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
+          .addDef(VRegs[i][0])
+          .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
+    }
     if (F.isDeclaration())
       GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
     i++;
@@ -343,9 +377,106 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
                     {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
   }
 
+  // Handle function pointers decoration
+  if (hasFunctionPointers) {
+    if (F.hasFnAttribute("referenced-indirectly")) {
+      assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
+              "Unexpected 'referenced-indirectly' attribute of the kernel "
+              "function");
+      buildOpDecorate(FuncVReg, MIRBuilder,
+                      SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
+    }
+  }
+
   return true;
 }
 
+// Use collect during function's body analysis information about the indirect
+// call to specify opaque ptr types of parent function's parameters
+void SPIRVCallLowering::SPIRVFunFormalArgs::produceFunArgsInstructions(
+    MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR,
+    SmallVector<SPIRVCallLowering::SPIRVIndirectCall> &IndirectCalls) {
+  // Store current insertion point
+  MachineBasicBlock &NextKeepMBB = MIRBuilder.getMBB();
+  MachineBasicBlock::iterator NextKeepInsertPt = MIRBuilder.getInsertPt();
+  // Set a new insertion point
+  MIRBuilder.setInsertPt(*KeepMBB, KeepInsertPt);
+
+  bool IsTypeUpd = false;
+  if (IndirectCalls.size() > 0) {
+    // TODO: add a topological sort of IndirectCalls
+    // Create indirect call data types if any
+    MachineFunction &MF = MIRBuilder.getMF();
+    for (auto const &IC : IndirectCalls) {
+      SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
+      SmallVector<SPIRVType *, 4> SpirvArgTypes;
+      for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
+        SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
+        SpirvArgTypes.push_back(SPIRVTy);
+        if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
+          GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
+      }
+      // SPIR-V function type:
+      FunctionType *FTy =
+          FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
+      SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
+          FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
+      // SPIR-V pointer to function type:
+      SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
+          SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
+      // Correct the Calee type
+      GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
+    }
+
+    // Check if our knowledge about a type of the function parameter is updated
+    // as a result of indirect calls analysis
+    for (size_t i = 0; i < ArgVRegs.size(); ++i) {
+      SPIRVType *ArgTy = GR->getSPIRVTypeForVReg(ArgVRegs[i]);
+      if (ArgTy && ArgTypeVRegs[i] != ArgTy) {
+        ArgTypeVRegs[i] = ArgTy;
+        IsTypeUpd = true;
+      }
+    }
+  }
+
+  // If we have update about function parameter types, create a new function
+  // type instead of the stored
+  // TODO: (maybe) allocated in getOriginalFunctionType(F) this->OrigFTy may be
+  // overwritten and is not used (tracked?) anywhere
+  FunctionType *UpdateFTy = OrigFTy;
+  if (IsTypeUpd) {
+    SmallVector<Type *, 4> ArgTys;
+    for (size_t i = 0; i < ArgTypeVRegs.size(); ++i) {
+      const Type *Ty = GR->getTypeForSPIRVType(ArgTypeVRegs[i]);
+      ArgTys.push_back(const_cast<Type *>(Ty));
+    }
+    // Argument types were specified, we must update function type
+    UpdateFTy = FunctionType::get(F->getReturnType(), ArgTys,
+                                  F->getFunctionType()->isVarArg());
+  }
+  // Create SPIR-V function type
+  SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
+      UpdateFTy, RetTy, ArgTypeVRegs, MIRBuilder);
+
+  // Emit OpFunction
+  const MachineInstrBuilder &MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
+                                      .addDef(FuncVReg)
+                                      .addUse(GR->getSPIRVTypeID(RetTy))
+                                      .addImm(FuncControl)
+                                      .addUse(GR->getSPIRVTypeID(FuncTy));
+  GR->recordFunctionDefinition(F, &MB.getInstr()->getOperand(0));
+
+  // Emit OpFunctionParameter's
+  for (size_t i = 0; i < ArgVRegs.size(); ++i) {
+    MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
+        .addDef(ArgVRegs[i])
+        .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
+  }
+
+  // Restore insertion point
+  MIRBuilder.setInsertPt(NextKeepMBB, NextKeepInsertPt);
+}
+
 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
                                   CallLoweringInfo &Info) const {
   // Currently call returns should have single vregs.
@@ -356,45 +487,44 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   GR->setCurrentFunc(MF);
   FunctionType *FTy = nullptr;
   const Function *CF = nullptr;
+  std::string DemangledName;
+  const Type *OrigRetTy = Info.OrigRet.Ty;
 
   // Emit a regular OpFunctionCall. If it's an externally declared function,
   // be sure to emit its type and function declaration here. It will be hoisted
   // globally later.
   if (Info.Callee.isGlobal()) {
+    std::string FuncName = Info.Callee.getGlobal()->getName().str();
+    DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
     CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
       return false;
-    FTy = getOriginalFunctionType(*CF);
+    if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
+      OrigRetTy = FTy->getReturnType();
   }
 
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register ResVReg =
       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
-  std::string FuncName = Info.Callee.getGlobal()->getName().str();
-  std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
   // TODO: check that it's OCL builtin, then apply OpenCL_std.
   if (!DemangledName.empty() && CF && CF->isDeclaration() &&
       ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
-    const Type *OrigRetTy = Info.OrigRet.Ty;
-    if (FTy)
-      OrigRetTy = FTy->getReturnType();
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
       ArgVRegs.push_back(Arg.Regs[0]);
       SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
       if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
-        GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
+        GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
     }
     if (auto Res = SPIRV::lowerBuiltin(
             DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
             ResVReg, OrigRetTy, ArgVRegs, GR))
       return *Res;
   }
-  if (CF && CF->isDeclaration() &&
-      !GR->find(CF, &MIRBuilder.getMF()).isValid()) {
+  if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
     // Emit the type info and forward function declaration to the first MBB
     // to ensure VReg definition dependencies are valid across all MBBs.
     MachineIRBuilder FirstBlockBuilder;
@@ -416,14 +546,40 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
   }
 
+  unsigned CallOp;
+  if (Info.CB->isIndirectCall()) {
+    if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
+      report_fatal_error("An indirect call is encountered but SPIR-V without "
+                         "extensions does not support it",
+                         false);
+    // Set instruction operation according to SPV_INTEL_function_pointers
+    CallOp = SPIRV::OpFunctionPointerCallINTEL;
+    // Collect information about the indirect call to support possible
+    // specification of opaque ptr types of parent function's parameters
+    Register CalleeReg = Info.Callee.getReg();
+    if (CalleeReg.isValid()) {
+      SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
+      IndirectCall.Callee = CalleeReg;
+      IndirectCall.RetTy = OrigRetTy;
+      for (const auto &Arg : Info.OrigArgs) {
+        assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
+        IndirectCall.ArgTys.push_back(Arg.Ty);
+        IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
+      }
+      IndirectCalls.push_back(IndirectCall);
+    }
+  } else {
+    // Emit a regular OpFunctionCall
+    CallOp = SPIRV::OpFunctionCall;
+  }
+
   // Make sure there's a valid return reg, even for functions returning void.
   if (!ResVReg.isValid())
     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
-  SPIRVType *RetType =
-      GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
+  SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
 
-  // Emit the OpFunctionCall and its args.
-  auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
+  // Emit the call instruction and its args.
+  auto MIB = MIRBuilder.buildInstr(CallOp)
                  .addDef(ResVReg)
                  .addUse(GR->getSPIRVTypeID(RetType))
                  .add(Info.Callee);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h
index c2d6ad82d507d..680db7ca7b1be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h
@@ -26,6 +26,34 @@ class SPIRVCallLowering : public CallLowering {
   // Used to create and assign function, argument, and return type information.
   SPIRVGlobalRegistry *GR;
 
+  // Used to postpone producing of OpFunction and OpFunctionParameter
+  // and use indirect calls to specify argument types
+  struct SPIRVIndirectCall {
+    const Type *RetTy = nullptr;
+    SmallVector<Type *> ArgTys;
+    SmallVector<Register> ArgRegs;
+    Register Callee;
+  };
+  struct SPIRVFunFormalArgs {
+    const Function *F = nullptr;
+    // the insertion point
+    MachineBasicBlock *KeepMBB = nullptr;
+    MachineBasicBlock::iterator KeepInsertPt;
+    // OpFunction and OpFunctionParameter operands
+    Register FuncVReg;
+    SPIRVType *RetTy = nullptr;
+    uint32_t FuncControl;
+    FunctionType *OrigFTy = nullptr;
+    SmallVector<SPIRVType *, 4> ArgTypeVRegs;
+    SmallVector<Register> ArgVRegs;
+
+    void produceFunArgsInstructions(
+        MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR,
+        SmallVector<SPIRVCallLowering::SPIRVIndirectCall> &IndirectCalls);
+  };
+  mutable SPIRVFunFormalArgs FormalArgs;
+  mutable SmallVector<SPIRVIndirectCall> IndirectCalls;
+
 public:
   SPIRVCallLowering(const SPIRVTargetLowering &TLI, SPIRVGlobalRegistry *GR);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
index cbe1a53fd7568..d82fb2df4539a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
@@ -54,7 +54,14 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
         MachineOperand &Op = MI->getOperand(i);
         if (!Op.isReg())
           continue;
-        MachineOperand *RegOp = &MRI.getVRegDef(Op.getReg())->getOperand(0);
+        MachineInstr *VRegDef = MRI.getVRegDef(Op.getReg());
+        // References to a function via function pointers generate virtual
+        // registers without a definition. We are able to resolve this
+        // reference using Globar Register info into an OpFunction instruction
+        // but do not expect to find it in Reg2Entry.
+        if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL && i == 2)
+          continue;
+        MachineOperand *RegOp = &VRegDef->getOperand(0);
         assert((MI->getOpcode() == SPIRV::OpVariable && i == 3) ||
                Reg2Entry.count(RegOp));
         if (Reg2Entry.count(RegOp))
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 96cc621791e97..65bc651116962 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -264,6 +264,11 @@ class SPIRVGeneralDuplicatesTracker {
   SPIRVDuplicatesTracker<Argument> AT;
   SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
 
+  // map a Function to its definition (as a machine instruction operand)
+  DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
+  // map function pointer (as a machine instruction operand) to the used Function
+  DenseMap<const MachineOperand *, const Function *> InstrToFunction;
+
   // NOTE: using MOs instead of regs to get rid of MF dependency to be able
   // to use flat data structure.
   // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
@@ -280,6 +285,29 @@ class SPIRVGeneralDuplicatesTracker {
   void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
                       MachineModuleInfo *MMI);
 
+  // Map a machine operand that represents a use of a function via function
+  // pointer to a machine operand that represents the function definition.
+  // Return either the register or invalid value, because we have no context for
+  // a good diagnostic message in case of unexpectedly missing references.
+  const MachineOperand* getFunctionDefinitionByUse(const MachineOperand *Use) {
+    auto ResF = InstrToFunction.find(Use);
+    if (ResF == InstrToFunction.end())
+      return nullptr;
+    auto ResReg = FunctionToInstr.find(ResF->second);
+    return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
+  }
+  // map function pointer (as a machine instruction operand) to the used
+  // Function
+  void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
+    InstrToFunction[MO] = F;
+  }
+  // map a Function to its definition (as a machine instruction)
+  void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
+    FunctionToInstr[F] = MO;
+  }
+  // Return true if any OpConstantFunctionPointerINTEL were generated
+  bool hasConstFunPtr() { return !InstrToFunction.empty(); }
+
   void add(const Type *Ty, const MachineFunction *MF, Register R) {
     TT.add(Ty, MF, R);
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index f3280928c25df..3ccc224ad8a4c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -101,6 +101,25 @@ class SPIRVGlobalRegistry {
     DT.buildDepsGraph(Graph, MMI);
   }
 
+  // map function pointer (as a machine instruction operand) to the used
+  // Function
+  void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
+    DT.recordFunctionPointer(MO, F);
+  }
+  // map a Function to its definition (as a machine instruction)
+  void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
+    DT.recordFunctionDefinition(F, MO);
+  }
+  // Map a machine operand that represents a use of a function via function
+  // pointer to a machine operand that represents the function definition.
+  // Return either the register or invalid value, because we have no context for
+  // a good diagnostic message in case of unexpectedly missing references.
+  const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
+    return DT.getFunctionDefinitionByUse(Use);
+  }
+  // Return true if any OpConstantFunctionPointerINTEL were generated
+  bool hasConstFunPtr() { return DT.hasConstFunPtr(); }
+
   // Get or create a SPIR-V type corresponding the given LLVM IR type,
   // and map it to the given VReg by creating an ASSIGN_TYPE instruction.
   SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index 42317453a2370..e3f76419f1313 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -40,6 +40,7 @@ bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const {
   case SPIRV::OpSpecConstantComposite:
   case SPIRV::OpSpecConstantOp:
   case SPIRV::OpUndef:
+  case SPIRV::OpConstantFunctionPointerINTEL:
     return true;
   default:
     return false;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index da033ba32624c..3683fe9ec1648 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -761,3 +761,13 @@ def OpGroupNonUniformBitwiseXor: OpGroupNUGroup<"BitwiseXor", 361>;
 def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>;
 def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>;
 def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>;
+
+// 3.49.7, Constant-Creation Instructions
+
+//  - SPV_INTEL_function_pointers
+def OpConstantFunctionPointerINTEL: Op<5600, (outs ID:$res), (ins TYPE:$ty, ID:$fun), "$res = OpConstantFunctionPointerINTEL $ty $fun">;
+
+// 3.49.9. Function Instructions
+
+//  - SPV_INTEL_function_pointers
+def OpFunctionPointerCallINTEL: Op<5601, (outs ID:$res), (ins TYPE:$ty, ID:$funPtr, variable_ops), "$res = OpFunctionPointerCallINTEL $ty $funPtr">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 8c1dfc5e626db..f7cf3e1936e33 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1534,6 +1534,12 @@ bool SPIRVInstructionSelector::selectGlobalValue(
     GlobalIdent = GV->getGlobalIdentifier();
   }
 
+  // Behaviour of functions as operands depends on availability of the
+  // corresponding extension (SPV_INTEL_function_pointers):
+  // - If there is an extension to operate with functions as operands:
+  // We create a proper constant operand and evaluate a correct type for a
+  // function pointer.
+  // - Without the required extension:
   // We have functions as operands in tests with blocks of instruction e.g. in
   // transcoding/global_block.ll. These operands are not used and should be
   // substituted by zero constants. Their type is expected to be always
@@ -1545,6 +1551,27 @@ bool SPIRVInstructionSelector::selectGlobalValue(
     if (!NewReg.isValid()) {
       Register NewReg = ResVReg;
       GR.add(ConstVal, GR.CurMF, NewReg);
+      const Function *GVFun =
+          STI.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
+              ? dyn_cast<Function>(GV)
+              : nullptr;
+      if (GVFun) {
+        // References to a function via function pointers generate virtual
+        // registers without a definition. We will resolve it later, during
+        // module analysis stage.
+        MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+        Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+        MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
+        const MachineInstrBuilder &MB =
+            BuildMI(BB, I, I.getDebugLoc(),
+                    TII.get(SPIRV::OpConstantFunctionPointerINTEL))
+                .addDef(NewReg)
+                .addUse(GR.getSPIRVTypeID(ResType))
+                .addUse(FuncVReg);
+        // mapping the function pointer to the used Function
+        GR.recordFunctionPointer(&MB.getInstr()->getOperand(2), GVFun);
+        return MB.constrainAllUses(TII, TRI, RBI);
+      }
       return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
           .addDef(NewReg)
           .addUse(GR.getSPIRVTypeID(ResType))
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 370da046984f9..6f648c8aa3e88 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -291,6 +291,32 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
   }
 }
 
+// References to a function via function pointers generate virtual
+// registers without a definition. We are able to resolve this
+// reference using Globar Register info into an OpFunction instruction
+// and replace dummy operands by the corresponding global register references.
+void SPIRVModuleAnalysis::collectFuncPtrs() {
+  for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars])
+    if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL)
+      collectFuncPtrs(MI);
+}
+
+void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) {
+  const MachineOperand *FunUse = &MI->getOperand(2);
+  if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) {
+    const MachineInstr *FunDefMI = FunDef->getParent();
+    assert(FunDefMI->getOpcode() == SPIRV::OpFunction &&
+           "Constant function pointer must refer to function definition");
+    Register FunDefReg = FunDef->getReg();
+    Register GlobalFunDefReg =
+        MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg);
+    assert(GlobalFunDefReg.isValid() &&
+           "Function definition must refer to a global register");
+    Register FunPtrReg = FunUse->getReg();
+    MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg);
+  }
+}
+
 using InstrSignature = SmallVector<size_t>;
 using InstrTraces = std::set<InstrSignature>;
 
@@ -915,6 +941,18 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
     }
     break;
+  case SPIRV::OpConstantFunctionPointerINTEL:
+    if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
+      Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
+    }
+    break;
+  case SPIRV::OpFunctionPointerCallINTEL:
+    if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
+      Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
+    }
+    break;
   default:
     break;
   }
@@ -1073,6 +1111,10 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) {
   // Number rest of registers from N+1 onwards.
   numberRegistersGlobally(M);
 
+  // Update references to OpFunction instructions to use Global Registers
+  if (GR->hasConstFunPtr())
+    collectFuncPtrs();
+
   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
   processOtherInstrs(M);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h
index d0b8027edd420..b05526b06e7da 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h
@@ -224,6 +224,8 @@ struct SPIRVModuleAnalysis : public ModulePass {
   void collectFuncNames(MachineInstr &MI, const Function *F);
   void processOtherInstrs(const Module &M);
   void numberRegistersGlobally(const Module &M);
+  void collectFuncPtrs();
+  void collectFuncPtrs(MachineInstr *MI);
 
   const SPIRVSubtarget *ST;
   SPIRVGlobalRegistry *GR;
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index cf6dfb127cdeb..5974e37201fad 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -48,7 +48,10 @@ cl::list<SPIRV::Extension::Extension> Extensions(
         clEnumValN(SPIRV::Extension::SPV_KHR_bit_instructions,
                    "SPV_KHR_bit_instructions",
                    "This enables bit instructions to be used by SPIR-V modules "
-                   "without requiring the Shader capability")));
+                   "without requiring the Shader capability"),
+        clEnumValN(SPIRV::Extension::SPV_INTEL_function_pointers,
+                   "SPV_INTEL_function_pointers",
+                   "Allows translation of function pointers")));
 
 // Compare version numbers, but allow 0 to mean unspecified.
 static bool isAtLeastVer(uint32_t Target, uint32_t VerToCompareTo) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index ac92ee4a0756a..674e432203eed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -295,6 +295,7 @@ defm SPV_INTEL_usm_storage_classes : ExtensionOperand<100>;
 defm SPV_INTEL_fpga_latency_control : ExtensionOperand<101>;
 defm SPV_INTEL_fpga_argument_interfaces : ExtensionOperand<102>;
 defm SPV_INTEL_optnone : ExtensionOperand<103>;
+defm SPV_INTEL_function_pointers : ExtensionOperand<104>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -452,6 +453,8 @@ defm ArbitraryPrecisionIntegersINTEL : CapabilityOperand<5844, 0, 0, [SPV_INTEL_
 defm OptNoneINTEL : CapabilityOperand<6094, 0, 0, [SPV_INTEL_optnone], []>;
 defm BitInstructions : CapabilityOperand<6025, 0, 0, [SPV_KHR_bit_instructions], []>;
 defm ExpectAssumeKHR : CapabilityOperand<5629, 0, 0, [SPV_KHR_expect_assume], []>;
+defm FunctionPointersINTEL : CapabilityOperand<5603, 0, 0, [SPV_INTEL_function_pointers], []>;
+defm IndirectReferencesINTEL : CapabilityOperand<5604, 0, 0, [SPV_INTEL_function_pointers], []>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
@@ -688,6 +691,7 @@ defm HitAttributeNV : StorageClassOperand<5339, [RayTracingNV]>;
 defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>;
 defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>;
 defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>;
+defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Dim enum values and at the same time
@@ -1179,6 +1183,8 @@ defm CountBuffer : DecorationOperand<5634, 0, 0, [], []>;
 defm UserSemantic : DecorationOperand<5635, 0, 0, [], []>;
 defm RestrictPointerEXT : DecorationOperand<5355, 0, 0, [], [PhysicalStorageBufferAddressesEXT]>;
 defm AliasedPointerEXT : DecorationOperand<5356, 0, 0, [], [PhysicalStorageBufferAddressesEXT]>;
+defm ReferencedIndirectlyINTEL : DecorationOperand<5602, 0, 0, [], [IndirectReferencesINTEL]>;
+defm ArgumentAttributeINTEL : DecorationOperand<6409, 0, 0, [], [FunctionPointersINTEL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define BuiltIn enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll
new file mode 100644
index 0000000000000..0bd1b5d776a94
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll
@@ -0,0 +1,34 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability Int8
+; CHECK-DAG: OpCapability FunctionPointersINTEL
+; CHECK-DAG: OpCapability Int64
+; CHECK: OpExtension "SPV_INTEL_function_pointers"
+; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
+; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
+; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyVoid]] %[[TyInt64]]
+; CHECK-DAG: %[[ConstInt64:.*]] = OpConstant %[[TyInt64]] 42
+; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]]
+; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFunFp]] %[[DefFunFp:.*]]
+; CHECK: %[[FunPtr1:.*]] = OpBitcast %[[#]] %[[ConstFunFp]]
+; CHECK: %[[FunPtr2:.*]] = OpLoad %[[#]] %[[FunPtr1]]
+; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FunPtr2]] %[[ConstInt64]]
+; CHECK: OpReturn
+; CHECK: OpFunctionEnd
+; CHECK: %[[DefFunFp]] = OpFunction %[[TyVoid]] None %[[TyFunFp]]
+
+target triple = "spir64-unknown-unknown"
+
+define spir_kernel void @test() {
+entry:
+  %0 = load ptr, ptr @foo
+  %1 = call i64 %0(i64 42)
+  ret void
+}
+
+define void @foo(i64 %a) {
+entry:
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
new file mode 100644
index 0000000000000..33f176b7325d5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
@@ -0,0 +1,34 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability Int8
+; CHECK-DAG: OpCapability FunctionPointersINTEL
+; CHECK-DAG: OpCapability Int64
+; CHECK: OpExtension "SPV_INTEL_function_pointers"
+; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
+; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-DAG: %[[TyFloat32:.*]] = OpTypeFloat 32
+; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
+; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]]
+; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyFunBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrInt8]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]]
+; CHECK-DAG: %[[TyPtrFunBar:.*]] = OpTypePointer Function %[[TyFunBar]]
+; CHECK: %[[TyFunTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrFunFp]] %[[TyPtrInt8]] %[[TyPtrFunBar]]
+; CHECK: %[[FunTest:.*]] = OpFunction %[[TyVoid]] None %[[TyFunTest]]
+; CHECK: %[[ArgFp:.*]] = OpFunctionParameter %[[TyPtrFunFp]]
+; CHECK: %[[ArgData:.*]] = OpFunctionParameter %[[TyPtrInt8]]
+; CHECK: %[[ArgBar:.*]] = OpFunctionParameter %[[TyPtrFunBar]]
+; CHECK: OpFunctionPointerCallINTEL %[[TyFloat32]] %[[ArgFp]] %[[ArgBar]]
+; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[ArgBar]] %[[ArgFp]] %[[ArgData]]
+; CHECK: OpReturn
+; CHECK: OpFunctionEnd
+
+target triple = "spir64-unknown-unknown"
+
+define spir_kernel void @test(ptr %fp, ptr %data, ptr %bar) {
+entry:
+  %0 = call spir_func float %fp(ptr %bar)
+  %1 = call spir_func i64 %bar(ptr %fp, ptr %data)
+  ret void
+}

>From 11d00eee895156e17264f68807231303db7fffa9 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 5 Feb 2024 15:20:11 -0800
Subject: [PATCH 2/4] apply clang-format

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp    | 4 ++--
 llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 5 +++--
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 9fe517a996e91..a31180f847f2b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -381,8 +381,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   if (hasFunctionPointers) {
     if (F.hasFnAttribute("referenced-indirectly")) {
       assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
-              "Unexpected 'referenced-indirectly' attribute of the kernel "
-              "function");
+             "Unexpected 'referenced-indirectly' attribute of the kernel "
+             "function");
       buildOpDecorate(FuncVReg, MIRBuilder,
                       SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
     }
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 65bc651116962..dabd73367cd6a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -266,7 +266,8 @@ class SPIRVGeneralDuplicatesTracker {
 
   // map a Function to its definition (as a machine instruction operand)
   DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
-  // map function pointer (as a machine instruction operand) to the used Function
+  // map function pointer (as a machine instruction operand) to the used
+  // Function
   DenseMap<const MachineOperand *, const Function *> InstrToFunction;
 
   // NOTE: using MOs instead of regs to get rid of MF dependency to be able
@@ -289,7 +290,7 @@ class SPIRVGeneralDuplicatesTracker {
   // pointer to a machine operand that represents the function definition.
   // Return either the register or invalid value, because we have no context for
   // a good diagnostic message in case of unexpectedly missing references.
-  const MachineOperand* getFunctionDefinitionByUse(const MachineOperand *Use) {
+  const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
     auto ResF = InstrToFunction.find(Use);
     if (ResF == InstrToFunction.end())
       return nullptr;

>From 66a7079d90063156abc63b6efe6685ddc2d0addd Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 6 Feb 2024 04:48:55 -0800
Subject: [PATCH 3/4] simplify Op->Fun->FunDefOp register

---
 .../lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 29 ------------------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 30 ++++++++++++-------
 2 files changed, 20 insertions(+), 39 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index dabd73367cd6a..96cc621791e97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -264,12 +264,6 @@ class SPIRVGeneralDuplicatesTracker {
   SPIRVDuplicatesTracker<Argument> AT;
   SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
 
-  // map a Function to its definition (as a machine instruction operand)
-  DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
-  // map function pointer (as a machine instruction operand) to the used
-  // Function
-  DenseMap<const MachineOperand *, const Function *> InstrToFunction;
-
   // NOTE: using MOs instead of regs to get rid of MF dependency to be able
   // to use flat data structure.
   // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
@@ -286,29 +280,6 @@ class SPIRVGeneralDuplicatesTracker {
   void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
                       MachineModuleInfo *MMI);
 
-  // Map a machine operand that represents a use of a function via function
-  // pointer to a machine operand that represents the function definition.
-  // Return either the register or invalid value, because we have no context for
-  // a good diagnostic message in case of unexpectedly missing references.
-  const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
-    auto ResF = InstrToFunction.find(Use);
-    if (ResF == InstrToFunction.end())
-      return nullptr;
-    auto ResReg = FunctionToInstr.find(ResF->second);
-    return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
-  }
-  // map function pointer (as a machine instruction operand) to the used
-  // Function
-  void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
-    InstrToFunction[MO] = F;
-  }
-  // map a Function to its definition (as a machine instruction)
-  void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
-    FunctionToInstr[F] = MO;
-  }
-  // Return true if any OpConstantFunctionPointerINTEL were generated
-  bool hasConstFunPtr() { return !InstrToFunction.empty(); }
-
   void add(const Type *Ty, const MachineFunction *MF, Register R) {
     TT.add(Ty, MF, R);
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 3ccc224ad8a4c..792a00786f0aa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -38,6 +38,12 @@ class SPIRVGlobalRegistry {
 
   DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
 
+  // map a Function to its definition (as a machine instruction operand)
+  DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
+  // map function pointer (as a machine instruction operand) to the used
+  // Function
+  DenseMap<const MachineOperand *, const Function *> InstrToFunction;
+
   // Look for an equivalent of the newType in the map. Return the equivalent
   // if it's found, otherwise insert newType to the map and return the type.
   const MachineInstr *checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
@@ -101,24 +107,28 @@ class SPIRVGlobalRegistry {
     DT.buildDepsGraph(Graph, MMI);
   }
 
+  // Map a machine operand that represents a use of a function via function
+  // pointer to a machine operand that represents the function definition.
+  // Return either the register or invalid value, because we have no context for
+  // a good diagnostic message in case of unexpectedly missing references.
+  const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
+    auto ResF = InstrToFunction.find(Use);
+    if (ResF == InstrToFunction.end())
+      return nullptr;
+    auto ResReg = FunctionToInstr.find(ResF->second);
+    return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
+  }
   // map function pointer (as a machine instruction operand) to the used
   // Function
   void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
-    DT.recordFunctionPointer(MO, F);
+    InstrToFunction[MO] = F;
   }
   // map a Function to its definition (as a machine instruction)
   void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
-    DT.recordFunctionDefinition(F, MO);
-  }
-  // Map a machine operand that represents a use of a function via function
-  // pointer to a machine operand that represents the function definition.
-  // Return either the register or invalid value, because we have no context for
-  // a good diagnostic message in case of unexpectedly missing references.
-  const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
-    return DT.getFunctionDefinitionByUse(Use);
+    FunctionToInstr[F] = MO;
   }
   // Return true if any OpConstantFunctionPointerINTEL were generated
-  bool hasConstFunPtr() { return DT.hasConstFunPtr(); }
+  bool hasConstFunPtr() { return !InstrToFunction.empty(); }
 
   // Get or create a SPIR-V type corresponding the given LLVM IR type,
   // and map it to the given VReg by creating an ASSIGN_TYPE instruction.

>From 63da2447febd219a56141d7b6f4447ba297e023d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 6 Feb 2024 05:54:15 -0800
Subject: [PATCH 4/4] fix MachineInstrBuilder usage

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 23 +++++++++----------
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  2 +-
 2 files changed, 12 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index a31180f847f2b..42deba3b330e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -332,13 +332,12 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     FormalArgs.OrigFTy = FTy;
     FormalArgs.ArgTypeVRegs = ArgTypeVRegs;
   } else {
-    const MachineInstrBuilder &MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
-                                        .addDef(FuncVReg)
-                                        .addUse(GR->getSPIRVTypeID(RetTy))
-                                        .addImm(FuncControl)
-                                        .addUse(GR->getSPIRVTypeID(FuncTy));
-    const MachineOperand *DefOpFunction = &MB.getInstr()->getOperand(0);
-    GR->recordFunctionDefinition(&F, DefOpFunction);
+    MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
+                                 .addDef(FuncVReg)
+                                 .addUse(GR->getSPIRVTypeID(RetTy))
+                                 .addImm(FuncControl)
+                                 .addUse(GR->getSPIRVTypeID(FuncTy));
+    GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0));
   }
 
   // Add OpFunctionParameters.
@@ -459,11 +458,11 @@ void SPIRVCallLowering::SPIRVFunFormalArgs::produceFunArgsInstructions(
       UpdateFTy, RetTy, ArgTypeVRegs, MIRBuilder);
 
   // Emit OpFunction
-  const MachineInstrBuilder &MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
-                                      .addDef(FuncVReg)
-                                      .addUse(GR->getSPIRVTypeID(RetTy))
-                                      .addImm(FuncControl)
-                                      .addUse(GR->getSPIRVTypeID(FuncTy));
+  MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
+                               .addDef(FuncVReg)
+                               .addUse(GR->getSPIRVTypeID(RetTy))
+                               .addImm(FuncControl)
+                               .addUse(GR->getSPIRVTypeID(FuncTy));
   GR->recordFunctionDefinition(F, &MB.getInstr()->getOperand(0));
 
   // Emit OpFunctionParameter's
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f7cf3e1936e33..52eeb8a523e6f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1562,7 +1562,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
         MachineRegisterInfo *MRI = MIRBuilder.getMRI();
         Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
         MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
-        const MachineInstrBuilder &MB =
+        MachineInstrBuilder MB =
             BuildMI(BB, I, I.getDebugLoc(),
                     TII.get(SPIRV::OpConstantFunctionPointerINTEL))
                 .addDef(NewReg)



More information about the llvm-commits mailing list