[llvm] [SPIR-V] Fix some GEP legalization (PR #150943)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 29 05:54:12 PDT 2025


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/150943

>From 9a220c057aefa4d2ed62c768d92ea4c8bf25e8dd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 28 Jul 2025 14:08:31 +0200
Subject: [PATCH 1/3] [SPIR-V] Fix some GEP legalization

Pointers and GEP are untyped. SPIR-V required structured OpAccessChain.
This means the backend will have to determine a good way to retrieve the
structured access from an untyped GEP. This is not a trivial problem,
and needs to be addressed to have a robust compiler.

The issue is other workstreams relies on the access chain deduction to
work. So we have 2 options:
 - pause all dependent work until we have a good chain deduction.
 - submit this limited fix to we can work on both this and other
   features in parallel.

Choice we want to make is #2: submitting this **knowing** this is not a
**good** fix. It only increase the number of patterns we can work with,
thus allowing others to continue working on other parts of the backend.

This patch as-is has many limitations:
 - If cannot robustly determine the depth of the structured access
   from a GEP. Fixing this would require looking ahead at the full
   GEP chain.
 - It cannot always figure out the correct access indices, especially
   with dynamic indices. This will require frontend collaboration.

Because we know this is a temporary hack, this patch only impacts
the logical SPIR-V target. Physical SPIR-V, which can rely on pointer
cast remains on the old method.

Related to #145002
---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 152 +++++++++++++++++-
 .../SPIRV/hlsl-resources/StructuredBuffer.ll  |   2 +-
 .../CodeGen/SPIRV/llvm-intrinsics/lifetime.ll |   4 +-
 .../CodeGen/SPIRV/logical-struct-access.ll    |   3 +-
 .../pointers/structured-buffer-access.ll      |  75 +++++++++
 5 files changed, 230 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 3c631ce1fec02..c7d0a00bbe87d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -194,6 +194,31 @@ class SPIRVEmitIntrinsics
 
   void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);
 
+  // Tries to walk the type accessed by the given GEP instruction.
+  // For each nested type access, one of the 2 callbacks is called:
+  //  - OnStaticIndex when the index is a known constant value.
+  //  - OnDynamnicIndexing when the index is a non-constant value.
+  // Return true if an error occured during the walk, false otherwise.
+  bool walkLogicalAccessChain(
+      GetElementPtrInst &GEP,
+      const std::function<void(Type *, uint64_t)> &OnStaticIndexing,
+      const std::function<void(Type *, Value *)> &OnDynamicIndexing);
+
+  // Returns the type accessed using the given GEP instruction by relying
+  // on the GEP type.
+  // FIXME: GEP types are not supposed to be used to retrieve the pointed
+  // type. This must be fixed.
+  Type *getGEPType(GetElementPtrInst *GEP);
+
+  // Returns the type accessed using the given GEP instruction by walking
+  // the source type using the GEP indices.
+  // FIXME: without help from the frontend, this method cannot reliably retrieve
+  // the stored type, nor can robustly determine the depth of the type
+  // we are accessing.
+  Type *getGEPTypeLogical(GetElementPtrInst *GEP);
+
+  Instruction *buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP);
+
 public:
   static char ID;
   SPIRVEmitIntrinsics(SPIRVTargetMachine *TM = nullptr)
@@ -246,6 +271,17 @@ bool expectIgnoredInIRTranslation(const Instruction *I) {
   }
 }
 
+// Returns the source pointer from `I` ignoring intermediate ptrcast.
+Value *getPointerRoot(Value *I) {
+  if (auto *II = dyn_cast<IntrinsicInst>(I)) {
+    if (II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
+      Value *V = II->getArgOperand(0);
+      return getPointerRoot(V);
+    }
+  }
+  return I;
+}
+
 } // namespace
 
 char SPIRVEmitIntrinsics::ID = 0;
@@ -555,7 +591,97 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
   Ty = RefTy;
 }
 
