[llvm] [SPIRV] Add legalization for long vectors (PR #169665)

Steven Perron via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 1 08:39:49 PST 2025


https://github.com/s-perron updated https://github.com/llvm/llvm-project/pull/169665

>From 92c6074a8368558df07e1ba38e69935b80910db5 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 26 Nov 2025 09:41:23 -0500
Subject: [PATCH 1/3] [SPIRV] Add legalization for long vectors

This patch introduces the necessary infrastructure to legalize vector
operations on vectors that are longer than what the SPIR-V target
supports. For instance, shaders only support vectors up to 4 elements.

The legalization is done by splitting the long vectors into smaller
vectors of a legal size.

Specifically, this patch does the following:
- Introduces `vectorElementCountIsGreaterThan` and
  `vectorElementCountIsLessThanOrEqualTo` legality predicates.
- Adds legalization rules for `G_SHUFFLE_VECTOR`, `G_EXTRACT_VECTOR_ELT`,
  `G_BUILD_VECTOR`, `G_CONCAT_VECTORS`, `G_SPLAT_VECTOR`, and
  `G_UNMERGE_VALUES`.
- Handles `G_BITCAST` of long vectors by converting them to
  `@llvm.spv.bitcast` intrinsics which are then legalized.
- Updates `selectUnmergeValues` to handle extraction of both scalars
  and vectors from a larger vector, using `OpCompositeExtract` and
  `OpVectorShuffle` respectively.

Fixes: https://github.com/llvm/llvm-project/pull/165444
---
 .../llvm/CodeGen/GlobalISel/LegalizerInfo.h   |  10 ++
 .../CodeGen/GlobalISel/LegalityPredicates.cpp |  20 +++
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  50 ++++--
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  | 170 ++++++++++++++++--
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h    |   4 +
 .../vector-legalization-kernel.ll             |  69 +++++++
 6 files changed, 299 insertions(+), 24 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 51318c9c2736d..9324bab3fe656 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
 LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
                                                    unsigned Size);
 
+/// True iff the specified type index is a vector with a number of elements
+/// that's greater than the given size.
+LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
+                                                           unsigned Size);
+
+/// True iff the specified type index is a vector with a number of elements
+/// that's less than or equal to the given size.
+LLVM_ABI LegalityPredicate
+vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size);
+
 /// True iff the specified type index is a scalar or a vector with an element
 /// type that's wider than the given size.
 LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
