[llvm] [SPIRV] Fix vector bitcast check in LegalizePointerCast (PR #164997)
    via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Fri Oct 24 08:13:59 PDT 2025
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Steven Perron (s-perron)
<details>
<summary>Changes</summary>
The previous check for vector bitcasts in `loadVectorFromVector` only
compared the number of elements, which is insufficient when the element
types differ. This can lead to incorrect assumptions about the validity
of the cast.
This commit replaces the element count check with a comparison of the
total size of the vectors in bits. This ensures that the bitcast is
only performed between vectors of the same size, preventing potential
miscompilations.
---
Full diff: https://github.com/llvm/llvm-project/pull/164997.diff
3 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp (+8-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll (+1-3) 
- (modified) llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll (+22) 
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 28a1690ef0be1..a692c24363310 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass {
   // Returns the loaded value.
   Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
                               FixedVectorType *TargetType, Value *Source) {
-    assert(TargetType->getNumElements() <= SourceType->getNumElements());
     LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
     buildAssignType(B, SourceType, NewLoad);
     Value *AssignValue = NewLoad;
     if (TargetType->getElementType() != SourceType->getElementType()) {
+      const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
+      [[maybe_unused]] TypeSize TargetTypeSize =
+          DL.getTypeSizeInBits(TargetType);
+      [[maybe_unused]] TypeSize SourceTypeSize =
+          DL.getTypeSizeInBits(SourceType);
+      assert(TargetTypeSize == SourceTypeSize);
       AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
                                       {TargetType, SourceType}, {NewLoad});
       buildAssignType(B, TargetType, AssignValue);
+      return AssignValue;
     }
 
+    assert(TargetType->getNumElements() < SourceType->getNumElements());
     SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
     for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
       Mask[I] = I;
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
index ed67344842b11..4817e7450ac2e 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
@@ -16,7 +16,6 @@
 define void @case1() local_unnamed_addr {
   ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
   ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
-  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
   %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
   %2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.2)
   %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
@@ -29,8 +28,7 @@ define void @case1() local_unnamed_addr {
 define void @case2() local_unnamed_addr {
   ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
   ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
-  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
-  ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2
+  ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#CAST_LOAD]] %[[#UNDEF_INT4]] 0 1 2
   %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
   %2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.3)
   %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index 84913283f6868..a1ec2cd1cfdd2 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -26,3 +26,25 @@ entry:
   store <4 x i32> %6, ptr addrspace(11) %7, align 16
   ret void
 }
+
+; This tests a load from a pointer that has been bitcast between vector types
+; which share the same total bit-width but have different numbers of elements.
+; Tests that legalize-pointer-casts works correctly by moving the bitcast to
+; the element that was loaded.
+
+define void @main2() local_unnamed_addr #0 {
+entry:
+; CHECK:  %[[LOAD:[0-9]+]] = OpLoad %[[#v2_double]] {{.*}}
+; CHECK:  %[[BITCAST1:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD]]
+; CHECK:  %[[BITCAST2:[0-9]+]] = OpBitcast %[[#v2_double]] %[[BITCAST1]]
+; CHECK: OpStore {{%[0-9]+}} %[[BITCAST2]] {{.*}}
+
+  %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
+  %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 0)
+  %3 = load <4 x i32>, ptr addrspace(11) %2
+  %4 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 1)
+  store <4 x i32> %3, ptr addrspace(11) %4
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
``````````
</details>
https://github.com/llvm/llvm-project/pull/164997
    
    
More information about the llvm-commits
mailing list