-Type *getGEPType(GetElementPtrInst *Ref) {
+bool SPIRVEmitIntrinsics::walkLogicalAccessChain(
+    GetElementPtrInst &GEP,
+    const std::function<void(Type *, uint64_t)> &OnStaticIndexing,
+    const std::function<void(Type *, Value *)> &OnDynamicIndexing) {
+  auto &DL = CurrF->getDataLayout();
+  Value *Src = getPointerRoot(GEP.getPointerOperand());
+  Type *CurType = deduceElementType(Src, true);
+
+  for (Value *V : GEP.indices()) {
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
+      uint64_t Offset = CI->getZExtValue();
+
+      do {
+        if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
+          uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8;
+          assert(Offset < AT->getNumElements() * EltTypeSize);
+          uint64_t Index = Offset / EltTypeSize;
+          Offset = Offset - (Index * EltTypeSize);
+          CurType = AT->getElementType();
+          OnStaticIndexing(CurType, Index);
+        } else if (StructType *ST = dyn_cast<StructType>(CurType)) {
+          uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8;
+          assert(Offset < StructSize);
+          const auto &STL = DL.getStructLayout(ST);
+          unsigned Element = STL->getElementContainingOffset(Offset);
+          Offset -= STL->getElementOffset(Element);
+          CurType = ST->getElementType(Element);
+          OnStaticIndexing(CurType, Element);
+        } else {
+          // Vector type indexing should not use GEP.
+          // So if we have an index left, something is wrong. Giving up.
+          return true;
+        }
+      } while (Offset > 0);
+
+    } else if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
+      // Index is not constant. Either we have an array and accept it, or we
+      // give up.
+      CurType = AT->getElementType();
+      OnDynamicIndexing(CurType, V);
+    } else
+      return true;
+  }
+
+  return false;
+}
+
+Instruction *
+SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) {
+  IRBuilder<> B(GEP.getParent());
+
+  std::vector<Value *> Indices;
+  Indices.push_back(ConstantInt::get(
+      IntegerType::getInt32Ty(CurrF->getContext()), 0, /* Signed= */ false));
+  walkLogicalAccessChain(
+      GEP,
+      [&Indices, &B](Type *EltType, uint64_t Index) {
+        Indices.push_back(
+            ConstantInt::get(B.getInt64Ty(), Index, /* Signed= */ false));
+      },
+      [&Indices](Type *EltType, Value *Index) { Indices.push_back(Index); });
+
+  B.SetInsertPoint(&GEP);
+  SmallVector<Type *, 2> Types = {GEP.getType(), GEP.getOperand(0)->getType()};
+  SmallVector<Value *, 4> Args;
+  Args.push_back(B.getInt1(GEP.isInBounds()));
+  Args.push_back(GEP.getOperand(0));
+  llvm::append_range(Args, Indices);
+  auto *NewI = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+  replaceAllUsesWithAndErase(B, &GEP, NewI);
+  return NewI;
+}
+
+Type *SPIRVEmitIntrinsics::getGEPTypeLogical(GetElementPtrInst *GEP) {
+
+  Type *CurType = GEP->getResultElementType();
+
+  bool Interrupted = walkLogicalAccessChain(
+      *GEP, [&CurType](Type *EltType, uint64_t Index) { CurType = EltType; },
+      [&CurType](Type *EltType, Value *Index) { CurType = EltType; });
+
+  return Interrupted ? GEP->getResultElementType() : CurType;
+}
+
+Type *SPIRVEmitIntrinsics::getGEPType(GetElementPtrInst *Ref) {
+  if (Ref->getSourceElementType() ==
+          IntegerType::getInt8Ty(CurrF->getContext()) &&
+      TM->getSubtargetImpl()->isLogicalSPIRV()) {
+    return getGEPTypeLogical(Ref);
+  }
+
   Type *Ty = nullptr;
   // TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
   // useful here
@@ -1395,6 +1521,13 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
 }
 
 Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) {
+  if (I.getSourceElementType() == IntegerType::getInt8Ty(CurrF->getContext()) &&
+      TM->getSubtargetImpl()->isLogicalSPIRV()) {
+    Instruction *Result = buildLogicalAccessChainFromGEP(I);
+    if (Result)
+      return Result;
+  }
+
   IRBuilder<> B(I.getParent());
   B.SetInsertPoint(&I);
   SmallVector<Type *, 2> Types = {I.getType(), I.getOperand(0)->getType()};
@@ -1588,7 +1721,22 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
   }
   if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
     Value *Pointer = GEPI->getPointerOperand();
