[llvm] 372a6be - [SandboxIR] Implement CallBase and CallInst (#100218)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 26 11:13:45 PDT 2024
Author: vporpo
Date: 2024-07-26T11:13:41-07:00
New Revision: 372a6beac65851abe6d8016df397f5cea32ffe9b
URL: https://github.com/llvm/llvm-project/commit/372a6beac65851abe6d8016df397f5cea32ffe9b
DIFF: https://github.com/llvm/llvm-project/commit/372a6beac65851abe6d8016df397f5cea32ffe9b.diff
LOG: [SandboxIR] Implement CallBase and CallInst (#100218)
This patch adds the `CallBase` SandboxIR class and its subclass:
`CallInst`. Both are mirrors of `llvm::CallBase` and `llvm::CallInst`
respectively. Since `llvm::CallBase` contains a large number of member
functions so this patch implements only some of them.
The `CallBase` unit tests uncovered an issue with the class hierarchy,
where `sandboxir::Function` was not a subclass of `sandboxir::Constant`,
so this was fixed.
Testing tracking of the `CallBase` setters showed that
`sandboxir::Use::set()` was not being tracked. So this is also part of
this patch.
Added:
Modified:
llvm/include/llvm/SandboxIR/SandboxIR.h
llvm/include/llvm/SandboxIR/SandboxIRValues.def
llvm/include/llvm/SandboxIR/Use.h
llvm/lib/SandboxIR/SandboxIR.cpp
llvm/unittests/SandboxIR/SandboxIRTest.cpp
llvm/unittests/SandboxIR/TrackerTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 6c04c92e3e70e..2678ee0f4f90a 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -18,13 +18,19 @@
//
// namespace sandboxir {
//
-// +- Argument +- BinaryOperator
-// | |
-// Value -+- BasicBlock +- BranchInst
-// | |
-// +- Function +- Constant +- CastInst
-// | | |
-// +- User ------+- Instruction -+- CallInst
+// Value -+- Argument
+// |
+// +- BasicBlock
+// |
+// +- User ------+- Constant ------ Function
+// |
+// +- Instruction -+- BinaryOperator
+// |
+// +- BranchInst
+// |
+// +- CastInst
+// |
+// +- CallBase ----- CallInst
// |
// +- CmpInst
// |
@@ -82,6 +88,8 @@ class ReturnInst;
class StoreInst;
class User;
class Value;
+class CallBase;
+class CallInst;
/// Iterator for the `Use` edges of a User's operands.
/// \Returns the operand `Use` when dereferenced.
@@ -103,12 +111,20 @@ class OperandUseIterator {
OperandUseIterator() = default;
value_type operator*() const;
OperandUseIterator &operator++();
+ OperandUseIterator operator++(int) {
+ auto Copy = *this;
+ this->operator++();
+ return Copy;
+ }
bool operator==(const OperandUseIterator &Other) const {
return Use == Other.Use;
}
bool operator!=(const OperandUseIterator &Other) const {
return !(*this == Other);
}
+ OperandUseIterator operator+(unsigned Num) const;
+ OperandUseIterator operator-(unsigned Num) const;
+ int operator-(const OperandUseIterator &Other) const;
};
/// Iterator for the `Use` edges of a Value's users.
@@ -135,6 +151,7 @@ class UserUseIterator {
bool operator!=(const UserUseIterator &Other) const {
return !(*this == Other);
}
+ const sandboxir::Use &getUse() const { return Use; }
};
/// A SandboxIR Value has users. This is the base class.
@@ -184,6 +201,8 @@ class Value {
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
+ friend class CallBase; // For getting `Val`.
+ friend class CallInst; // For getting `Val`.
/// All values point to the context.
Context &Ctx;
@@ -417,7 +436,10 @@ class User : public Value {
class Constant : public sandboxir::User {
Constant(llvm::Constant *C, sandboxir::Context &SBCtx)
: sandboxir::User(ClassID::Constant, C, SBCtx) {}
- friend class Context; // For constructor.
+ Constant(ClassID ID, llvm::Constant *C, sandboxir::Context &SBCtx)
+ : sandboxir::User(ID, C, SBCtx) {}
+ friend class Function; // For constructor
+ friend class Context; // For constructor.
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
@@ -435,7 +457,7 @@ class Constant : public sandboxir::User {
return getUseOperandNoDefault(Use);
}
#ifndef NDEBUG
- void verify() const final {
+ void verify() const override {
assert(isa<llvm::Constant>(Val) && "Expected Constant!");
}
friend raw_ostream &operator<<(raw_ostream &OS,
@@ -518,6 +540,7 @@ class Instruction : public sandboxir::User {
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
+ friend class CallInst; // For getTopmostLLVMInstruction().
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
@@ -835,6 +858,177 @@ class ReturnInst final : public Instruction {
#endif
};
+class CallBase : public Instruction {
+ CallBase(ClassID ID, Opcode Opc, llvm::Instruction *I, Context &Ctx)
+ : Instruction(ID, Opc, I, Ctx) {}
+ friend class CallInst; // For constructor.
+
+public:
+ static bool classof(const Value *From) {
+ auto Opc = From->getSubclassID();
+ return Opc == Instruction::ClassID::Call ||
+ Opc == Instruction::ClassID::Invoke ||
+ Opc == Instruction::ClassID::CallBr;
+ }
+
+ FunctionType *getFunctionType() const {
+ return cast<llvm::CallBase>(Val)->getFunctionType();
+ }
+
+ op_iterator data_operands_begin() { return op_begin(); }
+ const_op_iterator data_operands_begin() const {
+ return const_cast<CallBase *>(this)->data_operands_begin();
+ }
+ op_iterator data_operands_end() {
+ auto *LLVMCB = cast<llvm::CallBase>(Val);
+ auto Dist = LLVMCB->data_operands_end() - LLVMCB->data_operands_begin();
+ return op_begin() + Dist;
+ }
+ const_op_iterator data_operands_end() const {
+ auto *LLVMCB = cast<llvm::CallBase>(Val);
+ auto Dist = LLVMCB->data_operands_end() - LLVMCB->data_operands_begin();
+ return op_begin() + Dist;
+ }
+ iterator_range<op_iterator> data_ops() {
+ return make_range(data_operands_begin(), data_operands_end());
+ }
+ iterator_range<const_op_iterator> data_ops() const {
+ return make_range(data_operands_begin(), data_operands_end());
+ }
+ bool data_operands_empty() const {
+ return data_operands_end() == data_operands_begin();
+ }
+ unsigned data_operands_size() const {
+ return std::distance(data_operands_begin(), data_operands_end());
+ }
+ bool isDataOperand(Use U) const {
+ assert(this == U.getUser() &&
+ "Only valid to query with a use of this instruction!");
+ return cast<llvm::CallBase>(Val)->isDataOperand(U.LLVMUse);
+ }
+ unsigned getDataOperandNo(Use U) const {
+ assert(isDataOperand(U) && "Data operand # out of range!");
+ return cast<llvm::CallBase>(Val)->getDataOperandNo(U.LLVMUse);
+ }
+
+ /// Return the total number operands (not operand bundles) used by
+ /// every operand bundle in this OperandBundleUser.
+ unsigned getNumTotalBundleOperands() const {
+ return cast<llvm::CallBase>(Val)->getNumTotalBundleOperands();
+ }
+
+ op_iterator arg_begin() { return op_begin(); }
+ const_op_iterator arg_begin() const { return op_begin(); }
+ op_iterator arg_end() {
+ return data_operands_end() - getNumTotalBundleOperands();
+ }
+ const_op_iterator arg_end() const {
+ return const_cast<CallBase *>(this)->arg_end();
+ }
+ iterator_range<op_iterator> args() {
+ return make_range(arg_begin(), arg_end());
+ }
+ iterator_range<const_op_iterator> args() const {
+ return make_range(arg_begin(), arg_end());
+ }
+ bool arg_empty() const { return arg_end() == arg_begin(); }
+ unsigned arg_size() const { return arg_end() - arg_begin(); }
+
+ Value *getArgOperand(unsigned OpIdx) const {
+ assert(OpIdx < arg_size() && "Out of bounds!");
+ return getOperand(OpIdx);
+ }
+ void setArgOperand(unsigned OpIdx, Value *NewOp) {
+ assert(OpIdx < arg_size() && "Out of bounds!");
+ setOperand(OpIdx, NewOp);
+ }
+
+ Use getArgOperandUse(unsigned Idx) const {
+ assert(Idx < arg_size() && "Out of bounds!");
+ return getOperandUse(Idx);
+ }
+ Use getArgOperandUse(unsigned Idx) {
+ assert(Idx < arg_size() && "Out of bounds!");
+ return getOperandUse(Idx);
+ }
+
+ bool isArgOperand(Use U) const {
+ return cast<llvm::CallBase>(Val)->isArgOperand(U.LLVMUse);
+ }
+ unsigned getArgOperandNo(Use U) const {
+ return cast<llvm::CallBase>(Val)->getArgOperandNo(U.LLVMUse);
+ }
+ bool hasArgument(const Value *V) const { return is_contained(args(), V); }
+
+ Value *getCalledOperand() const;
+ Use getCalledOperandUse() const;
+
+ Function *getCalledFunction() const;
+ bool isIndirectCall() const {
+ return cast<llvm::CallBase>(Val)->isIndirectCall();
+ }
+ bool isCallee(Use U) const {
+ return cast<llvm::CallBase>(Val)->isCallee(U.LLVMUse);
+ }
+ Function *getCaller();
+ const Function *getCaller() const {
+ return const_cast<CallBase *>(this)->getCaller();
+ }
+ bool isMustTailCall() const {
+ return cast<llvm::CallBase>(Val)->isMustTailCall();
+ }
+ bool isTailCall() const { return cast<llvm::CallBase>(Val)->isTailCall(); }
+ Intrinsic::ID getIntrinsicID() const {
+ return cast<llvm::CallBase>(Val)->getIntrinsicID();
+ }
+ void setCalledOperand(Value *V) { getCalledOperandUse().set(V); }
+ void setCalledFunction(Function *F);
+ CallingConv::ID getCallingConv() const {
+ return cast<llvm::CallBase>(Val)->getCallingConv();
+ }
+ bool isInlineAsm() const { return cast<llvm::CallBase>(Val)->isInlineAsm(); }
+};
+
+class CallInst final : public CallBase {
+ /// Use Context::createCallInst(). Don't call the
+ /// constructor directly.
+ CallInst(llvm::Instruction *I, Context &Ctx)
+ : CallBase(ClassID::Call, Opcode::Call, I, Ctx) {}
+ friend class Context; // For accessing the constructor in
+ // create*()
+ Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+ return getOperandUseDefault(OpIdx, Verify);
+ }
+ SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+ return {cast<llvm::Instruction>(Val)};
+ }
+
+public:
+ static CallInst *create(FunctionType *FTy, Value *Func,
+ ArrayRef<Value *> Args, BBIterator WhereIt,
+ BasicBlock *WhereBB, Context &Ctx,
+ const Twine &NameStr = "");
+ static CallInst *create(FunctionType *FTy, Value *Func,
+ ArrayRef<Value *> Args, Instruction *InsertBefore,
+ Context &Ctx, const Twine &NameStr = "");
+ static CallInst *create(FunctionType *FTy, Value *Func,
+ ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
+ Context &Ctx, const Twine &NameStr = "");
+
+ static bool classof(const Value *From) {
+ return From->getSubclassID() == ClassID::Call;
+ }
+ unsigned getUseOperandNo(const Use &Use) const final {
+ return getUseOperandNoDefault(Use);
+ }
+ unsigned getNumOfIRInstrs() const final { return 1u; }
+#ifndef NDEBUG
+ void verify() const final {}
+ void dump(raw_ostream &OS) const override;
+ LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
/// an OpaqueInstr.
class OpaqueInst : public sandboxir::Instruction {
@@ -983,6 +1177,8 @@ class Context {
friend StoreInst; // For createStoreInst()
ReturnInst *createReturnInst(llvm::ReturnInst *I);
friend ReturnInst; // For createReturnInst()
+ CallInst *createCallInst(llvm::CallInst *I);
+ friend CallInst; // For createCallInst()
public:
Context(LLVMContext &LLVMCtx)
@@ -1010,7 +1206,7 @@ class Context {
size_t getNumValues() const { return LLVMValueToValueMap.size(); }
};
-class Function : public sandboxir::Value {
+class Function : public Constant {
/// Helper for mapped_iterator.
struct LLVMBBToBB {
Context &Ctx;
@@ -1021,7 +1217,7 @@ class Function : public sandboxir::Value {
};
/// Use Context::createFunction() instead.
Function(llvm::Function *F, sandboxir::Context &Ctx)
- : sandboxir::Value(ClassID::Function, F, Ctx) {}
+ : Constant(ClassID::Function, F, Ctx) {}
friend class Context; // For constructor.
public:
@@ -1047,6 +1243,9 @@ class Function : public sandboxir::Value {
LLVMBBToBB BBGetter(Ctx);
return iterator(cast<llvm::Function>(Val)->end(), BBGetter);
}
+ FunctionType *getFunctionType() const {
+ return cast<llvm::Function>(Val)->getFunctionType();
+ }
#ifndef NDEBUG
void verify() const final {
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index f3d616774b3fd..5f6fc84fc2e07 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -30,6 +30,9 @@ DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
+DEF_INSTR(Call, OP(Call), CallInst)
+DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
+DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
#ifdef DEF_VALUE
#undef DEF_VALUE
diff --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h
index 03cbfe6cb0446..d30eb90594294 100644
--- a/llvm/include/llvm/SandboxIR/Use.h
+++ b/llvm/include/llvm/SandboxIR/Use.h
@@ -21,6 +21,7 @@ namespace llvm::sandboxir {
class Context;
class Value;
class User;
+class CallBase;
/// Represents a Def-use/Use-def edge in SandboxIR.
/// NOTE: Unlike llvm::Use, this is not an integral part of the use-def chains.
@@ -40,6 +41,7 @@ class Use {
friend class User; // For constructor
friend class OperandUseIterator; // For constructor
friend class UserUseIterator; // For accessing members
+ friend class CallBase; // For LLVMUse
public:
operator Value *() const { return get(); }
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index ceadb34f53eaf..da482765c7d11 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -16,7 +16,12 @@ using namespace llvm::sandboxir;
Value *Use::get() const { return Ctx->getValue(LLVMUse->get()); }
-void Use::set(Value *V) { LLVMUse->set(V->Val); }
+void Use::set(Value *V) {
+ auto &Tracker = Ctx->getTracker();
+ if (Tracker.isTracking())
+ Tracker.track(std::make_unique<UseSet>(*this, Tracker));
+ LLVMUse->set(V->Val);
+}
unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }
@@ -84,6 +89,25 @@ UserUseIterator &UserUseIterator::operator++() {
return *this;
}
+OperandUseIterator OperandUseIterator::operator+(unsigned Num) const {
+ sandboxir::Use U = Use.getUser()->getOperandUseInternal(
+ Use.getOperandNo() + Num, /*Verify=*/true);
+ return OperandUseIterator(U);
+}
+
+OperandUseIterator OperandUseIterator::operator-(unsigned Num) const {
+ assert(Use.getOperandNo() >= Num && "Out of bounds!");
+ sandboxir::Use U = Use.getUser()->getOperandUseInternal(
+ Use.getOperandNo() - Num, /*Verify=*/true);
+ return OperandUseIterator(U);
+}
+
+int OperandUseIterator::operator-(const OperandUseIterator &Other) const {
+ int ThisOpNo = Use.getOperandNo();
+ int OtherOpNo = Other.Use.getOperandNo();
+ return ThisOpNo - OtherOpNo;
+}
+
Value::Value(ClassID SubclassID, llvm::Value *Val, Context &Ctx)
: SubclassID(SubclassID), Val(Val), Ctx(Ctx) {
#ifndef NDEBUG
@@ -713,6 +737,78 @@ void ReturnInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
+#endif // NDEBUG
+
+Value *CallBase::getCalledOperand() const {
+ return Ctx.getValue(cast<llvm::CallBase>(Val)->getCalledOperand());
+}
+
+Use CallBase::getCalledOperandUse() const {
+ llvm::Use *LLVMUse = &cast<llvm::CallBase>(Val)->getCalledOperandUse();
+ return Use(LLVMUse, cast<User>(Ctx.getValue(LLVMUse->getUser())), Ctx);
+}
+
+Function *CallBase::getCalledFunction() const {
+ return cast_or_null<Function>(
+ Ctx.getValue(cast<llvm::CallBase>(Val)->getCalledFunction()));
+}
+Function *CallBase::getCaller() {
+ return cast<Function>(Ctx.getValue(cast<llvm::CallBase>(Val)->getCaller()));
+}
+
+void CallBase::setCalledFunction(Function *F) {
+ // F's function type is private, so we rely on `setCalledFunction()` to update
+ // it. But even though we are calling `setCalledFunction()` we also need to
+ // track this change at the SandboxIR level, which is why we call
+ // `setCalledOperand()` here.
+ // Note: This may break if `setCalledFunction()` early returns if `F`
+ // is already set, but we do have a unit test for it.
+ setCalledOperand(F);
+ cast<llvm::CallBase>(Val)->setCalledFunction(F->getFunctionType(),
+ cast<llvm::Function>(F->Val));
+}
+
+CallInst *CallInst::create(FunctionType *FTy, Value *Func,
+ ArrayRef<Value *> Args, BasicBlock::iterator WhereIt,
+ BasicBlock *WhereBB, Context &Ctx,
+ const Twine &NameStr) {
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ if (WhereIt != WhereBB->end())
+ Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+ else
+ Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+ SmallVector<llvm::Value *> LLVMArgs;
+ LLVMArgs.reserve(Args.size());
+ for (Value *Arg : Args)
+ LLVMArgs.push_back(Arg->Val);
+ llvm::CallInst *NewCI = Builder.CreateCall(FTy, Func->Val, LLVMArgs, NameStr);
+ return Ctx.createCallInst(NewCI);
+}
+
+CallInst *CallInst::create(FunctionType *FTy, Value *Func,
+ ArrayRef<Value *> Args, Instruction *InsertBefore,
+ Context &Ctx, const Twine &NameStr) {
+ return CallInst::create(FTy, Func, Args, InsertBefore->getIterator(),
+ InsertBefore->getParent(), Ctx, NameStr);
+}
+
+CallInst *CallInst::create(FunctionType *FTy, Value *Func,
+ ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
+ Context &Ctx, const Twine &NameStr) {
+ return CallInst::create(FTy, Func, Args, InsertAtEnd->end(), InsertAtEnd, Ctx,
+ NameStr);
+}
+
+#ifndef NDEBUG
+void CallInst::dump(raw_ostream &OS) const {
+ dumpCommonPrefix(OS);
+ dumpCommonSuffix(OS);
+}
+
+void CallInst::dump() const {
+ dump(dbgs());
+ dbgs() << "\n";
+}
void OpaqueInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
@@ -819,7 +915,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return It->second.get();
if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
- It->second = std::unique_ptr<Constant>(new Constant(C, *this));
+ if (auto *F = dyn_cast<llvm::Function>(LLVMV))
+ It->second = std::unique_ptr<Function>(new Function(F, *this));
+ else
+ It->second = std::unique_ptr<Constant>(new Constant(C, *this));
auto *NewC = It->second.get();
for (llvm::Value *COp : C->operands())
getOrCreateValueInternal(COp, C);
@@ -864,6 +963,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
return It->second.get();
}
+ case llvm::Instruction::Call: {
+ auto *LLVMCall = cast<llvm::CallInst>(LLVMV);
+ It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
+ return It->second.get();
+ }
default:
break;
}
@@ -907,6 +1011,11 @@ ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
return cast<ReturnInst>(registerValue(std::move(NewPtr)));
}
+CallInst *Context::createCallInst(llvm::CallInst *I) {
+ auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
+ return cast<CallInst>(registerValue(std::move(NewPtr)));
+}
+
Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
if (It != LLVMValueToValueMap.end())
@@ -917,13 +1026,13 @@ Value *Context::getValue(llvm::Value *V) const {
Function *Context::createFunction(llvm::Function *F) {
assert(getValue(F) == nullptr && "Already exists!");
auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
+ auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
// Create arguments.
for (auto &Arg : F->args())
getOrCreateArgument(&Arg);
// Create BBs.
for (auto &BB : *F)
createBasicBlock(&BB);
- auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
return SBF;
}
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index c600103fe10c6..05ec42c952eb6 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -90,7 +90,7 @@ define void @foo(i32 %v1) {
EXPECT_FALSE(isa<sandboxir::Instruction>(Const0));
EXPECT_TRUE(isa<sandboxir::Instruction>(OpaqueI));
- EXPECT_FALSE(isa<sandboxir::User>(F));
+ EXPECT_TRUE(isa<sandboxir::User>(F));
EXPECT_FALSE(isa<sandboxir::User>(Arg0));
EXPECT_FALSE(isa<sandboxir::User>(BB));
EXPECT_TRUE(isa<sandboxir::User>(AddI));
@@ -180,8 +180,8 @@ define i32 @foo(i32 %v0, i32 %v1) {
BS << "\n";
I0->getOperandUse(0).dump(BS);
EXPECT_EQ(Buff, R"IR(
-Def: i32 %v0 ; SB1. (Argument)
-User: %add0 = add i32 %v0, %v1 ; SB4. (Opaque)
+Def: i32 %v0 ; SB2. (Argument)
+User: %add0 = add i32 %v0, %v1 ; SB5. (Opaque)
OperandNo: 0
)IR");
#endif // NDEBUG
@@ -398,10 +398,10 @@ define void @foo(i32 %arg0, i32 %arg1) {
EXPECT_EQ(Buff, R"IR(
void @foo(i32 %arg0, i32 %arg1) {
bb0:
- br label %bb1 ; SB3. (Br)
+ br label %bb1 ; SB4. (Br)
bb1:
- ret void ; SB5. (Ret)
+ ret void ; SB6. (Ret)
}
)IR");
}
@@ -466,7 +466,7 @@ define void @foo(i32 %v1) {
BB0.dump(BS);
EXPECT_EQ(Buff, R"IR(
bb0:
- br label %bb1 ; SB2. (Br)
+ br label %bb1 ; SB3. (Br)
)IR");
}
#endif // NDEBUG
@@ -836,3 +836,203 @@ define i8 @foo(i8 %val) {
sandboxir::ReturnInst::create(Val, /*InsertAtEnd=*/BB, Ctx));
EXPECT_EQ(NewRet4->getReturnValue(), Val);
}
+
+TEST_F(SandboxIRTest, CallBase) {
+ parseIR(C, R"IR(
+declare void @bar1(i8)
+declare void @bar2()
+declare void @bar3()
+declare void @variadic(ptr, ...)
+
+define i8 @foo(i8 %arg0, i32 %arg1, ptr %indirectFoo) {
+ %call = call i8 @foo(i8 %arg0, i32 %arg1)
+ call void @bar1(i8 %arg0)
+ call void @bar2()
+ call void %indirectFoo()
+ call void @bar2() noreturn
+ tail call fastcc void @bar2()
+ call void (ptr, ...) @variadic(ptr %indirectFoo, i32 1)
+ ret i8 %call
+}
+)IR");
+ llvm::Function &LLVMF = *M->getFunction("foo");
+ unsigned ArgIdx = 0;
+ llvm::Argument *LLVMArg0 = LLVMF.getArg(ArgIdx++);
+ llvm::Argument *LLVMArg1 = LLVMF.getArg(ArgIdx++);
+ llvm::BasicBlock *LLVMBB = &*LLVMF.begin();
+ SmallVector<llvm::CallBase *, 8> LLVMCalls;
+ auto LLVMIt = LLVMBB->begin();
+ while (isa<llvm::CallBase>(&*LLVMIt))
+ LLVMCalls.push_back(cast<llvm::CallBase>(&*LLVMIt++));
+
+ sandboxir::Context Ctx(C);
+ sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
+
+ for (llvm::CallBase *LLVMCall : LLVMCalls) {
+ // Check classof(Instruction *).
+ auto *Call = cast<sandboxir::CallBase>(Ctx.getValue(LLVMCall));
+ // Check classof(Value *).
+ EXPECT_TRUE(isa<sandboxir::CallBase>((sandboxir::Value *)Call));
+ // Check getFunctionType().
+ EXPECT_EQ(Call->getFunctionType(), LLVMCall->getFunctionType());
+ // Check data_ops().
+ EXPECT_EQ(range_size(Call->data_ops()), range_size(LLVMCall->data_ops()));
+ auto DataOpIt = Call->data_operands_begin();
+ for (llvm::Use &LLVMUse : LLVMCall->data_ops()) {
+ Value *LLVMOp = LLVMUse.get();
+ sandboxir::Use Use = *DataOpIt++;
+ EXPECT_EQ(Ctx.getValue(LLVMOp), Use.get());
+ // Check isDataOperand().
+ EXPECT_EQ(Call->isDataOperand(Use), LLVMCall->isDataOperand(&LLVMUse));
+ // Check getDataOperandNo().
+ EXPECT_EQ(Call->getDataOperandNo(Use),
+ LLVMCall->getDataOperandNo(&LLVMUse));
+ // Check isArgOperand().
+ EXPECT_EQ(Call->isArgOperand(Use), LLVMCall->isArgOperand(&LLVMUse));
+ // Check isCallee().
+ EXPECT_EQ(Call->isCallee(Use), LLVMCall->isCallee(&LLVMUse));
+ }
+ // Check data_operands_empty().
+ EXPECT_EQ(Call->data_operands_empty(), LLVMCall->data_operands_empty());
+ // Check data_operands_size().
+ EXPECT_EQ(Call->data_operands_size(), LLVMCall->data_operands_size());
+ // Check getNumTotalBundleOperands().
+ EXPECT_EQ(Call->getNumTotalBundleOperands(),
+ LLVMCall->getNumTotalBundleOperands());
+ // Check args().
+ EXPECT_EQ(range_size(Call->args()), range_size(LLVMCall->args()));
+ auto ArgIt = Call->arg_begin();
+ for (llvm::Use &LLVMUse : LLVMCall->args()) {
+ Value *LLVMArg = LLVMUse.get();
+ sandboxir::Use Use = *ArgIt++;
+ EXPECT_EQ(Ctx.getValue(LLVMArg), Use.get());
+ }
+ // Check arg_empty().
+ EXPECT_EQ(Call->arg_empty(), LLVMCall->arg_empty());
+ // Check arg_size().
+ EXPECT_EQ(Call->arg_size(), LLVMCall->arg_size());
+ for (unsigned ArgIdx = 0, E = Call->arg_size(); ArgIdx != E; ++ArgIdx) {
+ // Check getArgOperand().
+ EXPECT_EQ(Call->getArgOperand(ArgIdx),
+ Ctx.getValue(LLVMCall->getArgOperand(ArgIdx)));
+ // Check getArgOperandUse().
+ sandboxir::Use Use = Call->getArgOperandUse(ArgIdx);
+ llvm::Use &LLVMUse = LLVMCall->getArgOperandUse(ArgIdx);
+ EXPECT_EQ(Use.get(), Ctx.getValue(LLVMUse.get()));
+ // Check getArgOperandNo().
+ EXPECT_EQ(Call->getArgOperandNo(Use),
+ LLVMCall->getArgOperandNo(&LLVMUse));
+ }
+ // Check hasArgument().
+ SmallVector<llvm::Value *> TestArgs(
+ {LLVMArg0, LLVMArg1, &LLVMF, LLVMBB, LLVMCall});
+ for (llvm::Value *LLVMV : TestArgs) {
+ sandboxir::Value *V = Ctx.getValue(LLVMV);
+ EXPECT_EQ(Call->hasArgument(V), LLVMCall->hasArgument(LLVMV));
+ }
+ // Check getCalledOperand().
+ EXPECT_EQ(Call->getCalledOperand(),
+ Ctx.getValue(LLVMCall->getCalledOperand()));
+ // Check getCalledOperandUse().
+ EXPECT_EQ(Call->getCalledOperandUse().get(),
+ Ctx.getValue(LLVMCall->getCalledOperandUse()));
+ // Check getCalledFunction().
+ if (LLVMCall->getCalledFunction() == nullptr)
+ EXPECT_EQ(Call->getCalledFunction(), nullptr);
+ else {
+ auto *LLVMCF = cast<llvm::Function>(LLVMCall->getCalledFunction());
+ (void)LLVMCF;
+ EXPECT_EQ(Call->getCalledFunction(),
+ cast<sandboxir::Function>(
+ Ctx.getValue(LLVMCall->getCalledFunction())));
+ }
+ // Check isIndirectCall().
+ EXPECT_EQ(Call->isIndirectCall(), LLVMCall->isIndirectCall());
+ // Check getCaller().
+ EXPECT_EQ(Call->getCaller(), Ctx.getValue(LLVMCall->getCaller()));
+ // Check isMustTailCall().
+ EXPECT_EQ(Call->isMustTailCall(), LLVMCall->isMustTailCall());
+ // Check isTailCall().
+ EXPECT_EQ(Call->isTailCall(), LLVMCall->isTailCall());
+ // Check getIntrinsicID().
+ EXPECT_EQ(Call->getIntrinsicID(), LLVMCall->getIntrinsicID());
+ // Check getCallingConv().
+ EXPECT_EQ(Call->getCallingConv(), LLVMCall->getCallingConv());
+ // Check isInlineAsm().
+ EXPECT_EQ(Call->isInlineAsm(), LLVMCall->isInlineAsm());
+ }
+
+ auto *Arg0 = F.getArg(0);
+ auto *Arg1 = F.getArg(1);
+ auto *BB = &*F.begin();
+ auto It = BB->begin();
+ auto *Call0 = cast<sandboxir::CallBase>(&*It++);
+ [[maybe_unused]] auto *Call1 = cast<sandboxir::CallBase>(&*It++);
+ auto *Call2 = cast<sandboxir::CallBase>(&*It++);
+ // Check setArgOperand
+ Call0->setArgOperand(0, Arg1);
+ EXPECT_EQ(Call0->getArgOperand(0), Arg1);
+ Call0->setArgOperand(0, Arg0);
+ EXPECT_EQ(Call0->getArgOperand(0), Arg0);
+
+ auto *Bar3F = Ctx.createFunction(M->getFunction("bar3"));
+
+ // Check setCalledOperand
+ auto *SvOp = Call0->getCalledOperand();
+ Call0->setCalledOperand(Bar3F);
+ EXPECT_EQ(Call0->getCalledOperand(), Bar3F);
+ Call0->setCalledOperand(SvOp);
+ // Check setCalledFunction
+ Call2->setCalledFunction(Bar3F);
+ EXPECT_EQ(Call2->getCalledFunction(), Bar3F);
+}
+
+TEST_F(SandboxIRTest, CallInst) {
+ parseIR(C, R"IR(
+define i8 @foo(i8 %arg) {
+ %call = call i8 @foo(i8 %arg)
+ ret i8 %call
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+ unsigned ArgIdx = 0;
+ auto *Arg0 = F.getArg(ArgIdx++);
+ auto *BB = &*F.begin();
+ auto It = BB->begin();
+ auto *Call = cast<sandboxir::CallInst>(&*It++);
+ auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+ EXPECT_EQ(Call->getNumOperands(), 2u);
+ EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Ret);
+ FunctionType *FTy = F.getFunctionType();
+ SmallVector<sandboxir::Value *, 1> Args;
+ Args.push_back(Arg0);
+ {
+ // Check create() WhereIt.
+ auto *Call = cast<sandboxir::CallInst>(sandboxir::CallInst::create(
+ FTy, &F, Args, /*WhereIt=*/Ret->getIterator(), BB, Ctx));
+ EXPECT_EQ(Call->getNextNode(), Ret);
+ EXPECT_EQ(Call->getCalledFunction(), &F);
+ EXPECT_EQ(range_size(Call->args()), 1u);
+ EXPECT_EQ(Call->getArgOperand(0), Arg0);
+ }
+ {
+ // Check create() InsertBefore.
+ auto *Call = cast<sandboxir::CallInst>(
+ sandboxir::CallInst::create(FTy, &F, Args, /*InsertBefore=*/Ret, Ctx));
+ EXPECT_EQ(Call->getNextNode(), Ret);
+ EXPECT_EQ(Call->getCalledFunction(), &F);
+ EXPECT_EQ(range_size(Call->args()), 1u);
+ EXPECT_EQ(Call->getArgOperand(0), Arg0);
+ }
+ {
+ // Check create() InsertAtEnd.
+ auto *Call = cast<sandboxir::CallInst>(
+ sandboxir::CallInst::create(FTy, &F, Args, /*InsertAtEnd=*/BB, Ctx));
+ EXPECT_EQ(Call->getPrevNode(), Ret);
+ EXPECT_EQ(Call->getCalledFunction(), &F);
+ EXPECT_EQ(range_size(Call->args()), 1u);
+ EXPECT_EQ(Call->getArgOperand(0), Arg0);
+ }
+}
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index dd9dcd543236e..5111d5f38798f 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -69,6 +69,34 @@ define void @foo(ptr %ptr) {
EXPECT_EQ(Ld->getOperand(0), Gep0);
}
+TEST_F(TrackerTest, SetUse) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %arg) {
+ %ld = load i8, ptr %ptr
+ %add = add i8 %ld, %arg
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(&LLVMF);
+ unsigned ArgIdx = 0;
+ auto *Arg0 = F->getArg(ArgIdx++);
+ auto *BB = &*F->begin();
+ auto &Tracker = Ctx.getTracker();
+ Tracker.save();
+ auto It = BB->begin();
+ auto *Ld = &*It++;
+ auto *Add = &*It++;
+
+ Ctx.save();
+ sandboxir::Use Use = Add->getOperandUse(0);
+ Use.set(Arg0);
+ EXPECT_EQ(Add->getOperand(0), Arg0);
+ Ctx.revert();
+ EXPECT_EQ(Add->getOperand(0), Ld);
+}
+
TEST_F(TrackerTest, SwapOperands) {
parseIR(C, R"IR(
define void @foo(i1 %cond) {
@@ -413,3 +441,50 @@ define i32 @foo(i32 %arg) {
EXPECT_EQ(&*It++, Ret);
EXPECT_EQ(It, BB->end());
}
+
+TEST_F(TrackerTest, CallBaseSetters) {
+ parseIR(C, R"IR(
+declare void @bar1(i8)
+declare void @bar2(i8)
+
+define void @foo(i8 %arg0, i8 %arg1) {
+ call void @bar1(i8 %arg0)
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+
+ auto *F = Ctx.createFunction(&LLVMF);
+ unsigned ArgIdx = 0;
+ auto *Arg0 = F->getArg(ArgIdx++);
+ auto *Arg1 = F->getArg(ArgIdx++);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *Call = cast<sandboxir::CallBase>(&*It++);
+ [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+ // Check setArgOperand().
+ Ctx.save();
+ Call->setArgOperand(0, Arg1);
+ EXPECT_EQ(Call->getArgOperand(0), Arg1);
+ Ctx.revert();
+ EXPECT_EQ(Call->getArgOperand(0), Arg0);
+
+ auto *Bar1F = Call->getCalledFunction();
+ auto *Bar2F = Ctx.createFunction(M->getFunction("bar2"));
+
+ // Check setCalledOperand().
+ Ctx.save();
+ Call->setCalledOperand(Bar2F);
+ EXPECT_EQ(Call->getCalledOperand(), Bar2F);
+ Ctx.revert();
+ EXPECT_EQ(Call->getCalledOperand(), Bar1F);
+
+ // Check setCalledFunction().
+ Ctx.save();
+ Call->setCalledFunction(Bar2F);
+ EXPECT_EQ(Call->getCalledFunction(), Bar2F);
+ Ctx.revert();
+ EXPECT_EQ(Call->getCalledFunction(), Bar1F);
+}
More information about the llvm-commits
mailing list