[llvm] [SPIR-V] Add pass to fixup global variable AS (PR #124591)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 31 03:21:41 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Nathan Gauër (Keenuts)

<details>
<summary>Changes</summary>

SPIR-V has storage classes, which behave similarly to LLVM's address spaces.
At the HLSL language level, this concept is not well defined, and some operations with no SPIR-V equivalent might be valid, for example:

```
static int global;

int& get_ref(int x, int &local) {
  return x ? global : local;
}

void main() {
  int local;
  int& ref = get_ref(0, local);
}
```

In this example, the function `get_ref` cannot be emitted in SPIR-V because the return value is a pointer with either the Function or the Private storage class.

A solution for this is to force inlining the function, and then fixup the pointer address spaces, which are then lowered into the SPIR-V storage class. This is however impossible when building a SPIR-V library.

The real solution is to make sure we use a single address space for both globals and locals, meaning we must make all variable local, or vis-versa.
Moving all variables to the local scope is not possible, once again because of the library target.
The last solution is to move all variables to the global scope. This is possible only because Vulkan disallow static recursion.

This commit adds a new pass to handle those fixups in the backend. It works at the IR level, and moves all local variables to the global scope. It also rewrite all pointer to the default AS to `ptr addrspace(10)`.

Function address spaces are left untouched: we have no indirect jump, hence we have no pointer to code, only data.

---

Patch is 52.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124591.diff


24 Files Affected:

- (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1) 
- (added) llvm/lib/Target/SPIRV/SPIRVFixAddressSpace.cpp (+593) 
- (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+8) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll (+2-2) 
- (modified) llvm/test/CodeGen/SPIRV/logical-access-chain.ll (+11-4) 
- (modified) llvm/test/CodeGen/SPIRV/logical-struct-access.ll (+8-4) 
- (added) llvm/test/CodeGen/SPIRV/passes/SPIRVFixAddressSpace-call.ll (+31) 
- (added) llvm/test/CodeGen/SPIRV/passes/SPIRVFixAddressSpace-initialization.ll (+25) 
- (added) llvm/test/CodeGen/SPIRV/passes/SPIRVFixAddressSpace-object.ll (+60) 
- (added) llvm/test/CodeGen/SPIRV/passes/SPIRVFixAddressSpace-ptr-ptr.ll (+31) 
- (added) llvm/test/CodeGen/SPIRV/passes/SPIRVFixAddressSpace-simple-local.ll (+42) 
- (added) llvm/test/CodeGen/SPIRV/passes/SPIRVFixAddressSpace-simple.ll (+30) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/basic-phi.ll (+1-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.cond-op.ll (+23-15) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.for.plain.ll (+7-7) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/condition-linear.ll (+2-2) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/logical-or.ll (+5-4) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/loop-continue-split.ll (+5-4) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll (+3-2) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll (+2-2) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll (+3-3) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/return-early.ll (+2-2) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/GlobalFunAnnotate.ll (+7-1) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index efdd8c8d24fbd5..2f6870195581f8 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -34,6 +34,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVMetadata.cpp
   SPIRVModuleAnalysis.cpp
   SPIRVStructurizer.cpp
+  SPIRVFixAddressSpace.cpp
   SPIRVPreLegalizer.cpp
   SPIRVPreLegalizerCombiner.cpp
   SPIRVPostLegalizer.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 6d00a046ff7caa..2353d04b3df082 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -20,6 +20,7 @@ class InstructionSelector;
 class RegisterBankInfo;
 
 ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
+ModulePass *createSPIRVFixAddressSpacePass();
 FunctionPass *createSPIRVStructurizerPass();
 FunctionPass *createSPIRVMergeRegionExitTargetsPass();
 FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVFixAddressSpace.cpp b/llvm/lib/Target/SPIRV/SPIRVFixAddressSpace.cpp
new file mode 100644
index 00000000000000..80bc60d917ab93
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVFixAddressSpace.cpp
@@ -0,0 +1,593 @@
+//===-- SPIRVFixAddressSpace.cpp ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This pass is Vulkan specific as Logical SPIR-V doesn't support pointer
+// cast.
+//
+// In LLVM IR, global and local variables are by default in the same address
+// space. In SPIR-V, global and local variables live in a different address
+// space (storage class in the SPIR-V parlance).
+//
+// This means the following function cannot be lowered to SPIR-V:
+//   static int global = 0;
+//   int& return_ref(int& ref, bool select) {
+//    return select ? ref : global;
+//   }
+//
+// A solution is to force inline, but this would prevent us from emitting SPIR-V
+// libraries. Another solution is to move all globals to local variables, but
+// this also blocks libraries. The last solution is to replace all local
+// variables with global variables. This is possible because Vulkan SPIR-V
+// completely forbids static recursion.
+//
+// This pass replace all alloca instruction with a new global variable.
+// In addition, it moves all such allocations into the `Private` address space.
+//
+// After this pass, no variable or pointer should reference the default address
+// space.
+//
+// Note:
+//  LLVM IR has address spaces for functions, but SPIR-V doesn't. In addition,
+//  Vulkan disallow function pointers and indirect jump, meaning we could never
+//  have a pointer storing the function address. For this reason, functions are
+//  left in the default address space, but all pointer operands to the default
+//  AS are rewritten to point to the AS `Private`. This kind of blind rewrite
+//  simplifies the code, but can only work with those assumptions.
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/NoFolder.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+#include <queue>
+#include <stack>
+#include <type_traits>
+#include <unordered_set>
+
+using namespace llvm;
+using namespace SPIRV;
+
+namespace llvm {
+void initializeSPIRVFixAddressSpacePass(PassRegistry &);
+} // namespace llvm
+
+namespace {
+
+constexpr unsigned GlobalAddressSpace =
+    storageClassToAddressSpace(SPIRV::StorageClass::Private);
+
+// Returns true if the given type or any subtype contains a pointer to the
+// default address space.
+bool typeRequiresConversion(Type *T) {
+  if (T->isPointerTy())
+    return T->getPointerAddressSpace() == 0;
+
+  if (T->isArrayTy())
+    return typeRequiresConversion(T->getArrayElementType());
+  if (T->isStructTy()) {
+    for (unsigned I = 0; I < T->getStructNumElements(); ++I)
+      if (typeRequiresConversion(T->getStructElementType(I)))
+        return true;
+    return false;
+  }
+  if (FunctionType *FT = dyn_cast<FunctionType>(T)) {
+    if (typeRequiresConversion(FT->getReturnType()))
+      return true;
+    for (Type *TP : FT->params())
+      if (typeRequiresConversion(TP))
+        return true;
+    return false;
+  }
+
+  return false;
+}
+
+class SPIRVFixAddressSpace : public InstVisitor<SPIRVFixAddressSpace>,
+                             public ModulePass {
+
+  // Types are supposed to be unique. When using Type::get, there is a lookup to
+  // only create types on demand. StructType::get only allows creating literal
+  // structs, meaning we would lose the type name. This forces us to use
+  // StructType::create, which doesn't deduplicates. This means we must bring
+  // our own old-type/new-type map to prevent creating distinct types when we
+  // shouldn't.
+  std::unordered_map<Type *, Type *> ConvertedTypes;
+
+private:
+  // If the passed type or subtype contains a pointer to AS 0, get or create a
+  // new type with all pointer changed to address space `Private`. The functions
+  // below are overloads depending on the input type.
+  PointerType *convertPointerType(PointerType *PT);
+  ArrayType *convertArrayType(ArrayType *AT);
+  StructType *convertStructType(StructType *ST);
+  VectorType *convertVectorType(VectorType *VT);
+  FunctionType *convertFunctionType(FunctionType *FT);
+
+  // Get or create a new type if `T` or any of its subtype is a pointer to the
+  // address space 0. All pointers in the returned type points to the address
+  // space `Private`. If `T` was already converted once, the cached converted
+  // type is returned and no additional type is created.
+  Type *convertType(Type *T);
+
+  // If `C` type or subtype contains any pointer to the address space 0, returns
+  // a new constant with a fixed type.
+  Constant *convertConstant(Constant *C);
+
+  // See `convertConstant`. Those functions are overloads to handle specific
+  // constant types.
+  Constant *convertConstantAggregate(ConstantAggregate *CA);
+  ConstantData *convertConstantData(ConstantData *CD);
+
+  // If the passes global variable is in the default address space, replace it
+  // with a global in the `Private` address space (=SPIR-V storage class). Does
+  // not modify globals in a different address space (resources for ex). Returns
+  // true if the global was replaced.
+  bool rewriteGlobalVariable(Module &M, GlobalVariable *GV);
+
+  // Modifies the given function by replacing all alloca by a global variable.
+  // This function requires the alloca allocation size to be static:
+  //  - Vulkan doesn't support VLA in local variables. (See
+  //  VUID-StandaloneSpirv-OpTypeRuntimeArray-04680).
+  //  - HLSL doesn't allow VLA.
+  // Returns true if the function was modified.
+  bool replaceAlloca(Function &F);
+
+  // Mutate the types and operands of `F` to make sure no referenced type has a
+  // pointer to the default address space. This function does not propagate the
+  // type changes, hence if not used carefully, this could generate invalid IR.
+  // Returns true if the function was modified.
+  bool blindlyMutateTypes(Function &F);
+
+  // Modifies any GEP instruction in the given function to only use
+  // ptr addrspace(10) instead of pointers to the default address space.
+  // Returns true if the function was modified.
+  bool rewriteGEP(Function &F);
+
+  // Checks all instructions in F, and make sure no pointer to the default
+  // address space or local variable remains. This function assumes all other
+  // functions/globals undergo the same treatment. Not calling this function on
+  // all the module functions could yield to invalid IR. Returns true if the
+  // function has been modified.
+  bool fixInstructions(Function &F);
+
+  // Checks if any type/subtype in the return value or a parameter is a ptr to
+  // the default address space. If such pointer is found, recreate the function
+  // replacing it with a `ptr addrspace(10)`. This function replaces all uses of
+  // the function return value/argument/declaration with the new version, but
+  // does not propagate changes further. If the rest of the instruction is not
+  // cleaned-up, this can produce invalid IR. Returns true if the function has
+  // been replaced.
+  bool rewriteFunctionParameters(Module &M, Function *F);
+
+  // Fix all functions in the given module.
+  // Returns true if any of the module functions have been modified.
+  // If the module is modified, new global variables could have been added.
+  // This function only modified global variables referenced by at least one
+  // function. Returns true if the module was modified.
+  bool fixFunctions(Module &M);
+
+  // Fix all the globals in the given module, even if not referenced by any
+  // function.
+  bool fixGlobals(Module &M);
+
+public:
+  static char ID;
+
+  SPIRVFixAddressSpace() : ModulePass(ID) {
+    initializeSPIRVFixAddressSpacePass(*PassRegistry::getPassRegistry());
+  };
+
+  virtual bool runOnModule(Module &M) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    ModulePass::getAnalysisUsage(AU);
+  }
+};
+
+PointerType *SPIRVFixAddressSpace::convertPointerType(PointerType *PT) {
+  if (PT->getPointerAddressSpace() != 0)
+    return PT;
+
+  auto It = ConvertedTypes.find(PT);
+  if (It != ConvertedTypes.end())
+    return cast<PointerType>(It->second);
+
+  PointerType *NewType = PointerType::get(PT->getContext(), GlobalAddressSpace);
+  ConvertedTypes.emplace(PT, NewType);
+  return NewType;
+}
+
+ArrayType *SPIRVFixAddressSpace::convertArrayType(ArrayType *AT) {
+  if (!typeRequiresConversion(AT))
+    return AT;
+
+  auto It = ConvertedTypes.find(AT);
+  if (It != ConvertedTypes.end())
+    return cast<ArrayType>(It->second);
+
+  Type *ElementType = convertType(AT->getElementType());
+  ArrayType *NewType = ArrayType::get(ElementType, AT->getNumElements());
+  ConvertedTypes.emplace(AT, NewType);
+  return NewType;
+}
+
+StructType *SPIRVFixAddressSpace::convertStructType(StructType *ST) {
+  if (!typeRequiresConversion(ST))
+    return ST;
+  std::vector<Type *> Elements;
+  Elements.resize(ST->getNumElements());
+  for (unsigned I = 0; I < Elements.size(); ++I)
+    Elements[I] = convertType(ST->getElementType(I));
+
+  auto It = ConvertedTypes.find(ST);
+  if (It != ConvertedTypes.end())
+    return cast<StructType>(It->second);
+
+  if (!ST->hasName())
+    return StructType::get(ST->getContext(), Elements);
+
+  std::string OldName = ST->getName().str();
+  ST->setName(OldName + ".old");
+
+  StructType *NewType = StructType::create(ST->getContext(), Elements, OldName);
+  ConvertedTypes.emplace(ST, NewType);
+  return NewType;
+}
+
+VectorType *SPIRVFixAddressSpace::convertVectorType(VectorType *VT) {
+  if (!typeRequiresConversion(VT))
+    return VT;
+
+  auto It = ConvertedTypes.find(VT);
+  if (It != ConvertedTypes.end())
+    return cast<VectorType>(It->second);
+
+  Type *ElementType = convertType(VT->getElementType());
+  VectorType *NewType = VectorType::get(ElementType, VT->getElementCount());
+  ConvertedTypes.emplace(VT, NewType);
+  return NewType;
+}
+
+FunctionType *SPIRVFixAddressSpace::convertFunctionType(FunctionType *FT) {
+  if (!typeRequiresConversion(FT))
+    return FT;
+
+  auto It = ConvertedTypes.find(FT);
+  if (It != ConvertedTypes.end())
+    return cast<FunctionType>(It->second);
+
+  Type *ReturnType = FT->getReturnType();
+  std::vector<Type *> Params;
+  Params.reserve(FT->getNumParams());
+  for (Type *P : FT->params())
+    Params.push_back(convertType(P));
+
+  FunctionType *NewType = FunctionType::get(ReturnType, Params, FT->isVarArg());
+  ConvertedTypes.emplace(FT, NewType);
+  return NewType;
+}
+
+// Replace all references of the default address space in `T` with
+// the `Private` SPIR-V address space, recreating the type is required.
+// Returns the new type if recreated, `T` otherwise.
+Type *SPIRVFixAddressSpace::convertType(Type *T) {
+  if (PointerType *PT = dyn_cast<PointerType>(T))
+    return convertPointerType(PT);
+
+  if (ArrayType *AT = dyn_cast<ArrayType>(T))
+    return convertArrayType(AT);
+
+  if (VectorType *VT = dyn_cast<VectorType>(T))
+    return convertVectorType(VT);
+
+  if (StructType *ST = dyn_cast<StructType>(T))
+    return convertStructType(ST);
+
+  if (FunctionType *FT = dyn_cast<FunctionType>(T))
+    return convertFunctionType(FT);
+
+  if (isa<TargetExtType>(T))
+    return T;
+
+  if (T == Type::getTokenTy(T->getContext()))
+    return T;
+
+  if (T == Type::getLabelTy(T->getContext()))
+    return T;
+
+  // TypedPointerType: not implemented on purpose.
+
+  // Make sure pointers & vectors are handled above.
+  // All other single-value types don't address space conversion.
+  assert(!T->isPointerTy() && !T->isVectorTy());
+  if (T->isSingleValueType())
+    return T;
+
+  llvm_unreachable("Unsupported type for address space fixup.");
+}
+
+Constant *
+SPIRVFixAddressSpace::convertConstantAggregate(ConstantAggregate *CA) {
+  Type *NewType = convertType(CA->getType());
+  std::vector<Constant *> Elements;
+  Elements.resize(CA->getNumOperands());
+  for (unsigned I = 0; I < CA->getNumOperands(); ++I)
+    Elements[I] = convertConstant(cast<Constant>(CA->getOperand(I)));
+
+  if (isa<ConstantArray>(CA))
+    return ConstantArray::get(cast<ArrayType>(NewType), Elements);
+  else if (isa<ConstantStruct>(CA))
+    return ConstantStruct::get(cast<StructType>(NewType), Elements);
+  return ConstantVector::get(Elements);
+}
+
+ConstantData *SPIRVFixAddressSpace::convertConstantData(ConstantData *CD) {
+  if (!typeRequiresConversion(CD->getType()))
+    return CD;
+
+  if (ConstantPointerNull *CPN = dyn_cast<ConstantPointerNull>(CD))
+    return ConstantPointerNull::get(convertPointerType(CPN->getType()));
+  report_fatal_error("Unsupported ConstantData type.");
+}
+
+// Replace all references of the default address space in `C` with the
+// SPIR-V private address space, recreating the constant, and/or modifying the
+// type if required.
+Constant *SPIRVFixAddressSpace::convertConstant(Constant *C) {
+  if (!typeRequiresConversion(C->getType()))
+    return C;
+
+  if (ConstantAggregate *CA = dyn_cast<ConstantAggregate>(C))
+    return convertConstantAggregate(CA);
+  if (ConstantData *CD = dyn_cast<ConstantData>(C))
+    return convertConstantData(CD);
+  llvm_unreachable("Unsupported constant type.");
+}
+
+bool SPIRVFixAddressSpace::rewriteGlobalVariable(Module &M,
+                                                 GlobalVariable *GV) {
+  if (GV->getAddressSpace() != 0)
+    return false;
+
+  Type *NewType = GV->getValueType();
+  if (typeRequiresConversion(GV->getValueType()))
+    NewType = convertType(NewType);
+
+  std::string OldName = GV->getName().str();
+  GV->setName(OldName + ".dead");
+  GlobalVariable *NewGV = new GlobalVariable(
+      M, NewType,
+      /* isConstant= */ false, GV->getLinkage(),
+      convertConstant(GV->getInitializer()), OldName,
+      /* insertBefore= */ GV, GV->getThreadLocalMode(), GlobalAddressSpace);
+
+  std::vector<User *> ToFix(GV->user_begin(), GV->user_end());
+  for (auto *User : ToFix) {
+    if (Constant *C = dyn_cast<Constant>(User))
+      C->handleOperandChange(GV, NewGV);
+    else
+      User->replaceUsesOfWith(GV, NewGV);
+  }
+  M.eraseGlobalVariable(GV);
+  return true;
+}
+
+bool SPIRVFixAddressSpace::replaceAlloca(Function &F) {
+  std::unordered_set<Instruction *> DeadInstructions;
+
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      AllocaInst *AI = dyn_cast<AllocaInst>(&I);
+      if (!AI)
+        continue;
+
+      // Vulkan doesn't support VLA in local variables. (See
+      // VUID-StandaloneSpirv-OpTypeRuntimeArray-04680). HLSL doesn't allow VLA,
+      // meaning we should not encounter this for now, but it another frontend
+      // is used, we may hit this case.
+      assert(isa<ConstantInt>(AI->getArraySize()));
+
+      Type *NewType = convertType(AI->getAllocatedType());
+      GlobalVariable *NewGV = new GlobalVariable(
+          *F.getParent(), NewType,
+          /* isConstant= */ false, GlobalValue::LinkageTypes::InternalLinkage,
+          Constant::getNullValue(NewType), F.getName() + ".local",
+          /* insertBefore= */ nullptr,
+          GlobalValue::ThreadLocalMode::NotThreadLocal, GlobalAddressSpace);
+
+      std::vector<User *> ToFix(AI->user_begin(), AI->user_end());
+      for (auto *User : ToFix)
+        User->replaceUsesOfWith(AI, NewGV);
+      DeadInstructions.insert(AI);
+    }
+  }
+
+  for (auto *I : DeadInstructions)
+    I->eraseFromParent();
+  return DeadInstructions.size() != 0;
+}
+
+bool SPIRVFixAddressSpace::blindlyMutateTypes(Function &F) {
+  bool Modified = false;
+
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      for (auto &Op : I.operands()) {
+        if (isa<Function>(Op.get()) || isa<BlockAddress>(Op.get()))
+          continue;
+
+        Type *NewType = convertType(Op->getType());
+        Op->mutateType(NewType);
+        Modified = true;
+      }
+
+      if (typeRequiresConversion(I.getType())) {
+        Type *NewType = convertType(I.getType());
+        I.mutateType(NewType);
+        Modified = true;
+      }
+    }
+  }
+
+  return Modified;
+}
+
+bool SPIRVFixAddressSpace::rewriteGEP(Function &F) {
+  std::unordered_set<Instruction *> DeadInstructions;
+
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      auto *GEP = dyn_cast<GetElementPtrInst>(&I);
+      if (!GEP)
+        continue;
+
+      Type *SourceType = convertType(GEP->getSourceElementType());
+
+      IRBuilder<NoFolder> B(GEP->getParent(), NoFolder());
+      B.SetInsertPoint(GEP);
+      std::vector<Value *> Indices(GEP->idx_begin(), GEP->idx_end());
+      auto *NewInstr =
+          B.CreateGEP(SourceType, GEP->getPointerOperand(), Indices,
+                      GEP->getName(), GEP->getNoWrapFlags());
+      GEP->replaceAllUsesWith(NewInstr);
+      DeadInstructions.insert(GEP);
+    }
+  }
+
+  for (auto *I : DeadInstructions)
+    I->eraseFromParent();
+  return DeadInstructions.size() != 0;
+}
+
+bool SPIRVFixAddressSpace::fixInstructions(Function &F) {
+  bool Modified = false;
+
+  Modified |= replaceAlloca(F);
+  Modified |= blindlyMutateTypes(F);
+  Modified |= rewriteGEP(F);
+
+  return Modified;
+}
+
+bool SPIRVFixAddressSpace::rewriteFunctionParameters(Module &M, Function *F) {
+  if (F->isDeclaration())
+    return false;
+
+  FunctionType *NewType = convertFunctionType(F->getFunctionType());
+  if (NewType == F->getFunctionType())
+    return false;
+
+  std::string OldName = F->getName().str();
+  F->setName(OldName + ".dead");
+  Function *NewFunction = Function::Create(NewType, F->getLinkage(),
+                                           /* AddressSpace= */ 0, OldName);
+  NewFunction->copyAttributesFrom(F);
+  NewFunction->copyMetadata(F, 0);
+  NewFunction->setIsNewDbgInfoFormat(F->IsNewDbgInfoFormat);
+  M.getFunctionList().insert(F->getIterator(), NewFunction);
+
+  std::vector<User *> ToFix(F->user_begin(), F->user_end());
+  for (auto *User : ToFix) {
+    User->replaceUsesOfWith(F, NewFunction);
+    CallBase *CB = dyn_cast<CallBase>(User);
+    if (!CB)
+      continue;
+    CB->mutateFunctionType(NewType);
+  }
+
+  for (size_t I = 0; I < NewFunction->arg_size(); ++I) {
+    Argument *OldArgument = F->getArg(I);
+    Argument *NewArgument = NewFunction->getArg(I);
+    NewArgument->set...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/124591


More information about the llvm-commits mailing list