[llvm] ec29660 - [SandboxIR] Implement UnaryOperator (#104509)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 16:55:05 PDT 2024


Author: vporpo
Date: 2024-08-15T16:55:01-07:00
New Revision: ec29660c44e5e73d3b78f4884f9178036563fb25

URL: https://github.com/llvm/llvm-project/commit/ec29660c44e5e73d3b78f4884f9178036563fb25
DIFF: https://github.com/llvm/llvm-project/commit/ec29660c44e5e73d3b78f4884f9178036563fb25.diff

LOG: [SandboxIR] Implement UnaryOperator (#104509)

This patch implements sandboxir::UnaryOperator mirroring
llvm::UnaryOperator.

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/include/llvm/SandboxIR/SandboxIRValues.def
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index b1af769f29af54..423dad854a91cb 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -130,6 +130,7 @@ class CastInst;
 class PtrToIntInst;
 class BitCastInst;
 class AllocaInst;
+class UnaryOperator;
 class BinaryOperator;
 class AtomicCmpXchgInst;
 
@@ -250,6 +251,7 @@ class Value {
   friend class InvokeInst;         // For getting `Val`.
   friend class CallBrInst;         // For getting `Val`.
   friend class GetElementPtrInst;  // For getting `Val`.
+  friend class UnaryOperator;      // For getting `Val`.
   friend class BinaryOperator;     // For getting `Val`.
   friend class AtomicCmpXchgInst;  // For getting `Val`.
   friend class AllocaInst;         // For getting `Val`.
@@ -632,6 +634,7 @@ class Instruction : public sandboxir::User {
   friend class InvokeInst;         // For getTopmostLLVMInstruction().
   friend class CallBrInst;         // For getTopmostLLVMInstruction().
   friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
+  friend class UnaryOperator;      // For getTopmostLLVMInstruction().
   friend class BinaryOperator;     // For getTopmostLLVMInstruction().
   friend class AtomicCmpXchgInst;  // For getTopmostLLVMInstruction().
   friend class AllocaInst;         // For getTopmostLLVMInstruction().
@@ -1435,6 +1438,47 @@ class GetElementPtrInst final
   // TODO: Add missing member functions.
 };
 
+class UnaryOperator : public UnaryInstruction {
+  static Opcode getUnaryOpcode(llvm::Instruction::UnaryOps UnOp) {
+    switch (UnOp) {
+    case llvm::Instruction::FNeg:
+      return Opcode::FNeg;
+    case llvm::Instruction::UnaryOpsEnd:
+      llvm_unreachable("Bad UnOp!");
+    }
+    llvm_unreachable("Unhandled UnOp!");
+  }
+  UnaryOperator(llvm::UnaryOperator *UO, Context &Ctx)
+      : UnaryInstruction(ClassID::UnOp, getUnaryOpcode(UO->getOpcode()), UO,
+                         Ctx) {}
+  friend Context; // for constructor.
+public:
+  static Value *create(Instruction::Opcode Op, Value *OpV, BBIterator WhereIt,
+                       BasicBlock *WhereBB, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Instruction::Opcode Op, Value *OpV,
+                       Instruction *InsertBefore, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Instruction::Opcode Op, Value *OpV,
+                       BasicBlock *InsertAtEnd, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                      Value *CopyFrom, BBIterator WhereIt,
+                                      BasicBlock *WhereBB, Context &Ctx,
+                                      const Twine &Name = "");
+  static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                      Value *CopyFrom,
+                                      Instruction *InsertBefore, Context &Ctx,
+                                      const Twine &Name = "");
+  static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                      Value *CopyFrom, BasicBlock *InsertAtEnd,
+                                      Context &Ctx, const Twine &Name = "");
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::UnOp;
+  }
+};
+
 class BinaryOperator : public SingleLLVMInstructionImpl<llvm::BinaryOperator> {
   static Opcode getBinOpOpcode(llvm::Instruction::BinaryOps BinOp) {
     switch (BinOp) {
@@ -1959,6 +2003,8 @@ class Context {
   friend CallBrInst; // For createCallBrInst()
   GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
   friend GetElementPtrInst; // For createGetElementPtrInst()
+  UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);
+  friend UnaryOperator; // For createUnaryOperator()
   BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
   friend BinaryOperator; // For createBinaryOperator()
   AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);

diff  --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 43441ba26ec2cb..7332316c85026c 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -45,7 +45,10 @@ DEF_INSTR(Call,          OP(Call),          CallInst)
 DEF_INSTR(Invoke,        OP(Invoke),        InvokeInst)
 DEF_INSTR(CallBr,        OP(CallBr),        CallBrInst)
 DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
-DEF_INSTR(BinaryOperator, OPCODES( \
+DEF_INSTR(UnOp,          OPCODES( \
+                         OP(FNeg) \
+                         ),                 UnaryOperator)
+DEF_INSTR(BinaryOperator, OPCODES(\
                          OP(Add)  \
                          OP(FAdd) \
                          OP(Sub)  \

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 52df53892b4506..67262bdd0dea99 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1219,6 +1219,71 @@ static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
   }
 }
 
+/// \Returns the LLVM opcode that corresponds to \p Opc.
+static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
+  switch (Opc) {
+  case Instruction::Opcode::FNeg:
+    return static_cast<llvm::Instruction::UnaryOps>(llvm::Instruction::FNeg);
+  default:
+    llvm_unreachable("Not a unary op!");
+  }
+}
+
+Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
+                             BBIterator WhereIt, BasicBlock *WhereBB,
+                             Context &Ctx, const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt == WhereBB->end())
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  else
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  auto *NewLLVMV = Builder.CreateUnOp(getLLVMUnaryOp(Op), OpV->Val, Name);
+  if (auto *NewUnOpV = dyn_cast<llvm::UnaryOperator>(NewLLVMV)) {
+    return Ctx.createUnaryOperator(NewUnOpV);
+  }
+  assert(isa<llvm::Constant>(NewLLVMV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewLLVMV));
+}
+
+Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
+                             Instruction *InsertBefore, Context &Ctx,
+                             const Twine &Name) {
+  return create(Op, OpV, InsertBefore->getIterator(), InsertBefore->getParent(),
+                Ctx, Name);
+}
+
+Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
+                             BasicBlock *InsertAfter, Context &Ctx,
+                             const Twine &Name) {
+  return create(Op, OpV, InsertAfter->end(), InsertAfter, Ctx, Name);
+}
+
+Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                            Value *CopyFrom, BBIterator WhereIt,
+                                            BasicBlock *WhereBB, Context &Ctx,
+                                            const Twine &Name) {
+  auto *NewV = create(Op, OpV, WhereIt, WhereBB, Ctx, Name);
+  if (auto *UnI = dyn_cast<llvm::UnaryOperator>(NewV->Val))
+    UnI->copyIRFlags(CopyFrom->Val);
+  return NewV;
+}
+
+Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                            Value *CopyFrom,
+                                            Instruction *InsertBefore,
+                                            Context &Ctx, const Twine &Name) {
+  return createWithCopiedFlags(Op, OpV, CopyFrom, InsertBefore->getIterator(),
+                               InsertBefore->getParent(), Ctx, Name);
+}
+
+Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                            Value *CopyFrom,
+                                            BasicBlock *InsertAtEnd,
+                                            Context &Ctx, const Twine &Name) {
+  return createWithCopiedFlags(Op, OpV, CopyFrom, InsertAtEnd->end(),
+                               InsertAtEnd, Ctx, Name);
+}
+
 /// \Returns the LLVM opcode that corresponds to \p Opc.
 static llvm::Instruction::BinaryOps getLLVMBinaryOp(Instruction::Opcode Opc) {
   switch (Opc) {
@@ -1729,6 +1794,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         new GetElementPtrInst(LLVMGEP, *this));
     return It->second.get();
   }
