[llvm] [SPIRV] Support non-constant indices for vector insert/extract (PR #172514)

Steven Perron via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 7 06:04:52 PST 2026


https://github.com/s-perron updated https://github.com/llvm/llvm-project/pull/172514

>From 242454d33aaaecbfce0f0bd36ae5a023bd0fd558 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Mon, 15 Dec 2025 10:34:34 -0500
Subject: [PATCH 1/4] [SPIRV] Support non-constant indices for vector
 insert/extract

This patch updates the legalization of spv_insertelt and spv_extractelt to
handle non-constant (dynamic) indices. When a dynamic index is encountered, the
vector is spilled to the stack, and the element is accessed via OpAccessChain
(lowered from spv_gep).

This patch also adds custom legalization for G_STORE to scalarize vector stores
and refines the legalization rules for G_LOAD, G_STORE, and G_BUILD_VECTOR.

Fixes https://github.com/llvm/llvm-project/issues/170534
---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  | 215 +++++++++++++++++-
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp  |  43 +++-
 .../SPIRV/legalization/load-store-global.ll   |  96 +++-----
 .../spv-extractelt-legalization.ll            |  66 ++++++
 .../SPIRV/legalization/vector-arithmetic-6.ll |  99 ++++----
 .../SPIRV/llvm-intrinsics/matrix-multiply.ll  |  19 +-
 .../SPIRV/llvm-intrinsics/matrix-transpose.ll |  48 ++--
 7 files changed, 415 insertions(+), 171 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/legalization/spv-extractelt-legalization.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 590182731b002..169e34bfffce6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -243,10 +243,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   // Illegal G_UNMERGE_VALUES instructions should be handled
   // during the combine phase.
   getActionDefinitionsBuilder(G_BUILD_VECTOR)
-      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
-      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
-                       LegalizeMutations::changeElementCountTo(
-                           0, ElementCount::getFixed(MaxVectorSize)));
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize));
 
   // When entering the legalizer, there should be no G_BITCAST instructions.
   // They should all be calls to the `spv_bitcast` intrinsic. The call to
@@ -307,9 +304,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                                   all(typeIsNot(0, p9), typeIs(1, p9))))
       .legalForCartesianProduct(allPtrs, allPtrs);
 
+  // Should we be legalizing bad scalar sizes like s5 here instead
+  // of handling them in the instruction selector?
   getActionDefinitionsBuilder({G_LOAD, G_STORE})
       .unsupportedIf(typeIs(1, p9))
-      .legalIf(typeInSet(1, allPtrs));
+      .legalForCartesianProduct(allowedVectorTypes, allPtrs)
+      .legalForCartesianProduct(allPtrs, allPtrs)
+      .legalIf(isScalar(0))
+      .custom();
 
   getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
                                G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
@@ -532,6 +534,59 @@ static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
   return ConvReg;
 }
 
+static bool legalizeLoad(LegalizerHelper &Helper, MachineInstr &MI,
+                         SPIRVGlobalRegistry *GR) {
+  return true;
+}
+
+static bool legalizeStore(LegalizerHelper &Helper, MachineInstr &MI,
+                          SPIRVGlobalRegistry *GR) {
+  MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  Register ValReg = MI.getOperand(0).getReg();
+  Register PtrReg = MI.getOperand(1).getReg();
+  LLT ValTy = MRI.getType(ValReg);
+
+  assert(ValTy.isVector() && "Expected vector store");
+
+  SmallVector<Register, 8> SplitRegs;
+  LLT EltTy = ValTy.getElementType();
+  unsigned NumElts = ValTy.getNumElements();
+
+  for (unsigned i = 0; i < NumElts; ++i)
+    SplitRegs.push_back(MRI.createGenericVirtualRegister(EltTy));
+
+  MIRBuilder.buildUnmerge(SplitRegs, ValReg);
+
+  LLT PtrTy = MRI.getType(PtrReg);
+  auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
+
+  for (unsigned i = 0; i < NumElts; ++i) {
+    auto Idx = MIRBuilder.buildConstant(LLT::scalar(32), i);
+    Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
+
+    MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
+        .addImm(1) // InBounds
+        .addUse(PtrReg)
+        .addUse(Zero.getReg(0))
+        .addUse(Idx.getReg(0));
+
+    MachinePointerInfo EltPtrInfo;
+    Align EltAlign = Align(1);
+    if (!MI.memoperands_empty()) {
+      MachineMemOperand *MMO = *MI.memoperands_begin();
+      EltPtrInfo =
+          MMO->getPointerInfo().getWithOffset(i * EltTy.getSizeInBytes());
+      EltAlign = commonAlignment(MMO->getAlign(), i * EltTy.getSizeInBytes());
+    }
+
+    MIRBuilder.buildStore(SplitRegs[i], EltPtr, EltPtrInfo, EltAlign);
+  }
+
+  MI.eraseFromParent();
+  return true;
+}
+
 bool SPIRVLegalizerInfo::legalizeCustom(
     LegalizerHelper &Helper, MachineInstr &MI,
     LostDebugLocObserver &LocObserver) const {
@@ -572,6 +627,10 @@ bool SPIRVLegalizerInfo::legalizeCustom(
     }
     return true;
   }
+  case TargetOpcode::G_LOAD:
+    return legalizeLoad(Helper, MI, GR);
+  case TargetOpcode::G_STORE:
+    return legalizeStore(Helper, MI, GR);
   }
 }
 
