[llvm] eb97761 - [SandboxIR] sandboxir::Use operands (part 1) and uses (part 2) (#98251)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 10 15:06:49 PDT 2024


Author: vporpo
Date: 2024-07-10T15:06:46-07:00
New Revision: eb97761b4db4877e8c2507054d94a35154e2ba54

URL: https://github.com/llvm/llvm-project/commit/eb97761b4db4877e8c2507054d94a35154e2ba54
DIFF: https://github.com/llvm/llvm-project/commit/eb97761b4db4877e8c2507054d94a35154e2ba54.diff

LOG: [SandboxIR] sandboxir::Use operands (part 1) and uses (part 2) (#98251)

This PR adds the Use class and several operands-related functions to the
User class (part 1) and several uses-related functions to the Value
class (part 2).

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index c84f25f6f5754..039c7ed078d7b 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -71,6 +71,96 @@ namespace sandboxir {
 class Function;
 class Context;
 class Instruction;
+class User;
+class Value;
+
+/// Represents a Def-use/Use-def edge in SandboxIR.
+/// NOTE: Unlike llvm::Use, this is not an integral part of the use-def chains.
+/// It is also not uniqued and is currently passed by value, so you can have
+/// more than one sandboxir::Use objects for the same use-def edge.
+class Use {
+  llvm::Use *LLVMUse;
+  User *Usr;
+  Context *Ctx;
+
+  /// Don't allow the user to create a sandboxir::Use directly.
+  Use(llvm::Use *LLVMUse, User *Usr, Context &Ctx)
+      : LLVMUse(LLVMUse), Usr(Usr), Ctx(&Ctx) {}
+  Use() : LLVMUse(nullptr), Ctx(nullptr) {}
+
+  friend class Value;              // For constructor
+  friend class User;               // For constructor
+  friend class OperandUseIterator; // For constructor
+  friend class UserUseIterator;    // For accessing members
+
+public:
+  operator Value *() const { return get(); }
+  Value *get() const;
+  class User *getUser() const { return Usr; }
+  unsigned getOperandNo() const;
+  Context *getContext() const { return Ctx; }
+  bool operator==(const Use &Other) const {
+    assert(Ctx == Other.Ctx && "Contexts 
diff er!");
+    return LLVMUse == Other.LLVMUse && Usr == Other.Usr;
+  }
+  bool operator!=(const Use &Other) const { return !(*this == Other); }
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const;
+  void dump() const;
+#endif // NDEBUG
+};
+
+/// Returns the operand edge when dereferenced.
+class OperandUseIterator {
+  Use Use;
+  /// Don't let the user create a non-empty OperandUseIterator.
+  OperandUseIterator(const class Use &Use) : Use(Use) {}
+  friend class User;                                  // For constructor
+#define DEF_INSTR(ID, OPC, CLASS) friend class CLASS; // For constructor
+#include "llvm/SandboxIR/SandboxIRValues.def"
+
+public:
+  using 
diff erence_type = std::ptr
diff _t;
+  using value_type = sandboxir::Use;
+  using pointer = value_type *;
+  using reference = value_type &;
+  using iterator_category = std::input_iterator_tag;
+
+  OperandUseIterator() = default;
+  value_type operator*() const;
+  OperandUseIterator &operator++();
+  bool operator==(const OperandUseIterator &Other) const {
+    return Use == Other.Use;
+  }
+  bool operator!=(const OperandUseIterator &Other) const {
+    return !(*this == Other);
+  }
+};
+
+/// Returns user edge when dereferenced.
+class UserUseIterator {
+  Use Use;
+  /// Don't let the user create a non-empty UserUseIterator.
+  UserUseIterator(const class Use &Use) : Use(Use) {}
+  friend class Value; // For constructor
+
+public:
+  using 
diff erence_type = std::ptr
diff _t;
+  using value_type = sandboxir::Use;
+  using pointer = value_type *;
+  using reference = value_type &;
+  using iterator_category = std::input_iterator_tag;
+
+  UserUseIterator() = default;
+  value_type operator*() const { return Use; }
+  UserUseIterator &operator++();
+  bool operator==(const UserUseIterator &Other) const {
+    return Use == Other.Use;
+  }
+  bool operator!=(const UserUseIterator &Other) const {
+    return !(*this == Other);
+  }
+};
 
 /// A SandboxIR Value has users. This is the base class.
 class Value {
@@ -123,9 +213,77 @@ class Value {
   virtual ~Value() = default;
   ClassID getSubclassID() const { return SubclassID; }
 
+  using use_iterator = UserUseIterator;
+  using const_use_iterator = UserUseIterator;
+
+  use_iterator use_begin();
+  const_use_iterator use_begin() const {
+    return const_cast<Value *>(this)->use_begin();
+  }
+  use_iterator use_end() { return use_iterator(Use(nullptr, nullptr, Ctx)); }
+  const_use_iterator use_end() const {
+    return const_cast<Value *>(this)->use_end();
+  }
+
+  iterator_range<use_iterator> uses() {
+    return make_range<use_iterator>(use_begin(), use_end());
+  }
+  iterator_range<const_use_iterator> uses() const {
+    return make_range<const_use_iterator>(use_begin(), use_end());
+  }
+
+  /// Helper for mapped_iterator.
+  struct UseToUser {
+    User *operator()(const Use &Use) const { return &*Use.getUser(); }
+  };
+
+  using user_iterator = mapped_iterator<sandboxir::UserUseIterator, UseToUser>;
+  using const_user_iterator = user_iterator;
+
+  user_iterator user_begin();
+  user_iterator user_end() {
+    return user_iterator(Use(nullptr, nullptr, Ctx), UseToUser());
+  }
+  const_user_iterator user_begin() const {
+    return const_cast<Value *>(this)->user_begin();
+  }
+  const_user_iterator user_end() const {
+    return const_cast<Value *>(this)->user_end();
+  }
+
+  iterator_range<user_iterator> users() {
+    return make_range<user_iterator>(user_begin(), user_end());
+  }
+  iterator_range<const_user_iterator> users() const {
+    return make_range<const_user_iterator>(user_begin(), user_end());
+  }
+  /// \Returns the number of user edges (not necessarily to unique users).
+  /// WARNING: This is a linear-time operation.
+  unsigned getNumUses() const;
+  /// Return true if this value has N uses or more.
+  /// This is logically equivalent to getNumUses() >= N.
+  /// WARNING: This can be expensive, as it is linear to the number of users.
+  bool hasNUsesOrMore(unsigned Num) const {
+    unsigned Cnt = 0;
+    for (auto It = use_begin(), ItE = use_end(); It != ItE; ++It) {
+      if (++Cnt >= Num)
+        return true;
+    }
+    return false;
+  }
+  /// Return true if this Value has exactly N uses.
+  bool hasNUses(unsigned Num) const {
+    unsigned Cnt = 0;
+    for (auto It = use_begin(), ItE = use_end(); It != ItE; ++It) {
+      if (++Cnt > Num)
+        return false;
+    }
+    return Cnt == Num;
+  }
+
   Type *getType() const { return Val->getType(); }
 
-  Context &getContext() const;
+  Context &getContext() const { return Ctx; }
 #ifndef NDEBUG
   /// Should crash if there is something wrong with the instruction.
   virtual void verify() const = 0;
@@ -174,9 +332,61 @@ class User : public Value {
 protected:
   User(ClassID ID, llvm::Value *V, Context &Ctx) : Value(ID, V, Ctx) {}
 
+  /// \Returns the Use edge that corresponds to \p OpIdx.
+  /// Note: This is the default implementation that works for instructions that
+  /// match the underlying LLVM instruction. All others should use a 
diff erent
+  /// implementation.
+  Use getOperandUseDefault(unsigned OpIdx, bool Verify) const;
+  virtual Use getOperandUseInternal(unsigned OpIdx, bool Verify) const = 0;
+  friend class OperandUseIterator; // for getOperandUseInternal()
+
+  /// The default implementation works only for single-LLVMIR-instruction
+  /// Users and only if they match exactly the LLVM instruction.
+  unsigned getUseOperandNoDefault(const Use &Use) const {
+    return Use.LLVMUse->getOperandNo();
+  }
+  /// \Returns the operand index of \p Use.
+  virtual unsigned getUseOperandNo(const Use &Use) const = 0;
+  friend unsigned Use::getOperandNo() const; // For getUseOperandNo()
+
 public:
   /// For isa/dyn_cast.
   static bool classof(const Value *From);
+  using op_iterator = OperandUseIterator;
+  using const_op_iterator = OperandUseIterator;
+  using op_range = iterator_range<op_iterator>;
+  using const_op_range = iterator_range<const_op_iterator>;
+
+  virtual op_iterator op_begin() {
+    assert(isa<llvm::User>(Val) && "Expect User value!");
+    return op_iterator(getOperandUseInternal(0, /*Verify=*/false));
+  }
+  virtual op_iterator op_end() {
+    assert(isa<llvm::User>(Val) && "Expect User value!");
+    return op_iterator(
+        getOperandUseInternal(getNumOperands(), /*Verify=*/false));
+  }
+  virtual const_op_iterator op_begin() const {
+    return const_cast<User *>(this)->op_begin();
+  }
+  virtual const_op_iterator op_end() const {
+    return const_cast<User *>(this)->op_end();
+  }
+
+  op_range operands() { return make_range<op_iterator>(op_begin(), op_end()); }
+  const_op_range operands() const {
+    return make_range<const_op_iterator>(op_begin(), op_end());
+  }
+  Value *getOperand(unsigned OpIdx) const { return getOperandUse(OpIdx).get(); }
+  /// \Returns the operand edge for \p OpIdx. NOTE: This should also work for
+  /// OpIdx == getNumOperands(), which is used for op_end().
+  Use getOperandUse(unsigned OpIdx) const {
+    return getOperandUseInternal(OpIdx, /*Verify=*/true);
+  }
+  virtual unsigned getNumOperands() const {
+    return isa<llvm::User>(Val) ? cast<llvm::User>(Val)->getNumOperands() : 0;
+  }
+
 #ifndef NDEBUG
   void verify() const override {
     assert(isa<llvm::User>(Val) && "Expected User!");
@@ -195,6 +405,9 @@ class Constant : public sandboxir::User {
   Constant(llvm::Constant *C, sandboxir::Context &SBCtx)
       : sandboxir::User(ClassID::Constant, C, SBCtx) {}
   friend class Context; // For constructor.
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
 
 public:
   /// For isa/dyn_cast.
@@ -203,6 +416,9 @@ class Constant : public sandboxir::User {
            From->getSubclassID() == ClassID::Function;
   }
   sandboxir::Context &getParent() const { return getContext(); }
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
 #ifndef NDEBUG
   void verify() const final {
     assert(isa<llvm::Constant>(Val) && "Expected Constant!");
@@ -309,11 +525,17 @@ class OpaqueInst : public sandboxir::Instruction {
   OpaqueInst(ClassID SubclassID, llvm::Instruction *I, sandboxir::Context &Ctx)
       : sandboxir::Instruction(SubclassID, Opcode::Opaque, I, Ctx) {}
   friend class Context; // For constructor.
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
 
 public:
   static bool classof(const sandboxir::Value *From) {
     return From->getSubclassID() == ClassID::Opaque;
   }
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
   unsigned getNumOfIRInstrs() const final { return 1u; }
 #ifndef NDEBUG
   void verify() const final {

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index f64b1145ebf43..160d807738a3c 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -7,12 +7,72 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/SandboxIR/SandboxIR.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/Support/Debug.h"
 #include <sstream>
 
 using namespace llvm::sandboxir;
 
+Value *Use::get() const { return Ctx->getValue(LLVMUse->get()); }
+
+unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }
+
+#ifndef NDEBUG
+void Use::dump(raw_ostream &OS) const {
+  Value *Def = nullptr;
+  if (LLVMUse == nullptr)
+    OS << "<null> LLVM Use! ";
+  else
+    Def = Ctx->getValue(LLVMUse->get());
+  OS << "Def:  ";
+  if (Def == nullptr)
+    OS << "NULL";
+  else
+    OS << *Def;
+  OS << "\n";
+
+  OS << "User: ";
+  if (Usr == nullptr)
+    OS << "NULL";
+  else
+    OS << *Usr;
+  OS << "\n";
+
+  OS << "OperandNo: ";
+  if (Usr == nullptr)
+    OS << "N/A";
+  else
+    OS << getOperandNo();
+  OS << "\n";
+}
+
+void Use::dump() const { dump(dbgs()); }
+#endif // NDEBUG
+
+Use OperandUseIterator::operator*() const { return Use; }
+
+OperandUseIterator &OperandUseIterator::operator++() {
+  assert(Use.LLVMUse != nullptr && "Already at end!");
+  User *User = Use.getUser();
+  Use = User->getOperandUseInternal(Use.getOperandNo() + 1, /*Verify=*/false);
+  return *this;
+}
+
+UserUseIterator &UserUseIterator::operator++() {
+  llvm::Use *&LLVMUse = Use.LLVMUse;
+  assert(LLVMUse != nullptr && "Already at end!");
+  LLVMUse = LLVMUse->getNext();
+  if (LLVMUse == nullptr) {
+    Use.Usr = nullptr;
+    return *this;
+  }
+  auto *Ctx = Use.Ctx;
+  auto *LLVMUser = LLVMUse->getUser();
+  Use.Usr = cast_or_null<sandboxir::User>(Ctx->getValue(LLVMUser));
+  return *this;
+}
+
 Value::Value(ClassID SubclassID, llvm::Value *Val, Context &Ctx)
     : SubclassID(SubclassID), Val(Val), Ctx(Ctx) {
 #ifndef NDEBUG
@@ -20,6 +80,29 @@ Value::Value(ClassID SubclassID, llvm::Value *Val, Context &Ctx)
 #endif
 }
 
+Value::use_iterator Value::use_begin() {
+  llvm::Use *LLVMUse = nullptr;
+  if (Val->use_begin() != Val->use_end())
+    LLVMUse = &*Val->use_begin();
+  User *User = LLVMUse != nullptr ? cast_or_null<sandboxir::User>(Ctx.getValue(
+                                        Val->use_begin()->getUser()))
+                                  : nullptr;
+  return use_iterator(Use(LLVMUse, User, Ctx));
+}
+
+Value::user_iterator Value::user_begin() {
+  auto UseBegin = Val->use_begin();
+  auto UseEnd = Val->use_end();
+  bool AtEnd = UseBegin == UseEnd;
+  llvm::Use *LLVMUse = AtEnd ? nullptr : &*UseBegin;
+  User *User =
+      AtEnd ? nullptr
+            : cast_or_null<sandboxir::User>(Ctx.getValue(&*LLVMUse->getUser()));
+  return user_iterator(Use(LLVMUse, User, Ctx), UseToUser());
+}
+
+unsigned Value::getNumUses() const { return range_size(Val->users()); }
+
 #ifndef NDEBUG
 std::string Value::getName() const {
   std::stringstream SS;
@@ -71,6 +154,17 @@ void Argument::dump() const {
 }
 #endif // NDEBUG
 
+Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
+  assert((!Verify || OpIdx < getNumOperands()) && "Out of bounds!");
+  assert(isa<llvm::User>(Val) && "Non-users have no operands!");
+  llvm::Use *LLVMUse;
+  if (OpIdx != getNumOperands())
+    LLVMUse = &cast<llvm::User>(Val)->getOperandUse(OpIdx);
+  else
+    LLVMUse = cast<llvm::User>(Val)->op_end();
+  return Use(LLVMUse, const_cast<User *>(this), Ctx);
+}
+
 bool User::classof(const Value *From) {
   switch (From->getSubclassID()) {
 #define DEF_VALUE(ID, CLASS)

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 161ee51432cd3..72e81bf640350 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -109,6 +109,126 @@ define void @foo(i32 %v1) {
 #endif
 }
 
+TEST_F(SandboxIRTest, Use) {
+  parseIR(C, R"IR(
+define i32 @foo(i32 %v0, i32 %v1) {
+  %add0 = add i32 %v0, %v1
+  ret i32 %add0
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  BasicBlock *LLVMBB = &*LLVMF.begin();
+  auto LLVMBBIt = LLVMBB->begin();
+  Instruction *LLVMI0 = &*LLVMBBIt++;
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto *Arg0 = F.getArg(0);
+  auto *Arg1 = F.getArg(1);
+  auto It = BB.begin();
+  auto *I0 = &*It++;
+  auto *Ret = &*It++;
+
+  SmallVector<sandboxir::Argument *> Args{Arg0, Arg1};
+  unsigned OpIdx = 0;
+  for (sandboxir::Use Use : I0->operands()) {
+    // Check Use.getOperandNo().
+    EXPECT_EQ(Use.getOperandNo(), OpIdx);
+    // Check Use.getUser().
+    EXPECT_EQ(Use.getUser(), I0);
+    // Check Use.getContext().
+    EXPECT_EQ(Use.getContext(), &Ctx);
+    // Check Use.get().
+    sandboxir::Value *Op = Use.get();
+    EXPECT_EQ(Op, Ctx.getValue(LLVMI0->getOperand(OpIdx)));
+    // Check Use.getUser().
+    EXPECT_EQ(Use.getUser(), I0);
+    // Check implicit cast to Value.
+    sandboxir::Value *Cast = Use;
+    EXPECT_EQ(Cast, Op);
+    // Check that Use points to the correct operand.
+    EXPECT_EQ(Op, Args[OpIdx]);
+    // Check getOperand().
+    EXPECT_EQ(Op, I0->getOperand(OpIdx));
+    // Check getOperandUse().
+    EXPECT_EQ(Use, I0->getOperandUse(OpIdx));
+    ++OpIdx;
+  }
+  EXPECT_EQ(OpIdx, 2u);
+
+  // Check Use.operator==() and Use.operator!=().
+  sandboxir::Use UseA = I0->getOperandUse(0);
+  sandboxir::Use UseB = I0->getOperandUse(0);
+  EXPECT_TRUE(UseA == UseB);
+  EXPECT_FALSE(UseA != UseB);
+
+  // Check getNumOperands().
+  EXPECT_EQ(I0->getNumOperands(), 2u);
+  EXPECT_EQ(Ret->getNumOperands(), 1u);
+
+  EXPECT_EQ(Ret->getOperand(0), I0);
+
+#ifndef NDEBUG
+  // Check Use.dump()
+  std::string Buff;
+  raw_string_ostream BS(Buff);
+  BS << "\n";
+  I0->getOperandUse(0).dump(BS);
+  EXPECT_EQ(Buff, R"IR(
+Def:  i32 %v0 ; SB1. (Argument)
+User:   %add0 = add i32 %v0, %v1 ; SB4. (Opaque)
+OperandNo: 0
+)IR");
+#endif // NDEBUG
+
+  // Check Value.user_begin().
+  sandboxir::Value::user_iterator UIt = I0->user_begin();
+  sandboxir::User *U = *UIt;
+  EXPECT_EQ(U, Ret);
+  // Check Value.uses().
+  EXPECT_EQ(range_size(I0->uses()), 1u);
+  EXPECT_EQ((*I0->uses().begin()).getUser(), Ret);
+  // Check Value.users().
+  EXPECT_EQ(range_size(I0->users()), 1u);
+  EXPECT_EQ(*I0->users().begin(), Ret);
+  // Check Value.getNumUses().
+  EXPECT_EQ(I0->getNumUses(), 1u);
+  // Check Value.hasNUsesOrMore().
+  EXPECT_TRUE(I0->hasNUsesOrMore(0u));
+  EXPECT_TRUE(I0->hasNUsesOrMore(1u));
+  EXPECT_FALSE(I0->hasNUsesOrMore(2u));
+  // Check Value.hasNUses().
+  EXPECT_FALSE(I0->hasNUses(0u));
+  EXPECT_TRUE(I0->hasNUses(1u));
+  EXPECT_FALSE(I0->hasNUses(2u));
+}
+
+// Check that the operands/users are counted correctly.
+//  I1
+// /  \
+// \  /
+//  I2
+TEST_F(SandboxIRTest, DuplicateUses) {
+  parseIR(C, R"IR(
+define void @foo(i8 %v) {
+  %I1 = add i8 %v, %v
+  %I2 = add i8 %I1, %I1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(&LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *I1 = &*It++;
+  auto *I2 = &*It++;
+  EXPECT_EQ(range_size(I1->users()), 2u);
+  EXPECT_EQ(range_size(I2->operands()), 2u);
+}
+
 TEST_F(SandboxIRTest, Function) {
   parseIR(C, R"IR(
 define void @foo(i32 %arg0, i32 %arg1) {


        


More information about the llvm-commits mailing list