-    Type *OpTy = GEPI->getSourceElementType();
+    Type *OpTy = nullptr;
+
+    // Knowing the accessed type is mandatory for logical SPIR-V. Sadly,
+    // the GEP source element type should not be used for this purpose, and
+    // the alternative type-scavenging method is not working.
+    // Physical SPIR-V can work around this, but not logical, hence still
+    // try to rely on the broken type scavenging for logical.
+    if (TM->getSubtargetImpl()->isLogicalSPIRV()) {
+      Value *Src = getPointerRoot(Pointer);
+      OpTy = GR->findDeducedElementType(Src);
+    }
+
+    // In all cases, fall back to the GEP type if type scavenging failed.
+    if (!OpTy)
+      OpTy = GEPI->getSourceElementType();
+
     replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B);
     if (isNestedPointer(OpTy))
       insertTodoType(Pointer);
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
index e47685cd38a2a..878a59cb76fe2 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
@@ -1,4 +1,4 @@
-; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - | FileCheck %s
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - -print-after-all | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-vulkan1.3-library %s -o - -filetype=obj | spirv-val %}
 
 target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/lifetime.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/lifetime.ll
index 085f8b3bc44c0..9d07b63b49a52 100644
--- a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/lifetime.ll
+++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/lifetime.ll
@@ -33,7 +33,7 @@ define spir_func void @foo(ptr noundef byval(%tprange) align 8 %_arg_UserRange)
   %RoundedRangeKernel = alloca %tprange, align 8
   call void @llvm.lifetime.start.p0(i64 72, ptr nonnull %RoundedRangeKernel)
   call void @llvm.memcpy.p0.p0.i64(ptr align 8 %RoundedRangeKernel, ptr align 8 %_arg_UserRange, i64 16, i1 false)
-  %KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 16
+  %KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 8
   call void @llvm.lifetime.end.p0(i64 72, ptr nonnull %RoundedRangeKernel)
   ret void
 }
@@ -55,7 +55,7 @@ define spir_func void @bar(ptr noundef byval(%tprange) align 8 %_arg_UserRange)
   %RoundedRangeKernel = alloca %tprange, align 8
   call void @llvm.lifetime.start.p0(i64 -1, ptr nonnull %RoundedRangeKernel)
   call void @llvm.memcpy.p0.p0.i64(ptr align 8 %RoundedRangeKernel, ptr align 8 %_arg_UserRange, i64 16, i1 false)
-  %KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 16
+  %KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 8
   call void @llvm.lifetime.end.p0(i64 -1, ptr nonnull %RoundedRangeKernel)
   ret void
 }
diff --git a/llvm/test/CodeGen/SPIRV/logical-struct-access.ll b/llvm/test/CodeGen/SPIRV/logical-struct-access.ll
index a1ff1e0752e03..66337b1ba2b37 100644
--- a/llvm/test/CodeGen/SPIRV/logical-struct-access.ll
+++ b/llvm/test/CodeGen/SPIRV/logical-struct-access.ll
@@ -1,4 +1,5 @@
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -print-after-all | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: [[uint:%[0-9]+]] = OpTypeInt 32 0
 
