[llvm] 7ff5770 - [SLP] allow forming 2-way reduction patterns

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 7 04:10:57 PST 2019


Author: Sanjay Patel
Date: 2019-11-07T06:08:42-05:00
New Revision: 7ff57705ba196ce649d6034614b3b9df57e1f84f

URL: https://github.com/llvm/llvm-project/commit/7ff57705ba196ce649d6034614b3b9df57e1f84f
DIFF: https://github.com/llvm/llvm-project/commit/7ff57705ba196ce649d6034614b3b9df57e1f84f.diff

LOG: [SLP] allow forming 2-way reduction patterns

We have a vector compare reduction problem seen in PR39665 comment 2:
https://bugs.llvm.org/show_bug.cgi?id=39665#c2

Or slightly reduced here:

define i1 @cmp2(<2 x double> %a0) {
  %a = fcmp ogt <2 x double> %a0, <double 1.0, double 1.0>
  %b = extractelement <2 x i1> %a, i32 0
  %c = extractelement <2 x i1> %a, i32 1
  %d = and i1 %b, %c
  ret i1 %d
}

SLP would not attempt to turn this into a vector reduction because there is an
artificial lower limit on that transform. We can not completely remove that limit
without inducing regressions though, so this patch just hacks an extra attempt at
creating a 2-way reduction to the end of the analysis.

As shown in the test file, we are still not getting some of the motivating cases,
so follow-on patches will be needed to solve those cases.

Differential Revision: https://reviews.llvm.org/D59710

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Feature/weak_constant.ll
    llvm/test/Transforms/SLPVectorizer/X86/reduction2.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h
