[llvm] [SPIR-V] Add store legalization for ptrcast (PR #135369)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 16 04:47:38 PDT 2025


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/135369

>From c77cfe26f614de8d4a31d905e443ff5980408768 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Thu, 10 Apr 2025 13:12:55 +0200
Subject: [PATCH 1/2] [SPIR-V] Add store legalization for ptrcast

This commits adds handling for spv.ptrcast result being
used in a store instruction, modifying the store to operate on
the source type.
---
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 104 +++++++++++++++++
 .../pointers/getelementptr-downcast-struct.ll |  20 ++++
 .../pointers/getelementptr-downcast-vector.ll | 110 ++++++++++++++++++
 .../CodeGen/SPIRV/pointers/store-struct.ll    |  66 +++++++++++
 4 files changed, 300 insertions(+)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/store-struct.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 5ba4fbb02560d..f3f1558265d4a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -150,6 +150,95 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     DeadInstructions.push_back(LI);
   }
 
+  // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
+  Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
+                           unsigned Index) {
+    Type *Int32Ty = Type::getInt32Ty(B.getContext());
+    SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
+                                    Element->getType(), Int32Ty};
+    SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
+    Instruction *NewI =
+        B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
+    buildAssignType(B, Vector->getType(), NewI);
+    return NewI;
+  }
+
+  // Creates an spv_extractelt instruction (equivalent to llvm's
+  // extractelement).
+  Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
+                            unsigned Index) {
+    Type *Int32Ty = Type::getInt32Ty(B.getContext());
+    SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
+    SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
+    Instruction *NewI =
+        B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
+    buildAssignType(B, ElementType, NewI);
+    return NewI;
+  }
+
+  // Stores the given Src vector operand into the Dst vector, adjusting the size
+  // if required.
+  Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
+                               Align Alignment) {
+    FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
+    FixedVectorType *DstType =
+        cast<FixedVectorType>(GR->findDeducedElementType(Dst));
+    assert(DstType->getNumElements() >= SrcType->getNumElements());
+
+    LoadInst *LI = B.CreateLoad(DstType, Dst);
+    LI->setAlignment(Alignment);
+    Value *OldValues = LI;
+    buildAssignType(B, OldValues->getType(), OldValues);
+    Value *NewValues = Src;
+
+    for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
+      Value *Element =
+          makeExtractElement(B, SrcType->getElementType(), NewValues, I);
+      OldValues = makeInsertElement(B, OldValues, Element, I);
+    }
+
+    StoreInst *SI = B.CreateStore(OldValues, Dst);
+    SI->setAlignment(Alignment);
+    return SI;
+  }
+
+  // Stores the given Src value into the first entry of the Dst aggregate.
+  Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
+                                    Align Alignment) {
+    SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
+    SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst,
+                                 B.getInt32(0), B.getInt32(0)};
+    auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+    GR->buildAssignPtr(B, Src->getType(), GEP);
+    StoreInst *SI = B.CreateStore(Src, GEP);
+    SI->setAlignment(Alignment);
+    return SI;
+  }
+
+  // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
+  // operand into a valid logical SPIR-V store with no ptrcast.
+  void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
+                      Value *Dst, Align Alignment) {
+    Type *ToTy = GR->findDeducedElementType(Dst);
+    Type *FromTy = Src->getType();
+
+    auto *SVT = dyn_cast<FixedVectorType>(FromTy);
+    auto *DST = dyn_cast<StructType>(ToTy);
+    auto *DVT = dyn_cast<FixedVectorType>(ToTy);
+
+    B.SetInsertPoint(BadStore);
+    if (DST && DST->getTypeAtIndex(0u) == FromTy)
+      storeToFirstValueAggregate(B, Src, Dst, Alignment);
+    else if (DVT && SVT)
+      storeVectorFromVector(B, Src, Dst, Alignment);
+    else if (DVT && !SVT && FromTy == DVT->getElementType())
+      storeToFirstValueAggregate(B, Src, Dst, Alignment);
+    else
+      llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
+
+    DeadInstructions.push_back(BadStore);
+  }
+
   void legalizePointerCast(IntrinsicInst *II) {
     Value *CastedOperand = II;
     Value *OriginalOperand = II->getOperand(0);
@@ -165,6 +254,12 @@ class SPIRVLegalizePointerCast : public FunctionPass {
         continue;
       }
 
+      if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
+        transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
+                       SI->getAlign());
+        continue;
+      }
+
       if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
         if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
           DeadInstructions.push_back(Intrin);
