[llvm] 8fe59ba - [rs4gc] track the original value in the state use for base pointer rewriting
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 6 08:47:46 PST 2021
Author: Philip Reames
Date: 2021-03-06T08:46:15-08:00
New Revision: 8fe59ba51e687939b412b95529ead22c2fba2443
URL: https://github.com/llvm/llvm-project/commit/8fe59ba51e687939b412b95529ead22c2fba2443
DIFF: https://github.com/llvm/llvm-project/commit/8fe59ba51e687939b412b95529ead22c2fba2443.diff
LOG: [rs4gc] track the original value in the state use for base pointer rewriting
I'd originally intended to build on this for another purpose and have decided not to, but at a minimum, the stronger asserts are useful.
Added:
Modified:
llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
index ad1c61c80881..fd92695fbc03 100644
--- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
+++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
@@ -676,22 +676,30 @@ namespace {
/// the base of this BDV.
class BDVState {
public:
- enum Status {
+ enum StatusTy {
// Starting state of lattice
Unknown,
- // Some specific base value
+ // Some specific base value -- does *not* mean that instruction
+ // propagates the base of the object
+ // ex: gep %arg, 16 -> %arg is the base value
Base,
// Need to insert a node to represent a merge.
Conflict
};
- BDVState() {}
- explicit BDVState(Status Status, Value *BaseValue = nullptr)
- : Status(Status), BaseValue(BaseValue) {
+ BDVState() {
+ llvm_unreachable("missing state in map");
+ }
+
+ explicit BDVState(Value *OriginalValue)
+ : OriginalValue(OriginalValue) {}
+ explicit BDVState(Value *OriginalValue, StatusTy Status, Value *BaseValue = nullptr)
+ : OriginalValue(OriginalValue), Status(Status), BaseValue(BaseValue) {
assert(Status != Base || BaseValue);
}
- Status getStatus() const { return Status; }
+ StatusTy getStatus() const { return Status; }
+ Value *getOriginalValue() const { return OriginalValue; }
Value *getBaseValue() const { return BaseValue; }
bool isBase() const { return getStatus() == Base; }
@@ -699,7 +707,8 @@ class BDVState {
bool isConflict() const { return getStatus() == Conflict; }
bool operator==(const BDVState &Other) const {
- return BaseValue == Other.BaseValue && Status == Other.Status;
+ return OriginalValue == OriginalValue && BaseValue == Other.BaseValue &&
+ Status == Other.Status;
}
bool operator!=(const BDVState &other) const { return !(*this == other); }
@@ -722,12 +731,14 @@ class BDVState {
OS << "C";
break;
}
- OS << " (" << getBaseValue() << " - "
- << (getBaseValue() ? getBaseValue()->getName() : "nullptr") << "): ";
+ OS << " (base " << getBaseValue() << " - "
+ << (getBaseValue() ? getBaseValue()->getName() : "nullptr") << ")"
+ << " for " << OriginalValue->getName() << ":";
}
private:
- Status Status = Unknown;
+ AssertingVH<Value> OriginalValue; // instruction this state corresponds to
+ StatusTy Status = Unknown;
AssertingVH<Value> BaseValue = nullptr; // Non-null only if Status == Base.
};
@@ -740,39 +751,40 @@ static raw_ostream &operator<<(raw_ostream &OS, const BDVState &State) {
}
#endif
-static BDVState meetBDVStateImpl(const BDVState &LHS, const BDVState &RHS) {
- switch (LHS.getStatus()) {
+static BDVState::StatusTy meet(const BDVState::StatusTy &LHS,
+ const BDVState::StatusTy &RHS) {
+ switch (LHS) {
case BDVState::Unknown:
return RHS;
-
case BDVState::Base:
- assert(LHS.getBaseValue() && "can't be null");
- if (RHS.isUnknown())
- return LHS;
-
- if (RHS.isBase()) {
- if (LHS.getBaseValue() == RHS.getBaseValue()) {
- assert(LHS == RHS && "equality broken!");
- return LHS;
- }
- return BDVState(BDVState::Conflict);
- }
- assert(RHS.isConflict() && "only three states!");
- return BDVState(BDVState::Conflict);
-
+ switch (RHS) {
+ case BDVState::Unknown:
+ case BDVState::Base:
+ return BDVState::Base;
+ case BDVState::Conflict:
+ return BDVState::Conflict;
+ };
+ llvm_unreachable("covered switch");
case BDVState::Conflict:
- return LHS;
+ return BDVState::Conflict;
}
- llvm_unreachable("only three states!");
+ llvm_unreachable("covered switch");
}
// Values of type BDVState form a lattice, and this function implements the meet
// operation.
static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) {
- BDVState Result = meetBDVStateImpl(LHS, RHS);
- assert(Result == meetBDVStateImpl(RHS, LHS) &&
- "Math is wrong: meet does not commute!");
- return Result;
+ auto NewStatus = meet(LHS.getStatus(), RHS.getStatus());
+ assert(NewStatus == meet(RHS.getStatus(), LHS.getStatus()));
+
+ Value *BaseValue = LHS.getStatus() == BDVState::Base ?
+ LHS.getBaseValue() : RHS.getBaseValue();
+ if (LHS.getStatus() == BDVState::Base && RHS.getStatus() == BDVState::Base &&
+ LHS.getBaseValue() != RHS.getBaseValue()) {
+ NewStatus = BDVState::Conflict;
+ BaseValue = nullptr;
+ }
+ return BDVState(LHS.getOriginalValue(), NewStatus, BaseValue);
}
/// For a given value or instruction, figure out what base ptr its derived from.
@@ -822,12 +834,18 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
// below. This is important for deterministic compilation.
MapVector<Value *, BDVState> States;
+ auto VerifyStates = [&]() {
+ for (auto &Entry : States) {
+ assert(Entry.first == Entry.second.getOriginalValue());
+ }
+ };
+
// Recursively fill in all base defining values reachable from the initial
// one for which we don't already know a definite base value for
/* scope */ {
SmallVector<Value*, 16> Worklist;
Worklist.push_back(Def);
- States.insert({Def, BDVState()});
+ States.insert({Def, BDVState(Def)});
while (!Worklist.empty()) {
Value *Current = Worklist.pop_back_val();
assert(!isOriginalBaseResult(Current) && "why did it get added?");
@@ -843,7 +861,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
return;
assert(isExpectedBDVType(Base) && "the only non-base values "
"we see should be base defining values");
- if (States.insert(std::make_pair(Base, BDVState())).second)
+ if (States.insert(std::make_pair(Base, BDVState(Base))).second)
Worklist.push_back(Base);
};
if (PHINode *PN = dyn_cast<PHINode>(Current)) {
@@ -868,6 +886,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
}
#ifndef NDEBUG
+ VerifyStates();
LLVM_DEBUG(dbgs() << "States after initialization:\n");
for (auto Pair : States) {
LLVM_DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n");
@@ -878,7 +897,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
// base state for known bases and expect to find a cached state otherwise.
auto GetStateForBDV = [&](Value *BaseValue, Value *Input) {
if (isKnownBaseResult(BaseValue) && areBothVectorOrScalar(BaseValue, Input))
- return BDVState(BDVState::Base, BaseValue);
+ return BDVState(BaseValue, BDVState::Base, BaseValue);
auto I = States.find(BaseValue);
assert(I != States.end() && "lookup failed!");
return I->second;
@@ -910,7 +929,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
return GetStateForBDV(BDV, V);
};
- BDVState NewState;
+ BDVState NewState(BDV);
if (SelectInst *SI = dyn_cast<SelectInst>(BDV)) {
NewState = meetBDVState(NewState, getStateForInput(SI->getTrueValue()));
NewState =
@@ -948,6 +967,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
}
#ifndef NDEBUG
+ VerifyStates();
LLVM_DEBUG(dbgs() << "States after meet iteration:\n");
for (auto Pair : States) {
LLVM_DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n");
@@ -981,17 +1001,21 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
auto *BaseInst = ExtractElementInst::Create(
State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE);
BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
- States[I] = BDVState(BDVState::Base, BaseInst);
+ States[I] = BDVState(I, BDVState::Base, BaseInst);
} else if (!isa<VectorType>(I->getType())) {
// We need to handle cases that have a vector base but the instruction is
// a scalar type (these could be phis or selects or any instruction that
// are of scalar type, but the base can be a vector type). We
// conservatively set this as conflict. Setting the base value for these
// conflicts is handled in the next loop which traverses States.
- States[I] = BDVState(BDVState::Conflict);
+ States[I] = BDVState(I, BDVState::Conflict);
}
}
+#ifndef NDEBUG
+ VerifyStates();
+#endif
+
// Insert Phis for all conflicts
// TODO: adjust naming patterns to avoid this order of iteration dependency
for (auto Pair : States) {
@@ -1048,9 +1072,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
Instruction *BaseInst = MakeBaseInstPlaceholder(I);
// Add metadata marking this as a base value
BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
- States[I] = BDVState(BDVState::Conflict, BaseInst);
+ States[I] = BDVState(I, BDVState::Conflict, BaseInst);
}
+#ifndef NDEBUG
+ VerifyStates();
+#endif
+
// Returns a instruction which produces the base pointer for a given
// instruction. The instruction is assumed to be an input to one of the BDVs
// seen in the inference algorithm above. As such, we must either already
@@ -1171,6 +1199,10 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
}
}
+#ifndef NDEBUG
+ VerifyStates();
+#endif
+
// Cache all of our results so we can cheaply reuse them
// NOTE: This is actually two caches: one of the base defining value
// relation and one of the base pointer relation! FIXME
More information about the llvm-commits
mailing list