[llvm] 6ec169d - [SandboxIR] Implement BinaryOperator (#104121)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 13:28:11 PDT 2024


Author: vporpo
Date: 2024-08-15T13:28:08-07:00
New Revision: 6ec169d3501124770c3301dab8156c4640346c40

URL: https://github.com/llvm/llvm-project/commit/6ec169d3501124770c3301dab8156c4640346c40
DIFF: https://github.com/llvm/llvm-project/commit/6ec169d3501124770c3301dab8156c4640346c40.diff

LOG: [SandboxIR] Implement BinaryOperator (#104121)

This patch implements sandboxir::BinaryOperator mirroring
llvm::BinaryOperator.

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/include/llvm/SandboxIR/SandboxIRValues.def
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index a2e2a32e9c01eb..b1af769f29af54 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -130,6 +130,7 @@ class CastInst;
 class PtrToIntInst;
 class BitCastInst;
 class AllocaInst;
+class BinaryOperator;
 class AtomicCmpXchgInst;
 
 /// Iterator for the `Use` edges of a User's operands.
@@ -249,6 +250,7 @@ class Value {
   friend class InvokeInst;         // For getting `Val`.
   friend class CallBrInst;         // For getting `Val`.
   friend class GetElementPtrInst;  // For getting `Val`.
+  friend class BinaryOperator;     // For getting `Val`.
   friend class AtomicCmpXchgInst;  // For getting `Val`.
   friend class AllocaInst;         // For getting `Val`.
   friend class CastInst;           // For getting `Val`.
@@ -630,6 +632,7 @@ class Instruction : public sandboxir::User {
   friend class InvokeInst;         // For getTopmostLLVMInstruction().
   friend class CallBrInst;         // For getTopmostLLVMInstruction().
   friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
+  friend class BinaryOperator;     // For getTopmostLLVMInstruction().
   friend class AtomicCmpXchgInst;  // For getTopmostLLVMInstruction().
   friend class AllocaInst;         // For getTopmostLLVMInstruction().
   friend class CastInst;           // For getTopmostLLVMInstruction().
@@ -1432,6 +1435,86 @@ class GetElementPtrInst final
   // TODO: Add missing member functions.
 };
 
+class BinaryOperator : public SingleLLVMInstructionImpl<llvm::BinaryOperator> {
+  static Opcode getBinOpOpcode(llvm::Instruction::BinaryOps BinOp) {
+    switch (BinOp) {
+    case llvm::Instruction::Add:
+      return Opcode::Add;
+    case llvm::Instruction::FAdd:
+      return Opcode::FAdd;
+    case llvm::Instruction::Sub:
+      return Opcode::Sub;
+    case llvm::Instruction::FSub:
+      return Opcode::FSub;
+    case llvm::Instruction::Mul:
+      return Opcode::Mul;
+    case llvm::Instruction::FMul:
+      return Opcode::FMul;
+    case llvm::Instruction::UDiv:
+      return Opcode::UDiv;
+    case llvm::Instruction::SDiv:
+      return Opcode::SDiv;
+    case llvm::Instruction::FDiv:
+      return Opcode::FDiv;
+    case llvm::Instruction::URem:
+      return Opcode::URem;
+    case llvm::Instruction::SRem:
+      return Opcode::SRem;
+    case llvm::Instruction::FRem:
+      return Opcode::FRem;
+    case llvm::Instruction::Shl:
+      return Opcode::Shl;
+    case llvm::Instruction::LShr:
+      return Opcode::LShr;
+    case llvm::Instruction::AShr:
+      return Opcode::AShr;
+    case llvm::Instruction::And:
+      return Opcode::And;
+    case llvm::Instruction::Or:
+      return Opcode::Or;
+    case llvm::Instruction::Xor:
+      return Opcode::Xor;
+    case llvm::Instruction::BinaryOpsEnd:
+      llvm_unreachable("Bad BinOp!");
+    }
+    llvm_unreachable("Unhandled BinOp!");
+  }
+  BinaryOperator(llvm::BinaryOperator *BinOp, Context &Ctx)
+      : SingleLLVMInstructionImpl(ClassID::BinaryOperator,
+                                  getBinOpOpcode(BinOp->getOpcode()), BinOp,
+                                  Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  static Value *create(Instruction::Opcode Op, Value *LHS, Value *RHS,
+                       BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Instruction::Opcode Op, Value *LHS, Value *RHS,
+                       Instruction *InsertBefore, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Instruction::Opcode Op, Value *LHS, Value *RHS,
+                       BasicBlock *InsertAtEnd, Context &Ctx,
+                       const Twine &Name = "");
+
+  static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
+                                      Value *RHS, Value *CopyFrom,
+                                      BBIterator WhereIt, BasicBlock *WhereBB,
+                                      Context &Ctx, const Twine &Name = "");
+  static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
+                                      Value *RHS, Value *CopyFrom,
+                                      Instruction *InsertBefore, Context &Ctx,
+                                      const Twine &Name = "");
+  static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
+                                      Value *RHS, Value *CopyFrom,
+                                      BasicBlock *InsertAtEnd, Context &Ctx,
+                                      const Twine &Name = "");
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::BinaryOperator;
+  }
+  void swapOperands() { swapOperandsInternal(0, 1); }
+};
+
 class AtomicCmpXchgInst
     : public SingleLLVMInstructionImpl<llvm::AtomicCmpXchgInst> {
   AtomicCmpXchgInst(llvm::AtomicCmpXchgInst *Atomic, Context &Ctx)
@@ -1876,6 +1959,8 @@ class Context {
   friend CallBrInst; // For createCallBrInst()
   GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
   friend GetElementPtrInst; // For createGetElementPtrInst()
+  BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
+  friend BinaryOperator; // For createBinaryOperator()
   AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);
   friend AtomicCmpXchgInst; // For createAtomicCmpXchgInst()
   AllocaInst *createAllocaInst(llvm::AllocaInst *I);

diff  --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 114dc8505ecd6f..43441ba26ec2cb 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -32,19 +32,39 @@ DEF_USER(Constant, Constant)
 #define OPCODES(...)
 #endif
 // clang-format off
-//        ClassID,        Opcode(s),          Class
-DEF_INSTR(Opaque,         OP(Opaque),         OpaqueInst)
+//       ClassID,        Opcode(s),         Class
+DEF_INSTR(Opaque,        OP(Opaque),        OpaqueInst)
 DEF_INSTR(ExtractElement, OP(ExtractElement), ExtractElementInst)
-DEF_INSTR(InsertElement,  OP(InsertElement),  InsertElementInst)
-DEF_INSTR(Select,         OP(Select),         SelectInst)
-DEF_INSTR(Br,             OP(Br),             BranchInst)
-DEF_INSTR(Load,           OP(Load),           LoadInst)
-DEF_INSTR(Store,          OP(Store),          StoreInst)
-DEF_INSTR(Ret,            OP(Ret),            ReturnInst)
-DEF_INSTR(Call,           OP(Call),           CallInst)
-DEF_INSTR(Invoke,         OP(Invoke),         InvokeInst)
-DEF_INSTR(CallBr,         OP(CallBr),         CallBrInst)
-DEF_INSTR(GetElementPtr,  OP(GetElementPtr),  GetElementPtrInst)
+DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
+DEF_INSTR(Select,        OP(Select),        SelectInst)
+DEF_INSTR(Br,            OP(Br),            BranchInst)
+DEF_INSTR(Load,          OP(Load),          LoadInst)
+DEF_INSTR(Store,         OP(Store),         StoreInst)
+DEF_INSTR(Ret,           OP(Ret),           ReturnInst)
+DEF_INSTR(Call,          OP(Call),          CallInst)
+DEF_INSTR(Invoke,        OP(Invoke),        InvokeInst)
+DEF_INSTR(CallBr,        OP(CallBr),        CallBrInst)
+DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
+DEF_INSTR(BinaryOperator, OPCODES( \
+                         OP(Add)  \
+                         OP(FAdd) \
+                         OP(Sub)  \
+                         OP(FSub) \
+                         OP(Mul)  \
+                         OP(FMul) \
+                         OP(UDiv) \
+                         OP(SDiv) \
+                         OP(FDiv) \
+                         OP(URem) \
+                         OP(SRem) \
+                         OP(FRem) \
+                         OP(Shl)  \
+                         OP(LShr) \
+                         OP(AShr) \
+                         OP(And)  \
+                         OP(Or)   \
+                         OP(Xor)  \
+                         ),                 BinaryOperator)
 DEF_INSTR(AtomicCmpXchg, OP(AtomicCmpXchg), AtomicCmpXchgInst)
 DEF_INSTR(Alloca,         OP(Alloca),         AllocaInst)
 DEF_INSTR(Cast,   OPCODES(\

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 3c7dbb70c9f4b7..52df53892b4506 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1219,6 +1219,107 @@ static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
   }
 }
 
+/// \Returns the LLVM opcode that corresponds to \p Opc.
+static llvm::Instruction::BinaryOps getLLVMBinaryOp(Instruction::Opcode Opc) {
+  switch (Opc) {
+  case Instruction::Opcode::Add:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Add);
+  case Instruction::Opcode::FAdd:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FAdd);
+  case Instruction::Opcode::Sub:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Sub);
+  case Instruction::Opcode::FSub:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FSub);
+  case Instruction::Opcode::Mul:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Mul);
+  case Instruction::Opcode::FMul:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FMul);
+  case Instruction::Opcode::UDiv:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::UDiv);
+  case Instruction::Opcode::SDiv:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::SDiv);
+  case Instruction::Opcode::FDiv:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FDiv);
+  case Instruction::Opcode::URem:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::URem);
+  case Instruction::Opcode::SRem:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::SRem);
+  case Instruction::Opcode::FRem:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::FRem);
+  case Instruction::Opcode::Shl:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Shl);
+  case Instruction::Opcode::LShr:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::LShr);
+  case Instruction::Opcode::AShr:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::AShr);
+  case Instruction::Opcode::And:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::And);
+  case Instruction::Opcode::Or:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Or);
+  case Instruction::Opcode::Xor:
+    return static_cast<llvm::Instruction::BinaryOps>(llvm::Instruction::Xor);
+  default:
+    llvm_unreachable("Not a binary op!");
+  }
+}
+Value *BinaryOperator::create(Instruction::Opcode Op, Value *LHS, Value *RHS,
+                              BBIterator WhereIt, BasicBlock *WhereBB,
+                              Context &Ctx, const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt == WhereBB->end())
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  else
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  llvm::Value *NewV =
+      Builder.CreateBinOp(getLLVMBinaryOp(Op), LHS->Val, RHS->Val, Name);
+  if (auto *NewBinOp = dyn_cast<llvm::BinaryOperator>(NewV))
+    return Ctx.createBinaryOperator(NewBinOp);
+  assert(isa<llvm::Constant>(NewV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+Value *BinaryOperator::create(Instruction::Opcode Op, Value *LHS, Value *RHS,
+                              Instruction *InsertBefore, Context &Ctx,
+                              const Twine &Name) {
+  return create(Op, LHS, RHS, InsertBefore->getIterator(),
+                InsertBefore->getParent(), Ctx, Name);
+}
+
+Value *BinaryOperator::create(Instruction::Opcode Op, Value *LHS, Value *RHS,
+                              BasicBlock *InsertAtEnd, Context &Ctx,
+                              const Twine &Name) {
+  return create(Op, LHS, RHS, InsertAtEnd->end(), InsertAtEnd, Ctx, Name);
+}
+
+Value *BinaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
+                                             Value *RHS, Value *CopyFrom,
+                                             BBIterator WhereIt,
+                                             BasicBlock *WhereBB, Context &Ctx,
+                                             const Twine &Name) {
+
+  Value *NewV = create(Op, LHS, RHS, WhereIt, WhereBB, Ctx, Name);
+  if (auto *NewBO = dyn_cast<BinaryOperator>(NewV))
+    cast<llvm::BinaryOperator>(NewBO->Val)->copyIRFlags(CopyFrom->Val);
+  return NewV;
+}
+
+Value *BinaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
+                                             Value *RHS, Value *CopyFrom,
+                                             Instruction *InsertBefore,
+                                             Context &Ctx, const Twine &Name) {
+  return createWithCopiedFlags(Op, LHS, RHS, CopyFrom,
+                               InsertBefore->getIterator(),
+                               InsertBefore->getParent(), Ctx, Name);
+}
+
+Value *BinaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
+                                             Value *RHS, Value *CopyFrom,
+                                             BasicBlock *InsertAtEnd,
+                                             Context &Ctx, const Twine &Name) {
+  return createWithCopiedFlags(Op, LHS, RHS, CopyFrom, InsertAtEnd->end(),
+                               InsertAtEnd, Ctx, Name);
+}
+
 void AtomicCmpXchgInst::setSyncScopeID(SyncScope::ID SSID) {
   Ctx.getTracker()
       .emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::getSyncScopeID,
@@ -1628,6 +1729,29 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         new GetElementPtrInst(LLVMGEP, *this));
     return It->second.get();
   }
