[llvm] r299299 - NewGVN: Clean up GVNExpression memory hierarchy, restructure hash computation a bit so we don't have to redefine it for loads, stores, and calls

Daniel Berlin via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 1 02:44:29 PDT 2017


Author: dannyb
Date: Sat Apr  1 04:44:29 2017
New Revision: 299299

URL: http://llvm.org/viewvc/llvm-project?rev=299299&view=rev
Log:
NewGVN: Clean up GVNExpression memory hierarchy, restructure hash computation a bit so we don't have to redefine it for loads, stores, and calls

Modified:
    llvm/trunk/include/llvm/Transforms/Scalar/GVNExpression.h
    llvm/trunk/lib/Transforms/Scalar/NewGVN.cpp

Modified: llvm/trunk/include/llvm/Transforms/Scalar/GVNExpression.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/Scalar/GVNExpression.h?rev=299299&r1=299298&r2=299299&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Transforms/Scalar/GVNExpression.h (original)
+++ llvm/trunk/include/llvm/Transforms/Scalar/GVNExpression.h Sat Apr  1 04:44:29 2017
@@ -43,11 +43,13 @@ enum ExpressionType {
   ET_Unknown,
   ET_BasicStart,
   ET_Basic,
-  ET_Call,
   ET_AggregateValue,
   ET_Phi,
+  ET_MemoryStart,
+  ET_Call,
   ET_Load,
   ET_Store,
+  ET_MemoryEnd,
   ET_BasicEnd
 };
 
@@ -72,8 +74,6 @@ public:
     if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
       return true;
     // Compare the expression type for anything but load and store.
-    // For load and store we set the opcode to zero.
-    // This is needed for load coercion.
     if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
         getExpressionType() != Other.getExpressionType())
       return false;
@@ -87,9 +87,8 @@ public:
   void setOpcode(unsigned opcode) { Opcode = opcode; }
   ExpressionType getExpressionType() const { return EType; }
 
-  virtual hash_code getHashValue() const {
-    return hash_combine(getExpressionType(), getOpcode());
-  }
+  // We deliberately leave the expression type out of the hash value.
+  virtual hash_code getHashValue() const { return getOpcode(); }
 
   //
   // Debugging support
@@ -106,7 +105,10 @@ public:
     OS << "}";
   }
 
-  void dump() const { print(dbgs()); }
+  LLVM_DUMP_METHOD void dump() const {
+    print(dbgs());
+    dbgs() << "\n";
+  }
 };
 
 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
@@ -200,7 +202,7 @@ public:
   }
 
   hash_code getHashValue() const override {
-    return hash_combine(getExpressionType(), getOpcode(), ValueType,
+    return hash_combine(this->Expression::getHashValue(), ValueType,
                         hash_combine_range(op_begin(), op_end()));
   }
 
@@ -241,32 +243,53 @@ public:
   op_inserter &operator++(int) { return *this; }
 };
 
-class CallExpression final : public BasicExpression {
+class MemoryExpression : public BasicExpression {
 private:
-  CallInst *Call;
-  MemoryAccess *DefiningAccess;
+  const MemoryAccess *MemoryLeader;
 
 public:
-  CallExpression(unsigned NumOperands, CallInst *C, MemoryAccess *DA)
-      : BasicExpression(NumOperands, ET_Call), Call(C), DefiningAccess(DA) {}
-  CallExpression() = delete;
-  CallExpression(const CallExpression &) = delete;
-  CallExpression &operator=(const CallExpression &) = delete;
-  ~CallExpression() override;
+  MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
+                   const MemoryAccess *MemoryLeader)
+      : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader){};
 
