[llvm] f0660a9 - [Local] collectBitParts - bail out if we find more than one root input value.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat May 15 06:01:33 PDT 2021


Author: Simon Pilgrim
Date: 2021-05-15T13:58:42+01:00
New Revision: f0660a977e6822ae68e4c88641d1909b555f0e05

URL: https://github.com/llvm/llvm-project/commit/f0660a977e6822ae68e4c88641d1909b555f0e05
DIFF: https://github.com/llvm/llvm-project/commit/f0660a977e6822ae68e4c88641d1909b555f0e05.diff

LOG: [Local] collectBitParts - bail out if we find more than one root input value.

All the uses that we have for collectBitParts revolve around us matching down to an operation with a single root value - I don't think we're intending to change that (and a lot of collectBitParts assumes it).

The binops cases (OR/FSHL/FSHR) already check if the providers are the same, but that would still mean we waste time collecting through unaryops before getting to them.

Added: 
    

Modified: 
    llvm/lib/Transforms/Utils/Local.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 74bcd0c14827..2b42b50d5b12 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -2879,7 +2879,8 @@ struct BitPart {
 /// does not invalidate internal references (std::map instead of DenseMap).
 static const Optional<BitPart> &
 collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
-                std::map<Value *, Optional<BitPart>> &BPS, int Depth) {
+                std::map<Value *, Optional<BitPart>> &BPS, int Depth,
+                bool &FoundRoot) {
   auto I = BPS.find(V);
   if (I != BPS.end())
     return I->second;
@@ -2904,13 +2905,13 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
     // If this is an or instruction, it may be an inner node of the bswap.
     if (match(V, m_Or(m_Value(X), m_Value(Y)))) {
       // Check we have both sources and they are from the same provider.
-      const auto &A =
-          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &A = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+                                      Depth + 1, FoundRoot);
       if (!A || !A->Provider)
         return Result;
 
-      const auto &B =
-          collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &B = collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS,
+                                      Depth + 1, FoundRoot);
       if (!B || A->Provider != B->Provider)
         return Result;
 
@@ -2943,8 +2944,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
       if (!MatchBitReversals && (BitShift.getZExtValue() % 8) != 0)
         return Result;
 
-      const auto &Res =
-          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+                                        Depth + 1, FoundRoot);
       if (!Res)
         return Result;
       Result = Res;
@@ -2973,8 +2974,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
       if (!MatchBitReversals && (NumMaskedBits % 8) != 0)
         return Result;
 
-      const auto &Res =
-          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+                                        Depth + 1, FoundRoot);
       if (!Res)
         return Result;
       Result = Res;
@@ -2988,8 +2989,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
 
     // If this is a zext instruction zero extend the result.
     if (match(V, m_ZExt(m_Value(X)))) {
-      const auto &Res =
-          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+                                        Depth + 1, FoundRoot);
       if (!Res)
         return Result;
 
@@ -3004,8 +3005,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
 
     // If this is a truncate instruction, extract the lower bits.
     if (match(V, m_Trunc(m_Value(X)))) {
-      const auto &Res =
-          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+                                        Depth + 1, FoundRoot);
       if (!Res)
         return Result;
 
@@ -3018,8 +3019,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
     // BITREVERSE - most likely due to us previous matching a partial
     // bitreverse.
     if (match(V, m_BitReverse(m_Value(X)))) {
-      const auto &Res =
-          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+                                        Depth + 1, FoundRoot);
       if (!Res)
         return Result;
 
@@ -3031,8 +3032,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
 
     // BSWAP - most likely due to us previous matching a partial bswap.
     if (match(V, m_BSwap(m_Value(X)))) {
-      const auto &Res =
-          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+                                        Depth + 1, FoundRoot);
       if (!Res)
         return Result;
 
@@ -3063,13 +3064,13 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
         return Result;
 
       // Check we have both sources and they are from the same provider.
-      const auto &LHS =
-          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &LHS = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+                                        Depth + 1, FoundRoot);
       if (!LHS || !LHS->Provider)
-          return Result;
+        return Result;
 
-      const auto &RHS =
-          collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &RHS = collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS,
+                                        Depth + 1, FoundRoot);
       if (!RHS || LHS->Provider != RHS->Provider)
         return Result;
 
@@ -3083,8 +3084,14 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
     }
   }
 
-  // Okay, we got to something that isn't a shift, 'or' or 'and'.  This must be
-  // the input value to the bswap/bitreverse.
+  // If we've already found a root input value then we're never going to merge
+  // these back together.
+  if (FoundRoot)
+    return Result;
+
+  // Okay, we got to something that isn't a shift, 'or', 'and', etc. This must
+  // be the root input value to the bswap/bitreverse.
+  FoundRoot = true;
   Result = BitPart(V, BitWidth);
   for (unsigned BitIdx = 0; BitIdx < BitWidth; ++BitIdx)
     Result->Provenance[BitIdx] = BitIdx;
@@ -3126,8 +3133,10 @@ bool llvm::recognizeBSwapOrBitReverseIdiom(
       DemandedTy = Trunc->getType();
 
   // Try to find all the pieces corresponding to the bswap.
+  bool FoundRoot = false;
   std::map<Value *, Optional<BitPart>> BPS;
-  auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0);
+  const auto &Res =
+      collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0, FoundRoot);
   if (!Res)
     return false;
   ArrayRef<int8_t> BitProvenance = Res->Provenance;


        


More information about the llvm-commits mailing list