+  case llvm::Instruction::FNeg: {
+    auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
+    It->second = std::unique_ptr<UnaryOperator>(
+        new UnaryOperator(LLVMUnaryOperator, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::Add:
   case llvm::Instruction::FAdd:
   case llvm::Instruction::Sub:
@@ -1875,6 +1946,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
       std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
   return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
 }
+UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
+  auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
+  return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
+}
 BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
   auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
   return cast<BinaryOperator>(registerValue(std::move(NewPtr)));

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index f5e555ba73287b..3df335985aa705 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1620,6 +1620,132 @@ define void @foo(i32 %arg, float %farg) {
   EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
 }
 
+TEST_F(SandboxIRTest, UnaryOperator) {
+  parseIR(C, R"IR(
+define void @foo(float %arg0) {
+  %fneg = fneg float %arg0
+  %copyfrom = fadd reassoc float %arg0, 42.0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Arg0 = F.getArg(0);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *I = cast<sandboxir::UnaryOperator>(&*It++);
+  auto *CopyFrom = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *Ret = &*It++;
+  EXPECT_EQ(I->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+  EXPECT_EQ(I->getOperand(0), Arg0);
+
+  {
+    // Check create() WhereIt, WhereBB.
+    auto *NewI =
+        cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
+            sandboxir::Instruction::Opcode::FNeg, Arg0,
+            /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+            "New1"));
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "New1");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check create() InsertBefore.
+    auto *NewI =
+        cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
+            sandboxir::Instruction::Opcode::FNeg, Arg0,
+            /*InsertBefore=*/Ret, Ctx, "New2"));
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "New2");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check create() InsertAtEnd.
+    auto *NewI =
+        cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
+            sandboxir::Instruction::Opcode::FNeg, Arg0,
+            /*InsertAtEnd=*/BB, Ctx, "New3"));
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "New3");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getParent(), BB);
+    EXPECT_EQ(NewI->getNextNode(), nullptr);
+  }
+  {
+    // Check create() when it gets folded.
+    auto *FortyTwo = CopyFrom->getOperand(1);
+    auto *NewV = sandboxir::UnaryOperator::create(
+        sandboxir::Instruction::Opcode::FNeg, FortyTwo,
+        /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+        "Folded");
+    EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
+  }
+
+  {
+    // Check createWithCopiedFlags() WhereIt, WhereBB.
+    auto *NewI = cast<sandboxir::UnaryOperator>(
+        sandboxir::UnaryOperator::createWithCopiedFlags(
+            sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
+            /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+            "NewCopyFrom1"));
+    EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "NewCopyFrom1");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check createWithCopiedFlags() InsertBefore,
+    auto *NewI = cast<sandboxir::UnaryOperator>(
+        sandboxir::UnaryOperator::createWithCopiedFlags(
+            sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
+            /*InsertBefore=*/Ret, Ctx, "NewCopyFrom2"));
+    EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "NewCopyFrom2");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check createWithCopiedFlags() InsertAtEnd,
+    auto *NewI = cast<sandboxir::UnaryOperator>(
+        sandboxir::UnaryOperator::createWithCopiedFlags(
+            sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
+            /*InsertAtEnd=*/BB, Ctx, "NewCopyFrom3"));
+    EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "NewCopyFrom3");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getParent(), BB);
+    EXPECT_EQ(NewI->getNextNode(), nullptr);
+  }
+  {
+    // Check createWithCopiedFlags() when it gets folded.
+    auto *FortyTwo = CopyFrom->getOperand(1);
+    auto *NewV = sandboxir::UnaryOperator::createWithCopiedFlags(
+        sandboxir::Instruction::Opcode::FNeg, FortyTwo, CopyFrom,
+        /*InsertAtEnd=*/BB, Ctx, "Folded");
+    EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
+  }
+}
+
 TEST_F(SandboxIRTest, BinaryOperator) {
   parseIR(C, R"IR(
 define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {


        


More information about the llvm-commits mailing list