[llvm] [SPIR-V] Add store legalization for ptrcast (PR #135369)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 11 06:58:17 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Nathan Gauër (Keenuts)
<details>
<summary>Changes</summary>
This commits adds handling for spv.ptrcast result being used in a store instruction, modifying the store to operate on the source type.
---
Full diff: https://github.com/llvm/llvm-project/pull/135369.diff
4 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp (+104)
- (modified) llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll (+20)
- (modified) llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll (+110)
- (added) llvm/test/CodeGen/SPIRV/pointers/store-struct.ll (+66)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 5ba4fbb02560d..f3f1558265d4a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -150,6 +150,95 @@ class SPIRVLegalizePointerCast : public FunctionPass {
DeadInstructions.push_back(LI);
}
+ // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
+ Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
+ unsigned Index) {
+ Type *Int32Ty = Type::getInt32Ty(B.getContext());
+ SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
+ Element->getType(), Int32Ty};
+ SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
+ Instruction *NewI =
+ B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
+ buildAssignType(B, Vector->getType(), NewI);
+ return NewI;
+ }
+
+ // Creates an spv_extractelt instruction (equivalent to llvm's
+ // extractelement).
+ Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
+ unsigned Index) {
+ Type *Int32Ty = Type::getInt32Ty(B.getContext());
+ SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
+ SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
+ Instruction *NewI =
+ B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
+ buildAssignType(B, ElementType, NewI);
+ return NewI;
+ }
+
+ // Stores the given Src vector operand into the Dst vector, adjusting the size
+ // if required.
+ Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
+ Align Alignment) {
+ FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
+ FixedVectorType *DstType =
+ cast<FixedVectorType>(GR->findDeducedElementType(Dst));
+ assert(DstType->getNumElements() >= SrcType->getNumElements());
+
+ LoadInst *LI = B.CreateLoad(DstType, Dst);
+ LI->setAlignment(Alignment);
+ Value *OldValues = LI;
+ buildAssignType(B, OldValues->getType(), OldValues);
+ Value *NewValues = Src;
+
+ for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
+ Value *Element =
+ makeExtractElement(B, SrcType->getElementType(), NewValues, I);
+ OldValues = makeInsertElement(B, OldValues, Element, I);
+ }
+
+ StoreInst *SI = B.CreateStore(OldValues, Dst);
+ SI->setAlignment(Alignment);
+ return SI;
+ }
+
+ // Stores the given Src value into the first entry of the Dst aggregate.
+ Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
+ Align Alignment) {
+ SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
+ SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst,
+ B.getInt32(0), B.getInt32(0)};
+ auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+ GR->buildAssignPtr(B, Src->getType(), GEP);
+ StoreInst *SI = B.CreateStore(Src, GEP);
+ SI->setAlignment(Alignment);
+ return SI;
+ }
+
+ // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
+ // operand into a valid logical SPIR-V store with no ptrcast.
+ void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
+ Value *Dst, Align Alignment) {
+ Type *ToTy = GR->findDeducedElementType(Dst);
+ Type *FromTy = Src->getType();
+
+ auto *SVT = dyn_cast<FixedVectorType>(FromTy);
+ auto *DST = dyn_cast<StructType>(ToTy);
+ auto *DVT = dyn_cast<FixedVectorType>(ToTy);
+
+ B.SetInsertPoint(BadStore);
+ if (DST && DST->getTypeAtIndex(0u) == FromTy)
+ storeToFirstValueAggregate(B, Src, Dst, Alignment);
+ else if (DVT && SVT)
+ storeVectorFromVector(B, Src, Dst, Alignment);
+ else if (DVT && !SVT && FromTy == DVT->getElementType())
+ storeToFirstValueAggregate(B, Src, Dst, Alignment);
+ else
+ llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
+
+ DeadInstructions.push_back(BadStore);
+ }
+
void legalizePointerCast(IntrinsicInst *II) {
Value *CastedOperand = II;
Value *OriginalOperand = II->getOperand(0);
@@ -165,6 +254,12 @@ class SPIRVLegalizePointerCast : public FunctionPass {
continue;
}
+ if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
+ transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
+ SI->getAlign());
+ continue;
+ }
+
if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
DeadInstructions.push_back(Intrin);
@@ -176,6 +271,15 @@ class SPIRVLegalizePointerCast : public FunctionPass {
/* DeleteOld= */ false);
continue;
}
+
+ if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
+ Align Alignment;
+ if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
+ Alignment = Align(C->getZExtValue());
+ transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
+ Alignment);
+ continue;
+ }
}
llvm_unreachable("Unsupported ptrcast user. Please fix.");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
index b0a68a30e29be..35e5880881e5c 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
@@ -45,3 +45,23 @@ entry:
%val = load i32, ptr addrspace(10) %ptr
ret i32 %val
}
+
+define spir_func void @foos(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global1]] %[[#uint_0]] %[[#index]]
+ %ptr = getelementptr inbounds %S1, ptr addrspace(10) @global1, i64 0, i32 0, i64 %index
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+ store i32 0, ptr addrspace(10) %ptr
+ ret void
+}
+
+define spir_func void @bars(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global2]] %[[#uint_0]] %[[#uint_0]] %[[#index]] %[[#uint_1]]
+ %ptr = getelementptr inbounds %S2, ptr addrspace(10) @global2, i64 0, i32 0, i32 0, i64 %index, i32 1
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+ store i32 0, ptr addrspace(10) %ptr
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
index d4131fa8a2658..be9e2a23365cc 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
@@ -5,9 +5,13 @@
; CHECK-DAG: %[[#uint_pp:]] = OpTypePointer Private %[[#uint]]
; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#uint_1:]] = OpConstant %[[#uint]] 1
+; CHECK-DAG: %[[#uint_2:]] = OpConstant %[[#uint]] 2
; CHECK-DAG: %[[#v2:]] = OpTypeVector %[[#uint]] 2
; CHECK-DAG: %[[#v3:]] = OpTypeVector %[[#uint]] 3
; CHECK-DAG: %[[#v4:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG: %[[#v2_01:]] = OpConstantComposite %[[#v2]] %[[#uint_0]] %[[#uint_1]]
+; CHECK-DAG: %[[#v3_012:]] = OpConstantComposite %[[#v3]] %[[#uint_0]] %[[#uint_1]] %[[#uint_2]]
; CHECK-DAG: %[[#v4_pp:]] = OpTypePointer Private %[[#v4]]
; CHECK-DAG: %[[#v4_fp:]] = OpTypePointer Function %[[#v4]]
@@ -108,3 +112,109 @@ define internal spir_func i32 @bazBounds(ptr %a) {
ret i32 %2
; CHECK: OpReturnValue %[[#val]]
}
+
+define internal spir_func void @foos(ptr addrspace(10) %a) {
+
+ %1 = getelementptr inbounds <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_pp]] %[[#]]
+
+ store <3 x i32> <i32 0, i32 1, i32 2>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+ ret void
+}
+
+define internal spir_func void @foosDefault(ptr %a) {
+
+ %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+ store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+ ret void
+}
+
+define internal spir_func void @foosBounds(ptr %a) {
+
+ %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
+
+ store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 64
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 64
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 64
+
+ ret void
+}
+
+define internal spir_func void @bars(ptr addrspace(10) %a) {
+
+ %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
+
+ store <2 x i32> <i32 0, i32 1>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: OpStore %[[#ptr]] %[[#out2]] Aligned 1
+
+ ret void
+}
+
+define internal spir_func void @bazs(ptr addrspace(10) %a) {
+
+ %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
+
+ store i32 0, ptr addrspace(10) %1, align 32
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 32
+
+ ret void
+}
+
+define internal spir_func void @bazsDefault(ptr %a) {
+
+ %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+ store i32 0, ptr %1, align 16
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+ ret void
+}
+
+define internal spir_func void @bazsBounds(ptr %a) {
+
+ %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
+
+ store i32 0, ptr %1, align 16
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
new file mode 100644
index 0000000000000..7d2c1093f0a71
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
@@ -0,0 +1,66 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
+; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
+; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
+; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#float_0:]] = OpConstant %[[#float]] 0
+; CHECK-DAG: %[[#sf:]] = OpTypeStruct %[[#float]]
+; CHECK-DAG: %[[#su:]] = OpTypeStruct %[[#uint]]
+; CHECK-DAG: %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
+; CHECK-DAG: %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
+; CHECK-DAG: %[[#su_fp:]] = OpTypePointer Function %[[#su]]
+; CHECK-DAG: %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
+; CHECK-DAG: %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
+
+%struct.SF = type { float }
+%struct.SU = type { i32 }
+%struct.SFUF = type { float, i32, float }
+
+ at gsfuf = external addrspace(10) global %struct.SFUF
+; CHECK: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
+
+define internal spir_func void @foo() {
+ %1 = alloca %struct.SF, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#sf_fp]] Function
+
+ store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @bar() {
+ %1 = alloca %struct.SU, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#su_fp]] Function
+
+ store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @baz() {
+ %1 = alloca %struct.SFUF, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#sfuf_fp]] Function
+
+ store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @biz() {
+ store float 0.0, ptr addrspace(10) @gsfuf, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_pp]] %[[#gsfuf]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+ ret void
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/135369
More information about the llvm-commits
mailing list