[llvm] [SPIRV] Handle ptrcast between array and vector types (PR #166418)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 4 11:27:06 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Steven Perron (s-perron)
<details>
<summary>Changes</summary>
This commit adds support for legalizing pointer casts between array and vector types within the SPIRV backend.
This is necessary to handle cases where a vector is loaded from or stored to an array, which can occur with HLSL matrix types.
The following changes are included:
- Added to load a vector from an array.
- Added to store a vector to an array.
- Added the test case to verify the functionality.
---
Full diff: https://github.com/llvm/llvm-project/pull/166418.diff
2 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp (+84)
- (added) llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll (+54)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 65dffc7908b78..87c0c8c5a7437 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -116,6 +116,85 @@ class SPIRVLegalizePointerCast : public FunctionPass {
return LI;
}
+ // Loads elements from an array and constructs a vector.
+ Value *loadVectorFromArray(IRBuilder<> &B, FixedVectorType *TargetType,
+ ArrayType *SourceType, Value *Source) {
+ // Ensure the element types of the array and vector are the same.
+ assert(TargetType->getElementType() == SourceType->getElementType() &&
+ "Element types of array and vector must be the same.");
+
+ // Load each element of the array.
+ SmallVector<Value *, 4> LoadedElements;
+ for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
+ // Create a GEP to access the i-th element of the array.
+ SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
+ SmallVector<Value *, 4> Args;
+ Args.push_back(B.getInt1(true));
+ Args.push_back(Source);
+ Args.push_back(B.getInt32(0));
+ Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
+ auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+ GR->buildAssignPtr(B, TargetType->getElementType(), ElementPtr);
+
+ // Load the value from the element pointer.
+ Value *Load = B.CreateLoad(TargetType->getElementType(), ElementPtr);
+ buildAssignType(B, TargetType->getElementType(), Load);
+ LoadedElements.push_back(Load);
+ }
+
+ // Build the vector from the loaded elements.
+ Value *NewVector = UndefValue::get(TargetType);
+ buildAssignType(B, TargetType, NewVector);
+
+ for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
+ Value *Index = B.getInt32(i);
+ SmallVector<Type *, 4> Types = {TargetType, TargetType,
+ TargetType->getElementType(),
+ Index->getType()};
+ SmallVector<Value *> Args = {NewVector, LoadedElements[i], Index};
+ NewVector = B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
+ buildAssignType(B, TargetType, NewVector);
+ }
+ return NewVector;
+ }
+
+ // Stores elements from a vector into an array.
+ void storeArrayFromVector(IRBuilder<> &B, Value *SrcVector,
+ Value *DstArrayPtr, ArrayType *ArrTy,
+ Align Alignment) {
+ auto *VecTy = cast<FixedVectorType>(SrcVector->getType());
+
+ // Ensure the element types of the array and vector are the same.
+ assert(VecTy->getElementType() == ArrTy->getElementType() &&
+ "Element types of array and vector must be the same.");
+
+ for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
+ // Create a GEP to access the i-th element of the array.
+ SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
+ DstArrayPtr->getType()};
+ SmallVector<Value *, 4> Args;
+ Args.push_back(B.getInt1(true));
+ Args.push_back(DstArrayPtr);
+ Args.push_back(B.getInt32(0));
+ Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
+ auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+ GR->buildAssignPtr(B, ArrTy->getElementType(), ElementPtr);
+
+ // Extract the element from the vector and store it.
+ Value *Index = B.getInt32(i);
+ SmallVector<Type *, 3> EltTypes = {VecTy->getElementType(), VecTy,
+ Index->getType()};
+ SmallVector<Value *, 2> EltArgs = {SrcVector, Index};
+ Value *Element =
+ B.CreateIntrinsic(Intrinsic::spv_extractelt, {EltTypes}, {EltArgs});
+ buildAssignType(B, VecTy->getElementType(), Element);
+
+ Types = {Element->getType(), ElementPtr->getType()};
+ Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())};
+ B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
+ }
+ }
+
// Replaces the load instruction to get rid of the ptrcast used as source
// operand.
void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
@@ -154,6 +233,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// - float v = s.m;
else if (SST && SST->getTypeAtIndex(0u) == ToTy)
Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);
+ else if (SAT && DVT && SAT->getElementType() == DVT->getElementType())
+ Output = loadVectorFromArray(B, DVT, SAT, OriginalOperand);
else
llvm_unreachable("Unimplemented implicit down-cast from load.");
@@ -288,6 +369,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
auto *D_ST = dyn_cast<StructType>(ToTy);
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
+ auto *D_AT = dyn_cast<ArrayType>(ToTy);
B.SetInsertPoint(BadStore);
if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
@@ -296,6 +378,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
storeVectorFromVector(B, Src, Dst, Alignment);
else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
+ else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType())
+ storeArrayFromVector(B, Src, Dst, D_AT, Alignment);
else
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll b/llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll
new file mode 100644
index 0000000000000..2a6d799c5bff6
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll
@@ -0,0 +1,54 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: [[FLOAT:%[0-9]+]] = OpTypeFloat 32
+; CHECK-DAG: [[VEC4FLOAT:%[0-9]+]] = OpTypeVector [[FLOAT]] 4
+; CHECK-DAG: [[UINT_TYPE:%[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: [[UINT4:%[0-9]+]] = OpConstant [[UINT_TYPE]] 4
+; CHECK-DAG: [[ARRAY4FLOAT:%[0-9]+]] = OpTypeArray [[FLOAT]] [[UINT4]]
+; CHECK-DAG: [[PTR_ARRAY4FLOAT:%[0-9]+]] = OpTypePointer Private [[ARRAY4FLOAT]]
+; CHECK-DAG: [[G_IN:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
+; CHECK-DAG: [[G_OUT:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
+; CHECK-DAG: [[UINT0:%[0-9]+]] = OpConstant [[UINT_TYPE]] 0
+; CHECK-DAG: [[UINT1:%[0-9]+]] = OpConstant [[UINT_TYPE]] 1
+; CHECK-DAG: [[UINT2:%[0-9]+]] = OpConstant [[UINT_TYPE]] 2
+; CHECK-DAG: [[UINT3:%[0-9]+]] = OpConstant [[UINT_TYPE]] 3
+; CHECK-DAG: [[PTR_FLOAT:%[0-9]+]] = OpTypePointer Private [[FLOAT]]
+; CHECK-DAG: [[UNDEF_VEC:%[0-9]+]] = OpUndef [[VEC4FLOAT]]
+
+ at G_in = internal addrspace(10) global [4 x float] zeroinitializer
+ at G_out = internal addrspace(10) global [4 x float] zeroinitializer
+
+define spir_func void @main() {
+entry:
+; CHECK: [[GEP0:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT0]]
+; CHECK-NEXT: [[LOAD0:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP0]]
+; CHECK-NEXT: [[GEP1:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT1]]
+; CHECK-NEXT: [[LOAD1:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP1]]
+; CHECK-NEXT: [[GEP2:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT2]]
+; CHECK-NEXT: [[LOAD2:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP2]]
+; CHECK-NEXT: [[GEP3:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT3]]
+; CHECK-NEXT: [[LOAD3:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP3]]
+; CHECK-NEXT: [[VEC_INSERT0:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD0]] [[UNDEF_VEC]] 0
+; CHECK-NEXT: [[VEC_INSERT1:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD1]] [[VEC_INSERT0]] 1
+; CHECK-NEXT: [[VEC_INSERT2:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD2]] [[VEC_INSERT1]] 2
+; CHECK-NEXT: [[VEC:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD3]] [[VEC_INSERT2]] 3
+ %0 = load <4 x float>, ptr addrspace(10) @G_in, align 64
+
+; CHECK-NEXT: [[GEP_OUT0:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT0]]
+; CHECK-NEXT: [[VEC_EXTRACT0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 0
+; CHECK-NEXT: OpStore [[GEP_OUT0]] [[VEC_EXTRACT0]]
+; CHECK-NEXT: [[GEP_OUT1:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT1]]
+; CHECK-NEXT: [[VEC_EXTRACT1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 1
+; CHECK-NEXT: OpStore [[GEP_OUT1]] [[VEC_EXTRACT1]]
+; CHECK-NEXT: [[GEP_OUT2:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT2]]
+; CHECK-NEXT: [[VEC_EXTRACT2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 2
+; CHECK-NEXT: OpStore [[GEP_OUT2]] [[VEC_EXTRACT2]]
+; CHECK-NEXT: [[GEP_OUT3:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT3]]
+; CHECK-NEXT: [[VEC_EXTRACT3:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 3
+; CHECK-NEXT: OpStore [[GEP_OUT3]] [[VEC_EXTRACT3]]
+ store <4 x float> %0, ptr addrspace(10) @G_out, align 64
+
+; CHECK-NEXT: OpReturn
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/166418
More information about the llvm-commits
mailing list