[llvm] 03783f1 - [SLP] sort candidates to increase chance of optimal compare reduction

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 17 05:49:34 PDT 2020


Author: Sanjay Patel
Date: 2020-09-17T08:49:27-04:00
New Revision: 03783f19dc78fc45fd987f892c314578b5e52d78

URL: https://github.com/llvm/llvm-project/commit/03783f19dc78fc45fd987f892c314578b5e52d78
DIFF: https://github.com/llvm/llvm-project/commit/03783f19dc78fc45fd987f892c314578b5e52d78.diff

LOG: [SLP] sort candidates to increase chance of optimal compare reduction

This is one (small) part of improving PR41312:
https://llvm.org/PR41312

As shown there and in the smaller tests here, if we have some member of the
reduction values that does not match the others, we want to push it to the
end (bring the matching members forward and together).

In the regression tests, we have 5 candidates for the 4 slots of the reduction.
If the one "wrong" compare is grouped with the others, it prevents forming the
ideal v4i1 compare reduction.

Differential Revision: https://reviews.llvm.org/D87772

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 3d19e867b6c2..c487301177c1 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6838,9 +6838,37 @@ class HorizontalReduction {
     for (ReductionOpsType &RdxOp : ReductionOps)
       IgnoreList.append(RdxOp.begin(), RdxOp.end());
 
+    unsigned ReduxWidth = PowerOf2Floor(NumReducedVals);
+    if (NumReducedVals > ReduxWidth) {
+      // In the loop below, we are building a tree based on a window of
+      // 'ReduxWidth' values.
+      // If the operands of those values have common traits (compare predicate,
+      // constant operand, etc), then we want to group those together to
+      // minimize the cost of the reduction.
+
+      // TODO: This should be extended to count common operands for
+      //       compares and binops.
+
+      // Step 1: Count the number of times each compare predicate occurs.
+      SmallDenseMap<unsigned, unsigned> PredCountMap;
+      for (Value *RdxVal : ReducedVals) {
+        CmpInst::Predicate Pred;
+        if (match(RdxVal, m_Cmp(Pred, m_Value(), m_Value())))
+          ++PredCountMap[Pred];
+      }
+      // Step 2: Sort the values so the most common predicates come first.
+      stable_sort(ReducedVals, [&PredCountMap](Value *A, Value *B) {
+        CmpInst::Predicate PredA, PredB;
+        if (match(A, m_Cmp(PredA, m_Value(), m_Value())) &&
+            match(B, m_Cmp(PredB, m_Value(), m_Value()))) {
+          return PredCountMap[PredA] > PredCountMap[PredB];
+        }
+        return false;
+      });
+    }
+
     Value *VectorizedTree = nullptr;
     unsigned i = 0;
-    unsigned ReduxWidth = PowerOf2Floor(NumReducedVals);
     while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {
       ArrayRef<Value *> VL = makeArrayRef(&ReducedVals[i], ReduxWidth);
       V.buildTree(VL, ExternallyUsedValues, IgnoreList);

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll b/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll
index daa96bfa84ae..b0971dd80450 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll
@@ -81,20 +81,12 @@ declare i32 @printf(i8* nocapture, ...)
 
 define float @merge_anyof_v4f32_wrong_first(<4 x float> %x) {
 ; CHECK-LABEL: @merge_anyof_v4f32_wrong_first(
-; CHECK-NEXT:    [[X0:%.*]] = extractelement <4 x float> [[X:%.*]], i32 0
-; CHECK-NEXT:    [[X1:%.*]] = extractelement <4 x float> [[X]], i32 1
-; CHECK-NEXT:    [[X2:%.*]] = extractelement <4 x float> [[X]], i32 2
-; CHECK-NEXT:    [[X3:%.*]] = extractelement <4 x float> [[X]], i32 3
-; CHECK-NEXT:    [[CMP3WRONG:%.*]] = fcmp olt float [[X3]], 4.200000e+01
-; CHECK-NEXT:    [[CMP0:%.*]] = fcmp ogt float [[X0]], 1.000000e+00
-; CHECK-NEXT:    [[CMP1:%.*]] = fcmp ogt float [[X1]], 1.000000e+00
-; CHECK-NEXT:    [[CMP2:%.*]] = fcmp ogt float [[X2]], 1.000000e+00
-; CHECK-NEXT:    [[CMP3:%.*]] = fcmp ogt float [[X3]], 1.000000e+00
-; CHECK-NEXT:    [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3WRONG]]
-; CHECK-NEXT:    [[OR031:%.*]] = or i1 [[OR03]], [[CMP1]]
-; CHECK-NEXT:    [[OR0312:%.*]] = or i1 [[OR031]], [[CMP2]]
-; CHECK-NEXT:    [[OR03123:%.*]] = or i1 [[OR0312]], [[CMP3]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[OR03123]], float -1.000000e+00, float 1.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i32 3
+; CHECK-NEXT:    [[CMP3WRONG:%.*]] = fcmp olt float [[TMP1]], 4.200000e+01
+; CHECK-NEXT:    [[TMP2:%.*]] = fcmp ogt <4 x float> [[X]], <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>
+; CHECK-NEXT:    [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP4]], float -1.000000e+00, float 1.000000e+00
 ; CHECK-NEXT:    ret float [[R]]
 ;
   %x0 = extractelement <4 x float> %x, i32 0
@@ -143,20 +135,12 @@ define float @merge_anyof_v4f32_wrong_last(<4 x float> %x) {
 
 define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) {
 ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle(
-; CHECK-NEXT:    [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0
-; CHECK-NEXT:    [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1
-; CHECK-NEXT:    [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2
-; CHECK-NEXT:    [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3
-; CHECK-NEXT:    [[CMP3WRONG:%.*]] = icmp slt i32 [[X3]], 42
-; CHECK-NEXT:    [[CMP0:%.*]] = icmp sgt i32 [[X0]], 1
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sgt i32 [[X1]], 1
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i32 [[X2]], 1
-; CHECK-NEXT:    [[CMP3:%.*]] = icmp sgt i32 [[X3]], 1
-; CHECK-NEXT:    [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3]]
-; CHECK-NEXT:    [[OR033:%.*]] = or i1 [[OR03]], [[CMP3WRONG]]
-; CHECK-NEXT:    [[OR0332:%.*]] = or i1 [[OR033]], [[CMP2]]
-; CHECK-NEXT:    [[OR03321:%.*]] = or i1 [[OR0332]], [[CMP1]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[OR03321]], i32 -1, i32 1
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3
+; CHECK-NEXT:    [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP1]], 42
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp sgt <4 x i32> [[X]], <i32 1, i32 1, i32 1, i32 1>
+; CHECK-NEXT:    [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP4]], i32 -1, i32 1
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %x0 = extractelement <4 x i32> %x, i32 0
@@ -176,29 +160,18 @@ define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) {
   ret i32 %r
 }
 
+; Operand/predicate swapping allows forming a reduction, but the
+; ideal reduction groups all of the original 'sgt' ops together.
+
 define i32 @merge_anyof_v4i32_wrong_middle_better_rdx(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle_better_rdx(
-; CHECK-NEXT:    [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0
-; CHECK-NEXT:    [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1
-; CHECK-NEXT:    [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2
-; CHECK-NEXT:    [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3
-; CHECK-NEXT:    [[Y0:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 0
-; CHECK-NEXT:    [[Y1:%.*]] = extractelement <4 x i32> [[Y]], i32 1
-; CHECK-NEXT:    [[Y2:%.*]] = extractelement <4 x i32> [[Y]], i32 2
-; CHECK-NEXT:    [[Y3:%.*]] = extractelement <4 x i32> [[Y]], i32 3
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sgt i32 [[X1]], [[Y1]]
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x i32> undef, i32 [[X0]], i32 0
-; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[X3]], i32 1
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[Y3]], i32 2
-; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[X2]], i32 3
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x i32> undef, i32 [[Y0]], i32 0
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x i32> [[TMP5]], i32 [[Y3]], i32 1
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x i32> [[TMP6]], i32 [[X3]], i32 2
-; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <4 x i32> [[TMP7]], i32 [[Y2]], i32 3
-; CHECK-NEXT:    [[TMP9:%.*]] = icmp sgt <4 x i32> [[TMP4]], [[TMP8]]
-; CHECK-NEXT:    [[TMP10:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP9]])
-; CHECK-NEXT:    [[TMP11:%.*]] = or i1 [[TMP10]], [[CMP1]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP11]], i32 -1, i32 1
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 3
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3
+; CHECK-NEXT:    [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp sgt <4 x i32> [[X]], [[Y]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP3]])
+; CHECK-NEXT:    [[TMP5:%.*]] = or i1 [[TMP4]], [[CMP3WRONG]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP5]], i32 -1, i32 1
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %x0 = extractelement <4 x i32> %x, i32 0


        


More information about the llvm-commits mailing list