[llvm] [SPIRV] Legalize long vectors in GlobalISel (PR #164634)

Steven Perron via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 22 07:53:51 PDT 2025


https://github.com/s-perron updated https://github.com/llvm/llvm-project/pull/164634

>From 50f19ff74bd0ec15e0fc7e489a5bb380d9fd81fc Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 7 Oct 2025 13:26:47 -0400
Subject: [PATCH] [SPIRV] Legalize long vectors in GlobalISel

This commit introduces support for legalizing long vectors (vectors with more
than 4 elements) in the SPIR-V backend using GlobalISel. This is primarily
for shader compilation where the GLSL_std_450 instruction set is available.

The main changes include:

- Adding legalization rules for vector operations (G_BUILD_VECTOR,
  G_SHUFFLE_VECTOR, G_EXTRACT_VECTOR_ELT, G_BITCAST, G_CONCAT_VECTORS)
  to split vectors with more than 4 elements into smaller vectors.
- Enhancing the SPIRVPostLegalizer with a worklist-based approach to
  correctly process instructions and types generated during legalization.
- Lowering G_EXTRACT_VECTOR_ELT to a spv_extractelt intrinsic.
- Refining the handling of G_BITCAST to legalize non-pointer bitcasts.
- Marking many SPIR-V operations as pure to aid optimization.
---
 .../llvm/CodeGen/GlobalISel/LegalizerInfo.h   |   5 +
 .../CodeGen/GlobalISel/LegalityPredicates.cpp |  10 +
 llvm/lib/Target/SPIRV/SPIRVInstrFormats.td    |   5 +
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       | 179 +++++++----
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  50 ++-
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp |   5 +-
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  55 +++-
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp  | 303 +++++++++++++++---
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |  26 +-
 9 files changed, 493 insertions(+), 145 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 51318c9c2736d..7cce0ae5359b6 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -314,6 +314,11 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
 LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
                                                    unsigned Size);
 
+/// True iff the specified type index is a vector with an element size
+/// that's greater than the given size.
+LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
+                                                           unsigned Size);
+
 /// True iff the specified type index is a scalar or a vector with an element
 /// type that's wider than the given size.
 LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
index 30c2d089c3121..757a1fdba7fbe 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
@@ -155,6 +155,16 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
   };
 }
 
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
+                                                    unsigned Size) {
+
+  return [=](const LegalityQuery &Query) {
+    const LLT QueryTy = Query.Types[TypeIdx];
+    return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
+  };
+}
+
 LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
                                                            unsigned Size) {
   return [=](const LegalityQuery &Query) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
index 2fde2b0bc0b1f..f93240dc35993 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
@@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
   let Pattern = pattern;
 }
 