index 32ccc8a46380..d06de6cd4e3d 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h
@@ -113,9 +113,12 @@ struct SLPVectorizerPass : public PassInfoMixin<SLPVectorizerPass> {
 
   /// Try to find horizontal reduction or otherwise vectorize a chain of binary
   /// operators.
+  /// \p Try2WayRdx specializes the analysis to only attempt a 2-element
+  /// reduction.
   bool vectorizeRootInstruction(PHINode *P, Value *V, BasicBlock *BB,
                                 slpvectorizer::BoUpSLP &R,
-                                TargetTransformInfo *TTI);
+                                TargetTransformInfo *TTI,
+                                bool Try2WayRdx = false);
 
   /// Try to vectorize trees that start at insertvalue instructions.
   bool vectorizeInsertValueInst(InsertValueInst *IVI, BasicBlock *BB,

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index c0c1a451e951..33e388a21c54 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6397,7 +6397,7 @@ class HorizontalReduction {
 
   /// Attempt to vectorize the tree found by
   /// matchAssociativeReduction.
-  bool tryToReduce(BoUpSLP &V, TargetTransformInfo *TTI) {
+  bool tryToReduce(BoUpSLP &V, TargetTransformInfo *TTI, bool Try2WayRdx) {
     if (ReducedVals.empty())
       return false;
 
@@ -6405,11 +6405,14 @@ class HorizontalReduction {
     // to a nearby power-of-2. Can safely generate oversized
     // vectors and rely on the backend to split them to legal sizes.
     unsigned NumReducedVals = ReducedVals.size();
-    if (NumReducedVals < 4)
+    if (Try2WayRdx && NumReducedVals != 2)
+      return false;
+    unsigned MinRdxVals = Try2WayRdx ? 2 : 4;
+    if (NumReducedVals < MinRdxVals)
       return false;
 
     unsigned ReduxWidth = PowerOf2Floor(NumReducedVals);
-
+    unsigned MinRdxWidth = Log2_32(MinRdxVals);
     Value *VectorizedTree = nullptr;
 
     // FIXME: Fast-math-flags should be set based on the instructions in the
@@ -6433,7 +6436,7 @@ class HorizontalReduction {
     SmallVector<Value *, 16> IgnoreList;
     for (auto &V : ReductionOps)
       IgnoreList.append(V.begin(), V.end());
-    while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {
+    while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > MinRdxWidth) {
       auto VL = makeArrayRef(&ReducedVals[i], ReduxWidth);
       V.buildTree(VL, ExternallyUsedValues, IgnoreList);
       Optional<ArrayRef<unsigned>> Order = V.bestOrder();
@@ -6759,7 +6762,7 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P,
 /// performed.
 static bool tryToVectorizeHorReductionOrInstOperands(
     PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R,
-    TargetTransformInfo *TTI,
+    TargetTransformInfo *TTI, bool Try2WayRdx,
     const function_ref<bool(Instruction *, BoUpSLP &)> Vectorize) {
   if (!ShouldVectorizeHor)
     return false;
@@ -6790,7 +6793,7 @@ static bool tryToVectorizeHorReductionOrInstOperands(
     if (BI || SI) {
       HorizontalReduction HorRdx;
       if (HorRdx.matchAssociativeReduction(P, Inst)) {
-        if (HorRdx.tryToReduce(R, TTI)) {
+        if (HorRdx.tryToReduce(R, TTI, Try2WayRdx)) {
           Res = true;
           // Set P to nullptr to avoid re-analysis of phi node in
           // matchAssociativeReduction function unless this is the root node.
@@ -6833,7 +6836,8 @@ static bool tryToVectorizeHorReductionOrInstOperands(
 
 bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V,
                                                  BasicBlock *BB, BoUpSLP &R,
-                                                 TargetTransformInfo *TTI) {
+                                                 TargetTransformInfo *TTI,
+                                                 bool Try2WayRdx) {
   if (!V)
     return false;
   auto *I = dyn_cast<Instruction>(V);
@@ -6846,7 +6850,7 @@ bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V,
   auto &&ExtraVectorization = [this](Instruction *I, BoUpSLP &R) -> bool {
     return tryToVectorize(I, R);
   };
-  return tryToVectorizeHorReductionOrInstOperands(P, I, BB, R, TTI,
+  return tryToVectorizeHorReductionOrInstOperands(P, I, BB, R, TTI, Try2WayRdx,
                                                   ExtraVectorization);
 }
 
@@ -7042,6 +7046,23 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
       PostProcessInstructions.push_back(&*it);
   }
 
+  // Make a final attempt to match a 2-way reduction if nothing else worked.
+  // We do not try this above because it may interfere with other vectorization
+  // attempts.
+  // TODO: The constraints are copied from the above call to
+  //       vectorizeRootInstruction(), but that might be too restrictive?
+  BasicBlock::iterator LastInst = --BB->end();
+  if (!Changed && LastInst->use_empty() &&
+      (LastInst->getType()->isVoidTy() || isa<CallInst>(LastInst) ||
+       isa<InvokeInst>(LastInst))) {
+    if (ShouldStartVectorizeHorAtStore || !isa<StoreInst>(LastInst)) {
+      for (auto *V : LastInst->operand_values()) {
+        Changed |= vectorizeRootInstruction(nullptr, V, BB, R, TTI,
+                                            /* Try2WayRdx */ true);
+      }
+    }
+  }
+
   return Changed;
 }
 

diff  --git a/llvm/test/Feature/weak_constant.ll b/llvm/test/Feature/weak_constant.ll
index 4ac2e7e7d689..9a2ea126ebc2 100644
--- a/llvm/test/Feature/weak_constant.ll
+++ b/llvm/test/Feature/weak_constant.ll
@@ -1,5 +1,5 @@
 ; RUN: opt < %s -O3 -S > %t
-; RUN:   grep undef %t | count 1
+; RUN:   grep undef %t | count 2
 ; RUN:   grep 5 %t | count 1
 ; RUN:   grep 7 %t | count 1
 ; RUN:   grep 9 %t | count 1

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/reduction2.ll b/llvm/test/Transforms/SLPVectorizer/X86/reduction2.ll
index b5f433549274..fef9a8e50cdf 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reduction2.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reduction2.ll
@@ -54,10 +54,10 @@ define double @foo(double* nocapture %D) {
 define i1 @two_wide_fcmp_reduction(<2 x double> %a0) {
 ; CHECK-LABEL: @two_wide_fcmp_reduction(
 ; CHECK-NEXT:    [[A:%.*]] = fcmp ogt <2 x double> [[A0:%.*]], <double 1.000000e+00, double 1.000000e+00>
-; CHECK-NEXT:    [[B:%.*]] = extractelement <2 x i1> [[A]], i32 0
-; CHECK-NEXT:    [[C:%.*]] = extractelement <2 x i1> [[A]], i32 1
-; CHECK-NEXT:    [[D:%.*]] = and i1 [[B]], [[C]]
-; CHECK-NEXT:    ret i1 [[D]]
+; CHECK-NEXT:    [[RDX_SHUF:%.*]] = shufflevector <2 x i1> [[A]], <2 x i1> undef, <2 x i32> <i32 1, i32 undef>
+; CHECK-NEXT:    [[BIN_RDX:%.*]] = and <2 x i1> [[A]], [[RDX_SHUF]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[BIN_RDX]], i32 0
+; CHECK-NEXT:    ret i1 [[TMP1]]
 ;
   %a = fcmp ogt <2 x double> %a0, <double 1.0, double 1.0>
   %b = extractelement <2 x i1> %a, i32 0
@@ -96,12 +96,11 @@ define i1 @fcmp_lt_gt(double %a, double %b, double %c) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x double> undef, double [[MUL]], i32 0
 ; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <2 x double> [[TMP5]], double [[MUL]], i32 1
 ; CHECK-NEXT:    [[TMP7:%.*]] = fdiv <2 x double> [[TMP4]], [[TMP6]]
-; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[TMP7]], i32 1
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp olt double [[TMP8]], 0x3EB0C6F7A0B5ED8D
-; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <2 x double> [[TMP7]], i32 0
-; CHECK-NEXT:    [[CMP4:%.*]] = fcmp olt double [[TMP9]], 0x3EB0C6F7A0B5ED8D
-; CHECK-NEXT:    [[OR_COND:%.*]] = and i1 [[CMP]], [[CMP4]]
-; CHECK-NEXT:    br i1 [[OR_COND]], label [[CLEANUP:%.*]], label [[LOR_LHS_FALSE:%.*]]
+; CHECK-NEXT:    [[TMP8:%.*]] = fcmp olt <2 x double> [[TMP7]], <double 0x3EB0C6F7A0B5ED8D, double 0x3EB0C6F7A0B5ED8D>
+; CHECK-NEXT:    [[RDX_SHUF:%.*]] = shufflevector <2 x i1> [[TMP8]], <2 x i1> undef, <2 x i32> <i32 1, i32 undef>
+; CHECK-NEXT:    [[BIN_RDX:%.*]] = and <2 x i1> [[TMP8]], [[RDX_SHUF]]
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <2 x i1> [[BIN_RDX]], i32 0
+; CHECK-NEXT:    br i1 [[TMP9]], label [[CLEANUP:%.*]], label [[LOR_LHS_FALSE:%.*]]
 ; CHECK:       lor.lhs.false:
 ; CHECK-NEXT:    [[TMP10:%.*]] = fcmp ule <2 x double> [[TMP7]], <double 1.000000e+00, double 1.000000e+00>
 ; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <2 x i1> [[TMP10]], i32 0


        


More information about the llvm-commits mailing list