[llvm] 250b467 - [SLP][NFC]Simplify common analysis of instructions in BoUpSLP::collectValuesToDemote by outlining common code, NFC.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 3 06:49:10 PDT 2024


Author: Alexey Bataev
Date: 2024-04-03T06:45:42-07:00
New Revision: 250b467f7c8f06350a64d1a17e3ac7e3e390d4b1

URL: https://github.com/llvm/llvm-project/commit/250b467f7c8f06350a64d1a17e3ac7e3e390d4b1
DIFF: https://github.com/llvm/llvm-project/commit/250b467f7c8f06350a64d1a17e3ac7e3e390d4b1.diff

LOG: [SLP][NFC]Simplify common analysis of instructions in BoUpSLP::collectValuesToDemote by outlining common code, NFC.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index db052cec69f75d..cb55992051ebf0 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -14097,25 +14097,52 @@ bool BoUpSLP::collectValuesToDemote(
         }
         return false;
       };
-  bool NeedToExit = false;
+  auto TryProcessInstruction =
+      [&](Instruction *I, const TreeEntry &ITE, unsigned &BitWidth,
+          ArrayRef<Value *> Operands = std::nullopt,
+          function_ref<bool(unsigned, unsigned)> Checker = {}) {
+        if (Operands.empty()) {
+          if (!IsTruncRoot)
+            MaxDepthLevel = 1;
+          (void)IsPotentiallyTruncated(V, BitWidth);
+        } else {
+          // Several vectorized uses? Check if we can truncate it, otherwise -
+          // exit.
+          if (ITE.UserTreeIndices.size() > 1 &&
+              !IsPotentiallyTruncated(I, BitWidth))
+            return false;
+          bool NeedToExit = false;
+          if (Checker && !AttemptCheckBitwidth(Checker, NeedToExit))
+            return false;
+          if (NeedToExit)
+            return true;
+          if (!ProcessOperands(Operands, NeedToExit))
+            return false;
+          if (NeedToExit)
+            return true;
+        }
+
+        ++MaxDepthLevel;
+        // Gather demoted constant operands.
+        for (unsigned Idx : seq<unsigned>(Start, End))
+          if (isa<Constant>(I->getOperand(Idx)))
+            DemotedConsts.try_emplace(I).first->getSecond().push_back(Idx);
+        // Record the value that we can demote.
+        ToDemote.push_back(V);
+        return IsProfitableToDemote;
+      };
   switch (I->getOpcode()) {
 
   // We can always demote truncations and extensions. Since truncations can
   // seed additional demotion, we save the truncated value.
   case Instruction::Trunc:
-    if (!IsTruncRoot)
-      MaxDepthLevel = 1;
     if (IsProfitableToDemoteRoot)
       IsProfitableToDemote = true;
-    (void)IsPotentiallyTruncated(V, BitWidth);
-    break;
+    return TryProcessInstruction(I, *ITE, BitWidth);
   case Instruction::ZExt:
   case Instruction::SExt:
-    if (!IsTruncRoot)
-      MaxDepthLevel = 1;
     IsProfitableToDemote = true;
-    (void)IsPotentiallyTruncated(V, BitWidth);
-    break;
+    return TryProcessInstruction(I, *ITE, BitWidth);
 
   // We can demote certain binary operations if we can demote both of their
   // operands.
@@ -14125,140 +14152,83 @@ bool BoUpSLP::collectValuesToDemote(
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor: {
-    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
-      return false;
-    if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
-      return false;
-    break;
+    return TryProcessInstruction(I, *ITE, BitWidth,
+                                 {I->getOperand(0), I->getOperand(1)});
   }
   case Instruction::Shl: {
-    // Several vectorized uses? Check if we can truncate it, otherwise - exit.
-    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
-      return false;
     // If we are truncating the result of this SHL, and if it's a shift of an
     // inrange amount, we can always perform a SHL in a smaller type.
-    if (!AttemptCheckBitwidth(
-            [&](unsigned BitWidth, unsigned) {
-              KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
-              return AmtKnownBits.getMaxValue().ult(BitWidth);
-            },
-            NeedToExit))
-      return false;
-    if (NeedToExit)
-      return true;
-    if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
-      return false;
-    break;
+    auto ShlChecker = [&](unsigned BitWidth, unsigned) {
+      KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+      return AmtKnownBits.getMaxValue().ult(BitWidth);
+    };
+    return TryProcessInstruction(
+        I, *ITE, BitWidth, {I->getOperand(0), I->getOperand(1)}, ShlChecker);
   }
   case Instruction::LShr: {
-    // Several vectorized uses? Check if we can truncate it, otherwise - exit.
-    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
-      return false;
     // If this is a truncate of a logical shr, we can truncate it to a smaller
     // lshr iff we know that the bits we would otherwise be shifting in are
     // already zeros.
-    if (!AttemptCheckBitwidth(
-            [&](unsigned BitWidth, unsigned OrigBitWidth) {
-              KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
-              APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
-              return AmtKnownBits.getMaxValue().ult(BitWidth) &&
-                     MaskedValueIsZero(I->getOperand(0), ShiftedBits,
-                                       SimplifyQuery(*DL));
-            },
-            NeedToExit))
-      return false;
-    if (NeedToExit)
-      return true;
-    if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
-      return false;
-    break;
+    auto LShrChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
+      KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+      APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+      return AmtKnownBits.getMaxValue().ult(BitWidth) &&
+             MaskedValueIsZero(I->getOperand(0), ShiftedBits,
+                               SimplifyQuery(*DL));
+    };
+    return TryProcessInstruction(
+        I, *ITE, BitWidth, {I->getOperand(0), I->getOperand(1)}, LShrChecker);
   }
   case Instruction::AShr: {
-    // Several vectorized uses? Check if we can truncate it, otherwise - exit.
-    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
-      return false;
     // If this is a truncate of an arithmetic shr, we can truncate it to a
     // smaller ashr iff we know that all the bits from the sign bit of the
     // original type and the sign bit of the truncate type are similar.
-    if (!AttemptCheckBitwidth(
-            [&](unsigned BitWidth, unsigned OrigBitWidth) {
-              KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
-              unsigned ShiftedBits = OrigBitWidth - BitWidth;
-              return AmtKnownBits.getMaxValue().ult(BitWidth) &&
-                     ShiftedBits < ComputeNumSignBits(I->getOperand(0), *DL, 0,
-                                                      AC, nullptr, DT);
-            },
-            NeedToExit))
-      return false;
-    if (NeedToExit)
-      return true;
-    if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
-      return false;
-    break;
+    auto AShrChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
+      KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+      unsigned ShiftedBits = OrigBitWidth - BitWidth;
+      return AmtKnownBits.getMaxValue().ult(BitWidth) &&
+             ShiftedBits <
+                 ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT);
+    };
+    return TryProcessInstruction(
+        I, *ITE, BitWidth, {I->getOperand(0), I->getOperand(1)}, AShrChecker);
   }
   case Instruction::UDiv:
   case Instruction::URem: {
-    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
-      return false;
     // UDiv and URem can be truncated if all the truncated bits are zero.
-    if (!AttemptCheckBitwidth(
-            [&](unsigned BitWidth, unsigned OrigBitWidth) {
-              assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!");
-              APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
-              return MaskedValueIsZero(I->getOperand(0), Mask,
-                                       SimplifyQuery(*DL)) &&
-                     MaskedValueIsZero(I->getOperand(1), Mask,
-                                       SimplifyQuery(*DL));
-            },
-            NeedToExit))
-      return false;
-    if (NeedToExit)
-      return true;
-    if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
-      return false;
-    break;
+    auto Checker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
+      assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!");
+      APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+      return MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)) &&
+             MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL));
+    };
+    return TryProcessInstruction(I, *ITE, BitWidth,
+                                 {I->getOperand(0), I->getOperand(1)}, Checker);
   }
 
   // We can demote selects if we can demote their true and false values.
   case Instruction::Select: {
-    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
-      return false;
     Start = 1;
     auto *SI = cast<SelectInst>(I);
-    if (!ProcessOperands({SI->getTrueValue(), SI->getFalseValue()}, NeedToExit))
-      return false;
-    break;
+    return TryProcessInstruction(I, *ITE, BitWidth,
+                                 {SI->getTrueValue(), SI->getFalseValue()});
   }
 
   // We can demote phis if we can demote all their incoming operands. Note that
   // we don't need to worry about cycles since we ensure single use above.
   case Instruction::PHI: {
     PHINode *PN = cast<PHINode>(I);
-    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
-      return false;
     SmallVector<Value *> Ops(PN->incoming_values().begin(),
                              PN->incoming_values().end());
-    if (!ProcessOperands(Ops, NeedToExit))
-      return false;
-    break;
+    return TryProcessInstruction(I, *ITE, BitWidth, Ops);
   }
 
   // Otherwise, conservatively give up.
   default:
