[llvm] [SPIR-V] Add store legalization for ptrcast (PR #135369)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 11 06:58:17 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Nathan Gauër (Keenuts)

<details>
<summary>Changes</summary>

This commits adds handling for spv.ptrcast result being used in a store instruction, modifying the store to operate on the source type.

---
Full diff: https://github.com/llvm/llvm-project/pull/135369.diff


4 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp (+104) 
- (modified) llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll (+20) 
- (modified) llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll (+110) 
- (added) llvm/test/CodeGen/SPIRV/pointers/store-struct.ll (+66) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 5ba4fbb02560d..f3f1558265d4a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -150,6 +150,95 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     DeadInstructions.push_back(LI);
   }
 
+  // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
+  Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
+                           unsigned Index) {
+    Type *Int32Ty = Type::getInt32Ty(B.getContext());
+    SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
+                                    Element->getType(), Int32Ty};
+    SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
+    Instruction *NewI =
+        B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
+    buildAssignType(B, Vector->getType(), NewI);
+    return NewI;
+  }
+
+  // Creates an spv_extractelt instruction (equivalent to llvm's
+  // extractelement).
+  Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
+                            unsigned Index) {
+    Type *Int32Ty = Type::getInt32Ty(B.getContext());
+    SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
+    SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
+    Instruction *NewI =
+        B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
+    buildAssignType(B, ElementType, NewI);
+    return NewI;
+  }
+
+  // Stores the given Src vector operand into the Dst vector, adjusting the size
+  // if required.
+  Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
+                               Align Alignment) {
+    FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
+    FixedVectorType *DstType =
+        cast<FixedVectorType>(GR->findDeducedElementType(Dst));
+    assert(DstType->getNumElements() >= SrcType->getNumElements());
+
+    LoadInst *LI = B.CreateLoad(DstType, Dst);
+    LI->setAlignment(Alignment);
+    Value *OldValues = LI;
+    buildAssignType(B, OldValues->getType(), OldValues);
+    Value *NewValues = Src;
+
+    for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
+      Value *Element =
+          makeExtractElement(B, SrcType->getElementType(), NewValues, I);
+      OldValues = makeInsertElement(B, OldValues, Element, I);
+    }
+
+    StoreInst *SI = B.CreateStore(OldValues, Dst);
+    SI->setAlignment(Alignment);
+    return SI;
+  }
+
+  // Stores the given Src value into the first entry of the Dst aggregate.
+  Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
+                                    Align Alignment) {
+    SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
+    SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst,
+                                 B.getInt32(0), B.getInt32(0)};
+    auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+    GR->buildAssignPtr(B, Src->getType(), GEP);
+    StoreInst *SI = B.CreateStore(Src, GEP);
+    SI->setAlignment(Alignment);
+    return SI;
+  }
+
+  // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
+  // operand into a valid logical SPIR-V store with no ptrcast.
+  void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
+                      Value *Dst, Align Alignment) {
+    Type *ToTy = GR->findDeducedElementType(Dst);
+    Type *FromTy = Src->getType();
+
+    auto *SVT = dyn_cast<FixedVectorType>(FromTy);
+    auto *DST = dyn_cast<StructType>(ToTy);
+    auto *DVT = dyn_cast<FixedVectorType>(ToTy);
+
+    B.SetInsertPoint(BadStore);
+    if (DST && DST->getTypeAtIndex(0u) == FromTy)
+      storeToFirstValueAggregate(B, Src, Dst, Alignment);
+    else if (DVT && SVT)
+      storeVectorFromVector(B, Src, Dst, Alignment);
+    else if (DVT && !SVT && FromTy == DVT->getElementType())
+      storeToFirstValueAggregate(B, Src, Dst, Alignment);
+    else
+      llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
+
+    DeadInstructions.push_back(BadStore);
+  }
+
   void legalizePointerCast(IntrinsicInst *II) {
     Value *CastedOperand = II;
     Value *OriginalOperand = II->getOperand(0);
@@ -165,6 +254,12 @@ class SPIRVLegalizePointerCast : public FunctionPass {
         continue;
       }
 
+      if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
+        transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
+                       SI->getAlign());
+        continue;
+      }
+
       if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
         if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
           DeadInstructions.push_back(Intrin);
@@ -176,6 +271,15 @@ class SPIRVLegalizePointerCast : public FunctionPass {
                                  /* DeleteOld= */ false);
           continue;
         }
+
+        if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
+          Align Alignment;
+          if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
+            Alignment = Align(C->getZExtValue());
+          transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
+                         Alignment);
+          continue;
+        }
       }
 
       llvm_unreachable("Unsupported ptrcast user. Please fix.");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
index b0a68a30e29be..35e5880881e5c 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
@@ -45,3 +45,23 @@ entry:
   %val = load i32, ptr addrspace(10) %ptr
   ret i32 %val
 }
