[llvm] f2f3050 - Revert "[SLP]Emit actual bitwidth for analyzed MinBitwidth nodes, NFCI."

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 14 09:46:06 PST 2023


Author: Alexey Bataev
Date: 2023-11-14T09:45:54-08:00
New Revision: f2f3050476544e7a96ae5c3075427bb045b97187

URL: https://github.com/llvm/llvm-project/commit/f2f3050476544e7a96ae5c3075427bb045b97187
DIFF: https://github.com/llvm/llvm-project/commit/f2f3050476544e7a96ae5c3075427bb045b97187.diff

LOG: Revert "[SLP]Emit actual bitwidth for analyzed MinBitwidth nodes, NFCI."

This reverts commit f6ae50f710d02d8553d28192a1f048b2a9e1fc4d to fix
a crash revealed in the internal testing.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/AArch64/trunc-insertion.ll
    llvm/test/Transforms/SLPVectorizer/X86/partail.ll
    llvm/test/Transforms/SLPVectorizer/X86/root-trunc-extract-reuse.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f0b5eba6c7b891b..bf380d073e635bb 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7895,26 +7895,6 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       continue;
     UsedScalars.set(I);
   }
-  auto GetCastContextHint = [&](Value *V) {
-    if (const TreeEntry *OpTE = getTreeEntry(V)) {
-      if (OpTE->State == TreeEntry::ScatterVectorize)
-        return TTI::CastContextHint::GatherScatter;
-      if (OpTE->State == TreeEntry::Vectorize &&
-          OpTE->getOpcode() == Instruction::Load && !OpTE->isAltShuffle()) {
-        if (OpTE->ReorderIndices.empty())
-          return TTI::CastContextHint::Normal;
-        SmallVector<int> Mask;
-        inversePermutation(OpTE->ReorderIndices, Mask);
-        if (ShuffleVectorInst::isReverseMask(Mask, Mask.size()))
-          return TTI::CastContextHint::Reversed;
-      }
-    } else {
-      InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI);
-      if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle())
-        return TTI::CastContextHint::GatherScatter;
-    }
-    return TTI::CastContextHint::None;
-  };
   auto GetCostDiff =
       [=](function_ref<InstructionCost(unsigned)> ScalarEltCost,
           function_ref<InstructionCost(InstructionCost)> VectorCost) {
@@ -7934,39 +7914,6 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         }
 
         InstructionCost VecCost = VectorCost(CommonCost);
-        // Check if the current node must be resized, if the parent node is not
-        // resized.
-        if (!UnaryInstruction::isCast(E->getOpcode()) && E->Idx != 0) {
-          const EdgeInfo &EI = E->UserTreeIndices.front();
-          if ((EI.UserTE->getOpcode() != Instruction::Select ||
-               EI.EdgeIdx != 0) &&
-              It != MinBWs.end()) {
-            auto UserBWIt = MinBWs.find(EI.UserTE->Scalars.front());
-            Type *UserScalarTy =
-                EI.UserTE->getOperand(EI.EdgeIdx).front()->getType();
-            if (UserBWIt != MinBWs.end())
-              UserScalarTy = IntegerType::get(ScalarTy->getContext(),
-                                              UserBWIt->second.first);
-            if (ScalarTy != UserScalarTy) {
-              unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
-              unsigned SrcBWSz = DL->getTypeSizeInBits(UserScalarTy);
-              unsigned VecOpcode;
-              auto *SrcVecTy =
-                  FixedVectorType::get(UserScalarTy, E->getVectorFactor());
-              if (BWSz > SrcBWSz)
-                VecOpcode = Instruction::Trunc;
-              else
-                VecOpcode =
-                    It->second.second ? Instruction::SExt : Instruction::ZExt;
-              TTI::CastContextHint CCH = GetCastContextHint(VL0);
-              VecCost += TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH,
-                                               CostKind);
-              ScalarCost +=
-                  Sz * TTI->getCastInstrCost(VecOpcode, ScalarTy, UserScalarTy,
-                                             CCH, CostKind);
-            }
-          }
-        }
         LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost - CommonCost,
                                  ScalarCost, "Calculated costs for Tree"));
         return VecCost - ScalarCost;
@@ -8238,7 +8185,6 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     Type *SrcScalarTy = VL0->getOperand(0)->getType();
     auto *SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
     unsigned Opcode = ShuffleOrOp;