diff --git a/llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access.ll b/llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access.ll
new file mode 100644
index 0000000000000..8e6b5a68fb769
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access.ll
@@ -0,0 +1,75 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
+
+; struct S1 {
+;   int4 i;
+;   float4 f;
+; };
+; struct S2 {
+;   float4 f;
+;   int4 i;
+; };
+;
+; StructuredBuffer<S1> In : register(t1);
+; RWStructuredBuffer<S2> Out : register(u0);
+;
+; [numthreads(1,1,1)]
+; void main(uint GI : SV_GroupIndex) {
+;   Out[GI].f = In[GI].f;
+;   Out[GI].i = In[GI].i;
+; }
+
+%struct.S1 = type { <4 x i32>, <4 x float> }
+%struct.S2 = type { <4 x float>, <4 x i32> }
+
+ at .str = private unnamed_addr constant [3 x i8] c"In\00", align 1
+ at .str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+
+define void @main() local_unnamed_addr #0 {
+; CHECK-LABEL: main
+; CHECK:       %43 = OpFunction %2 None %3 ; -- Begin function main
+; CHECK-NEXT:    %1 = OpLabel
+; CHECK-NEXT:    %44 = OpVariable %28 Function %38
+; CHECK-NEXT:    %45 = OpVariable %27 Function %39
+; CHECK-NEXT:    %46 = OpCopyObject %19 %40
+; CHECK-NEXT:    %47 = OpCopyObject %16 %41
+; CHECK-NEXT:    %48 = OpLoad %4 %42
+; CHECK-NEXT:    %49 = OpAccessChain %13 %46 %29 %48
+; CHECK-NEXT:    %50 = OpInBoundsAccessChain %9 %49 %31
+; CHECK-NEXT:    %51 = OpLoad %8 %50 Aligned 1
+; CHECK-NEXT:    %52 = OpAccessChain %11 %47 %29 %48
+; CHECK-NEXT:    %53 = OpInBoundsAccessChain %9 %52 %29
+; CHECK-NEXT:    OpStore %53 %51 Aligned 1
+; CHECK-NEXT:    %54 = OpAccessChain %6 %49 %29
+; CHECK-NEXT:    %55 = OpLoad %5 %54 Aligned 1
+; CHECK-NEXT:    %56 = OpInBoundsAccessChain %6 %52 %31
+; CHECK-NEXT:    OpStore %56 %55 Aligned 1
+; CHECK-NEXT:    OpReturn
+; CHECK-NEXT:    OpFunctionEnd
+entry:
+  %0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
+  %1 = tail call target("spirv.VulkanBuffer", [0 x %struct.S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S2s_12_1t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr nonnull @.str.2)
+  %2 = tail call i32 @llvm.spv.flattened.thread.id.in.group()
+  %3 = tail call noundef align 1 dereferenceable(32) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 %2)
+  %f.i = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 16
+  %4 = load <4 x float>, ptr addrspace(11) %f.i, align 1
+  %5 = tail call noundef align 1 dereferenceable(32) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S2s_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S2], 12, 1) %1, i32 %2)
+  store <4 x float> %4, ptr addrspace(11) %5, align 1
+  %6 = load <4 x i32>, ptr addrspace(11) %3, align 1
+  %i6.i = getelementptr inbounds nuw i8, ptr addrspace(11) %5, i64 16
+  store <4 x i32> %6, ptr addrspace(11) %i6.i, align 1
+  ret void
+}
+
+declare i32 @llvm.spv.flattened.thread.id.in.group()
+
+declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
+
+declare target("spirv.VulkanBuffer", [0 x %struct.S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S2s_12_1t(i32, i32, i32, i32, i1, ptr)
+
+declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S2s_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S2], 12, 1), i32)
+
+declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

>From 4f47d9cbcf835cd05ed81f41669a99818ec45794 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 29 Jul 2025 11:16:53 +0200
Subject: [PATCH 2/3] revert cmdline change

---
 llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
index 878a59cb76fe2..e47685cd38a2a 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
@@ -1,4 +1,4 @@
-; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - -print-after-all | FileCheck %s
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-vulkan1.3-library %s -o - -filetype=obj | spirv-val %}
 
 target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"

>From 58e246723e348e3ea571a8f6bd4ce8ec921f83a4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 29 Jul 2025 13:23:26 +0200
Subject: [PATCH 3/3] pr-feedback: simplify code and add more testing.

Fixes the index inconsistencies
---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 108 +++++++++++-------
 ...ructured-buffer-access-constant-index-1.ll |  46 ++++++++
 ...ructured-buffer-access-constant-index-2.ll |  54 +++++++++
 3 files changed, 168 insertions(+), 40 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access-constant-index-1.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access-constant-index-2.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index c7d0a00bbe87d..6081ab04dbfff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -196,13 +196,24 @@ class SPIRVEmitIntrinsics
 
   // Tries to walk the type accessed by the given GEP instruction.
   // For each nested type access, one of the 2 callbacks is called:
