[llvm] [SandboxIR] Implement UnaryOperator (PR #104509)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 14:14:30 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/104509

This patch implements sandboxir::UnaryOperator mirroring llvm::UnaryOperator.

>From 6a8383630adc103876362d287b24a2854cb27790 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 12 Aug 2024 11:17:46 -0700
Subject: [PATCH] [SandboxIR] Implement UnaryOperator

This patch implements sandboxir::UnaryOperator mirroring llvm::UnaryOperator.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       |  46 +++++++
 .../llvm/SandboxIR/SandboxIRValues.def        |   5 +-
 llvm/lib/SandboxIR/SandboxIR.cpp              |  75 +++++++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 126 ++++++++++++++++++
 4 files changed, 251 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index b1af769f29af54..423dad854a91cb 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -130,6 +130,7 @@ class CastInst;
 class PtrToIntInst;
 class BitCastInst;
 class AllocaInst;
+class UnaryOperator;
 class BinaryOperator;
 class AtomicCmpXchgInst;
 
@@ -250,6 +251,7 @@ class Value {
   friend class InvokeInst;         // For getting `Val`.
   friend class CallBrInst;         // For getting `Val`.
   friend class GetElementPtrInst;  // For getting `Val`.
+  friend class UnaryOperator;      // For getting `Val`.
   friend class BinaryOperator;     // For getting `Val`.
   friend class AtomicCmpXchgInst;  // For getting `Val`.
   friend class AllocaInst;         // For getting `Val`.
@@ -632,6 +634,7 @@ class Instruction : public sandboxir::User {
   friend class InvokeInst;         // For getTopmostLLVMInstruction().
   friend class CallBrInst;         // For getTopmostLLVMInstruction().
   friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
+  friend class UnaryOperator;      // For getTopmostLLVMInstruction().
   friend class BinaryOperator;     // For getTopmostLLVMInstruction().
   friend class AtomicCmpXchgInst;  // For getTopmostLLVMInstruction().
   friend class AllocaInst;         // For getTopmostLLVMInstruction().
@@ -1435,6 +1438,47 @@ class GetElementPtrInst final
   // TODO: Add missing member functions.
 };
 
+class UnaryOperator : public UnaryInstruction {
+  static Opcode getUnaryOpcode(llvm::Instruction::UnaryOps UnOp) {
+    switch (UnOp) {
+    case llvm::Instruction::FNeg:
+      return Opcode::FNeg;
+    case llvm::Instruction::UnaryOpsEnd:
+      llvm_unreachable("Bad UnOp!");
+    }
+    llvm_unreachable("Unhandled UnOp!");
+  }
+  UnaryOperator(llvm::UnaryOperator *UO, Context &Ctx)
+      : UnaryInstruction(ClassID::UnOp, getUnaryOpcode(UO->getOpcode()), UO,
+                         Ctx) {}
+  friend Context; // for constructor.
+public:
+  static Value *create(Instruction::Opcode Op, Value *OpV, BBIterator WhereIt,
+                       BasicBlock *WhereBB, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Instruction::Opcode Op, Value *OpV,
+                       Instruction *InsertBefore, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Instruction::Opcode Op, Value *OpV,
+                       BasicBlock *InsertAtEnd, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                      Value *CopyFrom, BBIterator WhereIt,
+                                      BasicBlock *WhereBB, Context &Ctx,
+                                      const Twine &Name = "");
+  static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                      Value *CopyFrom,
+                                      Instruction *InsertBefore, Context &Ctx,
+                                      const Twine &Name = "");
+  static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                      Value *CopyFrom, BasicBlock *InsertAtEnd,
+                                      Context &Ctx, const Twine &Name = "");
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::UnOp;
+  }
+};
+
 class BinaryOperator : public SingleLLVMInstructionImpl<llvm::BinaryOperator> {
   static Opcode getBinOpOpcode(llvm::Instruction::BinaryOps BinOp) {
     switch (BinOp) {
@@ -1959,6 +2003,8 @@ class Context {
   friend CallBrInst; // For createCallBrInst()
   GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
   friend GetElementPtrInst; // For createGetElementPtrInst()
+  UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);
+  friend UnaryOperator; // For createUnaryOperator()
   BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
   friend BinaryOperator; // For createBinaryOperator()
   AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 43441ba26ec2cb..7332316c85026c 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -45,7 +45,10 @@ DEF_INSTR(Call,          OP(Call),          CallInst)
 DEF_INSTR(Invoke,        OP(Invoke),        InvokeInst)
 DEF_INSTR(CallBr,        OP(CallBr),        CallBrInst)
 DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
-DEF_INSTR(BinaryOperator, OPCODES( \
+DEF_INSTR(UnOp,          OPCODES( \
+                         OP(FNeg) \
+                         ),                 UnaryOperator)
+DEF_INSTR(BinaryOperator, OPCODES(\
                          OP(Add)  \
                          OP(FAdd) \
                          OP(Sub)  \
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 52df53892b4506..67262bdd0dea99 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1219,6 +1219,71 @@ static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
   }
 }
 
+/// \Returns the LLVM opcode that corresponds to \p Opc.
+static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
+  switch (Opc) {
+  case Instruction::Opcode::FNeg:
+    return static_cast<llvm::Instruction::UnaryOps>(llvm::Instruction::FNeg);
+  default:
+    llvm_unreachable("Not a unary op!");
+  }
+}
+
+Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
+                             BBIterator WhereIt, BasicBlock *WhereBB,
+                             Context &Ctx, const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt == WhereBB->end())
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  else
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  auto *NewLLVMV = Builder.CreateUnOp(getLLVMUnaryOp(Op), OpV->Val, Name);
+  if (auto *NewUnOpV = dyn_cast<llvm::UnaryOperator>(NewLLVMV)) {
+    return Ctx.createUnaryOperator(NewUnOpV);
+  }
+  assert(isa<llvm::Constant>(NewLLVMV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewLLVMV));
+}
+
+Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
+                             Instruction *InsertBefore, Context &Ctx,
+                             const Twine &Name) {
+  return create(Op, OpV, InsertBefore->getIterator(), InsertBefore->getParent(),
+                Ctx, Name);
+}
+
+Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
+                             BasicBlock *InsertAfter, Context &Ctx,
+                             const Twine &Name) {
+  return create(Op, OpV, InsertAfter->end(), InsertAfter, Ctx, Name);
+}
+
+Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                            Value *CopyFrom, BBIterator WhereIt,
+                                            BasicBlock *WhereBB, Context &Ctx,
+                                            const Twine &Name) {
+  auto *NewV = create(Op, OpV, WhereIt, WhereBB, Ctx, Name);
+  if (auto *UnI = dyn_cast<llvm::UnaryOperator>(NewV->Val))
+    UnI->copyIRFlags(CopyFrom->Val);
+  return NewV;
+}
+
+Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                            Value *CopyFrom,
+                                            Instruction *InsertBefore,
+                                            Context &Ctx, const Twine &Name) {
+  return createWithCopiedFlags(Op, OpV, CopyFrom, InsertBefore->getIterator(),
+                               InsertBefore->getParent(), Ctx, Name);
+}
+
+Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+                                            Value *CopyFrom,
+                                            BasicBlock *InsertAtEnd,
+                                            Context &Ctx, const Twine &Name) {
+  return createWithCopiedFlags(Op, OpV, CopyFrom, InsertAtEnd->end(),
+                               InsertAtEnd, Ctx, Name);
+}
+
 /// \Returns the LLVM opcode that corresponds to \p Opc.
 static llvm::Instruction::BinaryOps getLLVMBinaryOp(Instruction::Opcode Opc) {
   switch (Opc) {
@@ -1729,6 +1794,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         new GetElementPtrInst(LLVMGEP, *this));
     return It->second.get();
   }