@@ -614,11 +673,61 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
     LLT DstTy = MRI.getType(DstReg);
 
     if (needsVectorLegalization(DstTy, ST)) {
+      Register DstReg = MI.getOperand(0).getReg();
       Register SrcReg = MI.getOperand(2).getReg();
       Register ValReg = MI.getOperand(3).getReg();
-      Register IdxReg = MI.getOperand(4).getReg();
-      MIRBuilder.buildInsertVectorElement(DstReg, SrcReg, ValReg, IdxReg);
+      LLT SrcTy = MRI.getType(SrcReg);
+      MachineOperand &IdxOperand = MI.getOperand(4);
+
+      if (getImm(IdxOperand, &MRI)) {
+        uint64_t IdxVal = foldImm(IdxOperand, &MRI);
+        if (IdxVal < SrcTy.getNumElements()) {
+          SmallVector<Register, 8> Regs;
+          SPIRVType *ElementType = GR->getScalarOrVectorComponentType(
+              GR->getSPIRVTypeForVReg(DstReg));
+          LLT ElementLLTTy = GR->getRegType(ElementType);
+          for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
+            Register Reg = MRI.createGenericVirtualRegister(ElementLLTTy);
+            MRI.setRegClass(Reg, GR->getRegClass(ElementType));
+            GR->assignSPIRVTypeToVReg(ElementType, Reg, *MI.getMF());
+            Regs.push_back(Reg);
+          }
+          MIRBuilder.buildUnmerge(Regs, SrcReg);
+          Regs[IdxVal] = ValReg;
+          MIRBuilder.buildBuildVector(DstReg, Regs);
+          MI.eraseFromParent();
+          return true;
+        }
+      }
+
+      LLT EltTy = SrcTy.getElementType();
+      Align VecAlign = Helper.getStackTemporaryAlignment(SrcTy);
+
+      MachinePointerInfo PtrInfo;
+      auto StackTemp = Helper.createStackTemporary(
+          TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
+
+      MIRBuilder.buildStore(SrcReg, StackTemp, PtrInfo, VecAlign);
+
+      Register IdxReg = IdxOperand.getReg();
+      LLT PtrTy = MRI.getType(StackTemp.getReg(0));
+      Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
+      auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
+
+      MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
+          .addImm(1) // InBounds
+          .addUse(StackTemp.getReg(0))
+          .addUse(Zero.getReg(0))
+          .addUse(IdxReg);
+
+      MachinePointerInfo EltPtrInfo =
+          MachinePointerInfo(PtrTy.getAddressSpace());
+      Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
+      MIRBuilder.buildStore(ValReg, EltPtr, EltPtrInfo, EltAlign);
+
+      MIRBuilder.buildLoad(DstReg, StackTemp, PtrInfo, VecAlign);
       MI.eraseFromParent();
+      return true;
     }
     return true;
   } else if (IntrinsicID == Intrinsic::spv_extractelt) {
@@ -627,9 +736,97 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
 
     if (needsVectorLegalization(SrcTy, ST)) {
       Register DstReg = MI.getOperand(0).getReg();
-      Register IdxReg = MI.getOperand(3).getReg();
-      MIRBuilder.buildExtractVectorElement(DstReg, SrcReg, IdxReg);
+      MachineOperand &IdxOperand = MI.getOperand(3);
+
+      if (getImm(IdxOperand, &MRI)) {
+        uint64_t IdxVal = foldImm(IdxOperand, &MRI);
+        if (IdxVal < SrcTy.getNumElements()) {
+          LLT DstTy = MRI.getType(DstReg);
+          SmallVector<Register, 8> Regs;
+          SPIRVType *DstSpvTy = GR->getSPIRVTypeForVReg(DstReg);
+          for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
+            if (I == IdxVal) {
+              Regs.push_back(DstReg);
+            } else {
+              Register Reg = MRI.createGenericVirtualRegister(DstTy);
+              MRI.setRegClass(Reg, GR->getRegClass(DstSpvTy));
+              GR->assignSPIRVTypeToVReg(DstSpvTy, Reg, *MI.getMF());
+              Regs.push_back(Reg);
+            }
+          }
+          MIRBuilder.buildUnmerge(Regs, SrcReg);
+          MI.eraseFromParent();
+          return true;
+        }
+      }
+
+      LLT EltTy = SrcTy.getElementType();
+      Align VecAlign = Helper.getStackTemporaryAlignment(SrcTy);
+
+      MachinePointerInfo PtrInfo;
+      auto StackTemp = Helper.createStackTemporary(
+          TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
+
+      // Set the type of StackTemp to a pointer to an array of the element type.
+      SPIRVType *SpvSrcTy = GR->getSPIRVTypeForVReg(SrcReg);
+      SPIRVType *EltSpvTy = GR->getScalarOrVectorComponentType(SpvSrcTy);
+      const Type *LLVMEltTy = GR->getTypeForSPIRVType(EltSpvTy);
+      const Type *LLVMArrTy =
+          ArrayType::get(const_cast<Type *>(LLVMEltTy), SrcTy.getNumElements());
+      SPIRVType *ArrSpvTy = GR->getOrCreateSPIRVType(
+          LLVMArrTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
+      SPIRVType *PtrToArrSpvTy = GR->getOrCreateSPIRVPointerType(
+          ArrSpvTy, MIRBuilder, SPIRV::StorageClass::Function);
+      setRegClassType(StackTemp.getReg(0), PtrToArrSpvTy, GR, &MRI,
+                      MIRBuilder.getMF());
+
+      // Store the vector elements one by one.
+      SmallVector<Register, 8> Regs;
+      for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
+        Register Reg = MRI.createGenericVirtualRegister(EltTy);
+        MRI.setRegClass(Reg, GR->getRegClass(EltSpvTy));
+        GR->assignSPIRVTypeToVReg(EltSpvTy, Reg, *MI.getMF());
+        Regs.push_back(Reg);
+      }
+      MIRBuilder.buildUnmerge(Regs, SrcReg);
+
+      auto ZeroNew = MIRBuilder.buildConstant(LLT::scalar(32), 0);
+      LLT PtrTyNew = MRI.getType(StackTemp.getReg(0));
+
+      for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
+        auto Idx = MIRBuilder.buildConstant(LLT::scalar(32), I);
+        Register EltPtr = MRI.createGenericVirtualRegister(PtrTyNew);
+        MIRBuilder
+            .buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
+            .addImm(1) // InBounds
+            .addUse(StackTemp.getReg(0))
+            .addUse(ZeroNew.getReg(0))
+            .addUse(Idx.getReg(0));
+
+        MachinePointerInfo EltPtrInfo =
+            PtrInfo.getWithOffset(I * EltTy.getSizeInBytes());
+        Align EltAlign = commonAlignment(VecAlign, I * EltTy.getSizeInBytes());
+        MIRBuilder.buildStore(Regs[I], EltPtr, EltPtrInfo, EltAlign);
+      }
+
+      Register IdxReg = IdxOperand.getReg();
+      LLT PtrTy = MRI.getType(StackTemp.getReg(0));
+      Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
+      auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
+
+      MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
+          .addImm(1) // InBounds
+          .addUse(StackTemp.getReg(0))
+          .addUse(Zero.getReg(0))
+          .addUse(IdxReg);
+
+      MachinePointerInfo EltPtrInfo =
+          MachinePointerInfo(PtrTy.getAddressSpace());
+      Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
+      MIRBuilder.buildLoad(DstReg, EltPtr, EltPtrInfo, EltAlign);
+
       MI.eraseFromParent();
+      return true;
     }
     return true;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 5f52f60da37e1..5b4ddc267c9b8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -17,6 +17,7 @@
 #include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
 #include <stack>
@@ -107,6 +108,37 @@ static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use,
   return nullptr;
 }
 
