[llvm] matrix mul and transpose (PR #172050)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 12 09:29:21 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Steven Perron (s-perron)

<details>
<summary>Changes</summary>

- **[SPIR-V] Legalize vector arithmetic and intrinsics for large vectors**
- **Add s64 to all scalars.**
- **Fix alignment in stores generated in legalize pointer cast.**
- **Set all def to default type in post legalization.**
- **Update tests, and add tests for 6-element vectors**
- **Handle G_STRICT_FMA**
- **Scalarize rem and div with illegal vector size that is not a power of 2.**
- **[SPIRV] Implement lowering for llvm.matrix.transpose and llvm.matrix.multiply**


---

Patch is 89.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172050.diff


11 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVCombine.td (+16-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp (+189) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h (+21) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp (+5-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+104-40) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+100-68) 
- (added) llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll (+218) 
- (added) llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll (+224) 
- (added) llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll (+299) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-multiply.ll (+168) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-transpose.ll (+124) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVCombine.td b/llvm/lib/Target/SPIRV/SPIRVCombine.td
index 991a5de1c4e83..7d69465de4ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCombine.td
+++ b/llvm/lib/Target/SPIRV/SPIRVCombine.td
@@ -22,8 +22,22 @@ def vector_select_to_faceforward_lowering : GICombineRule <
   (apply [{ Helper.applySPIRVFaceForward(*${root}); }])
 >;
 
+def matrix_transpose_lowering
+    : GICombineRule<(defs root:$root),
+                    (match (wip_match_opcode G_INTRINSIC):$root,
+                        [{ return Helper.matchMatrixTranspose(*${root}); }]),
+                    (apply [{ Helper.applyMatrixTranspose(*${root}); }])>;
+
+def matrix_multiply_lowering
+    : GICombineRule<(defs root:$root),
+                    (match (wip_match_opcode G_INTRINSIC):$root,
+                        [{ return Helper.matchMatrixMultiply(*${root}); }]),
+                    (apply [{ Helper.applyMatrixMultiply(*${root}); }])>;
+
 def SPIRVPreLegalizerCombiner
     : GICombiner<"SPIRVPreLegalizerCombinerImpl",
-                       [vector_length_sub_to_distance_lowering, vector_select_to_faceforward_lowering]> {
-    let CombineAllMethodName = "tryCombineAllImpl";
+                 [vector_length_sub_to_distance_lowering,
+                  vector_select_to_faceforward_lowering,
+                  matrix_transpose_lowering, matrix_multiply_lowering]> {
+  let CombineAllMethodName = "tryCombineAllImpl";
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp b/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp
index fad2b676fee04..693b74c1e06d7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp
@@ -7,9 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVCombinerHelper.h"
+#include "SPIRVGlobalRegistry.h"
+#include "SPIRVUtils.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/LLVMContext.h" // Explicitly include for LLVMContext
 #include "llvm/Target/TargetMachine.h"
 
 using namespace llvm;
@@ -209,3 +213,188 @@ void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const {
   GR->invalidateMachineInstr(FalseInstr);
   FalseInstr->eraseFromParent();
 }
+
+bool SPIRVCombinerHelper::matchMatrixTranspose(MachineInstr &MI) const {
+  return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
+         cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_transpose;
+}
+
+void SPIRVCombinerHelper::applyMatrixTranspose(MachineInstr &MI) const {
+  Register ResReg = MI.getOperand(0).getReg();
+  Register InReg = MI.getOperand(2).getReg();
+  uint32_t Rows = MI.getOperand(3).getImm();
+  uint32_t Cols = MI.getOperand(4).getImm();
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  if (Rows == 1 && Cols == 1) {
+    Builder.buildCopy(ResReg, InReg);
+    MI.eraseFromParent();
+    return;
+  }
+
+  SmallVector<int, 16> Mask;
+  for (uint32_t K = 0; K < Rows * Cols; ++K) {
+    uint32_t R = K / Cols;
+    uint32_t C = K % Cols;
+    Mask.push_back(C * Rows + R);
+  }
+
+  Builder.buildShuffleVector(ResReg, InReg, InReg, Mask);
+  MI.eraseFromParent();
+}
+
+bool SPIRVCombinerHelper::matchMatrixMultiply(MachineInstr &MI) const {
+  return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
+         cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_multiply;
+}
+
+SmallVector<Register, 4>
+SPIRVCombinerHelper::extractColumns(Register MatrixReg, uint32_t NumberOfCols,
+                                    SPIRVType *SpvColType,
+                                    SPIRVGlobalRegistry *GR) const {
+  // If the matrix is a single colunm, return that single column.
+  if (NumberOfCols == 1)
+    return {MatrixReg};
+
+  SmallVector<Register, 4> Cols;
+  LLT ColTy = GR->getRegType(SpvColType);
+  for (uint32_t J = 0; J < NumberOfCols; ++J)
+    Cols.push_back(MRI.createGenericVirtualRegister(ColTy));
+  Builder.buildUnmerge(Cols, MatrixReg);
+  for (Register R : Cols) {
+    setRegClassType(R, SpvColType, GR, &MRI, Builder.getMF());
+  }
+  return Cols;
+}
+
+SmallVector<Register, 4>
+SPIRVCombinerHelper::extractRows(Register MatrixReg, uint32_t NumRows,
+                                 uint32_t NumCols, SPIRVType *SpvRowType,
+                                 SPIRVGlobalRegistry *GR) const {
+  SmallVector<Register, 4> Rows;
+  LLT VecTy = GR->getRegType(SpvRowType);
+
+  // If there is only one column, then each row is a scalar that needs
+  // to be extracted.
+  if (NumCols == 1) {
+    assert(SpvRowType->getOpcode() != SPIRV::OpTypeVector);
+    for (uint32_t I = 0; I < NumRows; ++I)
+      Rows.push_back(MRI.createGenericVirtualRegister(VecTy));
+    Builder.buildUnmerge(Rows, MatrixReg);
+    for (Register R : Rows) {
+      setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
+    }
+    return Rows;
+  }
+
+  // If the matrix is a single row return that row.
+  if (NumRows == 1) {
+    return {MatrixReg};
+  }
+
+  for (uint32_t I = 0; I < NumRows; ++I) {
+    SmallVector<int, 4> Mask;
+    for (uint32_t k = 0; k < NumCols; ++k)
+      Mask.push_back(k * NumRows + I);
+    Rows.push_back(Builder.buildShuffleVector(VecTy, MatrixReg, MatrixReg, Mask)
+                       .getReg(0));
+  }
+  for (Register R : Rows) {
+    setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
+  }
+  return Rows;
+}
+
+Register SPIRVCombinerHelper::computeDotProduct(Register RowA, Register ColB,
+                                                SPIRVType *SpvVecType,
+                                                SPIRVGlobalRegistry *GR) const {
+  bool IsVectorOp = SpvVecType->getOpcode() == SPIRV::OpTypeVector;
+  SPIRVType *SpvScalarType = GR->getScalarOrVectorComponentType(SpvVecType);
+  bool IsFloatOp = SpvScalarType->getOpcode() == SPIRV::OpTypeFloat;
+  LLT VecTy = GR->getRegType(SpvVecType);
+
+  Register DotRes;
+  if (IsVectorOp) {
+    LLT ScalarTy = VecTy.getElementType();
+    Intrinsic::SPVIntrinsics DotIntrinsic =
+        (IsFloatOp ? Intrinsic::spv_fdot : Intrinsic::spv_udot);
+    DotRes = Builder.buildIntrinsic(DotIntrinsic, {ScalarTy})
+                 .addUse(RowA)
+                 .addUse(ColB)
+                 .getReg(0);
+  } else {
+    if (IsFloatOp)
+      DotRes = Builder.buildFMul(VecTy, RowA, ColB).getReg(0);
+    else
+      DotRes = Builder.buildMul(VecTy, RowA, ColB).getReg(0);
+  }
+  setRegClassType(DotRes, SpvScalarType, GR, &MRI, Builder.getMF());
+  return DotRes;
+}
+
+SmallVector<Register, 16>
+SPIRVCombinerHelper::computeDotProducts(const SmallVector<Register, 4> &RowsA,
+                                        const SmallVector<Register, 4> &ColsB,
+                                        SPIRVType *SpvVecType,
+                                        SPIRVGlobalRegistry *GR) const {
+  SmallVector<Register, 16> ResultScalars;
+  for (uint32_t J = 0; J < ColsB.size(); ++J) {
+    for (uint32_t I = 0; I < RowsA.size(); ++I) {
+      ResultScalars.push_back(
+          computeDotProduct(RowsA[I], ColsB[J], SpvVecType, GR));
+    }
+  }
+  return ResultScalars;
+}
+
+SPIRVType *
+SPIRVCombinerHelper::getDotProductVectorType(Register ResReg, uint32_t K,
+                                             SPIRVGlobalRegistry *GR) const {
+  // Loop over all non debug uses of ResReg
+  Type *ScalarResType = nullptr;
+  for (auto &UseMI : MRI.use_instructions(ResReg)) {
+    if (UseMI.getOpcode() != TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
+      continue;
+
+    if (!isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type))
+      continue;
+
+    Type *Ty = getMDOperandAsType(UseMI.getOperand(2).getMetadata(), 0);
+    if (Ty->isVectorTy())
+      ScalarResType = cast<VectorType>(Ty)->getElementType();
+    else
+      ScalarResType = Ty;
+    assert(ScalarResType->isIntegerTy() || ScalarResType->isFloatingPointTy());
+    break;
+  }
+  Type *VecType =
+      (K > 1 ? FixedVectorType::get(ScalarResType, K) : ScalarResType);
+  return GR->getOrCreateSPIRVType(VecType, Builder,
+                                  SPIRV::AccessQualifier::None, false);
+}
+
+void SPIRVCombinerHelper::applyMatrixMultiply(MachineInstr &MI) const {
+  Register ResReg = MI.getOperand(0).getReg();
+  Register AReg = MI.getOperand(2).getReg();
+  Register BReg = MI.getOperand(3).getReg();
+  uint32_t NumRowsA = MI.getOperand(4).getImm();
+  uint32_t NumColsA = MI.getOperand(5).getImm();
+  uint32_t NumColsB = MI.getOperand(6).getImm();
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  SPIRVGlobalRegistry *GR =
+      MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
+
+  SPIRVType *SpvVecType = getDotProductVectorType(ResReg, NumColsA, GR);
+  SmallVector<Register, 4> ColsB =
+      extractColumns(BReg, NumColsB, SpvVecType, GR);
+  SmallVector<Register, 4> RowsA =
+      extractRows(AReg, NumRowsA, NumColsA, SpvVecType, GR);
+  SmallVector<Register, 16> ResultScalars =
+      computeDotProducts(RowsA, ColsB, SpvVecType, GR);
+
+  Builder.buildBuildVector(ResReg, ResultScalars);
+  MI.eraseFromParent();
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h b/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h
index 3118cdc744b8f..b6b3b36f03ade 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h
+++ b/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h
@@ -33,6 +33,27 @@ class SPIRVCombinerHelper : public CombinerHelper {
   void applySPIRVDistance(MachineInstr &MI) const;
   bool matchSelectToFaceForward(MachineInstr &MI) const;
   void applySPIRVFaceForward(MachineInstr &MI) const;
+  bool matchMatrixTranspose(MachineInstr &MI) const;
+  void applyMatrixTranspose(MachineInstr &MI) const;
+  bool matchMatrixMultiply(MachineInstr &MI) const;
+  void applyMatrixMultiply(MachineInstr &MI) const;
+
+private:
+  SPIRVType *getDotProductVectorType(Register ResReg, uint32_t K,
+                                     SPIRVGlobalRegistry *GR) const;
+  SmallVector<Register, 4> extractColumns(Register BReg, uint32_t N,
+                                          SPIRVType *SpvVecType,
+                                          SPIRVGlobalRegistry *GR) const;
+  SmallVector<Register, 4> extractRows(Register AReg, uint32_t NumRows,
+                                       uint32_t NumCols, SPIRVType *SpvRowType,
+                                       SPIRVGlobalRegistry *GR) const;
+  SmallVector<Register, 16>
+  computeDotProducts(const SmallVector<Register, 4> &RowsA,
+                     const SmallVector<Register, 4> &ColsB,
+                     SPIRVType *SpvVecType, SPIRVGlobalRegistry *GR) const;
+  Register computeDotProduct(Register RowA, Register ColB,
+                             SPIRVType *SpvVecType,
+                             SPIRVGlobalRegistry *GR) const;
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 81c7596530ee2..a794f3e9c5363 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -168,6 +168,9 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     assert(VecTy->getElementType() == ArrTy->getElementType() &&
            "Element types of array and vector must be the same.");
 
+    const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
+    uint64_t ElemSize = DL.getTypeAllocSize(ArrTy->getElementType());
+
     for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
       // Create a GEP to access the i-th element of the array.
       SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
@@ -190,7 +193,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
       buildAssignType(B, VecTy->getElementType(), Element);
 
       Types = {Element->getType(), ElementPtr->getType()};
-      Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())};
+      Align NewAlign = commonAlignment(Alignment, i * ElemSize);
+      Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(NewAlign.value())};
       B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
     }
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 30703ee40be06..a7626c3d43c34 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -114,6 +114,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                            v3s1, v3s8, v3s16, v3s32, v3s64,
                            v4s1, v4s8, v4s16, v4s32, v4s64};
 
