[llvm] [SPIR-V] Add store legalization for ptrcast (PR #135369)
Nathan Gauër via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 16 04:47:38 PDT 2025
https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/135369
>From c77cfe26f614de8d4a31d905e443ff5980408768 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Thu, 10 Apr 2025 13:12:55 +0200
Subject: [PATCH 1/2] [SPIR-V] Add store legalization for ptrcast
This commits adds handling for spv.ptrcast result being
used in a store instruction, modifying the store to operate on
the source type.
---
.../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 104 +++++++++++++++++
.../pointers/getelementptr-downcast-struct.ll | 20 ++++
.../pointers/getelementptr-downcast-vector.ll | 110 ++++++++++++++++++
.../CodeGen/SPIRV/pointers/store-struct.ll | 66 +++++++++++
4 files changed, 300 insertions(+)
create mode 100644 llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 5ba4fbb02560d..f3f1558265d4a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -150,6 +150,95 @@ class SPIRVLegalizePointerCast : public FunctionPass {
DeadInstructions.push_back(LI);
}
+ // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
+ Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
+ unsigned Index) {
+ Type *Int32Ty = Type::getInt32Ty(B.getContext());
+ SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
+ Element->getType(), Int32Ty};
+ SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
+ Instruction *NewI =
+ B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
+ buildAssignType(B, Vector->getType(), NewI);
+ return NewI;
+ }
+
+ // Creates an spv_extractelt instruction (equivalent to llvm's
+ // extractelement).
+ Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
+ unsigned Index) {
+ Type *Int32Ty = Type::getInt32Ty(B.getContext());
+ SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
+ SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
+ Instruction *NewI =
+ B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
+ buildAssignType(B, ElementType, NewI);
+ return NewI;
+ }
+
+ // Stores the given Src vector operand into the Dst vector, adjusting the size
+ // if required.
+ Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
+ Align Alignment) {
+ FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
+ FixedVectorType *DstType =
+ cast<FixedVectorType>(GR->findDeducedElementType(Dst));
+ assert(DstType->getNumElements() >= SrcType->getNumElements());
+
+ LoadInst *LI = B.CreateLoad(DstType, Dst);
+ LI->setAlignment(Alignment);
+ Value *OldValues = LI;
+ buildAssignType(B, OldValues->getType(), OldValues);
+ Value *NewValues = Src;
+
+ for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
+ Value *Element =
+ makeExtractElement(B, SrcType->getElementType(), NewValues, I);
+ OldValues = makeInsertElement(B, OldValues, Element, I);
+ }
+
+ StoreInst *SI = B.CreateStore(OldValues, Dst);
+ SI->setAlignment(Alignment);
+ return SI;
+ }
+
+ // Stores the given Src value into the first entry of the Dst aggregate.
+ Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
+ Align Alignment) {
+ SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
+ SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst,
+ B.getInt32(0), B.getInt32(0)};
+ auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+ GR->buildAssignPtr(B, Src->getType(), GEP);
+ StoreInst *SI = B.CreateStore(Src, GEP);
+ SI->setAlignment(Alignment);
+ return SI;
+ }
+
+ // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
+ // operand into a valid logical SPIR-V store with no ptrcast.
+ void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
+ Value *Dst, Align Alignment) {
+ Type *ToTy = GR->findDeducedElementType(Dst);
+ Type *FromTy = Src->getType();
+
+ auto *SVT = dyn_cast<FixedVectorType>(FromTy);
+ auto *DST = dyn_cast<StructType>(ToTy);
+ auto *DVT = dyn_cast<FixedVectorType>(ToTy);
+
+ B.SetInsertPoint(BadStore);
+ if (DST && DST->getTypeAtIndex(0u) == FromTy)
+ storeToFirstValueAggregate(B, Src, Dst, Alignment);
+ else if (DVT && SVT)
+ storeVectorFromVector(B, Src, Dst, Alignment);
+ else if (DVT && !SVT && FromTy == DVT->getElementType())
+ storeToFirstValueAggregate(B, Src, Dst, Alignment);
+ else
+ llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
+
+ DeadInstructions.push_back(BadStore);
+ }
+
void legalizePointerCast(IntrinsicInst *II) {
Value *CastedOperand = II;
Value *OriginalOperand = II->getOperand(0);
@@ -165,6 +254,12 @@ class SPIRVLegalizePointerCast : public FunctionPass {
continue;
}
+ if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
+ transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
+ SI->getAlign());
+ continue;
+ }
+
if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
DeadInstructions.push_back(Intrin);
@@ -176,6 +271,15 @@ class SPIRVLegalizePointerCast : public FunctionPass {
/* DeleteOld= */ false);
continue;
}
+
+ if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
+ Align Alignment;
+ if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
+ Alignment = Align(C->getZExtValue());
+ transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
+ Alignment);
+ continue;
+ }
}
llvm_unreachable("Unsupported ptrcast user. Please fix.");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
index b0a68a30e29be..35e5880881e5c 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
@@ -45,3 +45,23 @@ entry:
%val = load i32, ptr addrspace(10) %ptr
ret i32 %val
}
+
+define spir_func void @foos(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global1]] %[[#uint_0]] %[[#index]]
+ %ptr = getelementptr inbounds %S1, ptr addrspace(10) @global1, i64 0, i32 0, i64 %index
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+ store i32 0, ptr addrspace(10) %ptr
+ ret void
+}
+
+define spir_func void @bars(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global2]] %[[#uint_0]] %[[#uint_0]] %[[#index]] %[[#uint_1]]
+ %ptr = getelementptr inbounds %S2, ptr addrspace(10) @global2, i64 0, i32 0, i32 0, i64 %index, i32 1
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+ store i32 0, ptr addrspace(10) %ptr
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
index d4131fa8a2658..be9e2a23365cc 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
@@ -5,9 +5,13 @@
; CHECK-DAG: %[[#uint_pp:]] = OpTypePointer Private %[[#uint]]
; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#uint_1:]] = OpConstant %[[#uint]] 1
+; CHECK-DAG: %[[#uint_2:]] = OpConstant %[[#uint]] 2
; CHECK-DAG: %[[#v2:]] = OpTypeVector %[[#uint]] 2
; CHECK-DAG: %[[#v3:]] = OpTypeVector %[[#uint]] 3
; CHECK-DAG: %[[#v4:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG: %[[#v2_01:]] = OpConstantComposite %[[#v2]] %[[#uint_0]] %[[#uint_1]]
+; CHECK-DAG: %[[#v3_012:]] = OpConstantComposite %[[#v3]] %[[#uint_0]] %[[#uint_1]] %[[#uint_2]]
; CHECK-DAG: %[[#v4_pp:]] = OpTypePointer Private %[[#v4]]
; CHECK-DAG: %[[#v4_fp:]] = OpTypePointer Function %[[#v4]]
@@ -108,3 +112,109 @@ define internal spir_func i32 @bazBounds(ptr %a) {
ret i32 %2
; CHECK: OpReturnValue %[[#val]]
}
+
+define internal spir_func void @foos(ptr addrspace(10) %a) {
+
+ %1 = getelementptr inbounds <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_pp]] %[[#]]
+
+ store <3 x i32> <i32 0, i32 1, i32 2>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+ ret void
+}
+
+define internal spir_func void @foosDefault(ptr %a) {
+
+ %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+ store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+ ret void
+}
+
+define internal spir_func void @foosBounds(ptr %a) {
+
+ %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
+
+ store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 64
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 64
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 64
+
+ ret void
+}
+
+define internal spir_func void @bars(ptr addrspace(10) %a) {
+
+ %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
+
+ store <2 x i32> <i32 0, i32 1>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: OpStore %[[#ptr]] %[[#out2]] Aligned 1
+
+ ret void
+}
+
+define internal spir_func void @bazs(ptr addrspace(10) %a) {
+
+ %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
+
+ store i32 0, ptr addrspace(10) %1, align 32
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 32
+
+ ret void
+}
+
+define internal spir_func void @bazsDefault(ptr %a) {
+
+ %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+ store i32 0, ptr %1, align 16
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+ ret void
+}
+
+define internal spir_func void @bazsBounds(ptr %a) {
+
+ %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
+
+ store i32 0, ptr %1, align 16
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
new file mode 100644
index 0000000000000..7d2c1093f0a71
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
@@ -0,0 +1,66 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
+; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
+; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
+; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#float_0:]] = OpConstant %[[#float]] 0
+; CHECK-DAG: %[[#sf:]] = OpTypeStruct %[[#float]]
+; CHECK-DAG: %[[#su:]] = OpTypeStruct %[[#uint]]
+; CHECK-DAG: %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
+; CHECK-DAG: %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
+; CHECK-DAG: %[[#su_fp:]] = OpTypePointer Function %[[#su]]
+; CHECK-DAG: %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
+; CHECK-DAG: %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
+
+%struct.SF = type { float }
+%struct.SU = type { i32 }
+%struct.SFUF = type { float, i32, float }
+
+ at gsfuf = external addrspace(10) global %struct.SFUF
+; CHECK: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
+
+define internal spir_func void @foo() {
+ %1 = alloca %struct.SF, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#sf_fp]] Function
+
+ store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @bar() {
+ %1 = alloca %struct.SU, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#su_fp]] Function
+
+ store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @baz() {
+ %1 = alloca %struct.SFUF, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#sfuf_fp]] Function
+
+ store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @biz() {
+ store float 0.0, ptr addrspace(10) @gsfuf, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_pp]] %[[#gsfuf]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+ ret void
+}
+
>From 414e576aa203aa0f464948161a8af05c2567a4ed Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 16 Apr 2025 13:46:39 +0200
Subject: [PATCH 2/2] support nested types
---
.../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 51 +++++++++---
.../CodeGen/SPIRV/pointers/store-struct.ll | 78 +++++++++++++++----
2 files changed, 103 insertions(+), 26 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index f3f1558265d4a..e55a6be18c944 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -202,12 +202,29 @@ class SPIRVLegalizePointerCast : public FunctionPass {
return SI;
}
+ void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
+ SmallVectorImpl<Value *> &Indices) {
+ Indices.push_back(B.getInt32(0));
+
+ if (Search == Aggregate)
+ return;
+
+ if (auto *ST = dyn_cast<StructType>(Aggregate))
+ buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);
+ else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
+ buildGEPIndexChain(B, Search, AT->getElementType(), Indices);
+ else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
+ buildGEPIndexChain(B, Search, VT->getElementType(), Indices);
+ else
+ llvm_unreachable("Bad access chain?");
+ }
+
// Stores the given Src value into the first entry of the Dst aggregate.
Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
- Align Alignment) {
+ Type *DstPointeeType, Align Alignment) {
SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
- SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst,
- B.getInt32(0), B.getInt32(0)};
+ SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};
+ buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
GR->buildAssignPtr(B, Src->getType(), GEP);
StoreInst *SI = B.CreateStore(Src, GEP);
@@ -215,6 +232,18 @@ class SPIRVLegalizePointerCast : public FunctionPass {
return SI;
}
+ bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
+ if (Search == Aggregate)
+ return true;
+ if (auto *ST = dyn_cast<StructType>(Aggregate))
+ return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));
+ if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
+ return isTypeFirstElementAggregate(Search, VT->getElementType());
+ if (auto *AT = dyn_cast<ArrayType>(Aggregate))
+ return isTypeFirstElementAggregate(Search, AT->getElementType());
+ return false;
+ }
+
// Transforms a store instruction (or SPV intrinsic) using a ptrcast as
// operand into a valid logical SPIR-V store with no ptrcast.
void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
@@ -222,17 +251,17 @@ class SPIRVLegalizePointerCast : public FunctionPass {
Type *ToTy = GR->findDeducedElementType(Dst);
Type *FromTy = Src->getType();
- auto *SVT = dyn_cast<FixedVectorType>(FromTy);
- auto *DST = dyn_cast<StructType>(ToTy);
- auto *DVT = dyn_cast<FixedVectorType>(ToTy);
+ auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
+ auto *D_ST = dyn_cast<StructType>(ToTy);
+ auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
B.SetInsertPoint(BadStore);
- if (DST && DST->getTypeAtIndex(0u) == FromTy)
- storeToFirstValueAggregate(B, Src, Dst, Alignment);
- else if (DVT && SVT)
+ if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
+ storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);
+ else if (D_VT && S_VT)
storeVectorFromVector(B, Src, Dst, Alignment);
- else if (DVT && !SVT && FromTy == DVT->getElementType())
- storeToFirstValueAggregate(B, Src, Dst, Alignment);
+ else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
+ storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
else
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
index 7d2c1093f0a71..1ff53614990d5 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
@@ -1,27 +1,43 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
-; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
-; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
-; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
-; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
-; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
-; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
-; CHECK-DAG: %[[#float_0:]] = OpConstant %[[#float]] 0
-; CHECK-DAG: %[[#sf:]] = OpTypeStruct %[[#float]]
-; CHECK-DAG: %[[#su:]] = OpTypeStruct %[[#uint]]
-; CHECK-DAG: %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
-; CHECK-DAG: %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
-; CHECK-DAG: %[[#su_fp:]] = OpTypePointer Function %[[#su]]
-; CHECK-DAG: %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
-; CHECK-DAG: %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
+; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
+; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
+; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#uint_4:]] = OpConstant %[[#uint]] 4
+; CHECK-DAG: %[[#float_0:]] = OpConstant %[[#float]] 0
+; CHECK-DAG: %[[#sf:]] = OpTypeStruct %[[#float]]
+; CHECK-DAG: %[[#su:]] = OpTypeStruct %[[#uint]]
+; CHECK-DAG: %[[#ssu:]] = OpTypeStruct %[[#su]]
+; CHECK-DAG: %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
+; CHECK-DAG: %[[#uint4:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG: %[[#sv:]] = OpTypeStruct %[[#uint4]]
+; CHECK-DAG: %[[#ssv:]] = OpTypeStruct %[[#sv]]
+; CHECK-DAG: %[[#assv:]] = OpTypeArray %[[#ssv]] %[[#uint_4]]
+; CHECK-DAG: %[[#sassv:]] = OpTypeStruct %[[#assv]]
+; CHECK-DAG: %[[#ssassv:]] = OpTypeStruct %[[#sassv]]
+; CHECK-DAG: %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
+; CHECK-DAG: %[[#su_fp:]] = OpTypePointer Function %[[#su]]
+; CHECK-DAG: %[[#ssu_fp:]] = OpTypePointer Function %[[#ssu]]
+; CHECK-DAG: %[[#ssv_fp:]] = OpTypePointer Function %[[#ssv]]
+; CHECK-DAG: %[[#ssassv_fp:]] = OpTypePointer Function %[[#ssassv]]
+; CHECK-DAG: %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
+; CHECK-DAG: %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
%struct.SF = type { float }
%struct.SU = type { i32 }
%struct.SFUF = type { float, i32, float }
+%struct.SSU = type { %struct.SU }
+%struct.SV = type { <4 x i32> }
+%struct.SSV = type { %struct.SV }
+%struct.SASSV = type { [4 x %struct.SSV] }
+%struct.SSASSV = type { %struct.SASSV }
@gsfuf = external addrspace(10) global %struct.SFUF
-; CHECK: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
+; CHECK-DAG: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
define internal spir_func void @foo() {
%1 = alloca %struct.SF, align 4
@@ -64,3 +80,35 @@ define internal spir_func void @biz() {
ret void
}
+define internal spir_func void @nested_store() {
+ %1 = alloca %struct.SSU, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#ssu_fp]] Function
+
+ store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @nested_store_vector() {
+ %1 = alloca %struct.SSV, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#ssv_fp]] Function
+
+ store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @nested_array_vector() {
+ %1 = alloca %struct.SSASSV, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#ssassv_fp]] Function
+
+ store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+ ret void
+}
More information about the llvm-commits
mailing list