[llvm] [SandboxIR] Implement SelectInst (PR #99996)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 22 16:35:38 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/99996

>From 388553b3061657b7db39673338faf97084095e12 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 22 Jul 2024 15:18:02 -0700
Subject: [PATCH 1/2] [SandboxIR] Implement SelectInst

This patch implements sandboxir::SelectInst which mirrors llvm::SelectInst.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 51 ++++++++++++++++++
 .../llvm/SandboxIR/SandboxIRValues.def        |  1 +
 llvm/lib/SandboxIR/SandboxIR.cpp              | 54 +++++++++++++++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 35 ++++++++++++
 4 files changed, 141 insertions(+)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index cd77897ccbb94..a15fb5b8a7030 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -177,6 +177,7 @@ class Value {
   friend class Context;    // For getting `Val`.
   friend class User;       // For getting `Val`.
   friend class Use;        // For getting `Val`.
+  friend class SelectInst; // For getting `Val`.
   friend class LoadInst;   // For getting `Val`.
   friend class StoreInst;  // For getting `Val`.
   friend class ReturnInst; // For getting `Val`.
@@ -499,6 +500,7 @@ class Instruction : public sandboxir::User {
   /// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
   /// returns its topmost LLVM IR instruction.
   llvm::Instruction *getTopmostLLVMInstruction() const;
+  friend class SelectInst; // For getTopmostLLVMInstruction().
   friend class LoadInst;   // For getTopmostLLVMInstruction().
   friend class StoreInst;  // For getTopmostLLVMInstruction().
   friend class ReturnInst; // For getTopmostLLVMInstruction().
@@ -566,6 +568,49 @@ class Instruction : public sandboxir::User {
 #endif
 };
 
+class SelectInst : public Instruction {
+  /// Use Context::createSelectInst(). Don't call the
+  /// constructor directly.
+  SelectInst(llvm::SelectInst *CI, Context &Ctx)
+      : Instruction(ClassID::Select, Opcode::Select, CI, Ctx) {}
+  friend Context; // for SelectInst()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+
+public:
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+  static Value *create(Value *Cond, Value *True, Value *False,
+                       Instruction *InsertBefore, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Value *Cond, Value *True, Value *False,
+                       BasicBlock *InsertAtEnd, Context &Ctx,
+                       const Twine &Name = "");
+  Value *getCondition() { return getOperand(0); }
+  Value *getTrueValue() { return getOperand(1); }
+  Value *getFalseValue() { return getOperand(2); }
+
+  void setCondition(Value *New) { setOperand(0, New); }
+  void setTrueValue(Value *New) { setOperand(1, New); }
+  void setFalseValue(Value *New) { setOperand(2, New); }
+  void swapValues() { cast<llvm::SelectInst>(Val)->swapValues(); }
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From);
+#ifndef NDEBUG
+  void verify() const final {
+    assert(isa<llvm::SelectInst>(Val) && "Expected SelectInst!");
+  }
+  void dump(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
 class LoadInst final : public Instruction {
   /// Use LoadInst::create() instead of calling the constructor.
   LoadInst(llvm::LoadInst *LI, Context &Ctx)
@@ -803,6 +848,10 @@ class Context {
   Value *getOrCreateValue(llvm::Value *LLVMV) {
     return getOrCreateValueInternal(LLVMV, 0);
   }
+  /// Get or create a sandboxir::Constant from an existing LLVM IR \p LLVMC.
+  Constant *getOrCreateConstant(llvm::Constant *LLVMC) {
+    return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
+  }
   /// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
   /// also create all contents of the block.
   BasicBlock *createBasicBlock(llvm::BasicBlock *BB);
@@ -812,6 +861,8 @@ class Context {
   IRBuilder<ConstantFolder> LLVMIRBuilder;
   auto &getLLVMIRBuilder() { return LLVMIRBuilder; }
 
+  SelectInst *createSelectInst(llvm::SelectInst *SI);
+  friend SelectInst; // For createSelectInst()
   LoadInst *createLoadInst(llvm::LoadInst *LI);
   friend LoadInst; // For createLoadInst()
   StoreInst *createStoreInst(llvm::StoreInst *SI);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index b2f88741af8d9..efa9155755587 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -25,6 +25,7 @@ DEF_USER(Constant, Constant)
 #endif
 //       ClassID, Opcode(s),  Class
 DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
+DEF_INSTR(Select, OP(Select), SelectInst)
 DEF_INSTR(Load, OP(Load), LoadInst)
 DEF_INSTR(Store, OP(Store), StoreInst)
 DEF_INSTR(Ret, OP(Ret), ReturnInst)
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 4cf45fa87693a..b3f444b6f2bc9 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -455,6 +455,50 @@ void Instruction::dump() const {
 }
 #endif // NDEBUG
 
+Value *SelectInst::create(Value *Cond, Value *True, Value *False,
+                          Instruction *InsertBefore, Context &Ctx,
+                          const Twine &Name) {
+  llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(BeforeIR);
+  llvm::Value *NewV =
+      Builder.CreateSelect(Cond->Val, True->Val, False->Val, Name);
+  if (auto *NewSI = dyn_cast<llvm::SelectInst>(NewV))
+    return Ctx.createSelectInst(NewSI);
+  assert(isa<llvm::Constant>(NewV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+Value *SelectInst::create(Value *Cond, Value *True, Value *False,
+                          BasicBlock *InsertAtEnd, Context &Ctx,
+                          const Twine &Name) {
+  auto *IRInsertAtEnd = cast<llvm::BasicBlock>(InsertAtEnd->Val);
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(IRInsertAtEnd);
+  llvm::Value *NewV =
+      Builder.CreateSelect(Cond->Val, True->Val, False->Val, Name);
+  if (auto *NewSI = dyn_cast<llvm::SelectInst>(NewV))
+    return Ctx.createSelectInst(NewSI);
+  assert(isa<llvm::Constant>(NewV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+bool SelectInst::classof(const Value *From) {
+  return From->getSubclassID() == ClassID::Select;
+}
+
+#ifndef NDEBUG
+void SelectInst::dump(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+
+void SelectInst::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
 LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
                            Instruction *InsertBefore, Context &Ctx,
                            const Twine &Name) {
@@ -700,6 +744,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
   assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");
 
   switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
+  case llvm::Instruction::Select: {
+    auto *LLVMSel = cast<llvm::SelectInst>(LLVMV);
+    It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::Load: {
     auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
     It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
@@ -733,6 +782,11 @@ BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
   return BB;
 }
 
+SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
+  auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
+  return cast<SelectInst>(registerValue(std::move(NewPtr)));
+}
+
 LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
   auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
   return cast<LoadInst>(registerValue(std::move(NewPtr)));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index b0d6ae85950d7..0bfd2586e2ca3 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -561,6 +561,41 @@ define void @foo(i8 %v1) {
   EXPECT_EQ(I0->getNextNode(), Ret);
 }
 
+TEST_F(SandboxIRTest, SelectInst) {
+  parseIR(C, R"IR(
+define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
+  %sel = select i1 %c0, i8 %v0, i8 %v1
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  sandboxir::Function *F = Ctx.createFunction(LLVMF);
+  auto *Cond0 = F->getArg(0);
+  auto *V0 = F->getArg(1);
+  auto *V1 = F->getArg(2);
+  auto *Cond1 = F->getArg(3);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Select = cast<sandboxir::SelectInst>(&*It++);
+
+  // Check getCondition().
+  EXPECT_EQ(Select->getCondition(), Cond0);
+  // Check getTrueValue().
+  EXPECT_EQ(Select->getTrueValue(), V0);
+  // Check getFalseValue().
+  EXPECT_EQ(Select->getFalseValue(), V1);
+  // Check setCondition().
+  Select->setCondition(Cond1);
+  EXPECT_EQ(Select->getCondition(), Cond1);
+  // Check setTrueValue().
+  Select->setTrueValue(V1);
+  EXPECT_EQ(Select->getTrueValue(), V1);
+  // Check setFalseValue().
+  Select->setFalseValue(V0);
+  EXPECT_EQ(Select->getFalseValue(), V0);
+}
+
 TEST_F(SandboxIRTest, LoadInst) {
   parseIR(C, R"IR(
 define void @foo(ptr %arg0, ptr %arg1) {

>From 7d785c61c16258c06783c6302b3cc36b1eb60f1c Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 22 Jul 2024 16:35:06 -0700
Subject: [PATCH 2/2] fixup! [SandboxIR] Implement SelectInst

Add tests to exercise SelectInst::create(), including the folded case.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h    |  3 ++
 llvm/lib/SandboxIR/SandboxIR.cpp           |  8 ++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 33 ++++++++++++++++++++++
 3 files changed, 44 insertions(+)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index a15fb5b8a7030..4ab4c1002ffbf 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -412,6 +412,8 @@ class Constant : public sandboxir::User {
   }
 
 public:
+  static Constant *createInt(Type *Ty, uint64_t V, Context &Ctx,
+                             bool IsSigned = false);
   /// For isa/dyn_cast.
   static bool classof(const sandboxir::Value *From) {
     return From->getSubclassID() == ClassID::Constant ||
@@ -852,6 +854,7 @@ class Context {
   Constant *getOrCreateConstant(llvm::Constant *LLVMC) {
     return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
   }
+  friend class Constant; // For getOrCreateConstant().
   /// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
   /// also create all contents of the block.
   BasicBlock *createBasicBlock(llvm::BasicBlock *BB);
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index b3f444b6f2bc9..96fd22c3934bd 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -636,7 +636,15 @@ void OpaqueInst::dump() const {
   dump(dbgs());
   dbgs() << "\n";
 }
+#endif // NDEBUG
+
+Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
+                              bool IsSigned) {
+  llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
+  return Ctx.getOrCreateConstant(LLVMC);
+}
 
+#ifndef NDEBUG
 void Constant::dump(raw_ostream &OS) const {
   dumpCommonPrefix(OS);
   dumpCommonSuffix(OS);
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 0bfd2586e2ca3..ba90b4f811f8e 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -578,6 +578,7 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
   auto *BB = &*F->begin();
   auto It = BB->begin();
   auto *Select = cast<sandboxir::SelectInst>(&*It++);
+  auto *Ret = &*It++;
 
   // Check getCondition().
   EXPECT_EQ(Select->getCondition(), Cond0);
@@ -594,6 +595,38 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
   // Check setFalseValue().
   Select->setFalseValue(V0);
   EXPECT_EQ(Select->getFalseValue(), V0);
+
+  {
+    // Check SelectInst::create() InsertBefore.
+    auto *NewSel = cast<sandboxir::SelectInst>(sandboxir::SelectInst::create(
+        Cond0, V0, V1, /*InsertBefore=*/Ret, Ctx));
+    EXPECT_EQ(NewSel->getCondition(), Cond0);
+    EXPECT_EQ(NewSel->getTrueValue(), V0);
+    EXPECT_EQ(NewSel->getFalseValue(), V1);
+    EXPECT_EQ(NewSel->getNextNode(), Ret);
+  }
+  {
+    // Check SelectInst::create() InsertAtEnd.
+    auto *NewSel = cast<sandboxir::SelectInst>(
+        sandboxir::SelectInst::create(Cond0, V0, V1, /*InsertAtEnd=*/BB, Ctx));
+    EXPECT_EQ(NewSel->getCondition(), Cond0);
+    EXPECT_EQ(NewSel->getTrueValue(), V0);
+    EXPECT_EQ(NewSel->getFalseValue(), V1);
+    EXPECT_EQ(NewSel->getPrevNode(), Ret);
+  }
+  {
+    // Check SelectInst::create() Folded.
+    auto *False =
+        sandboxir::Constant::createInt(llvm::Type::getInt1Ty(C), 0, Ctx,
+                                       /*IsSigned=*/false);
+    auto *FortyTwo =
+        sandboxir::Constant::createInt(llvm::Type::getInt1Ty(C), 42, Ctx,
+                                       /*IsSigned=*/false);
+    auto *NewSel =
+        sandboxir::SelectInst::create(False, FortyTwo, FortyTwo, Ret, Ctx);
+    EXPECT_TRUE(isa<sandboxir::Constant>(NewSel));
+    EXPECT_EQ(NewSel, FortyTwo);
+  }
 }
 
 TEST_F(SandboxIRTest, LoadInst) {



More information about the llvm-commits mailing list