+  auto allScalars = {s1, s8, s16, s32, s64};
+
   auto allScalarsAndVectors = {
       s1,    s8,    s16,   s32,   s64,    s128,   v2s1,  v2s8,
       v2s16, v2s32, v2s64, v3s1,  v3s8,   v3s16,  v3s32, v3s64,
@@ -171,12 +173,48 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   // non-shader contexts, vector sizes of 8 and 16 are also permitted, but
   // arbitrary sizes (e.g., 6 or 11) are not.
   uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
+  LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n");
 
   for (auto Opc : getTypeFoldingSupportedOpcodes()) {
-    if (Opc != G_EXTRACT_VECTOR_ELT)
-      getActionDefinitionsBuilder(Opc).custom();
+    switch (Opc) {
+    case G_EXTRACT_VECTOR_ELT:
+    case G_UREM:
+    case G_SREM:
+    case G_UDIV:
+    case G_SDIV:
+    case G_FREM:
+      break;
+    default:
+      getActionDefinitionsBuilder(Opc)
+          .customFor(allScalars)
+          .customFor(allowedVectorTypes)
+          .moreElementsToNextPow2(0)
+          .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                           LegalizeMutations::changeElementCountTo(
+                               0, ElementCount::getFixed(MaxVectorSize)))
+          .custom();
+      break;
+    }
   }
 
