[llvm] [SPIRV] Handle ptrcast between array and vector types (PR #166418)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 4 11:27:06 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Steven Perron (s-perron)

<details>
<summary>Changes</summary>

This commit adds support for legalizing pointer casts between array and vector types within the SPIRV backend.

This is necessary to handle cases where a vector is loaded from or stored to an array, which can occur with HLSL matrix types.

The following changes are included:
- Added  to load a vector from an array.
- Added  to store a vector to an array.
- Added the  test case to verify the functionality.


---
Full diff: https://github.com/llvm/llvm-project/pull/166418.diff


2 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp (+84) 
- (added) llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll (+54) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 65dffc7908b78..87c0c8c5a7437 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -116,6 +116,85 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     return LI;
   }
 
+  // Loads elements from an array and constructs a vector.
+  Value *loadVectorFromArray(IRBuilder<> &B, FixedVectorType *TargetType,
+                             ArrayType *SourceType, Value *Source) {
+    // Ensure the element types of the array and vector are the same.
+    assert(TargetType->getElementType() == SourceType->getElementType() &&
+           "Element types of array and vector must be the same.");
+
+    // Load each element of the array.
+    SmallVector<Value *, 4> LoadedElements;
+    for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
+      // Create a GEP to access the i-th element of the array.
+      SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
+      SmallVector<Value *, 4> Args;
+      Args.push_back(B.getInt1(true));
+      Args.push_back(Source);
+      Args.push_back(B.getInt32(0));
+      Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
+      auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+      GR->buildAssignPtr(B, TargetType->getElementType(), ElementPtr);
+
+      // Load the value from the element pointer.
+      Value *Load = B.CreateLoad(TargetType->getElementType(), ElementPtr);
+      buildAssignType(B, TargetType->getElementType(), Load);
+      LoadedElements.push_back(Load);
+    }
+
+    // Build the vector from the loaded elements.
+    Value *NewVector = UndefValue::get(TargetType);
+    buildAssignType(B, TargetType, NewVector);
+
+    for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
+      Value *Index = B.getInt32(i);
+      SmallVector<Type *, 4> Types = {TargetType, TargetType,
+                                      TargetType->getElementType(),
+                                      Index->getType()};
+      SmallVector<Value *> Args = {NewVector, LoadedElements[i], Index};
+      NewVector = B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
+      buildAssignType(B, TargetType, NewVector);
+    }
+    return NewVector;
+  }
+
+  // Stores elements from a vector into an array.
+  void storeArrayFromVector(IRBuilder<> &B, Value *SrcVector,
+                            Value *DstArrayPtr, ArrayType *ArrTy,
+                            Align Alignment) {
+    auto *VecTy = cast<FixedVectorType>(SrcVector->getType());
+
+    // Ensure the element types of the array and vector are the same.
+    assert(VecTy->getElementType() == ArrTy->getElementType() &&
+           "Element types of array and vector must be the same.");
+
+    for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
+      // Create a GEP to access the i-th element of the array.
+      SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
+                                      DstArrayPtr->getType()};
+      SmallVector<Value *, 4> Args;
+      Args.push_back(B.getInt1(true));
+      Args.push_back(DstArrayPtr);
+      Args.push_back(B.getInt32(0));
+      Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
+      auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+      GR->buildAssignPtr(B, ArrTy->getElementType(), ElementPtr);
+
+      // Extract the element from the vector and store it.
+      Value *Index = B.getInt32(i);
+      SmallVector<Type *, 3> EltTypes = {VecTy->getElementType(), VecTy,
+                                         Index->getType()};
+      SmallVector<Value *, 2> EltArgs = {SrcVector, Index};
+      Value *Element =
+          B.CreateIntrinsic(Intrinsic::spv_extractelt, {EltTypes}, {EltArgs});
+      buildAssignType(B, VecTy->getElementType(), Element);
+
+      Types = {Element->getType(), ElementPtr->getType()};
+      Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())};
+      B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
+    }
+  }
+
   // Replaces the load instruction to get rid of the ptrcast used as source
   // operand.
   void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
@@ -154,6 +233,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     // - float v = s.m;
     else if (SST && SST->getTypeAtIndex(0u) == ToTy)
       Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);
+    else if (SAT && DVT && SAT->getElementType() == DVT->getElementType())
+      Output = loadVectorFromArray(B, DVT, SAT, OriginalOperand);
     else
       llvm_unreachable("Unimplemented implicit down-cast from load.");
 
@@ -288,6 +369,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
     auto *D_ST = dyn_cast<StructType>(ToTy);
     auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