+class PureOp<bits<16> Opcode, dag outs, dag ins, string asmstr,
+             list<dag> pattern = []> : Op<Opcode, outs, ins, asmstr, pattern> {
+  let hasSideEffects = 0;
+}
+
 class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
     : Op<0, outs, ins, asmstr, pattern> {
   let isPseudo = 1;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a61351eba03f8..799a82c96b0f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -163,52 +163,74 @@ def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, vari
 
 // 3.42.6 Type-Declaration Instructions
 
-def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
-def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
-def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
-                  "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
-                  "$type = OpTypeFloat $width">;
-def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
-                  "$type = OpTypeVector $compType $compCount">;
-def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
-                  "$type = OpTypeMatrix $colType $colCount">;
-def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
-      i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops),
-                  "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">;
-def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
-def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType),
-                  "$res = OpTypeSampledImage $imageType">;
-def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
-                  "$type = OpTypeArray $elementType $length">;
-def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType),
-                  "$type = OpTypeRuntimeArray $elementType">;
-def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
-def OpTypeStructContinuedINTEL: Op<6090, (outs), (ins variable_ops),
-                  "OpTypeStructContinuedINTEL">;
-def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
-                  "$res = OpTypeOpaque $name">;
-def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
-                  "$res = OpTypePointer $storage $type">;
-def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
-                  "$funcType = OpTypeFunction $returnType">;
-def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
-def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
-def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
-def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
-def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">;
-def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
-                  "OpTypeForwardPointer $ptrType $storageClass">;
-def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
-def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
-def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
-                  "$res = OpTypeAccelerationStructureNV">;
-def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
-                  (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
-                  "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
-def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
-                  (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
-                  "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
+def OpTypeVoid : PureOp<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
+def OpTypeBool : PureOp<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
+def OpTypeInt
+    : PureOp<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
+             "$type = OpTypeInt $width $signedness">;
+def OpTypeFloat
+    : PureOp<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
+             "$type = OpTypeFloat $width">;
+def OpTypeVector
+    : PureOp<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
+             "$type = OpTypeVector $compType $compCount">;
+def OpTypeMatrix
+    : PureOp<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
+             "$type = OpTypeMatrix $colType $colCount">;
+def OpTypeImage : PureOp<25, (outs TYPE:$res),
+                         (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
+                             i32imm:$arrayed, i32imm:$MS, i32imm:$sampled,
+                             ImageFormat:$imFormat, variable_ops),
+                         "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS "
+                         "$sampled $imFormat">;
+def OpTypeSampler : PureOp<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
+def OpTypeSampledImage : PureOp<27, (outs TYPE:$res), (ins TYPE:$imageType),
+                                "$res = OpTypeSampledImage $imageType">;
+def OpTypeArray
+    : PureOp<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
+             "$type = OpTypeArray $elementType $length">;
+def OpTypeRuntimeArray : PureOp<29, (outs TYPE:$type), (ins TYPE:$elementType),
+                                "$type = OpTypeRuntimeArray $elementType">;
+def OpTypeStruct
+    : PureOp<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
+def OpTypeStructContinuedINTEL
+    : PureOp<6090, (outs), (ins variable_ops), "OpTypeStructContinuedINTEL">;
+def OpTypeOpaque
+    : PureOp<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
+             "$res = OpTypeOpaque $name">;
+def OpTypePointer
+    : PureOp<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
+             "$res = OpTypePointer $storage $type">;
+def OpTypeFunction
+    : PureOp<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
+             "$funcType = OpTypeFunction $returnType">;
+def OpTypeEvent : PureOp<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
+def OpTypeDeviceEvent
+    : PureOp<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
+def OpTypeReserveId
+    : PureOp<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
+def OpTypeQueue : PureOp<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
+def OpTypePipe : PureOp<38, (outs TYPE:$res), (ins AccessQualifier:$a),
+                        "$res = OpTypePipe $a">;
+def OpTypeForwardPointer
+    : PureOp<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
+             "OpTypeForwardPointer $ptrType $storageClass">;
+def OpTypePipeStorage
+    : PureOp<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
+def OpTypeNamedBarrier
+    : PureOp<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
+def OpTypeAccelerationStructureNV
+    : PureOp<5341, (outs TYPE:$res), (ins),
+             "$res = OpTypeAccelerationStructureNV">;
+def OpTypeCooperativeMatrixNV
+    : PureOp<5358, (outs TYPE:$res),
+             (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
+             "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
+def OpTypeCooperativeMatrixKHR
+    : PureOp<4456, (outs TYPE:$res),
+             (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
+             "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols "
+             "$use">;
 
 // 3.42.7 Constant-Creation Instructions
 
@@ -222,31 +244,46 @@ defm OpConstant: IntFPImm<43, "OpConstant">;
 
 def ConstPseudoTrue: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>;
 def ConstPseudoFalse: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>;
-def OpConstantTrue: Op<41, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty",
-                      [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
-def OpConstantFalse: Op<42, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty",
-                      [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
-
-def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
-                  "$res = OpConstantComposite $type">;
-def OpConstantCompositeContinuedINTEL: Op<6091, (outs), (ins variable_ops),
-                  "OpConstantCompositeContinuedINTEL">;
-
-def OpConstantSampler: Op<45, (outs ID:$res),
-                  (ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f),
-                  "$res = OpConstantSampler $t $s $p $f">;
-def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty">;
-
-def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
-def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
-def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
-                  "$res = OpSpecConstant $type $imm">;
-def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
-                  "$res = OpSpecConstantComposite $type">;
-def OpSpecConstantCompositeContinuedINTEL: Op<6092, (outs), (ins variable_ops),
-                  "OpSpecConstantCompositeContinuedINTEL">;
-def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
-                  "$res = OpSpecConstantOp $t $c $o">;
+def OpConstantTrue
+    : PureOp<41, (outs iID:$dst), (ins TYPE:$src_ty),
+             "$dst = OpConstantTrue $src_ty",
+             [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
+def OpConstantFalse
+    : PureOp<42, (outs iID:$dst), (ins TYPE:$src_ty),
+             "$dst = OpConstantFalse $src_ty",
+             [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
+
+def OpConstantComposite
+    : PureOp<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
+             "$res = OpConstantComposite $type">;
+def OpConstantCompositeContinuedINTEL
+    : PureOp<6091, (outs), (ins variable_ops),
+             "OpConstantCompositeContinuedINTEL">;
+
+def OpConstantSampler : PureOp<45, (outs ID:$res),
+                               (ins TYPE:$t, SamplerAddressingMode:$s,
+                                   i32imm:$p, SamplerFilterMode:$f),
+                               "$res = OpConstantSampler $t $s $p $f">;
+def OpConstantNull : PureOp<46, (outs ID:$dst), (ins TYPE:$src_ty),
+                            "$dst = OpConstantNull $src_ty">;
+
+def OpSpecConstantTrue
+    : PureOp<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
+def OpSpecConstantFalse
+    : PureOp<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
+def OpSpecConstant
+    : PureOp<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
+             "$res = OpSpecConstant $type $imm">;
+def OpSpecConstantComposite
+    : PureOp<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
+             "$res = OpSpecConstantComposite $type">;
+def OpSpecConstantCompositeContinuedINTEL
+    : PureOp<6092, (outs), (ins variable_ops),
+             "OpSpecConstantCompositeContinuedINTEL">;
+def OpSpecConstantOp
+    : PureOp<52, (outs ID:$res),
+             (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
+             "$res = OpSpecConstantOp $t $c $o">;
 
 // 3.42.8 Memory Instructions
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 021353ab716f7..23fe0b5a15041 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1526,33 +1526,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
   unsigned ArgI = I.getNumOperands() - 1;
   Register SrcReg =
       I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
-  SPIRVType *DefType =
+  SPIRVType *SrcType =
       SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
-  if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+  if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
     report_fatal_error(
         "cannot select G_UNMERGE_VALUES with a non-vector argument");
 
   SPIRVType *ScalarType =
-      GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+      GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
   MachineBasicBlock &BB = *I.getParent();
   bool Res = false;
+  unsigned CurrentIndex = 0;
   for (unsigned i = 0; i < I.getNumDefs(); ++i) {
     Register ResVReg = I.getOperand(i).getReg();
     SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
     if (!ResType) {
-      // There was no "assign type" actions, let's fix this now
-      ResType = ScalarType;
+      LLT ResLLT = MRI->getType(ResVReg);
+      assert(ResLLT.isValid());
+      if (ResLLT.isVector()) {
+        ResType = GR.getOrCreateSPIRVVectorType(
+            ScalarType, ResLLT.getNumElements(), I, TII);
+      } else {
+        ResType = ScalarType;
+      }
       MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
-      MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
       GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
     }
-    auto MIB =
-        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
-            .addDef(ResVReg)
-            .addUse(GR.getSPIRVTypeID(ResType))
-            .addUse(SrcReg)
-            .addImm(static_cast<int64_t>(i));
-    Res |= MIB.constrainAllUses(TII, TRI, RBI);
+
+    if (ResType->getOpcode() == SPIRV::OpTypeVector) {
+      Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
+      auto MIB =
+          BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
+              .addDef(ResVReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(SrcReg)
+              .addUse(UndefReg);
+      unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
+      for (unsigned j = 0; j < NumElements; ++j) {
+        MIB.addImm(CurrentIndex + j);
+      }
+      CurrentIndex += NumElements;
+      Res |= MIB.constrainAllUses(TII, TRI, RBI);
+    } else {
+      auto MIB =
+          BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+              .addDef(ResVReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(SrcReg)
+              .addImm(CurrentIndex);
+      CurrentIndex++;
+      Res |= MIB.constrainAllUses(TII, TRI, RBI);
+    }
   }
   return Res;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 28a1690ef0be1..61de82afad389 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -73,7 +73,10 @@ class SPIRVLegalizePointerCast : public FunctionPass {
   // Returns the loaded value.
   Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
                               FixedVectorType *TargetType, Value *Source) {
-    assert(TargetType->getNumElements() <= SourceType->getNumElements());
+    const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
+    [[maybe_unused]] TypeSize TargetTypeSize = DL.getTypeSizeInBits(TargetType);
+    [[maybe_unused]] TypeSize SourceTypeSize = DL.getTypeSizeInBits(SourceType);
+    assert(TargetTypeSize <= SourceTypeSize);
     LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
     buildAssignType(B, SourceType, NewLoad);
     Value *AssignValue = NewLoad;
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 53074ea3b2597..bedb42752d241 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -19,6 +19,7 @@
 #include "llvm/CodeGen/MachineInstr.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
 
 using namespace llvm;
 using namespace llvm::LegalizeActions;
@@ -101,6 +102,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                      v4s64, v8s1,   v8s8,   v8s16, v8s32, v8s64, v16s1,
                      v16s8, v16s16, v16s32, v16s64};
 
+  auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
+                           v3s1, v3s8, v3s16, v3s32, v3s64,
+                           v4s1, v4s8, v4s16, v4s32, v4s64};
+
+  auto allNonShaderVectors = {v8s1,  v8s8,  v8s16,  v8s32,  v8s64,
+                              v16s1, v16s8, v16s16, v16s32, v16s64};
+
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
@@ -148,15 +156,46 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
         return IsExtendedInts && Ty.isValid();
       };
 
-  for (auto Opc : getTypeFoldingSupportedOpcodes())
-    getActionDefinitionsBuilder(Opc).custom();
+  // TODO: So far we only legalize vectors for Shaders.
+  // We need to legalize for kernels as well. For Kernels
+  // vector sizes of 8 and 16 are allowed as well.
 
-  getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
+  for (auto Opc : getTypeFoldingSupportedOpcodes()) {
+    if (Opc != G_EXTRACT_VECTOR_ELT)
+      getActionDefinitionsBuilder(Opc).custom();
+  }
 
-  // TODO: add proper rules for vectors legalization.
-  getActionDefinitionsBuilder(
-      {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
-      .alwaysLegal();
+  if (ST.canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450)) {
+    getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
+        .lowerIf(vectorElementCountIsGreaterThan(0, 4))
+        .lowerIf(vectorElementCountIsGreaterThan(1, 4))
+        .legalFor(allShaderVectors);
+    getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
+        .moreElementsToNextPow2(1)
+        .fewerElementsIf(vectorElementCountIsGreaterThan(1, 4),
+                         LegalizeMutations::changeElementCountTo(
+                             1, ElementCount::getFixed(4)))
+        .custom();
+    getActionDefinitionsBuilder(G_BUILD_VECTOR)
+        .legalFor(allShaderVectors)
+        .fewerElementsIf(vectorElementCountIsGreaterThan(0, 4),
+                         LegalizeMutations::changeElementCountTo(
+                             0, ElementCount::getFixed(4)));
+    getActionDefinitionsBuilder(G_BITCAST)
+        .moreElementsToNextPow2(0)
+        .moreElementsToNextPow2(1)
+        .fewerElementsIf(vectorElementCountIsGreaterThan(0, 4),
+                         LegalizeMutations::changeElementCountTo(
+                             0, ElementCount::getFixed(4)))
+        .lowerIf(vectorElementCountIsGreaterThan(1, 4))
+        .alwaysLegal();
+    getActionDefinitionsBuilder(G_CONCAT_VECTORS)
+        .legalFor(allShaderVectors)
+        .lower();
+  } else
+    getActionDefinitionsBuilder(
+        {G_SHUFFLE_VECTOR, G_BUILD_VECTOR, G_SPLAT_VECTOR})
+        .alwaysLegal();
 
   // Vector Reduction Operations
   getActionDefinitionsBuilder(
@@ -287,6 +326,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   // Pointer-handling.
   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
 
+  getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
+
   // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
   getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d17528dd882bf..94d7417868750 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -17,7 +17,8 @@
 #include "SPIRV.h"
 #include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
-#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
 #include <stack>
 
 #define DEBUG_TYPE "spirv-postlegalizer"
@@ -45,6 +46,10 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
 
 static bool mayBeInserted(unsigned Opcode) {
   switch (Opcode) {
+  case TargetOpcode::G_CONSTANT:
+  case TargetOpcode::G_UNMERGE_VALUES:
+  case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+  case TargetOpcode::G_BITCAST:
   case TargetOpcode::G_SMAX:
   case TargetOpcode::G_UMAX:
   case TargetOpcode::G_SMIN:
@@ -53,70 +58,230 @@ static bool mayBeInserted(unsigned Opcode) {
   case TargetOpcode::G_FMINIMUM:
   case TargetOpcode::G_FMAXNUM:
   case TargetOpcode::G_FMAXIMUM:
+  case TargetOpcode::G_IMPLICIT_DEF:
+  case TargetOpcode::G_BUILD_VECTOR:
     return true;
   default:
     return isTypeFoldingSupported(Opcode);
   }
 }
 
-static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
-                             MachineIRBuilder MIB) {
+static bool processInstr(MachineInstr *I, MachineFunction &MF,
+                         SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) {
   MachineRegisterInfo &MRI = MF.getRegInfo();
+  const unsigned Opcode = I->getOpcode();
+  Register ResVReg = I->getOperand(0).getReg();
+  SPIRVType *ResType = nullptr;
+  bool Handled = false;
 
-  for (MachineBasicBlock &MBB : MF) {
-    for (MachineInstr &I : MBB) {
-      const unsigned Opcode = I.getOpcode();
-      if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
-        unsigned ArgI = I.getNumOperands() - 1;
-        Register SrcReg = I.getOperand(ArgI).isReg()
-                              ? I.getOperand(ArgI).getReg()
-                              : Register(0);
-        SPIRVType *DefType =
-            SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
-        if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
-          report_fatal_error(
-              "cannot select G_UNMERGE_VALUES with a non-vector argument");
+  switch (Opcode) {
+  case TargetOpcode::G_CONSTANT: {
+    const LLT &Ty = MRI.getType(ResVReg);
+    unsigned BitWidth = Ty.getScalarSizeInBits();
+    ResType = GR->getOrCreateSPIRVIntegerType(BitWidth, MIB);
+    Handled = true;
+    break;
+  }
+  case TargetOpcode::G_UNMERGE_VALUES: {
+    Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+    if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+      if (DefType->getOpcode() == SPIRV::OpTypeVector) {
         SPIRVType *ScalarType =
             GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
-        for (unsigned i = 0; i < I.getNumDefs(); ++i) {
-          Register ResVReg = I.getOperand(i).getReg();
-          SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
-          if (!ResType) {
-            // There was no "assign type" actions, let's fix this now
-            ResType = ScalarType;
-            setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
+        for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+          Register DefReg = I->getOperand(i).getReg();
+          if (!GR->getSPIRVTypeForVReg(DefReg)) {
+            LLT DefLLT = MRI.getType(DefReg);
+            SPIRVType *ResType;
+            if (DefLLT.isVector()) {
+              const SPIRVInstrInfo *TII =
+                  MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+              ResType = GR->getOrCreateSPIRVVectorType(
+                  ScalarType, DefLLT.getNumElements(), *I, *TII);
+            } else {
+              ResType = ScalarType;
+            }
+            setRegClassType(DefReg, ResType, GR, &MRI, MF);
           }
         }
-      } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
-                 I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
-        // Legalizer may have added a new instructions and introduced new
-        // registers, we must decorate them as if they were introduced in a
-        // non-automatic way
-        Register ResVReg = I.getOperand(0).getReg();
-        // Check if the register defined by the instruction is newly generated
-        // or already processed
-        // Check if we have type defined for operands of the new instruction
-        bool IsKnownReg = MRI.getRegClassOrNull(ResVReg);
-        SPIRVType *ResVType = GR->getSPIRVTypeForVReg(
-            IsKnownReg ? ResVReg : I.getOperand(1).getReg());
-        if (!ResVType)
-          continue;
-        // Set type & class
-        if (!IsKnownReg)
-          setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
-        // If this is a simple operation that is to be reduced by TableGen
-        // definition we must apply some of pre-legalizer rules here
-        if (isTypeFoldingSupported(Opcode)) {
-          processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg));
-          if (IsKnownReg && MRI.hasOneUse(ResVReg)) {
-            MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg);
-            if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
-              continue;
+        Handled = true;
+      }
+    }
+    break;
+  }
+  case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
+    LLVM_DEBUG(dbgs() << "Processing G_EXTRACT_VECTOR_ELT: " << *I);
+    Register VecReg = I->getOperand(1).getReg();
+    if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) {
+      LLVM_DEBUG(dbgs() << "  Found vector type: " << *VecType << "\n");
+      if (VecType->getOpcode() != SPIRV::OpTypeVector) {
+        VecType->dump();
+      }
+      assert(VecType->getOpcode() == SPIRV::OpTypeVector);
+      ResType = GR->getScalarOrVectorComponentType(VecType);
+      Handled = true;
+    } else {
+      LLVM_DEBUG(dbgs() << "  Vector operand " << VecReg
+                        << " has no type. Looking at uses of " << ResVReg
+                        << ".\n");
+      // If not handled yet, then check if it is used in a G_BUILD_VECTOR.
+      // If so get the type from there.
+      for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+        LLVM_DEBUG(dbgs() << "  Use: " << Use);
+        if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
+          LLVM_DEBUG(dbgs() << "    Use is G_BUILD_VECTOR.\n");
+          Register BuildVecResReg = Use.getOperand(0).getReg();
+          if (SPIRVType *BuildVecType =
+                  GR->getSPIRVTypeForVReg(BuildVecResReg)) {
+            LLVM_DEBUG(dbgs() << "    Found G_BUILD_VECTOR result type: "
+                              << *BuildVecType << "\n");
+            ResType = GR->getScalarOrVectorComponentType(BuildVecType);
+            Handled = true;
+            break;
+          } else {
+            LLVM_DEBUG(dbgs() << "    G_BUILD_VECTOR result " << BuildVecResReg
+                              << " has no type yet.\n");
           }
-          insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
         }
       }
     }
+    if (!Handled) {
+      LLVM_DEBUG(
+          dbgs() << "  Could not determine type for G_EXTRACT_VECTOR_ELT.\n");
+    }
+    break;
+  }
+  case TargetOpcode::G_BUILD_VECTOR: {
+    // First check if any of the operands have a type.
+    for (unsigned i = 1; i < I->getNumOperands(); ++i) {
+      if (SPIRVType *OpType =
+              GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
+        const LLT &ResLLT = MRI.getType(ResVReg);
+        ResType = GR->getOrCreateSPIRVVectorType(
+            OpType, ResLLT.getNumElements(), MIB, false);
+        Handled = true;
+        break;
+      }
+    }
+    if (Handled) {
+      break;
+    }
+    // If that did not work, then check the uses.
+    for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+      if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+        Register ExtractResReg = Use.getOperand(0).getReg();
+        if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
+          const LLT &ResLLT = MRI.getType(ResVReg);
+          ResType = GR->getOrCreateSPIRVVectorType(
+              ScalarType, ResLLT.getNumElements(), MIB, false);
+          Handled = true;
+          break;
+        }
+      }
+    }
+    break;
+  }
+  case TargetOpcode::G_IMPLICIT_DEF: {
+    for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+      const unsigned UseOpc = Use.getOpcode();
+      assert(UseOpc == TargetOpcode::G_BUILD_VECTOR ||
+             UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+      // It's possible that the use instruction has not been processed yet.
+      // We should look at the operands of the use to determine the type.
+      for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
+        if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg())) {
+          ResType = Type;
+          Handled = true;
+          break;
+        }
+      }
+      if (Handled) {
+        break;
+      }
+    }
+    break;
+  }
+  case TargetOpcode::G_BITCAST: {
+    for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+      const unsigned UseOpc = Use.getOpcode();
+      assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
+             UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+      Register UseResultReg = Use.getOperand(0).getReg();
+      if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
+        SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
+        const LLT &BitcastLLT = MRI.getType(ResVReg);
+        if (BitcastLLT.isVector()) {
+          ResType = GR->getOrCreateSPIRVVectorType(
+              ScalarType, BitcastLLT.getNumElements(), MIB, false);
+        } else {
+          ResType = ScalarType;
+        }
+        Handled = true;
+        break;
+      }
+    }
+    break;
+  }
+  default:
+    if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
+        I->getOperand(1).isReg()) {
+      if (SPIRVType *OpType =
+              GR->getSPIRVTypeForVReg(I->getOperand(1).getReg())) {
+        ResType = OpType;
+        Handled = true;
+      }
+    }
+    break;
+  }
+
+  if (Handled && ResType) {
+    LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n");
+    GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
+  }
+  return Handled;
+}
+
+static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+                             MachineIRBuilder MIB) {
+  SmallVector<MachineInstr *, 8> Worklist;
+
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &I : MBB) {
+      if (I.getNumDefs() > 0 &&
+          !GR->getSPIRVTypeForVReg(I.getOperand(0).getReg()) &&
+          mayBeInserted(I.getOpcode())) {
+        Worklist.push_back(&I);
+      }
+    }
+  }
+
+  if (Worklist.empty()) {
+    return;
+  }
+
+  LLVM_DEBUG(dbgs() << "Initial worklist:\n";
+             for (auto *I : Worklist) { I->dump(); });
+
+  bool Changed = true;
+  while (Changed) {
+    Changed = false;
+    SmallVector<MachineInstr *, 8> NextWorklist;
+
+    for (MachineInstr *I : Worklist) {
+      if (processInstr(I, MF, GR, MIB)) {
+        Changed = true;
+      } else {
+        NextWorklist.push_back(I);
+      }
+    }
+    Worklist = NextWorklist;
+    LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
+  }
+
+  if (!Worklist.empty()) {
+    LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
+               for (auto *I : Worklist) { I->dump(); });
+    assert(Worklist.empty() && "Worklist is not empty");
   }
 }
 
