[llvm] f0660a9 - [Local] collectBitParts - bail out if we find more than one root input value.
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Sat May 15 06:01:33 PDT 2021
Author: Simon Pilgrim
Date: 2021-05-15T13:58:42+01:00
New Revision: f0660a977e6822ae68e4c88641d1909b555f0e05
URL: https://github.com/llvm/llvm-project/commit/f0660a977e6822ae68e4c88641d1909b555f0e05
DIFF: https://github.com/llvm/llvm-project/commit/f0660a977e6822ae68e4c88641d1909b555f0e05.diff
LOG: [Local] collectBitParts - bail out if we find more than one root input value.
All the uses that we have for collectBitParts revolve around us matching down to an operation with a single root value - I don't think we're intending to change that (and a lot of collectBitParts assumes it).
The binops cases (OR/FSHL/FSHR) already check if the providers are the same, but that would still mean we waste time collecting through unaryops before getting to them.
Added:
Modified:
llvm/lib/Transforms/Utils/Local.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 74bcd0c14827..2b42b50d5b12 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -2879,7 +2879,8 @@ struct BitPart {
/// does not invalidate internal references (std::map instead of DenseMap).
static const Optional<BitPart> &
collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
- std::map<Value *, Optional<BitPart>> &BPS, int Depth) {
+ std::map<Value *, Optional<BitPart>> &BPS, int Depth,
+ bool &FoundRoot) {
auto I = BPS.find(V);
if (I != BPS.end())
return I->second;
@@ -2904,13 +2905,13 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// If this is an or instruction, it may be an inner node of the bswap.
if (match(V, m_Or(m_Value(X), m_Value(Y)))) {
// Check we have both sources and they are from the same provider.
- const auto &A =
- collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &A = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!A || !A->Provider)
return Result;
- const auto &B =
- collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &B = collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!B || A->Provider != B->Provider)
return Result;
@@ -2943,8 +2944,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
if (!MatchBitReversals && (BitShift.getZExtValue() % 8) != 0)
return Result;
- const auto &Res =
- collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!Res)
return Result;
Result = Res;
@@ -2973,8 +2974,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
if (!MatchBitReversals && (NumMaskedBits % 8) != 0)
return Result;
- const auto &Res =
- collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!Res)
return Result;
Result = Res;
@@ -2988,8 +2989,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// If this is a zext instruction zero extend the result.
if (match(V, m_ZExt(m_Value(X)))) {
- const auto &Res =
- collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!Res)
return Result;
@@ -3004,8 +3005,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// If this is a truncate instruction, extract the lower bits.
if (match(V, m_Trunc(m_Value(X)))) {
- const auto &Res =
- collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!Res)
return Result;
@@ -3018,8 +3019,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// BITREVERSE - most likely due to us previous matching a partial
// bitreverse.
if (match(V, m_BitReverse(m_Value(X)))) {
- const auto &Res =
- collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!Res)
return Result;
@@ -3031,8 +3032,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// BSWAP - most likely due to us previous matching a partial bswap.
if (match(V, m_BSwap(m_Value(X)))) {
- const auto &Res =
- collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!Res)
return Result;
@@ -3063,13 +3064,13 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
return Result;
// Check we have both sources and they are from the same provider.
- const auto &LHS =
- collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &LHS = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!LHS || !LHS->Provider)
- return Result;
+ return Result;
- const auto &RHS =
- collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &RHS = collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS,
+ Depth + 1, FoundRoot);
if (!RHS || LHS->Provider != RHS->Provider)
return Result;
@@ -3083,8 +3084,14 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
}
}
- // Okay, we got to something that isn't a shift, 'or' or 'and'. This must be
- // the input value to the bswap/bitreverse.
+ // If we've already found a root input value then we're never going to merge
+ // these back together.
+ if (FoundRoot)
+ return Result;
+
+ // Okay, we got to something that isn't a shift, 'or', 'and', etc. This must
+ // be the root input value to the bswap/bitreverse.
+ FoundRoot = true;
Result = BitPart(V, BitWidth);
for (unsigned BitIdx = 0; BitIdx < BitWidth; ++BitIdx)
Result->Provenance[BitIdx] = BitIdx;
@@ -3126,8 +3133,10 @@ bool llvm::recognizeBSwapOrBitReverseIdiom(
DemandedTy = Trunc->getType();
// Try to find all the pieces corresponding to the bswap.
+ bool FoundRoot = false;
std::map<Value *, Optional<BitPart>> BPS;
- auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0);
+ const auto &Res =
+ collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0, FoundRoot);
if (!Res)
return false;
ArrayRef<int8_t> BitProvenance = Res->Provenance;
More information about the llvm-commits
mailing list