[llvm] [SLP]Emit actual bitwidth for analyzed MinBitwidth nodes, NFCI. (PR #71536)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 10 07:25:22 PST 2023


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/71536

>From c4dd78482e8ac04327c0cd4d0e8a35d89d6612f3 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Mon, 23 Oct 2023 05:43:15 -0700
Subject: [PATCH] [SLP]Emit actual bitwidth for analyzed MinBitwidth nodes,
 NFCI.

SLP includes analysis for the minimum bitwidth, the actual integer
operations can be emitted. It allows to reduce register pressure and
improve perf. Currently, it includes only cost model and the next
transformation relies on InstructionCombiner. Better to do it directly
in SLP, it allows to reduce compile time and fix cost model issues.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 309 +++++++++++++-----
 .../SLPVectorizer/AArch64/trunc-insertion.ll  |  31 +-
 .../Transforms/SLPVectorizer/X86/partail.ll   |  20 +-
 .../X86/root-trunc-extract-reuse.ll           |  10 +-
 4 files changed, 257 insertions(+), 113 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index bb233ed7d6c77ce..715b6ef6dbca70f 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7892,6 +7892,26 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       continue;
     UsedScalars.set(I);
   }
+  auto GetCastContextHint = [&](Value *V) {
+    if (const TreeEntry *OpTE = getTreeEntry(V)) {
+      if (OpTE->State == TreeEntry::ScatterVectorize)
+        return TTI::CastContextHint::GatherScatter;
+      if (OpTE->State == TreeEntry::Vectorize &&
+          OpTE->getOpcode() == Instruction::Load && !OpTE->isAltShuffle()) {
+        if (OpTE->ReorderIndices.empty())
+          return TTI::CastContextHint::Normal;
+        SmallVector<int> Mask;
+        inversePermutation(OpTE->ReorderIndices, Mask);
+        if (ShuffleVectorInst::isReverseMask(Mask, Mask.size()))
+          return TTI::CastContextHint::Reversed;
+      }
+    } else {
+      InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI);
+      if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle())
+        return TTI::CastContextHint::GatherScatter;
+    }
+    return TTI::CastContextHint::None;
+  };
   auto GetCostDiff =
       [=](function_ref<InstructionCost(unsigned)> ScalarEltCost,
           function_ref<InstructionCost(InstructionCost)> VectorCost) {
@@ -7911,6 +7931,39 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         }
 
         InstructionCost VecCost = VectorCost(CommonCost);
+        // Check if the current node must be resized, if the parent node is not
+        // resized.
+        if (!UnaryInstruction::isCast(E->getOpcode()) && E->Idx != 0) {
+          const EdgeInfo &EI = E->UserTreeIndices.front();
+          if ((EI.UserTE->getOpcode() != Instruction::Select ||
+               EI.EdgeIdx != 0) &&
+              It != MinBWs.end()) {
+            auto UserBWIt = MinBWs.find(EI.UserTE->Scalars.front());
+            Type *UserScalarTy =
+                EI.UserTE->getOperand(EI.EdgeIdx).front()->getType();
+            if (UserBWIt != MinBWs.end())
+              UserScalarTy = IntegerType::get(ScalarTy->getContext(),
+                                              UserBWIt->second.first);
+            if (ScalarTy != UserScalarTy) {
+              unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+              unsigned SrcBWSz = DL->getTypeSizeInBits(UserScalarTy);
+              unsigned VecOpcode;
+              auto *SrcVecTy =
+                  FixedVectorType::get(UserScalarTy, E->getVectorFactor());
+              if (BWSz > SrcBWSz)
+                VecOpcode = Instruction::Trunc;
+              else
+                VecOpcode =
+                    It->second.second ? Instruction::SExt : Instruction::ZExt;
+              TTI::CastContextHint CCH = GetCastContextHint(VL0);
+              VecCost += TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH,
+                                               CostKind);
+              ScalarCost +=
+                  Sz * TTI->getCastInstrCost(VecOpcode, ScalarTy, UserScalarTy,
+                                             CCH, CostKind);
+            }
+          }
+        }
         LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost - CommonCost,
                                  ScalarCost, "Calculated costs for Tree"));
         return VecCost - ScalarCost;
@@ -8182,6 +8235,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     Type *SrcScalarTy = VL0->getOperand(0)->getType();
     auto *SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
     unsigned Opcode = ShuffleOrOp;
