[llvm] [SandboxIR] Implement CallBase and CallInst (PR #100218)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 25 14:34:01 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/100218

>From 23b9e5eae0aef8d25bc5f9854d3b451a59f19081 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 22 Jul 2024 12:22:11 -0700
Subject: [PATCH] [SandboxIR] Implement CallBase and CallInst

This patch adds the `CallBase` SandboxIR class and its subclass: `CallInst`.
Both are mirrors of `llvm::CallBase` and `llvm::CallInst` respectively.
Since `llvm::CallBase` contains a large number of member functions so
this patch implements only some of them.

The `CallBase` unit tests uncovered an issue with the class hierarchy, where
`sandboxir::Function` was not a subclass of `sandboxir::Constant`, so this was
fixed.

Testing tracking of the `CallBase` setters showed that `sandboxir::Use::set()`
was not being tracked. So this is also part of this test.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 215 +++++++++++++++++-
 .../llvm/SandboxIR/SandboxIRValues.def        |   3 +
 llvm/include/llvm/SandboxIR/Use.h             |   2 +
 llvm/lib/SandboxIR/SandboxIR.cpp              | 115 +++++++++-
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 212 ++++++++++++++++-
 llvm/unittests/SandboxIR/TrackerTest.cpp      |  75 ++++++
 6 files changed, 603 insertions(+), 19 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 6c04c92e3e70e..73fc6e6524c58 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -18,13 +18,15 @@
 //
 // namespace sandboxir {
 //
-//        +- Argument                   +- BinaryOperator
-//        |                             |
-// Value -+- BasicBlock                 +- BranchInst
-//        |                             |
-//        +- Function   +- Constant     +- CastInst
+//                      +- Constant ------ Function
+//                      |
+//        +- Argument   |               +- BinaryOperator
 //        |             |               |
-//        +- User ------+- Instruction -+- CallInst
+// Value -+- BasicBlock |               +- BranchInst
+//        |             |               |
+//        +- Function   |               +- CastInst
+//        |             |               |
+//        +- User ------+- Instruction -+- CallBase ----- CallInst
 //                                      |
 //                                      +- CmpInst
 //                                      |
@@ -82,6 +84,8 @@ class ReturnInst;
 class StoreInst;
 class User;
 class Value;
+class CallBase;
+class CallInst;
 
 /// Iterator for the `Use` edges of a User's operands.
 /// \Returns the operand `Use` when dereferenced.
@@ -103,12 +107,20 @@ class OperandUseIterator {
   OperandUseIterator() = default;
   value_type operator*() const;
   OperandUseIterator &operator++();
+  OperandUseIterator operator++(int) {
+    auto Copy = *this;
+    this->operator++();
+    return Copy;
+  }
   bool operator==(const OperandUseIterator &Other) const {
     return Use == Other.Use;
   }
   bool operator!=(const OperandUseIterator &Other) const {
     return !(*this == Other);
   }
+  OperandUseIterator operator+(unsigned Num) const;
+  OperandUseIterator operator-(unsigned Num) const;
+  int operator-(const OperandUseIterator &Other) const;
 };
 
 /// Iterator for the `Use` edges of a Value's users.
@@ -135,6 +147,7 @@ class UserUseIterator {
   bool operator!=(const UserUseIterator &Other) const {
     return !(*this == Other);
   }
+  const sandboxir::Use &getUse() const { return Use; }
 };
 
 /// A SandboxIR Value has users. This is the base class.
@@ -184,6 +197,8 @@ class Value {
   friend class LoadInst;   // For getting `Val`.
   friend class StoreInst;  // For getting `Val`.
   friend class ReturnInst; // For getting `Val`.
+  friend class CallBase;   // For getting `Val`.
+  friend class CallInst;   // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -417,7 +432,10 @@ class User : public Value {
 class Constant : public sandboxir::User {
   Constant(llvm::Constant *C, sandboxir::Context &SBCtx)
       : sandboxir::User(ClassID::Constant, C, SBCtx) {}
-  friend class Context; // For constructor.
+  Constant(ClassID ID, llvm::Constant *C, sandboxir::Context &SBCtx)
+      : sandboxir::User(ID, C, SBCtx) {}
+  friend class Function; // For constructor
+  friend class Context;  // For constructor.
   Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
     return getOperandUseDefault(OpIdx, Verify);
   }
@@ -435,7 +453,7 @@ class Constant : public sandboxir::User {
     return getUseOperandNoDefault(Use);
   }
 #ifndef NDEBUG
-  void verify() const final {
+  void verify() const override {
     assert(isa<llvm::Constant>(Val) && "Expected Constant!");
   }
   friend raw_ostream &operator<<(raw_ostream &OS,
@@ -518,6 +536,7 @@ class Instruction : public sandboxir::User {
   friend class LoadInst;   // For getTopmostLLVMInstruction().
   friend class StoreInst;  // For getTopmostLLVMInstruction().
   friend class ReturnInst; // For getTopmostLLVMInstruction().
+  friend class CallInst;   // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -835,6 +854,177 @@ class ReturnInst final : public Instruction {
 #endif
 };
 
+class CallBase : public Instruction {
+  CallBase(ClassID ID, Opcode Opc, llvm::Instruction *I, Context &Ctx)
+      : Instruction(ID, Opc, I, Ctx) {}
+  friend class CallInst; // For constructor.
+
+public:
+  static bool classof(const Value *From) {
+    auto Opc = From->getSubclassID();
+    return Opc == Instruction::ClassID::Call ||
+           Opc == Instruction::ClassID::Invoke ||
+           Opc == Instruction::ClassID::CallBr;
+  }
+
+  FunctionType *getFunctionType() const {
+    return cast<llvm::CallBase>(Val)->getFunctionType();
+  }
+
+  op_iterator data_operands_begin() { return op_begin(); }
+  const_op_iterator data_operands_begin() const {
+    return const_cast<CallBase *>(this)->data_operands_begin();
+  }
+  op_iterator data_operands_end() {
+    auto *LLVMCB = cast<llvm::CallBase>(Val);
+    auto Dist = LLVMCB->data_operands_end() - LLVMCB->data_operands_begin();
+    return op_begin() + Dist;
+  }
+  const_op_iterator data_operands_end() const {
+    auto *LLVMCB = cast<llvm::CallBase>(Val);
+    auto Dist = LLVMCB->data_operands_end() - LLVMCB->data_operands_begin();
+    return op_begin() + Dist;
+  }
+  iterator_range<op_iterator> data_ops() {
+    return make_range(data_operands_begin(), data_operands_end());
+  }
+  iterator_range<const_op_iterator> data_ops() const {
+    return make_range(data_operands_begin(), data_operands_end());
+  }
+  bool data_operands_empty() const {
+    return data_operands_end() == data_operands_begin();
+  }
+  unsigned data_operands_size() const {
+    return std::distance(data_operands_begin(), data_operands_end());
+  }
+  bool isDataOperand(Use U) const {
+    assert(this == U.getUser() &&
+           "Only valid to query with a use of this instruction!");
+    return cast<llvm::CallBase>(Val)->isDataOperand(U.LLVMUse);
+  }
+  unsigned getDataOperandNo(Use U) const {
+    assert(isDataOperand(U) && "Data operand # out of range!");
+    return cast<llvm::CallBase>(Val)->getDataOperandNo(U.LLVMUse);
+  }
+
+  /// Return the total number operands (not operand bundles) used by
+  /// every operand bundle in this OperandBundleUser.
+  unsigned getNumTotalBundleOperands() const {
+    return cast<llvm::CallBase>(Val)->getNumTotalBundleOperands();
+  }
+
+  op_iterator arg_begin() { return op_begin(); }
+  const_op_iterator arg_begin() const { return op_begin(); }
+  op_iterator arg_end() {
+    return data_operands_end() - getNumTotalBundleOperands();
+  }
+  const_op_iterator arg_end() const {
+    return const_cast<CallBase *>(this)->arg_end();
+  }
+  iterator_range<op_iterator> args() {
+    return make_range(arg_begin(), arg_end());
+  }
+  iterator_range<const_op_iterator> args() const {
+    return make_range(arg_begin(), arg_end());
+  }
+  bool arg_empty() const { return arg_end() == arg_begin(); }
+  unsigned arg_size() const { return arg_end() - arg_begin(); }
+
+  Value *getArgOperand(unsigned OpIdx) const {
+    assert(OpIdx < arg_size() && "Out of bounds!");
+    return getOperand(OpIdx);
+  }
+  void setArgOperand(unsigned OpIdx, Value *NewOp) {
+    assert(OpIdx < arg_size() && "Out of bounds!");
+    setOperand(OpIdx, NewOp);
+  }
+
+  Use getArgOperandUse(unsigned Idx) const {
+    assert(Idx < arg_size() && "Out of bounds!");
+    return getOperandUse(Idx);
+  }
+  Use getArgOperandUse(unsigned Idx) {
+    assert(Idx < arg_size() && "Out of bounds!");
+    return getOperandUse(Idx);
+  }
+
+  bool isArgOperand(Use U) const {
+    return cast<llvm::CallBase>(Val)->isArgOperand(U.LLVMUse);
+  }
+  unsigned getArgOperandNo(Use U) const {
+    return cast<llvm::CallBase>(Val)->getArgOperandNo(U.LLVMUse);
+  }
+  bool hasArgument(const Value *V) const { return is_contained(args(), V); }
+
+  Value *getCalledOperand() const;
+  Use getCalledOperandUse() const;
+
+  Function *getCalledFunction() const;
+  bool isIndirectCall() const {
+    return cast<llvm::CallBase>(Val)->isIndirectCall();
+  }
+  bool isCallee(Use U) const {
+    return cast<llvm::CallBase>(Val)->isCallee(U.LLVMUse);
+  }
+  Function *getCaller();
+  const Function *getCaller() const {
+    return const_cast<CallBase *>(this)->getCaller();
+  }
+  bool isMustTailCall() const {
+    return cast<llvm::CallBase>(Val)->isMustTailCall();
+  }
+  bool isTailCall() const { return cast<llvm::CallBase>(Val)->isTailCall(); }
+  Intrinsic::ID getIntrinsicID() const {
+    return cast<llvm::CallBase>(Val)->getIntrinsicID();
+  }
+  void setCalledOperand(Value *V) { getCalledOperandUse().set(V); }
+  void setCalledFunction(Function *F);
+  CallingConv::ID getCallingConv() const {
+    return cast<llvm::CallBase>(Val)->getCallingConv();
+  }
+  bool isInlineAsm() const { return cast<llvm::CallBase>(Val)->isInlineAsm(); }
+};
+
+class CallInst final : public CallBase {
+  /// Use Context::createCallInst(). Don't call the
+  /// constructor directly.
+  CallInst(llvm::Instruction *I, Context &Ctx)
+      : CallBase(ClassID::Call, Opcode::Call, I, Ctx) {}
+  friend class Context; // For accessing the constructor in
+                        // create*()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+
+public:
+  static CallInst *create(FunctionType *FTy, Value *Func,
+                          ArrayRef<Value *> Args, BBIterator WhereIt,
+                          BasicBlock *WhereBB, Context &Ctx,
+                          const Twine &NameStr = "");
+  static CallInst *create(FunctionType *FTy, Value *Func,
+                          ArrayRef<Value *> Args, Instruction *InsertBefore,
+                          Context &Ctx, const Twine &NameStr = "");
+  static CallInst *create(FunctionType *FTy, Value *Func,
+                          ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
+                          Context &Ctx, const Twine &NameStr = "");
+
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::Call;
+  }
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+#ifndef NDEBUG
+  void verify() const final {}
+  void dump(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
 /// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
 /// an OpaqueInstr.
 class OpaqueInst : public sandboxir::Instruction {
@@ -983,6 +1173,8 @@ class Context {
   friend StoreInst; // For createStoreInst()
   ReturnInst *createReturnInst(llvm::ReturnInst *I);
   friend ReturnInst; // For createReturnInst()
+  CallInst *createCallInst(llvm::CallInst *I);
+  friend CallInst; // For createCallInst()
 
 public:
   Context(LLVMContext &LLVMCtx)
@@ -1010,7 +1202,7 @@ class Context {
   size_t getNumValues() const { return LLVMValueToValueMap.size(); }
 };
 
-class Function : public sandboxir::Value {
+class Function : public Constant {
   /// Helper for mapped_iterator.
   struct LLVMBBToBB {
     Context &Ctx;
@@ -1021,7 +1213,7 @@ class Function : public sandboxir::Value {
   };
   /// Use Context::createFunction() instead.
   Function(llvm::Function *F, sandboxir::Context &Ctx)
-      : sandboxir::Value(ClassID::Function, F, Ctx) {}
+      : Constant(ClassID::Function, F, Ctx) {}
   friend class Context; // For constructor.
 
 public:
@@ -1047,6 +1239,9 @@ class Function : public sandboxir::Value {
     LLVMBBToBB BBGetter(Ctx);
     return iterator(cast<llvm::Function>(Val)->end(), BBGetter);
   }
+  FunctionType *getFunctionType() const {
+    return cast<llvm::Function>(Val)->getFunctionType();
+  }
 
 #ifndef NDEBUG
   void verify() const final {
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index f3d616774b3fd..5f6fc84fc2e07 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -30,6 +30,9 @@ DEF_INSTR(Br, OP(Br), BranchInst)
 DEF_INSTR(Load, OP(Load), LoadInst)
 DEF_INSTR(Store, OP(Store), StoreInst)
 DEF_INSTR(Ret, OP(Ret), ReturnInst)
+DEF_INSTR(Call, OP(Call), CallInst)
+DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
+DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
 
 #ifdef DEF_VALUE
 #undef DEF_VALUE
diff --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h
index 03cbfe6cb0446..d30eb90594294 100644
--- a/llvm/include/llvm/SandboxIR/Use.h
+++ b/llvm/include/llvm/SandboxIR/Use.h
@@ -21,6 +21,7 @@ namespace llvm::sandboxir {
 class Context;
 class Value;
 class User;
+class CallBase;
 
 /// Represents a Def-use/Use-def edge in SandboxIR.
 /// NOTE: Unlike llvm::Use, this is not an integral part of the use-def chains.
@@ -40,6 +41,7 @@ class Use {
   friend class User;               // For constructor
   friend class OperandUseIterator; // For constructor
   friend class UserUseIterator;    // For accessing members
+  friend class CallBase;           // For LLVMUse
 
 public:
   operator Value *() const { return get(); }
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index ceadb34f53eaf..da482765c7d11 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -16,7 +16,12 @@ using namespace llvm::sandboxir;
 
 Value *Use::get() const { return Ctx->getValue(LLVMUse->get()); }
 
-void Use::set(Value *V) { LLVMUse->set(V->Val); }
+void Use::set(Value *V) {
+  auto &Tracker = Ctx->getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<UseSet>(*this, Tracker));
+  LLVMUse->set(V->Val);
+}
 
 unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }
 
@@ -84,6 +89,25 @@ UserUseIterator &UserUseIterator::operator++() {
   return *this;
 }
 
+OperandUseIterator OperandUseIterator::operator+(unsigned Num) const {
+  sandboxir::Use U = Use.getUser()->getOperandUseInternal(
+      Use.getOperandNo() + Num, /*Verify=*/true);
+  return OperandUseIterator(U);
+}
+
+OperandUseIterator OperandUseIterator::operator-(unsigned Num) const {
+  assert(Use.getOperandNo() >= Num && "Out of bounds!");
+  sandboxir::Use U = Use.getUser()->getOperandUseInternal(
+      Use.getOperandNo() - Num, /*Verify=*/true);
+  return OperandUseIterator(U);
+}
+
+int OperandUseIterator::operator-(const OperandUseIterator &Other) const {
+  int ThisOpNo = Use.getOperandNo();
+  int OtherOpNo = Other.Use.getOperandNo();
+  return ThisOpNo - OtherOpNo;
+}
+
 Value::Value(ClassID SubclassID, llvm::Value *Val, Context &Ctx)
     : SubclassID(SubclassID), Val(Val), Ctx(Ctx) {
 #ifndef NDEBUG
@@ -713,6 +737,78 @@ void ReturnInst::dump() const {
   dump(dbgs());
   dbgs() << "\n";
 }
+#endif // NDEBUG
+
+Value *CallBase::getCalledOperand() const {
+  return Ctx.getValue(cast<llvm::CallBase>(Val)->getCalledOperand());
+}
+
+Use CallBase::getCalledOperandUse() const {
+  llvm::Use *LLVMUse = &cast<llvm::CallBase>(Val)->getCalledOperandUse();
+  return Use(LLVMUse, cast<User>(Ctx.getValue(LLVMUse->getUser())), Ctx);
+}
+
+Function *CallBase::getCalledFunction() const {
+  return cast_or_null<Function>(
+      Ctx.getValue(cast<llvm::CallBase>(Val)->getCalledFunction()));
+}
+Function *CallBase::getCaller() {
+  return cast<Function>(Ctx.getValue(cast<llvm::CallBase>(Val)->getCaller()));
+}
+
+void CallBase::setCalledFunction(Function *F) {
+  // F's function type is private, so we rely on `setCalledFunction()` to update
+  // it. But even though we are calling `setCalledFunction()` we also need to
+  // track this change at the SandboxIR level, which is why we call
+  // `setCalledOperand()` here.
+  // Note: This may break if `setCalledFunction()` early returns if `F`
+  // is already set, but we do have a unit test for it.
+  setCalledOperand(F);
+  cast<llvm::CallBase>(Val)->setCalledFunction(F->getFunctionType(),
+                                               cast<llvm::Function>(F->Val));
+}
+
+CallInst *CallInst::create(FunctionType *FTy, Value *Func,
+                           ArrayRef<Value *> Args, BasicBlock::iterator WhereIt,
+                           BasicBlock *WhereBB, Context &Ctx,
+                           const Twine &NameStr) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt != WhereBB->end())
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  else
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  SmallVector<llvm::Value *> LLVMArgs;
+  LLVMArgs.reserve(Args.size());
+  for (Value *Arg : Args)
+    LLVMArgs.push_back(Arg->Val);
+  llvm::CallInst *NewCI = Builder.CreateCall(FTy, Func->Val, LLVMArgs, NameStr);
+  return Ctx.createCallInst(NewCI);
+}
+
+CallInst *CallInst::create(FunctionType *FTy, Value *Func,
+                           ArrayRef<Value *> Args, Instruction *InsertBefore,
+                           Context &Ctx, const Twine &NameStr) {
+  return CallInst::create(FTy, Func, Args, InsertBefore->getIterator(),
+                          InsertBefore->getParent(), Ctx, NameStr);
+}
+
+CallInst *CallInst::create(FunctionType *FTy, Value *Func,
+                           ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
+                           Context &Ctx, const Twine &NameStr) {
+  return CallInst::create(FTy, Func, Args, InsertAtEnd->end(), InsertAtEnd, Ctx,
+                          NameStr);
+}
+
+#ifndef NDEBUG
+void CallInst::dump(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+
+void CallInst::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
 
 void OpaqueInst::dump(raw_ostream &OS) const {
   dumpCommonPrefix(OS);
@@ -819,7 +915,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     return It->second.get();
 
   if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
-    It->second = std::unique_ptr<Constant>(new Constant(C, *this));
+    if (auto *F = dyn_cast<llvm::Function>(LLVMV))
+      It->second = std::unique_ptr<Function>(new Function(F, *this));
+    else
+      It->second = std::unique_ptr<Constant>(new Constant(C, *this));
     auto *NewC = It->second.get();
     for (llvm::Value *COp : C->operands())
       getOrCreateValueInternal(COp, C);
@@ -864,6 +963,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
     return It->second.get();
   }
+  case llvm::Instruction::Call: {
+    auto *LLVMCall = cast<llvm::CallInst>(LLVMV);
+    It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
+    return It->second.get();
+  }
   default:
     break;
   }
@@ -907,6 +1011,11 @@ ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
   return cast<ReturnInst>(registerValue(std::move(NewPtr)));
 }
 
+CallInst *Context::createCallInst(llvm::CallInst *I) {
+  auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
+  return cast<CallInst>(registerValue(std::move(NewPtr)));
+}
+
 Value *Context::getValue(llvm::Value *V) const {
   auto It = LLVMValueToValueMap.find(V);
   if (It != LLVMValueToValueMap.end())
@@ -917,13 +1026,13 @@ Value *Context::getValue(llvm::Value *V) const {
 Function *Context::createFunction(llvm::Function *F) {
   assert(getValue(F) == nullptr && "Already exists!");
   auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
+  auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
   // Create arguments.
   for (auto &Arg : F->args())
     getOrCreateArgument(&Arg);
   // Create BBs.
   for (auto &BB : *F)
     createBasicBlock(&BB);
-  auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
   return SBF;
 }
 
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index c600103fe10c6..05ec42c952eb6 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -90,7 +90,7 @@ define void @foo(i32 %v1) {
   EXPECT_FALSE(isa<sandboxir::Instruction>(Const0));
   EXPECT_TRUE(isa<sandboxir::Instruction>(OpaqueI));
 
-  EXPECT_FALSE(isa<sandboxir::User>(F));
+  EXPECT_TRUE(isa<sandboxir::User>(F));
   EXPECT_FALSE(isa<sandboxir::User>(Arg0));
   EXPECT_FALSE(isa<sandboxir::User>(BB));
   EXPECT_TRUE(isa<sandboxir::User>(AddI));
@@ -180,8 +180,8 @@ define i32 @foo(i32 %v0, i32 %v1) {
   BS << "\n";
   I0->getOperandUse(0).dump(BS);
   EXPECT_EQ(Buff, R"IR(
-Def:  i32 %v0 ; SB1. (Argument)
-User:   %add0 = add i32 %v0, %v1 ; SB4. (Opaque)
+Def:  i32 %v0 ; SB2. (Argument)
+User:   %add0 = add i32 %v0, %v1 ; SB5. (Opaque)
 OperandNo: 0
 )IR");
 #endif // NDEBUG
@@ -398,10 +398,10 @@ define void @foo(i32 %arg0, i32 %arg1) {
     EXPECT_EQ(Buff, R"IR(
 void @foo(i32 %arg0, i32 %arg1) {
 bb0:
-  br label %bb1 ; SB3. (Br)
+  br label %bb1 ; SB4. (Br)
 
 bb1:
-  ret void ; SB5. (Ret)
+  ret void ; SB6. (Ret)
 }
 )IR");
   }
@@ -466,7 +466,7 @@ define void @foo(i32 %v1) {
     BB0.dump(BS);
     EXPECT_EQ(Buff, R"IR(
 bb0:
-  br label %bb1 ; SB2. (Br)
+  br label %bb1 ; SB3. (Br)
 )IR");
   }
 #endif // NDEBUG
@@ -836,3 +836,203 @@ define i8 @foo(i8 %val) {
       sandboxir::ReturnInst::create(Val, /*InsertAtEnd=*/BB, Ctx));
   EXPECT_EQ(NewRet4->getReturnValue(), Val);
 }
+
+TEST_F(SandboxIRTest, CallBase) {
+  parseIR(C, R"IR(
+declare void @bar1(i8)
+declare void @bar2()
+declare void @bar3()
+declare void @variadic(ptr, ...)
+
+define i8 @foo(i8 %arg0, i32 %arg1, ptr %indirectFoo) {
+  %call = call i8 @foo(i8 %arg0, i32 %arg1)
+  call void @bar1(i8 %arg0)
+  call void @bar2()
+  call void %indirectFoo()
+  call void @bar2() noreturn
+  tail call fastcc void @bar2()
+  call void (ptr, ...) @variadic(ptr %indirectFoo, i32 1)
+  ret i8 %call
+}
+)IR");
+  llvm::Function &LLVMF = *M->getFunction("foo");
+  unsigned ArgIdx = 0;
+  llvm::Argument *LLVMArg0 = LLVMF.getArg(ArgIdx++);
+  llvm::Argument *LLVMArg1 = LLVMF.getArg(ArgIdx++);
+  llvm::BasicBlock *LLVMBB = &*LLVMF.begin();
+  SmallVector<llvm::CallBase *, 8> LLVMCalls;
+  auto LLVMIt = LLVMBB->begin();
+  while (isa<llvm::CallBase>(&*LLVMIt))
+    LLVMCalls.push_back(cast<llvm::CallBase>(&*LLVMIt++));
+
+  sandboxir::Context Ctx(C);
+  sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
+
+  for (llvm::CallBase *LLVMCall : LLVMCalls) {
+    // Check classof(Instruction *).
+    auto *Call = cast<sandboxir::CallBase>(Ctx.getValue(LLVMCall));
+    // Check classof(Value *).
+    EXPECT_TRUE(isa<sandboxir::CallBase>((sandboxir::Value *)Call));
+    // Check getFunctionType().
+    EXPECT_EQ(Call->getFunctionType(), LLVMCall->getFunctionType());
+    // Check data_ops().
+    EXPECT_EQ(range_size(Call->data_ops()), range_size(LLVMCall->data_ops()));
+    auto DataOpIt = Call->data_operands_begin();
+    for (llvm::Use &LLVMUse : LLVMCall->data_ops()) {
+      Value *LLVMOp = LLVMUse.get();
+      sandboxir::Use Use = *DataOpIt++;
+      EXPECT_EQ(Ctx.getValue(LLVMOp), Use.get());
+      // Check isDataOperand().
+      EXPECT_EQ(Call->isDataOperand(Use), LLVMCall->isDataOperand(&LLVMUse));
+      // Check getDataOperandNo().
+      EXPECT_EQ(Call->getDataOperandNo(Use),
+                LLVMCall->getDataOperandNo(&LLVMUse));
+      // Check isArgOperand().
+      EXPECT_EQ(Call->isArgOperand(Use), LLVMCall->isArgOperand(&LLVMUse));
+      // Check isCallee().
+      EXPECT_EQ(Call->isCallee(Use), LLVMCall->isCallee(&LLVMUse));
+    }
+    // Check data_operands_empty().
+    EXPECT_EQ(Call->data_operands_empty(), LLVMCall->data_operands_empty());
+    // Check data_operands_size().
+    EXPECT_EQ(Call->data_operands_size(), LLVMCall->data_operands_size());
+    // Check getNumTotalBundleOperands().
+    EXPECT_EQ(Call->getNumTotalBundleOperands(),
+              LLVMCall->getNumTotalBundleOperands());
+    // Check args().
+    EXPECT_EQ(range_size(Call->args()), range_size(LLVMCall->args()));
+    auto ArgIt = Call->arg_begin();
+    for (llvm::Use &LLVMUse : LLVMCall->args()) {
+      Value *LLVMArg = LLVMUse.get();
+      sandboxir::Use Use = *ArgIt++;
+      EXPECT_EQ(Ctx.getValue(LLVMArg), Use.get());
+    }
+    // Check arg_empty().
+    EXPECT_EQ(Call->arg_empty(), LLVMCall->arg_empty());
+    // Check arg_size().
+    EXPECT_EQ(Call->arg_size(), LLVMCall->arg_size());
+    for (unsigned ArgIdx = 0, E = Call->arg_size(); ArgIdx != E; ++ArgIdx) {
+      // Check getArgOperand().
+      EXPECT_EQ(Call->getArgOperand(ArgIdx),
+                Ctx.getValue(LLVMCall->getArgOperand(ArgIdx)));
+      // Check getArgOperandUse().
+      sandboxir::Use Use = Call->getArgOperandUse(ArgIdx);
+      llvm::Use &LLVMUse = LLVMCall->getArgOperandUse(ArgIdx);
+      EXPECT_EQ(Use.get(), Ctx.getValue(LLVMUse.get()));
+      // Check getArgOperandNo().
+      EXPECT_EQ(Call->getArgOperandNo(Use),
+                LLVMCall->getArgOperandNo(&LLVMUse));
+    }
+    // Check hasArgument().
+    SmallVector<llvm::Value *> TestArgs(
+        {LLVMArg0, LLVMArg1, &LLVMF, LLVMBB, LLVMCall});
+    for (llvm::Value *LLVMV : TestArgs) {
+      sandboxir::Value *V = Ctx.getValue(LLVMV);
+      EXPECT_EQ(Call->hasArgument(V), LLVMCall->hasArgument(LLVMV));
+    }
+    // Check getCalledOperand().
+    EXPECT_EQ(Call->getCalledOperand(),
+              Ctx.getValue(LLVMCall->getCalledOperand()));
+    // Check getCalledOperandUse().
+    EXPECT_EQ(Call->getCalledOperandUse().get(),
+              Ctx.getValue(LLVMCall->getCalledOperandUse()));
+    // Check getCalledFunction().
+    if (LLVMCall->getCalledFunction() == nullptr)
+      EXPECT_EQ(Call->getCalledFunction(), nullptr);
+    else {
+      auto *LLVMCF = cast<llvm::Function>(LLVMCall->getCalledFunction());
+      (void)LLVMCF;
+      EXPECT_EQ(Call->getCalledFunction(),
+                cast<sandboxir::Function>(
+                    Ctx.getValue(LLVMCall->getCalledFunction())));
+    }
+    // Check isIndirectCall().
+    EXPECT_EQ(Call->isIndirectCall(), LLVMCall->isIndirectCall());
+    // Check getCaller().
+    EXPECT_EQ(Call->getCaller(), Ctx.getValue(LLVMCall->getCaller()));
+    // Check isMustTailCall().
+    EXPECT_EQ(Call->isMustTailCall(), LLVMCall->isMustTailCall());
+    // Check isTailCall().
+    EXPECT_EQ(Call->isTailCall(), LLVMCall->isTailCall());
+    // Check getIntrinsicID().
+    EXPECT_EQ(Call->getIntrinsicID(), LLVMCall->getIntrinsicID());
+    // Check getCallingConv().
+    EXPECT_EQ(Call->getCallingConv(), LLVMCall->getCallingConv());
+    // Check isInlineAsm().
+    EXPECT_EQ(Call->isInlineAsm(), LLVMCall->isInlineAsm());
+  }
+
+  auto *Arg0 = F.getArg(0);
+  auto *Arg1 = F.getArg(1);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *Call0 = cast<sandboxir::CallBase>(&*It++);
+  [[maybe_unused]] auto *Call1 = cast<sandboxir::CallBase>(&*It++);
+  auto *Call2 = cast<sandboxir::CallBase>(&*It++);
+  // Check setArgOperand
+  Call0->setArgOperand(0, Arg1);
+  EXPECT_EQ(Call0->getArgOperand(0), Arg1);
+  Call0->setArgOperand(0, Arg0);
+  EXPECT_EQ(Call0->getArgOperand(0), Arg0);
+
+  auto *Bar3F = Ctx.createFunction(M->getFunction("bar3"));
+
+  // Check setCalledOperand
+  auto *SvOp = Call0->getCalledOperand();
+  Call0->setCalledOperand(Bar3F);
+  EXPECT_EQ(Call0->getCalledOperand(), Bar3F);
+  Call0->setCalledOperand(SvOp);
+  // Check setCalledFunction
+  Call2->setCalledFunction(Bar3F);
+  EXPECT_EQ(Call2->getCalledFunction(), Bar3F);
+}
+
+TEST_F(SandboxIRTest, CallInst) {
+  parseIR(C, R"IR(
+define i8 @foo(i8 %arg) {
+  %call = call i8 @foo(i8 %arg)
+  ret i8 %call
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  unsigned ArgIdx = 0;
+  auto *Arg0 = F.getArg(ArgIdx++);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *Call = cast<sandboxir::CallInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+  EXPECT_EQ(Call->getNumOperands(), 2u);
+  EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Ret);
+  FunctionType *FTy = F.getFunctionType();
+  SmallVector<sandboxir::Value *, 1> Args;
+  Args.push_back(Arg0);
+  {
+    // Check create() WhereIt.
+    auto *Call = cast<sandboxir::CallInst>(sandboxir::CallInst::create(
+        FTy, &F, Args, /*WhereIt=*/Ret->getIterator(), BB, Ctx));
+    EXPECT_EQ(Call->getNextNode(), Ret);
+    EXPECT_EQ(Call->getCalledFunction(), &F);
+    EXPECT_EQ(range_size(Call->args()), 1u);
+    EXPECT_EQ(Call->getArgOperand(0), Arg0);
+  }
+  {
+    // Check create() InsertBefore.
+    auto *Call = cast<sandboxir::CallInst>(
+        sandboxir::CallInst::create(FTy, &F, Args, /*InsertBefore=*/Ret, Ctx));
+    EXPECT_EQ(Call->getNextNode(), Ret);
+    EXPECT_EQ(Call->getCalledFunction(), &F);
+    EXPECT_EQ(range_size(Call->args()), 1u);
+    EXPECT_EQ(Call->getArgOperand(0), Arg0);
+  }
+  {
+    // Check create() InsertAtEnd.
+    auto *Call = cast<sandboxir::CallInst>(
+        sandboxir::CallInst::create(FTy, &F, Args, /*InsertAtEnd=*/BB, Ctx));
+    EXPECT_EQ(Call->getPrevNode(), Ret);
+    EXPECT_EQ(Call->getCalledFunction(), &F);
+    EXPECT_EQ(range_size(Call->args()), 1u);
+    EXPECT_EQ(Call->getArgOperand(0), Arg0);
+  }
+}
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index dd9dcd543236e..5111d5f38798f 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -69,6 +69,34 @@ define void @foo(ptr %ptr) {
   EXPECT_EQ(Ld->getOperand(0), Gep0);
 }
 
+TEST_F(TrackerTest, SetUse) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %arg) {
+  %ld = load i8, ptr %ptr
+  %add = add i8 %ld, %arg
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(&LLVMF);
+  unsigned ArgIdx = 0;
+  auto *Arg0 = F->getArg(ArgIdx++);
+  auto *BB = &*F->begin();
+  auto &Tracker = Ctx.getTracker();
+  Tracker.save();
+  auto It = BB->begin();
+  auto *Ld = &*It++;
+  auto *Add = &*It++;
+
+  Ctx.save();
+  sandboxir::Use Use = Add->getOperandUse(0);
+  Use.set(Arg0);
+  EXPECT_EQ(Add->getOperand(0), Arg0);
+  Ctx.revert();
+  EXPECT_EQ(Add->getOperand(0), Ld);
+}
+
 TEST_F(TrackerTest, SwapOperands) {
   parseIR(C, R"IR(
 define void @foo(i1 %cond) {
@@ -413,3 +441,50 @@ define i32 @foo(i32 %arg) {
   EXPECT_EQ(&*It++, Ret);
   EXPECT_EQ(It, BB->end());
 }
+
+TEST_F(TrackerTest, CallBaseSetters) {
+  parseIR(C, R"IR(
+declare void @bar1(i8)
+declare void @bar2(i8)
+
+define void @foo(i8 %arg0, i8 %arg1) {
+  call void @bar1(i8 %arg0)
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto *F = Ctx.createFunction(&LLVMF);
+  unsigned ArgIdx = 0;
+  auto *Arg0 = F->getArg(ArgIdx++);
+  auto *Arg1 = F->getArg(ArgIdx++);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Call = cast<sandboxir::CallBase>(&*It++);
+  [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  // Check setArgOperand().
+  Ctx.save();
+  Call->setArgOperand(0, Arg1);
+  EXPECT_EQ(Call->getArgOperand(0), Arg1);
+  Ctx.revert();
+  EXPECT_EQ(Call->getArgOperand(0), Arg0);
+
+  auto *Bar1F = Call->getCalledFunction();
+  auto *Bar2F = Ctx.createFunction(M->getFunction("bar2"));
+
+  // Check setCalledOperand().
+  Ctx.save();
+  Call->setCalledOperand(Bar2F);
+  EXPECT_EQ(Call->getCalledOperand(), Bar2F);
+  Ctx.revert();
+  EXPECT_EQ(Call->getCalledOperand(), Bar1F);
+
+  // Check setCalledFunction().
+  Ctx.save();
+  Call->setCalledFunction(Bar2F);
+  EXPECT_EQ(Call->getCalledFunction(), Bar2F);
+  Ctx.revert();
+  EXPECT_EQ(Call->getCalledFunction(), Bar1F);
+}



More information about the llvm-commits mailing list