+static SPIRVType *deducePointerTypeFromResultRegister(MachineInstr *Use,
+                                                      Register UseRegister,
+                                                      SPIRVGlobalRegistry *GR,
+                                                      MachineIRBuilder &MIB) {
+  assert(Use->getOpcode() == TargetOpcode::G_LOAD ||
+         Use->getOpcode() == TargetOpcode::G_STORE);
+
+  Register ValueReg = Use->getOperand(0).getReg();
+  SPIRVType *ValueType = GR->getSPIRVTypeForVReg(ValueReg);
+  if (!ValueType)
+    return nullptr;
+
+  return GR->getOrCreateSPIRVPointerType(ValueType, MIB,
+                                         SPIRV::StorageClass::Function);
+}
+
+static SPIRVType *deduceTypeFromPointerOperand(MachineInstr *Use,
+                                               Register UseRegister,
+                                               SPIRVGlobalRegistry *GR,
+                                               MachineIRBuilder &MIB) {
+  assert(Use->getOpcode() == TargetOpcode::G_LOAD ||
+         Use->getOpcode() == TargetOpcode::G_STORE);
+
+  Register PtrReg = Use->getOperand(1).getReg();
+  SPIRVType *PtrType = GR->getSPIRVTypeForVReg(PtrReg);
+  if (!PtrType)
+    return nullptr;
+
+  return GR->getPointeeType(PtrType);
+}
+
 static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
                                      SPIRVGlobalRegistry *GR,
                                      MachineIRBuilder &MIB) {
@@ -135,6 +167,13 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
     case TargetOpcode::G_STRICT_FMA:
       ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
       break;
+    case TargetOpcode::G_LOAD:
+    case TargetOpcode::G_STORE:
+      if (Reg == Use.getOperand(1).getReg())
+        ResType = deducePointerTypeFromResultRegister(&Use, Reg, GR, MIB);
+      else
+        ResType = deduceTypeFromPointerOperand(&Use, Reg, GR, MIB);
+      break;
     case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
     case TargetOpcode::G_INTRINSIC: {
       auto IntrinsicID = cast<GIntrinsic>(Use).getIntrinsicID();
@@ -323,6 +362,7 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
 
   for (auto *I : Worklist) {
     MachineIRBuilder MIB(*I);
+    LLVM_DEBUG(dbgs() << "Assigning default type to results in " << *I);
     for (unsigned Idx = 0; Idx < I->getNumDefs(); ++Idx) {
       Register ResVReg = I->getOperand(Idx).getReg();
       if (GR->getSPIRVTypeForVReg(ResVReg))
@@ -337,9 +377,6 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
       } else {
         ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
       }
-      LLVM_DEBUG(dbgs() << "Could not determine type for " << ResVReg
-                        << ", defaulting to " << *ResType << "\n");
-
       setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
     }
   }
diff --git a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
index 19b39ff59809a..39e72fb9b1f51 100644
--- a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
+++ b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
@@ -67,72 +67,40 @@ entry:
 ; CHECK-DAG: %[[#VAL14:]] = OpLoad %[[#int]] %[[#PTR14]] Aligned 4
 ; CHECK-DAG: %[[#PTR15:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C15]]
 ; CHECK-DAG: %[[#VAL15:]] = OpLoad %[[#int]] %[[#PTR15]] Aligned 4
-; CHECK-DAG: %[[#INS0:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL0]] %[[#UNDEF:]] 0
-; CHECK-DAG: %[[#INS1:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL1]] %[[#INS0]] 1
-; CHECK-DAG: %[[#INS2:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL2]] %[[#INS1]] 2
-; CHECK-DAG: %[[#INS3:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL3]] %[[#INS2]] 3
-; CHECK-DAG: %[[#INS4:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL4]] %[[#UNDEF]] 0
-; CHECK-DAG: %[[#INS5:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL5]] %[[#INS4]] 1
-; CHECK-DAG: %[[#INS6:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL6]] %[[#INS5]] 2
-; CHECK-DAG: %[[#INS7:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL7]] %[[#INS6]] 3
-; CHECK-DAG: %[[#INS8:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL8]] %[[#UNDEF]] 0
-; CHECK-DAG: %[[#INS9:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL9]] %[[#INS8]] 1
-; CHECK-DAG: %[[#INS10:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL10]] %[[#INS9]] 2
-; CHECK-DAG: %[[#INS11:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL11]] %[[#INS10]] 3
-; CHECK-DAG: %[[#INS12:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL12]] %[[#UNDEF]] 0
-; CHECK-DAG: %[[#INS13:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL13]] %[[#INS12]] 1
-; CHECK-DAG: %[[#INS14:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL14]] %[[#INS13]] 2
-; CHECK-DAG: %[[#INS15:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL15]] %[[#INS14]] 3
   %0 = load <16 x i32>, ptr addrspace(10) @G_16, align 64
  
-; CHECK-DAG: %[[#PTR0_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C0]]
-; CHECK-DAG: %[[#VAL0_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 0
-; CHECK-DAG: OpStore %[[#PTR0_S]] %[[#VAL0_S]] Aligned 64
-; CHECK-DAG: %[[#PTR1_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C1]]
-; CHECK-DAG: %[[#VAL1_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 1
-; CHECK-DAG: OpStore %[[#PTR1_S]] %[[#VAL1_S]] Aligned 4
-; CHECK-DAG: %[[#PTR2_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C2]]
-; CHECK-DAG: %[[#VAL2_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 2
-; CHECK-DAG: OpStore %[[#PTR2_S]] %[[#VAL2_S]] Aligned 8
-; CHECK-DAG: %[[#PTR3_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C3]]
-; CHECK-DAG: %[[#VAL3_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 3
-; CHECK-DAG: OpStore %[[#PTR3_S]] %[[#VAL3_S]] Aligned 4
-; CHECK-DAG: %[[#PTR4_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C4]]
-; CHECK-DAG: %[[#VAL4_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 0
-; CHECK-DAG: OpStore %[[#PTR4_S]] %[[#VAL4_S]] Aligned 16
-; CHECK-DAG: %[[#PTR5_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C5]]
-; CHECK-DAG: %[[#VAL5_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 1
-; CHECK-DAG: OpStore %[[#PTR5_S]] %[[#VAL5_S]] Aligned 4
-; CHECK-DAG: %[[#PTR6_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C6]]
-; CHECK-DAG: %[[#VAL6_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 2
-; CHECK-DAG: OpStore %[[#PTR6_S]] %[[#VAL6_S]] Aligned 8
-; CHECK-DAG: %[[#PTR7_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C7]]
-; CHECK-DAG: %[[#VAL7_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 3
-; CHECK-DAG: OpStore %[[#PTR7_S]] %[[#VAL7_S]] Aligned 4
-; CHECK-DAG: %[[#PTR8_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C8]]
-; CHECK-DAG: %[[#VAL8_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 0
-; CHECK-DAG: OpStore %[[#PTR8_S]] %[[#VAL8_S]] Aligned 32
-; CHECK-DAG: %[[#PTR9_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C9]]
-; CHECK-DAG: %[[#VAL9_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 1
-; CHECK-DAG: OpStore %[[#PTR9_S]] %[[#VAL9_S]] Aligned 4
-; CHECK-DAG: %[[#PTR10_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C10]]
-; CHECK-DAG: %[[#VAL10_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 2
-; CHECK-DAG: OpStore %[[#PTR10_S]] %[[#VAL10_S]] Aligned 8
-; CHECK-DAG: %[[#PTR11_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C11]]
-; CHECK-DAG: %[[#VAL11_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 3
-; CHECK-DAG: OpStore %[[#PTR11_S]] %[[#VAL11_S]] Aligned 4
-; CHECK-DAG: %[[#PTR12_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C12]]
-; CHECK-DAG: %[[#VAL12_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 0
-; CHECK-DAG: OpStore %[[#PTR12_S]] %[[#VAL12_S]] Aligned 16
-; CHECK-DAG: %[[#PTR13_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C13]]
-; CHECK-DAG: %[[#VAL13_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 1
-; CHECK-DAG: OpStore %[[#PTR13_S]] %[[#VAL13_S]] Aligned 4
-; CHECK-DAG: %[[#PTR14_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C14]]
-; CHECK-DAG: %[[#VAL14_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 2
-; CHECK-DAG: OpStore %[[#PTR14_S]] %[[#VAL14_S]] Aligned 8
-; CHECK-DAG: %[[#PTR15_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C15]]
-; CHECK-DAG: %[[#VAL15_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 3
-; CHECK-DAG: OpStore %[[#PTR15_S]] %[[#VAL15_S]] Aligned 4
+; CHECK: %[[#PTR0_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C0]]
+; CHECK: OpStore %[[#PTR0_S]] %[[#VAL0]] Aligned 64
+; CHECK: %[[#PTR1_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C1]]
+; CHECK: OpStore %[[#PTR1_S]] %[[#VAL1]] Aligned 4
+; CHECK: %[[#PTR2_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C2]]
+; CHECK: OpStore %[[#PTR2_S]] %[[#VAL2]] Aligned 8
+; CHECK: %[[#PTR3_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C3]]
+; CHECK: OpStore %[[#PTR3_S]] %[[#VAL3]] Aligned 4
+; CHECK: %[[#PTR4_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C4]]
+; CHECK: OpStore %[[#PTR4_S]] %[[#VAL4]] Aligned 16
+; CHECK: %[[#PTR5_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C5]]
+; CHECK: OpStore %[[#PTR5_S]] %[[#VAL5]] Aligned 4
+; CHECK: %[[#PTR6_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C6]]
+; CHECK: OpStore %[[#PTR6_S]] %[[#VAL6]] Aligned 8
+; CHECK: %[[#PTR7_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C7]]
+; CHECK: OpStore %[[#PTR7_S]] %[[#VAL7]] Aligned 4
+; CHECK: %[[#PTR8_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C8]]
+; CHECK: OpStore %[[#PTR8_S]] %[[#VAL8]] Aligned 32
+; CHECK: %[[#PTR9_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C9]]
+; CHECK: OpStore %[[#PTR9_S]] %[[#VAL9]] Aligned 4
+; CHECK: %[[#PTR10_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C10]]
+; CHECK: OpStore %[[#PTR10_S]] %[[#VAL10]] Aligned 8
+; CHECK: %[[#PTR11_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C11]]
+; CHECK: OpStore %[[#PTR11_S]] %[[#VAL11]] Aligned 4
+; CHECK: %[[#PTR12_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C12]]
+; CHECK: OpStore %[[#PTR12_S]] %[[#VAL12]] Aligned 16
+; CHECK: %[[#PTR13_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C13]]
+; CHECK: OpStore %[[#PTR13_S]] %[[#VAL13]] Aligned 4
+; CHECK: %[[#PTR14_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C14]]
+; CHECK: OpStore %[[#PTR14_S]] %[[#VAL14]] Aligned 8
+; CHECK: %[[#PTR15_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C15]]
+; CHECK: OpStore %[[#PTR15_S]] %[[#VAL15]] Aligned 4
   store <16 x i32> %0, ptr addrspace(10) @G_16, align 64
   ret void
 }
diff --git a/llvm/test/CodeGen/SPIRV/legalization/spv-extractelt-legalization.ll b/llvm/test/CodeGen/SPIRV/legalization/spv-extractelt-legalization.ll
new file mode 100644
index 0000000000000..3188f8b31aac5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/spv-extractelt-legalization.ll
@@ -0,0 +1,66 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
+
+; CHECK-DAG: %[[#Int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#Const0:]] = OpConstant %[[#Int]] 0
+; CHECK-DAG: %[[#Const1:]] = OpConstant %[[#Int]] 1
+; CHECK-DAG: %[[#Const2:]] = OpConstant %[[#Int]] 2
+; CHECK-DAG: %[[#Const3:]] = OpConstant %[[#Int]] 3
+; CHECK-DAG: %[[#Const4:]] = OpConstant %[[#Int]] 4
+; CHECK-DAG: %[[#Const5:]] = OpConstant %[[#Int]] 5
+; CHECK-DAG: %[[#Const10:]] = OpConstant %[[#Int]] 10
+; CHECK-DAG: %[[#Const20:]] = OpConstant %[[#Int]] 20
+; CHECK-DAG: %[[#Const30:]] = OpConstant %[[#Int]] 30
+; CHECK-DAG: %[[#Const40:]] = OpConstant %[[#Int]] 40
+; CHECK-DAG: %[[#Const50:]] = OpConstant %[[#Int]] 50
+; CHECK-DAG: %[[#Const60:]] = OpConstant %[[#Int]] 60
+; CHECK-DAG: %[[#Arr:]] = OpTypeArray %[[#Int]] %[[#]]
+; CHECK-DAG: %[[#PtrArr:]] = OpTypePointer Function %[[#Arr]]
+
+ at G = addrspace(1) global i32 0, align 4
+
+define void @main() #0 {
+entry:
+; CHECK: %[[#Var:]] = OpVariable %[[#PtrArr]] Function
+
+; CHECK: %[[#Idx:]] = OpLoad %[[#Int]]
+  %idx = load i32, ptr addrspace(1) @G, align 4
+
+
+; CHECK: %[[#PtrElt0:]] = OpInBoundsAccessChain %[[#]] %[[#Var]] %[[#Const0]]
+; CHECK: OpStore %[[#PtrElt0]] %[[#Const10]]
+  %vec = insertelement <6 x i32> poison, i32 10, i64 0
+
+; CHECK: %[[#PtrElt1:]] = OpInBoundsAccessChain %[[#]] %[[#Var]] %[[#Const1]]
+; CHECK: OpStore %[[#PtrElt1]] %[[#Const20]]
+  %vec2 = insertelement <6 x i32> %vec, i32 20, i64 1
+
+; CHECK: %[[#PtrElt2:]] = OpInBoundsAccessChain %[[#]] %[[#Var]] %[[#Const2]]
+; CHECK: OpStore %[[#PtrElt2]] %[[#Const30]]
+  %vec3 = insertelement <6 x i32> %vec2, i32 30, i64 2
+
+; CHECK: %[[#PtrElt3:]] = OpInBoundsAccessChain %[[#]] %[[#Var]] %[[#Const3]]
+; CHECK: OpStore %[[#PtrElt3]] %[[#Const40]]
+  %vec4 = insertelement <6 x i32> %vec3, i32 40, i64 3
+
+; CHECK: %[[#PtrElt4:]] = OpInBoundsAccessChain %[[#]] %[[#Var]] %[[#Const4]]
+; CHECK: OpStore %[[#PtrElt4]] %[[#Const50]]
+  %vec5 = insertelement <6 x i32> %vec4, i32 50, i64 4
+
+; CHECK: %[[#PtrElt5:]] = OpInBoundsAccessChain %[[#]] %[[#Var]] %[[#Const5]]
+; CHECK: OpStore %[[#PtrElt5]] %[[#Const60]]
+  %vec6 = insertelement <6 x i32> %vec5, i32 60, i64 5
+
+; CHECK: %[[#Ptr:]] = OpInBoundsAccessChain %[[#]] %[[#Var]] %[[#Idx]]
+; CHECK: %[[#Ld:]] = OpLoad %[[#Int]] %[[#Ptr]]
+  %res = extractelement <6 x i32> %vec6, i32 %idx
+  
+; CHECK: OpStore {{.*}} %[[#Ld]]
+  store i32 %res, ptr addrspace(1) @G, align 4
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+
+
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
index d1cbfd4811c30..028495b10bee2 100644
--- a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
@@ -42,18 +42,19 @@ entry:
   ; CHECK: %[[#Sub2:]] = OpFSub %[[#v4f32]] %[[#Add2]]
   %13 = fsub reassoc nnan ninf nsz arcp afn <6 x float> %11, %9
 
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 0
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 1
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 2
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 3
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 0
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 1
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT0:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 0
+  ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 1
+  ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 2
+  ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 3
+  ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 0
+  ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 1
+
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT0]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT1]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT2]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT3]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT4]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT5]]
   
   %14 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
   store <6 x float> %13, ptr addrspace(10) %14, align 4
@@ -119,24 +120,12 @@ entry:
   ; CHECK: %[[#UMod6:]] = OpUMod %[[#int]] %[[#SRem6]]
   %12 = urem <6 x i32> %11, splat (i32 3)
 
-  ; CHECK: %[[#Construct1:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod1]] %[[#UMod2]] %[[#UMod3]] %[[#UMod4]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct1]] 0
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#Construct2:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod1]] %[[#UMod2]] %[[#UMod3]] %[[#UMod4]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct2]] 1
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#Construct3:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod1]] %[[#UMod2]] %[[#UMod3]] %[[#UMod4]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct3]] 2
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#Construct4:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod1]] %[[#UMod2]] %[[#UMod3]] %[[#UMod4]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct4]] 3
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#Construct5:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod5]] %[[#UMod6]] %[[#UNDEF]] %[[#UNDEF]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct5]] 0
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#Construct6:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod5]] %[[#UMod6]] %[[#UNDEF]] %[[#UNDEF]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct6]] 1
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: OpStore {{.*}} %[[#UMod1]]
+  ; CHECK: OpStore {{.*}} %[[#UMod2]]
+  ; CHECK: OpStore {{.*}} %[[#UMod3]]
+  ; CHECK: OpStore {{.*}} %[[#UMod4]]
+  ; CHECK: OpStore {{.*}} %[[#UMod5]]
+  ; CHECK: OpStore {{.*}} %[[#UMod6]]
 
   %13 = getelementptr [4 x [6 x i32] ], ptr addrspace(10) @i2, i32 0, i32 0
   store <6 x i32> %12, ptr addrspace(10) %13, align 4
@@ -168,18 +157,19 @@ entry:
   ; CHECK: %[[#Fma2:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
   %8 = call reassoc nnan ninf nsz arcp afn <6 x float> @llvm.fma.v6f32(<6 x float> %5, <6 x float> %6, <6 x float> %7)
 
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 0
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 1
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 2
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 3
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 0
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 1
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT0:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 0
+  ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 1
+  ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 2
+  ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 3
+  ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 0
+  ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 1
+
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT0]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT1]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT2]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT3]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT4]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT5]]
 
   %9 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
   store <6 x float> %8, ptr addrspace(10) %9, align 4
@@ -201,18 +191,19 @@ entry:
   ; CHECK: %[[#Fma2:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
   %8 = call <6 x float> @llvm.experimental.constrained.fma.v6f32(<6 x float> %3, <6 x float> %5, <6 x float> %7, metadata !"round.dynamic", metadata !"fpexcept.strict")
 
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 0
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 1
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 2
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 3
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 0
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 1
-  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT0:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 0
+  ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 1
+  ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 2
+  ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 3
+  ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 0
+  ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 1
+
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT0]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT1]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT2]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT3]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT4]]
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT5]]
 
   %9 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
   store <6 x float> %8, ptr addrspace(10) %9, align 4
diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-multiply.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-multiply.ll
index 4f8dfd0494009..cb41c8a1ce2f7 100644
--- a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-multiply.ll
+++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-multiply.ll
@@ -88,18 +88,15 @@ define internal void @test_matrix_multiply_i32_2x2_2x2() {
 
 ; Test Matrix Multiply 2x3 * 3x2 float (Result 2x2 float)
 ; CHECK-LABEL: ; -- Begin function test_matrix_multiply_f32_2x3_3x2
-; CHECK-DAG:   %[[B:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]]
-; CHECK-DAG:   %[[A:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]]
+; CHECK:       %[[Col0B:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]] {{.*}} {{.*}} {{.*}}
+; CHECK:       %[[Col1B:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]] {{.*}} {{.*}} {{.*}}
+; CHECK:       %[[Row0A:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]] {{.*}} {{.*}} {{.*}}
+; CHECK:       %[[Row1A:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]] {{.*}} {{.*}} {{.*}}
 ;
-; CHECK-DAG:   %[[B_Col0:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]]
-; CHECK-DAG:   %[[B_Col1:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]]
-; CHECK-DAG:   %[[A_Row0:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]]
-; CHECK-DAG:   %[[A_Row1:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]]
-;
-; CHECK-DAG:   %[[C00:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row0]] %[[B_Col0]]
-; CHECK-DAG:   %[[C10:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row1]] %[[B_Col0]]
-; CHECK-DAG:   %[[C01:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row0]] %[[B_Col1]]
-; CHECK-DAG:   %[[C11:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row1]] %[[B_Col1]]
+; CHECK-DAG:   %[[C00:[0-9]+]] = OpDot %[[Float_ID]] %[[Row0A]] %[[Col0B]]
+; CHECK-DAG:   %[[C10:[0-9]+]] = OpDot %[[Float_ID]] %[[Row1A]] %[[Col0B]]
+; CHECK-DAG:   %[[C01:[0-9]+]] = OpDot %[[Float_ID]] %[[Row0A]] %[[Col1B]]
+; CHECK-DAG:   %[[C11:[0-9]+]] = OpDot %[[Float_ID]] %[[Row1A]] %[[Col1B]]
 ; CHECK:       OpCompositeConstruct %[[V4F32_ID]] %[[C00]] %[[C10]] %[[C01]] %[[C11]]
 define internal void @test_matrix_multiply_f32_2x3_3x2() {
   %1 = load <6 x float>, ptr addrspace(10) @private_v6f32
diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-transpose.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-transpose.ll
index 3474fecae9957..3106d5d55ef77 100644
--- a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-transpose.ll
+++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-transpose.ll
@@ -38,48 +38,36 @@ define internal void @test_transpose_f32_2x3() {
 ; CHECK: %[[Load6:[0-9]+]] = OpLoad %[[Float_ID]] %[[AccessChain6]]
 ;
 ; -- Construct intermediate vectors
-; CHECK: %[[CompositeInsert1:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load1]] %[[undef_V4F32_ID:[0-9]+]] 0
-; CHECK: %[[CompositeInsert2:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load2]] %[[CompositeInsert1]] 1
-; CHECK: %[[CompositeInsert3:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load3]] %[[CompositeInsert2]] 2
-; CHECK: %[[CompositeInsert4:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load4]] %[[CompositeInsert3]] 3
-; CHECK: %[[CompositeInsert5:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load5]] %[[undef_V4F32_ID]] 0
-; CHECK: %[[CompositeInsert6:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load6]] %[[CompositeInsert5]] 1
+; CHECK: %[[Construct1:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Load1]] %[[Load2]] %[[Load3]] %[[Load4]]
+; CHECK: %[[Construct2:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Load1]] %[[Load2]] %[[Load3]] %[[Load4]]
+; CHECK: %[[Construct3:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Load5]] %[[Load6]] {{.*}} {{.*}}
+; CHECK: %[[Construct4:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Load1]] %[[Load2]] %[[Load3]] %[[Load4]]
+; CHECK: %[[Construct5:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Load1]] %[[Load2]] %[[Load3]] %[[Load4]]
+; CHECK: %[[Construct6:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Load5]] %[[Load6]] {{.*}} {{.*}}
   %1 = load <6 x float>, ptr addrspace(10) @private_v6f32
 
 ; -- Extract elements for transposition
-; CHECK: %[[Extract1:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert4]] 0
-; CHECK: %[[Extract2:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert4]] 2
-; CHECK: %[[Extract3:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert6]] 0
-; CHECK: %[[Extract4:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert4]] 1
-; CHECK: %[[Extract5:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert4]] 3
-; CHECK: %[[Extract6:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert6]] 1
+; CHECK: %[[Extract1:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[Construct1]] 0
+; CHECK: %[[Extract2:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[Construct2]] 2
+; CHECK: %[[Extract3:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[Construct3]] 0
+; CHECK: %[[Extract4:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[Construct4]] 1
+; CHECK: %[[Extract5:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[Construct5]] 3
+; CHECK: %[[Extract6:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[Construct6]] 1
   %2 = call <6 x float> @llvm.matrix.transpose.v6f32.i32(<6 x float> %1, i32 2, i32 3)
 
 ; -- Store output 3x2 matrix elements
 ; CHECK: %[[AccessChain7:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_0]]
-; CHECK: %[[CompositeConstruct1:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract1]] %[[Extract2]] %[[Extract3]] %[[Extract4]]
-; CHECK: %[[Extract7:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct1]] 0
-; CHECK: OpStore %[[AccessChain7]] %[[Extract7]]
+; CHECK: OpStore %[[AccessChain7]] %[[Extract1]]
 ; CHECK: %[[AccessChain8:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_1]]
-; CHECK: %[[CompositeConstruct2:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract1]] %[[Extract2]] %[[Extract3]] %[[Extract4]]
-; CHECK: %[[Extract8:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct2]] 1
-; CHECK: OpStore %[[AccessChain8]] %[[Extract8]]
+; CHECK: OpStore %[[AccessChain8]] %[[Extract2]]
 ; CHECK: %[[AccessChain9:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_2]]
-; CHECK: %[[CompositeConstruct3:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract1]] %[[Extract2]] %[[Extract3]] %[[Extract4]]
-; CHECK: %[[Extract9:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct3]] 2
-; CHECK: OpStore %[[AccessChain9]] %[[Extract9]]
+; CHECK: OpStore %[[AccessChain9]] %[[Extract3]]
 ; CHECK: %[[AccessChain10:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_3]]
-; CHECK: %[[CompositeConstruct4:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract1]] %[[Extract2]] %[[Extract3]] %[[Extract4]]
-; CHECK: %[[Extract10:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct4]] 3
-; CHECK: OpStore %[[AccessChain10]] %[[Extract10]]
+; CHECK: OpStore %[[AccessChain10]] %[[Extract4]]
 ; CHECK: %[[AccessChain11:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_4]]
-; CHECK: %[[CompositeConstruct5:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract5]] %[[Extract6]] %[[undef_Float_ID:[0-9]+]] %[[undef_Float_ID]]
-; CHECK: %[[Extract11:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct5]] 0
-; CHECK: OpStore %[[AccessChain11]] %[[Extract11]]
+; CHECK: OpStore %[[AccessChain11]] %[[Extract5]]
 ; CHECK: %[[AccessChain12:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_5]]
-; CHECK: %[[CompositeConstruct6:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract5]] %[[Extract6]] %[[undef_Float_ID]] %[[undef_Float_ID]]
-; CHECK: %[[Extract12:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct6]] 1
-; CHECK: OpStore %[[AccessChain12]] %[[Extract12]]
+; CHECK: OpStore %[[AccessChain12]] %[[Extract6]]
   store <6 x float> %2, ptr addrspace(10) @private_v6f32
   ret void
 }

>From 1a3b8ba2dc408d46ca4a6fc90d3e4cede382aefc Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Mon, 5 Jan 2026 10:28:06 -0500
Subject: [PATCH 2/4] Refactor to simplify code.

---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 299 +++++++++----------
 1 file changed, 140 insertions(+), 159 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 169e34bfffce6..01756cd51cfb1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -14,6 +14,7 @@
 #include "SPIRV.h"
 #include "SPIRVGlobalRegistry.h"
 #include "SPIRVSubtarget.h"
+#include "SPIRVUtils.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
@@ -643,196 +644,176 @@ static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
          NumElements > MaxVectorSize;
 }
 
-bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
-                                           MachineInstr &MI) const {
-  LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
+static bool legalizeSpvBitcast(LegalizerHelper &Helper, MachineInstr &MI,
+                               SPIRVGlobalRegistry *GR) {
+  LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
+  const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
 
+  Register DstReg = MI.getOperand(0).getReg();
+  Register SrcReg = MI.getOperand(2).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+  LLT SrcTy = MRI.getType(SrcReg);
+
+  // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
+  // allow using the generic legalization rules.
+  if (needsVectorLegalization(DstTy, ST) ||
+      needsVectorLegalization(SrcTy, ST)) {
+    LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
+    MIRBuilder.buildBitcast(DstReg, SrcReg);
+    MI.eraseFromParent();
+  }
+  return true;
+}
+
+static bool legalizeSpvInsertElt(LegalizerHelper &Helper, MachineInstr &MI,
+                                 SPIRVGlobalRegistry *GR) {
   MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
   MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
   const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
 
-  auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
-  if (IntrinsicID == Intrinsic::spv_bitcast) {
-    LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
-    Register DstReg = MI.getOperand(0).getReg();
+  Register DstReg = MI.getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+
+  if (needsVectorLegalization(DstTy, ST)) {
     Register SrcReg = MI.getOperand(2).getReg();
-    LLT DstTy = MRI.getType(DstReg);
+    Register ValReg = MI.getOperand(3).getReg();
     LLT SrcTy = MRI.getType(SrcReg);
-
-    // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
-    // allow using the generic legalization rules.
-    if (needsVectorLegalization(DstTy, ST) ||
-        needsVectorLegalization(SrcTy, ST)) {
-      LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
-      MIRBuilder.buildBitcast(DstReg, SrcReg);
-      MI.eraseFromParent();
-    }
-    return true;
-  } else if (IntrinsicID == Intrinsic::spv_insertelt) {
-    Register DstReg = MI.getOperand(0).getReg();
-    LLT DstTy = MRI.getType(DstReg);
-
-    if (needsVectorLegalization(DstTy, ST)) {
-      Register DstReg = MI.getOperand(0).getReg();
-      Register SrcReg = MI.getOperand(2).getReg();
-      Register ValReg = MI.getOperand(3).getReg();
-      LLT SrcTy = MRI.getType(SrcReg);
-      MachineOperand &IdxOperand = MI.getOperand(4);
-
-      if (getImm(IdxOperand, &MRI)) {
-        uint64_t IdxVal = foldImm(IdxOperand, &MRI);
-        if (IdxVal < SrcTy.getNumElements()) {
-          SmallVector<Register, 8> Regs;
-          SPIRVType *ElementType = GR->getScalarOrVectorComponentType(
-              GR->getSPIRVTypeForVReg(DstReg));
-          LLT ElementLLTTy = GR->getRegType(ElementType);
-          for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
-            Register Reg = MRI.createGenericVirtualRegister(ElementLLTTy);
-            MRI.setRegClass(Reg, GR->getRegClass(ElementType));
-            GR->assignSPIRVTypeToVReg(ElementType, Reg, *MI.getMF());
-            Regs.push_back(Reg);
-          }
-          MIRBuilder.buildUnmerge(Regs, SrcReg);
-          Regs[IdxVal] = ValReg;
-          MIRBuilder.buildBuildVector(DstReg, Regs);
-          MI.eraseFromParent();
-          return true;
+    MachineOperand &IdxOperand = MI.getOperand(4);
+
+    if (getImm(IdxOperand, &MRI)) {
+      uint64_t IdxVal = foldImm(IdxOperand, &MRI);
+      if (IdxVal < SrcTy.getNumElements()) {
+        SmallVector<Register, 8> Regs;
+        SPIRVType *ElementType =
+            GR->getScalarOrVectorComponentType(GR->getSPIRVTypeForVReg(DstReg));
+        LLT ElementLLTTy = GR->getRegType(ElementType);
+        for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
+          Register Reg = MRI.createGenericVirtualRegister(ElementLLTTy);
+          MRI.setRegClass(Reg, GR->getRegClass(ElementType));
+          GR->assignSPIRVTypeToVReg(ElementType, Reg, *MI.getMF());
+          Regs.push_back(Reg);
         }
+        MIRBuilder.buildUnmerge(Regs, SrcReg);
+        Regs[IdxVal] = ValReg;
+        MIRBuilder.buildBuildVector(DstReg, Regs);
+        MI.eraseFromParent();
+        return true;
       }
+    }
 
-      LLT EltTy = SrcTy.getElementType();
-      Align VecAlign = Helper.getStackTemporaryAlignment(SrcTy);
+    LLT EltTy = SrcTy.getElementType();
+    Align VecAlign = Helper.getStackTemporaryAlignment(SrcTy);
 
-      MachinePointerInfo PtrInfo;
-      auto StackTemp = Helper.createStackTemporary(
-          TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
+    MachinePointerInfo PtrInfo;
+    auto StackTemp = Helper.createStackTemporary(
+        TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
 
-      MIRBuilder.buildStore(SrcReg, StackTemp, PtrInfo, VecAlign);
+    MIRBuilder.buildStore(SrcReg, StackTemp, PtrInfo, VecAlign);
 
-      Register IdxReg = IdxOperand.getReg();
-      LLT PtrTy = MRI.getType(StackTemp.getReg(0));
-      Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
-      auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
+    Register IdxReg = IdxOperand.getReg();
+    LLT PtrTy = MRI.getType(StackTemp.getReg(0));
+    Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
+    auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
 
-      MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
-          .addImm(1) // InBounds
-          .addUse(StackTemp.getReg(0))
-          .addUse(Zero.getReg(0))
-          .addUse(IdxReg);
+    MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
+        .addImm(1) // InBounds
+        .addUse(StackTemp.getReg(0))
+        .addUse(Zero.getReg(0))
+        .addUse(IdxReg);
 
-      MachinePointerInfo EltPtrInfo =
-          MachinePointerInfo(PtrTy.getAddressSpace());
-      Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
-      MIRBuilder.buildStore(ValReg, EltPtr, EltPtrInfo, EltAlign);
+    MachinePointerInfo EltPtrInfo = MachinePointerInfo(PtrTy.getAddressSpace());
+    Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
+    MIRBuilder.buildStore(ValReg, EltPtr, EltPtrInfo, EltAlign);
 
-      MIRBuilder.buildLoad(DstReg, StackTemp, PtrInfo, VecAlign);
-      MI.eraseFromParent();
-      return true;
-    }
+    MIRBuilder.buildLoad(DstReg, StackTemp, PtrInfo, VecAlign);
+    MI.eraseFromParent();
     return true;
-  } else if (IntrinsicID == Intrinsic::spv_extractelt) {
-    Register SrcReg = MI.getOperand(2).getReg();
-    LLT SrcTy = MRI.getType(SrcReg);
+  }
+  return true;
+}
+
+static bool legalizeSpvExtractElt(LegalizerHelper &Helper, MachineInstr &MI,
+                                  SPIRVGlobalRegistry *GR) {
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
+  const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
 
-    if (needsVectorLegalization(SrcTy, ST)) {
-      Register DstReg = MI.getOperand(0).getReg();
-      MachineOperand &IdxOperand = MI.getOperand(3);
-
-      if (getImm(IdxOperand, &MRI)) {
-        uint64_t IdxVal = foldImm(IdxOperand, &MRI);
-        if (IdxVal < SrcTy.getNumElements()) {
-          LLT DstTy = MRI.getType(DstReg);
-          SmallVector<Register, 8> Regs;
-          SPIRVType *DstSpvTy = GR->getSPIRVTypeForVReg(DstReg);
-          for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
-            if (I == IdxVal) {
-              Regs.push_back(DstReg);
-            } else {
-              Register Reg = MRI.createGenericVirtualRegister(DstTy);
-              MRI.setRegClass(Reg, GR->getRegClass(DstSpvTy));
-              GR->assignSPIRVTypeToVReg(DstSpvTy, Reg, *MI.getMF());
-              Regs.push_back(Reg);
-            }
+  Register SrcReg = MI.getOperand(2).getReg();
+  LLT SrcTy = MRI.getType(SrcReg);
+
+  if (needsVectorLegalization(SrcTy, ST)) {
+    Register DstReg = MI.getOperand(0).getReg();
+    MachineOperand &IdxOperand = MI.getOperand(3);
+
+    if (getImm(IdxOperand, &MRI)) {
+      uint64_t IdxVal = foldImm(IdxOperand, &MRI);
+      if (IdxVal < SrcTy.getNumElements()) {
+        LLT DstTy = MRI.getType(DstReg);
+        SmallVector<Register, 8> Regs;
+        SPIRVType *DstSpvTy = GR->getSPIRVTypeForVReg(DstReg);
+        for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
+          if (I == IdxVal) {
+            Regs.push_back(DstReg);
+          } else {
+            Register Reg = MRI.createGenericVirtualRegister(DstTy);
+            MRI.setRegClass(Reg, GR->getRegClass(DstSpvTy));
+            GR->assignSPIRVTypeToVReg(DstSpvTy, Reg, *MI.getMF());
+            Regs.push_back(Reg);
           }
-          MIRBuilder.buildUnmerge(Regs, SrcReg);
-          MI.eraseFromParent();
-          return true;
         }
+        MIRBuilder.buildUnmerge(Regs, SrcReg);
+        MI.eraseFromParent();
+        return true;
       }
+    }
 
-      LLT EltTy = SrcTy.getElementType();
-      Align VecAlign = Helper.getStackTemporaryAlignment(SrcTy);
-
-      MachinePointerInfo PtrInfo;
-      auto StackTemp = Helper.createStackTemporary(
-          TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
-
-      // Set the type of StackTemp to a pointer to an array of the element type.
-      SPIRVType *SpvSrcTy = GR->getSPIRVTypeForVReg(SrcReg);
-      SPIRVType *EltSpvTy = GR->getScalarOrVectorComponentType(SpvSrcTy);
-      const Type *LLVMEltTy = GR->getTypeForSPIRVType(EltSpvTy);
-      const Type *LLVMArrTy =
-          ArrayType::get(const_cast<Type *>(LLVMEltTy), SrcTy.getNumElements());
-      SPIRVType *ArrSpvTy = GR->getOrCreateSPIRVType(
-          LLVMArrTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
-      SPIRVType *PtrToArrSpvTy = GR->getOrCreateSPIRVPointerType(
-          ArrSpvTy, MIRBuilder, SPIRV::StorageClass::Function);
-      setRegClassType(StackTemp.getReg(0), PtrToArrSpvTy, GR, &MRI,
-                      MIRBuilder.getMF());
-
-      // Store the vector elements one by one.
-      SmallVector<Register, 8> Regs;
-      for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
-        Register Reg = MRI.createGenericVirtualRegister(EltTy);
-        MRI.setRegClass(Reg, GR->getRegClass(EltSpvTy));
-        GR->assignSPIRVTypeToVReg(EltSpvTy, Reg, *MI.getMF());
-        Regs.push_back(Reg);
-      }
-      MIRBuilder.buildUnmerge(Regs, SrcReg);
-
-      auto ZeroNew = MIRBuilder.buildConstant(LLT::scalar(32), 0);
-      LLT PtrTyNew = MRI.getType(StackTemp.getReg(0));
-
-      for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
-        auto Idx = MIRBuilder.buildConstant(LLT::scalar(32), I);
-        Register EltPtr = MRI.createGenericVirtualRegister(PtrTyNew);
-        MIRBuilder
-            .buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
-            .addImm(1) // InBounds
-            .addUse(StackTemp.getReg(0))
-            .addUse(ZeroNew.getReg(0))
-            .addUse(Idx.getReg(0));
-
-        MachinePointerInfo EltPtrInfo =
-            PtrInfo.getWithOffset(I * EltTy.getSizeInBytes());
-        Align EltAlign = commonAlignment(VecAlign, I * EltTy.getSizeInBytes());
-        MIRBuilder.buildStore(Regs[I], EltPtr, EltPtrInfo, EltAlign);
-      }
+    LLT EltTy = SrcTy.getElementType();
+    Align VecAlign = Helper.getStackTemporaryAlignment(SrcTy);
+
+    MachinePointerInfo PtrInfo;
+    auto StackTemp = Helper.createStackTemporary(
+        TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
 
-      Register IdxReg = IdxOperand.getReg();
-      LLT PtrTy = MRI.getType(StackTemp.getReg(0));
-      Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
-      auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
+    MIRBuilder.buildStore(SrcReg, StackTemp, PtrInfo, VecAlign);
 
-      MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
-          .addImm(1) // InBounds
-          .addUse(StackTemp.getReg(0))
-          .addUse(Zero.getReg(0))
-          .addUse(IdxReg);
+    Register IdxReg = IdxOperand.getReg();
+    LLT PtrTy = MRI.getType(StackTemp.getReg(0));
+    Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
+    auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
+
+    MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
+        .addImm(1) // InBounds
+        .addUse(StackTemp.getReg(0))
+        .addUse(Zero.getReg(0))
+        .addUse(IdxReg);
 
-      MachinePointerInfo EltPtrInfo =
-          MachinePointerInfo(PtrTy.getAddressSpace());
-      Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
-      MIRBuilder.buildLoad(DstReg, EltPtr, EltPtrInfo, EltAlign);
+    MachinePointerInfo EltPtrInfo = MachinePointerInfo(PtrTy.getAddressSpace());
+    Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
+    MIRBuilder.buildLoad(DstReg, EltPtr, EltPtrInfo, EltAlign);
 
-      MI.eraseFromParent();
-      return true;
-    }
+    MI.eraseFromParent();
     return true;
   }
   return true;
 }
 
+bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
+                                           MachineInstr &MI) const {
+  LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
+  auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
+  switch (IntrinsicID) {
+  case Intrinsic::spv_bitcast:
+    return legalizeSpvBitcast(Helper, MI, GR);
+  case Intrinsic::spv_insertelt:
+    return legalizeSpvInsertElt(Helper, MI, GR);
+  case Intrinsic::spv_extractelt:
+    return legalizeSpvExtractElt(Helper, MI, GR);
+  }
+  return true;
+}
+
 bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
                                          MachineInstr &MI) const {
   // Once the G_BITCAST is using vectors that are allowed, we turn it back into

>From 7d7ee765062ce52dc76af8a7db96d039b1620f1a Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 6 Jan 2026 12:25:35 -0500
Subject: [PATCH 3/4] Fix G_FRAME_INDEX type assignment in spv_extractelt
 legalization

Restore the logic to assign a SPIR-V pointer-to-array type to the stack
temporary created during spv_extractelt legalization. This was accidentally
removed in the previous refactor and is required for legality.
---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 01756cd51cfb1..f755733cfe9c6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -776,6 +776,21 @@ static bool legalizeSpvExtractElt(LegalizerHelper &Helper, MachineInstr &MI,
     auto StackTemp = Helper.createStackTemporary(
         TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
 
+    // Set the type of StackTemp to a pointer to an array of the element type.
+    SPIRVType *SpvSrcTy = GR->getSPIRVTypeForVReg(SrcReg);
+    SPIRVType *EltSpvTy = GR->getScalarOrVectorComponentType(SpvSrcTy);
+    const Type *LLVMEltTy = GR->getTypeForSPIRVType(EltSpvTy);
+    const Type *LLVMArrTy =
+        ArrayType::get(const_cast<Type *>(LLVMEltTy), SrcTy.getNumElements());
+    SPIRVType *ArrSpvTy = GR->getOrCreateSPIRVType(
+        LLVMArrTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
+    SPIRVType *PtrToArrSpvTy = GR->getOrCreateSPIRVPointerType(
+        ArrSpvTy, MIRBuilder, SPIRV::StorageClass::Function);
+
+    Register StackReg = StackTemp.getReg(0);
+    MRI.setRegClass(StackReg, GR->getRegClass(PtrToArrSpvTy));
+    GR->assignSPIRVTypeToVReg(PtrToArrSpvTy, StackReg, *MI.getMF());
+
     MIRBuilder.buildStore(SrcReg, StackTemp, PtrInfo, VecAlign);
 
     Register IdxReg = IdxOperand.getReg();

>From adc59340e8fa86294812165d21c2583b9393e5c3 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 7 Jan 2026 09:04:29 -0500
Subject: [PATCH 4/4] Remove empty function.

---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index f755733cfe9c6..f6587ba068c0e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -535,11 +535,6 @@ static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
   return ConvReg;
 }
 
-static bool legalizeLoad(LegalizerHelper &Helper, MachineInstr &MI,
-                         SPIRVGlobalRegistry *GR) {
-  return true;
-}
-
 static bool legalizeStore(LegalizerHelper &Helper, MachineInstr &MI,
                           SPIRVGlobalRegistry *GR) {
   MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
@@ -628,8 +623,6 @@ bool SPIRVLegalizerInfo::legalizeCustom(
     }
     return true;
   }
-  case TargetOpcode::G_LOAD:
-    return legalizeLoad(Helper, MI, GR);
   case TargetOpcode::G_STORE:
     return legalizeStore(Helper, MI, GR);
   }



More information about the llvm-commits mailing list