[llvm] [SPIRV] Add Intermediate cast when Vector From/To types are of different element size and type (PR #182166)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 18 14:09:36 PST 2026
https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/182166
fixes https://github.com/llvm/llvm-project/issues/177838
- Replaced assert(TargetTypeSize == SourceTypeSize) with a conditional: when sizes differ, compute a BitcastType with the target's element type but sized to match the source's total bitwidth
- Instead of unconditionally returning after bitcast, only return early if BitcastType == TargetType (same-size case), otherwise fall through to the shuffle logic
- Updated the final assert to use AssignValue->getType() since that may now be the intermediate type rather than SourceType
>From 8bb8724ec291fb48797f6ca059cf1034a0c93989 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 18 Feb 2026 17:05:03 -0500
Subject: [PATCH] [SPIRV] Add Intermediate cast when Vector From/To types are
of different element size and type
fixes https://github.com/llvm/llvm-project/issues/177838
- Replaced assert(TargetTypeSize == SourceTypeSize) with a conditional: when sizes differ, compute a BitcastType with the target's element type but sized to match the source's total bitwidth
- Instead of unconditionally returning after bitcast, only return early if BitcastType == TargetType (same-size case), otherwise fall through to the shuffle logic
- Updated the final assert to use AssignValue->getType() since that may now be the intermediate type rather than SourceType
---
.../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 25 ++++++++++++-------
.../CodeGen/SPIRV/pointers/ptrcast-bitcast.ll | 21 ++++++++++++++++
2 files changed, 37 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 1050a1a1a10c2..4191ff0d29447 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -78,18 +78,25 @@ class SPIRVLegalizePointerCast : public FunctionPass {
Value *AssignValue = NewLoad;
if (TargetType->getElementType() != SourceType->getElementType()) {
const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
- [[maybe_unused]] TypeSize TargetTypeSize =
- DL.getTypeSizeInBits(TargetType);
- [[maybe_unused]] TypeSize SourceTypeSize =
- DL.getTypeSizeInBits(SourceType);
- assert(TargetTypeSize == SourceTypeSize);
+ TypeSize TargetTypeSize = DL.getTypeSizeInBits(TargetType);
+ TypeSize SourceTypeSize = DL.getTypeSizeInBits(SourceType);
+ // Bitcast to a same-bitwidth vector with TargetType's element type.
+ // When sizes differ, use an intermediate type to preserve bitwidth.
+ auto *BitcastType = TargetType;
+ if (TargetTypeSize != SourceTypeSize) {
+ unsigned ElemBits = TargetType->getElementType()->getScalarSizeInBits();
+ BitcastType = FixedVectorType::get(TargetType->getElementType(),
+ SourceTypeSize / ElemBits);
+ }
AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
- {TargetType, SourceType}, {NewLoad});
- buildAssignType(B, TargetType, AssignValue);
- return AssignValue;
+ {BitcastType, SourceType}, {NewLoad});
+ buildAssignType(B, BitcastType, AssignValue);
+ if (BitcastType == TargetType)
+ return AssignValue;
}
- assert(TargetType->getNumElements() < SourceType->getNumElements());
+ assert(TargetType->getNumElements() <
+ cast<FixedVectorType>(AssignValue->getType())->getNumElements());
SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
Mask[I] = I;
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index a1ec2cd1cfdd2..d735bc93d1075 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -1,10 +1,12 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s --match-full-lines
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#v2_uint:]] = OpTypeVector %[[#uint]] 2
; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
; CHECK-DAG: %[[#v2_double:]] = OpTypeVector %[[#double]] 2
+; CHECK-DAG: %[[#v4_float:]] = OpTypeVector %[[#float]] 4
; CHECK-DAG: %[[#v4_uint:]] = OpTypeVector %[[#uint]] 4
@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
@.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
@@ -47,4 +49,23 @@ entry:
ret void
}
+ at .str.3 = private unnamed_addr constant [3 x i8] c"In\00", align 1
+ at .str.4 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+
+define void @main3() local_unnamed_addr #0 {
+entry:
+; CHECK: %[[LOAD3:[0-9]+]] = OpLoad %[[#v4_float]] {{.*}}
+; CHECK-NEXT: %[[BITCAST3:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD3]]
+; CHECK-NEXT: %[[SHUFFLE3:[0-9]+]] = OpVectorShuffle %[[#v2_uint]] %[[BITCAST3]] %[[BITCAST3]] 0 1
+; CHECK: OpStore {{%[0-9]+}} %[[SHUFFLE3]] {{.*}}
+
+ %0 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str.3)
+ %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 1, i32 1, i32 0, ptr nonnull @.str.4)
+ %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %0, i32 0)
+ %3 = load <2 x i32>, ptr addrspace(11) %2, align 16
+ %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %1, i32 0)
+ store <2 x i32> %3, ptr addrspace(11) %4, align 8
+ ret void
+}
+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
More information about the llvm-commits
mailing list