+    unsigned VecOpcode = Opcode;
     if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
         (SrcIt != MinBWs.end() || It != MinBWs.end())) {
       // Check if the values are candidates to demote.
@@ -8193,46 +8247,36 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       }
       unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
       if (BWSz == SrcBWSz) {
-        Opcode = Instruction::BitCast;
+        VecOpcode = Instruction::BitCast;
       } else if (BWSz < SrcBWSz) {
-        Opcode = Instruction::Trunc;
+        VecOpcode = Instruction::Trunc;
       } else if (It != MinBWs.end()) {
         assert(BWSz > SrcBWSz && "Invalid cast!");
-        Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+        VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
       }
     }
-    auto GetScalarCost = [&](unsigned Idx) {
+    auto GetScalarCost = [&](unsigned Idx) -> InstructionCost {
+      // Do not count cost here if minimum bitwidth is in effect and it is just
+      // a bitcast (here it is just a noop).
+      if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast)
+        return TTI::TCC_Free;
       auto *VI = VL0->getOpcode() == Opcode
                      ? cast<Instruction>(UniqueValues[Idx])
                      : nullptr;
-      return TTI->getCastInstrCost(Opcode, ScalarTy, SrcScalarTy,
+      return TTI->getCastInstrCost(Opcode, VL0->getType(),
+                                   VL0->getOperand(0)->getType(),
                                    TTI::getCastContextHint(VI), CostKind, VI);
     };
-    TTI::CastContextHint CCH = TTI::CastContextHint::None;
-    if (const TreeEntry *OpTE = getTreeEntry(VL0->getOperand(0))) {
-      if (OpTE->State == TreeEntry::ScatterVectorize) {
-        CCH = TTI::CastContextHint::GatherScatter;
-      } else if (OpTE->State == TreeEntry::Vectorize &&
-                 OpTE->getOpcode() == Instruction::Load &&
-                 !OpTE->isAltShuffle()) {
-        if (OpTE->ReorderIndices.empty()) {
-          CCH = TTI::CastContextHint::Normal;
-        } else {
-          SmallVector<int> Mask;
-          inversePermutation(OpTE->ReorderIndices, Mask);
-          if (ShuffleVectorInst::isReverseMask(Mask, Mask.size()))
-            CCH = TTI::CastContextHint::Reversed;
-        }
-      }
-    } else {
-      InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI);
-      if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle())
-        CCH = TTI::CastContextHint::GatherScatter;
-    }
     auto GetVectorCost = [=](InstructionCost CommonCost) {
+      // Do not count cost here if minimum bitwidth is in effect and it is just
+      // a bitcast (here it is just a noop).
+      if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast)
+        return CommonCost;
       auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
+      TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
       return CommonCost +
-             TTI->getCastInstrCost(Opcode, VecTy, SrcVecTy, CCH, CostKind, VI);
+             TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
+                                   VecOpcode == Opcode ? VI : nullptr);
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }
@@ -8966,6 +9010,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
   SmallVector<std::pair<Value *, const TreeEntry *>> FirstUsers;
   SmallVector<APInt> DemandedElts;
   SmallDenseSet<Value *, 4> UsedInserts;
+  DenseSet<Value *> VectorCasts;
   for (ExternalUser &EU : ExternalUses) {
     // We only add extract cost once for the same scalar.
     if (!isa_and_nonnull<InsertElementInst>(EU.User) &&
@@ -9034,6 +9079,28 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
             FirstUsers.emplace_back(VU, ScalarTE);
             DemandedElts.push_back(APInt::getZero(FTy->getNumElements()));
             VecId = FirstUsers.size() - 1;
+            auto It = MinBWs.find(EU.Scalar);
+            if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) {
+              unsigned BWSz = It->second.second;
+              unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType());
+              unsigned VecOpcode;
+              if (BWSz < SrcBWSz)
+                VecOpcode = Instruction::Trunc;
+              else
+                VecOpcode =
+                    It->second.second ? Instruction::SExt : Instruction::ZExt;
+              TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+              InstructionCost C = TTI->getCastInstrCost(
+                  VecOpcode, FTy,
+                  FixedVectorType::get(
+                      IntegerType::get(FTy->getContext(), It->second.first),
+                      FTy->getNumElements()),
+                  TTI::CastContextHint::None, CostKind);
+              LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C
+                                << " for extending externally used vector with "
+                                   "non-equal minimum bitwidth.\n");
+              Cost += C;
+            }
           } else {
             if (isFirstInsertElement(VU, cast<InsertElementInst>(It->first)))
               It->first = VU;
@@ -9069,6 +9136,21 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
                                              CostKind, EU.Lane);
     }
   }
