[llvm] 9f72fab - [SPIRV] Fix vector bitcast check in LegalizePointerCast (#164997)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 31 07:12:30 PDT 2025


Author: Steven Perron
Date: 2025-10-31T10:12:26-04:00
New Revision: 9f72fab4901d4e2b49c8b1974e1d75d73728e416

URL: https://github.com/llvm/llvm-project/commit/9f72fab4901d4e2b49c8b1974e1d75d73728e416
DIFF: https://github.com/llvm/llvm-project/commit/9f72fab4901d4e2b49c8b1974e1d75d73728e416.diff

LOG: [SPIRV] Fix vector bitcast check in LegalizePointerCast (#164997)

The previous check for vector bitcasts in `loadVectorFromVector` only
compared the number of elements, which is insufficient when the element
types differ. This can lead to incorrect assumptions about the validity
of the cast.

This commit replaces the element count check with a comparison of the
total size of the vectors in bits. This ensures that the bitcast is
only performed between vectors of the same size, preventing potential
miscompilations.

Part of https://github.com/llvm/llvm-project/issues/153091

Added: 
    

Modified: 
    llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
    llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
    llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 6e444c98de8da..65dffc7908b78 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass {
   // Returns the loaded value.
   Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
                               FixedVectorType *TargetType, Value *Source) {
-    assert(TargetType->getNumElements() <= SourceType->getNumElements());
     LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
     buildAssignType(B, SourceType, NewLoad);
     Value *AssignValue = NewLoad;
     if (TargetType->getElementType() != SourceType->getElementType()) {
+      const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
+      [[maybe_unused]] TypeSize TargetTypeSize =
+          DL.getTypeSizeInBits(TargetType);
+      [[maybe_unused]] TypeSize SourceTypeSize =
+          DL.getTypeSizeInBits(SourceType);
+      assert(TargetTypeSize == SourceTypeSize);
       AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
                                       {TargetType, SourceType}, {NewLoad});
       buildAssignType(B, TargetType, AssignValue);
+      return AssignValue;
     }
 
+    assert(TargetType->getNumElements() < SourceType->getNumElements());
     SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
     for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
       Mask[I] = I;

diff  --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
index ed67344842b11..4817e7450ac2e 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
@@ -16,7 +16,6 @@
 define void @case1() local_unnamed_addr {
   ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
   ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
-  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
   %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
   %2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.2)
   %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
@@ -29,8 +28,7 @@ define void @case1() local_unnamed_addr {
 define void @case2() local_unnamed_addr {
   ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
   ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
-  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
-  ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2
+  ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#CAST_LOAD]] %[[#UNDEF_INT4]] 0 1 2
   %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
   %2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.3)
   %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index 84913283f6868..a1ec2cd1cfdd2 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -26,3 +26,25 @@ entry:
   store <4 x i32> %6, ptr addrspace(11) %7, align 16
   ret void
 }
+
+; This tests a load from a pointer that has been bitcast between vector types
+; which share the same total bit-width but have 
diff erent numbers of elements.
+; Tests that legalize-pointer-casts works correctly by moving the bitcast to
+; the element that was loaded.
+
+define void @main2() local_unnamed_addr #0 {
+entry:
+; CHECK:  %[[LOAD:[0-9]+]] = OpLoad %[[#v2_double]] {{.*}}
+; CHECK:  %[[BITCAST1:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD]]
+; CHECK:  %[[BITCAST2:[0-9]+]] = OpBitcast %[[#v2_double]] %[[BITCAST1]]
+; CHECK: OpStore {{%[0-9]+}} %[[BITCAST2]] {{.*}}
+
+  %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
+  %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 0)
+  %3 = load <4 x i32>, ptr addrspace(11) %2
+  %4 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 1)
+  store <4 x i32> %3, ptr addrspace(11) %4
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }


        


More information about the llvm-commits mailing list