@@ -159,6 +324,46 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
 
   processNewInstrs(MF, GR, MIB);
 
+  // TODO: Move this into is own function.
+  SmallVector<MachineInstr *, 8> ExtractInstrs;
+  SmallVector<MachineInstr *, 8> BitcastInstrs;
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : MBB) {
+      if (MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+        ExtractInstrs.push_back(&MI);
+      } else if (MI.getOpcode() == TargetOpcode::G_BITCAST) {
+        BitcastInstrs.push_back(&MI);
+      }
+    }
+  }
+
+  for (MachineInstr *MI : ExtractInstrs) {
+    MachineIRBuilder MIB(*MI);
+    Register Dst = MI->getOperand(0).getReg();
+    Register Vec = MI->getOperand(1).getReg();
+    Register Idx = MI->getOperand(2).getReg();
+
+    auto Intr =
+        MIB.buildIntrinsic(Intrinsic::spv_extractelt, Dst, false, false);
+    Intr.addUse(Vec);
+    Intr.addUse(Idx);
+
+    MI->eraseFromParent();
+  }
+
+  for (MachineInstr *MI : BitcastInstrs) {
+    MachineIRBuilder MIB(*MI);
+    Register Dst = MI->getOperand(0).getReg();
+    Register Src = MI->getOperand(1).getReg();
+    SPIRVType *DstType = GR->getSPIRVTypeForVReg(Dst);
+    assert(DstType && "Destination of G_BITCAST must have a type");
+    MIB.buildInstr(SPIRV::OpBitcast)
+        .addDef(Dst)
+        .addUse(GR->getSPIRVTypeID(DstType))
+        .addUse(Src);
+    MI->eraseFromParent();
+  }
+
   return true;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index db6f2d61e8f29..b4b9771b0b50a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -192,6 +192,10 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
         .addUse(OpReg);
 }
 