+  case llvm::Instruction::Add:
+  case llvm::Instruction::FAdd:
+  case llvm::Instruction::Sub:
+  case llvm::Instruction::FSub:
+  case llvm::Instruction::Mul:
+  case llvm::Instruction::FMul:
+  case llvm::Instruction::UDiv:
+  case llvm::Instruction::SDiv:
+  case llvm::Instruction::FDiv:
+  case llvm::Instruction::URem:
+  case llvm::Instruction::SRem:
+  case llvm::Instruction::FRem:
+  case llvm::Instruction::Shl:
+  case llvm::Instruction::LShr:
+  case llvm::Instruction::AShr:
+  case llvm::Instruction::And:
+  case llvm::Instruction::Or:
+  case llvm::Instruction::Xor: {
+    auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(LLVMV);
+    It->second = std::unique_ptr<BinaryOperator>(
+        new BinaryOperator(LLVMBinaryOperator, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::AtomicCmpXchg: {
     auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV);
     It->second = std::unique_ptr<AtomicCmpXchgInst>(
@@ -1751,6 +1875,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
       std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
   return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
 }
+BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
+  auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
+  return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
+}
 AtomicCmpXchgInst *
 Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
   auto NewPtr =

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 1e4679ed6e802e..f5e555ba73287b 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -54,7 +54,6 @@ define void @foo(i32 %v1) {
   sandboxir::Argument *Arg0 = F->getArg(0);
   sandboxir::BasicBlock *BB = &*F->begin();
   sandboxir::Instruction *AddI = &*BB->begin();
-  sandboxir::OpaqueInst *OpaqueI = cast<sandboxir::OpaqueInst>(AddI);
   sandboxir::Constant *Const0 = cast<sandboxir::Constant>(Ctx.getValue(LLVMC));
 
   EXPECT_TRUE(isa<sandboxir::Function>(F));
@@ -62,42 +61,36 @@ define void @foo(i32 %v1) {
   EXPECT_FALSE(isa<sandboxir::Function>(BB));
   EXPECT_FALSE(isa<sandboxir::Function>(AddI));
   EXPECT_FALSE(isa<sandboxir::Function>(Const0));
-  EXPECT_FALSE(isa<sandboxir::Function>(OpaqueI));
 
   EXPECT_FALSE(isa<sandboxir::Argument>(F));
   EXPECT_TRUE(isa<sandboxir::Argument>(Arg0));
   EXPECT_FALSE(isa<sandboxir::Argument>(BB));
   EXPECT_FALSE(isa<sandboxir::Argument>(AddI));
   EXPECT_FALSE(isa<sandboxir::Argument>(Const0));
-  EXPECT_FALSE(isa<sandboxir::Argument>(OpaqueI));
 
   EXPECT_TRUE(isa<sandboxir::Constant>(F));
   EXPECT_FALSE(isa<sandboxir::Constant>(Arg0));
   EXPECT_FALSE(isa<sandboxir::Constant>(BB));
   EXPECT_FALSE(isa<sandboxir::Constant>(AddI));
   EXPECT_TRUE(isa<sandboxir::Constant>(Const0));
-  EXPECT_FALSE(isa<sandboxir::Constant>(OpaqueI));
 
   EXPECT_FALSE(isa<sandboxir::OpaqueInst>(F));
   EXPECT_FALSE(isa<sandboxir::OpaqueInst>(Arg0));
   EXPECT_FALSE(isa<sandboxir::OpaqueInst>(BB));
-  EXPECT_TRUE(isa<sandboxir::OpaqueInst>(AddI));
+  EXPECT_FALSE(isa<sandboxir::OpaqueInst>(AddI));
   EXPECT_FALSE(isa<sandboxir::OpaqueInst>(Const0));
-  EXPECT_TRUE(isa<sandboxir::OpaqueInst>(OpaqueI));
 
   EXPECT_FALSE(isa<sandboxir::Instruction>(F));
   EXPECT_FALSE(isa<sandboxir::Instruction>(Arg0));
   EXPECT_FALSE(isa<sandboxir::Instruction>(BB));
   EXPECT_TRUE(isa<sandboxir::Instruction>(AddI));
   EXPECT_FALSE(isa<sandboxir::Instruction>(Const0));
-  EXPECT_TRUE(isa<sandboxir::Instruction>(OpaqueI));
 
   EXPECT_TRUE(isa<sandboxir::User>(F));
   EXPECT_FALSE(isa<sandboxir::User>(Arg0));
   EXPECT_FALSE(isa<sandboxir::User>(BB));
   EXPECT_TRUE(isa<sandboxir::User>(AddI));
   EXPECT_TRUE(isa<sandboxir::User>(Const0));
-  EXPECT_TRUE(isa<sandboxir::User>(OpaqueI));
 
 #ifndef NDEBUG
   std::string Buff;
@@ -107,7 +100,6 @@ define void @foo(i32 %v1) {
   BB->dumpOS(BS);
   AddI->dumpOS(BS);
   Const0->dumpOS(BS);
-  OpaqueI->dumpOS(BS);
 #endif
 }
 
@@ -183,7 +175,7 @@ define i32 @foo(i32 %v0, i32 %v1) {
   I0->getOperandUse(0).dumpOS(BS);
   EXPECT_EQ(Buff, R"IR(
 Def:  i32 %v0 ; SB2. (Argument)
-User:   %add0 = add i32 %v0, %v1 ; SB5. (Opaque)
+User:   %add0 = add i32 %v0, %v1 ; SB5. (BinaryOperator)
 OperandNo: 0
 )IR");
 #endif // NDEBUG
@@ -508,8 +500,8 @@ define void @foo(i8 %v1) {
   EXPECT_EQ(Ret->getIterator(), std::next(BB->begin(), 2));
 
   // Check getOpcode().
-  EXPECT_EQ(I0->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
-  EXPECT_EQ(I1->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
+  EXPECT_EQ(I0->getOpcode(), sandboxir::Instruction::Opcode::Add);
+  EXPECT_EQ(I1->getOpcode(), sandboxir::Instruction::Opcode::Sub);
   EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Ret);
 
   // Check moveBefore(I).
@@ -1628,6 +1620,192 @@ define void @foo(i32 %arg, float %farg) {
   EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
 }
 
+TEST_F(SandboxIRTest, BinaryOperator) {
+  parseIR(C, R"IR(
+define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {
+  %add = add i8 %arg0, %arg1
+  %fadd = fadd float %farg0, %farg1
+  %sub = sub i8 %arg0, %arg1
+  %fsub = fsub float %farg0, %farg1
+  %mul = mul i8 %arg0, %arg1
+  %fmul = fmul float %farg0, %farg1
+  %udiv = udiv i8 %arg0, %arg1
+  %sdiv = sdiv i8 %arg0, %arg1
+  %fdiv = fdiv float %farg0, %farg1
+  %urem = urem i8 %arg0, %arg1
+  %srem = srem i8 %arg0, %arg1
+  %frem = frem float %farg0, %farg1
+  %shl = shl i8 %arg0, %arg1
+  %lshr = lshr i8 %arg0, %arg1
+  %ashr = ashr i8 %arg0, %arg1
+  %and = and i8 %arg0, %arg1
+  %or = or i8 %arg0, %arg1
+  %xor = xor i8 %arg0, %arg1
+
+  %copyfrom = add nsw i8 %arg0, %arg1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Arg0 = F.getArg(0);
+  auto *Arg1 = F.getArg(1);
+  auto *FArg0 = F.getArg(2);
+  auto *FArg1 = F.getArg(3);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+
+#define CHECK_IBINOP(OPCODE)                                                   \
+  {                                                                            \
+    auto *I = cast<sandboxir::BinaryOperator>(&*It++);                         \
+    EXPECT_EQ(I->getOpcode(), OPCODE);                                         \
+    EXPECT_EQ(I->getOperand(0), Arg0);                                         \
+    EXPECT_EQ(I->getOperand(1), Arg1);                                         \
+  }
+#define CHECK_FBINOP(OPCODE)                                                   \
+  {                                                                            \
+    auto *I = cast<sandboxir::BinaryOperator>(&*It++);                         \
+    EXPECT_EQ(I->getOpcode(), OPCODE);                                         \
+    EXPECT_EQ(I->getOperand(0), FArg0);                                        \
+    EXPECT_EQ(I->getOperand(1), FArg1);                                        \
+  }
+
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::Add);
+  CHECK_FBINOP(sandboxir::Instruction::Opcode::FAdd);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::Sub);
+  CHECK_FBINOP(sandboxir::Instruction::Opcode::FSub);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::Mul);
+  CHECK_FBINOP(sandboxir::Instruction::Opcode::FMul);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::UDiv);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::SDiv);
+  CHECK_FBINOP(sandboxir::Instruction::Opcode::FDiv);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::URem);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::SRem);
+  CHECK_FBINOP(sandboxir::Instruction::Opcode::FRem);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::Shl);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::LShr);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::AShr);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::And);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::Or);
+  CHECK_IBINOP(sandboxir::Instruction::Opcode::Xor);
+
+  auto *CopyFrom = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  {
+    // Check create() WhereIt, WhereBB.
+    auto *NewI =
+        cast<sandboxir::BinaryOperator>(sandboxir::BinaryOperator::create(
+            sandboxir::Instruction::Opcode::Add, Arg0, Arg1,
+            /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+            "New1"));
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::Add);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+    EXPECT_EQ(NewI->getOperand(1), Arg1);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "New1");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check create() InsertBefore.
+    auto *NewI =
+        cast<sandboxir::BinaryOperator>(sandboxir::BinaryOperator::create(
+            sandboxir::Instruction::Opcode::Add, Arg0, Arg1,
+            /*InsertBefore=*/Ret, Ctx, "New2"));
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::Add);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+    EXPECT_EQ(NewI->getOperand(1), Arg1);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "New2");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check create() InsertAtEnd.
+    auto *NewI =
+        cast<sandboxir::BinaryOperator>(sandboxir::BinaryOperator::create(
+            sandboxir::Instruction::Opcode::Add, Arg0, Arg1,
+            /*InsertAtEnd=*/BB, Ctx, "New3"));
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::Add);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+    EXPECT_EQ(NewI->getOperand(1), Arg1);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "New3");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), nullptr);
+    EXPECT_EQ(NewI->getParent(), BB);
+  }
+  {
+    // Check create() when it gets folded.
+    auto *FortyTwo =
+        sandboxir::Constant::createInt(Type::getInt32Ty(C), 42, Ctx);
+    auto *NewV = sandboxir::BinaryOperator::create(
+        sandboxir::Instruction::Opcode::Add, FortyTwo, FortyTwo,
+        /*InsertBefore=*/Ret, Ctx, "Folded");
+    EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
+  }
+
+  {
+    // Check createWithCopiedFlags() WhereIt, WhereBB.
+    auto *NewI = cast<sandboxir::BinaryOperator>(
+        sandboxir::BinaryOperator::createWithCopiedFlags(
+            sandboxir::Instruction::Opcode::Add, Arg0, Arg1, CopyFrom,
+            /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+            "NewNSW1"));
+    EXPECT_EQ(NewI->hasNoSignedWrap(), CopyFrom->hasNoSignedWrap());
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::Add);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+    EXPECT_EQ(NewI->getOperand(1), Arg1);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "NewNSW1");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check createWithCopiedFlags() InsertBefore.
+    auto *NewI = cast<sandboxir::BinaryOperator>(
+        sandboxir::BinaryOperator::createWithCopiedFlags(
+            sandboxir::Instruction::Opcode::Add, Arg0, Arg1, CopyFrom,
+            /*InsertBefore=*/Ret, Ctx, "NewNSW2"));
+    EXPECT_EQ(NewI->hasNoSignedWrap(), CopyFrom->hasNoSignedWrap());
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::Add);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+    EXPECT_EQ(NewI->getOperand(1), Arg1);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "NewNSW2");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check createWithCopiedFlags() InsertAtEnd.
+    auto *NewI = cast<sandboxir::BinaryOperator>(
+        sandboxir::BinaryOperator::createWithCopiedFlags(
+            sandboxir::Instruction::Opcode::Add, Arg0, Arg1, CopyFrom,
+            /*InsertAtEnd=*/BB, Ctx, "NewNSW3"));
+    EXPECT_EQ(NewI->hasNoSignedWrap(), CopyFrom->hasNoSignedWrap());
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::Add);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+    EXPECT_EQ(NewI->getOperand(1), Arg1);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "NewNSW3");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getParent(), BB);
+    EXPECT_EQ(NewI->getNextNode(), nullptr);
+  }
+  {
+    // Check createWithCopiedFlags() when it gets folded.
+    auto *FortyTwo =
+        sandboxir::Constant::createInt(Type::getInt32Ty(C), 42, Ctx);
+    auto *NewV = sandboxir::BinaryOperator::createWithCopiedFlags(
+        sandboxir::Instruction::Opcode::Add, FortyTwo, FortyTwo, CopyFrom,
+        /*InsertBefore=*/Ret, Ctx, "Folded");
+    EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
+  }
+}
+
 TEST_F(SandboxIRTest, AtomicCmpXchgInst) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr, i8 %cmp, i8 %new) {


        


More information about the llvm-commits mailing list