index 30c2d089c3121..5e7cd5fd5d9ad 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
@@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
   };
 }
 
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
+                                                    unsigned Size) {
+
+  return [=](const LegalityQuery &Query) {
+    const LLT QueryTy = Query.Types[TypeIdx];
+    return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
+  };
+}
+
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx,
+                                                          unsigned Size) {
+
+  return [=](const LegalityQuery &Query) {
+    const LLT QueryTy = Query.Types[TypeIdx];
+    return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size;
+  };
+}
+
 LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
                                                            unsigned Size) {
   return [=](const LegalityQuery &Query) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 2c27289e759eb..a2e29366dc4cc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1781,33 +1781,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
   unsigned ArgI = I.getNumOperands() - 1;
   Register SrcReg =
       I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
-  SPIRVType *DefType =
+  SPIRVType *SrcType =
       SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
-  if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+  if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
     report_fatal_error(
         "cannot select G_UNMERGE_VALUES with a non-vector argument");
 
   SPIRVType *ScalarType =
-      GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+      GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
   MachineBasicBlock &BB = *I.getParent();
   bool Res = false;
+  unsigned CurrentIndex = 0;
   for (unsigned i = 0; i < I.getNumDefs(); ++i) {
     Register ResVReg = I.getOperand(i).getReg();
     SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
     if (!ResType) {
-      // There was no "assign type" actions, let's fix this now
-      ResType = ScalarType;
+      LLT ResLLT = MRI->getType(ResVReg);
+      assert(ResLLT.isValid());
+      if (ResLLT.isVector()) {
+        ResType = GR.getOrCreateSPIRVVectorType(
+            ScalarType, ResLLT.getNumElements(), I, TII);
+      } else {
+        ResType = ScalarType;
+      }
       MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
-      MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
       GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
     }
-    auto MIB =
-        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
-            .addDef(ResVReg)
-            .addUse(GR.getSPIRVTypeID(ResType))
-            .addUse(SrcReg)
-            .addImm(static_cast<int64_t>(i));
-    Res |= MIB.constrainAllUses(TII, TRI, RBI);
+
+    if (ResType->getOpcode() == SPIRV::OpTypeVector) {
+      Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
+      auto MIB =
+          BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
+              .addDef(ResVReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(SrcReg)
+              .addUse(UndefReg);
+      unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
+      for (unsigned j = 0; j < NumElements; ++j) {
+        MIB.addImm(CurrentIndex + j);
+      }
+      CurrentIndex += NumElements;
+      Res |= MIB.constrainAllUses(TII, TRI, RBI);
+    } else {
+      auto MIB =
+          BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+              .addDef(ResVReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(SrcReg)
+              .addImm(CurrentIndex);
+      CurrentIndex++;
+      Res |= MIB.constrainAllUses(TII, TRI, RBI);
+    }
   }
   return Res;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 53074ea3b2597..c9c663ee3309e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -14,16 +14,22 @@
 #include "SPIRV.h"
 #include "SPIRVGlobalRegistry.h"
 #include "SPIRVSubtarget.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/CodeGen/MachineInstr.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 
 using namespace llvm;
 using namespace llvm::LegalizeActions;
 using namespace llvm::LegalityPredicates;
 
+#define DEBUG_TYPE "spirv-legalizer"
+
 LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
   return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
     const LLT Ty = Query.Types[TypeIdx];
@@ -101,6 +107,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                      v4s64, v8s1,   v8s8,   v8s16, v8s32, v8s64, v16s1,
                      v16s8, v16s16, v16s32, v16s64};
 
+  auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
+                           v3s1, v3s8, v3s16, v3s32, v3s64,
+                           v4s1, v4s8, v4s16, v4s32, v4s64};
+
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
@@ -126,6 +136,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
 
+  auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
+
   bool IsExtendedInts =
       ST.canUseExtension(
           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
@@ -148,14 +160,65 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
         return IsExtendedInts && Ty.isValid();
       };
 
-  for (auto Opc : getTypeFoldingSupportedOpcodes())
-    getActionDefinitionsBuilder(Opc).custom();
+  uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
 
-  getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
+  for (auto Opc : getTypeFoldingSupportedOpcodes()) {
+    if (Opc != G_EXTRACT_VECTOR_ELT)
+      getActionDefinitionsBuilder(Opc).custom();
+  }
 
-  // TODO: add proper rules for vectors legalization.
-  getActionDefinitionsBuilder(
-      {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
+  getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
+
+  getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
+      .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
+      .moreElementsToNextPow2(0)
+      .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
+      .moreElementsToNextPow2(1)
+      .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
+      .alwaysLegal();
+
+  getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize))
+      .moreElementsToNextPow2(1)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           1, ElementCount::getFixed(MaxVectorSize)))
+      .custom();
+
+  // Illegal G_UNMERGE_VALUES instructions should be handled
+  // during the combine phase.
+  getActionDefinitionsBuilder(G_BUILD_VECTOR)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)));
+
+  // When entering the legalizer, there should be no G_BITCAST instructions.
+  // They should all be calls to the `spv_bitcast` intrinsic. The call to
+  // the intrinsic will be converted to a G_BITCAST during legalization if
+  // the vectors are not legal. After using the rules to legalize a G_BITCAST,
+  // we turn it back into a call to the intrinsic with a custom rule to avoid
+  // potential machine verifier failures.
+  getActionDefinitionsBuilder(G_BITCAST)
+      .moreElementsToNextPow2(0)
+      .moreElementsToNextPow2(1)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)))
+      .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
+      .custom();
+
+  getActionDefinitionsBuilder(G_CONCAT_VECTORS)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+      .moreElementsToNextPow2(0)
+      .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
+      .alwaysLegal();
+
+  getActionDefinitionsBuilder(G_SPLAT_VECTOR)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+      .moreElementsToNextPow2(0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
       .alwaysLegal();
 
   // Vector Reduction Operations
@@ -164,7 +227,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
        G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
        G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
        G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
-      .legalFor(allVectors)
+      .legalFor(allowedVectorTypes)
       .scalarize(1)
       .lower();
 
@@ -172,9 +235,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .scalarize(2)
       .lower();
 
-  // Merge/Unmerge
-  // TODO: add proper legalization rules.
-  getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
+  // Illegal G_UNMERGE_VALUES instructions should be handled
+  // during the combine phase.
+  getActionDefinitionsBuilder(G_UNMERGE_VALUES)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize));
 
   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
       .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