+  getActionDefinitionsBuilder({G_UREM, G_SREM, G_SDIV, G_UDIV, G_FREM})
+      .customFor(allScalars)
+      .customFor(allowedVectorTypes)
+      .scalarizeIf(numElementsNotPow2(0), 0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)))
+      .custom();
+
+  getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
+      .legalFor(allScalars)
+      .legalFor(allowedVectorTypes)
+      .moreElementsToNextPow2(0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)))
+      .alwaysLegal();
+
   getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
 
   getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
@@ -184,8 +222,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .moreElementsToNextPow2(0)
       .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
       .moreElementsToNextPow2(1)
-      .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
-      .alwaysLegal();
+      .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize));
 
   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
       .moreElementsToNextPow2(1)
@@ -194,6 +231,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                            1, ElementCount::getFixed(MaxVectorSize)))
       .custom();
 
+  getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
+      .moreElementsToNextPow2(0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)))
+      .custom();
+
   // Illegal G_UNMERGE_VALUES instructions should be handled
   // during the combine phase.
   getActionDefinitionsBuilder(G_BUILD_VECTOR)
@@ -217,14 +261,12 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
       .custom();
 
+  // If the result is still illegal, the combiner should be able to remove it.
   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
-      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
-      .moreElementsToNextPow2(0)
-      .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
-      .alwaysLegal();
+      .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes);
 
   getActionDefinitionsBuilder(G_SPLAT_VECTOR)
