[llvm] [SPIRV] Add Intermediate cast when Vector From/To types are of different element size and type (PR #182166)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 19 11:48:40 PST 2026


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/182166

>From 8bb8724ec291fb48797f6ca059cf1034a0c93989 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 18 Feb 2026 17:05:03 -0500
Subject: [PATCH 1/2] [SPIRV] Add Intermediate cast when Vector From/To types
 are of different element size and type

fixes https://github.com/llvm/llvm-project/issues/177838

- Replaced assert(TargetTypeSize == SourceTypeSize) with a conditional: when sizes differ, compute a BitcastType with the target's element type but sized to match the source's total bitwidth
- Instead of unconditionally returning after bitcast, only return early if BitcastType == TargetType (same-size case), otherwise fall through to the shuffle logic
- Updated the final assert to use AssignValue->getType() since that may now be the intermediate type rather than SourceType
---
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 25 ++++++++++++-------
 .../CodeGen/SPIRV/pointers/ptrcast-bitcast.ll | 21 ++++++++++++++++
 2 files changed, 37 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 1050a1a1a10c2..4191ff0d29447 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -78,18 +78,25 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     Value *AssignValue = NewLoad;
     if (TargetType->getElementType() != SourceType->getElementType()) {
       const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
-      [[maybe_unused]] TypeSize TargetTypeSize =
-          DL.getTypeSizeInBits(TargetType);
-      [[maybe_unused]] TypeSize SourceTypeSize =
-          DL.getTypeSizeInBits(SourceType);
-      assert(TargetTypeSize == SourceTypeSize);
+      TypeSize TargetTypeSize = DL.getTypeSizeInBits(TargetType);
+      TypeSize SourceTypeSize = DL.getTypeSizeInBits(SourceType);
+      // Bitcast to a same-bitwidth vector with TargetType's element type.
+      // When sizes differ, use an intermediate type to preserve bitwidth.
+      auto *BitcastType = TargetType;
+      if (TargetTypeSize != SourceTypeSize) {
+        unsigned ElemBits = TargetType->getElementType()->getScalarSizeInBits();
+        BitcastType = FixedVectorType::get(TargetType->getElementType(),
+                                           SourceTypeSize / ElemBits);
+      }
       AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
-                                      {TargetType, SourceType}, {NewLoad});
-      buildAssignType(B, TargetType, AssignValue);
-      return AssignValue;
+                                      {BitcastType, SourceType}, {NewLoad});
+      buildAssignType(B, BitcastType, AssignValue);
+      if (BitcastType == TargetType)
+        return AssignValue;
     }
 
-    assert(TargetType->getNumElements() < SourceType->getNumElements());
+    assert(TargetType->getNumElements() <
+           cast<FixedVectorType>(AssignValue->getType())->getNumElements());
     SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
     for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
       Mask[I] = I;
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index a1ec2cd1cfdd2..d735bc93d1075 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -1,10 +1,12 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s --match-full-lines
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
 