@@ -228,7 +292,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       all(typeInSet(0, allPtrsScalarsAndVectors),
           typeInSet(1, allPtrsScalarsAndVectors)));
 
-  getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
+  getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
+      .legalFor({s1})
+      .legalFor(allFloatAndIntScalarsAndPtrs)
+      .legalFor(allowedVectorTypes)
+      .moreElementsToNextPow2(0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)));
 
   getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
 
@@ -287,6 +358,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   // Pointer-handling.
   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
 
+  getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
+
   // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
   getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
 
@@ -374,6 +447,11 @@ bool SPIRVLegalizerInfo::legalizeCustom(
   default:
     // TODO: implement legalization for other opcodes.
     return true;
+  case TargetOpcode::G_BITCAST:
+    return legalizeBitcast(Helper, MI);
+  case TargetOpcode::G_INTRINSIC:
+  case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
+    return legalizeIntrinsic(Helper, MI);
   case TargetOpcode::G_IS_FPCLASS:
     return legalizeIsFPClass(Helper, MI, LocObserver);
   case TargetOpcode::G_ICMP: {
@@ -400,6 +478,76 @@ bool SPIRVLegalizerInfo::legalizeCustom(
   }
 }
 
+bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
+                                           MachineInstr &MI) const {
+  LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
+
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
+  const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
+
+  auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
+  if (IntrinsicID == Intrinsic::spv_bitcast) {
+    LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
+    Register DstReg = MI.getOperand(0).getReg();
+    Register SrcReg = MI.getOperand(2).getReg();
+    LLT DstTy = MRI.getType(DstReg);
+    LLT SrcTy = MRI.getType(SrcReg);
+
+    int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
+
+    bool DstNeedsLegalization = false;
+    bool SrcNeedsLegalization = false;
+
+    if (DstTy.isVector()) {
+      if (DstTy.getNumElements() > 4 &&
+          !isPowerOf2_32(DstTy.getNumElements())) {
+        DstNeedsLegalization = true;
+      }
+
+      if (DstTy.getNumElements() > MaxVectorSize) {
+        DstNeedsLegalization = true;
+      }
+    }
+
+    if (SrcTy.isVector()) {
+      if (SrcTy.getNumElements() > 4 &&
+          !isPowerOf2_32(SrcTy.getNumElements())) {
+        SrcNeedsLegalization = true;
+      }
+
+      if (SrcTy.getNumElements() > MaxVectorSize) {
+        SrcNeedsLegalization = true;
+      }
+    }
+
+    // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
+    // allow using the generic legalization rules.
+    if (DstNeedsLegalization || SrcNeedsLegalization) {
+      LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
+      MIRBuilder.buildBitcast(DstReg, SrcReg);
+      MI.eraseFromParent();
+    }
+    return true;
+  }
+  return true;
+}
+
+bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
+                                         MachineInstr &MI) const {
+  // Once the G_BITCAST is using vectors that are allowed, we turn it back into
+  // an spv_bitcast to avoid verifier problems when the register types are the
+  // same for the source and the result. Note that the SPIR-V types associated
+  // with the bitcast can be different even if the register types are the same.
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  Register DstReg = MI.getOperand(0).getReg();
+  Register SrcReg = MI.getOperand(1).getReg();
+  SmallVector<Register, 1> DstRegs = {DstReg};
+  MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
+  MI.eraseFromParent();
+  return true;
+}
+
 // Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
 // to ensure that all instructions created during the lowering have SPIR-V types
 // assigned to them.
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
index eeefa4239c778..86e7e711caa60 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
@@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo {
 public:
   bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
                       LostDebugLocObserver &LocObserver) const override;
+  bool legalizeIntrinsic(LegalizerHelper &Helper,
+                         MachineInstr &MI) const override;
+
   SPIRVLegalizerInfo(const SPIRVSubtarget &ST);
 
 private:
   bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI,
                          LostDebugLocObserver &LocObserver) const;
