[llvm] [SandboxIR] Implement SelectInst (PR #99996)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 22 16:35:38 PDT 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/99996
>From 388553b3061657b7db39673338faf97084095e12 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 22 Jul 2024 15:18:02 -0700
Subject: [PATCH 1/2] [SandboxIR] Implement SelectInst
This patch implements sandboxir::SelectInst which mirrors llvm::SelectInst.
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 51 ++++++++++++++++++
.../llvm/SandboxIR/SandboxIRValues.def | 1 +
llvm/lib/SandboxIR/SandboxIR.cpp | 54 +++++++++++++++++++
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 35 ++++++++++++
4 files changed, 141 insertions(+)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index cd77897ccbb94..a15fb5b8a7030 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -177,6 +177,7 @@ class Value {
friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
+ friend class SelectInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
@@ -499,6 +500,7 @@ class Instruction : public sandboxir::User {
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
+ friend class SelectInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
@@ -566,6 +568,49 @@ class Instruction : public sandboxir::User {
#endif
};
+class SelectInst : public Instruction {
+ /// Use Context::createSelectInst(). Don't call the
+ /// constructor directly.
+ SelectInst(llvm::SelectInst *CI, Context &Ctx)
+ : Instruction(ClassID::Select, Opcode::Select, CI, Ctx) {}
+ friend Context; // for SelectInst()
+ Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+ return getOperandUseDefault(OpIdx, Verify);
+ }
+ SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+ return {cast<llvm::Instruction>(Val)};
+ }
+
+public:
+ unsigned getUseOperandNo(const Use &Use) const final {
+ return getUseOperandNoDefault(Use);
+ }
+ unsigned getNumOfIRInstrs() const final { return 1u; }
+ static Value *create(Value *Cond, Value *True, Value *False,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name = "");
+ static Value *create(Value *Cond, Value *True, Value *False,
+ BasicBlock *InsertAtEnd, Context &Ctx,
+ const Twine &Name = "");
+ Value *getCondition() { return getOperand(0); }
+ Value *getTrueValue() { return getOperand(1); }
+ Value *getFalseValue() { return getOperand(2); }
+
+ void setCondition(Value *New) { setOperand(0, New); }
+ void setTrueValue(Value *New) { setOperand(1, New); }
+ void setFalseValue(Value *New) { setOperand(2, New); }
+ void swapValues() { cast<llvm::SelectInst>(Val)->swapValues(); }
+ /// For isa/dyn_cast.
+ static bool classof(const Value *From);
+#ifndef NDEBUG
+ void verify() const final {
+ assert(isa<llvm::SelectInst>(Val) && "Expected SelectInst!");
+ }
+ void dump(raw_ostream &OS) const override;
+ LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
class LoadInst final : public Instruction {
/// Use LoadInst::create() instead of calling the constructor.
LoadInst(llvm::LoadInst *LI, Context &Ctx)
@@ -803,6 +848,10 @@ class Context {
Value *getOrCreateValue(llvm::Value *LLVMV) {
return getOrCreateValueInternal(LLVMV, 0);
}
+ /// Get or create a sandboxir::Constant from an existing LLVM IR \p LLVMC.
+ Constant *getOrCreateConstant(llvm::Constant *LLVMC) {
+ return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
+ }
/// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
/// also create all contents of the block.
BasicBlock *createBasicBlock(llvm::BasicBlock *BB);
@@ -812,6 +861,8 @@ class Context {
IRBuilder<ConstantFolder> LLVMIRBuilder;
auto &getLLVMIRBuilder() { return LLVMIRBuilder; }
+ SelectInst *createSelectInst(llvm::SelectInst *SI);
+ friend SelectInst; // For createSelectInst()
LoadInst *createLoadInst(llvm::LoadInst *LI);
friend LoadInst; // For createLoadInst()
StoreInst *createStoreInst(llvm::StoreInst *SI);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index b2f88741af8d9..efa9155755587 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -25,6 +25,7 @@ DEF_USER(Constant, Constant)
#endif
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
+DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 4cf45fa87693a..b3f444b6f2bc9 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -455,6 +455,50 @@ void Instruction::dump() const {
}
#endif // NDEBUG
+Value *SelectInst::create(Value *Cond, Value *True, Value *False,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name) {
+ llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ Builder.SetInsertPoint(BeforeIR);
+ llvm::Value *NewV =
+ Builder.CreateSelect(Cond->Val, True->Val, False->Val, Name);
+ if (auto *NewSI = dyn_cast<llvm::SelectInst>(NewV))
+ return Ctx.createSelectInst(NewSI);
+ assert(isa<llvm::Constant>(NewV) && "Expected constant");
+ return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+Value *SelectInst::create(Value *Cond, Value *True, Value *False,
+ BasicBlock *InsertAtEnd, Context &Ctx,
+ const Twine &Name) {
+ auto *IRInsertAtEnd = cast<llvm::BasicBlock>(InsertAtEnd->Val);
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ Builder.SetInsertPoint(IRInsertAtEnd);
+ llvm::Value *NewV =
+ Builder.CreateSelect(Cond->Val, True->Val, False->Val, Name);
+ if (auto *NewSI = dyn_cast<llvm::SelectInst>(NewV))
+ return Ctx.createSelectInst(NewSI);
+ assert(isa<llvm::Constant>(NewV) && "Expected constant");
+ return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+bool SelectInst::classof(const Value *From) {
+ return From->getSubclassID() == ClassID::Select;
+}
+
+#ifndef NDEBUG
+void SelectInst::dump(raw_ostream &OS) const {
+ dumpCommonPrefix(OS);
+ dumpCommonSuffix(OS);
+}
+
+void SelectInst::dump() const {
+ dump(dbgs());
+ dbgs() << "\n";
+}
+#endif // NDEBUG
+
LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
@@ -700,6 +744,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");
switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
+ case llvm::Instruction::Select: {
+ auto *LLVMSel = cast<llvm::SelectInst>(LLVMV);
+ It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
+ return It->second.get();
+ }
case llvm::Instruction::Load: {
auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
@@ -733,6 +782,11 @@ BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
return BB;
}
+SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
+ auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
+ return cast<SelectInst>(registerValue(std::move(NewPtr)));
+}
+
LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
return cast<LoadInst>(registerValue(std::move(NewPtr)));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index b0d6ae85950d7..0bfd2586e2ca3 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -561,6 +561,41 @@ define void @foo(i8 %v1) {
EXPECT_EQ(I0->getNextNode(), Ret);
}
+TEST_F(SandboxIRTest, SelectInst) {
+ parseIR(C, R"IR(
+define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
+ %sel = select i1 %c0, i8 %v0, i8 %v1
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ sandboxir::Function *F = Ctx.createFunction(LLVMF);
+ auto *Cond0 = F->getArg(0);
+ auto *V0 = F->getArg(1);
+ auto *V1 = F->getArg(2);
+ auto *Cond1 = F->getArg(3);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *Select = cast<sandboxir::SelectInst>(&*It++);
+
+ // Check getCondition().
+ EXPECT_EQ(Select->getCondition(), Cond0);
+ // Check getTrueValue().
+ EXPECT_EQ(Select->getTrueValue(), V0);
+ // Check getFalseValue().
+ EXPECT_EQ(Select->getFalseValue(), V1);
+ // Check setCondition().
+ Select->setCondition(Cond1);
+ EXPECT_EQ(Select->getCondition(), Cond1);
+ // Check setTrueValue().
+ Select->setTrueValue(V1);
+ EXPECT_EQ(Select->getTrueValue(), V1);
+ // Check setFalseValue().
+ Select->setFalseValue(V0);
+ EXPECT_EQ(Select->getFalseValue(), V0);
+}
+
TEST_F(SandboxIRTest, LoadInst) {
parseIR(C, R"IR(
define void @foo(ptr %arg0, ptr %arg1) {
>From 7d785c61c16258c06783c6302b3cc36b1eb60f1c Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 22 Jul 2024 16:35:06 -0700
Subject: [PATCH 2/2] fixup! [SandboxIR] Implement SelectInst
Add tests to exercise SelectInst::create(), including the folded case.
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 3 ++
llvm/lib/SandboxIR/SandboxIR.cpp | 8 ++++++
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 33 ++++++++++++++++++++++
3 files changed, 44 insertions(+)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index a15fb5b8a7030..4ab4c1002ffbf 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -412,6 +412,8 @@ class Constant : public sandboxir::User {
}
public:
+ static Constant *createInt(Type *Ty, uint64_t V, Context &Ctx,
+ bool IsSigned = false);
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From) {
return From->getSubclassID() == ClassID::Constant ||
@@ -852,6 +854,7 @@ class Context {
Constant *getOrCreateConstant(llvm::Constant *LLVMC) {
return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
}
+ friend class Constant; // For getOrCreateConstant().
/// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
/// also create all contents of the block.
BasicBlock *createBasicBlock(llvm::BasicBlock *BB);
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index b3f444b6f2bc9..96fd22c3934bd 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -636,7 +636,15 @@ void OpaqueInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
+#endif // NDEBUG
+
+Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
+ bool IsSigned) {
+ llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
+ return Ctx.getOrCreateConstant(LLVMC);
+}
+#ifndef NDEBUG
void Constant::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 0bfd2586e2ca3..ba90b4f811f8e 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -578,6 +578,7 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
auto *BB = &*F->begin();
auto It = BB->begin();
auto *Select = cast<sandboxir::SelectInst>(&*It++);
+ auto *Ret = &*It++;
// Check getCondition().
EXPECT_EQ(Select->getCondition(), Cond0);
@@ -594,6 +595,38 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
// Check setFalseValue().
Select->setFalseValue(V0);
EXPECT_EQ(Select->getFalseValue(), V0);
+
+ {
+ // Check SelectInst::create() InsertBefore.
+ auto *NewSel = cast<sandboxir::SelectInst>(sandboxir::SelectInst::create(
+ Cond0, V0, V1, /*InsertBefore=*/Ret, Ctx));
+ EXPECT_EQ(NewSel->getCondition(), Cond0);
+ EXPECT_EQ(NewSel->getTrueValue(), V0);
+ EXPECT_EQ(NewSel->getFalseValue(), V1);
+ EXPECT_EQ(NewSel->getNextNode(), Ret);
+ }
+ {
+ // Check SelectInst::create() InsertAtEnd.
+ auto *NewSel = cast<sandboxir::SelectInst>(
+ sandboxir::SelectInst::create(Cond0, V0, V1, /*InsertAtEnd=*/BB, Ctx));
+ EXPECT_EQ(NewSel->getCondition(), Cond0);
+ EXPECT_EQ(NewSel->getTrueValue(), V0);
+ EXPECT_EQ(NewSel->getFalseValue(), V1);
+ EXPECT_EQ(NewSel->getPrevNode(), Ret);
+ }
+ {
+ // Check SelectInst::create() Folded.
+ auto *False =
+ sandboxir::Constant::createInt(llvm::Type::getInt1Ty(C), 0, Ctx,
+ /*IsSigned=*/false);
+ auto *FortyTwo =
+ sandboxir::Constant::createInt(llvm::Type::getInt1Ty(C), 42, Ctx,
+ /*IsSigned=*/false);
+ auto *NewSel =
+ sandboxir::SelectInst::create(False, FortyTwo, FortyTwo, Ret, Ctx);
+ EXPECT_TRUE(isa<sandboxir::Constant>(NewSel));
+ EXPECT_EQ(NewSel, FortyTwo);
+ }
}
TEST_F(SandboxIRTest, LoadInst) {
More information about the llvm-commits
mailing list