-  //  - OnStaticIndex when the index is a known constant value.
+  //  - OnLiteralIndexing when the index is a known constant value.
+  //    Parameters:
+  //      PointedType: the pointed type resulting of this indexing.
+  //        If the parent type is an array, this is the index in the array.
+  //        If the parent type is a struct, this is the field index.
+  //      Index: index of the element in the parent type.
   //  - OnDynamnicIndexing when the index is a non-constant value.
+  //    This callback is only called when indexing into an array.
+  //    Parameters:
+  //      ElementType: the type of the elements stored in the parent array.
+  //      Offset: the Value* containing the byte offset into the array.
   // Return true if an error occured during the walk, false otherwise.
   bool walkLogicalAccessChain(
       GetElementPtrInst &GEP,
-      const std::function<void(Type *, uint64_t)> &OnStaticIndexing,
-      const std::function<void(Type *, Value *)> &OnDynamicIndexing);
+      const std::function<void(Type *PointedType, uint64_t Index)>
+          &OnLiteralIndexing,
+      const std::function<void(Type *ElementType, Value *Offset)>
+          &OnDynamicIndexing);
 
   // Returns the type accessed using the given GEP instruction by relying
   // on the GEP type.
@@ -593,54 +604,64 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
 
 bool SPIRVEmitIntrinsics::walkLogicalAccessChain(
     GetElementPtrInst &GEP,
-    const std::function<void(Type *, uint64_t)> &OnStaticIndexing,
+    const std::function<void(Type *, uint64_t)> &OnLiteralIndexing,
     const std::function<void(Type *, Value *)> &OnDynamicIndexing) {
+  // We only rewrite i8* GEP. Other should be left as-is.
+  // Observation so-far is i8* GEP always have a single index. Making sure
+  // that's the case.
+  assert(GEP.getSourceElementType() ==
+         IntegerType::getInt8Ty(CurrF->getContext()));
+  assert(GEP.getNumIndices() == 1);
+
   auto &DL = CurrF->getDataLayout();
   Value *Src = getPointerRoot(GEP.getPointerOperand());
   Type *CurType = deduceElementType(Src, true);
 
-  for (Value *V : GEP.indices()) {
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
-      uint64_t Offset = CI->getZExtValue();
-
-      do {
-        if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
-          uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8;
-          assert(Offset < AT->getNumElements() * EltTypeSize);
-          uint64_t Index = Offset / EltTypeSize;
-          Offset = Offset - (Index * EltTypeSize);
-          CurType = AT->getElementType();
-          OnStaticIndexing(CurType, Index);
-        } else if (StructType *ST = dyn_cast<StructType>(CurType)) {
-          uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8;
-          assert(Offset < StructSize);
-          const auto &STL = DL.getStructLayout(ST);
-          unsigned Element = STL->getElementContainingOffset(Offset);
-          Offset -= STL->getElementOffset(Element);
-          CurType = ST->getElementType(Element);
-          OnStaticIndexing(CurType, Element);
-        } else {
-          // Vector type indexing should not use GEP.
-          // So if we have an index left, something is wrong. Giving up.
-          return true;
-        }
-      } while (Offset > 0);
-
-    } else if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
-      // Index is not constant. Either we have an array and accept it, or we
-      // give up.
+  Value *Operand = *GEP.idx_begin();
+  ConstantInt *CI = dyn_cast<ConstantInt>(Operand);
+  if (!CI) {
+    ArrayType *AT = dyn_cast<ArrayType>(CurType);
+    // Operand is not constant. Either we have an array and accept it, or we
+    // give up.
+    if (AT)
+      OnDynamicIndexing(AT->getElementType(), Operand);
+    return AT == nullptr;
+  }
+
+  assert(CI);
+  uint64_t Offset = CI->getZExtValue();
+
+  do {
+    if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
+      uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8;
+      assert(Offset < AT->getNumElements() * EltTypeSize);
+      uint64_t Index = Offset / EltTypeSize;
+      Offset = Offset - (Index * EltTypeSize);
       CurType = AT->getElementType();
-      OnDynamicIndexing(CurType, V);
-    } else
+      OnLiteralIndexing(CurType, Index);
+    } else if (StructType *ST = dyn_cast<StructType>(CurType)) {
+      uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8;
+      assert(Offset < StructSize);
+      const auto &STL = DL.getStructLayout(ST);
+      unsigned Element = STL->getElementContainingOffset(Offset);
+      Offset -= STL->getElementOffset(Element);
+      CurType = ST->getElementType(Element);
+      OnLiteralIndexing(CurType, Element);
+    } else {
+      // Vector type indexing should not use GEP.
+      // So if we have an index left, something is wrong. Giving up.
       return true;
-  }
+    }
+  } while (Offset > 0);
 
   return false;
 }
 
 Instruction *
 SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) {
+  auto &DL = CurrF->getDataLayout();
   IRBuilder<> B(GEP.getParent());
+  B.SetInsertPoint(&GEP);
 
   std::vector<Value *> Indices;
   Indices.push_back(ConstantInt::get(
@@ -651,9 +672,14 @@ SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) {
         Indices.push_back(
             ConstantInt::get(B.getInt64Ty(), Index, /* Signed= */ false));
       },
-      [&Indices](Type *EltType, Value *Index) { Indices.push_back(Index); });
+      [&Indices, &B, &DL](Type *EltType, Value *Offset) {
+        uint32_t EltTypeSize = DL.getTypeSizeInBits(EltType) / 8;
+        Value *Index = B.CreateUDiv(
+            Offset, ConstantInt::get(Offset->getType(), EltTypeSize,
+                                     /* Signed= */ false));
+        Indices.push_back(Index);
+      });
 