+  MemoryExpression() = delete;
+  MemoryExpression(const MemoryExpression &) = delete;
+  MemoryExpression &operator=(const MemoryExpression &) = delete;
   static bool classof(const Expression *EB) {
-    return EB->getExpressionType() == ET_Call;
+    return EB->getExpressionType() > ET_MemoryStart &&
+           EB->getExpressionType() < ET_MemoryEnd;
+  }
+  hash_code getHashValue() const override {
+    return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
   }
 
   bool equals(const Expression &Other) const override {
     if (!this->BasicExpression::equals(Other))
       return false;
-    const auto &OE = cast<CallExpression>(Other);
-    return DefiningAccess == OE.DefiningAccess;
+    const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
+
+    return MemoryLeader == OtherMCE.MemoryLeader;
   }
 
-  hash_code getHashValue() const override {
-    return hash_combine(this->BasicExpression::getHashValue(), DefiningAccess);
+  const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
+  void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
+};
+
+class CallExpression final : public MemoryExpression {
+private:
+  CallInst *Call;
+
+public:
+  CallExpression(unsigned NumOperands, CallInst *C,
+                 const MemoryAccess *MemoryLeader)
+      : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
+  CallExpression() = delete;
+  CallExpression(const CallExpression &) = delete;
+  CallExpression &operator=(const CallExpression &) = delete;
+  ~CallExpression() override;
+
+  static bool classof(const Expression *EB) {
+    return EB->getExpressionType() == ET_Call;
   }
 
   //
@@ -276,22 +299,23 @@ public:
     if (PrintEType)
       OS << "ExpressionTypeCall, ";
     this->BasicExpression::printInternal(OS, false);
-    OS << " represents call at " << Call;
+    OS << " represents call at ";
+    Call->printAsOperand(OS);
   }
 };
 
-class LoadExpression final : public BasicExpression {
+class LoadExpression final : public MemoryExpression {
 private:
   LoadInst *Load;
-  MemoryAccess *DefiningAccess;
   unsigned Alignment;
 
 public:
-  LoadExpression(unsigned NumOperands, LoadInst *L, MemoryAccess *DA)
-      : LoadExpression(ET_Load, NumOperands, L, DA) {}
+  LoadExpression(unsigned NumOperands, LoadInst *L,
+                 const MemoryAccess *MemoryLeader)
+      : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
   LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
-                 MemoryAccess *DA)
-      : BasicExpression(NumOperands, EType), Load(L), DefiningAccess(DA) {
+                 const MemoryAccess *MemoryLeader)
+      : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {
     Alignment = L ? L->getAlignment() : 0;
   }
   LoadExpression() = delete;
@@ -306,18 +330,11 @@ public:
   LoadInst *getLoadInst() const { return Load; }
   void setLoadInst(LoadInst *L) { Load = L; }
 
-  MemoryAccess *getDefiningAccess() const { return DefiningAccess; }
-  void setDefiningAccess(MemoryAccess *MA) { DefiningAccess = MA; }
   unsigned getAlignment() const { return Alignment; }
   void setAlignment(unsigned Align) { Alignment = Align; }
 
   bool equals(const Expression &Other) const override;
 
-  hash_code getHashValue() const override {
-    return hash_combine(getOpcode(), getType(), DefiningAccess,
-                        hash_combine_range(op_begin(), op_end()));
-  }
-
   //
   // Debugging support
   //
@@ -325,22 +342,22 @@ public:
     if (PrintEType)
       OS << "ExpressionTypeLoad, ";
     this->BasicExpression::printInternal(OS, false);
-    OS << " represents Load at " << Load;
-    OS << " with DefiningAccess " << *DefiningAccess;
+    OS << " represents Load at ";
+    Load->printAsOperand(OS);
+    OS << " with MemoryLeader " << *getMemoryLeader();
   }
 };
 
-class StoreExpression final : public BasicExpression {
+class StoreExpression final : public MemoryExpression {
 private:
   StoreInst *Store;
   Value *StoredValue;
-  MemoryAccess *DefiningAccess;
 
 public:
   StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
-                  MemoryAccess *DA)
-      : BasicExpression(NumOperands, ET_Store), Store(S),
-        StoredValue(StoredValue), DefiningAccess(DA) {}
+                  const MemoryAccess *MemoryLeader)
+      : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
+        StoredValue(StoredValue) {}
   StoreExpression() = delete;
   StoreExpression(const StoreExpression &) = delete;
   StoreExpression &operator=(const StoreExpression &) = delete;