-    unsigned VecOpcode = Opcode;
     if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
         (SrcIt != MinBWs.end() || It != MinBWs.end())) {
       // Check if the values are candidates to demote.
@@ -8250,36 +8196,46 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       }
       unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
       if (BWSz == SrcBWSz) {
-        VecOpcode = Instruction::BitCast;
+        Opcode = Instruction::BitCast;
       } else if (BWSz < SrcBWSz) {
-        VecOpcode = Instruction::Trunc;
+        Opcode = Instruction::Trunc;
       } else if (It != MinBWs.end()) {
         assert(BWSz > SrcBWSz && "Invalid cast!");
-        VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+        Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
       }
     }
-    auto GetScalarCost = [&](unsigned Idx) -> InstructionCost {
-      // Do not count cost here if minimum bitwidth is in effect and it is just
-      // a bitcast (here it is just a noop).
-      if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast)
-        return TTI::TCC_Free;
+    auto GetScalarCost = [&](unsigned Idx) {
       auto *VI = VL0->getOpcode() == Opcode
                      ? cast<Instruction>(UniqueValues[Idx])
                      : nullptr;
-      return TTI->getCastInstrCost(Opcode, VL0->getType(),
-                                   VL0->getOperand(0)->getType(),
+      return TTI->getCastInstrCost(Opcode, ScalarTy, SrcScalarTy,
                                    TTI::getCastContextHint(VI), CostKind, VI);
     };
+    TTI::CastContextHint CCH = TTI::CastContextHint::None;
+    if (const TreeEntry *OpTE = getTreeEntry(VL0->getOperand(0))) {
+      if (OpTE->State == TreeEntry::ScatterVectorize) {
+        CCH = TTI::CastContextHint::GatherScatter;
+      } else if (OpTE->State == TreeEntry::Vectorize &&
+                 OpTE->getOpcode() == Instruction::Load &&
+                 !OpTE->isAltShuffle()) {
+        if (OpTE->ReorderIndices.empty()) {
+          CCH = TTI::CastContextHint::Normal;
+        } else {
+          SmallVector<int> Mask;
+          inversePermutation(OpTE->ReorderIndices, Mask);
+          if (ShuffleVectorInst::isReverseMask(Mask, Mask.size()))
+            CCH = TTI::CastContextHint::Reversed;
+        }
+      }
+    } else {
+      InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI);
+      if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle())
+        CCH = TTI::CastContextHint::GatherScatter;
+    }
     auto GetVectorCost = [=](InstructionCost CommonCost) {
-      // Do not count cost here if minimum bitwidth is in effect and it is just
-      // a bitcast (here it is just a noop).
-      if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast)
-        return CommonCost;
       auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
-      TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
       return CommonCost +
-             TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
-                                   VecOpcode == Opcode ? VI : nullptr);
+             TTI->getCastInstrCost(Opcode, VecTy, SrcVecTy, CCH, CostKind, VI);
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }
@@ -9030,7 +8986,6 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
   SmallVector<std::pair<Value *, const TreeEntry *>> FirstUsers;
   SmallVector<APInt> DemandedElts;
   SmallDenseSet<Value *, 4> UsedInserts;
-  DenseSet<Value *> VectorCasts;
   for (ExternalUser &EU : ExternalUses) {
     // We only add extract cost once for the same scalar.
     if (!isa_and_nonnull<InsertElementInst>(EU.User) &&
@@ -9099,28 +9054,6 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
             FirstUsers.emplace_back(VU, ScalarTE);
             DemandedElts.push_back(APInt::getZero(FTy->getNumElements()));
             VecId = FirstUsers.size() - 1;
-            auto It = MinBWs.find(EU.Scalar);
-            if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) {
-              unsigned BWSz = It->second.second;
-              unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType());
-              unsigned VecOpcode;
-              if (BWSz < SrcBWSz)
-                VecOpcode = Instruction::Trunc;
-              else
-                VecOpcode =
-                    It->second.second ? Instruction::SExt : Instruction::ZExt;
-              TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
-              InstructionCost C = TTI->getCastInstrCost(
-                  VecOpcode, FTy,
-                  FixedVectorType::get(
-                      IntegerType::get(FTy->getContext(), It->second.first),
-                      FTy->getNumElements()),
-                  TTI::CastContextHint::None, CostKind);
-              LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C
-                                << " for extending externally used vector with "
-                                   "non-equal minimum bitwidth.\n");
-              Cost += C;
-            }
           } else {
             if (isFirstInsertElement(VU, cast<InsertElementInst>(It->first)))
               It->first = VU;
@@ -9156,21 +9089,6 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
                                              CostKind, EU.Lane);
     }
   }