+  // Add reduced value cost, if resized.
+  if (!VectorizedVals.empty()) {
+    auto BWIt = MinBWs.find(VectorizableTree.front()->Scalars.front());
+    if (BWIt != MinBWs.end()) {
+      Type *DstTy = BWIt->first->getType();
+      unsigned OriginalSz = DL->getTypeSizeInBits(DstTy);
+      unsigned Opcode = Instruction::Trunc;
+      if (OriginalSz < BWIt->second.first)
+        Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
+      Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
+      Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
+                                    TTI::CastContextHint::None,
+                                    TTI::TCK_RecipThroughput);
+    }
+  }
 
   InstructionCost SpillCost = getSpillCost();
   Cost += SpillCost + ExtractCost;
@@ -9274,6 +9356,11 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
       Instruction &LastBundleInst = getLastInstructionInBundle(VTE);
       if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst))
         continue;
+      auto It = MinBWs.find(VTE->Scalars.front());
+      // If vectorize node is demoted - do not match.
+      if (It != MinBWs.end() &&
+          It->second.first != DL->getTypeSizeInBits(V->getType()))
+        continue;
       VToTEs.insert(VTE);
     }
     if (VToTEs.empty())
@@ -10830,7 +10917,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
     return Vec;
   }
 
-  auto FinalShuffle = [&](Value *V, const TreeEntry *E) {
+  auto FinalShuffle = [&](Value *V, const TreeEntry *E, VectorType *VecTy,
+                          bool IsSigned) {
+    if (V->getType() != VecTy)
+      V = Builder.CreateIntCast(V, VecTy, IsSigned);
     ShuffleInstructionBuilder ShuffleBuilder(Builder, *this);
     if (E->getOpcode() == Instruction::Store) {
       ArrayRef<int> Mask =
@@ -10857,6 +10947,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
     ScalarTy = Store->getValueOperand()->getType();
   else if (auto *IE = dyn_cast<InsertElementInst>(VL0))
     ScalarTy = IE->getOperand(1)->getType();
+  bool IsSigned = false;
+  auto It = MinBWs.find(E->Scalars.front());
+  if (It != MinBWs.end()) {
+    ScalarTy = IntegerType::get(F->getContext(), It->second.first);
+    IsSigned = It->second.second;
+  }
   auto *VecTy = FixedVectorType::get(ScalarTy, E->Scalars.size());
   switch (ShuffleOrOp) {
     case Instruction::PHI: {
@@ -10880,7 +10976,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
                                PH->getParent()->getFirstInsertionPt());
         Builder.SetCurrentDebugLocation(PH->getDebugLoc());
 
-        V = FinalShuffle(V, E);
+        V = FinalShuffle(V, E, VecTy, IsSigned);
 
         E->VectorizedValue = V;
         if (PostponedPHIs)
@@ -10913,6 +11009,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         Builder.SetInsertPoint(IBB->getTerminator());
         Builder.SetCurrentDebugLocation(PH->getDebugLoc());
         Value *Vec = vectorizeOperand(E, i, /*PostponedPHIs=*/true);
+        if (VecTy != Vec->getType()) {
+          assert(It != MinBWs.end() && "Expected item in MinBWs.");
+          Vec = Builder.CreateIntCast(Vec, VecTy, It->second.second);
+        }
         NewPhi->addIncoming(Vec, IBB);
       }
 
@@ -10924,7 +11024,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
     case Instruction::ExtractElement: {
       Value *V = E->getSingleOperand(0);
       setInsertPointAfterBundle(E);
-      V = FinalShuffle(V, E);
+      V = FinalShuffle(V, E, VecTy, IsSigned);
       E->VectorizedValue = V;
       return V;
     }
@@ -10934,7 +11034,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       Value *Ptr = LI->getPointerOperand();
       LoadInst *V = Builder.CreateAlignedLoad(VecTy, Ptr, LI->getAlign());
       Value *NewV = propagateMetadata(V, E->Scalars);
-      NewV = FinalShuffle(NewV, E);
+      NewV = FinalShuffle(NewV, E, VecTy, IsSigned);
       E->VectorizedValue = NewV;
       return NewV;
     }
@@ -10942,6 +11042,19 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       assert(E->ReuseShuffleIndices.empty() && "All inserts should be unique");
       Builder.SetInsertPoint(cast<Instruction>(E->Scalars.back()));
       Value *V = vectorizeOperand(E, 1, PostponedPHIs);
+      ArrayRef<Value *> Op = E->getOperand(1);
+      Type *ScalarTy = Op.front()->getType();
+      if (cast<VectorType>(V->getType())->getElementType() != ScalarTy) {
+        assert(ScalarTy->isIntegerTy() && "Expected item in MinBWs.");
+        std::pair<unsigned, bool> Res = MinBWs.lookup(Op.front());
+        assert(Res.first > 0 && "Expected item in MinBWs.");
+        V = Builder.CreateIntCast(
+            V,
+            FixedVectorType::get(
+                ScalarTy,
+                cast<FixedVectorType>(V->getType())->getNumElements()),
+            Res.second);
+      }
 
       // Create InsertVector shuffle if necessary
       auto *FirstInsert = cast<Instruction>(*find_if(E->Scalars, [E](Value *V) {
@@ -11107,7 +11220,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
 
       auto *CI = cast<CastInst>(VL0);
       Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
-      V = FinalShuffle(V, E);
+      V = FinalShuffle(V, E, VecTy, IsSigned);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11127,11 +11240,22 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
         return E->VectorizedValue;
       }
+      if (L->getType() != R->getType()) {
+        assert(It != MinBWs.end() && "Expected item in MinBWs.");
+        if (L == R) {
+          R = L = Builder.CreateIntCast(L, VecTy, IsSigned);
+        } else {
+          L = Builder.CreateIntCast(L, VecTy, IsSigned);
+          R = Builder.CreateIntCast(R, VecTy, IsSigned);
+        }
+      }
 
       CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate();
       Value *V = Builder.CreateCmp(P0, L, R);
       propagateIRFlags(V, E->Scalars, VL0);
-      V = FinalShuffle(V, E);
+      // Do not cast for cmps.
+      VecTy = cast<FixedVectorType>(V->getType());
+      V = FinalShuffle(V, E, VecTy, IsSigned);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11155,9 +11279,18 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
         return E->VectorizedValue;
       }
+      if (True->getType() != False->getType()) {
+        assert(It != MinBWs.end() && "Expected item in MinBWs.");
+        if (True == False) {
+          True = False = Builder.CreateIntCast(True, VecTy, IsSigned);
+        } else {
+          True = Builder.CreateIntCast(True, VecTy, IsSigned);
+          False = Builder.CreateIntCast(False, VecTy, IsSigned);
+        }
+      }
 
       Value *V = Builder.CreateSelect(Cond, True, False);
-      V = FinalShuffle(V, E);
+      V = FinalShuffle(V, E, VecTy, IsSigned);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11179,7 +11312,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       if (auto *I = dyn_cast<Instruction>(V))
         V = propagateMetadata(I, E->Scalars);
 
-      V = FinalShuffle(V, E);
+      V = FinalShuffle(V, E, VecTy, IsSigned);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11216,6 +11349,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
         return E->VectorizedValue;
       }
+      if (LHS->getType() != RHS->getType()) {
+        assert(It != MinBWs.end() && "Expected item in MinBWs.");
+        if (LHS == RHS) {
+          RHS = LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
+        } else {
+          LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
+          RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned);
+        }
+      }
 
       Value *V = Builder.CreateBinOp(
           static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS,
@@ -11224,7 +11366,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       if (auto *I = dyn_cast<Instruction>(V))
         V = propagateMetadata(I, E->Scalars);
 
-      V = FinalShuffle(V, E);
+      V = FinalShuffle(V, E, VecTy, IsSigned);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11270,7 +11412,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       }
       Value *V = propagateMetadata(NewLI, E->Scalars);
 
-      V = FinalShuffle(V, E);
+      V = FinalShuffle(V, E, VecTy, IsSigned);
       E->VectorizedValue = V;
       ++NumVectorInstructions;
       return V;
@@ -11281,7 +11423,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       setInsertPointAfterBundle(E);
 
       Value *VecValue = vectorizeOperand(E, 0, PostponedPHIs);
-      VecValue = FinalShuffle(VecValue, E);
+      VecValue = FinalShuffle(VecValue, E, VecTy, IsSigned);
 
       Value *Ptr = SI->getPointerOperand();
       StoreInst *ST =
@@ -11334,7 +11476,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         V = propagateMetadata(I, GEPs);
       }
 
