[llvm] 81d9ed6 - [SLP]Do extra analysis int minbitwidth if some checks return false.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 20 05:52:47 PDT 2024


Author: Alexey Bataev
Date: 2024-03-20T05:48:55-07:00
New Revision: 81d9ed605b126d163f2f1cc226c639f1dfdc5224

URL: https://github.com/llvm/llvm-project/commit/81d9ed605b126d163f2f1cc226c639f1dfdc5224
DIFF: https://github.com/llvm/llvm-project/commit/81d9ed605b126d163f2f1cc226c639f1dfdc5224.diff

LOG: [SLP]Do extra analysis int minbitwidth if some checks return false.

The instruction itself can be considered good for minbitwidth casting,
even if one of the operand checks returns false.

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: https://github.com/llvm/llvm-project/pull/84363

Added: 
    llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-node-but-not-operands.ll

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e3711fe63efe3e..b0cb3cd47dd793 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -12225,7 +12225,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         return E->VectorizedValue;
       }
       if (True->getType() != VecTy || False->getType() != VecTy) {
-        assert((getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
+        assert((It != MinBWs.end() ||
+                getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
                 getOperandEntry(E, 2)->State == TreeEntry::NeedToGather ||
                 MinBWs.contains(getOperandEntry(E, 1)) ||
                 MinBWs.contains(getOperandEntry(E, 2))) &&
@@ -12297,7 +12298,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         return E->VectorizedValue;
       }
       if (LHS->getType() != VecTy || RHS->getType() != VecTy) {
-        assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
+        assert((It != MinBWs.end() ||
+                getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
                 getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
                 MinBWs.contains(getOperandEntry(E, 0)) ||
                 MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12543,7 +12545,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
           ((Instruction::isBinaryOp(E->getOpcode()) &&
             (LHS->getType() != VecTy || RHS->getType() != VecTy)) ||
            (isa<CmpInst>(VL0) && LHS->getType() != RHS->getType()))) {
-        assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
+        assert((It != MinBWs.end() ||
+                getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
                 getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
                 MinBWs.contains(getOperandEntry(E, 0)) ||
                 MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12559,9 +12562,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
           else
             CastTy = LHS->getType();
         }
-        if (LHS->getType() != VecTy)
+        if (LHS->getType() != CastTy)
           LHS = Builder.CreateIntCast(LHS, CastTy, GetOperandSignedness(0));
-        if (RHS->getType() != VecTy)
+        if (RHS->getType() != CastTy)
           RHS = Builder.CreateIntCast(RHS, CastTy, GetOperandSignedness(1));
       }
 
@@ -13988,15 +13991,6 @@ bool BoUpSLP::collectValuesToDemote(
   // If the value is not a vectorized instruction in the expression and not used
   // by the insertelement instruction and not used in multiple vector nodes, it
   // cannot be demoted.
-  // TODO: improve handling of gathered values and others.
-  auto *I = dyn_cast<Instruction>(V);
-  const TreeEntry *ITE = I ? getTreeEntry(I) : nullptr;
-  if (!ITE || !Visited.insert(I).second || MultiNodeScalars.contains(I) ||
-      all_of(I->users(), [&](User *U) {
-        return isa<InsertElementInst>(U) && !getTreeEntry(U);
-      }))
-    return false;
-
   auto IsPotentiallyTruncated = [&](Value *V, unsigned &BitWidth) -> bool {
     if (MultiNodeScalars.contains(V))
       return false;
@@ -14011,8 +14005,44 @@ bool BoUpSLP::collectValuesToDemote(
     BitWidth = std::max(BitWidth, BitWidth1);
     return BitWidth > 0 && OrigBitWidth >= (BitWidth * 2);
   };
+  auto FinalAnalysis = [&](const TreeEntry *ITE = nullptr) {
+    if (!IsProfitableToDemote)
+      return false;
+    return (ITE && ITE->UserTreeIndices.size() > 1) ||
+           IsPotentiallyTruncated(V, BitWidth);
+  };
+  // TODO: improve handling of gathered values and others.
+  auto *I = dyn_cast<Instruction>(V);
+  const TreeEntry *ITE = I ? getTreeEntry(I) : nullptr;
+  if (!ITE || !Visited.insert(I).second || MultiNodeScalars.contains(I) ||
+      all_of(I->users(), [&](User *U) {
+        return isa<InsertElementInst>(U) && !getTreeEntry(U);
+      }))
+    return FinalAnalysis();
+
   unsigned Start = 0;
   unsigned End = I->getNumOperands();
+
+  auto ProcessOperands = [&](ArrayRef<Value *> Operands, bool &NeedToExit) {
+    NeedToExit = false;
+    unsigned InitLevel = MaxDepthLevel;
+    for (Value *IncValue : Operands) {
+      unsigned Level = InitLevel;
+      if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
+                                 ToDemote, DemotedConsts, Visited, Level,
+                                 IsProfitableToDemote, IsTruncRoot)) {
+        if (!IsProfitableToDemote)
+          return false;
+        NeedToExit = true;
+        if (!FinalAnalysis(ITE))
+          return false;
+        continue;
+      }
+      MaxDepthLevel = std::max(MaxDepthLevel, Level);
+    }
+    return true;
+  };
+  bool NeedToExit = false;
   switch (I->getOpcode()) {
 
   // We can always demote truncations and extensions. Since truncations can
@@ -14038,35 +14068,21 @@ bool BoUpSLP::collectValuesToDemote(
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor: {
-    unsigned Level1 = MaxDepthLevel, Level2 = MaxDepthLevel;
-    if ((ITE->UserTreeIndices.size() > 1 &&
-         !IsPotentiallyTruncated(I, BitWidth)) ||
-        !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
-                               BitWidth, ToDemote, DemotedConsts, Visited,
-                               Level1, IsProfitableToDemote, IsTruncRoot) ||
-        !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
-                               BitWidth, ToDemote, DemotedConsts, Visited,
-                               Level2, IsProfitableToDemote, IsTruncRoot))
+    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+      return false;
+    if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
       return false;
-    MaxDepthLevel = std::max(Level1, Level2);
     break;
   }
 
   // We can demote selects if we can demote their true and false values.
   case Instruction::Select: {
+    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+      return false;
     Start = 1;
-    unsigned Level1 = MaxDepthLevel, Level2 = MaxDepthLevel;
-    SelectInst *SI = cast<SelectInst>(I);
-    if ((ITE->UserTreeIndices.size() > 1 &&
-         !IsPotentiallyTruncated(I, BitWidth)) ||
-        !collectValuesToDemote(SI->getTrueValue(), IsProfitableToDemoteRoot,
-                               BitWidth, ToDemote, DemotedConsts, Visited,
-                               Level1, IsProfitableToDemote, IsTruncRoot) ||
-        !collectValuesToDemote(SI->getFalseValue(), IsProfitableToDemoteRoot,
-                               BitWidth, ToDemote, DemotedConsts, Visited,
-                               Level2, IsProfitableToDemote, IsTruncRoot))
+    auto *SI = cast<SelectInst>(I);
+    if (!ProcessOperands({SI->getTrueValue(), SI->getFalseValue()}, NeedToExit))
       return false;
-    MaxDepthLevel = std::max(Level1, Level2);
     break;
   }
 
@@ -14076,23 +14092,20 @@ bool BoUpSLP::collectValuesToDemote(
     PHINode *PN = cast<PHINode>(I);
     if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
       return false;
-    unsigned InitLevel = MaxDepthLevel;
-    for (Value *IncValue : PN->incoming_values()) {
-      unsigned Level = InitLevel;
-      if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
-                                 ToDemote, DemotedConsts, Visited, Level,
-                                 IsProfitableToDemote, IsTruncRoot))
-        return false;
-      MaxDepthLevel = std::max(MaxDepthLevel, Level);
-    }
+    SmallVector<Value *> Ops(PN->incoming_values().begin(),
+                             PN->incoming_values().end());
+    if (!ProcessOperands(Ops, NeedToExit))
+      return false;
     break;
   }
 
   // Otherwise, conservatively give up.
   default:
     MaxDepthLevel = 1;
-    return IsProfitableToDemote && IsPotentiallyTruncated(I, BitWidth);
+    return FinalAnalysis();
   }
+  if (NeedToExit)
+    return true;
 
   ++MaxDepthLevel;
   // Gather demoted constant operands.
@@ -14131,6 +14144,7 @@ void BoUpSLP::computeMinimumValueSizes() {
 
   // The first value node for store/insertelement is sext/zext/trunc? Skip it,
   // resize to the final type.
+  bool IsTruncRoot = false;
   bool IsProfitableToDemoteRoot = !IsStoreOrInsertElt;
   if (NodeIdx != 0 &&
       VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
@@ -14138,8 +14152,9 @@ void BoUpSLP::computeMinimumValueSizes() {
        VectorizableTree[NodeIdx]->getOpcode() == Instruction::SExt ||
        VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc)) {
     assert(IsStoreOrInsertElt && "Expected store/insertelement seeded graph.");
-    ++NodeIdx;
+    IsTruncRoot = VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc;
     IsProfitableToDemoteRoot = true;
+    ++NodeIdx;
   }
 
   // Analyzed in reduction already and not profitable - exit.
@@ -14271,7 +14286,6 @@ void BoUpSLP::computeMinimumValueSizes() {
     ReductionBitWidth = bit_ceil(ReductionBitWidth);
   }
   bool IsTopRoot = NodeIdx == 0;
-  bool IsTruncRoot = false;
   while (NodeIdx < VectorizableTree.size() &&
          VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
          VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc) {

diff  --git a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
index 1986b51ec94828..7c5f9847db1f41 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
@@ -228,7 +228,7 @@ for.end:                                          ; preds = %for.end.loopexit, %
 ; YAML-NEXT: Function:        test_unrolled_select
 ; YAML-NEXT: Args:
 ; YAML-NEXT:   - String:          'Vectorized horizontal reduction with cost '
-; YAML-NEXT:   - Cost:            '-36'
+; YAML-NEXT:   - Cost:            '-41'
 ; YAML-NEXT:   - String:          ' and with tree size '
 ; YAML-NEXT:   - TreeSize:        '10'
 
@@ -246,15 +246,17 @@ define i32 @test_unrolled_select(ptr noalias nocapture readonly %blk1, ptr noali
 ; CHECK-NEXT:    [[P2_045:%.*]] = phi ptr [ [[BLK2:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR88:%.*]], [[IF_END_86]] ]
 ; CHECK-NEXT:    [[P1_044:%.*]] = phi ptr [ [[BLK1:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR:%.*]], [[IF_END_86]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = load <8 x i8>, ptr [[P1_044]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i16>
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <8 x i8>, ptr [[P2_045]], align 1
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = sub nsw <8 x i32> [[TMP1]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = icmp slt <8 x i32> [[TMP4]], zeroinitializer
-; CHECK-NEXT:    [[TMP6:%.*]] = sub nsw <8 x i32> zeroinitializer, [[TMP4]]
-; CHECK-NEXT:    [[TMP7:%.*]] = select <8 x i1> [[TMP5]], <8 x i32> [[TMP6]], <8 x i32> [[TMP4]]
-; CHECK-NEXT:    [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP7]])
-; CHECK-NEXT:    [[OP_RDX]] = add i32 [[TMP8]], [[S_047]]
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i16>
+; CHECK-NEXT:    [[TMP4:%.*]] = sub <8 x i16> [[TMP1]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[TMP4]] to <8 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp slt <8 x i32> [[TMP5]], zeroinitializer
+; CHECK-NEXT:    [[TMP7:%.*]] = sub <8 x i16> zeroinitializer, [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <8 x i1> [[TMP6]], <8 x i16> [[TMP7]], <8 x i16> [[TMP4]]
+; CHECK-NEXT:    [[TMP9:%.*]] = sext <8 x i16> [[TMP8]] to <8 x i32>
+; CHECK-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP9]])
+; CHECK-NEXT:    [[OP_RDX]] = add i32 [[TMP10]], [[S_047]]
 ; CHECK-NEXT:    [[CMP83:%.*]] = icmp slt i32 [[OP_RDX]], [[LIM:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP83]], label [[IF_END_86]], label [[FOR_END_LOOPEXIT:%.*]]
 ; CHECK:       if.end.86:

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-node-but-not-operands.ll b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-node-but-not-operands.ll
new file mode 100644
index 00000000000000..fcccaf202a91a1
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-node-but-not-operands.ll
@@ -0,0 +1,61 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt --passes=slp-vectorizer -S -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s
+
+define void @test() {
+; CHECK-LABEL: define void @test() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    store <8 x i16> zeroinitializer, ptr null, align 2
+; CHECK-NEXT:    ret void
+;
+entry:
+  %arrayidx8 = getelementptr i8, ptr null, i64 2
+  %shr10 = ashr i32 0, 0
+  %shr19 = lshr i32 0, 0
+  %sub20 = or i32 %shr19, %shr10
+  %xor21 = xor i32 %sub20, 0
+  %conv22 = trunc i32 %xor21 to i16
+  store i16 %conv22, ptr %arrayidx8, align 2
+  %arrayidx28 = getelementptr i8, ptr null, i64 4
+  %shr34 = lshr i32 0, 0
+  %sub35 = or i32 %shr34, %shr10
+  %xor36 = xor i32 %sub35, 0
+  %conv37 = trunc i32 %xor36 to i16
+  store i16 %conv37, ptr %arrayidx28, align 2
+  %arrayidx43 = getelementptr i8, ptr null, i64 6
+  %shr49 = lshr i32 0, 0
+  %sub50 = or i32 %shr49, %shr10
+  %xor51 = xor i32 %sub50, 0
+  %conv52 = trunc i32 %xor51 to i16
+  store i16 %conv52, ptr %arrayidx43, align 2
+  %arrayidx.1 = getelementptr i8, ptr null, i64 8
+  %shr.1 = lshr i32 0, 0
+  %xor2.1 = xor i32 %shr.1, %shr10
+  %sub3.1 = or i32 %xor2.1, 0
+  %conv4.1 = trunc i32 %sub3.1 to i16
+  store i16 %conv4.1, ptr %arrayidx.1, align 2
+  %arrayidx8.1 = getelementptr i8, ptr null, i64 10
+  %shr10.1 = ashr i32 0, 0
+  %shr19.1 = lshr i32 0, 0
+  %sub20.1 = or i32 %shr19.1, %shr10.1
+  %xor21.1 = xor i32 %sub20.1, 0
+  %conv22.1 = trunc i32 %xor21.1 to i16
+  store i16 %conv22.1, ptr %arrayidx8.1, align 2
+  %arrayidx28.1 = getelementptr i8, ptr null, i64 12
+  %shr34.1 = lshr i32 0, 0
+  %sub35.1 = or i32 %shr34.1, %shr10.1
+  %xor36.1 = xor i32 %sub35.1, 0
+  %conv37.1 = trunc i32 %xor36.1 to i16
+  store i16 %conv37.1, ptr %arrayidx28.1, align 2
+  %arrayidx43.1 = getelementptr i8, ptr null, i64 14
+  %shr49.1 = lshr i32 0, 0
+  %sub50.1 = or i32 %shr49.1, %shr10.1
+  %xor51.1 = xor i32 %sub50.1, 0
+  %conv52.1 = trunc i32 %xor51.1 to i16
+  store i16 %conv52.1, ptr %arrayidx43.1, align 2
+  %shr.2 = lshr i32 0, 0
+  %xor2.2 = xor i32 %shr.2, %shr10.1
+  %sub3.2 = or i32 %xor2.2, 0
+  %conv4.2 = trunc i32 %sub3.2 to i16
+  store i16 %conv4.2, ptr null, align 2
+  ret void
+}


        


More information about the llvm-commits mailing list