+
+define spir_func void @foos(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global1]] %[[#uint_0]] %[[#index]]
+  %ptr = getelementptr inbounds %S1, ptr addrspace(10) @global1, i64 0, i32 0, i64 %index
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+  store i32 0, ptr addrspace(10) %ptr
+  ret void
+}
+
+define spir_func void @bars(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global2]] %[[#uint_0]] %[[#uint_0]] %[[#index]] %[[#uint_1]]
+  %ptr = getelementptr inbounds %S2, ptr addrspace(10) @global2, i64 0, i32 0, i32 0, i64 %index, i32 1
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+  store i32 0, ptr addrspace(10) %ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
index d4131fa8a2658..be9e2a23365cc 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
@@ -5,9 +5,13 @@
 ; CHECK-DAG: %[[#uint_pp:]] = OpTypePointer Private %[[#uint]]
 ; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
 ; CHECK-DAG:  %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG:  %[[#uint_1:]] = OpConstant %[[#uint]] 1
+; CHECK-DAG:  %[[#uint_2:]] = OpConstant %[[#uint]] 2
 ; CHECK-DAG:      %[[#v2:]] = OpTypeVector %[[#uint]] 2
 ; CHECK-DAG:      %[[#v3:]] = OpTypeVector %[[#uint]] 3
 ; CHECK-DAG:      %[[#v4:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG:   %[[#v2_01:]] = OpConstantComposite %[[#v2]] %[[#uint_0]] %[[#uint_1]]
+; CHECK-DAG:  %[[#v3_012:]] = OpConstantComposite %[[#v3]] %[[#uint_0]] %[[#uint_1]] %[[#uint_2]]
 ; CHECK-DAG:   %[[#v4_pp:]] = OpTypePointer Private %[[#v4]]
 ; CHECK-DAG:   %[[#v4_fp:]] = OpTypePointer Function %[[#v4]]
 
@@ -108,3 +112,109 @@ define internal spir_func i32 @bazBounds(ptr %a) {
   ret i32 %2
 ; CHECK: OpReturnValue %[[#val]]
 }
+
+define internal spir_func void @foos(ptr addrspace(10) %a) {
+
+  %1 = getelementptr inbounds <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]]  = OpInBoundsAccessChain %[[#v4_pp]] %[[#]]
+
+  store <3 x i32> <i32 0, i32 1, i32 2>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK:    %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK:    %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK:    %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+  ret void
+}
+
+define internal spir_func void @foosDefault(ptr %a) {
+
+  %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]]  = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+  store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK:    %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK:    %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK:    %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+  ret void
+}
+
+define internal spir_func void @foosBounds(ptr %a) {
+
+  %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]]  = OpAccessChain %[[#v4_fp]] %[[#]]
+
+  store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 64
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 64
+; CHECK:    %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK:    %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK:    %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 64
+
+  ret void
+}
+
+define internal spir_func void @bars(ptr addrspace(10) %a) {
+
+  %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]]  = OpAccessChain %[[#v4_pp]] %[[#]]
+
+  store <2 x i32> <i32 0, i32 1>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK:    %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK:    %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: OpStore %[[#ptr]] %[[#out2]] Aligned 1
+
+  ret void
+}
+
+define internal spir_func void @bazs(ptr addrspace(10) %a) {
+
+  %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]]  = OpAccessChain %[[#v4_pp]] %[[#]]
+
+  store i32 0, ptr addrspace(10) %1, align 32
+; CHECK:  %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 32
+
+  ret void
+}
+
+define internal spir_func void @bazsDefault(ptr %a) {
+
+  %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]]  = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+  store i32 0, ptr %1, align 16
+; CHECK:  %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+  ret void
+}
+
+define internal spir_func void @bazsBounds(ptr %a) {
+
+  %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]]  = OpAccessChain %[[#v4_fp]] %[[#]]
+
+  store i32 0, ptr %1, align 16
+; CHECK:  %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
new file mode 100644
index 0000000000000..7d2c1093f0a71
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
@@ -0,0 +1,66 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG:     %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:    %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
+; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
+; CHECK-DAG:  %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
+; CHECK-DAG:   %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG:   %[[#float_0:]] = OpConstant %[[#float]] 0
+; CHECK-DAG:       %[[#sf:]] = OpTypeStruct %[[#float]]
+; CHECK-DAG:       %[[#su:]] = OpTypeStruct %[[#uint]]
+; CHECK-DAG:       %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
+; CHECK-DAG:    %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
+; CHECK-DAG:    %[[#su_fp:]] = OpTypePointer Function %[[#su]]
+; CHECK-DAG:    %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
+; CHECK-DAG:    %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
+
+%struct.SF = type { float }
+%struct.SU = type { i32 }
+%struct.SFUF = type { float, i32, float }
+
+ at gsfuf = external addrspace(10) global %struct.SFUF
+; CHECK: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
+
+define internal spir_func void @foo() {
+  %1 = alloca %struct.SF, align 4
+; CHECK: %[[#var:]]  = OpVariable %[[#sf_fp]] Function
+
+  store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+  ret void
+}
+
+define internal spir_func void @bar() {
+  %1 = alloca %struct.SU, align 4
+; CHECK: %[[#var:]]  = OpVariable %[[#su_fp]] Function
+
+  store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+  ret void
+}
+
+define internal spir_func void @baz() {
+  %1 = alloca %struct.SFUF, align 4
+; CHECK: %[[#var:]]  = OpVariable %[[#sfuf_fp]] Function
+
+  store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+  ret void
+}
+
+define internal spir_func void @biz() {
+  store float 0.0, ptr addrspace(10) @gsfuf, align 4
+; CHECK: %[[#tmp:]]  = OpInBoundsAccessChain %[[#float_pp]] %[[#gsfuf]] %[[#uint_0]]
+; CHECK:               OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+  ret void
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/135369


More information about the llvm-commits mailing list