[llvm] [SPIRV][Matrix] Legalize store of matrix to array of vector memory layout (PR #188139)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 26 13:55:37 PDT 2026


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/188139

>From f0277666cb9aaca1eabfb25e2e9e6c1f5d30cbfd Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 23 Mar 2026 17:28:45 -0400
Subject: [PATCH 1/2] [SPIRV][Matrix] Legalize store of matrix to array of
 vector memory layout

fixes #188131
fixes #188130

This change address stylistic changes @bogners requested in https://github.com/llvm/llvm-project/pull/186215/
It also adds the `storeMatrixArrayFromVector`. to
SPIRVLegalizePointerCast.cpp when we detect the matrix array of vector
memory layout
Changes to storeArrayFromVector were cleanup
Finally changes to SPIRVGlobalRegistry.cpp addressed #188130 vec3 to
vec4 layout so we don't need --scalar-block-layout.
---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |   6 +
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp |  95 ++++++++++------
 .../store-array-of-vectors-to-vector.ll       | 107 ++++++++++++++++++
 3 files changed, 171 insertions(+), 37 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/store-array-of-vectors-to-vector.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 73fca3ee18bce..c9c54b989f273 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -2280,6 +2280,12 @@ void SPIRVGlobalRegistry::addStructOffsetDecorations(
 void SPIRVGlobalRegistry::addArrayStrideDecorations(
     Register Reg, Type *ElementType, MachineIRBuilder &MIRBuilder) {
   uint32_t SizeInBytes = DataLayout().getTypeSizeInBits(ElementType) / 8;
+  // Vulkan requires vec3 to have the same alignment as vec4
+  if (auto *VecTy = dyn_cast<FixedVectorType>(ElementType);
+      VecTy && VecTy->getNumElements() == 3) {
+    uint32_t ScalarSize = VecTy->getScalarSizeInBits() / 8;
+    SizeInBytes = 4 * ScalarSize;
+  }
   buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::ArrayStride,
                   {SizeInBytes});
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 925b1b00336b5..3e3829a3c52bf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -172,7 +172,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     Value *NewVector = PoisonValue::get(TargetType);
     buildAssignType(B, TargetType, NewVector);
 
-    for (unsigned I = 0; I < TargetType->getNumElements(); ++I) {
+    for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
       Value *Index = B.getInt32(I);
       SmallVector<Type *, 4> Types = {TargetType, TargetType,
                                       TargetType->getElementType(),
@@ -194,15 +194,13 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     // Load each element of the array.
     SmallVector<Value *, 4> LoadedElements;
     SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
-    for (unsigned I = 0; I < TargetType->getNumElements(); ++I) {
+    for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
       unsigned ArrayIndex = I / ScalarsPerArrayElement;
       unsigned ElementIndexInArrayElem = I % ScalarsPerArrayElement;
       // Create a GEP to access the i-th element of the array.
-      SmallVector<Value *, 4> Args;
-      Args.push_back(B.getInt1(/*Inbounds=*/false));
-      Args.push_back(Source);
-      Args.push_back(B.getInt32(0));
-      Args.push_back(ConstantInt::get(B.getInt32Ty(), ArrayIndex));
+      std::array<Value *, 4> Args = {
+          B.getInt1(/*Inbounds=*/false), Source, B.getInt32(0),
+          ConstantInt::get(B.getInt32Ty(), ArrayIndex)};
       auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
       GR->buildAssignPtr(B, ArrElemVecTy, ElementPtr);
       Value *LoadVec = B.CreateLoad(ArrElemVecTy, ElementPtr);
@@ -218,13 +216,11 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     // Load each element of the array.
     SmallVector<Value *, 4> LoadedElements;
     SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
-    for (unsigned I = 0; I < TargetType->getNumElements(); ++I) {
+    for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
       // Create a GEP to access the i-th element of the array.
-      SmallVector<Value *, 4> Args;
-      Args.push_back(B.getInt1(/*Inbounds=*/false));
-      Args.push_back(Source);
-      Args.push_back(B.getInt32(0));
-      Args.push_back(ConstantInt::get(B.getInt32Ty(), I));
+      std::array<Value *, 4> Args = {B.getInt1(/*Inbounds=*/false), Source,
+                                     B.getInt32(0),
+                                     ConstantInt::get(B.getInt32Ty(), I)};
       auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
       GR->buildAssignPtr(B, TargetType->getElementType(), ElementPtr);
 
@@ -236,44 +232,65 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     return buildVectorFromLoadedElements(B, TargetType, LoadedElements);
   }
 
+  void storeMatrixArrayFromVector(IRBuilder<> &B, Value *SrcVector,
+                                  Value *DstArrayPtr, ArrayType *ArrTy,
+                                  Align Alignment) {
+    auto *SrcVecTy = cast<FixedVectorType>(SrcVector->getType());
+    auto *ArrElemVecTy = cast<FixedVectorType>(ArrTy->getElementType());
+    Type *ElemTy = ArrElemVecTy->getElementType();
+    unsigned ScalarsPerArrayElement = ArrElemVecTy->getNumElements();
+    unsigned SrcNumElements = SrcVecTy->getNumElements();
+
+    SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
+                                    DstArrayPtr->getType()};
+
+    for (unsigned I = 0; I < SrcNumElements; I += ScalarsPerArrayElement) {
+      unsigned ArrayIndex = I / ScalarsPerArrayElement;
+      // Create a GEP to access the array element.
+      std::array<Value *, 4> Args = {
+          B.getInt1(/*Inbounds=*/false), DstArrayPtr, B.getInt32(0),
+          ConstantInt::get(B.getInt32Ty(), ArrayIndex)};
+      auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+      GR->buildAssignPtr(B, ArrElemVecTy, ElementPtr);
+
+      // Extract scalar elements from the source vector for this array slot.
+      SmallVector<Value *, 4> Elements;
+      for (unsigned J = 0; J < ScalarsPerArrayElement; ++J)
+        Elements.push_back(makeExtractElement(B, ElemTy, SrcVector, I + J));
+
+      // Build a vector from the extracted elements and store it.
+      Value *Vec = buildVectorFromLoadedElements(B, ArrElemVecTy, Elements);
+      StoreInst *SI = B.CreateStore(Vec, ElementPtr);
+      SI->setAlignment(Alignment);
+    }
+  }
+
   // Stores elements from a vector into an array.
   void storeArrayFromVector(IRBuilder<> &B, Value *SrcVector,
                             Value *DstArrayPtr, ArrayType *ArrTy,
                             Align Alignment) {
     auto *VecTy = cast<FixedVectorType>(SrcVector->getType());
+    Type *ElemTy = ArrTy->getElementType();
 
     // Ensure the element types of the array and vector are the same.
-    assert(VecTy->getElementType() == ArrTy->getElementType() &&
+    assert(VecTy->getElementType() == ElemTy &&
            "Element types of array and vector must be the same.");
 
-    const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
-    uint64_t ElemSize = DL.getTypeAllocSize(ArrTy->getElementType());
+    SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
+                                    DstArrayPtr->getType()};
 
-    for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
+    for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
       // Create a GEP to access the i-th element of the array.
-      SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
-                                      DstArrayPtr->getType()};
-      SmallVector<Value *, 4> Args;
-      Args.push_back(B.getInt1(false));
-      Args.push_back(DstArrayPtr);
-      Args.push_back(B.getInt32(0));
-      Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
+      std::array<Value *, 4> Args = {B.getInt1(/*Inbounds=*/false), DstArrayPtr,
+                                     B.getInt32(0),
+                                     ConstantInt::get(B.getInt32Ty(), I)};
       auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
-      GR->buildAssignPtr(B, ArrTy->getElementType(), ElementPtr);
+      GR->buildAssignPtr(B, ElemTy, ElementPtr);
 
       // Extract the element from the vector and store it.
-      Value *Index = B.getInt32(i);
-      SmallVector<Type *, 3> EltTypes = {VecTy->getElementType(), VecTy,
-                                         Index->getType()};
-      SmallVector<Value *, 2> EltArgs = {SrcVector, Index};
-      Value *Element =
-          B.CreateIntrinsic(Intrinsic::spv_extractelt, {EltTypes}, {EltArgs});
-      buildAssignType(B, VecTy->getElementType(), Element);
-
-      Types = {Element->getType(), ElementPtr->getType()};
-      Align NewAlign = commonAlignment(Alignment, i * ElemSize);
-      Args = {Element, ElementPtr, B.getInt16(2), B.getInt32(NewAlign.value())};
-      B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
+      Value *Element = makeExtractElement(B, ElemTy, SrcVector, I);
+      StoreInst *SI = B.CreateStore(Element, ElementPtr);
+      SI->setAlignment(Alignment);
     }
   }
 
@@ -441,6 +458,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
     auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
     auto *D_AT = dyn_cast<ArrayType>(ToTy);
+    auto *D_MAT =
+        D_AT ? dyn_cast<FixedVectorType>(D_AT->getElementType()) : nullptr;
 
     B.SetInsertPoint(BadStore);
     if (isTypeFirstElementAggregate(FromTy, ToTy))
@@ -451,6 +470,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
       storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
     else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType())
       storeArrayFromVector(B, Src, Dst, D_AT, Alignment);
+    else if (D_MAT && S_VT && D_MAT->getElementType() == S_VT->getElementType())
+      storeMatrixArrayFromVector(B, Src, Dst, D_AT, Alignment);
     else
       llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
 
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-array-of-vectors-to-vector.ll b/llvm/test/CodeGen/SPIRV/pointers/store-array-of-vectors-to-vector.ll
new file mode 100644
index 0000000000000..73585beb75a85
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-array-of-vectors-to-vector.ll
@@ -0,0 +1,107 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan %s -stop-after=spirv-legalize-bitcast -o - | FileCheck %s --check-prefix=IRCHECK
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %}
+
+; CHECK-DAG: [[FLOAT:%[0-9]+]] = OpTypeFloat 32
+; CHECK-DAG: [[VEC3FLOAT:%[0-9]+]] = OpTypeVector [[FLOAT]] 3
+; CHECK-DAG: [[PTR_VEC3:%[0-9]+]] = OpTypePointer StorageBuffer [[VEC3FLOAT]]
+; CHECK-DAG: [[UINT:%[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: [[UINT2:%[0-9]+]] = OpConstant [[UINT]] 2
+; CHECK-DAG: [[ARRAY2VEC3:%[0-9]+]] = OpTypeArray [[VEC3FLOAT]] [[UINT2]]
+; CHECK-DAG: [[PTR_ARRAY2VEC3:%[0-9]+]] = OpTypePointer StorageBuffer [[ARRAY2VEC3]]
+; CHECK-DAG: [[UINT0:%[0-9]+]] = OpConstant [[UINT]] 0
+; CHECK-DAG: [[UINT1:%[0-9]+]] = OpConstant [[UINT]] 1
+; CHECK-DAG: [[UNDEF_VEC3:%[0-9]+]] = OpUndef [[VEC3FLOAT]]
+
+; Load from input[0][0] (first vector)
+; CHECK:      [[IN_AC_ARR:%[0-9]+]] = OpAccessChain [[PTR_ARRAY2VEC3]] %{{[0-9]+}} [[UINT0]] [[UINT0]]
+; CHECK-NEXT: [[IN_AC_VEC0:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT0]]
+; CHECK-NEXT: [[LOAD0:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC0]]
+; CHECK-NEXT: [[EX_0_0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD0]] 0
+; CHECK-NEXT: [[IN_AC_VEC0_2:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT0]]
+; CHECK-NEXT: [[LOAD1:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC0_2]]
+; CHECK-NEXT: [[EX_0_1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD1]] 1
+; CHECK-NEXT: [[IN_AC_VEC0_3:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT0]]
+; CHECK-NEXT: [[LOAD2:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC0_3]]
+; CHECK-NEXT: [[EX_0_2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD2]] 2
+
+; Load from input[0][1] (second vector)
+; CHECK-NEXT: [[IN_AC_VEC1:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT1]]
+; CHECK-NEXT: [[LOAD3:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC1]]
+; CHECK-NEXT: [[EX_1_0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD3]] 0
+; CHECK-NEXT: [[IN_AC_VEC1_2:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT1]]
+; CHECK-NEXT: [[LOAD4:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC1_2]]
+; CHECK-NEXT: [[EX_1_1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD4]] 1
+; CHECK-NEXT: [[IN_AC_VEC1_3:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[IN_AC_ARR]] [[UINT1]]
+; CHECK-NEXT: [[LOAD5:%[0-9]+]] = OpLoad [[VEC3FLOAT]] [[IN_AC_VEC1_3]]
+; CHECK-NEXT: [[EX_1_2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[LOAD5]] 2
+
+; Store to output[0][0] (first vector)
+; CHECK-NEXT: [[OUT_AC_ARR:%[0-9]+]] = OpAccessChain [[PTR_ARRAY2VEC3]] %{{[0-9]+}} [[UINT0]] [[UINT0]]
+; CHECK-NEXT: [[OUT_AC_VEC0:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[OUT_AC_ARR]] [[UINT0]]
+; CHECK-NEXT: [[INS_0_0:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_0_0]] [[UNDEF_VEC3]] 0
+; CHECK-NEXT: [[INS_0_1:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_0_1]] [[INS_0_0]] 1
+; CHECK-NEXT: [[INS_0_2:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_0_2]] [[INS_0_1]] 2
+; CHECK-NEXT: OpStore [[OUT_AC_VEC0]] [[INS_0_2]]
+
+; Store to output[0][1] (second vector)
+; CHECK-NEXT: [[OUT_AC_VEC1:%[0-9]+]] = OpAccessChain [[PTR_VEC3]] [[OUT_AC_ARR]] [[UINT1]]
+; CHECK-NEXT: [[INS_1_0:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_1_0]] [[UNDEF_VEC3]] 0
+; CHECK-NEXT: [[INS_1_1:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_1_1]] [[INS_1_0]] 1
+; CHECK-NEXT: [[INS_1_2:%[0-9]+]] = OpCompositeInsert [[VEC3FLOAT]] [[EX_1_2]] [[INS_1_1]] 2
+; CHECK-NEXT: OpStore [[OUT_AC_VEC1]] [[INS_1_2]]
+; CHECK-NEXT: OpReturn
+
+; IRCHECK-LABEL: define void @main
+
+; Load: GEP to input row 0, load <3 x float>, extract elements
+; IRCHECK: [[IN_PTR:%[0-9]+]] = {{.*}}call {{.*}} ptr addrspace(11) @llvm.spv.resource.getpointer
+; IRCHECK: [[GEP_ROW0:%[0-9]+]] = call ptr addrspace(11) (i1, ptr addrspace(11), ...) @llvm.spv.gep.p11.p11(i1 false, ptr addrspace(11) [[IN_PTR]], i32 0, i32 0)
+; IRCHECK: [[VEC_ROW0:%[0-9]+]] = load <3 x float>, ptr addrspace(11) [[GEP_ROW0]]
+; IRCHECK: [[E00:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> [[VEC_ROW0]], i32 0)
+; IRCHECK: [[E01:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> {{%[0-9]+}}, i32 1)
+; IRCHECK: [[E02:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> {{%[0-9]+}}, i32 2)
+
+; Load: GEP to input row 1, load <3 x float>, extract elements
+; IRCHECK: [[GEP_ROW1:%[0-9]+]] = call ptr addrspace(11) (i1, ptr addrspace(11), ...) @llvm.spv.gep.p11.p11(i1 false, ptr addrspace(11) [[IN_PTR]], i32 0, i32 1)
+; IRCHECK: [[VEC_ROW1:%[0-9]+]] = load <3 x float>, ptr addrspace(11) [[GEP_ROW1]]
+; IRCHECK: [[E10:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> [[VEC_ROW1]], i32 0)
+; IRCHECK: [[E11:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> {{%[0-9]+}}, i32 1)
+; IRCHECK: [[E12:%[0-9]+]] = call float @llvm.spv.extractelt.f32.v3f32.i32(<3 x float> {{%[0-9]+}}, i32 2)
+
+; Store: GEP to output row 0, extract from <6 x float>, insert into <3 x float>, store
+; IRCHECK: [[OUT_PTR:%[0-9]+]] = {{.*}}call {{.*}} ptr addrspace(11) @llvm.spv.resource.getpointer
+; IRCHECK: [[OUT_GEP_ROW0:%[0-9]+]] = call ptr addrspace(11) (i1, ptr addrspace(11), ...) @llvm.spv.gep.p11.p11(i1 false, ptr addrspace(11) [[OUT_PTR]], i32 0, i32 0)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 0)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 1)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 2)
+; IRCHECK: call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> poison, float {{%[0-9]+}}, i32 0)
+; IRCHECK: call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> {{%[0-9]+}}, float {{%[0-9]+}}, i32 1)
+; IRCHECK: [[STORE_VEC0:%[0-9]+]] = call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> {{%[0-9]+}}, float {{%[0-9]+}}, i32 2)
+; IRCHECK: store <3 x float> [[STORE_VEC0]], ptr addrspace(11) [[OUT_GEP_ROW0]]
+
+; Store: GEP to output row 1, extract from <6 x float>, insert into <3 x float>, store
+; IRCHECK: [[OUT_GEP_ROW1:%[0-9]+]] = call ptr addrspace(11) (i1, ptr addrspace(11), ...) @llvm.spv.gep.p11.p11(i1 false, ptr addrspace(11) [[OUT_PTR]], i32 0, i32 1)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 3)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 4)
+; IRCHECK: call float @llvm.spv.extractelt.f32.v6f32.i32(<6 x float> {{%[0-9]+}}, i32 5)
+; IRCHECK: call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> poison, float {{%[0-9]+}}, i32 0)
+; IRCHECK: call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> {{%[0-9]+}}, float {{%[0-9]+}}, i32 1)
+; IRCHECK: [[STORE_VEC1:%[0-9]+]] = call <3 x float> @llvm.spv.insertelt.v3f32.v3f32.f32.i32(<3 x float> {{%[0-9]+}}, float {{%[0-9]+}}, i32 2)
+; IRCHECK: store <3 x float> [[STORE_VEC1]], ptr addrspace(11) [[OUT_GEP_ROW1]]
+
+ at .str = private unnamed_addr constant [4 x i8] c"InA\00", align 1
+ at .str.2 = private unnamed_addr constant [5 x i8] c"OutA\00", align 1
+
+define void @main() local_unnamed_addr #0 {
+entry:
+  %0 = tail call target("spirv.VulkanBuffer", [0 x [2 x <3 x float>]], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
+  %1 = tail call target("spirv.VulkanBuffer", [0 x [2 x <3 x float>]], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
+  %2 = tail call noundef align 4 dereferenceable(24) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer(target("spirv.VulkanBuffer", [0 x [2 x <3 x float>]], 12, 0) %0, i32 0)
+  %3 = load <6 x float>, ptr addrspace(11) %2, align 4
+  %4 = tail call noundef align 4 dereferenceable(24) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer(target("spirv.VulkanBuffer", [0 x [2 x <3 x float>]], 12, 1) %1, i32 0)
+  store <6 x float> %3, ptr addrspace(11) %4, align 4
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

>From dd88444a635a4c87cb5b8897120a6bf310355248 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 26 Mar 2026 16:51:36 -0400
Subject: [PATCH 2/2] address pr comments

---
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 21 +++++++++++--------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 3e3829a3c52bf..93886f552fe2f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -149,7 +149,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
   // which should be the load being legalized. Returns the loaded value.
   Value *loadFirstValueFromAggregate(IRBuilder<> &B, Type *ElementType,
                                      Value *Source, LoadInst *BadLoad) {
-    SmallVector<Type *, 2> Types = {BadLoad->getPointerOperandType(),
+    std::array<Type *, 2> Types = {BadLoad->getPointerOperandType(),
                                     Source->getType()};
     SmallVector<Value *, 8> Args{/* isInBounds= */ B.getInt1(false), Source};
 
@@ -193,7 +193,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     unsigned ScalarsPerArrayElement = ArrElemVecTy->getNumElements();
     // Load each element of the array.
     SmallVector<Value *, 4> LoadedElements;
-    SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
+    std::array<Type *, 2> Types = {Source->getType(), Source->getType()};
     for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
       unsigned ArrayIndex = I / ScalarsPerArrayElement;
       unsigned ElementIndexInArrayElem = I % ScalarsPerArrayElement;
@@ -215,7 +215,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
                              Value *Source) {
     // Load each element of the array.
     SmallVector<Value *, 4> LoadedElements;
-    SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
+    std::array<Type *, 2> Types = {Source->getType(), Source->getType()};
     for (unsigned I = 0, E = TargetType->getNumElements(); I < E; ++I) {
       // Create a GEP to access the i-th element of the array.
       std::array<Value *, 4> Args = {B.getInt1(/*Inbounds=*/false), Source,
@@ -232,6 +232,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     return buildVectorFromLoadedElements(B, TargetType, LoadedElements);
   }
 
+  // Stores elements from a vector into a matrix (an array of vectors).
   void storeMatrixArrayFromVector(IRBuilder<> &B, Value *SrcVector,
                                   Value *DstArrayPtr, ArrayType *ArrTy,
                                   Align Alignment) {
@@ -240,8 +241,11 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     Type *ElemTy = ArrElemVecTy->getElementType();
     unsigned ScalarsPerArrayElement = ArrElemVecTy->getNumElements();
     unsigned SrcNumElements = SrcVecTy->getNumElements();
+    assert(
+        SrcNumElements % ScalarsPerArrayElement == 0 &&
+        "Source vector size must be a multiple of array element vector size");
 
-    SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
+    std::array<Type *, 2> Types = {DstArrayPtr->getType(),
                                     DstArrayPtr->getType()};
 
     for (unsigned I = 0; I < SrcNumElements; I += ScalarsPerArrayElement) {
@@ -275,11 +279,10 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     // Ensure the element types of the array and vector are the same.
     assert(VecTy->getElementType() == ElemTy &&
            "Element types of array and vector must be the same.");
+    std::array<Type *, 2> Types = {DstArrayPtr->getType(),
+                                   DstArrayPtr->getType()};
 
-    SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
-                                    DstArrayPtr->getType()};
-
-    for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
+    for (unsigned I = 0, E = VecTy->getNumElements(); I < E; ++I) {
       // Create a GEP to access the i-th element of the array.
       std::array<Value *, 4> Args = {B.getInt1(/*Inbounds=*/false), DstArrayPtr,
                                      B.getInt32(0),
@@ -426,7 +429,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
   // Stores the given Src value into the first entry of the Dst aggregate.
   Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
                                     Type *DstPointeeType, Align Alignment) {
-    SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
+    std::array<Type *, 2> Types = {Dst->getType(), Dst->getType()};
     SmallVector<Value *, 8> Args{/* isInBounds= */ B.getInt1(true), Dst};
     buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
     auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});



More information about the llvm-commits mailing list