+// TODO: See if the comment needs to be more precise. This is a problem for more
+// than just pointers. A bitcast between an two type that map to the same LLT
+// will cause a problem. For example a bitcast from a float to an int.
+
 // We do instruction selections early instead of calling MIB.buildBitcast()
 // generating the general op code G_BITCAST. When MachineVerifier validates
 // G_BITCAST we see a check of a kind: if Source Type is equal to Destination
@@ -212,15 +216,29 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
 // in https://github.com/llvm/llvm-project/pull/110270 for even more context.
 static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
                              MachineIRBuilder MIB) {
+  // TODO: This change will still cause failures for the machine verifier.
+  // This is done so that the G_Bitcast legalization rules can be used.
+  // We will have to add rules to legalize the intrinsic. Note that we cannot
+  // try to lower only bitcast the verify will complain about. You could cast
+  // a long vector of float to a long vector of ints. That has to be legalized
+  // but is also an invalid G_Bitcast.
+  //
+  // Could we use G_COPY? for cases where the LLT are the same? Then lowering
+  // the G_COPY could be either an OpBitcast or OpCopyObject denpending on the
+  // source and result type.
+
+  MachineRegisterInfo &MRI = MF.getRegInfo();
   SmallVector<MachineInstr *, 16> ToErase;
   for (MachineBasicBlock &MBB : MF) {
     for (MachineInstr &MI : MBB) {
       if (MI.getOpcode() != TargetOpcode::G_BITCAST)
         continue;
-      MIB.setInsertPt(*MI.getParent(), MI);
-      buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),
-                     MI.getOperand(1).getReg());
-      ToErase.push_back(&MI);
+      Register DstReg = MI.getOperand(0).getReg();
+      if (MRI.getType(DstReg).isPointer()) {
+        MIB.setInsertPt(*MI.getParent(), MI);
+        buildOpBitcast(GR, MIB, DstReg, MI.getOperand(1).getReg());
+        ToErase.push_back(&MI);
+      }
     }
   }
   for (MachineInstr *MI : ToErase) {



More information about the llvm-commits mailing list