[llvm] [SPIRV] Add Intermediate cast when Vector From/To types are of different element size and type (PR #182166)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 18 14:09:36 PST 2026


https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/182166


fixes https://github.com/llvm/llvm-project/issues/177838

- Replaced assert(TargetTypeSize == SourceTypeSize) with a conditional: when sizes differ, compute a BitcastType with the target's element type but sized to match the source's total bitwidth
- Instead of unconditionally returning after bitcast, only return early if BitcastType == TargetType (same-size case), otherwise fall through to the shuffle logic
- Updated the final assert to use AssignValue->getType() since that may now be the intermediate type rather than SourceType

>From 8bb8724ec291fb48797f6ca059cf1034a0c93989 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 18 Feb 2026 17:05:03 -0500
Subject: [PATCH] [SPIRV] Add Intermediate cast when Vector From/To types are
 of different element size and type

fixes https://github.com/llvm/llvm-project/issues/177838

- Replaced assert(TargetTypeSize == SourceTypeSize) with a conditional: when sizes differ, compute a BitcastType with the target's element type but sized to match the source's total bitwidth
- Instead of unconditionally returning after bitcast, only return early if BitcastType == TargetType (same-size case), otherwise fall through to the shuffle logic
- Updated the final assert to use AssignValue->getType() since that may now be the intermediate type rather than SourceType
---
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 25 ++++++++++++-------
 .../CodeGen/SPIRV/pointers/ptrcast-bitcast.ll | 21 ++++++++++++++++
 2 files changed, 37 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 1050a1a1a10c2..4191ff0d29447 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -78,18 +78,25 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     Value *AssignValue = NewLoad;
     if (TargetType->getElementType() != SourceType->getElementType()) {
       const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
-      [[maybe_unused]] TypeSize TargetTypeSize =
-          DL.getTypeSizeInBits(TargetType);
-      [[maybe_unused]] TypeSize SourceTypeSize =
-          DL.getTypeSizeInBits(SourceType);
-      assert(TargetTypeSize == SourceTypeSize);
+      TypeSize TargetTypeSize = DL.getTypeSizeInBits(TargetType);
+      TypeSize SourceTypeSize = DL.getTypeSizeInBits(SourceType);
+      // Bitcast to a same-bitwidth vector with TargetType's element type.
+      // When sizes differ, use an intermediate type to preserve bitwidth.
+      auto *BitcastType = TargetType;
+      if (TargetTypeSize != SourceTypeSize) {
+        unsigned ElemBits = TargetType->getElementType()->getScalarSizeInBits();
+        BitcastType = FixedVectorType::get(TargetType->getElementType(),
+                                           SourceTypeSize / ElemBits);
+      }
       AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
-                                      {TargetType, SourceType}, {NewLoad});
-      buildAssignType(B, TargetType, AssignValue);
-      return AssignValue;
+                                      {BitcastType, SourceType}, {NewLoad});
+      buildAssignType(B, BitcastType, AssignValue);
+      if (BitcastType == TargetType)
+        return AssignValue;
     }
 
-    assert(TargetType->getNumElements() < SourceType->getNumElements());
+    assert(TargetType->getNumElements() <
+           cast<FixedVectorType>(AssignValue->getType())->getNumElements());
     SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
     for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
       Mask[I] = I;
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index a1ec2cd1cfdd2..d735bc93d1075 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -1,10 +1,12 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s --match-full-lines
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
 
+; CHECK-DAG:       %[[#float:]] = OpTypeFloat 32
 ; CHECK-DAG:        %[[#uint:]] = OpTypeInt 32 0
 ; CHECK-DAG:     %[[#v2_uint:]] = OpTypeVector %[[#uint]] 2
 ; CHECK-DAG:      %[[#double:]] = OpTypeFloat 64
 ; CHECK-DAG:   %[[#v2_double:]] = OpTypeVector %[[#double]] 2
+; CHECK-DAG:    %[[#v4_float:]] = OpTypeVector %[[#float]] 4
 ; CHECK-DAG:     %[[#v4_uint:]] = OpTypeVector %[[#uint]] 4
 @.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
 @.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
@@ -47,4 +49,23 @@ entry:
   ret void
 }
 
+ at .str.3 = private unnamed_addr constant [3 x i8] c"In\00", align 1
+ at .str.4 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+
+define void @main3() local_unnamed_addr #0 {
+entry:
+; CHECK:       %[[LOAD3:[0-9]+]] = OpLoad %[[#v4_float]] {{.*}}
+; CHECK-NEXT:  %[[BITCAST3:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD3]]
+; CHECK-NEXT:  %[[SHUFFLE3:[0-9]+]] = OpVectorShuffle %[[#v2_uint]] %[[BITCAST3]] %[[BITCAST3]] 0 1
+; CHECK:       OpStore {{%[0-9]+}} %[[SHUFFLE3]] {{.*}}
+
+  %0 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str.3)
+  %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 1, i32 1, i32 0, ptr nonnull @.str.4)
+  %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %0, i32 0)
+  %3 = load <2 x i32>, ptr addrspace(11) %2, align 16
+  %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %1, i32 0)
+  store <2 x i32> %3, ptr addrspace(11) %4, align 8
+  ret void
+}
+
 attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }



More information about the llvm-commits mailing list