+; CHECK-DAG:       %[[#float:]] = OpTypeFloat 32
 ; CHECK-DAG:        %[[#uint:]] = OpTypeInt 32 0
 ; CHECK-DAG:     %[[#v2_uint:]] = OpTypeVector %[[#uint]] 2
 ; CHECK-DAG:      %[[#double:]] = OpTypeFloat 64
 ; CHECK-DAG:   %[[#v2_double:]] = OpTypeVector %[[#double]] 2
+; CHECK-DAG:    %[[#v4_float:]] = OpTypeVector %[[#float]] 4
 ; CHECK-DAG:     %[[#v4_uint:]] = OpTypeVector %[[#uint]] 4
 @.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
 @.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
@@ -47,4 +49,23 @@ entry:
   ret void
 }
 
+ at .str.3 = private unnamed_addr constant [3 x i8] c"In\00", align 1
+ at .str.4 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+
+define void @main3() local_unnamed_addr #0 {
+entry:
+; CHECK:       %[[LOAD3:[0-9]+]] = OpLoad %[[#v4_float]] {{.*}}
+; CHECK-NEXT:  %[[BITCAST3:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD3]]
+; CHECK-NEXT:  %[[SHUFFLE3:[0-9]+]] = OpVectorShuffle %[[#v2_uint]] %[[BITCAST3]] %[[BITCAST3]] 0 1
+; CHECK:       OpStore {{%[0-9]+}} %[[SHUFFLE3]] {{.*}}
+
+  %0 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str.3)
+  %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 1, i32 1, i32 0, ptr nonnull @.str.4)
+  %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %0, i32 0)
+  %3 = load <2 x i32>, ptr addrspace(11) %2, align 16
+  %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %1, i32 0)
+  store <2 x i32> %3, ptr addrspace(11) %4, align 8
+  ret void
+}
+
 attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

>From bdce99e75535523f2f8380d7f3612a494192e2e9 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 19 Feb 2026 14:48:23 -0500
Subject: [PATCH 2/2] address pr comments, handle the case where bit width is
 not evenly divisible

---
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 27 +++++++++++++++--
 .../CodeGen/SPIRV/pointers/ptrcast-bitcast.ll | 30 ++++++++++++++++---
 2 files changed, 51 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 4191ff0d29447..fdd1fed350630 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -85,8 +85,31 @@ class SPIRVLegalizePointerCast : public FunctionPass {
       auto *BitcastType = TargetType;
       if (TargetTypeSize != SourceTypeSize) {
         unsigned ElemBits = TargetType->getElementType()->getScalarSizeInBits();
-        BitcastType = FixedVectorType::get(TargetType->getElementType(),
-                                           SourceTypeSize / ElemBits);
+        if (SourceTypeSize % ElemBits == 0) {
+          BitcastType = FixedVectorType::get(TargetType->getElementType(),
+                                             SourceTypeSize / ElemBits);
+        } else {
+          // Source total bits aren't evenly divisible by target element bits.
+          // Resize source (extract or pad) to match target bit width using
+          // source element type, then bitcast to target.
+          unsigned SourceElemBits =
+              SourceType->getElementType()->getScalarSizeInBits();
+          assert(TargetTypeSize % SourceElemBits == 0 &&
+                 "Target size must be a multiple of source element size");
+          unsigned NumNeeded = TargetTypeSize / SourceElemBits;
+          unsigned NumSource = SourceType->getNumElements();
+          auto *ResizedType =
+              FixedVectorType::get(SourceType->getElementType(), NumNeeded);
+          SmallVector<int> Mask(NumNeeded);
+          for (unsigned I = 0; I < NumNeeded; ++I)
+            Mask[I] = (I < NumSource) ? static_cast<int>(I) : -1;
+          Value *Resized = B.CreateShuffleVector(NewLoad, NewLoad, Mask);
+          buildAssignType(B, ResizedType, Resized);
+          AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
+                                          {TargetType, ResizedType}, {Resized});
+          buildAssignType(B, TargetType, AssignValue);
+          return AssignValue;
+        }
       }
       AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
                                       {BitcastType, SourceType}, {NewLoad});
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index d735bc93d1075..7149a5cba5f68 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -49,8 +49,8 @@ entry:
   ret void
 }
 
- at .str.3 = private unnamed_addr constant [3 x i8] c"In\00", align 1
- at .str.4 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+ at .str.3 = private unnamed_addr constant [4 x i8] c"In2\00", align 1
+ at .str.4 = private unnamed_addr constant [5 x i8] c"Out2\00", align 1
 
 define void @main3() local_unnamed_addr #0 {
 entry:
@@ -59,8 +59,8 @@ entry:
 ; CHECK-NEXT:  %[[SHUFFLE3:[0-9]+]] = OpVectorShuffle %[[#v2_uint]] %[[BITCAST3]] %[[BITCAST3]] 0 1
 ; CHECK:       OpStore {{%[0-9]+}} %[[SHUFFLE3]] {{.*}}
 
-  %0 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str.3)
-  %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 1, i32 1, i32 0, ptr nonnull @.str.4)
+  %0 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 1, i32 0, i32 1, i32 0, ptr nonnull @.str.3)
+  %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 1, i32 1, i32 1, i32 0, ptr nonnull @.str.4)
   %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %0, i32 0)
   %3 = load <2 x i32>, ptr addrspace(11) %2, align 16
   %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %1, i32 0)
@@ -68,4 +68,26 @@ entry:
   ret void
 }
 
+; Tests loading a vector where the source total bit width is not evenly
+; divisible by the target element bit width.
+
+ at .str.in = private unnamed_addr constant [3 x i8] c"In3\00", align 1
+ at .str.out = private unnamed_addr constant [4 x i8] c"Out3\00", align 1
+
+define void @main4() local_unnamed_addr #0 {
+entry:
+; CHECK:       %[[LOAD:[0-9]+]] = OpLoad %[[#v3_uint]] {{.*}}
+; CHECK-NEXT:  %[[SHUFFLE:[0-9]+]] = OpVectorShuffle %[[#v4_uint]] %[[LOAD]] %[[LOAD]] 0 1 2 0xFFFFFFFF{{.*}}
+; CHECK-NEXT:  %[[BITCAST:[0-9]+]] = OpBitcast %[[#v2_ulong]] %[[SHUFFLE]]{{.*}}
+; CHECK:       OpStore {{%[0-9]+}} %[[BITCAST]] {{.*}}
+
+  %0 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str.in)
+  %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i64>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v1i64_12_0t(i32 0, i32 1, i32 1, i32 0, ptr nonnull @.str.out)
+  %2 = tail call noundef align 16 dereferenceable(12) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v3i32_12_0t(target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 0) %0, i32 0)
+  %3 = load <2 x i64>, ptr addrspace(11) %2, align 16
+  %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v1i64_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i64>], 12, 0) %1, i32 0)
+  store <2 x i64> %3, ptr addrspace(11) %4, align 8
+  ret void
+}
+
 attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }



More information about the llvm-commits mailing list