[llvm] ce14d1b - [SLP]Do not reorder reduction nodes.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 26 07:44:52 PDT 2021


Author: Alexey Bataev
Date: 2021-10-26T07:41:24-07:00
New Revision: ce14d1b690d886ff6022a5cc0f2e40cc8ecb9a46

URL: https://github.com/llvm/llvm-project/commit/ce14d1b690d886ff6022a5cc0f2e40cc8ecb9a46
DIFF: https://github.com/llvm/llvm-project/commit/ce14d1b690d886ff6022a5cc0f2e40cc8ecb9a46.diff

LOG: [SLP]Do not reorder reduction nodes.

The final reduction nodes should not be reordered, the order does not
matter for reductions. Also, it might be profitable to vectorize smaller
reduction trees, reduction cost may compensate small tree cost.

Part of D111574

Differential Revision: https://reviews.llvm.org/D112467

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
    llvm/test/Transforms/SLPVectorizer/X86/horizontal-list.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8495bf8759e0..4f0cb272509e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -781,7 +781,7 @@ class BoUpSLP {
   /// operands. Plus, even the leaf nodes have 
diff erent orders, it allows to
   /// sink reordering in the graph closer to the root node and merge it later
   /// during analysis.
-  void reorderBottomToTop();
+  void reorderBottomToTop(bool IgnoreReorder = false);
 
   /// \return The vector element size in bits to use when vectorizing the
   /// expression tree ending at \p V. If V is a store, the size is the width of
@@ -824,7 +824,7 @@ class BoUpSLP {
 
   /// \returns True if the VectorizableTree is both tiny and not fully
   /// vectorizable. We do not vectorize such trees.
-  bool isTreeTinyAndNotFullyVectorizable() const;
+  bool isTreeTinyAndNotFullyVectorizable(bool ForReduction = false) const;
 
   /// Assume that a legal-sized 'or'-reduction of shifted/zexted loaded values
   /// can be load combined in the backend. Load combining may not be allowed in
@@ -1620,7 +1620,7 @@ class BoUpSLP {
 
   /// \returns whether the VectorizableTree is fully vectorizable and will
   /// be beneficial even the tree height is tiny.
-  bool isFullyVectorizableTinyTree() const;
+  bool isFullyVectorizableTinyTree(bool ForReduction) const;
 
   /// Reorder commutative or alt operands to get better probability of
   /// generating vectorized code.
@@ -2820,7 +2820,7 @@ void BoUpSLP::reorderTopToBottom() {
   }
 }
 
-void BoUpSLP::reorderBottomToTop() {
+void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
   SetVector<TreeEntry *> OrderedEntries;
   DenseMap<const TreeEntry *, OrdersType> GathersToOrders;
   // Find all reorderable leaf nodes with the given VF.
@@ -2950,7 +2950,8 @@ void BoUpSLP::reorderBottomToTop() {
       SmallPtrSet<const TreeEntry *, 4> VisitedOps;
       for (const auto &Op : Data.second) {
         TreeEntry *OpTE = Op.second;
-        if (!OpTE->ReuseShuffleIndices.empty())
+        if (!OpTE->ReuseShuffleIndices.empty() ||
+            (IgnoreReorder && OpTE == VectorizableTree.front().get()))
           continue;
         const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & {
           if (OpTE->State == TreeEntry::NeedToGather)
@@ -3061,6 +3062,10 @@ void BoUpSLP::reorderBottomToTop() {
       }
     }
   }
+  // If the reordering is unnecessary, just remove the reorder.
+  if (IgnoreReorder && !VectorizableTree.front()->ReorderIndices.empty() &&
+      VectorizableTree.front()->ReuseShuffleIndices.empty())
+    VectorizableTree.front()->ReorderIndices.clear();
 }
 
 void BoUpSLP::buildExternalUses(
@@ -4894,13 +4899,29 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
   }
 }
 
-bool BoUpSLP::isFullyVectorizableTinyTree() const {
+bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const {
   LLVM_DEBUG(dbgs() << "SLP: Check whether the tree with height "
                     << VectorizableTree.size() << " is fully vectorizable .\n");
 
+  auto &&AreVectorizableGathers = [this](const TreeEntry *TE, unsigned Limit) {
+    SmallVector<int> Mask;
+    return TE->State == TreeEntry::NeedToGather &&
+           !any_of(TE->Scalars,
+                   [this](Value *V) { return EphValues.contains(V); }) &&
+           (allConstant(TE->Scalars) || isSplat(TE->Scalars) ||
+            TE->Scalars.size() < Limit ||
+            (TE->getOpcode() == Instruction::ExtractElement &&
+             isFixedVectorShuffle(TE->Scalars, Mask)));
+  };
+
   // We only handle trees of heights 1 and 2.
   if (VectorizableTree.size() == 1 &&
-      VectorizableTree[0]->State == TreeEntry::Vectorize)
+      (VectorizableTree[0]->State == TreeEntry::Vectorize ||
+       (ForReduction &&
+        AreVectorizableGathers(VectorizableTree[0].get(),
+                               VectorizableTree[0]->Scalars.size()) &&
+        (VectorizableTree[0]->Scalars.size() > 2 ||
+         VectorizableTree[0]->ReuseShuffleIndices.size() > 2))))
     return true;
 
   if (VectorizableTree.size() != 2)
@@ -4912,19 +4933,14 @@ bool BoUpSLP::isFullyVectorizableTinyTree() const {
   // or they are extractelements, which form shuffle.
   SmallVector<int> Mask;
   if (VectorizableTree[0]->State == TreeEntry::Vectorize &&
-      (allConstant(VectorizableTree[1]->Scalars) ||
-       isSplat(VectorizableTree[1]->Scalars) ||
-       (VectorizableTree[1]->State == TreeEntry::NeedToGather &&
-        VectorizableTree[1]->Scalars.size() <
-            VectorizableTree[0]->Scalars.size()) ||
-       (VectorizableTree[1]->State == TreeEntry::NeedToGather &&
-        VectorizableTree[1]->getOpcode() == Instruction::ExtractElement &&
-        isFixedVectorShuffle(VectorizableTree[1]->Scalars, Mask))))
+      AreVectorizableGathers(VectorizableTree[1].get(),
+                             VectorizableTree[0]->Scalars.size()))
     return true;
 
   // Gathering cost would be too much for tiny trees.
   if (VectorizableTree[0]->State == TreeEntry::NeedToGather ||
-      VectorizableTree[1]->State == TreeEntry::NeedToGather)
+      (VectorizableTree[1]->State == TreeEntry::NeedToGather &&
+       VectorizableTree[0]->State != TreeEntry::ScatterVectorize))
     return false;
 
   return true;
@@ -4993,7 +5009,7 @@ bool BoUpSLP::isLoadCombineCandidate() const {
   return true;
 }
 
-bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const {
+bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const {
   // No need to vectorize inserts of gathered values.
   if (VectorizableTree.size() == 2 &&
       isa<InsertElementInst>(VectorizableTree[0]->Scalars[0]) &&
@@ -5007,7 +5023,7 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const {
 
   // If we have a tiny tree (a tree whose size is less than MinTreeSize), we
   // can vectorize it if we can prove it fully vectorizable.
-  if (isFullyVectorizableTinyTree())
+  if (isFullyVectorizableTinyTree(ForReduction))
     return false;
 
   assert(VectorizableTree.empty()
@@ -5769,7 +5785,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
     VF = E->ReuseShuffleIndices.size();
   ShuffleInstructionBuilder ShuffleBuilder(Builder, VF);
   if (E->State == TreeEntry::NeedToGather) {
-    setInsertPointAfterBundle(E);
+    if (E->getMainOp())
+      setInsertPointAfterBundle(E);
     Value *Vec;
     SmallVector<int> Mask;
     SmallVector<const TreeEntry *> Entries;
@@ -8447,12 +8464,12 @@ class HorizontalReduction {
     while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {
       ArrayRef<Value *> VL(&ReducedVals[i], ReduxWidth);
       V.buildTree(VL, IgnoreList);
-      if (V.isTreeTinyAndNotFullyVectorizable())
+      if (V.isTreeTinyAndNotFullyVectorizable(/*ForReduction=*/true))
         break;
       if (V.isLoadCombineReductionCandidate(RdxKind))
         break;
       V.reorderTopToBottom();
-      V.reorderBottomToTop();
+      V.reorderBottomToTop(/*IgnoreReorder=*/true);
       V.buildExternalUses(ExternallyUsedValues);
 
       // For a poison-safe boolean logic reduction, do not replace select
@@ -8630,6 +8647,7 @@ class HorizontalReduction {
     assert(isPowerOf2_32(ReduxWidth) &&
            "We only handle power-of-two reductions for now");
 
+    ++NumVectorInstructions;
     return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind,
                                        ReductionOps.back());
   }
@@ -8889,15 +8907,15 @@ static bool tryToVectorizeHorReductionOrInstOperands(
           continue;
         }
       }
+      // Set P to nullptr to avoid re-analysis of phi node in
+      // matchAssociativeReduction function unless this is the root node.
+      P = nullptr;
+      // Do not try to vectorize CmpInst operands, this is done separately.
+      // Final attempt for binop args vectorization should happen after the loop
+      // to try to find reductions.
+      if (!isa<CmpInst>(Inst))
+        PostponedInsts.push_back(Inst);
     }
-    // Set P to nullptr to avoid re-analysis of phi node in
-    // matchAssociativeReduction function unless this is the root node.
-    P = nullptr;
-    // Do not try to vectorize CmpInst operands, this is done separately.
-    // Final attempt for binop args vectorization should happen after the loop
-    // to try to find reductions.
-    if (!isa<CmpInst>(Inst))
-      PostponedInsts.push_back(Inst);
 
     // Try to vectorize operands.
     // Continue analysis for the instruction from the same basic block only to

diff  --git a/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll b/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
index e31e2a96fb9f..6406e9ca572c 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
@@ -41,14 +41,9 @@ define i32 @ext_ext_partial_add_reduction_v4i32(<4 x i32> %x) {
 
 define i32 @ext_ext_partial_add_reduction_and_extra_add_v4i32(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: @ext_ext_partial_add_reduction_and_extra_add_v4i32(
-; CHECK-NEXT:    [[SHIFT:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> poison, <4 x i32> <i32 2, i32 undef, i32 undef, i32 undef>
-; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[SHIFT]], [[Y:%.*]]
-; CHECK-NEXT:    [[SHIFT1:%.*]] = shufflevector <4 x i32> [[Y]], <4 x i32> poison, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = add <4 x i32> [[TMP1]], [[SHIFT1]]
-; CHECK-NEXT:    [[SHIFT2:%.*]] = shufflevector <4 x i32> [[Y]], <4 x i32> poison, <4 x i32> <i32 2, i32 undef, i32 undef, i32 undef>
-; CHECK-NEXT:    [[TMP3:%.*]] = add <4 x i32> [[TMP2]], [[SHIFT2]]
-; CHECK-NEXT:    [[X2Y210:%.*]] = extractelement <4 x i32> [[TMP3]], i32 0
-; CHECK-NEXT:    ret i32 [[X2Y210]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]], <4 x i32> <i32 4, i32 2, i32 5, i32 6>
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %y0 = extractelement <4 x i32> %y, i32 0
   %y1 = extractelement <4 x i32> %y, i32 1

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/horizontal-list.ll b/llvm/test/Transforms/SLPVectorizer/X86/horizontal-list.ll
index caa960eeeb45..58e40177f54e 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/horizontal-list.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/horizontal-list.ll
@@ -17,23 +17,25 @@ define float @baz() {
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <2 x float>, <2 x float>* bitcast ([20 x float]* @arr1 to <2 x float>*), align 16
 ; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x float> [[TMP3]], i32 0
-; CHECK-NEXT:    [[ADD:%.*]] = fadd fast float [[TMP4]], [[CONV]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x float> [[TMP3]], i32 1
-; CHECK-NEXT:    [[ADD_1:%.*]] = fadd fast float [[TMP5]], [[ADD]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr, i64 0, i64 2) to <2 x float>*), align 8
 ; CHECK-NEXT:    [[TMP7:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr1, i64 0, i64 2) to <2 x float>*), align 8
 ; CHECK-NEXT:    [[TMP8:%.*]] = fmul fast <2 x float> [[TMP7]], [[TMP6]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <2 x float> [[TMP8]], i32 0
-; CHECK-NEXT:    [[ADD_2:%.*]] = fadd fast float [[TMP9]], [[ADD_1]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x float> [[TMP8]], i32 1
-; CHECK-NEXT:    [[ADD_3:%.*]] = fadd fast float [[TMP10]], [[ADD_2]]
-; CHECK-NEXT:    [[ADD7:%.*]] = fadd fast float [[ADD_3]], [[CONV]]
-; CHECK-NEXT:    [[ADD19:%.*]] = fadd fast float [[TMP4]], [[ADD7]]
-; CHECK-NEXT:    [[ADD19_1:%.*]] = fadd fast float [[TMP5]], [[ADD19]]
-; CHECK-NEXT:    [[ADD19_2:%.*]] = fadd fast float [[TMP9]], [[ADD19_1]]
-; CHECK-NEXT:    [[ADD19_3:%.*]] = fadd fast float [[TMP10]], [[ADD19_2]]
-; CHECK-NEXT:    store float [[ADD19_3]], float* @res, align 4
-; CHECK-NEXT:    ret float [[ADD19_3]]
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <8 x float> poison, float [[TMP10]], i32 0
+; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <8 x float> [[TMP11]], float [[TMP9]], i32 1
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <8 x float> [[TMP12]], float [[TMP5]], i32 2
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <8 x float> [[TMP13]], float [[TMP4]], i32 3
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <8 x float> [[TMP14]], float [[TMP10]], i32 4
+; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <8 x float> [[TMP15]], float [[TMP9]], i32 5
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <8 x float> [[TMP16]], float [[TMP5]], i32 6
+; CHECK-NEXT:    [[TMP18:%.*]] = insertelement <8 x float> [[TMP17]], float [[TMP4]], i32 7
+; CHECK-NEXT:    [[TMP19:%.*]] = call fast float @llvm.vector.reduce.fadd.v8f32(float -0.000000e+00, <8 x float> [[TMP18]])
+; CHECK-NEXT:    [[OP_EXTRA:%.*]] = fadd fast float [[TMP19]], [[CONV]]
+; CHECK-NEXT:    [[OP_EXTRA1:%.*]] = fadd fast float [[OP_EXTRA]], [[CONV]]
+; CHECK-NEXT:    store float [[OP_EXTRA1]], float* @res, align 4
+; CHECK-NEXT:    ret float [[OP_EXTRA1]]
 ;
 ; THRESHOLD-LABEL: @baz(
 ; THRESHOLD-NEXT:  entry:
@@ -44,23 +46,25 @@ define float @baz() {
 ; THRESHOLD-NEXT:    [[TMP2:%.*]] = load <2 x float>, <2 x float>* bitcast ([20 x float]* @arr1 to <2 x float>*), align 16
 ; THRESHOLD-NEXT:    [[TMP3:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP1]]
 ; THRESHOLD-NEXT:    [[TMP4:%.*]] = extractelement <2 x float> [[TMP3]], i32 0
-; THRESHOLD-NEXT:    [[ADD:%.*]] = fadd fast float [[TMP4]], [[CONV]]
 ; THRESHOLD-NEXT:    [[TMP5:%.*]] = extractelement <2 x float> [[TMP3]], i32 1
-; THRESHOLD-NEXT:    [[ADD_1:%.*]] = fadd fast float [[TMP5]], [[ADD]]
 ; THRESHOLD-NEXT:    [[TMP6:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr, i64 0, i64 2) to <2 x float>*), align 8
 ; THRESHOLD-NEXT:    [[TMP7:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr1, i64 0, i64 2) to <2 x float>*), align 8
 ; THRESHOLD-NEXT:    [[TMP8:%.*]] = fmul fast <2 x float> [[TMP7]], [[TMP6]]
 ; THRESHOLD-NEXT:    [[TMP9:%.*]] = extractelement <2 x float> [[TMP8]], i32 0
-; THRESHOLD-NEXT:    [[ADD_2:%.*]] = fadd fast float [[TMP9]], [[ADD_1]]
 ; THRESHOLD-NEXT:    [[TMP10:%.*]] = extractelement <2 x float> [[TMP8]], i32 1
-; THRESHOLD-NEXT:    [[ADD_3:%.*]] = fadd fast float [[TMP10]], [[ADD_2]]
-; THRESHOLD-NEXT:    [[ADD7:%.*]] = fadd fast float [[ADD_3]], [[CONV]]
-; THRESHOLD-NEXT:    [[ADD19:%.*]] = fadd fast float [[TMP4]], [[ADD7]]
-; THRESHOLD-NEXT:    [[ADD19_1:%.*]] = fadd fast float [[TMP5]], [[ADD19]]
-; THRESHOLD-NEXT:    [[ADD19_2:%.*]] = fadd fast float [[TMP9]], [[ADD19_1]]
-; THRESHOLD-NEXT:    [[ADD19_3:%.*]] = fadd fast float [[TMP10]], [[ADD19_2]]
-; THRESHOLD-NEXT:    store float [[ADD19_3]], float* @res, align 4
-; THRESHOLD-NEXT:    ret float [[ADD19_3]]
+; THRESHOLD-NEXT:    [[TMP11:%.*]] = insertelement <8 x float> poison, float [[TMP10]], i32 0
+; THRESHOLD-NEXT:    [[TMP12:%.*]] = insertelement <8 x float> [[TMP11]], float [[TMP9]], i32 1
+; THRESHOLD-NEXT:    [[TMP13:%.*]] = insertelement <8 x float> [[TMP12]], float [[TMP5]], i32 2
+; THRESHOLD-NEXT:    [[TMP14:%.*]] = insertelement <8 x float> [[TMP13]], float [[TMP4]], i32 3
+; THRESHOLD-NEXT:    [[TMP15:%.*]] = insertelement <8 x float> [[TMP14]], float [[TMP10]], i32 4
+; THRESHOLD-NEXT:    [[TMP16:%.*]] = insertelement <8 x float> [[TMP15]], float [[TMP9]], i32 5
+; THRESHOLD-NEXT:    [[TMP17:%.*]] = insertelement <8 x float> [[TMP16]], float [[TMP5]], i32 6
+; THRESHOLD-NEXT:    [[TMP18:%.*]] = insertelement <8 x float> [[TMP17]], float [[TMP4]], i32 7
+; THRESHOLD-NEXT:    [[TMP19:%.*]] = call fast float @llvm.vector.reduce.fadd.v8f32(float -0.000000e+00, <8 x float> [[TMP18]])
+; THRESHOLD-NEXT:    [[OP_EXTRA:%.*]] = fadd fast float [[TMP19]], [[CONV]]
+; THRESHOLD-NEXT:    [[OP_EXTRA1:%.*]] = fadd fast float [[OP_EXTRA]], [[CONV]]
+; THRESHOLD-NEXT:    store float [[OP_EXTRA1]], float* @res, align 4
+; THRESHOLD-NEXT:    ret float [[OP_EXTRA1]]
 ;
 entry:
   %0 = load i32, i32* @n, align 4


        


More information about the llvm-commits mailing list