-  // Add reduced value cost, if resized.
-  if (!VectorizedVals.empty()) {
-    auto BWIt = MinBWs.find(VectorizableTree.front()->Scalars.front());
-    if (BWIt != MinBWs.end()) {
-      Type *DstTy = BWIt->first->getType();
-      unsigned OriginalSz = DL->getTypeSizeInBits(DstTy);
-      unsigned Opcode = Instruction::Trunc;
-      if (OriginalSz < BWIt->second.first)
-        Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
-      Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
-      Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
-                                    TTI::CastContextHint::None,
-                                    TTI::TCK_RecipThroughput);
-    }
-  }
 
   InstructionCost SpillCost = getSpillCost();
   Cost += SpillCost + ExtractCost;
@@ -9376,11 +9294,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
       Instruction &LastBundleInst = getLastInstructionInBundle(VTE);
       if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst))
         continue;
-      auto It = MinBWs.find(VTE->Scalars.front());
-      // If vectorize node is demoted - do not match.
-      if (It != MinBWs.end() &&
-          It->second.first != DL->getTypeSizeInBits(V->getType()))
-        continue;
       VToTEs.insert(VTE);
     }
     if (VToTEs.empty())
@@ -10937,10 +10850,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
     return Vec;
   }
 
-  auto FinalShuffle = [&](Value *V, const TreeEntry *E, VectorType *VecTy,
-                          bool IsSigned) {
-    if (V->getType() != VecTy)
-      V = Builder.CreateIntCast(V, VecTy, IsSigned);
+  auto FinalShuffle = [&](Value *V, const TreeEntry *E) {
     ShuffleInstructionBuilder ShuffleBuilder(Builder, *this);
     if (E->getOpcode() == Instruction::Store) {
       ArrayRef<int> Mask =
@@ -10967,12 +10877,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
     ScalarTy = Store->getValueOperand()->getType();
   else if (auto *IE = dyn_cast<InsertElementInst>(VL0))
     ScalarTy = IE->getOperand(1)->getType();
-  bool IsSigned = false;
-  auto It = MinBWs.find(E->Scalars.front());
-  if (It != MinBWs.end()) {
-    ScalarTy = IntegerType::get(F->getContext(), It->second.first);
-    IsSigned = It->second.second;
-  }
   auto *VecTy = FixedVectorType::get(ScalarTy, E->Scalars.size());
   switch (ShuffleOrOp) {
     case Instruction::PHI: {
@@ -10996,7 +10900,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
                                PH->getParent()->getFirstInsertionPt());
         Builder.SetCurrentDebugLocation(PH->getDebugLoc());
 
-        V = FinalShuffle(V, E, VecTy, IsSigned);
+        V = FinalShuffle(V, E);
 
         E->VectorizedValue = V;
         if (PostponedPHIs)
@@ -11029,10 +10933,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         Builder.SetInsertPoint(IBB->getTerminator());
         Builder.SetCurrentDebugLocation(PH->getDebugLoc());
         Value *Vec = vectorizeOperand(E, i, /*PostponedPHIs=*/true);
-        if (VecTy != Vec->getType()) {
-          assert(It != MinBWs.end() && "Expected item in MinBWs.");
-          Vec = Builder.CreateIntCast(Vec, VecTy, It->second.second);
-        }
         NewPhi->addIncoming(Vec, IBB);
       }
 
@@ -11044,7 +10944,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
     case Instruction::ExtractElement: {
       Value *V = E->getSingleOperand(0);
       setInsertPointAfterBundle(E);
-      V = FinalShuffle(V, E, VecTy, IsSigned);
+      V = FinalShuffle(V, E);
       E->VectorizedValue = V;
       return V;
     }
@@ -11054,7 +10954,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       Value *Ptr = LI->getPointerOperand();
       LoadInst *V = Builder.CreateAlignedLoad(VecTy, Ptr, LI->getAlign());
       Value *NewV = propagateMetadata(V, E->Scalars);
-      NewV = FinalShuffle(NewV, E, VecTy, IsSigned);
+      NewV = FinalShuffle(NewV, E);
       E->VectorizedValue = NewV;
       return NewV;
     }
@@ -11062,19 +10962,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       assert(E->ReuseShuffleIndices.empty() && "All inserts should be unique");
       Builder.SetInsertPoint(cast<Instruction>(E->Scalars.back()));
       Value *V = vectorizeOperand(E, 1, PostponedPHIs);
-      ArrayRef<Value *> Op = E->getOperand(1);
-      Type *ScalarTy = Op.front()->getType();
-      if (cast<VectorType>(V->getType())->getElementType() != ScalarTy) {
-        assert(ScalarTy->isIntegerTy() && "Expected item in MinBWs.");
-        std::pair<unsigned, bool> Res = MinBWs.lookup(Op.front());
-        assert(Res.first > 0 && "Expected item in MinBWs.");
-        V = Builder.CreateIntCast(
-            V,
-            FixedVectorType::get(
-                ScalarTy,
-                cast<FixedVectorType>(V->getType())->getNumElements()),
-            Res.second);
-      }
 
       // Create InsertVector shuffle if necessary
       auto *FirstInsert = cast<Instruction>(*find_if(E->Scalars, [E](Value *V) {
@@ -11240,7 +11127,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
 
       auto *CI = cast<CastInst>(VL0);
       Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
-      V = FinalShuffle(V, E, VecTy, IsSigned);
+      V = FinalShuffle(V, E);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11260,22 +11147,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
         return E->VectorizedValue;
       }
-      if (L->getType() != R->getType()) {
-        assert(It != MinBWs.end() && "Expected item in MinBWs.");
-        if (L == R) {
-          R = L = Builder.CreateIntCast(L, VecTy, IsSigned);
-        } else {
-          L = Builder.CreateIntCast(L, VecTy, IsSigned);
-          R = Builder.CreateIntCast(R, VecTy, IsSigned);
-        }
-      }
 
       CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate();
       Value *V = Builder.CreateCmp(P0, L, R);
       propagateIRFlags(V, E->Scalars, VL0);
-      // Do not cast for cmps.
-      VecTy = cast<FixedVectorType>(V->getType());
-      V = FinalShuffle(V, E, VecTy, IsSigned);
+      V = FinalShuffle(V, E);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11299,18 +11175,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
         return E->VectorizedValue;
       }
-      if (True->getType() != False->getType()) {
-        assert(It != MinBWs.end() && "Expected item in MinBWs.");
-        if (True == False) {
-          True = False = Builder.CreateIntCast(True, VecTy, IsSigned);
-        } else {
-          True = Builder.CreateIntCast(True, VecTy, IsSigned);
-          False = Builder.CreateIntCast(False, VecTy, IsSigned);
-        }
-      }
 
       Value *V = Builder.CreateSelect(Cond, True, False);
-      V = FinalShuffle(V, E, VecTy, IsSigned);
+      V = FinalShuffle(V, E);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11332,7 +11199,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       if (auto *I = dyn_cast<Instruction>(V))
         V = propagateMetadata(I, E->Scalars);
 
-      V = FinalShuffle(V, E, VecTy, IsSigned);
+      V = FinalShuffle(V, E);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11369,15 +11236,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
         return E->VectorizedValue;
       }
-      if (LHS->getType() != RHS->getType()) {
-        assert(It != MinBWs.end() && "Expected item in MinBWs.");
-        if (LHS == RHS) {
-          RHS = LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
-        } else {
-          LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
-          RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned);
-        }
-      }
 
       Value *V = Builder.CreateBinOp(
           static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS,
@@ -11386,7 +11244,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       if (auto *I = dyn_cast<Instruction>(V))
         V = propagateMetadata(I, E->Scalars);
 
-      V = FinalShuffle(V, E, VecTy, IsSigned);
+      V = FinalShuffle(V, E);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11432,7 +11290,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       }
       Value *V = propagateMetadata(NewLI, E->Scalars);
 
-      V = FinalShuffle(V, E, VecTy, IsSigned);
+      V = FinalShuffle(V, E);
       E->VectorizedValue = V;
       ++NumVectorInstructions;
       return V;
@@ -11443,7 +11301,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       setInsertPointAfterBundle(E);
 
       Value *VecValue = vectorizeOperand(E, 0, PostponedPHIs);
-      VecValue = FinalShuffle(VecValue, E, VecTy, IsSigned);
+      VecValue = FinalShuffle(VecValue, E);
 
       Value *Ptr = SI->getPointerOperand();
       StoreInst *ST =
@@ -11496,7 +11354,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         V = propagateMetadata(I, GEPs);
       }
 