@@ -176,6 +271,15 @@ class SPIRVLegalizePointerCast : public FunctionPass {
                                  /* DeleteOld= */ false);
           continue;
         }
+
+        if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
+          Align Alignment;
+          if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
+            Alignment = Align(C->getZExtValue());
+          transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
+                         Alignment);
+          continue;
+        }
       }
 
       llvm_unreachable("Unsupported ptrcast user. Please fix.");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
index b0a68a30e29be..35e5880881e5c 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
@@ -45,3 +45,23 @@ entry:
   %val = load i32, ptr addrspace(10) %ptr
   ret i32 %val
 }
+
+define spir_func void @foos(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global1]] %[[#uint_0]] %[[#index]]
+  %ptr = getelementptr inbounds %S1, ptr addrspace(10) @global1, i64 0, i32 0, i64 %index
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+  store i32 0, ptr addrspace(10) %ptr
+  ret void
+}
+
+define spir_func void @bars(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global2]] %[[#uint_0]] %[[#uint_0]] %[[#index]] %[[#uint_1]]
+  %ptr = getelementptr inbounds %S2, ptr addrspace(10) @global2, i64 0, i32 0, i32 0, i64 %index, i32 1
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+  store i32 0, ptr addrspace(10) %ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
index d4131fa8a2658..be9e2a23365cc 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
@@ -5,9 +5,13 @@
 ; CHECK-DAG: %[[#uint_pp:]] = OpTypePointer Private %[[#uint]]
 ; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
 ; CHECK-DAG:  %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG:  %[[#uint_1:]] = OpConstant %[[#uint]] 1
+; CHECK-DAG:  %[[#uint_2:]] = OpConstant %[[#uint]] 2
 ; CHECK-DAG:      %[[#v2:]] = OpTypeVector %[[#uint]] 2
 ; CHECK-DAG:      %[[#v3:]] = OpTypeVector %[[#uint]] 3
 ; CHECK-DAG:      %[[#v4:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG:   %[[#v2_01:]] = OpConstantComposite %[[#v2]] %[[#uint_0]] %[[#uint_1]]
+; CHECK-DAG:  %[[#v3_012:]] = OpConstantComposite %[[#v3]] %[[#uint_0]] %[[#uint_1]] %[[#uint_2]]
 ; CHECK-DAG:   %[[#v4_pp:]] = OpTypePointer Private %[[#v4]]
 ; CHECK-DAG:   %[[#v4_fp:]] = OpTypePointer Function %[[#v4]]
 
@@ -108,3 +112,109 @@ define internal spir_func i32 @bazBounds(ptr %a) {
   ret i32 %2
 ; CHECK: OpReturnValue %[[#val]]
 }
+
+define internal spir_func void @foos(ptr addrspace(10) %a) {
+
+  %1 = getelementptr inbounds <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]]  = OpInBoundsAccessChain %[[#v4_pp]] %[[#]]
+
+  store <3 x i32> <i32 0, i32 1, i32 2>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK:    %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK:    %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK:    %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+  ret void
+}
+
+define internal spir_func void @foosDefault(ptr %a) {
+
+  %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]]  = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+  store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK:    %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK:    %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK:    %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+  ret void
+}
+
+define internal spir_func void @foosBounds(ptr %a) {
+
+  %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]]  = OpAccessChain %[[#v4_fp]] %[[#]]
+
+  store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 64
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 64
+; CHECK:    %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK:    %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK:    %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 64
+
+  ret void
+}
+
+define internal spir_func void @bars(ptr addrspace(10) %a) {
+
+  %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]]  = OpAccessChain %[[#v4_pp]] %[[#]]
+
+  store <2 x i32> <i32 0, i32 1>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK:    %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK:    %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: OpStore %[[#ptr]] %[[#out2]] Aligned 1
+
+  ret void
+}
+
+define internal spir_func void @bazs(ptr addrspace(10) %a) {
+
+  %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]]  = OpAccessChain %[[#v4_pp]] %[[#]]
+
+  store i32 0, ptr addrspace(10) %1, align 32
+; CHECK:  %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 32
+
+  ret void
+}
+
+define internal spir_func void @bazsDefault(ptr %a) {
+
+  %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]]  = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+  store i32 0, ptr %1, align 16
+; CHECK:  %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+  ret void
+}
+
+define internal spir_func void @bazsBounds(ptr %a) {
+
+  %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]]  = OpAccessChain %[[#v4_fp]] %[[#]]
+
+  store i32 0, ptr %1, align 16
+; CHECK:  %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
new file mode 100644
index 0000000000000..7d2c1093f0a71
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
@@ -0,0 +1,66 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG:     %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:    %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
+; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
+; CHECK-DAG:  %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
+; CHECK-DAG:   %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG:   %[[#float_0:]] = OpConstant %[[#float]] 0
+; CHECK-DAG:       %[[#sf:]] = OpTypeStruct %[[#float]]
+; CHECK-DAG:       %[[#su:]] = OpTypeStruct %[[#uint]]
+; CHECK-DAG:       %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
+; CHECK-DAG:    %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
+; CHECK-DAG:    %[[#su_fp:]] = OpTypePointer Function %[[#su]]
+; CHECK-DAG:    %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
+; CHECK-DAG:    %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
+
+%struct.SF = type { float }
+%struct.SU = type { i32 }
+%struct.SFUF = type { float, i32, float }
+
+ at gsfuf = external addrspace(10) global %struct.SFUF
+; CHECK: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
+
+define internal spir_func void @foo() {
+  %1 = alloca %struct.SF, align 4
+; CHECK: %[[#var:]]  = OpVariable %[[#sf_fp]] Function
+
+  store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+  ret void
+}
+
+define internal spir_func void @bar() {
+  %1 = alloca %struct.SU, align 4
+; CHECK: %[[#var:]]  = OpVariable %[[#su_fp]] Function
+
+  store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+  ret void
+}
+
+define internal spir_func void @baz() {
+  %1 = alloca %struct.SFUF, align 4
+; CHECK: %[[#var:]]  = OpVariable %[[#sfuf_fp]] Function
+
+  store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+  ret void
+}
+
+define internal spir_func void @biz() {
+  store float 0.0, ptr addrspace(10) @gsfuf, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#float_pp]] %[[#gsfuf]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+  ret void
+}
+

>From 414e576aa203aa0f464948161a8af05c2567a4ed Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 16 Apr 2025 13:46:39 +0200
Subject: [PATCH 2/2] support nested types

---
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 51 +++++++++---
 .../CodeGen/SPIRV/pointers/store-struct.ll    | 78 +++++++++++++++----
 2 files changed, 103 insertions(+), 26 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index f3f1558265d4a..e55a6be18c944 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -202,12 +202,29 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     return SI;
   }
 
+  void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
+                          SmallVectorImpl<Value *> &Indices) {
+    Indices.push_back(B.getInt32(0));
+
+    if (Search == Aggregate)
+      return;
+
+    if (auto *ST = dyn_cast<StructType>(Aggregate))
+      buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);
+    else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
+      buildGEPIndexChain(B, Search, AT->getElementType(), Indices);
+    else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
+      buildGEPIndexChain(B, Search, VT->getElementType(), Indices);
+    else
+      llvm_unreachable("Bad access chain?");
+  }
+
   // Stores the given Src value into the first entry of the Dst aggregate.
   Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
-                                    Align Alignment) {
+                                    Type *DstPointeeType, Align Alignment) {
     SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
-    SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst,
-                                 B.getInt32(0), B.getInt32(0)};
+    SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};
+    buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
     auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
     GR->buildAssignPtr(B, Src->getType(), GEP);
     StoreInst *SI = B.CreateStore(Src, GEP);
@@ -215,6 +232,18 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     return SI;
   }
 
+  bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
+    if (Search == Aggregate)
+      return true;
+    if (auto *ST = dyn_cast<StructType>(Aggregate))
+      return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));
+    if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
+      return isTypeFirstElementAggregate(Search, VT->getElementType());
+    if (auto *AT = dyn_cast<ArrayType>(Aggregate))
+      return isTypeFirstElementAggregate(Search, AT->getElementType());
+    return false;
+  }
+
   // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
   // operand into a valid logical SPIR-V store with no ptrcast.
   void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
@@ -222,17 +251,17 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     Type *ToTy = GR->findDeducedElementType(Dst);
     Type *FromTy = Src->getType();
 
-    auto *SVT = dyn_cast<FixedVectorType>(FromTy);
-    auto *DST = dyn_cast<StructType>(ToTy);
-    auto *DVT = dyn_cast<FixedVectorType>(ToTy);
+    auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
+    auto *D_ST = dyn_cast<StructType>(ToTy);
+    auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
 
     B.SetInsertPoint(BadStore);
-    if (DST && DST->getTypeAtIndex(0u) == FromTy)
-      storeToFirstValueAggregate(B, Src, Dst, Alignment);
-    else if (DVT && SVT)
+    if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
+      storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);
+    else if (D_VT && S_VT)
       storeVectorFromVector(B, Src, Dst, Alignment);
-    else if (DVT && !SVT && FromTy == DVT->getElementType())
-      storeToFirstValueAggregate(B, Src, Dst, Alignment);
+    else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
+      storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
     else
       llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
 
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
index 7d2c1093f0a71..1ff53614990d5 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
@@ -1,27 +1,43 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
 
-; CHECK-DAG:     %[[#uint:]] = OpTypeInt 32 0
-; CHECK-DAG:    %[[#float:]] = OpTypeFloat 32
-; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
-; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
-; CHECK-DAG:  %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
-; CHECK-DAG:   %[[#uint_0:]] = OpConstant %[[#uint]] 0
-; CHECK-DAG:   %[[#float_0:]] = OpConstant %[[#float]] 0
-; CHECK-DAG:       %[[#sf:]] = OpTypeStruct %[[#float]]
-; CHECK-DAG:       %[[#su:]] = OpTypeStruct %[[#uint]]
-; CHECK-DAG:       %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
-; CHECK-DAG:    %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
-; CHECK-DAG:    %[[#su_fp:]] = OpTypePointer Function %[[#su]]
-; CHECK-DAG:    %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
-; CHECK-DAG:    %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
+; CHECK-DAG:      %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:     %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG:  %[[#float_fp:]] = OpTypePointer Function %[[#float]]
+; CHECK-DAG:  %[[#float_pp:]] = OpTypePointer Private %[[#float]]
+; CHECK-DAG:   %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
+; CHECK-DAG:    %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG:    %[[#uint_4:]] = OpConstant %[[#uint]] 4
+; CHECK-DAG:    %[[#float_0:]] = OpConstant %[[#float]] 0
+; CHECK-DAG:        %[[#sf:]] = OpTypeStruct %[[#float]]
+; CHECK-DAG:        %[[#su:]] = OpTypeStruct %[[#uint]]
+; CHECK-DAG:       %[[#ssu:]] = OpTypeStruct %[[#su]]
+; CHECK-DAG:        %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
+; CHECK-DAG:        %[[#uint4:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG:        %[[#sv:]] = OpTypeStruct %[[#uint4]]
+; CHECK-DAG:        %[[#ssv:]] = OpTypeStruct %[[#sv]]
+; CHECK-DAG:        %[[#assv:]] = OpTypeArray %[[#ssv]] %[[#uint_4]]
+; CHECK-DAG:        %[[#sassv:]] = OpTypeStruct %[[#assv]]
+; CHECK-DAG:        %[[#ssassv:]] = OpTypeStruct %[[#sassv]]
+; CHECK-DAG:     %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
+; CHECK-DAG:     %[[#su_fp:]] = OpTypePointer Function %[[#su]]
+; CHECK-DAG:    %[[#ssu_fp:]] = OpTypePointer Function %[[#ssu]]
+; CHECK-DAG:    %[[#ssv_fp:]] = OpTypePointer Function %[[#ssv]]
+; CHECK-DAG: %[[#ssassv_fp:]] = OpTypePointer Function %[[#ssassv]]
+; CHECK-DAG:   %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
+; CHECK-DAG:   %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
 
 %struct.SF = type { float }
 %struct.SU = type { i32 }
 %struct.SFUF = type { float, i32, float }
+%struct.SSU = type { %struct.SU }
+%struct.SV = type { <4 x i32> }
+%struct.SSV = type { %struct.SV }
+%struct.SASSV = type { [4 x %struct.SSV] }
+%struct.SSASSV = type { %struct.SASSV }
 
 @gsfuf = external addrspace(10) global %struct.SFUF
-; CHECK: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
+; CHECK-DAG: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
 
 define internal spir_func void @foo() {
   %1 = alloca %struct.SF, align 4
@@ -64,3 +80,35 @@ define internal spir_func void @biz() {
   ret void
 }
 
+define internal spir_func void @nested_store() {
+  %1 = alloca %struct.SSU, align 4
+; CHECK: %[[#var:]]  = OpVariable %[[#ssu_fp]] Function
+
+  store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+  ret void
+}
+
+define internal spir_func void @nested_store_vector() {
+  %1 = alloca %struct.SSV, align 4
+; CHECK: %[[#var:]]  = OpVariable %[[#ssv_fp]] Function
+
+  store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+  ret void
+}
+
+define internal spir_func void @nested_array_vector() {
+  %1 = alloca %struct.SSASSV, align 4
+; CHECK: %[[#var:]]  = OpVariable %[[#ssassv_fp]] Function
+
+  store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+  ret void
+}



More information about the llvm-commits mailing list