+  case llvm::Instruction::FNeg: {
+    auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
+    It->second = std::unique_ptr<UnaryOperator>(
+        new UnaryOperator(LLVMUnaryOperator, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::Add:
   case llvm::Instruction::FAdd:
   case llvm::Instruction::Sub:
@@ -1875,6 +1946,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
       std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
   return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
 }
+UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
+  auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
+  return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
+}
 BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
   auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
   return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index f5e555ba73287b..3df335985aa705 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1620,6 +1620,132 @@ define void @foo(i32 %arg, float %farg) {
   EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
 }
 
+TEST_F(SandboxIRTest, UnaryOperator) {
+  parseIR(C, R"IR(
+define void @foo(float %arg0) {
+  %fneg = fneg float %arg0
+  %copyfrom = fadd reassoc float %arg0, 42.0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Arg0 = F.getArg(0);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *I = cast<sandboxir::UnaryOperator>(&*It++);
+  auto *CopyFrom = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *Ret = &*It++;
+  EXPECT_EQ(I->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+  EXPECT_EQ(I->getOperand(0), Arg0);
+
+  {
+    // Check create() WhereIt, WhereBB.
+    auto *NewI =
+        cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
+            sandboxir::Instruction::Opcode::FNeg, Arg0,
+            /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+            "New1"));
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "New1");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check create() InsertBefore.
+    auto *NewI =
+        cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
+            sandboxir::Instruction::Opcode::FNeg, Arg0,
+            /*InsertBefore=*/Ret, Ctx, "New2"));
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "New2");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check create() InsertAtEnd.
+    auto *NewI =
+        cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
+            sandboxir::Instruction::Opcode::FNeg, Arg0,
+            /*InsertAtEnd=*/BB, Ctx, "New3"));
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "New3");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getParent(), BB);
+    EXPECT_EQ(NewI->getNextNode(), nullptr);
+  }
+  {
+    // Check create() when it gets folded.
+    auto *FortyTwo = CopyFrom->getOperand(1);
+    auto *NewV = sandboxir::UnaryOperator::create(
+        sandboxir::Instruction::Opcode::FNeg, FortyTwo,
+        /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+        "Folded");
+    EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
+  }
+
+  {
+    // Check createWithCopiedFlags() WhereIt, WhereBB.
+    auto *NewI = cast<sandboxir::UnaryOperator>(
+        sandboxir::UnaryOperator::createWithCopiedFlags(
+            sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
+            /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+            "NewCopyFrom1"));
+    EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "NewCopyFrom1");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check createWithCopiedFlags() InsertBefore,
+    auto *NewI = cast<sandboxir::UnaryOperator>(
+        sandboxir::UnaryOperator::createWithCopiedFlags(
+            sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
+            /*InsertBefore=*/Ret, Ctx, "NewCopyFrom2"));
+    EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "NewCopyFrom2");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+  }
+  {
+    // Check createWithCopiedFlags() InsertAtEnd,
+    auto *NewI = cast<sandboxir::UnaryOperator>(
+        sandboxir::UnaryOperator::createWithCopiedFlags(
+            sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
+            /*InsertAtEnd=*/BB, Ctx, "NewCopyFrom3"));
+    EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+    EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+    EXPECT_EQ(NewI->getName(), "NewCopyFrom3");
+#endif // NDEBUG
+    EXPECT_EQ(NewI->getParent(), BB);
+    EXPECT_EQ(NewI->getNextNode(), nullptr);
+  }
+  {
+    // Check createWithCopiedFlags() when it gets folded.
+    auto *FortyTwo = CopyFrom->getOperand(1);
+    auto *NewV = sandboxir::UnaryOperator::createWithCopiedFlags(
+        sandboxir::Instruction::Opcode::FNeg, FortyTwo, CopyFrom,
+        /*InsertAtEnd=*/BB, Ctx, "Folded");
+    EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
+  }
+}
+
 TEST_F(SandboxIRTest, BinaryOperator) {
   parseIR(C, R"IR(
 define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {



More information about the llvm-commits mailing list