[llvm] 494f672 - [SPIR-V] Prevent type change of GEP results in type inference (#129250)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 28 11:55:21 PST 2025


Author: Vyacheslav Levytskyy
Date: 2025-02-28T20:55:14+01:00
New Revision: 494f67282f93f4a5c995434a3530a7a76f3aa63c

URL: https://github.com/llvm/llvm-project/commit/494f67282f93f4a5c995434a3530a7a76f3aa63c
DIFF: https://github.com/llvm/llvm-project/commit/494f67282f93f4a5c995434a3530a7a76f3aa63c.diff

LOG: [SPIR-V] Prevent type change of GEP results in type inference (#129250)

The following reproducer demonstrates the issue with invalid definition
of GEP results during type inference

```
define spir_kernel void @foo(i1 %fl, i64 %idx, ptr addrspace(1) %dest, ptr addrspace(3) %src) {
  %p1 = getelementptr inbounds i8, ptr addrspace(1) %dest, i64 %idx
  %res = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %p1, ptr addrspace(3) %src, i64 128, i64 1, target("spirv.Event") zeroinitializer)
  ret void
}

declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
```

Here `OpGroupAsyncCopy` expects i32* arguments and type inference fails
to set a correct type of the GEP result `%p1`, because it is an argument
of `OpGroupAsyncCopy`.

This PR fixes the issue by preventing type change of GEP results in type
inference.

Added: 
    llvm/test/CodeGen/SPIRV/pointers/ptr-access-chain-type.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5dfba8427258f..d6177058231d9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -646,6 +646,20 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
   Ty = RefTy;
 }
 
+Type *getGEPType(GetElementPtrInst *Ref) {
+  Type *Ty = nullptr;
+  // TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
+  // useful here
+  if (isNestedPointer(Ref->getSourceElementType())) {
+    Ty = Ref->getSourceElementType();
+    for (Use &U : drop_begin(Ref->indices()))
+      Ty = GetElementPtrInst::getTypeAtIndex(Ty, U.get());
+  } else {
+    Ty = Ref->getResultElementType();
+  }
+  return Ty;
+}
+
 Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
     Value *I, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8,
     bool IgnoreKnownType) {
@@ -668,15 +682,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
   if (auto *Ref = dyn_cast<AllocaInst>(I)) {
     maybeAssignPtrType(Ty, I, Ref->getAllocatedType(), UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
-    // TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
-    // useful here
-    if (isNestedPointer(Ref->getSourceElementType())) {
-      Ty = Ref->getSourceElementType();
-      for (Use &U : drop_begin(Ref->indices()))
-        Ty = GetElementPtrInst::getTypeAtIndex(Ty, U.get());
-    } else {
-      Ty = Ref->getResultElementType();
-    }
+    Ty = getGEPType(Ref);
   } else if (auto *Ref = dyn_cast<LoadInst>(I)) {
     Value *Op = Ref->getPointerOperand();
     Type *KnownTy = GR->findDeducedElementType(Op);
@@ -2307,6 +2313,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
 
 // Apply types parsed from demangled function declarations.
 void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
+  DenseMap<Function *, CallInst *> Ptrcasts;
   for (auto It : FDeclPtrTys) {
     Function *F = It.first;
     for (auto *U : F->users()) {
@@ -2326,6 +2333,9 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
             B.SetCurrentDebugLocation(DebugLoc());
             buildAssignPtr(B, ElemTy, Arg);
           }
+        } else if (isa<GetElementPtrInst>(Param)) {
+          replaceUsesOfWithSpvPtrcast(Param, normalizeType(ElemTy), CI,
+                                      Ptrcasts);
         } else if (isa<Instruction>(Param)) {
           GR->addDeducedElementType(Param, normalizeType(ElemTy));
           // insertAssignTypeIntrs() will complete buildAssignPtr()
@@ -2370,6 +2380,15 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   AggrConstTypes.clear();
   AggrStores.clear();
 
+  // fix GEP result types ahead of inference
+  for (auto &I : instructions(Func)) {
+    auto *Ref = dyn_cast<GetElementPtrInst>(&I);
+    if (!Ref || GR->findDeducedElementType(Ref))
+      continue;
+    if (Type *GepTy = getGEPType(Ref))
+      GR->addDeducedElementType(Ref, normalizeType(GepTy));
+  }
+
   processParamTypesByFunHeader(CurrF, B);
 
   // StoreInst's operand type can be changed during the next transformations,

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/ptr-access-chain-type.ll b/llvm/test/CodeGen/SPIRV/pointers/ptr-access-chain-type.ll
new file mode 100644
index 0000000000000..d69959609c9dc
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptr-access-chain-type.ll
@@ -0,0 +1,24 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#Char:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#Long:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#CharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
+; CHECK-DAG: %[[#LongPtr:]] = OpTypePointer CrossWorkgroup %[[#Long]]
+; CHECK-DAG: %[[#LongPtrWG:]] = OpTypePointer Workgroup %[[#Long]]
+; CHECK: OpFunction
+; CHECK: OpFunctionParameter
+; CHECK: %[[#Dest:]] = OpFunctionParameter %[[#CharPtr]]
+; CHECK: %[[#Src:]] = OpFunctionParameter %[[#LongPtrWG]]
+; CHECK: %[[#InDest:]] = OpInBoundsPtrAccessChain %[[#CharPtr]] %[[#Dest]] %[[#]]
+; CHECK: %[[#InDestCasted:]] = OpBitcast %[[#LongPtr]] %[[#InDest]]
+; CHECK: OpGroupAsyncCopy %[[#]] %[[#]] %[[#InDestCasted]] %[[#Src]] %[[#]] %[[#]] %[[#]]
+
+define spir_kernel void @foo(i64 %idx, ptr addrspace(1) %dest, ptr addrspace(3) %src) {
+  %p1 = getelementptr inbounds i8, ptr addrspace(1) %dest, i64 %idx
+  %res = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %p1, ptr addrspace(3) %src, i64 128, i64 1, target("spirv.Event") zeroinitializer)
+  ret void
+}
+
+; For this test case the mangling is important.
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))


        


More information about the llvm-commits mailing list