+    auto *D_AT = dyn_cast<ArrayType>(ToTy);
 
     B.SetInsertPoint(BadStore);
     if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
@@ -296,6 +378,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
       storeVectorFromVector(B, Src, Dst, Alignment);
     else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
       storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
+    else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType())
+      storeArrayFromVector(B, Src, Dst, D_AT, Alignment);
     else
       llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
 
diff --git a/llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll b/llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll
new file mode 100644
index 0000000000000..2a6d799c5bff6
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll
@@ -0,0 +1,54 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: [[FLOAT:%[0-9]+]] = OpTypeFloat 32
+; CHECK-DAG: [[VEC4FLOAT:%[0-9]+]] = OpTypeVector [[FLOAT]] 4
+; CHECK-DAG: [[UINT_TYPE:%[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: [[UINT4:%[0-9]+]] = OpConstant [[UINT_TYPE]] 4
+; CHECK-DAG: [[ARRAY4FLOAT:%[0-9]+]] = OpTypeArray [[FLOAT]] [[UINT4]]
+; CHECK-DAG: [[PTR_ARRAY4FLOAT:%[0-9]+]] = OpTypePointer Private [[ARRAY4FLOAT]]
+; CHECK-DAG: [[G_IN:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
+; CHECK-DAG: [[G_OUT:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
+; CHECK-DAG: [[UINT0:%[0-9]+]] = OpConstant [[UINT_TYPE]] 0
+; CHECK-DAG: [[UINT1:%[0-9]+]] = OpConstant [[UINT_TYPE]] 1
+; CHECK-DAG: [[UINT2:%[0-9]+]] = OpConstant [[UINT_TYPE]] 2
+; CHECK-DAG: [[UINT3:%[0-9]+]] = OpConstant [[UINT_TYPE]] 3
+; CHECK-DAG: [[PTR_FLOAT:%[0-9]+]] = OpTypePointer Private [[FLOAT]]
+; CHECK-DAG: [[UNDEF_VEC:%[0-9]+]] = OpUndef [[VEC4FLOAT]]
+
+ at G_in = internal addrspace(10) global [4 x float] zeroinitializer
+ at G_out = internal addrspace(10) global [4 x float] zeroinitializer
+
+define spir_func void @main() {
+entry:
+; CHECK:      [[GEP0:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT0]]
+; CHECK-NEXT: [[LOAD0:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP0]]
+; CHECK-NEXT: [[GEP1:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT1]]
+; CHECK-NEXT: [[LOAD1:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP1]]
+; CHECK-NEXT: [[GEP2:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT2]]
+; CHECK-NEXT: [[LOAD2:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP2]]
+; CHECK-NEXT: [[GEP3:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT3]]
+; CHECK-NEXT: [[LOAD3:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP3]]
+; CHECK-NEXT: [[VEC_INSERT0:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD0]] [[UNDEF_VEC]] 0
+; CHECK-NEXT: [[VEC_INSERT1:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD1]] [[VEC_INSERT0]] 1
+; CHECK-NEXT: [[VEC_INSERT2:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD2]] [[VEC_INSERT1]] 2
+; CHECK-NEXT: [[VEC:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD3]] [[VEC_INSERT2]] 3
+  %0 = load <4 x float>, ptr addrspace(10) @G_in, align 64
+
+; CHECK-NEXT: [[GEP_OUT0:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT0]]
+; CHECK-NEXT: [[VEC_EXTRACT0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 0
+; CHECK-NEXT: OpStore [[GEP_OUT0]] [[VEC_EXTRACT0]]
+; CHECK-NEXT: [[GEP_OUT1:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT1]]
+; CHECK-NEXT: [[VEC_EXTRACT1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 1
+; CHECK-NEXT: OpStore [[GEP_OUT1]] [[VEC_EXTRACT1]]
+; CHECK-NEXT: [[GEP_OUT2:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT2]]
+; CHECK-NEXT: [[VEC_EXTRACT2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 2
+; CHECK-NEXT: OpStore [[GEP_OUT2]] [[VEC_EXTRACT2]]
+; CHECK-NEXT: [[GEP_OUT3:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT3]]
+; CHECK-NEXT: [[VEC_EXTRACT3:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 3
+; CHECK-NEXT: OpStore [[GEP_OUT3]] [[VEC_EXTRACT3]]
+  store <4 x float> %0, ptr addrspace(10) @G_out, align 64
+
+; CHECK-NEXT: OpReturn
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/166418


More information about the llvm-commits mailing list