[llvm] [SPIRV] Add Intermediate cast when Vector From/To types are of different element size and type (PR #182166)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 19 11:48:40 PST 2026
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/182166
>From 8bb8724ec291fb48797f6ca059cf1034a0c93989 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 18 Feb 2026 17:05:03 -0500
Subject: [PATCH 1/2] [SPIRV] Add Intermediate cast when Vector From/To types
are of different element size and type
fixes https://github.com/llvm/llvm-project/issues/177838
- Replaced assert(TargetTypeSize == SourceTypeSize) with a conditional: when sizes differ, compute a BitcastType with the target's element type but sized to match the source's total bitwidth
- Instead of unconditionally returning after bitcast, only return early if BitcastType == TargetType (same-size case), otherwise fall through to the shuffle logic
- Updated the final assert to use AssignValue->getType() since that may now be the intermediate type rather than SourceType
---
.../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 25 ++++++++++++-------
.../CodeGen/SPIRV/pointers/ptrcast-bitcast.ll | 21 ++++++++++++++++
2 files changed, 37 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 1050a1a1a10c2..4191ff0d29447 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -78,18 +78,25 @@ class SPIRVLegalizePointerCast : public FunctionPass {
Value *AssignValue = NewLoad;
if (TargetType->getElementType() != SourceType->getElementType()) {
const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
- [[maybe_unused]] TypeSize TargetTypeSize =
- DL.getTypeSizeInBits(TargetType);
- [[maybe_unused]] TypeSize SourceTypeSize =
- DL.getTypeSizeInBits(SourceType);
- assert(TargetTypeSize == SourceTypeSize);
+ TypeSize TargetTypeSize = DL.getTypeSizeInBits(TargetType);
+ TypeSize SourceTypeSize = DL.getTypeSizeInBits(SourceType);
+ // Bitcast to a same-bitwidth vector with TargetType's element type.
+ // When sizes differ, use an intermediate type to preserve bitwidth.
+ auto *BitcastType = TargetType;
+ if (TargetTypeSize != SourceTypeSize) {
+ unsigned ElemBits = TargetType->getElementType()->getScalarSizeInBits();
+ BitcastType = FixedVectorType::get(TargetType->getElementType(),
+ SourceTypeSize / ElemBits);
+ }
AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
- {TargetType, SourceType}, {NewLoad});
- buildAssignType(B, TargetType, AssignValue);
- return AssignValue;
+ {BitcastType, SourceType}, {NewLoad});
+ buildAssignType(B, BitcastType, AssignValue);
+ if (BitcastType == TargetType)
+ return AssignValue;
}
- assert(TargetType->getNumElements() < SourceType->getNumElements());
+ assert(TargetType->getNumElements() <
+ cast<FixedVectorType>(AssignValue->getType())->getNumElements());
SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
Mask[I] = I;
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index a1ec2cd1cfdd2..d735bc93d1075 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -1,10 +1,12 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s --match-full-lines
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#v2_uint:]] = OpTypeVector %[[#uint]] 2
; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
; CHECK-DAG: %[[#v2_double:]] = OpTypeVector %[[#double]] 2
+; CHECK-DAG: %[[#v4_float:]] = OpTypeVector %[[#float]] 4
; CHECK-DAG: %[[#v4_uint:]] = OpTypeVector %[[#uint]] 4
@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
@.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
@@ -47,4 +49,23 @@ entry:
ret void
}
+ at .str.3 = private unnamed_addr constant [3 x i8] c"In\00", align 1
+ at .str.4 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+
+define void @main3() local_unnamed_addr #0 {
+entry:
+; CHECK: %[[LOAD3:[0-9]+]] = OpLoad %[[#v4_float]] {{.*}}
+; CHECK-NEXT: %[[BITCAST3:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD3]]
+; CHECK-NEXT: %[[SHUFFLE3:[0-9]+]] = OpVectorShuffle %[[#v2_uint]] %[[BITCAST3]] %[[BITCAST3]] 0 1
+; CHECK: OpStore {{%[0-9]+}} %[[SHUFFLE3]] {{.*}}
+
+ %0 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str.3)
+ %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 1, i32 1, i32 0, ptr nonnull @.str.4)
+ %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %0, i32 0)
+ %3 = load <2 x i32>, ptr addrspace(11) %2, align 16
+ %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %1, i32 0)
+ store <2 x i32> %3, ptr addrspace(11) %4, align 8
+ ret void
+}
+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
>From bdce99e75535523f2f8380d7f3612a494192e2e9 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 19 Feb 2026 14:48:23 -0500
Subject: [PATCH 2/2] address pr comments, handle the case where bit width is
not evenly divisible
---
.../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 27 +++++++++++++++--
.../CodeGen/SPIRV/pointers/ptrcast-bitcast.ll | 30 ++++++++++++++++---
2 files changed, 51 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 4191ff0d29447..fdd1fed350630 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -85,8 +85,31 @@ class SPIRVLegalizePointerCast : public FunctionPass {
auto *BitcastType = TargetType;
if (TargetTypeSize != SourceTypeSize) {
unsigned ElemBits = TargetType->getElementType()->getScalarSizeInBits();
- BitcastType = FixedVectorType::get(TargetType->getElementType(),
- SourceTypeSize / ElemBits);
+ if (SourceTypeSize % ElemBits == 0) {
+ BitcastType = FixedVectorType::get(TargetType->getElementType(),
+ SourceTypeSize / ElemBits);
+ } else {
+ // Source total bits aren't evenly divisible by target element bits.
+ // Resize source (extract or pad) to match target bit width using
+ // source element type, then bitcast to target.
+ unsigned SourceElemBits =
+ SourceType->getElementType()->getScalarSizeInBits();
+ assert(TargetTypeSize % SourceElemBits == 0 &&
+ "Target size must be a multiple of source element size");
+ unsigned NumNeeded = TargetTypeSize / SourceElemBits;
+ unsigned NumSource = SourceType->getNumElements();
+ auto *ResizedType =
+ FixedVectorType::get(SourceType->getElementType(), NumNeeded);
+ SmallVector<int> Mask(NumNeeded);
+ for (unsigned I = 0; I < NumNeeded; ++I)
+ Mask[I] = (I < NumSource) ? static_cast<int>(I) : -1;
+ Value *Resized = B.CreateShuffleVector(NewLoad, NewLoad, Mask);
+ buildAssignType(B, ResizedType, Resized);
+ AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
+ {TargetType, ResizedType}, {Resized});
+ buildAssignType(B, TargetType, AssignValue);
+ return AssignValue;
+ }
}
AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
{BitcastType, SourceType}, {NewLoad});
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index d735bc93d1075..7149a5cba5f68 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -49,8 +49,8 @@ entry:
ret void
}
- at .str.3 = private unnamed_addr constant [3 x i8] c"In\00", align 1
- at .str.4 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+ at .str.3 = private unnamed_addr constant [4 x i8] c"In2\00", align 1
+ at .str.4 = private unnamed_addr constant [5 x i8] c"Out2\00", align 1
define void @main3() local_unnamed_addr #0 {
entry:
@@ -59,8 +59,8 @@ entry:
; CHECK-NEXT: %[[SHUFFLE3:[0-9]+]] = OpVectorShuffle %[[#v2_uint]] %[[BITCAST3]] %[[BITCAST3]] 0 1
; CHECK: OpStore {{%[0-9]+}} %[[SHUFFLE3]] {{.*}}
- %0 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str.3)
- %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 1, i32 1, i32 0, ptr nonnull @.str.4)
+ %0 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 1, i32 0, i32 1, i32 0, ptr nonnull @.str.3)
+ %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 1, i32 1, i32 1, i32 0, ptr nonnull @.str.4)
%2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %0, i32 0)
%3 = load <2 x i32>, ptr addrspace(11) %2, align 16
%4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %1, i32 0)
@@ -68,4 +68,26 @@ entry:
ret void
}
+; Tests loading a vector where the source total bit width is not evenly
+; divisible by the target element bit width.
+
+ at .str.in = private unnamed_addr constant [3 x i8] c"In3\00", align 1
+ at .str.out = private unnamed_addr constant [4 x i8] c"Out3\00", align 1
+
+define void @main4() local_unnamed_addr #0 {
+entry:
+; CHECK: %[[LOAD:[0-9]+]] = OpLoad %[[#v3_uint]] {{.*}}
+; CHECK-NEXT: %[[SHUFFLE:[0-9]+]] = OpVectorShuffle %[[#v4_uint]] %[[LOAD]] %[[LOAD]] 0 1 2 0xFFFFFFFF{{.*}}
+; CHECK-NEXT: %[[BITCAST:[0-9]+]] = OpBitcast %[[#v2_ulong]] %[[SHUFFLE]]{{.*}}
+; CHECK: OpStore {{%[0-9]+}} %[[BITCAST]] {{.*}}
+
+ %0 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str.in)
+ %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i64>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v1i64_12_0t(i32 0, i32 1, i32 1, i32 0, ptr nonnull @.str.out)
+ %2 = tail call noundef align 16 dereferenceable(12) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v3i32_12_0t(target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 0) %0, i32 0)
+ %3 = load <2 x i64>, ptr addrspace(11) %2, align 16
+ %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v1i64_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i64>], 12, 0) %1, i32 0)
+ store <2 x i64> %3, ptr addrspace(11) %4, align 8
+ ret void
+}
+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
More information about the llvm-commits
mailing list