-      V = FinalShuffle(V, E);
+      V = FinalShuffle(V, E, VecTy, IsSigned);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11414,7 +11556,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       }
 
       propagateIRFlags(V, E->Scalars, VL0);
-      V = FinalShuffle(V, E);
+      V = FinalShuffle(V, E, VecTy, IsSigned);
 
       E->VectorizedValue = V;
       ++NumVectorInstructions;
@@ -11446,6 +11588,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
         return E->VectorizedValue;
       }
+      if (LHS && RHS && LHS->getType() != RHS->getType()) {
+        assert(It != MinBWs.end() && "Expected item in MinBWs.");
+        if (LHS == RHS) {
+          RHS = LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
+        } else {
+          LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
+          RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned);
+        }
+      }
 
       Value *V0, *V1;
       if (Instruction::isBinaryOp(E->getOpcode())) {
@@ -11496,6 +11647,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         CSEBlocks.insert(I->getParent());
       }
 
+      if (V->getType() != VecTy && !isa<CmpInst>(VL0))
+        V = Builder.CreateIntCast(
+            V, FixedVectorType::get(ScalarTy, E->getVectorFactor()), IsSigned);
       E->VectorizedValue = V;
       ++NumVectorInstructions;
 
@@ -11543,8 +11697,7 @@ Value *BoUpSLP::vectorizeTree(
     Builder.SetInsertPoint(&F->getEntryBlock(), F->getEntryBlock().begin());
 
   // Postpone emission of PHIs operands to avoid cyclic dependencies issues.
-  auto *VectorRoot =
-      vectorizeTree(VectorizableTree[0].get(), /*PostponedPHIs=*/true);
+  (void)vectorizeTree(VectorizableTree[0].get(), /*PostponedPHIs=*/true);
   for (const std::unique_ptr<TreeEntry> &TE : VectorizableTree)
     if (TE->State == TreeEntry::Vectorize &&
         TE->getOpcode() == Instruction::PHI && !TE->isAltShuffle() &&
@@ -11604,28 +11757,6 @@ Value *BoUpSLP::vectorizeTree(
     eraseInstruction(PrevVec);
   }
 
-  // If the vectorized tree can be rewritten in a smaller type, we truncate the
-  // vectorized root. InstCombine will then rewrite the entire expression. We
-  // sign extend the extracted values below.
-  auto *ScalarRoot = VectorizableTree[0]->Scalars[0];
-  auto It = MinBWs.find(ScalarRoot);
-  if (It != MinBWs.end()) {
-    if (auto *I = dyn_cast<Instruction>(VectorRoot)) {
-      // If current instr is a phi and not the last phi, insert it after the
-      // last phi node.
-      if (isa<PHINode>(I))
-        Builder.SetInsertPoint(I->getParent(),
-                               I->getParent()->getFirstInsertionPt());
-      else
-        Builder.SetInsertPoint(&*++BasicBlock::iterator(I));
-    }
-    auto BundleWidth = VectorizableTree[0]->Scalars.size();
-    auto *MinTy = IntegerType::get(F->getContext(), It->second.first);
-    auto *VecTy = FixedVectorType::get(MinTy, BundleWidth);
-    auto *Trunc = Builder.CreateTrunc(VectorRoot, VecTy);
-    VectorizableTree[0]->VectorizedValue = Trunc;
-  }
-
   LLVM_DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size()
                     << " values .\n");
 
@@ -11636,6 +11767,7 @@ Value *BoUpSLP::vectorizeTree(
   // basic block. Only one extractelement per block should be emitted.
   DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs;
   SmallDenseSet<Value *, 4> UsedInserts;
+  DenseMap<Value *, Value *> VectorCasts;
   // Extract all of the elements with the external uses.
   for (const auto &ExternalUse : ExternalUses) {
     Value *Scalar = ExternalUse.Scalar;
@@ -11694,12 +11826,10 @@ Value *BoUpSLP::vectorizeTree(
         }
         // If necessary, sign-extend or zero-extend ScalarRoot
         // to the larger type.
-        auto BWIt = MinBWs.find(ScalarRoot);
-        if (BWIt == MinBWs.end())
-          return Ex;
-        if (BWIt->second.second)
-          return Builder.CreateSExt(Ex, Scalar->getType());
-        return Builder.CreateZExt(Ex, Scalar->getType());
+        if (Scalar->getType() != Ex->getType())
+          return Builder.CreateIntCast(Ex, Scalar->getType(),
+                                       MinBWs.find(Scalar)->second.second);
+        return Ex;
       }
       assert(isa<FixedVectorType>(Scalar->getType()) &&
              isa<InsertElementInst>(Scalar) &&
@@ -11738,12 +11868,24 @@ Value *BoUpSLP::vectorizeTree(
         if (auto *FTy = dyn_cast<FixedVectorType>(User->getType())) {
           if (!UsedInserts.insert(VU).second)
             continue;
+          // Need to use original vector, if the root is truncated.
+          auto BWIt = MinBWs.find(Scalar);
+          if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) {
+            auto VecIt = VectorCasts.find(Scalar);
+            if (VecIt == VectorCasts.end()) {
+              IRBuilder<>::InsertPointGuard Guard(Builder);
+              if (auto *IVec = dyn_cast<Instruction>(Vec))
+                Builder.SetInsertPoint(IVec->getNextNonDebugInstruction());
+              Vec = Builder.CreateIntCast(Vec, VU->getType(),
+                                          BWIt->second.second);
+              VectorCasts.try_emplace(Scalar, Vec);
+            } else {
+              Vec = VecIt->second;
+            }
+          }
+
           std::optional<unsigned> InsertIdx = getInsertIndex(VU);
           if (InsertIdx) {
-            // Need to use original vector, if the root is truncated.
-            if (MinBWs.contains(Scalar) &&
-                VectorizableTree[0]->VectorizedValue == Vec)
-              Vec = VectorRoot;
             auto *It =
                 find_if(ShuffledInserts, [VU](const ShuffledInsertData &Data) {
                   // Checks if 2 insertelements are from the same buildvector.
@@ -12888,11 +13030,8 @@ void BoUpSLP::computeMinimumValueSizes() {
     return;
 
   // If the expression is not rooted by a store, these roots should have
-  // external uses. We will rely on InstCombine to rewrite the expression in
-  // the narrower type. However, InstCombine only rewrites single-use values.
-  // This means that if a tree entry other than a root is used externally, it
-  // must have multiple uses and InstCombine will not rewrite it. The code
-  // below ensures that only the roots are used externally.
+  // external uses.
+  // TOSO: investigate if this can be relaxed.
   SmallPtrSet<Value *, 32> Expr(TreeRoot.begin(), TreeRoot.end());
   for (auto &EU : ExternalUses)
     if (!Expr.erase(EU.Scalar))
@@ -12924,7 +13063,7 @@ void BoUpSLP::computeMinimumValueSizes() {
   // The maximum bit width required to represent all the values that can be
   // demoted without loss of precision. It would be safe to truncate the roots
   // of the expression to this width.
-  auto MaxBitWidth = 8u;
+  auto MaxBitWidth = 1u;
 
   // We first check if all the bits of the roots are demanded. If they're not,
   // we can truncate the roots to this narrower type.
@@ -14561,6 +14700,16 @@ class HorizontalReduction {
 
         Value *ReducedSubTree =
             emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
+        if (ReducedSubTree->getType() != VL.front()->getType()) {
+          ReducedSubTree = Builder.CreateIntCast(
+              ReducedSubTree, VL.front()->getType(), any_of(VL, [&](Value *R) {
+                KnownBits Known = computeKnownBits(
+                    R, cast<Instruction>(ReductionOps.front().front())
+                           ->getModule()
+                           ->getDataLayout());
+                return !Known.isNonNegative();
+              }));
+        }
 
         // Improved analysis for add/fadd/xor reductions with same scale factor
         // for all operands of reductions. We can emit scalar ops for them
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/trunc-insertion.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/trunc-insertion.ll
index 56f252bf640408e..0eda2cbc862ff04 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/trunc-insertion.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/trunc-insertion.ll
@@ -8,34 +8,31 @@ define dso_local void @l() local_unnamed_addr {
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    br label [[BB1:%.*]]
 ; CHECK:       bb1:
-; CHECK-NEXT:    [[TMP0:%.*]] = phi <2 x i16> [ undef, [[BB:%.*]] ], [ [[TMP11:%.*]], [[BB25:%.*]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = phi <2 x i16> [ undef, [[BB:%.*]] ], [ [[TMP9:%.*]], [[BB25:%.*]] ]
 ; CHECK-NEXT:    br i1 undef, label [[BB3:%.*]], label [[BB11:%.*]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    [[I4:%.*]] = zext i1 undef to i32
 ; CHECK-NEXT:    [[TMP1:%.*]] = xor <2 x i16> [[TMP0]], undef
 ; CHECK-NEXT:    [[TMP2:%.*]] = icmp ugt <2 x i16> [[TMP1]], <i16 8, i16 8>
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <2 x i1> [[TMP2]] to <2 x i32>
 ; CHECK-NEXT:    br label [[BB25]]
 ; CHECK:       bb11:
 ; CHECK-NEXT:    [[I12:%.*]] = zext i1 undef to i32
-; CHECK-NEXT:    [[TMP4:%.*]] = xor <2 x i16> [[TMP0]], undef
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <2 x i16> [[TMP4]] to <2 x i64>
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ule <2 x i64> undef, [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = zext <2 x i1> [[TMP6]] to <2 x i32>
-; CHECK-NEXT:    [[TMP8:%.*]] = icmp ult <2 x i32> undef, [[TMP7]]
-; CHECK-NEXT:    [[TMP9:%.*]] = zext <2 x i1> [[TMP8]] to <2 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = xor <2 x i16> [[TMP0]], undef
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <2 x i16> [[TMP3]] to <2 x i64>
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ule <2 x i64> undef, [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <2 x i1> [[TMP5]] to <2 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = icmp ult <2 x i32> undef, [[TMP6]]
 ; CHECK-NEXT:    br label [[BB25]]
 ; CHECK:       bb25:
 ; CHECK-NEXT:    [[I28:%.*]] = phi i32 [ [[I12]], [[BB11]] ], [ [[I4]], [[BB3]] ]
-; CHECK-NEXT:    [[TMP10:%.*]] = phi <2 x i32> [ [[TMP9]], [[BB11]] ], [ [[TMP3]], [[BB3]] ]
-; CHECK-NEXT:    [[TMP11]] = phi <2 x i16> [ [[TMP4]], [[BB11]] ], [ [[TMP1]], [[BB3]] ]
-; CHECK-NEXT:    [[TMP12:%.*]] = trunc <2 x i32> [[TMP10]] to <2 x i8>
-; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <2 x i8> [[TMP12]], i32 0
-; CHECK-NEXT:    [[TMP14:%.*]] = zext i8 [[TMP13]] to i32
-; CHECK-NEXT:    [[I31:%.*]] = and i32 undef, [[TMP14]]
-; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <2 x i8> [[TMP12]], i32 1
-; CHECK-NEXT:    [[TMP16:%.*]] = zext i8 [[TMP15]] to i32
-; CHECK-NEXT:    [[I32:%.*]] = and i32 [[I31]], [[TMP16]]
+; CHECK-NEXT:    [[TMP8:%.*]] = phi <2 x i1> [ [[TMP7]], [[BB11]] ], [ [[TMP2]], [[BB3]] ]
+; CHECK-NEXT:    [[TMP9]] = phi <2 x i16> [ [[TMP3]], [[BB11]] ], [ [[TMP1]], [[BB3]] ]
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x i1> [[TMP8]], i32 0
+; CHECK-NEXT:    [[TMP11:%.*]] = zext i1 [[TMP10]] to i32
+; CHECK-NEXT:    [[I31:%.*]] = and i32 undef, [[TMP11]]
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x i1> [[TMP8]], i32 1
+; CHECK-NEXT:    [[TMP13:%.*]] = zext i1 [[TMP12]] to i32
+; CHECK-NEXT:    [[I32:%.*]] = and i32 [[I31]], [[TMP13]]
 ; CHECK-NEXT:    [[I33:%.*]] = and i32 [[I32]], [[I28]]
 ; CHECK-NEXT:    br i1 undef, label [[BB34:%.*]], label [[BB1]]
 ; CHECK:       bb34:
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/partail.ll b/llvm/test/Transforms/SLPVectorizer/X86/partail.ll
index b9747b6ae8c89a3..40ca0150d8e744d 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/partail.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/partail.ll
@@ -23,20 +23,18 @@ define void @get_block(i32 %y_pos) local_unnamed_addr #0 {
 ; CHECK-NEXT:    [[TMP7:%.*]] = select <4 x i1> [[TMP3]], <4 x i32> [[TMP6]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp slt <4 x i32> [[TMP7]], undef
 ; CHECK-NEXT:    [[TMP9:%.*]] = select <4 x i1> [[TMP8]], <4 x i32> [[TMP7]], <4 x i32> undef
-; CHECK-NEXT:    [[TMP10:%.*]] = sext <4 x i32> [[TMP9]] to <4 x i64>
-; CHECK-NEXT:    [[TMP11:%.*]] = trunc <4 x i64> [[TMP10]] to <4 x i32>
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i32> [[TMP11]], i32 0
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i32> [[TMP9]], i32 0
+; CHECK-NEXT:    [[TMP11:%.*]] = sext i32 [[TMP10]] to i64
+; CHECK-NEXT:    [[ARRAYIDX31:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i32> [[TMP9]], i32 1
 ; CHECK-NEXT:    [[TMP13:%.*]] = sext i32 [[TMP12]] to i64
-; CHECK-NEXT:    [[ARRAYIDX31:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP13]]
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x i32> [[TMP11]], i32 1
+; CHECK-NEXT:    [[ARRAYIDX31_1:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP13]]
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x i32> [[TMP9]], i32 2
 ; CHECK-NEXT:    [[TMP15:%.*]] = sext i32 [[TMP14]] to i64
-; CHECK-NEXT:    [[ARRAYIDX31_1:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP15]]
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x i32> [[TMP11]], i32 2
+; CHECK-NEXT:    [[ARRAYIDX31_2:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP15]]
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x i32> [[TMP9]], i32 3
 ; CHECK-NEXT:    [[TMP17:%.*]] = sext i32 [[TMP16]] to i64
-; CHECK-NEXT:    [[ARRAYIDX31_2:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP17]]
-; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <4 x i32> [[TMP11]], i32 3
-; CHECK-NEXT:    [[TMP19:%.*]] = sext i32 [[TMP18]] to i64
-; CHECK-NEXT:    [[ARRAYIDX31_3:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP19]]
+; CHECK-NEXT:    [[ARRAYIDX31_3:%.*]] = getelementptr inbounds ptr, ptr undef, i64 [[TMP17]]
 ; CHECK-NEXT:    unreachable
 ;
 entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/root-trunc-extract-reuse.ll b/llvm/test/Transforms/SLPVectorizer/X86/root-trunc-extract-reuse.ll
index 87dd2cfd2004430..f48528e502b8cf1 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/root-trunc-extract-reuse.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/root-trunc-extract-reuse.ll
@@ -8,12 +8,12 @@ define i1 @test() {
 ; CHECK:       then:
 ; CHECK-NEXT:    br label [[ELSE]]
 ; CHECK:       else:
-; CHECK-NEXT:    [[TMP0:%.*]] = phi <2 x i32> [ zeroinitializer, [[THEN]] ], [ zeroinitializer, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i32> [[TMP0]] to <2 x i8>
-; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x i8> [[TMP1]], i32 0
-; CHECK-NEXT:    [[TMP3:%.*]] = zext i8 [[TMP2]] to i32
+; CHECK-NEXT:    [[TMP0:%.*]] = phi <2 x i1> [ zeroinitializer, [[THEN]] ], [ zeroinitializer, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <2 x i1> [[TMP0]] to <2 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x i1> [[TMP0]], i32 0
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
 ; CHECK-NEXT:    [[BF_CAST162:%.*]] = and i32 [[TMP3]], 0
-; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <2 x i32> zeroinitializer, <2 x i32> [[TMP0]], <2 x i32> <i32 3, i32 1>
+; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <2 x i32> zeroinitializer, <2 x i32> [[TMP1]], <2 x i32> <i32 3, i32 1>
 ; CHECK-NEXT:    [[T13:%.*]] = and <2 x i32> [[TMP4]], zeroinitializer
 ; CHECK-NEXT:    br label [[ELSE1:%.*]]
 ; CHECK:       else1:



More information about the llvm-commits mailing list