+  bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const;
 };
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll
new file mode 100644
index 0000000000000..4fe6f217dd40f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll
@@ -0,0 +1,69 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion"
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#v8i32:]] = OpTypeVector %[[#int]] 8
+; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
+; CHECK-DAG: %[[#ptr_func_v8i32:]] = OpTypePointer Function %[[#v8i32]]
+
+; CHECK-DAG: OpName %[[#test_v3f64_conversion:]] "test_v3f64_conversion"
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#v3f64:]] = OpTypeVector %[[#double]] 3
+; CHECK-DAG: %[[#ptr_func_v3f64:]] = OpTypePointer Function %[[#v3f64]]
+; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4
+
+define spir_kernel void @test_int32_double_conversion(ptr %G_vec) {
+; CHECK: %[[#test_int32_double_conversion]] = OpFunction
+; CHECK: %[[#param:]] = OpFunctionParameter %[[#ptr_func_v8i32]]
+entry:
+  ; CHECK: %[[#LOAD:]] = OpLoad %[[#v8i32]] %[[#param]]
+  ; CHECK: %[[#SHUF1:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 2 4 6
+  ; CHECK: %[[#SHUF2:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 1 3 5 7
+  ; CHECK: %[[#SHUF3:]] = OpVectorShuffle %[[#v8i32]] %[[#SHUF1]] %[[#SHUF2]] 0 4 1 5 2 6 3 7
+  ; CHECK: OpStore %[[#param]] %[[#SHUF3]]
+
+  %0 = load <8 x i32>, ptr %G_vec
+  %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
+  store <8 x i32> %3, ptr %G_vec
+  ret void
+}
+
+define spir_kernel void @test_v3f64_conversion(ptr %G_vec) {
+; CHECK: %[[#test_v3f64_conversion:]] = OpFunction
+; CHECK: %[[#param_v3f64:]] = OpFunctionParameter %[[#ptr_func_v3f64]]
+entry:
+  ; CHECK: %[[#LOAD:]] = OpLoad %[[#v3f64]] %[[#param_v3f64]]
+  %0 = load <3 x double>, ptr %G_vec
+
+  ; The 6-element vector is not legal. It get expanded to 8.
+  ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 0
+  ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 1
+  ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 2
+  ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4f64]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %{{[a-zA-Z0-9_]+}}
+  ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v8i32]] %[[#CONSTRUCT1]]
+  %1 = bitcast <3 x double> %0 to <6 x i32>
+
+  ; CHECK: %[[#SHUFFLE1:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 0 2 4 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF
+  %2 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
+
+  ; CHECK: %[[#SHUFFLE2:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 1 3 5 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF
+  %3 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
+
+  ; CHECK: %[[#SHUFFLE3:]] = OpVectorShuffle %[[#v8i32]] %[[#SHUFFLE1]] %[[#SHUFFLE2]] 0 8 1 9 2 10 0xFFFFFFFF 0xFFFFFFFF
+  %4 = shufflevector <3 x i32> %2, <3 x i32> %3, <6 x i32> <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
+
+  ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4f64]] %[[#SHUFFLE3]]
+  ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 0
+  ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 1
+  ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 2
+  ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v3f64]] %[[#EXTRACT10]] %[[#EXTRACT11]] %[[#EXTRACT12]]
+  %5 = bitcast <6 x i32> %4 to <3 x double>
+  
+  ; CHECK: OpStore %[[#param_v3f64]] %[[#CONSTRUCT3]]
+  store <3 x double> %5, ptr %G_vec
+  ret void
+}
+

>From e17a42e12a5447338cec1a2da0a2efc40c51cd90 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Mon, 1 Dec 2025 11:28:20 -0500
Subject: [PATCH 2/3] Add missing test. Rewrite extract generated extract
 instruction to spv intrinsic.

---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  18 ++-
 .../vector-legalization-shader.ll             | 133 ++++++++++++++++++
 2 files changed, 150 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/legalization/vector-legalization-shader.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index c9c663ee3309e..b46c61a2ed953 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -178,7 +178,6 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .alwaysLegal();
 
   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
-      .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize))
       .moreElementsToNextPow2(1)
       .fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize),
                        LegalizeMutations::changeElementCountTo(
@@ -426,6 +425,21 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   verify(*ST.getInstrInfo());
 }
 
+static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
+                                     SPIRVGlobalRegistry *GR) {
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  Register DstReg = MI.getOperand(0).getReg();
+  Register SrcReg = MI.getOperand(1).getReg();
+  Register IdxReg = MI.getOperand(2).getReg();
+
+  MIRBuilder
+      .buildIntrinsic(Intrinsic::spv_extractelt, ArrayRef<Register>{DstReg})
+      .addUse(SrcReg)
+      .addUse(IdxReg);
+  MI.eraseFromParent();
+  return true;
+}
+
 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
                                 LegalizerHelper &Helper,
                                 MachineRegisterInfo &MRI,
@@ -449,6 +463,8 @@ bool SPIRVLegalizerInfo::legalizeCustom(
     return true;
   case TargetOpcode::G_BITCAST:
     return legalizeBitcast(Helper, MI);
+  case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+    return legalizeExtractVectorElt(Helper, MI, GR);
   case TargetOpcode::G_INTRINSIC:
   case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
     return legalizeIntrinsic(Helper, MI);
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-shader.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-shader.ll
new file mode 100644
index 0000000000000..438d7ae21283a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-shader.ll
@@ -0,0 +1,133 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %}
+
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#v4int:]] = OpTypeVector %[[#int]] 4
+; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4
+; CHECK-DAG: %[[#v2int:]] = OpTypeVector %[[#int]] 2
+; CHECK-DAG: %[[#v2double:]] = OpTypeVector %[[#double]] 2
+; CHECK-DAG: %[[#v3int:]] = OpTypeVector %[[#int]] 3
+; CHECK-DAG: %[[#v3double:]] = OpTypeVector %[[#double]] 3
+; CHECK-DAG: %[[#ptr_v4double:]] = OpTypePointer Private %[[#v4double]]
+; CHECK-DAG: %[[#ptr_v4int:]] = OpTypePointer Private %[[#v4int]]
+; CHECK-DAG: %[[#ptr_v3double:]] = OpTypePointer Private %[[#v3double]]
+; CHECK-DAG: %[[#ptr_v3int:]] = OpTypePointer Private %[[#v3int]]
+; CHECK-DAG: %[[#GVec4:]] = OpVariable %[[#ptr_v4double]] Private
+; CHECK-DAG: %[[#Lows:]] = OpVariable %[[#ptr_v4int]] Private
+; CHECK-DAG: %[[#Highs:]] = OpVariable %[[#ptr_v4int]] Private
+; CHECK-DAG: %[[#GVec3:]] = OpVariable %[[#ptr_v3double]] Private
+; CHECK-DAG: %[[#Lows3:]] = OpVariable %[[#ptr_v3int]] Private
+; CHECK-DAG: %[[#Highs3:]] = OpVariable %[[#ptr_v3int]] Private
+
+ at GVec4 = internal addrspace(10) global <4 x double> zeroinitializer
+ at Lows = internal addrspace(10) global <4 x i32> zeroinitializer
+ at Highs = internal addrspace(10) global <4 x i32> zeroinitializer
+ at GVec3 = internal addrspace(10) global <3 x double> zeroinitializer
+ at Lows3 = internal addrspace(10) global <3 x i32> zeroinitializer
+ at Highs3 = internal addrspace(10) global <3 x i32> zeroinitializer
+
+; Test splitting a vector of size 8.
+define internal void @test_split() {
+entry:
+  ; CHECK: %[[#load_v4double:]] = OpLoad %[[#v4double]] %[[#GVec4]]
+  ; CHECK: %[[#v2double_01:]] = OpVectorShuffle %[[#v2double]] %[[#load_v4double]] %{{[a-zA-Z0-9_]+}} 0 1
+  ; CHECK: %[[#v2double_23:]] = OpVectorShuffle %[[#v2double]] %[[#load_v4double]] %{{[a-zA-Z0-9_]+}} 2 3
+  ; CHECK: %[[#v4int_01:]] = OpBitcast %[[#v4int]] %[[#v2double_01]]
+  ; CHECK: %[[#v4int_23:]] = OpBitcast %[[#v4int]] %[[#v2double_23]]
+  %0 = load <8 x i32>, ptr addrspace(10) @GVec4, align 32
+
+  ; CHECK: %[[#l0:]] = OpCompositeExtract %[[#int]] %[[#v4int_01]] 0
+  ; CHECK: %[[#l1:]] = OpCompositeExtract %[[#int]] %[[#v4int_01]] 2
+  ; CHECK: %[[#l2:]] = OpCompositeExtract %[[#int]] %[[#v4int_23]] 0
+  ; CHECK: %[[#l3:]] = OpCompositeExtract %[[#int]] %[[#v4int_23]] 2
+  ; CHECK: %[[#res_low:]] = OpCompositeConstruct %[[#v4int]] %[[#l0]] %[[#l1]] %[[#l2]] %[[#l3]]
+  %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  
+  ; CHECK: %[[#h0:]] = OpCompositeExtract %[[#int]] %[[#v4int_01]] 1
+  ; CHECK: %[[#h1:]] = OpCompositeExtract %[[#int]] %[[#v4int_01]] 3
+  ; CHECK: %[[#h2:]] = OpCompositeExtract %[[#int]] %[[#v4int_23]] 1
+  ; CHECK: %[[#h3:]] = OpCompositeExtract %[[#int]] %[[#v4int_23]] 3
+  ; CHECK: %[[#res_high:]] = OpCompositeConstruct %[[#v4int]] %[[#h0]] %[[#h1]] %[[#h2]] %[[#h3]]
+  %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+
+  store <4 x i32> %1, ptr addrspace(10) @Lows, align 16
+  store <4 x i32> %2, ptr addrspace(10) @Highs, align 16
+  ret void
+}
+
+define internal void @test_recombine() {
+entry:
+  ; CHECK: %[[#l:]] = OpLoad %[[#v4int]] %[[#Lows]]
+  %0 = load <4 x i32>, ptr addrspace(10) @Lows, align 16
+  ; CHECK: %[[#h:]] = OpLoad %[[#v4int]] %[[#Highs]]
+  %1 = load <4 x i32>, ptr addrspace(10) @Highs, align 16
+
+  ; CHECK-DAG: %[[#l0:]] = OpCompositeExtract %[[#int]] %[[#l]] 0
+  ; CHECK-DAG: %[[#l1:]] = OpCompositeExtract %[[#int]] %[[#l]] 1
+  ; CHECK-DAG: %[[#l2:]] = OpCompositeExtract %[[#int]] %[[#l]] 2
+  ; CHECK-DAG: %[[#l3:]] = OpCompositeExtract %[[#int]] %[[#l]] 3
+  ; CHECK-DAG: %[[#h0:]] = OpCompositeExtract %[[#int]] %[[#h]] 0
+  ; CHECK-DAG: %[[#h1:]] = OpCompositeExtract %[[#int]] %[[#h]] 1
+  ; CHECK-DAG: %[[#h2:]] = OpCompositeExtract %[[#int]] %[[#h]] 2
+  ; CHECK-DAG: %[[#h3:]] = OpCompositeExtract %[[#int]] %[[#h]] 3
+  ; CHECK-DAG: %[[#v2i0:]] = OpCompositeConstruct %[[#v2int]] %[[#l0]] %[[#h0]]
+  ; CHECK-DAG: %[[#d0:]] = OpBitcast %[[#double]] %[[#v2i0]]
+  ; CHECK-DAG: %[[#v2i1:]] = OpCompositeConstruct %[[#v2int]] %[[#l1]] %[[#h1]]
+  ; CHECK-DAG: %[[#d1:]] = OpBitcast %[[#double]] %[[#v2i1]]
+  ; CHECK-DAG: %[[#v2i2:]] = OpCompositeConstruct %[[#v2int]] %[[#l2]] %[[#h2]]
+  ; CHECK-DAG: %[[#d2:]] = OpBitcast %[[#double]] %[[#v2i2]]
+  ; CHECK-DAG: %[[#v2i3:]] = OpCompositeConstruct %[[#v2int]] %[[#l3]] %[[#h3]]
+  ; CHECK-DAG: %[[#d3:]] = OpBitcast %[[#double]] %[[#v2i3]]
+  ; CHECK-DAG: %[[#res:]] = OpCompositeConstruct %[[#v4double]] %[[#d0]] %[[#d1]] %[[#d2]] %[[#d3]]
+  %2 = shufflevector <4 x i32> %0, <4 x i32> %1, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
+
+  ; CHECK: OpStore %[[#GVec4]] %[[#res]]
+  store <8 x i32> %2, ptr addrspace(10) @GVec4, align 32
+  ret void
+}
+
+; Test splitting a vector of size 6. It must be expanded to 8, and then split.
+define internal void @test_bitcast_expand() {
+entry:
+  ; CHECK: %[[#load:]] = OpLoad %[[#v3double]] %[[#GVec3]]
+  %0 = load <3 x double>, ptr addrspace(10) @GVec3, align 32
+
+  ; CHECK: %[[#d0:]] = OpCompositeExtract %[[#double]] %[[#load]] 0
+  ; CHECK: %[[#d1:]] = OpCompositeExtract %[[#double]] %[[#load]] 1
+  ; CHECK: %[[#d2:]] = OpCompositeExtract %[[#double]] %[[#load]] 2
+  ; CHECK: %[[#v2d0:]] = OpCompositeConstruct %[[#v2double]] %[[#d0]] %[[#d1]]
+  ; CHECK: %[[#v2d1:]] = OpCompositeConstruct %[[#v2double]] %[[#d2]] %[[#]]
+  ; CHECK: %[[#v4i0:]] = OpBitcast %[[#v4int]] %[[#v2d0]]
+  ; CHECK: %[[#v4i1:]] = OpBitcast %[[#v4int]] %[[#v2d1]]
+  %1 = bitcast <3 x double> %0 to <6 x i32>
+  
+  ; CHECK: %[[#l0:]] = OpCompositeExtract %[[#int]] %[[#v4i0]] 0
+  ; CHECK: %[[#l1:]] = OpCompositeExtract %[[#int]] %[[#v4i0]] 2
+  ; CHECK: %[[#l2:]] = OpCompositeExtract %[[#int]] %[[#v4i1]] 0
+  ; CHECK: %[[#res_low:]] = OpCompositeConstruct %[[#v3int]] %[[#l0]] %[[#l1]] %[[#l2]]
+  %2 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
+  
+  ; CHECK: %[[#h0:]] = OpCompositeExtract %[[#int]] %[[#v4i0]] 1
+  ; CHECK: %[[#h1:]] = OpCompositeExtract %[[#int]] %[[#v4i0]] 3
+  ; CHECK: %[[#h2:]] = OpCompositeExtract %[[#int]] %[[#v4i1]] 1
+  ; CHECK: %[[#res_high:]] = OpCompositeConstruct %[[#v3int]] %[[#h0]] %[[#h1]] %[[#h2]]
+  %3 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
+
+  ; CHECK: OpStore %[[#Lows3]] %[[#res_low]]
+  store <3 x i32> %2, ptr addrspace(10) @Lows3, align 16
+
+  ; CHECK: OpStore %[[#Highs3]] %[[#res_high]]
+  store <3 x i32> %3, ptr addrspace(10) @Highs3, align 16
+  ret void
+}
+
+define void @main() local_unnamed_addr #0 {
+entry:
+  call void @test_split()
+  call void @test_recombine()
+  call void @test_bitcast_expand()
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

>From fba15c44ccc1058cf2d63f5ab53b5468e59077c4 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Mon, 1 Dec 2025 11:39:36 -0500
Subject: [PATCH 3/3] Add comment explaining the vector size restrictions.

---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index b46c61a2ed953..32739d7e5e87f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -160,6 +160,12 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
         return IsExtendedInts && Ty.isValid();
       };
 
+  // The universal validation rules in the SPIR-V specification state that
+  // vector sizes are typically limited to 2, 3, or 4. However, larger vector
+  // sizes (8 and 16) are enabled when the Kernel capability is present. For
+  // shader execution models, vector sizes are strictly limited to 4. In
+  // non-shader contexts, vector sizes of 8 and 16 are also permitted, but
+  // arbitrary sizes (e.g., 6 or 11) are not.
   uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
 
   for (auto Opc : getTypeFoldingSupportedOpcodes()) {



More information about the llvm-commits mailing list