-      V = FinalShuffle(V, E, VecTy, IsSigned);
+      V = FinalShuffle(V, E);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11576,7 +11434,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       }
 
       propagateIRFlags(V, E->Scalars, VL0);
-      V = FinalShuffle(V, E, VecTy, IsSigned);
+      V = FinalShuffle(V, E);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11608,15 +11466,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
         return E->VectorizedValue;
       }
-      if (LHS && RHS && LHS->getType() != RHS->getType()) {
-        assert(It != MinBWs.end() && "Expected item in MinBWs.");
-        if (LHS == RHS) {
-          RHS = LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
-        } else {
-          LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
-          RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned);
-        }
-      }
 
       Value *V0, *V1;
       if (Instruction::isBinaryOp(E->getOpcode())) {
@@ -11667,9 +11516,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         CSEBlocks.insert(I->getParent());
       }
 
-      if (V->getType() != VecTy && !isa<CmpInst>(VL0))
-        V = Builder.CreateIntCast(
-            V, FixedVectorType::get(ScalarTy, E->getVectorFactor()), IsSigned);
       E->VectorizedValue = V;
       ++NumVectorInstructions;
 
@@ -11717,7 +11563,8 @@ Value *BoUpSLP::vectorizeTree(
     Builder.SetInsertPoint(&F->getEntryBlock(), F->getEntryBlock().begin());
 
   // Postpone emission of PHIs operands to avoid cyclic dependencies issues.
-  (void)vectorizeTree(VectorizableTree[0].get(), /*PostponedPHIs=*/true);
+  auto *VectorRoot =
+      vectorizeTree(VectorizableTree[0].get(), /*PostponedPHIs=*/true);
   for (const std::unique_ptr<TreeEntry> &TE : VectorizableTree)
     if (TE->State == TreeEntry::Vectorize &&
         TE->getOpcode() == Instruction::PHI && !TE->isAltShuffle() &&
@@ -11777,6 +11624,28 @@ Value *BoUpSLP::vectorizeTree(
     eraseInstruction(PrevVec);
   }
 
+  // If the vectorized tree can be rewritten in a smaller type, we truncate the
+  // vectorized root. InstCombine will then rewrite the entire expression. We
+  // sign extend the extracted values below.
+  auto *ScalarRoot = VectorizableTree[0]->Scalars[0];
+  auto It = MinBWs.find(ScalarRoot);
+  if (It != MinBWs.end()) {
+    if (auto *I = dyn_cast<Instruction>(VectorRoot)) {
+      // If current instr is a phi and not the last phi, insert it after the
+      // last phi node.
+      if (isa<PHINode>(I))
+        Builder.SetInsertPoint(I->getParent(),
+                               I->getParent()->getFirstInsertionPt());
+      else
+        Builder.SetInsertPoint(&*++BasicBlock::iterator(I));
+    }
+    auto BundleWidth = VectorizableTree[0]->Scalars.size();
+    auto *MinTy = IntegerType::get(F->getContext(), It->second.first);
+    auto *VecTy = FixedVectorType::get(MinTy, BundleWidth);
+    auto *Trunc = Builder.CreateTrunc(VectorRoot, VecTy);
+    VectorizableTree[0]->VectorizedValue = Trunc;
+  }
+
   LLVM_DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size()
                     << " values .\n");
 
@@ -11787,7 +11656,6 @@ Value *BoUpSLP::vectorizeTree(
   // basic block. Only one extractelement per block should be emitted.
   DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs;
   SmallDenseSet<Value *, 4> UsedInserts;
-  DenseMap<Value *, Value *> VectorCasts;
   // Extract all of the elements with the external uses.
   for (const auto &ExternalUse : ExternalUses) {
     Value *Scalar = ExternalUse.Scalar;
@@ -11846,10 +11714,12 @@ Value *BoUpSLP::vectorizeTree(
         }
         // If necessary, sign-extend or zero-extend ScalarRoot
         // to the larger type.
-        if (Scalar->getType() != Ex->getType())
-          return Builder.CreateIntCast(Ex, Scalar->getType(),
-                                       MinBWs.find(Scalar)->second.second);
-        return Ex;
+        auto BWIt = MinBWs.find(ScalarRoot);
+        if (BWIt == MinBWs.end())
+          return Ex;
+        if (BWIt->second.second)
+          return Builder.CreateSExt(Ex, Scalar->getType());
+        return Builder.CreateZExt(Ex, Scalar->getType());
       }
       assert(isa<FixedVectorType>(Scalar->getType()) &&
              isa<InsertElementInst>(Scalar) &&
@@ -11888,24 +11758,12 @@ Value *BoUpSLP::vectorizeTree(
         if (auto *FTy = dyn_cast<FixedVectorType>(User->getType())) {
           if (!UsedInserts.insert(VU).second)
             continue;
-          // Need to use original vector, if the root is truncated.
-          auto BWIt = MinBWs.find(Scalar);
-          if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) {
-            auto VecIt = VectorCasts.find(Scalar);
-            if (VecIt == VectorCasts.end()) {
-              IRBuilder<>::InsertPointGuard Guard(Builder);
-              if (auto *IVec = dyn_cast<Instruction>(Vec))
-                Builder.SetInsertPoint(IVec->getNextNonDebugInstruction());
-              Vec = Builder.CreateIntCast(Vec, VU->getType(),
-                                          BWIt->second.second);
-              VectorCasts.try_emplace(Scalar, Vec);
-            } else {
-              Vec = VecIt->second;
-            }
-          }
-
           std::optional<unsigned> InsertIdx = getInsertIndex(VU);
           if (InsertIdx) {
+            // Need to use original vector, if the root is truncated.
+            if (MinBWs.contains(Scalar) &&
+                VectorizableTree[0]->VectorizedValue == Vec)
+              Vec = VectorRoot;
             auto *It =
                 find_if(ShuffledInserts, [VU](const ShuffledInsertData &Data) {
                   // Checks if 2 insertelements are from the same buildvector.
@@ -13050,8 +12908,11 @@ void BoUpSLP::computeMinimumValueSizes() {
     return;
 
   // If the expression is not rooted by a store, these roots should have
-  // external uses.
-  // TOSO: investigate if this can be relaxed.
+  // external uses. We will rely on InstCombine to rewrite the expression in
+  // the narrower type. However, InstCombine only rewrites single-use values.
+  // This means that if a tree entry other than a root is used externally, it
+  // must have multiple uses and InstCombine will not rewrite it. The code
+  // below ensures that only the roots are used externally.
   SmallPtrSet<Value *, 32> Expr(TreeRoot.begin(), TreeRoot.end());
   for (auto &EU : ExternalUses)
     if (!Expr.erase(EU.Scalar))
@@ -13083,7 +12944,7 @@ void BoUpSLP::computeMinimumValueSizes() {
   // The maximum bit width required to represent all the values that can be
   // demoted without loss of precision. It would be safe to truncate the roots
   // of the expression to this width.
-  auto MaxBitWidth = 1u;
+  auto MaxBitWidth = 8u;
 
   // We first check if all the bits of the roots are demanded. If they're not,
   // we can truncate the roots to this narrower type.
@@ -14720,16 +14581,6 @@ class HorizontalReduction {
 
         Value *ReducedSubTree =
             emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
-        if (ReducedSubTree->getType() != VL.front()->getType()) {
-          ReducedSubTree = Builder.CreateIntCast(
-              ReducedSubTree, VL.front()->getType(), any_of(VL, [&](Value *R) {
-                KnownBits Known = computeKnownBits(
-                    R, cast<Instruction>(ReductionOps.front().front())
-                           ->getModule()
-                           ->getDataLayout());
-                return !Known.isNonNegative();
-              }));
-        }
 
         // Improved analysis for add/fadd/xor reductions with same scale factor
         // for all operands of reductions. We can emit scalar ops for them

diff  --git a/llvm/test/Transforms/SLPVectorizer/AArch64/trunc-insertion.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/trunc-insertion.ll
index 0eda2cbc862ff04..56f252bf640408e 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/trunc-insertion.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/trunc-insertion.ll
@@ -8,31 +8,34 @@ define dso_local void @l() local_unnamed_addr {
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    br label [[BB1:%.*]]
 ; CHECK:       bb1:
-; CHECK-NEXT:    [[TMP0:%.*]] = phi <2 x i16> [ undef, [[BB:%.*]] ], [ [[TMP9:%.*]], [[BB25:%.*]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = phi <2 x i16> [ undef, [[BB:%.*]] ], [ [[TMP11:%.*]], [[BB25:%.*]] ]
 ; CHECK-NEXT:    br i1 undef, label [[BB3:%.*]], label [[BB11:%.*]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    [[I4:%.*]] = zext i1 undef to i32
 ; CHECK-NEXT:    [[TMP1:%.*]] = xor <2 x i16> [[TMP0]], undef
 ; CHECK-NEXT:    [[TMP2:%.*]] = icmp ugt <2 x i16> [[TMP1]], <i16 8, i16 8>
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <2 x i1> [[TMP2]] to <2 x i32>
 ; CHECK-NEXT:    br label [[BB25]]
 ; CHECK:       bb11:
 ; CHECK-NEXT:    [[I12:%.*]] = zext i1 undef to i32
-; CHECK-NEXT:    [[TMP3:%.*]] = xor <2 x i16> [[TMP0]], undef
-; CHECK-NEXT:    [[TMP4:%.*]] = sext <2 x i16> [[TMP3]] to <2 x i64>
-; CHECK-NEXT:    [[TMP5:%.*]] = icmp ule <2 x i64> undef, [[TMP4]]
-; CHECK-NEXT:    [[TMP6:%.*]] = zext <2 x i1> [[TMP5]] to <2 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = icmp ult <2 x i32> undef, [[TMP6]]
+; CHECK-NEXT:    [[TMP4:%.*]] = xor <2 x i16> [[TMP0]], undef
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <2 x i16> [[TMP4]] to <2 x i64>
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp ule <2 x i64> undef, [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <2 x i1> [[TMP6]] to <2 x i32>
+; CHECK-NEXT:    [[TMP8:%.*]] = icmp ult <2 x i32> undef, [[TMP7]]
+; CHECK-NEXT:    [[TMP9:%.*]] = zext <2 x i1> [[TMP8]] to <2 x i32>
 ; CHECK-NEXT:    br label [[BB25]]
 ; CHECK:       bb25:
 ; CHECK-NEXT:    [[I28:%.*]] = phi i32 [ [[I12]], [[BB11]] ], [ [[I4]], [[BB3]] ]
-; CHECK-NEXT:    [[TMP8:%.*]] = phi <2 x i1> [ [[TMP7]], [[BB11]] ], [ [[TMP2]], [[BB3]] ]
-; CHECK-NEXT:    [[TMP9]] = phi <2 x i16> [ [[TMP3]], [[BB11]] ], [ [[TMP1]], [[BB3]] ]
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x i1> [[TMP8]], i32 0
-; CHECK-NEXT:    [[TMP11:%.*]] = zext i1 [[TMP10]] to i32
-; CHECK-NEXT:    [[I31:%.*]] = and i32 undef, [[TMP11]]
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x i1> [[TMP8]], i32 1
-; CHECK-NEXT:    [[TMP13:%.*]] = zext i1 [[TMP12]] to i32
-; CHECK-NEXT:    [[I32:%.*]] = and i32 [[I31]], [[TMP13]]
+; CHECK-NEXT:    [[TMP10:%.*]] = phi <2 x i32> [ [[TMP9]], [[BB11]] ], [ [[TMP3]], [[BB3]] ]
+; CHECK-NEXT:    [[TMP11]] = phi <2 x i16> [ [[TMP4]], [[BB11]] ], [ [[TMP1]], [[BB3]] ]
+; CHECK-NEXT:    [[TMP12:%.*]] = trunc <2 x i32> [[TMP10]] to <2 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <2 x i8> [[TMP12]], i32 0
+; CHECK-NEXT:    [[TMP14:%.*]] = zext i8 [[TMP13]] to i32
+; CHECK-NEXT:    [[I31:%.*]] = and i32 undef, [[TMP14]]
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <2 x i8> [[TMP12]], i32 1
+; CHECK-NEXT:    [[TMP16:%.*]] = zext i8 [[TMP15]] to i32
+; CHECK-NEXT:    [[I32:%.*]] = and i32 [[I31]], [[TMP16]]
 ; CHECK-NEXT:    [[I33:%.*]] = and i32 [[I32]], [[I28]]
 ; CHECK-NEXT:    br i1 undef, label [[BB34:%.*]], label [[BB1]]
 ; CHECK:       bb34:

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/partail.ll b/llvm/test/Transforms/SLPVectorizer/X86/partail.ll
index 40ca0150d8e744d..b9747b6ae8c89a3 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/partail.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/partail.ll
@@ -23,18 +23,20 @@ define void @get_block(i32 %y_pos) local_unnamed_addr #0 {
 ; CHECK-NEXT:    [[TMP7:%.*]] = select <4 x i1> [[TMP3]], <4 x i32> [[TMP6]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp slt <4 x i32> [[TMP7]], undef
 ; CHECK-NEXT:    [[TMP9:%.*]] = select <4 x i1> [[TMP8]], <4 x i32> [[TMP7]], <4 x i32> undef
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i32> [[TMP9]], i32 0
-; CHECK-NEXT:    [[TMP11:%.*]] = sext i32 [[TMP10]] to i64
-; CHECK-NEXT:    [[ARRAYIDX31:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i32> [[TMP9]], i32 1
+; CHECK-NEXT:    [[TMP10:%.*]] = sext <4 x i32> [[TMP9]] to <4 x i64>
+; CHECK-NEXT:    [[TMP11:%.*]] = trunc <4 x i64> [[TMP10]] to <4 x i32>
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i32> [[TMP11]], i32 0
 ; CHECK-NEXT:    [[TMP13:%.*]] = sext i32 [[TMP12]] to i64
-; CHECK-NEXT:    [[ARRAYIDX31_1:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP13]]
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x i32> [[TMP9]], i32 2
+; CHECK-NEXT:    [[ARRAYIDX31:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP13]]
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x i32> [[TMP11]], i32 1
 ; CHECK-NEXT:    [[TMP15:%.*]] = sext i32 [[TMP14]] to i64
-; CHECK-NEXT:    [[ARRAYIDX31_2:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP15]]
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x i32> [[TMP9]], i32 3
+; CHECK-NEXT:    [[ARRAYIDX31_1:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP15]]
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x i32> [[TMP11]], i32 2
 ; CHECK-NEXT:    [[TMP17:%.*]] = sext i32 [[TMP16]] to i64
-; CHECK-NEXT:    [[ARRAYIDX31_3:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP17]]
+; CHECK-NEXT:    [[ARRAYIDX31_2:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP17]]
+; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <4 x i32> [[TMP11]], i32 3
+; CHECK-NEXT:    [[TMP19:%.*]] = sext i32 [[TMP18]] to i64
+; CHECK-NEXT:    [[ARRAYIDX31_3:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP19]]
 ; CHECK-NEXT:    unreachable
 ;
 entry:

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/root-trunc-extract-reuse.ll b/llvm/test/Transforms/SLPVectorizer/X86/root-trunc-extract-reuse.ll
index f48528e502b8cf1..87dd2cfd2004430 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/root-trunc-extract-reuse.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/root-trunc-extract-reuse.ll
@@ -8,12 +8,12 @@ define i1 @test() {
 ; CHECK:       then:
 ; CHECK-NEXT:    br label [[ELSE]]
 ; CHECK:       else:
-; CHECK-NEXT:    [[TMP0:%.*]] = phi <2 x i1> [ zeroinitializer, [[THEN]] ], [ zeroinitializer, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <2 x i1> [[TMP0]] to <2 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x i1> [[TMP0]], i32 0
-; CHECK-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    [[TMP0:%.*]] = phi <2 x i32> [ zeroinitializer, [[THEN]] ], [ zeroinitializer, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i32> [[TMP0]] to <2 x i8>
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x i8> [[TMP1]], i32 0
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i8 [[TMP2]] to i32
 ; CHECK-NEXT:    [[BF_CAST162:%.*]] = and i32 [[TMP3]], 0
-; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <2 x i32> zeroinitializer, <2 x i32> [[TMP1]], <2 x i32> <i32 3, i32 1>
+; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <2 x i32> zeroinitializer, <2 x i32> [[TMP0]], <2 x i32> <i32 3, i32 1>
 ; CHECK-NEXT:    [[T13:%.*]] = and <2 x i32> [[TMP4]], zeroinitializer
 ; CHECK-NEXT:    br label [[ELSE1:%.*]]
 ; CHECK:       else1:


        


More information about the llvm-commits mailing list