-  B.SetInsertPoint(&GEP);
   SmallVector<Type *, 2> Types = {GEP.getType(), GEP.getOperand(0)->getType()};
   SmallVector<Value *, 4> Args;
   Args.push_back(B.getInt1(GEP.isInBounds()));
@@ -1728,7 +1754,9 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
     // the alternative type-scavenging method is not working.
     // Physical SPIR-V can work around this, but not logical, hence still
     // try to rely on the broken type scavenging for logical.
-    if (TM->getSubtargetImpl()->isLogicalSPIRV()) {
+    bool IsRewrittenGEP =
+        GEPI->getSourceElementType() == IntegerType::getInt8Ty(I->getContext());
+    if (IsRewrittenGEP && TM->getSubtargetImpl()->isLogicalSPIRV()) {
       Value *Src = getPointerRoot(Pointer);
       OpTy = GR->findDeducedElementType(Src);
     }
diff --git a/llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access-constant-index-1.ll b/llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access-constant-index-1.ll
new file mode 100644
index 0000000000000..26dc60e738927
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access-constant-index-1.ll
@@ -0,0 +1,46 @@
+; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
+
+%struct.S1 = type { <4 x i32>, [10 x <4 x float>], <4 x float> }
+%struct.S2 = type { <4 x float>, <4 x i32> }
+
+ at .str = private unnamed_addr constant [3 x i8] c"In\00", align 1
+
+define <4 x float> @main() {
+entry:
+  %0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
+  %3 = tail call noundef align 1 dereferenceable(192) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 0)
+
+; CHECK-DAG:  %[[#ulong:]] = OpTypeInt 64 0
+; CHECK-DAG:  %[[#ulong_1:]] = OpConstant %[[#ulong]] 1
+; CHECK-DAG:  %[[#ulong_3:]] = OpConstant %[[#ulong]] 3
+
+; CHECK-DAG:  %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:  %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG:  %[[#uint_10:]] = OpConstant %[[#uint]] 10
+
+; CHECK-DAG:  %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG:  %[[#v4f:]] = OpTypeVector %[[#float]] 4
+; CHECK-DAG:  %[[#arr_v4f:]] = OpTypeArray %[[#v4f]] %[[#uint_10]]
+; CHECK-DAG:  %[[#S1:]] = OpTypeStruct %[[#]] %[[#arr_v4f]] %[[#]]
+; CHECK-DAG:  %[[#sb_S1:]] = OpTypePointer StorageBuffer %[[#S1]]
+; CHECK-DAG:  %[[#sb_v4f:]] = OpTypePointer StorageBuffer %[[#v4f]]
+
+; CHECK:      %[[#tmp:]] = OpAccessChain %[[#sb_S1]] %[[#]] %[[#uint_0]] %[[#uint_0]]
+; CHECK:      %[[#ptr:]] = OpInBoundsAccessChain %[[#sb_v4f]] %[[#tmp]] %[[#ulong_1]] %[[#ulong_3]]
+; This rewritten GEP combined all constant indices into a single value.
+; We should make sure the correct indices are retrieved.
+  %arrayidx.i = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 64
+
+; CHECK:  OpLoad %[[#v4f]] %[[#ptr]]
+  %4 = load <4 x float>, ptr addrspace(11) %arrayidx.i, align 1
+
+  ret <4 x float> %4
+}
+
+declare i32 @llvm.spv.flattened.thread.id.in.group()
+declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
+declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
diff --git a/llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access-constant-index-2.ll b/llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access-constant-index-2.ll
new file mode 100644
index 0000000000000..a6efb38a23d05
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/structured-buffer-access-constant-index-2.ll
@@ -0,0 +1,54 @@
+; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
+
+%struct.S1 = type { <4 x i32>, [10 x <4 x float>], <4 x float> }
+%struct.S2 = type { <4 x float>, <4 x i32> }
+
+ at .str = private unnamed_addr constant [3 x i8] c"In\00", align 1
+
+define <4 x float> @main(i32 %index) {
+entry:
+  %0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
+  %3 = tail call noundef align 1 dereferenceable(192) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 0)
+
+; CHECK-DAG:  %[[#ulong:]] = OpTypeInt 64 0
+; CHECK-DAG:  %[[#ulong_1:]] = OpConstant %[[#ulong]] 1
+
+; CHECK-DAG:  %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:  %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG:  %[[#uint_10:]] = OpConstant %[[#uint]] 10
+; CHECK-DAG:  %[[#uint_16:]] = OpConstant %[[#uint]] 16
+
+; CHECK-DAG:  %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG:  %[[#v4f:]] = OpTypeVector %[[#float]] 4
+; CHECK-DAG:  %[[#arr_v4f:]] = OpTypeArray %[[#v4f]] %[[#uint_10]]
+; CHECK-DAG:  %[[#S1:]] = OpTypeStruct %[[#]] %[[#arr_v4f]] %[[#]]
+; CHECK-DAG:  %[[#sb_S1:]] = OpTypePointer StorageBuffer %[[#S1]]
+; CHECK-DAG:  %[[#sb_arr_v4f:]] = OpTypePointer StorageBuffer %[[#arr_v4f]]
+; CHECK-DAG:  %[[#sb_v4f:]] = OpTypePointer StorageBuffer %[[#v4f]]
+
+; CHECK:      %[[#a:]] = OpAccessChain %[[#sb_S1]] %[[#]] %[[#uint_0]] %[[#uint_0]]
+; CHECK:      %[[#b:]] = OpInBoundsAccessChain %[[#sb_arr_v4f]] %[[#a]] %[[#ulong_1]]
+  %4 = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 16
+
+; CHECK:      %[[#offset:]] = OpIMul %[[#]] %[[#]] %[[#uint_16]]
+; Offset is computed in bytes. Make sure we reconvert it back to an index.
+  %offset = mul i32 %index, 16
+
+; CHECK:      %[[#index:]] = OpUDiv %[[#]] %[[#offset]] %[[#uint_16]]
+; CHECK:      %[[#c:]] = OpInBoundsAccessChain %[[#sb_v4f]] %[[#b]] %[[#index]]
+  %5 = getelementptr inbounds nuw i8, ptr addrspace(11) %4, i32 %offset
+
+; CHECK:  OpLoad %[[#v4f]] %[[#c]]
+  %6 = load <4 x float>, ptr addrspace(11) %5, align 1
+
+  ret <4 x float> %6
+}
+
+declare i32 @llvm.spv.flattened.thread.id.in.group()
+declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
+declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+



More information about the llvm-commits mailing list