-    MaxDepthLevel = 1;
-    return FinalAnalysis();
+    break;
   }
-  if (NeedToExit)
-    return true;
-
-  ++MaxDepthLevel;
-  // Gather demoted constant operands.
-  for (unsigned Idx : seq<unsigned>(Start, End))
-    if (isa<Constant>(I->getOperand(Idx)))
-      DemotedConsts.try_emplace(I).first->getSecond().push_back(Idx);
-  // Record the value that we can demote.
-  ToDemote.push_back(V);
-  return IsProfitableToDemote;
+  MaxDepthLevel = 1;
+  return FinalAnalysis();
 }
 
 void BoUpSLP::computeMinimumValueSizes() {
@@ -14309,7 +14279,8 @@ void BoUpSLP::computeMinimumValueSizes() {
   DenseMap<Instruction *, SmallVector<unsigned>> DemotedConsts;
   auto ComputeMaxBitWidth = [&](ArrayRef<Value *> TreeRoot, unsigned VF,
                                 bool IsTopRoot, bool IsProfitableToDemoteRoot,
-                                unsigned Opcode, unsigned Limit, bool IsTruncRoot) {
+                                unsigned Opcode, unsigned Limit,
+                                bool IsTruncRoot) {
     ToDemote.clear();
     auto *TreeRootIT = dyn_cast<IntegerType>(TreeRoot[0]->getType());
     if (!TreeRootIT || !Opcode)


        


More information about the llvm-commits mailing list