-      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+      .legalFor(allowedVectorTypes)
       .moreElementsToNextPow2(0)
       .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
                        LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
@@ -273,9 +315,6 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .legalFor(allIntScalarsAndVectors)
       .legalIf(extendedScalarsAndVectors);
 
-  getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
-      .legalFor(allFloatScalarsAndVectors);
-
   getActionDefinitionsBuilder(G_STRICT_FLDEXP)
       .legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars);
 
@@ -461,6 +500,23 @@ static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
   return true;
 }
 
+static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
+                                    SPIRVGlobalRegistry *GR) {
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  Register DstReg = MI.getOperand(0).getReg();
+  Register SrcReg = MI.getOperand(1).getReg();
+  Register ValReg = MI.getOperand(2).getReg();
+  Register IdxReg = MI.getOperand(3).getReg();
+
+  MIRBuilder
+      .buildIntrinsic(Intrinsic::spv_insertelt, ArrayRef<Register>{DstReg})
+      .addUse(SrcReg)
+      .addUse(ValReg)
+      .addUse(IdxReg);
+  MI.eraseFromParent();
+  return true;
+}
+
 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
                                 LegalizerHelper &Helper,
                                 MachineRegisterInfo &MRI,
@@ -486,6 +542,8 @@ bool SPIRVLegalizerInfo::legalizeCustom(
     return legalizeBitcast(Helper, MI);
   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
     return legalizeExtractVectorElt(Helper, MI, GR);
+  case TargetOpcode::G_INSERT_VECTOR_ELT:
+    return legalizeInsertVectorElt(Helper, MI, GR);
   case TargetOpcode::G_INTRINSIC:
   case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
     return legalizeIntrinsic(Helper, MI);
@@ -515,6 +573,15 @@ bool SPIRVLegalizerInfo::legalizeCustom(
   }
 }
 
