[llvm] [SPIR-V] Improve type inference in SPIR-V Backend for opaque pointers (PR #86283)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 22 06:31:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR improves type inference in SPIR-V Backend for opaque pointers, accounting or a case when there is a chain of function calls that allows to deduce formal parameter types from actual arguments. The attached test demonstrates the case.


---
Full diff: https://github.com/llvm/llvm-project/pull/86283.diff


4 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+77-46) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+8-4) 
- (added) llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll (+52) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 458af9229ed7b1..5828db6669ff18 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -92,6 +92,9 @@ class SPIRVEmitIntrinsics
   void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
   void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
   void processParamTypes(Function *F, IRBuilder<> &B);
+  Type *deduceFunParamType(Function *F, unsigned OpIdx);
+  Type *deduceFunParamType(Function *F, unsigned OpIdx,
+                           std::unordered_set<Function *> &FVisited);
 
 public:
   static char ID;
@@ -169,6 +172,10 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
 static Type *deduceElementTypeHelper(Value *I,
                                      std::unordered_set<Value *> &Visited,
                                      DenseMap<Value *, Type *> &DeducedElTys) {
+  // allow to pass nullptr as an argument
+  if (!I)
+    return nullptr;
+
   // maybe already known
   auto It = DeducedElTys.find(I);
   if (It != DeducedElTys.end())
@@ -182,15 +189,20 @@ static Type *deduceElementTypeHelper(Value *I,
   // fallback value in case when we fail to deduce a type
   Type *Ty = nullptr;
   // look for known basic patterns of type inference
-  if (auto *Ref = dyn_cast<AllocaInst>(I))
+  if (auto *Ref = dyn_cast<AllocaInst>(I)) {
     Ty = Ref->getAllocatedType();
-  else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
+  } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
     Ty = Ref->getResultElementType();
-  else if (auto *Ref = dyn_cast<GlobalValue>(I))
+  } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
     Ty = Ref->getValueType();
-  else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
+  } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
     Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
                                  DeducedElTys);
+  } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
+    if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
+        isPointerTy(Src) && isPointerTy(Dest))
+      Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
+  }
 
   // remember the found relationship
   if (Ty)
@@ -795,61 +807,80 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
   }
 }
 