@@ -351,27 +368,18 @@ public:
   }
 
   StoreInst *getStoreInst() const { return Store; }
-  MemoryAccess *getDefiningAccess() const { return DefiningAccess; }
   Value *getStoredValue() const { return StoredValue; }
 
   bool equals(const Expression &Other) const override;
 
-  hash_code getHashValue() const override {
-    // This deliberately does not include the stored value we compare it as part
-    // of equals, and only against other stores.
-    return hash_combine(getOpcode(), getType(), DefiningAccess,
-                        hash_combine_range(op_begin(), op_end()));
-  }
-
-  //
   // Debugging support
   //
   void printInternal(raw_ostream &OS, bool PrintEType) const override {
     if (PrintEType)
       OS << "ExpressionTypeStore, ";
     this->BasicExpression::printInternal(OS, false);
-    OS << " represents Store at " << Store;
-    OS << " with DefiningAccess " << *DefiningAccess;
+    OS << " represents Store  " << *Store;
+    OS << " with MemoryLeader " << *getMemoryLeader();
   }
 };
 
@@ -527,8 +535,8 @@ public:
   }
 
   hash_code getHashValue() const override {
-    return hash_combine(getExpressionType(), VariableValue->getType(),
-                        VariableValue);
+    return hash_combine(this->Expression::getHashValue(),
+                        VariableValue->getType(), VariableValue);
   }
 
   //
@@ -566,8 +574,8 @@ public:
   }
 
   hash_code getHashValue() const override {
-    return hash_combine(getExpressionType(), ConstantValue->getType(),
-                        ConstantValue);
+    return hash_combine(this->Expression::getHashValue(),
+                        ConstantValue->getType(), ConstantValue);
   }
 
   //
@@ -604,7 +612,7 @@ public:
   }
 
   hash_code getHashValue() const override {
-    return hash_combine(getExpressionType(), Inst);
+    return hash_combine(this->Expression::getHashValue(), Inst);
   }
 
   //

Modified: llvm/trunk/lib/Transforms/Scalar/NewGVN.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/NewGVN.cpp?rev=299299&r1=299298&r2=299299&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/NewGVN.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/NewGVN.cpp Sat Apr  1 04:44:29 2017
@@ -429,17 +429,9 @@ private:
 
 template <typename T>
 static bool equalsLoadStoreHelper(const T &LHS, const Expression &RHS) {
-  if ((!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS)) ||
-      !LHS.BasicExpression::equals(RHS)) {
+  if (!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS))
     return false;
-  } else if (const auto *L = dyn_cast<LoadExpression>(&RHS)) {
-    if (LHS.getDefiningAccess() != L->getDefiningAccess())
-      return false;
-  } else if (const auto *S = dyn_cast<StoreExpression>(&RHS)) {
-    if (LHS.getDefiningAccess() != S->getDefiningAccess())
-      return false;
-  }
-  return true;
+  return LHS.MemoryExpression::equals(RHS);
 }
 
 bool LoadExpression::equals(const Expression &Other) const {
@@ -447,13 +439,13 @@ bool LoadExpression::equals(const Expres
 }
 
 bool StoreExpression::equals(const Expression &Other) const {
-  bool Result = equalsLoadStoreHelper(*this, Other);
+  if (!equalsLoadStoreHelper(*this, Other))
+    return false;
   // Make sure that store vs store includes the value operand.
-  if (Result)
-    if (const auto *S = dyn_cast<StoreExpression>(&Other))
-      if (getStoredValue() != S->getStoredValue())
-        return false;
-  return Result;
+  if (const auto *S = dyn_cast<StoreExpression>(&Other))
+    if (getStoredValue() != S->getStoredValue())
+      return false;
+  return true;
 }
 
 #ifndef NDEBUG




More information about the llvm-commits mailing list