+static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
+  if (!Ty.isVector())
+    return false;
+  unsigned NumElements = Ty.getNumElements();
+  unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
+  return (NumElements > 4 && !isPowerOf2_32(NumElements)) ||
+         NumElements > MaxVectorSize;
+}
+
 bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
                                            MachineInstr &MI) const {
   LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
@@ -531,41 +598,38 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
     LLT DstTy = MRI.getType(DstReg);
     LLT SrcTy = MRI.getType(SrcReg);
 
-    int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
-
-    bool DstNeedsLegalization = false;
-    bool SrcNeedsLegalization = false;
-
-    if (DstTy.isVector()) {
-      if (DstTy.getNumElements() > 4 &&
-          !isPowerOf2_32(DstTy.getNumElements())) {
-        DstNeedsLegalization = true;
-      }
-
-      if (DstTy.getNumElements() > MaxVectorSize) {
-        DstNeedsLegalization = true;
-      }
-    }
-
-    if (SrcTy.isVector()) {
-      if (SrcTy.getNumElements() > 4 &&
-          !isPowerOf2_32(SrcTy.getNumElements())) {
-        SrcNeedsLegalization = true;
-      }
-
-      if (SrcTy.getNumElements() > MaxVectorSize) {
-        SrcNeedsLegalization = true;
-      }
-    }
-
     // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
     // allow using the generic legalization rules.
-    if (DstNeedsLegalization || SrcNeedsLegalization) {
+    if (needsVectorLegalization(DstTy, ST) ||
+        needsVectorLegalization(SrcTy, ST)) {
       LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/172050


More information about the llvm-commits mailing list