-void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
-  DenseMap<unsigned, Argument *> Args;
-  unsigned i = 0;
-  for (Argument &Arg : F->args()) {
-    if (isUntypedPointerTy(Arg.getType()) &&
-        DeducedElTys.find(&Arg) == DeducedElTys.end() &&
-        !HasPointeeTypeAttr(&Arg))
-      Args[i] = &Arg;
-    i++;
-  }
-  if (Args.size() == 0)
-    return;
+Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
+  std::unordered_set<Function *> FVisited;
+  return deduceFunParamType(F, OpIdx, FVisited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceFunParamType(
+    Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
+  // maybe a cycle
+  if (FVisited.find(F) != FVisited.end())
+    return nullptr;
+  FVisited.insert(F);
 
-  // Args contains opaque pointers without element type definition
-  B.SetInsertPointPastAllocas(F);
   std::unordered_set<Value *> Visited;
+  SmallVector<std::pair<Function *, unsigned>> Lookup;
+  // search in function's call sites
   for (User *U : F->users()) {
     CallInst *CI = dyn_cast<CallInst>(U);
-    if (!CI)
+    if (!CI || OpIdx >= CI->arg_size())
       continue;
-    for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
-         OpIdx++) {
-      auto It = Args.find(OpIdx);
-      Argument *Arg = It == Args.end() ? nullptr : It->second;
-      if (!Arg)
-        continue;
-      Value *OpArg = CI->getArgOperand(OpIdx);
-      if (!isPointerTy(OpArg->getType()))
+    Value *OpArg = CI->getArgOperand(OpIdx);
+    if (!isPointerTy(OpArg->getType()))
+      continue;
+    // maybe we already know operand's element type
+    if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
+      return It->second;
+    // search in actual parameter's users
+    for (User *OpU : OpArg->users()) {
+      Instruction *Inst = dyn_cast<Instruction>(OpU);
+      if (!Inst || Inst == CI)
         continue;
-      // maybe we already know the operand's element type
-      auto DeducedIt = DeducedElTys.find(OpArg);
-      Type *ElemTy =
-          DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
-      if (!ElemTy) {
-        for (User *OpU : OpArg->users()) {
-          if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
-            Visited.clear();
-            ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
-            if (ElemTy)
-              break;
-          }
-        }
+      Visited.clear();
+      if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
+        return Ty;
+    }
+    // check if it's a formal parameter of the outer function
+    if (!CI->getParent() || !CI->getParent()->getParent())
+      continue;
+    Function *OuterF = CI->getParent()->getParent();
+    if (FVisited.find(OuterF) != FVisited.end())
+      continue;
+    for (unsigned i = 0; i < OuterF->arg_size(); ++i) {
+      if (OuterF->getArg(i) == OpArg) {
+        Lookup.push_back(std::make_pair(OuterF, i));
+        break;
       }
-      if (ElemTy) {
-        unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
+    }
+  }
+
+  // search in function parameters
+  for (auto &Pair : Lookup) {
+    if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
+      return Ty;
+  }
+
+  return nullptr;
+}
+
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+  B.SetInsertPointPastAllocas(F);
+  DenseMap<Argument *, Type *> Args;
+  for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
+    Argument *Arg = F->getArg(OpIdx);
+    if (isUntypedPointerTy(Arg->getType()) &&
+        DeducedElTys.find(Arg) == DeducedElTys.end() &&
+        !HasPointeeTypeAttr(Arg)) {
+      if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
         CallInst *AssignPtrTyCI = buildIntrWithMD(
             Intrinsic::spv_assign_ptr_type, {Arg->getType()},
-            Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
+            Constant::getNullValue(ElemTy), Arg,
+            {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
         DeducedElTys[AssignPtrTyCI] = ElemTy;
         DeducedElTys[Arg] = ElemTy;
-        Args.erase(It);
       }
     }
-    if (Args.size() == 0)
-      break;
   }
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 42f8397a3023b1..ee52163a5d127f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -479,6 +479,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
     GVar = M->getGlobalVariable(Name);
     if (GVar == nullptr) {
       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
+      // Module takes ownership of the global var.
       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
                                 GlobalValue::ExternalLinkage, nullptr,
                                 Twine(Name));
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5bb8f6084f9671..39228e2196b3af 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -499,6 +499,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
     assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
     Register GV = I.getOperand(1).getReg();
     MachineRegisterInfo::def_instr_iterator II = MRI->def_instr_begin(GV);
+    (void)II;
     assert(((*II).getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
             (*II).getOpcode() == TargetOpcode::COPY ||
             (*II).getOpcode() == SPIRV::OpVariable) &&
@@ -771,10 +772,13 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
     SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
         ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
     // TODO: check if we have such GV, add init, use buildGlobalVariable.
-    Type *LLVMArrTy = ArrayType::get(
-        IntegerType::get(GR.CurMF->getFunction().getContext(), 8), Num);
-    GlobalVariable *GV =
-        new GlobalVariable(LLVMArrTy, true, GlobalValue::InternalLinkage);
+    Function &CurFunction = GR.CurMF->getFunction();
+    Type *LLVMArrTy =
+        ArrayType::get(IntegerType::get(CurFunction.getContext(), 8), Num);
+    // Module takes ownership of the global var.
+    GlobalVariable *GV = new GlobalVariable(*CurFunction.getParent(), LLVMArrTy,
+                                            true, GlobalValue::InternalLinkage,
+                                            Constant::getNullValue(LLVMArrTy));
     Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
     GR.add(GV, GR.CurMF, VarReg);
 
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
new file mode 100644
index 00000000000000..703f1e22a0321a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
@@ -0,0 +1,52 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[ArgCum:.*]] "_arg_cum"
+; CHECK-SPIRV-DAG: OpName %[[FunTest:.*]] "test"
+; CHECK-SPIRV-DAG: OpName %[[Addr:.*]] "addr"
+; CHECK-SPIRV-DAG: OpName %[[StubObj:.*]] "stub_object"
+; CHECK-SPIRV-DAG: OpName %[[MemOrder:.*]] "mem_order"
+; CHECK-SPIRV-DAG: OpName %[[FooStub:.*]] "foo_stub"
+; CHECK-SPIRV-DAG: OpName %[[FooObj:.*]] "foo_object"
+; CHECK-SPIRV-DAG: OpName %[[FooMemOrder:.*]] "mem_order"
+; CHECK-SPIRV-DAG: OpName %[[FooFunc:.*]] "foo"
+; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer CrossWorkgroup %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyFunPtrLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrLong]]
+; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyFunGenPtrLongLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrLong]] %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[Const3:.*]] = OpConstant %[[TyLong]] 3
+; CHECK-SPIRV: %[[FunTest]] = OpFunction %[[TyVoid]] None %[[TyFunPtrLong]]
+; CHECK-SPIRV: %[[ArgCum]] = OpFunctionParameter %[[TyPtrLong]]
+; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[Addr]] %[[Const3]]
+; CHECK-SPIRV: %[[FooStub]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
+; CHECK-SPIRV: %[[StubObj]] = OpFunctionParameter %[[TyGenPtrLong]]
+; CHECK-SPIRV: %[[MemOrder]] = OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV: %[[FooFunc]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
+; CHECK-SPIRV: %[[FooObj]] = OpFunctionParameter %[[TyGenPtrLong]]
+; CHECK-SPIRV: %[[FooMemOrder]] = OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooStub]] %[[FooObj]] %[[FooMemOrder]]
+
+define spir_kernel void @test(ptr addrspace(1) noundef align 4 %_arg_cum) {
+entry:
+  %lptr = getelementptr inbounds i32, ptr addrspace(1) %_arg_cum, i64 1
+  %addr = addrspacecast ptr addrspace(1) %lptr to ptr addrspace(4)
+  %object = bitcast ptr addrspace(4) %addr to ptr addrspace(4)
+  call spir_func void @foo(ptr addrspace(4) %object, i32 3)
+  ret void
+}
+
+define void @foo_stub(ptr addrspace(4) noundef %stub_object, i32 noundef %mem_order) {
+entry:
+  %object.addr = alloca ptr addrspace(4)
+  %object.addr.ascast = addrspacecast ptr %object.addr to ptr addrspace(4)
+  store ptr addrspace(4) %stub_object, ptr addrspace(4) %object.addr.ascast
+  ret void
+}
+
+define void @foo(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) {
+  tail call void @foo_stub(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order)
+  ret void
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/86283


More information about the llvm-commits mailing list