[llvm] ec29660 - [SandboxIR] Implement UnaryOperator (#104509)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 15 16:55:05 PDT 2024
Author: vporpo
Date: 2024-08-15T16:55:01-07:00
New Revision: ec29660c44e5e73d3b78f4884f9178036563fb25
URL: https://github.com/llvm/llvm-project/commit/ec29660c44e5e73d3b78f4884f9178036563fb25
DIFF: https://github.com/llvm/llvm-project/commit/ec29660c44e5e73d3b78f4884f9178036563fb25.diff
LOG: [SandboxIR] Implement UnaryOperator (#104509)
This patch implements sandboxir::UnaryOperator mirroring
llvm::UnaryOperator.
Added:
Modified:
llvm/include/llvm/SandboxIR/SandboxIR.h
llvm/include/llvm/SandboxIR/SandboxIRValues.def
llvm/lib/SandboxIR/SandboxIR.cpp
llvm/unittests/SandboxIR/SandboxIRTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index b1af769f29af54..423dad854a91cb 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -130,6 +130,7 @@ class CastInst;
class PtrToIntInst;
class BitCastInst;
class AllocaInst;
+class UnaryOperator;
class BinaryOperator;
class AtomicCmpXchgInst;
@@ -250,6 +251,7 @@ class Value {
friend class InvokeInst; // For getting `Val`.
friend class CallBrInst; // For getting `Val`.
friend class GetElementPtrInst; // For getting `Val`.
+ friend class UnaryOperator; // For getting `Val`.
friend class BinaryOperator; // For getting `Val`.
friend class AtomicCmpXchgInst; // For getting `Val`.
friend class AllocaInst; // For getting `Val`.
@@ -632,6 +634,7 @@ class Instruction : public sandboxir::User {
friend class InvokeInst; // For getTopmostLLVMInstruction().
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
+ friend class UnaryOperator; // For getTopmostLLVMInstruction().
friend class BinaryOperator; // For getTopmostLLVMInstruction().
friend class AtomicCmpXchgInst; // For getTopmostLLVMInstruction().
friend class AllocaInst; // For getTopmostLLVMInstruction().
@@ -1435,6 +1438,47 @@ class GetElementPtrInst final
// TODO: Add missing member functions.
};
+class UnaryOperator : public UnaryInstruction {
+ static Opcode getUnaryOpcode(llvm::Instruction::UnaryOps UnOp) {
+ switch (UnOp) {
+ case llvm::Instruction::FNeg:
+ return Opcode::FNeg;
+ case llvm::Instruction::UnaryOpsEnd:
+ llvm_unreachable("Bad UnOp!");
+ }
+ llvm_unreachable("Unhandled UnOp!");
+ }
+ UnaryOperator(llvm::UnaryOperator *UO, Context &Ctx)
+ : UnaryInstruction(ClassID::UnOp, getUnaryOpcode(UO->getOpcode()), UO,
+ Ctx) {}
+ friend Context; // for constructor.
+public:
+ static Value *create(Instruction::Opcode Op, Value *OpV, BBIterator WhereIt,
+ BasicBlock *WhereBB, Context &Ctx,
+ const Twine &Name = "");
+ static Value *create(Instruction::Opcode Op, Value *OpV,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name = "");
+ static Value *create(Instruction::Opcode Op, Value *OpV,
+ BasicBlock *InsertAtEnd, Context &Ctx,
+ const Twine &Name = "");
+ static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+ Value *CopyFrom, BBIterator WhereIt,
+ BasicBlock *WhereBB, Context &Ctx,
+ const Twine &Name = "");
+ static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+ Value *CopyFrom,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name = "");
+ static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+ Value *CopyFrom, BasicBlock *InsertAtEnd,
+ Context &Ctx, const Twine &Name = "");
+ /// For isa/dyn_cast.
+ static bool classof(const Value *From) {
+ return From->getSubclassID() == ClassID::UnOp;
+ }
+};
+
class BinaryOperator : public SingleLLVMInstructionImpl<llvm::BinaryOperator> {
static Opcode getBinOpOpcode(llvm::Instruction::BinaryOps BinOp) {
switch (BinOp) {
@@ -1959,6 +2003,8 @@ class Context {
friend CallBrInst; // For createCallBrInst()
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
friend GetElementPtrInst; // For createGetElementPtrInst()
+ UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);
+ friend UnaryOperator; // For createUnaryOperator()
BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
friend BinaryOperator; // For createBinaryOperator()
AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 43441ba26ec2cb..7332316c85026c 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -45,7 +45,10 @@ DEF_INSTR(Call, OP(Call), CallInst)
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
-DEF_INSTR(BinaryOperator, OPCODES( \
+DEF_INSTR(UnOp, OPCODES( \
+ OP(FNeg) \
+ ), UnaryOperator)
+DEF_INSTR(BinaryOperator, OPCODES(\
OP(Add) \
OP(FAdd) \
OP(Sub) \
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 52df53892b4506..67262bdd0dea99 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1219,6 +1219,71 @@ static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
}
}
+/// \Returns the LLVM opcode that corresponds to \p Opc.
+static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
+ switch (Opc) {
+ case Instruction::Opcode::FNeg:
+ return static_cast<llvm::Instruction::UnaryOps>(llvm::Instruction::FNeg);
+ default:
+ llvm_unreachable("Not a unary op!");
+ }
+}
+
+Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
+ BBIterator WhereIt, BasicBlock *WhereBB,
+ Context &Ctx, const Twine &Name) {
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ if (WhereIt == WhereBB->end())
+ Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+ else
+ Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+ auto *NewLLVMV = Builder.CreateUnOp(getLLVMUnaryOp(Op), OpV->Val, Name);
+ if (auto *NewUnOpV = dyn_cast<llvm::UnaryOperator>(NewLLVMV)) {
+ return Ctx.createUnaryOperator(NewUnOpV);
+ }
+ assert(isa<llvm::Constant>(NewLLVMV) && "Expected constant");
+ return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewLLVMV));
+}
+
+Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name) {
+ return create(Op, OpV, InsertBefore->getIterator(), InsertBefore->getParent(),
+ Ctx, Name);
+}
+
+Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
+ BasicBlock *InsertAfter, Context &Ctx,
+ const Twine &Name) {
+ return create(Op, OpV, InsertAfter->end(), InsertAfter, Ctx, Name);
+}
+
+Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+ Value *CopyFrom, BBIterator WhereIt,
+ BasicBlock *WhereBB, Context &Ctx,
+ const Twine &Name) {
+ auto *NewV = create(Op, OpV, WhereIt, WhereBB, Ctx, Name);
+ if (auto *UnI = dyn_cast<llvm::UnaryOperator>(NewV->Val))
+ UnI->copyIRFlags(CopyFrom->Val);
+ return NewV;
+}
+
+Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+ Value *CopyFrom,
+ Instruction *InsertBefore,
+ Context &Ctx, const Twine &Name) {
+ return createWithCopiedFlags(Op, OpV, CopyFrom, InsertBefore->getIterator(),
+ InsertBefore->getParent(), Ctx, Name);
+}
+
+Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
+ Value *CopyFrom,
+ BasicBlock *InsertAtEnd,
+ Context &Ctx, const Twine &Name) {
+ return createWithCopiedFlags(Op, OpV, CopyFrom, InsertAtEnd->end(),
+ InsertAtEnd, Ctx, Name);
+}
+
/// \Returns the LLVM opcode that corresponds to \p Opc.
static llvm::Instruction::BinaryOps getLLVMBinaryOp(Instruction::Opcode Opc) {
switch (Opc) {
@@ -1729,6 +1794,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
new GetElementPtrInst(LLVMGEP, *this));
return It->second.get();
}
+ case llvm::Instruction::FNeg: {
+ auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
+ It->second = std::unique_ptr<UnaryOperator>(
+ new UnaryOperator(LLVMUnaryOperator, *this));
+ return It->second.get();
+ }
case llvm::Instruction::Add:
case llvm::Instruction::FAdd:
case llvm::Instruction::Sub:
@@ -1875,6 +1946,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
}
+UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
+ auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
+ return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
+}
BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index f5e555ba73287b..3df335985aa705 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1620,6 +1620,132 @@ define void @foo(i32 %arg, float %farg) {
EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
}
+TEST_F(SandboxIRTest, UnaryOperator) {
+ parseIR(C, R"IR(
+define void @foo(float %arg0) {
+ %fneg = fneg float %arg0
+ %copyfrom = fadd reassoc float %arg0, 42.0
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto *Arg0 = F.getArg(0);
+ auto *BB = &*F.begin();
+ auto It = BB->begin();
+ auto *I = cast<sandboxir::UnaryOperator>(&*It++);
+ auto *CopyFrom = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *Ret = &*It++;
+ EXPECT_EQ(I->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+ EXPECT_EQ(I->getOperand(0), Arg0);
+
+ {
+ // Check create() WhereIt, WhereBB.
+ auto *NewI =
+ cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
+ sandboxir::Instruction::Opcode::FNeg, Arg0,
+ /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+ "New1"));
+ EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+ EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+ EXPECT_EQ(NewI->getName(), "New1");
+#endif // NDEBUG
+ EXPECT_EQ(NewI->getNextNode(), Ret);
+ }
+ {
+ // Check create() InsertBefore.
+ auto *NewI =
+ cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
+ sandboxir::Instruction::Opcode::FNeg, Arg0,
+ /*InsertBefore=*/Ret, Ctx, "New2"));
+ EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+ EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+ EXPECT_EQ(NewI->getName(), "New2");
+#endif // NDEBUG
+ EXPECT_EQ(NewI->getNextNode(), Ret);
+ }
+ {
+ // Check create() InsertAtEnd.
+ auto *NewI =
+ cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
+ sandboxir::Instruction::Opcode::FNeg, Arg0,
+ /*InsertAtEnd=*/BB, Ctx, "New3"));
+ EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+ EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+ EXPECT_EQ(NewI->getName(), "New3");
+#endif // NDEBUG
+ EXPECT_EQ(NewI->getParent(), BB);
+ EXPECT_EQ(NewI->getNextNode(), nullptr);
+ }
+ {
+ // Check create() when it gets folded.
+ auto *FortyTwo = CopyFrom->getOperand(1);
+ auto *NewV = sandboxir::UnaryOperator::create(
+ sandboxir::Instruction::Opcode::FNeg, FortyTwo,
+ /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+ "Folded");
+ EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
+ }
+
+ {
+ // Check createWithCopiedFlags() WhereIt, WhereBB.
+ auto *NewI = cast<sandboxir::UnaryOperator>(
+ sandboxir::UnaryOperator::createWithCopiedFlags(
+ sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
+ /*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
+ "NewCopyFrom1"));
+ EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
+ EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+ EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+ EXPECT_EQ(NewI->getName(), "NewCopyFrom1");
+#endif // NDEBUG
+ EXPECT_EQ(NewI->getNextNode(), Ret);
+ }
+ {
+ // Check createWithCopiedFlags() InsertBefore,
+ auto *NewI = cast<sandboxir::UnaryOperator>(
+ sandboxir::UnaryOperator::createWithCopiedFlags(
+ sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
+ /*InsertBefore=*/Ret, Ctx, "NewCopyFrom2"));
+ EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
+ EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+ EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+ EXPECT_EQ(NewI->getName(), "NewCopyFrom2");
+#endif // NDEBUG
+ EXPECT_EQ(NewI->getNextNode(), Ret);
+ }
+ {
+ // Check createWithCopiedFlags() InsertAtEnd,
+ auto *NewI = cast<sandboxir::UnaryOperator>(
+ sandboxir::UnaryOperator::createWithCopiedFlags(
+ sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
+ /*InsertAtEnd=*/BB, Ctx, "NewCopyFrom3"));
+ EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
+ EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
+ EXPECT_EQ(NewI->getOperand(0), Arg0);
+#ifndef NDEBUG
+ EXPECT_EQ(NewI->getName(), "NewCopyFrom3");
+#endif // NDEBUG
+ EXPECT_EQ(NewI->getParent(), BB);
+ EXPECT_EQ(NewI->getNextNode(), nullptr);
+ }
+ {
+ // Check createWithCopiedFlags() when it gets folded.
+ auto *FortyTwo = CopyFrom->getOperand(1);
+ auto *NewV = sandboxir::UnaryOperator::createWithCopiedFlags(
+ sandboxir::Instruction::Opcode::FNeg, FortyTwo, CopyFrom,
+ /*InsertAtEnd=*/BB, Ctx, "Folded");
+ EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
+ }
+}
+
TEST_F(SandboxIRTest, BinaryOperator) {
parseIR(C, R"IR(
define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {
More information about the llvm-commits
mailing list