[llvm] b765fdd - [SLP]Try to keep scalars, used in phi nodes, if phi nodes from same block are vectorized.

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 21 12:23:50 PDT 2024


Author: Alexey Bataev
Date: 2024-08-21T15:23:47-04:00
New Revision: b765fdd997be9ff0afb6de87077cd53d5f3d349c

URL: https://github.com/llvm/llvm-project/commit/b765fdd997be9ff0afb6de87077cd53d5f3d349c
DIFF: https://github.com/llvm/llvm-project/commit/b765fdd997be9ff0afb6de87077cd53d5f3d349c.diff

LOG: [SLP]Try to keep scalars, used in phi nodes, if phi nodes from same block are vectorized.

Before doing the vectorization of the PHI nodes, the compiler sorts them
by the opcodes of the operands. If the scalar is replaced during the
vectorization by extractelement, it breaks this sorting and prevent some
further vectorization attempts. Patch tries to improve this by doing
extra analysis of the scalars and tries to keep them, if it is found that
this scalar is used in other (external) PHI node in the same block.

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: https://github.com/llvm/llvm-project/pull/103923

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/X86/phi.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 848e0de20e7b6c..8f70a43465b8ac 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10930,8 +10930,31 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
       if (CanBeUsedAsScalar) {
         InstructionCost ScalarCost = TTI->getInstructionCost(Inst, CostKind);
         bool KeepScalar = ScalarCost <= ExtraCost;
-        if (KeepScalar && ScalarCost != TTI::TCC_Free &&
-            ExtraCost - ScalarCost <= TTI::TCC_Basic) {
+        // Try to keep original scalar if the user is the phi node from the same
+        // block as the root phis, currently vectorized. It allows to keep
+        // better ordering info of PHIs, being vectorized currently.
+        bool IsProfitablePHIUser =
+            (KeepScalar || (ScalarCost - ExtraCost <= TTI::TCC_Basic &&
+                            VectorizableTree.front()->Scalars.size() > 2)) &&
+            VectorizableTree.front()->getOpcode() == Instruction::PHI &&
+            !Inst->hasNUsesOrMore(UsesLimit) &&
+            none_of(Inst->users(),
+                    [&](User *U) {
+                      auto *PHIUser = dyn_cast<PHINode>(U);
+                      return (!PHIUser ||
+                              PHIUser->getParent() !=
+                                  cast<Instruction>(
+                                      VectorizableTree.front()->getMainOp())
+                                      ->getParent()) &&
+                             !getTreeEntry(U);
+                    }) &&
+            count_if(Entry->Scalars, [&](Value *V) {
+              return ValueToExtUses->contains(V);
+            }) <= 2;
+        if (IsProfitablePHIUser) {
+          KeepScalar = true;
+        } else if (KeepScalar && ScalarCost != TTI::TCC_Free &&
+                   ExtraCost - ScalarCost <= TTI::TCC_Basic) {
           unsigned ScalarUsesCount = count_if(Entry->Scalars, [&](Value *V) {
             return ValueToExtUses->contains(V);
           });

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/phi.ll b/llvm/test/Transforms/SLPVectorizer/X86/phi.ll
index 495a503311ab9e..96151e0bd6c418 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/phi.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/phi.ll
@@ -136,42 +136,41 @@ for.end:                                          ; preds = %for.body
 define float @foo3(ptr nocapture readonly %A) #0 {
 ; CHECK-LABEL: @foo3(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = load float, ptr [[A:%.*]], align 4
-; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds float, ptr [[A]], i64 1
+; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds float, ptr [[A:%.*]], i64 1
+; CHECK-NEXT:    [[TMP0:%.*]] = load <2 x float>, ptr [[A]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x float>, ptr [[ARRAYIDX1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x float> [[TMP1]], <4 x float> poison, <2 x i32> <i32 poison, i32 0>
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x float> [[TMP2]], float [[TMP0]], i32 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x float> [[TMP0]], i32 0
 ; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[R_052:%.*]] = phi float [ [[TMP0]], [[ENTRY]] ], [ [[ADD6:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[TMP4:%.*]] = phi <4 x float> [ [[TMP1]], [[ENTRY]] ], [ [[TMP13:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[TMP5:%.*]] = phi <2 x float> [ [[TMP3]], [[ENTRY]] ], [ [[TMP9:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x float> [[TMP5]], i32 0
-; CHECK-NEXT:    [[MUL:%.*]] = fmul float [[TMP6]], 7.000000e+00
+; CHECK-NEXT:    [[R_052:%.*]] = phi float [ [[TMP2]], [[ENTRY]] ], [ [[ADD6:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = phi <4 x float> [ [[TMP1]], [[ENTRY]] ], [ [[TMP12:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[TMP4:%.*]] = phi <2 x float> [ [[TMP0]], [[ENTRY]] ], [ [[TMP8:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x float> [[TMP4]], i32 0
+; CHECK-NEXT:    [[MUL:%.*]] = fmul float [[TMP5]], 7.000000e+00
 ; CHECK-NEXT:    [[ADD6]] = fadd float [[R_052]], [[MUL]]
-; CHECK-NEXT:    [[TMP7:%.*]] = add nsw i64 [[INDVARS_IV]], 2
-; CHECK-NEXT:    [[ARRAYIDX14:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[TMP7]]
-; CHECK-NEXT:    [[TMP8:%.*]] = load float, ptr [[ARRAYIDX14]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = add nsw i64 [[INDVARS_IV]], 2
+; CHECK-NEXT:    [[ARRAYIDX14:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP7:%.*]] = load float, ptr [[ARRAYIDX14]], align 4
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 3
 ; CHECK-NEXT:    [[ARRAYIDX19:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[INDVARS_IV_NEXT]]
-; CHECK-NEXT:    [[TMP9]] = load <2 x float>, ptr [[ARRAYIDX19]], align 4
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x float> [[TMP5]], <2 x float> [[TMP9]], <4 x i32> <i32 1, i32 poison, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x float> [[TMP10]], float [[TMP8]], i32 1
-; CHECK-NEXT:    [[TMP12:%.*]] = fmul <4 x float> [[TMP11]], <float 8.000000e+00, float 9.000000e+00, float 1.000000e+01, float 1.100000e+01>
-; CHECK-NEXT:    [[TMP13]] = fadd <4 x float> [[TMP4]], [[TMP12]]
-; CHECK-NEXT:    [[TMP14:%.*]] = trunc i64 [[INDVARS_IV_NEXT]] to i32
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[TMP14]], 121
+; CHECK-NEXT:    [[TMP8]] = load <2 x float>, ptr [[ARRAYIDX19]], align 4
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x float> [[TMP4]], <2 x float> [[TMP8]], <4 x i32> <i32 1, i32 poison, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <4 x float> [[TMP9]], float [[TMP7]], i32 1
+; CHECK-NEXT:    [[TMP11:%.*]] = fmul <4 x float> [[TMP10]], <float 8.000000e+00, float 9.000000e+00, float 1.000000e+01, float 1.100000e+01>
+; CHECK-NEXT:    [[TMP12]] = fadd <4 x float> [[TMP3]], [[TMP11]]
+; CHECK-NEXT:    [[TMP13:%.*]] = trunc i64 [[INDVARS_IV_NEXT]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[TMP13]], 121
 ; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
 ; CHECK:       for.end:
-; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <4 x float> [[TMP13]], i32 0
-; CHECK-NEXT:    [[ADD28:%.*]] = fadd float [[ADD6]], [[TMP15]]
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x float> [[TMP13]], i32 1
-; CHECK-NEXT:    [[ADD29:%.*]] = fadd float [[ADD28]], [[TMP16]]
-; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <4 x float> [[TMP13]], i32 2
-; CHECK-NEXT:    [[ADD30:%.*]] = fadd float [[ADD29]], [[TMP17]]
-; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <4 x float> [[TMP13]], i32 3
-; CHECK-NEXT:    [[ADD31:%.*]] = fadd float [[ADD30]], [[TMP18]]
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x float> [[TMP12]], i32 0
+; CHECK-NEXT:    [[ADD28:%.*]] = fadd float [[ADD6]], [[TMP14]]
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <4 x float> [[TMP12]], i32 1
+; CHECK-NEXT:    [[ADD29:%.*]] = fadd float [[ADD28]], [[TMP15]]
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x float> [[TMP12]], i32 2
+; CHECK-NEXT:    [[ADD30:%.*]] = fadd float [[ADD29]], [[TMP16]]
+; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <4 x float> [[TMP12]], i32 3
+; CHECK-NEXT:    [[ADD31:%.*]] = fadd float [[ADD30]], [[TMP17]]
 ; CHECK-NEXT:    ret float [[ADD31]]
 ;
 entry:


        


More information about the llvm-commits mailing list