[llvm] [SPIRV][Matrix] Legalize store of matrix to array of vector memory layout (PR #188139)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 26 13:55:37 PDT 2026
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/188139
>From f0277666cb9aaca1eabfb25e2e9e6c1f5d30cbfd Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 23 Mar 2026 17:28:45 -0400
Subject: [PATCH 1/2] [SPIRV][Matrix] Legalize store of matrix to array of
vector memory layout
fixes #188131
fixes #188130
This change address stylistic changes @bogners requested in https://github.com/llvm/llvm-project/pull/186215/
It also adds the `storeMatrixArrayFromVector`. to
SPIRVLegalizePointerCast.cpp when we detect the matrix array of vector
memory layout
Changes to storeArrayFromVector were cleanup
Finally changes to SPIRVGlobalRegistry.cpp addressed #188130 vec3 to
vec4 layout so we don't need --scalar-block-layout.
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 6 +
.../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 95 ++++++++++------
.../store-array-of-vectors-to-vector.ll | 107 ++++++++++++++++++
3 files changed, 171 insertions(+), 37 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/pointers/store-array-of-vectors-to-vector.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 73fca3ee18bce..c9c54b989f273 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -2280,6 +2280,12 @@ void SPIRVGlobalRegistry::addStructOffsetDecorations(
void SPIRVGlobalRegistry::addArrayStrideDecorations(
Register Reg, Type *ElementType, MachineIRBuilder &MIRBuilder) {
uint32_t SizeInBytes = DataLayout().getTypeSizeInBits(ElementType) / 8;
+ // Vulkan requires vec3 to have the same alignment as vec4
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ElementType);
+ VecTy && VecTy->getNumElements() == 3) {
+ uint32_t ScalarSize = VecTy->getScalarSizeInBits() / 8;
+ SizeInBytes = 4 * ScalarSize;
+ }
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::ArrayStride,
{SizeInBytes});
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 925b1b00336b5..3e3829a3c52bf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -172,7 +172,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
Value *NewVector = PoisonValue::get(TargetType);
buildAssignType(B, TargetType, NewVector);
- for (unsigned I = 0; I < TargetType->getNumElements(); ++I) {
+ for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
Value *Index = B.getInt32(I);
SmallVector<Type *, 4> Types = {TargetType, TargetType,
TargetType->getElementType(),
@@ -194,15 +194,13 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// Load each element of the array.
SmallVector<Value *, 4> LoadedElements;
SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
- for (unsigned I = 0; I < TargetType->getNumElements(); ++I) {
+ for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
unsigned ArrayIndex = I / ScalarsPerArrayElement;
unsigned ElementIndexInArrayElem = I % ScalarsPerArrayElement;
// Create a GEP to access the i-th element of the array.
- SmallVector<Value *, 4> Args;
- Args.push_back(B.getInt1(/*Inbounds=*/false));
- Args.push_back(Source);
- Args.push_back(B.getInt32(0));
- Args.push_back(ConstantInt::get(B.getInt32Ty(), ArrayIndex));
+ std::array<Value *, 4> Args = {
+ B.getInt1(/*Inbounds=*/false), Source, B.getInt32(0),
+ ConstantInt::get(B.getInt32Ty(), ArrayIndex)};
auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
GR->buildAssignPtr(B, ArrElemVecTy, ElementPtr);
Value *LoadVec = B.CreateLoad(ArrElemVecTy, ElementPtr);
@@ -218,13 +216,11 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// Load each element of the array.
SmallVector<Value *, 4> LoadedElements;
SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
- for (unsigned I = 0; I < TargetType->getNumElements(); ++I) {
+ for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
// Create a GEP to access the i-th element of the array.
- SmallVector<Value *, 4> Args;
- Args.push_back(B.getInt1(/*Inbounds=*/false));
- Args.push_back(Source);
- Args.push_back(B.getInt32(0));
- Args.push_back(ConstantInt::get(B.getInt32Ty(), I));
+ std::array<Value *, 4> Args = {B.getInt1(/*Inbounds=*/false), Source,
+ B.getInt32(0),
+ ConstantInt::get(B.getInt32Ty(), I)};
auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
GR->buildAssignPtr(B, TargetType->getElementType(), ElementPtr);
@@ -236,44 +232,65 @@ class SPIRVLegalizePointerCast : public FunctionPass {
return buildVectorFromLoadedElements(B, TargetType, LoadedElements);
}
+ void storeMatrixArrayFromVector(IRBuilder<> &B, Value *SrcVector,
+ Value *DstArrayPtr, ArrayType *ArrTy,
+ Align Alignment) {
+ auto *SrcVecTy = cast<FixedVectorType>(SrcVector->getType());
+ auto *ArrElemVecTy = cast<FixedVectorType>(ArrTy->getElementType());
+ Type *ElemTy = ArrElemVecTy->getElementType();
+ unsigned ScalarsPerArrayElement = ArrElemVecTy->getNumElements();
+ unsigned SrcNumElements = SrcVecTy->getNumElements();
+
+ SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
+ DstArrayPtr->getType()};
+
+ for (unsigned I = 0; I < SrcNumElements; I += ScalarsPerArrayElement) {
+ unsigned ArrayIndex = I / ScalarsPerArrayElement;
+ // Create a GEP to access the array element.
+ std::array<Value *, 4> Args = {
+ B.getInt1(/*Inbounds=*/false), DstArrayPtr, B.getInt32(0),
+ ConstantInt::get(B.getInt32Ty(), ArrayIndex)};
+ auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+ GR->buildAssignPtr(B, ArrElemVecTy, ElementPtr);
+
+ // Extract scalar elements from the source vector for this array slot.
+ SmallVector<Value *, 4> Elements;
+ for (unsigned J = 0; J < ScalarsPerArrayElement; ++J)
+ Elements.push_back(makeExtractElement(B, ElemTy, SrcVector, I + J));
+
+ // Build a vector from the extracted elements and store it.
+ Value *Vec = buildVectorFromLoadedElements(B, ArrElemVecTy, Elements);
+ StoreInst *SI = B.CreateStore(Vec, ElementPtr);
+ SI->setAlignment(Alignment);
+ }
+ }
+
// Stores elements from a vector into an array.
void storeArrayFromVector(IRBuilder<> &B, Value *SrcVector,
Value *DstArrayPtr, ArrayType *ArrTy,
Align Alignment) {
auto *VecTy = cast<FixedVectorType>(SrcVector->getType());
+ Type *ElemTy = ArrTy->getElementType();
// Ensure the element types of the array and vector are the same.
- assert(VecTy->getElementType() == ArrTy->getElementType() &&
+ assert(VecTy->getElementType() == ElemTy &&
"Element types of array and vector must be the same.");
- const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
- uint64_t ElemSize = DL.getTypeAllocSize(ArrTy->getElementType());
+ SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
+ DstArrayPtr->getType()};
- for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
+ for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
// Create a GEP to access the i-th element of the array.
- SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
- DstArrayPtr->getType()};
- SmallVector<Value *, 4> Args;
- Args.push_back(B.getInt1(false));
- Args.push_back(DstArrayPtr);
- Args.push_back(B.getInt32(0));
- Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
+ std::array<Value *, 4> Args = {B.getInt1(/*Inbounds=*/false), DstArrayPtr,
+ B.getInt32(0),
+ ConstantInt::get(B.getInt32Ty(), I)};
auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
- GR->buildAssignPtr(B, ArrTy->getElementType(), ElementPtr);
+ GR->buildAssignPtr(B, ElemTy, ElementPtr);
// Extract the element from the vector and store it.
- Value *Index = B.getInt32(i);
- SmallVector<Type *, 3> EltTypes = {VecTy->getElementType(), VecTy,
- Index->getType()};
- SmallVector<Value *, 2> EltArgs = {SrcVector, Index};
- Value *Element =
- B.CreateIntrinsic(Intrinsic::spv_extractelt, {EltTypes}, {EltArgs});
- buildAssignType(B, VecTy->getElementType(), Element);
-
- Types = {Element->getType(), ElementPtr->getType()};
- Align NewAlign = commonAlignment(Alignment, i * ElemSize);
- Args = {Element, ElementPtr, B.getInt16(2), B.getInt32(NewAlign.value())};
- B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
+ Value *Element = makeExtractElement(B, ElemTy, SrcVector, I);
+ StoreInst *SI = B.CreateStore(Element, ElementPtr);
+ SI->setAlignment(Alignment);
}
}
@@ -441,6 +458,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
auto *D_AT = dyn_cast<ArrayType>(ToTy);
+ auto *D_MAT =
+ D_AT ? dyn_cast<FixedVectorType>(D_AT->getElementType()) : nullptr;
B.SetInsertPoint(BadStore);
if (isTypeFirstElementAggregate(FromTy, ToTy))
@@ -451,6 +470,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType())
storeArrayFromVector(B, Src, Dst, D_AT, Alignment);
+ else if (D_MAT && S_VT && D_MAT->getElementType() == S_VT->getElementType())
+ storeMatrixArrayFromVector(B, Src, Dst, D_AT, Alignment);
else
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-array-of-vectors-to-vector.ll b/llvm/test/CodeGen/SPIRV/pointers/store-array-of-vectors-to-vector.ll
new file mode 100644
index 0000000000000..73585beb75a85
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-array-of-vectors-to-vector.ll
@@ -0,0 +1,107 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan %s -stop-after=spirv-legalize-bitcast -o - | FileCheck %s --check-prefix=IRCHECK
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %}
+
+; CHECK-DAG: [[FLOAT:%[0-9]+]] = OpTypeFloat 32
+; CHECK-DAG: [[VEC3FLOAT:%[0-9]+]] = OpTypeVector [[FLOAT]] 3
+; CHECK-DAG: [[PTR_VEC3:%[0-9]+]] = OpTypePointer StorageBuffer [[VEC3FLOAT]]
+; CHECK-DAG: [[UINT:%[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: [[UINT2:%[0-9]+]] = OpConstant [[UINT]] 2
+; CHECK-DAG: [[ARRAY2VEC3:%[0-9]+]] = OpTypeArray [[VEC3FLOAT]] [[UINT2]]
+; CHECK-DAG: [[PTR_ARRAY2VEC3:%[0-9]+]] = OpTypePointer StorageBuffer [[ARRAY2VEC3]]
+; CHECK-DAG: [[UINT0:%[0-9]+]] = OpConstant [[UINT]] 0
+; CHECK-DAG: [[UINT1:%[0-9]+]] = OpConstant [[UINT]] 1
+; CHECK-DAG: [[UNDEF_VEC3:%[0-9]+]] = OpUndef [[VEC3FLOAT]]
+
+; Load from input[0][0] (first vector)
+; CHECK: [[IN_AC_ARR:%[0-9]+]] = OpAccessChain [[PTR_ARRAY2VEC3]] %{{[0-9]+}} [[UINT0]] [[UINT0]]
+; CHECK-NEXT: [[IN_AC_VEC0:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT0]]
+; CHECK-NEXT: [[LOAD0:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC0]]
+; CHECK-NEXT: [[EX_0_0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD0]] 0
+; CHECK-NEXT: [[IN_AC_VEC0_2:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT0]]
+; CHECK-NEXT: [[LOAD1:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC0_2]]
+; CHECK-NEXT: [[EX_0_1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD1]] 1
+; CHECK-NEXT: [[IN_AC_VEC0_3:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT0]]
+; CHECK-NEXT: [[LOAD2:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC0_3]]
+; CHECK-NEXT: [[EX_0_2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD2]] 2
+
+; Load from input[0][1] (second vector)
+; CHECK-NEXT: [[IN_AC_VEC1:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT1]]
+; CHECK-NEXT: [[LOAD3:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC1]]
+; CHECK-NEXT: [[EX_1_0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD3]] 0
+; CHECK-NEXT: [[IN_AC_VEC1_2:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT1]]
+; CHECK-NEXT: [[LOAD4:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC1_2]]
+; CHECK-NEXT: [[EX_1_1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD4]] 1
+; CHECK-NEXT: [[IN_AC_VEC1_3:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT1]]
+; CHECK-NEXT: [[LOAD5:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC1_3]]
+; CHECK-NEXT: [[EX_1_2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD5]] 2
+
+; Store to output[0][0] (first vector)
+; CHECK-NEXT: [[OUT_AC_ARR:%[0-9]+]] = OpAccessChain [[PTR_ARRAY2VEC3]] %{{[0-9]+}} [[UINT0]] [[UINT0]]
+; CHECK-NEXT: [[OUT_AC_VEC0:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[OUT_AC_ARR]] [[UINT0]]
+; CHECK-NEXT: [[INS_0_0:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_0_0]] [[UNDEF_VEC3]] 0
+; CHECK-NEXT: [[INS_0_1:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_0_1]] [[INS_0_0]] 1
+; CHECK-NEXT: [[INS_0_2:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_0_2]] [[INS_0_1]] 2
+; CHECK-NEXT: OpStore [[OUT_AC_VEC0]] [[INS_0_2]]
+
+; Store to output[0][1] (second vector)
+; CHECK-NEXT: [[OUT_AC_VEC1:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[OUT_AC_ARR]] [[UINT1]]
+; CHECK-NEXT: [[INS_1_0:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_1_0]] [[UNDEF_VEC3]] 0
+; CHECK-NEXT: [[INS_1_1:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_1_1]] [[INS_1_0]] 1
+; CHECK-NEXT: [[INS_1_2:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_1_2]] [[INS_1_1]] 2
+; CHECK-NEXT: OpStore [[OUT_AC_VEC1]] [[INS_1_2]]
+; CHECK-NEXT: OpReturn
+
+; IRCHECK-LABEL: define void @main
+
+; Load: GEP to input row 0, load <3 x float>, extract elements
+; IRCHECK: [[IN_PTR:%[0-9]+]] = {{.*}}call {{.*}} ptr addrspace(11) @llvm.spv.resource.getpointer
+; IRCHECK: [[GEP_ROW0:%[0-9]+]] = call ptr addrspace(11) (i1, ptr addrspace(11), ...) @llvm.spv.gep.p11.p11(i1 false, ptr addrspace(11) [[IN_PTR]], i32 0, i32 0)
+; IRCHECK: [[VEC_ROW0:%[0-9]+]] = load <3 x float>, ptr addrspace(11) [[GEP_ROW0]]
+; IRCHECK: [[E00:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> [[VEC_ROW0]], i32 0)
+; IRCHECK: [[E01:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> {{%[0-9]+}}, i32 1)
+; IRCHECK: [[E02:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> {{%[0-9]+}}, i32 2)
+
+; Load: GEP to input row 1, load <3 x float>, extract elements
+; IRCHECK: [[GEP_ROW1:%[0-9]+]] = call ptr addrspace(11) (i1, ptr addrspace(11), ...) @llvm.spv.gep.p11.p11(i1 false, ptr addrspace(11) [[IN_PTR]], i32 0, i32 1)
+; IRCHECK: [[VEC_ROW1:%[0-9]+]] = load <3 x float>, ptr addrspace(11) [[GEP_ROW1]]
+; IRCHECK: [[E10:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> [[VEC_ROW1]], i32 0)
+; IRCHECK: [[E11:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> {{%[0-9]+}}, i32 1)
+; IRCHECK: [[E12:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> {{%[0-9]+}}, i32 2)
+
+; Store: GEP to output row 0, extract from <6 x float>, insert into <3 x float>, store
+; IRCHECK: [[OUT_PTR:%[0-9]+]] = {{.*}}call {{.*}} ptr addrspace(11) @llvm.spv.resource.getpointer
+; IRCHECK: [[OUT_GEP_ROW0:%[0-9]+]] = call ptr addrspace(11) (i1, ptr addrspace(11), ...) @llvm.spv.gep.p11.p11(i1 false, ptr addrspace(11) [[OUT_PTR]], i32 0, i32 0)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 0)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 1)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 2)
+; IRCHECK: call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> poison, float {{%[0-9]+}}, i32 0)
+; IRCHECK: call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> {{%[0-9]+}}, float {{%[0-9]+}}, i32 1)
+; IRCHECK: [[STORE_VEC0:%[0-9]+]] = call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> {{%[0-9]+}}, float {{%[0-9]+}}, i32 2)
+; IRCHECK: store <3 x float> [[STORE_VEC0]], ptr addrspace(11) [[OUT_GEP_ROW0]]
+
+; Store: GEP to output row 1, extract from <6 x float>, insert into <3 x float>, store
+; IRCHECK: [[OUT_GEP_ROW1:%[0-9]+]] = call ptr addrspace(11) (i1, ptr addrspace(11), ...) @llvm.spv.gep.p11.p11(i1 false, ptr addrspace(11) [[OUT_PTR]], i32 0, i32 1)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 3)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 4)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 5)
+; IRCHECK: call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> poison, float {{%[0-9]+}}, i32 0)
+; IRCHECK: call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> {{%[0-9]+}}, float {{%[0-9]+}}, i32 1)
+; IRCHECK: [[STORE_VEC1:%[0-9]+]] = call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> {{%[0-9]+}}, float {{%[0-9]+}}, i32 2)
+; IRCHECK: store <3 x float> [[STORE_VEC1]], ptr addrspace(11) [[OUT_GEP_ROW1]]
+
+ at .str = private unnamed_addr constant [4 x i8] c"InA\00", align 1
+ at .str.2 = private unnamed_addr constant [5 x i8] c"OutA\00", align 1
+
+define void @main() local_unnamed_addr #0 {
+entry:
+ %0 = tail call target("spirv.VulkanBuffer", [0 x [2 x <3 x float>]], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
+ %1 = tail call target("spirv.VulkanBuffer", [0 x [2 x <3 x float>]], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
+ %2 = tail call noundef align 4 dereferenceable(24) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer(target("spirv.VulkanBuffer", [0 x [2 x <3 x float>]], 12, 0) %0, i32 0)
+ %3 = load <6 x float>, ptr addrspace(11) %2, align 4
+ %4 = tail call noundef align 4 dereferenceable(24) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer(target("spirv.VulkanBuffer", [0 x [2 x <3 x float>]], 12, 1) %1, i32 0)
+ store <6 x float> %3, ptr addrspace(11) %4, align 4
+ ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
>From dd88444a635a4c87cb5b8897120a6bf310355248 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 26 Mar 2026 16:51:36 -0400
Subject: [PATCH 2/2] address pr comments
---
.../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 21 +++++++++++--------
1 file changed, 12 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 3e3829a3c52bf..93886f552fe2f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -149,7 +149,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// which should be the load being legalized. Returns the loaded value.
Value *loadFirstValueFromAggregate(IRBuilder<> &B, Type *ElementType,
Value *Source, LoadInst *BadLoad) {
- SmallVector<Type *, 2> Types = {BadLoad->getPointerOperandType(),
+ std::array<Type *, 2> Types = {BadLoad->getPointerOperandType(),
Source->getType()};
SmallVector<Value *, 8> Args{/* isInBounds= */ B.getInt1(false), Source};
@@ -193,7 +193,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
unsigned ScalarsPerArrayElement = ArrElemVecTy->getNumElements();
// Load each element of the array.
SmallVector<Value *, 4> LoadedElements;
- SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
+ std::array<Type *, 2> Types = {Source->getType(), Source->getType()};
for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
unsigned ArrayIndex = I / ScalarsPerArrayElement;
unsigned ElementIndexInArrayElem = I % ScalarsPerArrayElement;
@@ -215,7 +215,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
Value *Source) {
// Load each element of the array.
SmallVector<Value *, 4> LoadedElements;
- SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
+ std::array<Type *, 2> Types = {Source->getType(), Source->getType()};
for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
// Create a GEP to access the i-th element of the array.
std::array<Value *, 4> Args = {B.getInt1(/*Inbounds=*/false), Source,
@@ -232,6 +232,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
return buildVectorFromLoadedElements(B, TargetType, LoadedElements);
}
+ // Stores elements from a vector into a matrix (an array of vectors).
void storeMatrixArrayFromVector(IRBuilder<> &B, Value *SrcVector,
Value *DstArrayPtr, ArrayType *ArrTy,
Align Alignment) {
@@ -240,8 +241,11 @@ class SPIRVLegalizePointerCast : public FunctionPass {
Type *ElemTy = ArrElemVecTy->getElementType();
unsigned ScalarsPerArrayElement = ArrElemVecTy->getNumElements();
unsigned SrcNumElements = SrcVecTy->getNumElements();
+ assert(
+ SrcNumElements % ScalarsPerArrayElement == 0 &&
+ "Source vector size must be a multiple of array element vector size");
- SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
+ std::array<Type *, 2> Types = {DstArrayPtr->getType(),
DstArrayPtr->getType()};
for (unsigned I = 0; I < SrcNumElements; I += ScalarsPerArrayElement) {
@@ -275,11 +279,10 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// Ensure the element types of the array and vector are the same.
assert(VecTy->getElementType() == ElemTy &&
"Element types of array and vector must be the same.");
+ std::array<Type *, 2> Types = {DstArrayPtr->getType(),
+ DstArrayPtr->getType()};
- SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
- DstArrayPtr->getType()};
-
- for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
+ for (unsigned I = 0, E = VecTy->getNumElements(); I < E; ++I) {
// Create a GEP to access the i-th element of the array.
std::array<Value *, 4> Args = {B.getInt1(/*Inbounds=*/false), DstArrayPtr,
B.getInt32(0),
@@ -426,7 +429,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// Stores the given Src value into the first entry of the Dst aggregate.
Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
Type *DstPointeeType, Align Alignment) {
- SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
+ std::array<Type *, 2> Types = {Dst->getType(), Dst->getType()};
SmallVector<Value *, 8> Args{/* isInBounds= */ B.getInt1(true), Dst};
buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
More information about the llvm-commits
mailing list