[llvm] [SPIRV] Preserve implicit bitcast (PR #151041)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 30 08:07:12 PDT 2025


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/151041

>From 344e7a1389d4d62ad42791b8a0aa6c33163bd8ff Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 28 Jul 2025 16:07:05 -0400
Subject: [PATCH 1/3] [SPIRV] Preserve bitcast implicit bitcast

---
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 12 +++++-----
 .../hlsl-resources/issue-146942-ptr-cast.ll   | 24 +++++++++++++++++++
 2 files changed, 30 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 5cda6a07352d5..7e9c9c62cd889 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -74,17 +74,17 @@ class SPIRVLegalizePointerCast : public FunctionPass {
   // Returns the loaded value.
   Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
                               FixedVectorType *TargetType, Value *Source) {
-    // We expect the codegen to avoid doing implicit bitcast from a load.
-    assert(TargetType->getElementType() == SourceType->getElementType());
-    assert(TargetType->getNumElements() < SourceType->getNumElements());
-
+    assert(TargetType->getNumElements() <= SourceType->getNumElements());
     LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
-    buildAssignType(B, SourceType, NewLoad);
+    Value *AssignType = NewLoad;
+    if (TargetType->getElementType() != SourceType->getElementType())
+      AssignType = B.CreateBitCast(NewLoad, TargetType);
+    buildAssignType(B, SourceType, AssignType);
 
     SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
     for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
       Mask[I] = I;
-    Value *Output = B.CreateShuffleVector(NewLoad, NewLoad, Mask);
+    Value *Output = B.CreateShuffleVector(AssignType, AssignType, Mask);
     buildAssignType(B, TargetType, Output);
     return Output;
   }
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
new file mode 100644
index 0000000000000..df1ef616df301
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
@@ -0,0 +1,24 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+ at .str = private unnamed_addr constant [4 x i8] c"In3\00", align 1
+ at .str.2 = private unnamed_addr constant [5 x i8] c"Out3\00", align 1
+
+; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none)
+define void @main() local_unnamed_addr #0 {
+  ; CHECK: %[[#INT32:]] = OpTypeInt 32 0
+  ; CHECK: %[[#INT4:]] = OpTypeVector %[[#INT32]] 4
+  ; CHECK: %[[#FLOAT:]] = OpTypeFloat 32
+  ; CHECK: %[[#FLOAT4:]] = OpTypeVector %[[#FLOAT]] 4
+  ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
+  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#BUFFER_LOAD]] %[[#BUFFER_LOAD]] 0 1 2 3
+  %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, i1 false, ptr nonnull @.str)
+  %2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, i1 false, ptr nonnull @.str.2)
+  %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
+  %4 = load <4 x i32>, ptr addrspace(11) %3, align 16
+  %5 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4i32_12_1t(target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) %2, i32 0)
+  store <4 x i32> %4, ptr addrspace(11) %5, align 16
+  ret void
+}
+
+attributes #0 = { mustprogress nofree noinline norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none) "approx-func-fp-math"="true" "frame-pointer"="all" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }

>From 52cc44e5edb2b55285e59a8708e3ad702b6d5f9b Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Tue, 29 Jul 2025 11:58:50 -0400
Subject: [PATCH 2/3] llvm ir bitcast added in `SPIRVLegalizePointerCast.cpp`
 does not get translated into MIR. Instead we will do what
 `SPIRVEmitIntrinsics::visitBitCastInst` does and emit `spv_bitcast` instead
 of `BitCastInst`.

---
 .../lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp | 15 ++++++++++-----
 .../SPIRV/hlsl-resources/issue-146942-ptr-cast.ll |  3 ++-
 2 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 7e9c9c62cd889..ce25c7e054af0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -76,15 +76,18 @@ class SPIRVLegalizePointerCast : public FunctionPass {
                               FixedVectorType *TargetType, Value *Source) {
     assert(TargetType->getNumElements() <= SourceType->getNumElements());
     LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
-    Value *AssignType = NewLoad;
-    if (TargetType->getElementType() != SourceType->getElementType())
-      AssignType = B.CreateBitCast(NewLoad, TargetType);
-    buildAssignType(B, SourceType, AssignType);
+    buildAssignType(B, SourceType, NewLoad);
+    Value *AssignValue = NewLoad;
+    if (TargetType->getElementType() != SourceType->getElementType()) {
+      AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
+                                      {TargetType, SourceType}, {NewLoad});
+      buildAssignType(B, TargetType, AssignValue);
+    }
 
     SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
     for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
       Mask[I] = I;
-    Value *Output = B.CreateShuffleVector(AssignType, AssignType, Mask);
+    Value *Output = B.CreateShuffleVector(AssignValue, AssignValue, Mask);
     buildAssignType(B, TargetType, Output);
     return Output;
   }
@@ -136,7 +139,9 @@ class SPIRVLegalizePointerCast : public FunctionPass {
                                            OriginalOperand, LI);
     }
     // Destination is a smaller vector than source.
+    // or different vector type.
     // - float3 v3 = vector4;
+    // - float4 v2 = int4;
     else if (SVT && DVT)
       Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand);
     // Destination is the scalar type stored at the start of an aggregate.
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
index df1ef616df301..6cb75bce32f69 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
@@ -11,7 +11,8 @@ define void @main() local_unnamed_addr #0 {
   ; CHECK: %[[#FLOAT:]] = OpTypeFloat 32
   ; CHECK: %[[#FLOAT4:]] = OpTypeVector %[[#FLOAT]] 4
   ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
-  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#BUFFER_LOAD]] %[[#BUFFER_LOAD]] 0 1 2 3
+  ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
+  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
   %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, i1 false, ptr nonnull @.str)
   %2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, i1 false, ptr nonnull @.str.2)
   %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)

>From d06d08c308c30cd1bf5bdecf4d389de52db27b34 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 30 Jul 2025 11:05:59 -0400
Subject: [PATCH 3/3] address pr comments, add a truncation test case

---
 .../hlsl-resources/issue-146942-ptr-cast.ll   | 33 ++++++++++++++-----
 1 file changed, 25 insertions(+), 8 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
index 6cb75bce32f69..b2333e642340c 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
@@ -2,14 +2,18 @@
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
 
 @.str = private unnamed_addr constant [4 x i8] c"In3\00", align 1
- at .str.2 = private unnamed_addr constant [5 x i8] c"Out3\00", align 1
+ at .str.2 = private unnamed_addr constant [5 x i8] c"Out4\00", align 1
+ at .str.3 = private unnamed_addr constant [5 x i8] c"Out3\00", align 1
 
-; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none)
-define void @main() local_unnamed_addr #0 {
-  ; CHECK: %[[#INT32:]] = OpTypeInt 32 0
-  ; CHECK: %[[#INT4:]] = OpTypeVector %[[#INT32]] 4
-  ; CHECK: %[[#FLOAT:]] = OpTypeFloat 32
-  ; CHECK: %[[#FLOAT4:]] = OpTypeVector %[[#FLOAT]] 4
+
+; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#INT4:]] = OpTypeVector %[[#INT32]] 4
+; CHECK-DAG: %[[#FLOAT:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#FLOAT4:]] = OpTypeVector %[[#FLOAT]] 4
+; CHECK-DAG: %[[#INT3:]] = OpTypeVector %[[#INT32]] 3
+; CHECK-DAG: %[[#UNDEF_INT4:]] = OpUndef %[[#INT4]]
+
+define void @case1() local_unnamed_addr {
   ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
   ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
   ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
@@ -22,4 +26,17 @@ define void @main() local_unnamed_addr #0 {
   ret void
 }
 
-attributes #0 = { mustprogress nofree noinline norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none) "approx-func-fp-math"="true" "frame-pointer"="all" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+define void @case2() local_unnamed_addr {
+  ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
+  ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
+  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
+  ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2
+  %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, i1 false, ptr nonnull @.str)
+  %2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, i1 false, ptr nonnull @.str.3)
+  %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
+  %4 = load <4 x i32>, ptr addrspace(11) %3, align 16
+  %5 = shufflevector <4 x i32> %4, <4 x i32> poison, <3 x i32> <i32 0, i32 1, i32 2>
+  %6 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v3i32_12_1t(target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) %2, i32 0)
+  store <3 x i32> %5, ptr addrspace(11) %6, align 16
+  ret void
+}



More information about the llvm-commits mailing list