[llvm] [SandboxIR] Implement CmpInst, FCmpInst, and ICmpInst (PR #106301)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 3 17:43:39 PDT 2024
https://github.com/Sterling-Augustine updated https://github.com/llvm/llvm-project/pull/106301
>From e91c6ce666144c102e889f51c13fa13bacead8a9 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Thu, 8 Aug 2024 23:08:25 +0000
Subject: [PATCH 1/7] [SandboxIR] Implement CmpInst, FCmpInst, and ICmpInst
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 212 +++++++++++++++++-
.../llvm/SandboxIR/SandboxIRValues.def | 3 +
llvm/include/llvm/SandboxIR/Tracker.h | 28 +++
llvm/lib/SandboxIR/SandboxIR.cpp | 62 +++++
llvm/lib/SandboxIR/Tracker.cpp | 18 ++
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 127 +++++++++++
llvm/unittests/SandboxIR/TrackerTest.cpp | 43 ++++
7 files changed, 488 insertions(+), 5 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index b7bdf9acd2ef45..b47b6ee512471e 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -54,11 +54,17 @@
// | |
// | +- ZExtInst
// |
-// +- CallBase -----------+- CallBrInst
-// | |
-// +- CmpInst +- CallInst
-// | |
-// +- ExtractElementInst +- InvokeInst
+// +- CallBase --------+- CallBrInst
+// | |
+// | +- CallInst
+// | |
+// | +- InvokeInst
+// |
+// +- CmpInst ---------+- ICmpInst
+// | |
+// | +- FCmpInst
+// |
+// +- ExtractElementInst
// |
// +- GetElementPtrInst
// |
@@ -150,6 +156,9 @@ class BinaryOperator;
class PossiblyDisjointInst;
class AtomicRMWInst;
class AtomicCmpXchgInst;
+class CmpInst;
+class ICmpInst;
+class FCmpInst;
/// Iterator for the `Use` edges of a User's operands.
/// \Returns the operand `Use` when dereferenced.
@@ -294,6 +303,9 @@ class Value {
friend class PHINode; // For getting `Val`.
friend class UnreachableInst; // For getting `Val`.
friend class CatchSwitchAddHandler; // For `Val`.
+ friend class CmpInst; // For getting `Val`.
+ friend class ICmpInst; // For getting `Val`.
+ friend class FCmpInst; // For getting `Val`.
/// All values point to the context.
Context &Ctx;
@@ -730,6 +742,9 @@ class Instruction : public sandboxir::User {
friend class CastInst; // For getTopmostLLVMInstruction().
friend class PHINode; // For getTopmostLLVMInstruction().
friend class UnreachableInst; // For getTopmostLLVMInstruction().
+ friend class CmpInst; // For getTopmostLLVMInstruction().
+ friend class ICmpInst; // For getTopmostLLVMInstruction().
+ friend class FCmpInst; // For getTopmostLLVMInstruction().
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
@@ -2974,6 +2989,187 @@ class PHINode final : public SingleLLVMInstructionImpl<llvm::PHINode> {
// uint32_t ToIdx = 0)
};
+// Wraps a static function that takes a single Predicate parameter
+// LLVMValType should be the type of the wrapped class
+#define WRAP_STATIC_PREDICATE(FunctionName) \
+ static auto FunctionName(Predicate P) { return LLVMValType::FunctionName(P); }
+// Wraps a member function that takes no parameters
+// LLVMValType should be the type of the wrapped class
+#define WRAP_MEMBER(FunctionName) \
+ auto FunctionName() { return cast<LLVMValType>(Val)->FunctionName(); }
+// Wraps both--a common idiom in the CmpInst classes
+#define WRAP_BOTH(FunctionName) \
+ WRAP_STATIC_PREDICATE(FunctionName) \
+ WRAP_MEMBER(FunctionName)
+
+class CmpInst : public Instruction {
+protected:
+ using LLVMValType = llvm::CmpInst;
+ /// Use Context::createCmpInst(). Don't call the
+ /// constructor directly.
+ CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id = ClassID::Cmp,
+ Opcode Opc = Opcode::Cmp)
+ : Instruction(Id, Opc, CI, Ctx) {}
+ friend Context; // for CmpInst()
+ Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+ return getOperandUseDefault(OpIdx, Verify);
+ }
+ SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+ return {cast<llvm::Instruction>(Val)};
+ }
+ static Value *createCommon(Value *Cond, Value *True, Value *False,
+ const Twine &Name, IRBuilder<> &Builder,
+ Context &Ctx);
+
+public:
+ using Predicate = llvm::CmpInst::Predicate;
+ using PredicateField = llvm::CmpInst::PredicateField;
+ using OtherOps = llvm::Instruction::OtherOps;
+
+ unsigned getUseOperandNo(const Use &Use) const final {
+ return getUseOperandNoDefault(Use);
+ }
+ unsigned getNumOfIRInstrs() const final { return 1u; }
+ static CmpInst *create(OtherOps Op, Predicate Pred, Value *S1, Value *S2,
+ const Twine &Name = "",
+ InsertPosition InsertBefore = nullptr);
+ static CmpInst *CreateWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
+ Value *S2,
+ const Instruction *FlagsSource,
+ const Twine &Name = "",
+ InsertPosition InsertBefore = nullptr);
+ OtherOps getOpcode() const {
+ return static_cast<OtherOps>(Instruction::getOpcode());
+ }
+ void setPredicate(Predicate P);
+ void swapOperands();
+
+ WRAP_MEMBER(getPredicate);
+ WRAP_BOTH(isFPPredicate);
+ WRAP_BOTH(isIntPredicate);
+ WRAP_STATIC_PREDICATE(getPredicateName);
+ WRAP_BOTH(getInversePredicate);
+ WRAP_BOTH(getOrderedPredicate);
+ WRAP_BOTH(getUnorderedPredicate);
+ WRAP_BOTH(getSwappedPredicate);
+ WRAP_BOTH(isStrictPredicate);
+ WRAP_BOTH(isNonStrictPredicate);
+ WRAP_BOTH(getStrictPredicate);
+ WRAP_BOTH(getNonStrictPredicate);
+ WRAP_BOTH(getFlippedStrictnessPredicate);
+ WRAP_MEMBER(isCommutative);
+ WRAP_BOTH(isEquality);
+ WRAP_BOTH(isRelational);
+ WRAP_BOTH(isSigned);
+ WRAP_BOTH(getSignedPredicate);
+ WRAP_BOTH(getUnsignedPredicate);
+ WRAP_BOTH(getFlippedSignednessPredicate);
+ WRAP_BOTH(isTrueWhenEqual);
+ WRAP_BOTH(isFalseWhenEqual);
+ WRAP_BOTH(isUnsigned);
+ WRAP_STATIC_PREDICATE(isOrdered);
+ WRAP_STATIC_PREDICATE(isUnordered);
+
+ static bool isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+ return llvm::CmpInst::isImpliedTrueByMatchingCmp(Pred1, Pred2);
+ }
+ static bool isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+ return llvm::CmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
+ }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const Instruction *From) {
+ return isa<ICmpInst>(From) || isa<FCmpInst>(From);
+ }
+ static bool classof(const Value *From) {
+ return isa<Instruction>(From) && classof(cast<Instruction>(From));
+ }
+ /// Create a result type for fcmp/icmp
+ static Type *makeCmpResultType(Type *opnd_type) {
+ if (VectorType *vt = dyn_cast<VectorType>(opnd_type)) {
+ return VectorType::get(Type::getInt1Ty(opnd_type->getContext()),
+ vt->getElementCount());
+ }
+ return Type::getInt1Ty(opnd_type->getContext());
+ }
+
+#ifndef NDEBUG
+ void verify() const final {
+ assert(isa<LLVMValType>(Val) && "Expected CmpInst!");
+ }
+ void dumpOS(raw_ostream &OS) const override;
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+class ICmpInst : public CmpInst {
+ /// Use Context::createICmpInst(). Don't call the
+ /// constructor directly.
+ ICmpInst(llvm::ICmpInst *CI, Context &Ctx)
+ : CmpInst(CI, Ctx, ClassID::ICmp, Opcode::ICmp) {}
+ friend class Context; // For constructor.
+ using LLVMValType = llvm::ICmpInst;
+
+public:
+ void swapOperands();
+
+ WRAP_BOTH(getSignedPredicate);
+ WRAP_BOTH(getUnsignedPredicate);
+ WRAP_BOTH(isEquality);
+ WRAP_MEMBER(isCommutative);
+ WRAP_MEMBER(isRelational);
+ WRAP_STATIC_PREDICATE(isGT);
+ WRAP_STATIC_PREDICATE(isLT);
+ WRAP_STATIC_PREDICATE(isGE);
+ WRAP_STATIC_PREDICATE(isLE);
+
+ static auto predicates() { return llvm::ICmpInst::predicates(); }
+ static bool compare(const APInt &LHS, const APInt &RHS,
+ ICmpInst::Predicate Pred) {
+ return llvm::ICmpInst::compare(LHS, RHS, Pred);
+ }
+
+ static bool classof(const Instruction *From) {
+ return From->getSubclassID() == ClassID::ICmp;
+ }
+ static bool classof(const Value *From) {
+ return isa<Instruction>(From) && classof(cast<Instruction>(From));
+ }
+};
+
+class FCmpInst : public CmpInst {
+ /// Use Context::createFCmpInst(). Don't call the
+ /// constructor directly.
+ FCmpInst(llvm::FCmpInst *CI, Context &Ctx)
+ : CmpInst(CI, Ctx, ClassID::FCmp, Opcode::FCmp) {}
+ friend class Context; // For constructor.
+ using LLVMValType = llvm::FCmpInst;
+
+public:
+ void swapOperands();
+
+ WRAP_BOTH(isEquality);
+ WRAP_MEMBER(isCommutative);
+ WRAP_MEMBER(isRelational);
+
+ static auto predicates() { return llvm::FCmpInst::predicates(); }
+ static bool compare(const APFloat &LHS, const APFloat &RHS,
+ FCmpInst::Predicate Pred) {
+ return llvm::FCmpInst::compare(LHS, RHS, Pred);
+ }
+
+ static bool classof(const Instruction *From) {
+ return From->getSubclassID() == ClassID::FCmp;
+ }
+ static bool classof(const Value *From) {
+ return From->getSubclassID() == ClassID::FCmp;
+ }
+};
+
+#undef WRAP_STATIC_PREDICATE
+#undef WRAP_MEMBER
+#undef WRAP_BOTH
+
/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
/// an OpaqueInstr.
class OpaqueInst : public SingleLLVMInstructionImpl<llvm::Instruction> {
@@ -3101,6 +3297,12 @@ class Context {
friend PHINode; // For createPHINode()
UnreachableInst *createUnreachableInst(llvm::UnreachableInst *UI);
friend UnreachableInst; // For createUnreachableInst()
+ CmpInst *createCmpInst(llvm::CmpInst *I);
+ friend CmpInst; // For createCmpInst()
+ ICmpInst *createICmpInst(llvm::ICmpInst *I);
+ friend ICmpInst; // For createICmpInst()
+ FCmpInst *createFCmpInst(llvm::FCmpInst *I);
+ friend FCmpInst; // For createFCmpInst()
public:
Context(LLVMContext &LLVMCtx)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 00c1a6333c8ec4..ac0b942a7cf67e 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -107,6 +107,9 @@ DEF_INSTR(Cast, OPCODES(\
), CastInst)
DEF_INSTR(PHI, OP(PHI), PHINode)
DEF_INSTR(Unreachable, OP(Unreachable), UnreachableInst)
+DEF_INSTR(Cmp, OP(Cmp), CmpInst)
+DEF_INSTR(ICmp, OP(ICmp), FCmpInst)
+DEF_INSTR(FCmp, OP(FCmp), ICmpInst)
// clang-format on
#ifdef DEF_VALUE
diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index c8a9e99a34341d..d3d0bf3dfe7803 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -63,6 +63,7 @@ class CatchSwitchInst;
class SwitchInst;
class ConstantInt;
class ShuffleVectorInst;
+class CmpInst;
/// The base class for IR Change classes.
class IRChangeBase {
@@ -130,6 +131,33 @@ class PHIAddIncoming : public IRChangeBase {
#endif
};
+class CmpSetPredicate : public IRChangeBase {
+ CmpInst *Cmp;
+ llvm::CmpInst::Predicate OldP;
+
+public:
+ CmpSetPredicate(CmpInst *Cmp);
+ void revert(Tracker &Tracker) final;
+ void accept() final {}
+#ifndef NDEBUG
+ void dump(raw_ostream &OS) const final { OS << "CmpSetPredicate"; }
+ LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
+class CmpSwapOperands : public IRChangeBase {
+ CmpInst *Cmp;
+
+public:
+ CmpSwapOperands(CmpInst *Cmp);
+ void revert(Tracker &Tracker) final;
+ void accept() final {}
+#ifndef NDEBUG
+ void dump(raw_ostream &OS) const final { OS << "CmpSwapOperands"; }
+ LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
/// Tracks swapping a Use with another Use.
class UseSwap : public IRChangeBase {
Use ThisUse;
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index b75424909f0835..29a92aa82572b8 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2462,6 +2462,16 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
return It->second.get();
}
+ case llvm::Instruction::ICmp: {
+ auto *LLVMICmp = cast<llvm::ICmpInst>(LLVMV);
+ It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
+ return It->second.get();
+ }
+ case llvm::Instruction::FCmp: {
+ auto *LLVMFCmp = cast<llvm::FCmpInst>(LLVMV);
+ It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
+ return It->second.get();
+ }
case llvm::Instruction::Unreachable: {
auto *LLVMUnreachable = cast<llvm::UnreachableInst>(LLVMV);
It->second = std::unique_ptr<UnreachableInst>(
@@ -2637,6 +2647,58 @@ PHINode *Context::createPHINode(llvm::PHINode *I) {
auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
return cast<PHINode>(registerValue(std::move(NewPtr)));
}
+CmpInst *Context::createCmpInst(llvm::CmpInst *I) {
+ auto NewPtr = std::unique_ptr<CmpInst>(new CmpInst(I, *this));
+ return cast<CmpInst>(registerValue(std::move(NewPtr)));
+}
+ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
+ auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
+ return cast<ICmpInst>(registerValue(std::move(NewPtr)));
+}
+FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
+ auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
+ return cast<FCmpInst>(registerValue(std::move(NewPtr)));
+}
+
+void CmpInst::setPredicate(Predicate P) {
+ auto &Tracker = Ctx.getTracker();
+ if (Tracker.isTracking())
+ Tracker.track(std::make_unique<CmpSetPredicate>(this));
+ cast<llvm::CmpInst>(Val)->setPredicate(P);
+}
+
+void CmpInst::swapOperands() {
+ if (ICmpInst *IC = dyn_cast<ICmpInst>(this))
+ IC->swapOperands();
+ else
+ cast<FCmpInst>(this)->swapOperands();
+}
+
+void ICmpInst::swapOperands() {
+ auto &Tracker = Ctx.getTracker();
+ if (Tracker.isTracking())
+ Tracker.track(std::make_unique<CmpSwapOperands>(this));
+ cast<llvm::ICmpInst>(Val)->swapOperands();
+}
+
+void FCmpInst::swapOperands() {
+ auto &Tracker = Ctx.getTracker();
+ if (Tracker.isTracking())
+ Tracker.track(std::make_unique<CmpSwapOperands>(this));
+ cast<llvm::FCmpInst>(Val)->swapOperands();
+}
+
+#ifndef NDEBUG
+void CmpInst::dumpOS(raw_ostream &OS) const {
+ dumpCommonPrefix(OS);
+ dumpCommonSuffix(OS);
+}
+
+void CmpInst::dump() const {
+ dumpOS(dbgs());
+ dbgs() << "\n";
+}
+#endif // NDEBUG
Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index 953d4bd51353a9..9a824d5211a407 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -248,6 +248,24 @@ void ShuffleVectorSetMask::dump() const {
}
#endif
+CmpSetPredicate::CmpSetPredicate(CmpInst *Cmp)
+ : IRChangeBase(), Cmp(Cmp), OldP(Cmp->getPredicate()) {}
+
+void CmpSetPredicate::revert(Tracker &Tracker) { Cmp->setPredicate(OldP); }
+
+void CmpSetPredicate::dump() const {
+ dump(dbgs());
+ dbgs() << "\n";
+}
+
+CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
+
+void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
+void CmpSwapOperands::dump() const {
+ dump(dbgs());
+ dbgs() << "\n";
+}
+
void Tracker::save() { State = TrackerState::Record; }
void Tracker::revert() {
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index bc3fddf9e163dc..5bee722d2b953d 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4188,6 +4188,133 @@ define void @foo(i32 %arg) {
EXPECT_EQ(NewPHI->getNumIncomingValues(), PHI->getNumIncomingValues());
}
+void checkSwapOperands(sandboxir::Context &Ctx, llvm::sandboxir::CmpInst *SBCmp,
+ llvm::CmpInst *LLVMCmp) {
+ auto OrigOp0 = SBCmp->getOperand(0);
+ auto OrigOp1 = SBCmp->getOperand(1);
+ EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp0);
+ EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp1);
+ // This checks the dispatch mechanism in CmpInst, as well as
+ // the specific implementations.
+ SBCmp->swapOperands();
+ EXPECT_NE(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp0);
+ EXPECT_NE(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp1);
+ EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp0);
+ EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp1);
+ // Undo it to keep the rest of the test consistent
+ SBCmp->swapOperands();
+}
+
+void checkCommonPredicates(sandboxir::CmpInst *SBCmp, llvm::CmpInst *LLVMCmp) {
+ // Check proper creation
+ auto SBPred = SBCmp->getPredicate();
+ auto LLVMPred = LLVMCmp->getPredicate();
+ EXPECT_EQ(SBPred, LLVMPred);
+ // Check setPredicate
+ SBCmp->setPredicate(llvm::CmpInst::FCMP_FALSE);
+ EXPECT_EQ(LLVMCmp->getPredicate(), llvm::CmpInst::FCMP_FALSE);
+ SBCmp->setPredicate(SBPred);
+ EXPECT_EQ(LLVMCmp->getPredicate(), SBPred);
+ // Ensure the accessors properly forward to the underlying implementation
+ EXPECT_STREQ(sandboxir::CmpInst::getPredicateName(SBPred).data(),
+ llvm::CmpInst::getPredicateName(LLVMPred).data());
+ EXPECT_EQ(SBCmp->isFPPredicate(), LLVMCmp->isFPPredicate());
+ EXPECT_EQ(SBCmp->isIntPredicate(), LLVMCmp->isIntPredicate());
+ EXPECT_EQ(SBCmp->getInversePredicate(), LLVMCmp->getInversePredicate());
+ EXPECT_EQ(SBCmp->getOrderedPredicate(), LLVMCmp->getOrderedPredicate());
+ EXPECT_EQ(SBCmp->getUnorderedPredicate(), LLVMCmp->getUnorderedPredicate());
+ EXPECT_EQ(SBCmp->getSwappedPredicate(), LLVMCmp->getSwappedPredicate());
+ EXPECT_EQ(SBCmp->isStrictPredicate(), LLVMCmp->isStrictPredicate());
+ EXPECT_EQ(SBCmp->isNonStrictPredicate(), LLVMCmp->isNonStrictPredicate());
+ EXPECT_EQ(SBCmp->isRelational(), LLVMCmp->isRelational());
+ if (SBCmp->isRelational()) {
+ EXPECT_EQ(SBCmp->getFlippedStrictnessPredicate(),
+ LLVMCmp->getFlippedStrictnessPredicate());
+ }
+ EXPECT_EQ(SBCmp->isCommutative(), LLVMCmp->isCommutative());
+ EXPECT_EQ(SBCmp->isTrueWhenEqual(), LLVMCmp->isTrueWhenEqual());
+ EXPECT_EQ(SBCmp->isFalseWhenEqual(), LLVMCmp->isFalseWhenEqual());
+ EXPECT_EQ(sandboxir::CmpInst::isOrdered(SBPred),
+ llvm::CmpInst::isOrdered(LLVMPred));
+ EXPECT_EQ(sandboxir::CmpInst::isUnordered(SBPred),
+ llvm::CmpInst::isUnordered(LLVMPred));
+}
+
+TEST_F(SandboxIRTest, ICmpInst) {
+ SCOPED_TRACE("SandboxIRTest sandboxir::ICmpInst tests");
+ parseIR(C, R"IR(
+define void @foo(float %f0, float %f1, i32 %i0, i32 %i1) {
+ bb:
+ %ine = icmp ne i32 %i0, %i1
+ %iugt = icmp ugt i32 %i0, %i1
+ %iuge = icmp uge i32 %i0, %i1
+ %iult = icmp ult i32 %i0, %i1
+ %iule = icmp ule i32 %i0, %i1
+ %isgt = icmp sgt i32 %i0, %i1
+ %isle = icmp sle i32 %i0, %i1
+ %ieg = icmp eq i32 %i0, %i1
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+
+ auto *LLVMBB = getBasicBlockByName(LLVMF, "bb");
+ auto LLVMIt = LLVMBB->begin();
+ auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
+ auto It = BB->begin();
+ // Check classof()
+ while (auto *ICmp = dyn_cast<sandboxir::ICmpInst>(&*It++)) {
+ auto *LLVMICmp = cast<llvm::ICmpInst>(&*LLVMIt++);
+ checkSwapOperands(Ctx, ICmp, LLVMICmp);
+ checkCommonPredicates(ICmp, LLVMICmp);
+ EXPECT_EQ(ICmp->isSigned(), LLVMICmp->isSigned());
+ EXPECT_EQ(ICmp->isUnsigned(), LLVMICmp->isUnsigned());
+ EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
+ EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
+ }
+}
+
+TEST_F(SandboxIRTest, FCmpInst) {
+ SCOPED_TRACE("SandboxIRTest sandboxir::FCmpInst tests");
+ parseIR(C, R"IR(
+define void @foo(float %f0, float %f1) {
+bb:
+ %ffalse = fcmp false float %f0, %f1
+ %foeq = fcmp oeq float %f0, %f1
+ %fogt = fcmp ogt float %f0, %f1
+ %folt = fcmp olt float %f0, %f1
+ %fole = fcmp ole float %f0, %f1
+ %fone = fcmp one float %f0, %f1
+ %ford = fcmp ord float %f0, %f1
+ %funo = fcmp uno float %f0, %f1
+ %fueq = fcmp ueq float %f0, %f1
+ %fugt = fcmp ugt float %f0, %f1
+ %fuge = fcmp uge float %f0, %f1
+ %fult = fcmp ult float %f0, %f1
+ %fule = fcmp ule float %f0, %f1
+ %fune = fcmp une float %f0, %f1
+ %ftrue = fcmp true float %f0, %f1
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+
+ auto *LLVMBB = getBasicBlockByName(LLVMF, "bb");
+ auto LLVMIt = LLVMBB->begin();
+ auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
+ auto It = BB->begin();
+ // Check classof()
+ while (auto *FCmp = dyn_cast<sandboxir::ICmpInst>(&*It++)) {
+ auto *LLVMFCmp = cast<llvm::ICmpInst>(&*LLVMIt++);
+ checkSwapOperands(Ctx, FCmp, LLVMFCmp);
+ checkCommonPredicates(FCmp, LLVMFCmp);
+ }
+}
+
TEST_F(SandboxIRTest, UnreachableInst) {
parseIR(C, R"IR(
define void @foo() {
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index ca6effb727bf37..b4905c2231e3c8 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -1439,6 +1439,49 @@ define void @foo(i8 %arg0, i8 %arg1, i8 %arg2) {
EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
}
+void checkCmpInst(sandboxir::Context &Ctx, sandboxir::CmpInst *Cmp) {
+ Ctx.save();
+ auto OrigP = Cmp->getPredicate();
+ auto NewP = Cmp->getSwappedPredicate();
+ Cmp->setPredicate(NewP);
+ EXPECT_EQ(Cmp->getPredicate(), NewP);
+ Ctx.revert();
+ EXPECT_EQ(Cmp->getPredicate(), OrigP);
+
+ Ctx.save();
+ auto OrigOp0 = Cmp->getOperand(0);
+ auto OrigOp1 = Cmp->getOperand(1);
+ Cmp->swapOperands();
+ EXPECT_EQ(Cmp->getPredicate(), NewP);
+ EXPECT_EQ(Cmp->getOperand(0), OrigOp1);
+ EXPECT_EQ(Cmp->getOperand(1), OrigOp0);
+ Ctx.revert();
+ EXPECT_EQ(Cmp->getPredicate(), OrigP);
+ EXPECT_EQ(Cmp->getOperand(0), OrigOp0);
+ EXPECT_EQ(Cmp->getOperand(1), OrigOp1);
+}
+
+TEST_F(TrackerTest, CmpInst) {
+ SCOPED_TRACE("TrackerTest sandboxir::CmpInst tests");
+ parseIR(C, R"IR(
+define void @foo(i64 %i0, i64 %i1, float %f0, float %f1) {
+ %foeq = fcmp ogt float %f0, %f1
+ %ioeq = icmp uge i64 %i0, %i1
+
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto *BB = &*F.begin();
+ auto It = BB->begin();
+ auto *FCmp = cast<sandboxir::FCmpInst>(&*It++);
+ checkCmpInst(Ctx, FCmp);
+ auto *ICmp = cast<sandboxir::ICmpInst>(&*It++);
+ checkCmpInst(Ctx, ICmp);
+}
+
TEST_F(TrackerTest, SetVolatile) {
parseIR(C, R"IR(
define void @foo(ptr %arg0, i8 %val) {
>From 3bf39eba11917fe6e174b546b16f9f9236191d32 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Wed, 28 Aug 2024 23:22:21 +0000
Subject: [PATCH 2/7] Address comments
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 49 +++++++------------
.../llvm/SandboxIR/SandboxIRValues.def | 1 -
llvm/lib/SandboxIR/SandboxIR.cpp | 27 ++++++++--
llvm/unittests/SandboxIR/TrackerTest.cpp | 4 +-
4 files changed, 43 insertions(+), 38 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 55f0d0489e44ed..e2a3cf1902189b 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -312,8 +312,6 @@ class Value {
friend class UnreachableInst; // For getting `Val`.
friend class CatchSwitchAddHandler; // For `Val`.
friend class CmpInst; // For getting `Val`.
- friend class ICmpInst; // For getting `Val`.
- friend class FCmpInst; // For getting `Val`.
/// All values point to the context.
Context &Ctx;
@@ -3139,10 +3137,8 @@ class PHINode final : public SingleLLVMInstructionImpl<llvm::PHINode> {
class CmpInst : public Instruction {
protected:
using LLVMValType = llvm::CmpInst;
- /// Use Context::createCmpInst(). Don't call the
- /// constructor directly.
- CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id = ClassID::Cmp,
- Opcode Opc = Opcode::Cmp)
+ /// Use Context::createCmpInst(). Don't call the constructor directly.
+ CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id, Opcode Opc)
: Instruction(Id, Opc, CI, Ctx) {}
friend Context; // for CmpInst()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
@@ -3165,13 +3161,13 @@ class CmpInst : public Instruction {
}
unsigned getNumOfIRInstrs() const final { return 1u; }
static CmpInst *create(OtherOps Op, Predicate Pred, Value *S1, Value *S2,
- const Twine &Name = "",
- InsertPosition InsertBefore = nullptr);
- static CmpInst *CreateWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
+ Context &Ctx, const Twine &Name = "",
+ Instruction *InsertBefore = nullptr);
+ static CmpInst *createWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
Value *S2,
const Instruction *FlagsSource,
- const Twine &Name = "",
- InsertPosition InsertBefore = nullptr);
+ Context &Ctx, const Twine &Name = "",
+ Instruction *InsertBefore = nullptr);
OtherOps getOpcode() const {
return static_cast<OtherOps>(Instruction::getOpcode());
}
@@ -3211,20 +3207,19 @@ class CmpInst : public Instruction {
return llvm::CmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
}
- /// Methods for support type inquiry through isa, cast, and dyn_cast:
- static bool classof(const Instruction *From) {
- return isa<ICmpInst>(From) || isa<FCmpInst>(From);
- }
+ /// Method for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const Value *From) {
- return isa<Instruction>(From) && classof(cast<Instruction>(From));
+ return From->getSubclassID() == ClassID::ICmp ||
+ From->getSubclassID() == ClassID::FCmp;
}
+
/// Create a result type for fcmp/icmp
- static Type *makeCmpResultType(Type *opnd_type) {
- if (VectorType *vt = dyn_cast<VectorType>(opnd_type)) {
- return VectorType::get(Type::getInt1Ty(opnd_type->getContext()),
+ static Type *makeCmpResultType(Type *OpndType) {
+ if (VectorType *vt = dyn_cast<VectorType>(OpndType)) {
+ return VectorType::get(Type::getInt1Ty(OpndType->getContext()),
vt->getElementCount());
}
- return Type::getInt1Ty(opnd_type->getContext());
+ return Type::getInt1Ty(OpndType->getContext());
}
#ifndef NDEBUG
@@ -3237,8 +3232,7 @@ class CmpInst : public Instruction {
};
class ICmpInst : public CmpInst {
- /// Use Context::createICmpInst(). Don't call the
- /// constructor directly.
+ /// Use Context::createICmpInst(). Don't call the constructor directly.
ICmpInst(llvm::ICmpInst *CI, Context &Ctx)
: CmpInst(CI, Ctx, ClassID::ICmp, Opcode::ICmp) {}
friend class Context; // For constructor.
@@ -3263,17 +3257,13 @@ class ICmpInst : public CmpInst {
return llvm::ICmpInst::compare(LHS, RHS, Pred);
}
- static bool classof(const Instruction *From) {
- return From->getSubclassID() == ClassID::ICmp;
- }
static bool classof(const Value *From) {
- return isa<Instruction>(From) && classof(cast<Instruction>(From));
+ return From->getSubclassID() == ClassID::ICmp;
}
};
class FCmpInst : public CmpInst {
- /// Use Context::createFCmpInst(). Don't call the
- /// constructor directly.
+ /// Use Context::createFCmpInst(). Don't call the constructor directly.
FCmpInst(llvm::FCmpInst *CI, Context &Ctx)
: CmpInst(CI, Ctx, ClassID::FCmp, Opcode::FCmp) {}
friend class Context; // For constructor.
@@ -3292,9 +3282,6 @@ class FCmpInst : public CmpInst {
return llvm::FCmpInst::compare(LHS, RHS, Pred);
}
- static bool classof(const Instruction *From) {
- return From->getSubclassID() == ClassID::FCmp;
- }
static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::FCmp;
}
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 4addbd049f1af0..0712498d7c894b 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -109,7 +109,6 @@ DEF_INSTR(Cast, OPCODES(\
), CastInst)
DEF_INSTR(PHI, OP(PHI), PHINode)
DEF_INSTR(Unreachable, OP(Unreachable), UnreachableInst)
-DEF_INSTR(Cmp, OP(Cmp), CmpInst)
DEF_INSTR(ICmp, OP(ICmp), FCmpInst)
DEF_INSTR(FCmp, OP(FCmp), ICmpInst)
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 69ed354bd928c0..20a8a300b289cd 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2701,10 +2701,6 @@ PHINode *Context::createPHINode(llvm::PHINode *I) {
auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
return cast<PHINode>(registerValue(std::move(NewPtr)));
}
-CmpInst *Context::createCmpInst(llvm::CmpInst *I) {
- auto NewPtr = std::unique_ptr<CmpInst>(new CmpInst(I, *this));
- return cast<CmpInst>(registerValue(std::move(NewPtr)));
-}
ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
return cast<ICmpInst>(registerValue(std::move(NewPtr)));
@@ -2714,6 +2710,29 @@ FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
return cast<FCmpInst>(registerValue(std::move(NewPtr)));
}
+CmpInst *CmpInst::create(OtherOps Op, Predicate P, Value *S1, Value *S2,
+ Context &Ctx, const Twine &Name,
+ Instruction *InsertBefore) {
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ if (InsertBefore)
+ Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
+ auto *LLVMI =
+ cast<llvm::CmpInst>(Builder.CreateCmp(P, S1->Val, S2->Val, Name));
+ if (llvm::ICmpInst *IC = dyn_cast<llvm::ICmpInst>(LLVMI))
+ return Ctx.createICmpInst(IC);
+ else
+ return Ctx.createFCmpInst(cast<llvm::FCmpInst>(IC));
+}
+
+CmpInst *CmpInst::createWithCopiedFlags(OtherOps Op, Predicate P, Value *S1,
+ Value *S2, const Instruction *F,
+ Context &Ctx, const Twine &Name,
+ Instruction *InsertBefore) {
+ CmpInst *Inst = create(Op, P, S1, S2, Ctx, Name, InsertBefore);
+ cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
+ return Inst;
+}
+
void CmpInst::setPredicate(Predicate P) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index b4905c2231e3c8..b192e0f4bd9ca6 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -1476,9 +1476,9 @@ define void @foo(i64 %i0, i64 %i1, float %f0, float %f1) {
auto &F = *Ctx.createFunction(&LLVMF);
auto *BB = &*F.begin();
auto It = BB->begin();
- auto *FCmp = cast<sandboxir::FCmpInst>(&*It++);
+ auto *FCmp = cast<sandboxir::CmpInst>(&*It++);
checkCmpInst(Ctx, FCmp);
- auto *ICmp = cast<sandboxir::ICmpInst>(&*It++);
+ auto *ICmp = cast<sandboxir::CmpInst>(&*It++);
checkCmpInst(Ctx, ICmp);
}
>From 42b5d75b9a70a811df235e651ab231a5856afd01 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Thu, 29 Aug 2024 21:51:26 +0000
Subject: [PATCH 3/7] Address comments
---
llvm/lib/SandboxIR/SandboxIR.cpp | 13 +++++-----
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 28 +++++++++++++++++++++-
2 files changed, 34 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 20a8a300b289cd..70207267fe04b0 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2716,12 +2716,13 @@ CmpInst *CmpInst::create(OtherOps Op, Predicate P, Value *S1, Value *S2,
auto &Builder = Ctx.getLLVMIRBuilder();
if (InsertBefore)
Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
- auto *LLVMI =
- cast<llvm::CmpInst>(Builder.CreateCmp(P, S1->Val, S2->Val, Name));
- if (llvm::ICmpInst *IC = dyn_cast<llvm::ICmpInst>(LLVMI))
- return Ctx.createICmpInst(IC);
- else
- return Ctx.createFCmpInst(cast<llvm::FCmpInst>(IC));
+ auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
+ if (Op == OtherOps::ICmp)
+ return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
+ else {
+ assert(Op == OtherOps::FCmp);
+ return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
+ }
}
CmpInst *CmpInst::createWithCopiedFlags(OtherOps Op, Predicate P, Value *S1,
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 65e6e93a97781a..bbccb042013d6a 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4447,7 +4447,7 @@ void checkCommonPredicates(sandboxir::CmpInst *SBCmp, llvm::CmpInst *LLVMCmp) {
TEST_F(SandboxIRTest, ICmpInst) {
SCOPED_TRACE("SandboxIRTest sandboxir::ICmpInst tests");
parseIR(C, R"IR(
-define void @foo(float %f0, float %f1, i32 %i0, i32 %i1) {
+define void @foo(i32 %i0, i32 %i1) {
bb:
%ine = icmp ne i32 %i0, %i1
%iugt = icmp ugt i32 %i0, %i1
@@ -4478,6 +4478,10 @@ define void @foo(float %f0, float %f1, i32 %i0, i32 %i1) {
EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
}
+ auto *NewCmp = sandboxir::CmpInst::create(
+ CmpInst::OtherOps::ICmp, llvm::CmpInst::ICMP_ULE, F.getArg(0),
+ F.getArg(1), Ctx, "", &*BB->begin());
+ EXPECT_EQ(NewCmp, &*BB->begin());
}
TEST_F(SandboxIRTest, FCmpInst) {
@@ -4501,6 +4505,9 @@ define void @foo(float %f0, float %f1) {
%fune = fcmp une float %f0, %f1
%ftrue = fcmp true float %f0, %f1
ret void
+bb1:
+ %copyfrom = fadd reassoc float %f0, 42.0
+ ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
@@ -4517,6 +4524,25 @@ define void @foo(float %f0, float %f1) {
checkSwapOperands(Ctx, FCmp, LLVMFCmp);
checkCommonPredicates(FCmp, LLVMFCmp);
}
+
+ auto *LLVMBB1 = getBasicBlockByName(LLVMF, "bb1");
+ auto *BB1 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB1));
+ auto It1 = BB1->begin();
+ auto *CopyFrom = &*It1++;
+ CopyFrom->setFastMathFlags(FastMathFlags::getFast());
+
+ // create with default flags
+ auto *NewFCmp = sandboxir::CmpInst::create(
+ CmpInst::OtherOps::FCmp, llvm::CmpInst::FCMP_ONE, F.getArg(0),
+ F.getArg(1), Ctx, "", &*It1);
+ FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
+ EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
+ // create with copied flags
+ auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
+ CmpInst::OtherOps::FCmp, llvm::CmpInst::FCMP_ONE, F.getArg(0),
+ F.getArg(1), CopyFrom, Ctx, "", &*It1);
+ EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
+ CopyFrom->getFastMathFlags());
}
TEST_F(SandboxIRTest, UnreachableInst) {
>From 17ca863e4120ed18bdfe33499ed67244a43bd455 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Fri, 30 Aug 2024 17:57:41 +0000
Subject: [PATCH 4/7] Address more comments.
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 34 ++--------
llvm/include/llvm/SandboxIR/Tracker.h | 15 -----
llvm/lib/SandboxIR/SandboxIR.cpp | 37 ++++-------
llvm/lib/SandboxIR/Tracker.cpp | 10 ---
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 76 +++++++++++-----------
5 files changed, 59 insertions(+), 113 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index e2a3cf1902189b..1555654b4c079e 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -751,8 +751,6 @@ class Instruction : public sandboxir::User {
friend class PHINode; // For getTopmostLLVMInstruction().
friend class UnreachableInst; // For getTopmostLLVMInstruction().
friend class CmpInst; // For getTopmostLLVMInstruction().
- friend class ICmpInst; // For getTopmostLLVMInstruction().
- friend class FCmpInst; // For getTopmostLLVMInstruction().
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
@@ -909,6 +907,7 @@ template <typename LLVMT> class SingleLLVMInstructionImpl : public Instruction {
friend class UnaryInstruction;
friend class CallBase;
friend class FuncletPadInst;
+ friend class CmpInst;
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
@@ -3128,49 +3127,33 @@ class PHINode final : public SingleLLVMInstructionImpl<llvm::PHINode> {
// Wraps a member function that takes no parameters
// LLVMValType should be the type of the wrapped class
#define WRAP_MEMBER(FunctionName) \
- auto FunctionName() { return cast<LLVMValType>(Val)->FunctionName(); }
+ auto FunctionName() const { return cast<LLVMValType>(Val)->FunctionName(); }
// Wraps both--a common idiom in the CmpInst classes
#define WRAP_BOTH(FunctionName) \
WRAP_STATIC_PREDICATE(FunctionName) \
WRAP_MEMBER(FunctionName)
-class CmpInst : public Instruction {
+class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
protected:
using LLVMValType = llvm::CmpInst;
/// Use Context::createCmpInst(). Don't call the constructor directly.
CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id, Opcode Opc)
- : Instruction(Id, Opc, CI, Ctx) {}
+ : SingleLLVMInstructionImpl(Id, Opc, CI, Ctx) {}
friend Context; // for CmpInst()
- Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
- return getOperandUseDefault(OpIdx, Verify);
- }
- SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
- return {cast<llvm::Instruction>(Val)};
- }
static Value *createCommon(Value *Cond, Value *True, Value *False,
const Twine &Name, IRBuilder<> &Builder,
Context &Ctx);
public:
using Predicate = llvm::CmpInst::Predicate;
- using PredicateField = llvm::CmpInst::PredicateField;
- using OtherOps = llvm::Instruction::OtherOps;
- unsigned getUseOperandNo(const Use &Use) const final {
- return getUseOperandNoDefault(Use);
- }
- unsigned getNumOfIRInstrs() const final { return 1u; }
- static CmpInst *create(OtherOps Op, Predicate Pred, Value *S1, Value *S2,
- Context &Ctx, const Twine &Name = "",
+ static CmpInst *create(Predicate Pred, Value *S1, Value *S2, Context &Ctx,
+ const Twine &Name = "",
Instruction *InsertBefore = nullptr);
- static CmpInst *createWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
- Value *S2,
+ static CmpInst *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
const Instruction *FlagsSource,
Context &Ctx, const Twine &Name = "",
Instruction *InsertBefore = nullptr);
- OtherOps getOpcode() const {
- return static_cast<OtherOps>(Instruction::getOpcode());
- }
void setPredicate(Predicate P);
void swapOperands();
@@ -3223,9 +3206,6 @@ class CmpInst : public Instruction {
}
#ifndef NDEBUG
- void verify() const final {
- assert(isa<LLVMValType>(Val) && "Expected CmpInst!");
- }
void dumpOS(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const;
#endif
diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index d3d0bf3dfe7803..5fc43db82bd707 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -64,7 +64,6 @@ class SwitchInst;
class ConstantInt;
class ShuffleVectorInst;
class CmpInst;
-
/// The base class for IR Change classes.
class IRChangeBase {
protected:
@@ -131,20 +130,6 @@ class PHIAddIncoming : public IRChangeBase {
#endif
};
-class CmpSetPredicate : public IRChangeBase {
- CmpInst *Cmp;
- llvm::CmpInst::Predicate OldP;
-
-public:
- CmpSetPredicate(CmpInst *Cmp);
- void revert(Tracker &Tracker) final;
- void accept() final {}
-#ifndef NDEBUG
- void dump(raw_ostream &OS) const final { OS << "CmpSetPredicate"; }
- LLVM_DUMP_METHOD void dump() const final;
-#endif
-};
-
class CmpSwapOperands : public IRChangeBase {
CmpInst *Cmp;
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 70207267fe04b0..29e2ffcc405b05 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2710,34 +2710,29 @@ FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
return cast<FCmpInst>(registerValue(std::move(NewPtr)));
}
-CmpInst *CmpInst::create(OtherOps Op, Predicate P, Value *S1, Value *S2,
- Context &Ctx, const Twine &Name,
- Instruction *InsertBefore) {
+CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2, Context &Ctx,
+ const Twine &Name, Instruction *InsertBefore) {
auto &Builder = Ctx.getLLVMIRBuilder();
- if (InsertBefore)
- Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
+ Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
- if (Op == OtherOps::ICmp)
+ if (dyn_cast<llvm::ICmpInst>(LLVMI))
return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
- else {
- assert(Op == OtherOps::FCmp);
- return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
- }
+ return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
}
-CmpInst *CmpInst::createWithCopiedFlags(OtherOps Op, Predicate P, Value *S1,
- Value *S2, const Instruction *F,
- Context &Ctx, const Twine &Name,
+CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
+ const Instruction *F, Context &Ctx,
+ const Twine &Name,
Instruction *InsertBefore) {
- CmpInst *Inst = create(Op, P, S1, S2, Ctx, Name, InsertBefore);
+ CmpInst *Inst = create(P, S1, S2, Ctx, Name, InsertBefore);
cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
return Inst;
}
void CmpInst::setPredicate(Predicate P) {
- auto &Tracker = Ctx.getTracker();
- if (Tracker.isTracking())
- Tracker.track(std::make_unique<CmpSetPredicate>(this));
+ Ctx.getTracker()
+ .emplaceIfTracking<
+ GenericSetter<&CmpInst::getPredicate, &CmpInst::setPredicate>>(this);
cast<llvm::CmpInst>(Val)->setPredicate(P);
}
@@ -2749,16 +2744,12 @@ void CmpInst::swapOperands() {
}
void ICmpInst::swapOperands() {
- auto &Tracker = Ctx.getTracker();
- if (Tracker.isTracking())
- Tracker.track(std::make_unique<CmpSwapOperands>(this));
+ Ctx.getTracker().emplaceIfTracking<CmpSwapOperands>(this);
cast<llvm::ICmpInst>(Val)->swapOperands();
}
void FCmpInst::swapOperands() {
- auto &Tracker = Ctx.getTracker();
- if (Tracker.isTracking())
- Tracker.track(std::make_unique<CmpSwapOperands>(this));
+ Ctx.getTracker().emplaceIfTracking<CmpSwapOperands>(this);
cast<llvm::FCmpInst>(Val)->swapOperands();
}
diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index 9a824d5211a407..c6eb9fc68a4b11 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -248,16 +248,6 @@ void ShuffleVectorSetMask::dump() const {
}
#endif
-CmpSetPredicate::CmpSetPredicate(CmpInst *Cmp)
- : IRChangeBase(), Cmp(Cmp), OldP(Cmp->getPredicate()) {}
-
-void CmpSetPredicate::revert(Tracker &Tracker) { Cmp->setPredicate(OldP); }
-
-void CmpSetPredicate::dump() const {
- dump(dbgs());
- dbgs() << "\n";
-}
-
CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index bbccb042013d6a..2fec35b1d5962b 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4392,55 +4392,56 @@ define void @foo(i32 %arg) {
EXPECT_EQ(NewPHI->getNumIncomingValues(), PHI->getNumIncomingValues());
}
-void checkSwapOperands(sandboxir::Context &Ctx, llvm::sandboxir::CmpInst *SBCmp,
- llvm::CmpInst *LLVMCmp) {
- auto OrigOp0 = SBCmp->getOperand(0);
- auto OrigOp1 = SBCmp->getOperand(1);
+static void checkSwapOperands(sandboxir::Context &Ctx,
+ llvm::sandboxir::CmpInst *Cmp,
+ llvm::CmpInst *LLVMCmp) {
+ auto OrigOp0 = Cmp->getOperand(0);
+ auto OrigOp1 = Cmp->getOperand(1);
EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp0);
EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp1);
// This checks the dispatch mechanism in CmpInst, as well as
// the specific implementations.
- SBCmp->swapOperands();
- EXPECT_NE(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp0);
- EXPECT_NE(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp1);
+ Cmp->swapOperands();
EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp0);
EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp1);
// Undo it to keep the rest of the test consistent
- SBCmp->swapOperands();
+ Cmp->swapOperands();
}
-void checkCommonPredicates(sandboxir::CmpInst *SBCmp, llvm::CmpInst *LLVMCmp) {
+static void checkCommonPredicates(sandboxir::CmpInst *Cmp,
+ llvm::CmpInst *LLVMCmp) {
// Check proper creation
- auto SBPred = SBCmp->getPredicate();
+ auto Pred = Cmp->getPredicate();
auto LLVMPred = LLVMCmp->getPredicate();
- EXPECT_EQ(SBPred, LLVMPred);
+ EXPECT_EQ(Pred, LLVMPred);
// Check setPredicate
- SBCmp->setPredicate(llvm::CmpInst::FCMP_FALSE);
+ Cmp->setPredicate(llvm::CmpInst::FCMP_FALSE);
+ EXPECT_EQ(Cmp->getPredicate(), llvm::CmpInst::FCMP_FALSE);
EXPECT_EQ(LLVMCmp->getPredicate(), llvm::CmpInst::FCMP_FALSE);
- SBCmp->setPredicate(SBPred);
- EXPECT_EQ(LLVMCmp->getPredicate(), SBPred);
+ Cmp->setPredicate(Pred);
+ EXPECT_EQ(LLVMCmp->getPredicate(), Pred);
// Ensure the accessors properly forward to the underlying implementation
- EXPECT_STREQ(sandboxir::CmpInst::getPredicateName(SBPred).data(),
+ EXPECT_STREQ(sandboxir::CmpInst::getPredicateName(Pred).data(),
llvm::CmpInst::getPredicateName(LLVMPred).data());
- EXPECT_EQ(SBCmp->isFPPredicate(), LLVMCmp->isFPPredicate());
- EXPECT_EQ(SBCmp->isIntPredicate(), LLVMCmp->isIntPredicate());
- EXPECT_EQ(SBCmp->getInversePredicate(), LLVMCmp->getInversePredicate());
- EXPECT_EQ(SBCmp->getOrderedPredicate(), LLVMCmp->getOrderedPredicate());
- EXPECT_EQ(SBCmp->getUnorderedPredicate(), LLVMCmp->getUnorderedPredicate());
- EXPECT_EQ(SBCmp->getSwappedPredicate(), LLVMCmp->getSwappedPredicate());
- EXPECT_EQ(SBCmp->isStrictPredicate(), LLVMCmp->isStrictPredicate());
- EXPECT_EQ(SBCmp->isNonStrictPredicate(), LLVMCmp->isNonStrictPredicate());
- EXPECT_EQ(SBCmp->isRelational(), LLVMCmp->isRelational());
- if (SBCmp->isRelational()) {
- EXPECT_EQ(SBCmp->getFlippedStrictnessPredicate(),
+ EXPECT_EQ(Cmp->isFPPredicate(), LLVMCmp->isFPPredicate());
+ EXPECT_EQ(Cmp->isIntPredicate(), LLVMCmp->isIntPredicate());
+ EXPECT_EQ(Cmp->getInversePredicate(), LLVMCmp->getInversePredicate());
+ EXPECT_EQ(Cmp->getOrderedPredicate(), LLVMCmp->getOrderedPredicate());
+ EXPECT_EQ(Cmp->getUnorderedPredicate(), LLVMCmp->getUnorderedPredicate());
+ EXPECT_EQ(Cmp->getSwappedPredicate(), LLVMCmp->getSwappedPredicate());
+ EXPECT_EQ(Cmp->isStrictPredicate(), LLVMCmp->isStrictPredicate());
+ EXPECT_EQ(Cmp->isNonStrictPredicate(), LLVMCmp->isNonStrictPredicate());
+ EXPECT_EQ(Cmp->isRelational(), LLVMCmp->isRelational());
+ if (Cmp->isRelational()) {
+ EXPECT_EQ(Cmp->getFlippedStrictnessPredicate(),
LLVMCmp->getFlippedStrictnessPredicate());
}
- EXPECT_EQ(SBCmp->isCommutative(), LLVMCmp->isCommutative());
- EXPECT_EQ(SBCmp->isTrueWhenEqual(), LLVMCmp->isTrueWhenEqual());
- EXPECT_EQ(SBCmp->isFalseWhenEqual(), LLVMCmp->isFalseWhenEqual());
- EXPECT_EQ(sandboxir::CmpInst::isOrdered(SBPred),
+ EXPECT_EQ(Cmp->isCommutative(), LLVMCmp->isCommutative());
+ EXPECT_EQ(Cmp->isTrueWhenEqual(), LLVMCmp->isTrueWhenEqual());
+ EXPECT_EQ(Cmp->isFalseWhenEqual(), LLVMCmp->isFalseWhenEqual());
+ EXPECT_EQ(sandboxir::CmpInst::isOrdered(Pred),
llvm::CmpInst::isOrdered(LLVMPred));
- EXPECT_EQ(sandboxir::CmpInst::isUnordered(SBPred),
+ EXPECT_EQ(sandboxir::CmpInst::isUnordered(Pred),
llvm::CmpInst::isUnordered(LLVMPred));
}
@@ -4478,9 +4479,9 @@ define void @foo(i32 %i0, i32 %i1) {
EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
}
- auto *NewCmp = sandboxir::CmpInst::create(
- CmpInst::OtherOps::ICmp, llvm::CmpInst::ICMP_ULE, F.getArg(0),
- F.getArg(1), Ctx, "", &*BB->begin());
+ auto *NewCmp =
+ sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
+ F.getArg(1), Ctx, "", &*BB->begin());
EXPECT_EQ(NewCmp, &*BB->begin());
}
@@ -4533,14 +4534,13 @@ define void @foo(float %f0, float %f1) {
// create with default flags
auto *NewFCmp = sandboxir::CmpInst::create(
- CmpInst::OtherOps::FCmp, llvm::CmpInst::FCMP_ONE, F.getArg(0),
- F.getArg(1), Ctx, "", &*It1);
+ llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), Ctx, "", &*It1);
FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
// create with copied flags
auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
- CmpInst::OtherOps::FCmp, llvm::CmpInst::FCMP_ONE, F.getArg(0),
- F.getArg(1), CopyFrom, Ctx, "", &*It1);
+ llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, Ctx, "",
+ &*It1);
EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
CopyFrom->getFastMathFlags());
}
>From 6508dcfffe7c2ef0e2ef0058b012217686a0778a Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Fri, 30 Aug 2024 18:59:56 +0000
Subject: [PATCH 5/7] Address more comments.
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 10 +++++-----
llvm/lib/SandboxIR/SandboxIR.cpp | 13 +++++++------
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 8 ++++----
3 files changed, 16 insertions(+), 15 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 1555654b4c079e..b8c238aa9adec9 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -3147,13 +3147,13 @@ class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
public:
using Predicate = llvm::CmpInst::Predicate;
- static CmpInst *create(Predicate Pred, Value *S1, Value *S2, Context &Ctx,
- const Twine &Name = "",
- Instruction *InsertBefore = nullptr);
+ static CmpInst *create(Predicate Pred, Value *S1, Value *S2,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name = "");
static CmpInst *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
const Instruction *FlagsSource,
- Context &Ctx, const Twine &Name = "",
- Instruction *InsertBefore = nullptr);
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name = "");
void setPredicate(Predicate P);
void swapOperands();
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 29e2ffcc405b05..f66b7a3eaa4684 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2710,8 +2710,9 @@ FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
return cast<FCmpInst>(registerValue(std::move(NewPtr)));
}
-CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2, Context &Ctx,
- const Twine &Name, Instruction *InsertBefore) {
+CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
@@ -2721,10 +2722,10 @@ CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2, Context &Ctx,
}
CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
- const Instruction *F, Context &Ctx,
- const Twine &Name,
- Instruction *InsertBefore) {
- CmpInst *Inst = create(P, S1, S2, Ctx, Name, InsertBefore);
+ const Instruction *F,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name) {
+ CmpInst *Inst = create(P, S1, S2, InsertBefore, Ctx, Name);
cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
return Inst;
}
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 2fec35b1d5962b..4406415db94355 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4481,7 +4481,7 @@ define void @foo(i32 %i0, i32 %i1) {
}
auto *NewCmp =
sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
- F.getArg(1), Ctx, "", &*BB->begin());
+ F.getArg(1), &*BB->begin(), Ctx, "");
EXPECT_EQ(NewCmp, &*BB->begin());
}
@@ -4534,13 +4534,13 @@ define void @foo(float %f0, float %f1) {
// create with default flags
auto *NewFCmp = sandboxir::CmpInst::create(
- llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), Ctx, "", &*It1);
+ llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), &*It1, Ctx, "");
FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
// create with copied flags
auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
- llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, Ctx, "",
- &*It1);
+ llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, &*It1, Ctx,
+ "");
EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
CopyFrom->getFastMathFlags());
}
>From 4efbddd93159e5e3f3709ea2df9df78225837906 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 3 Sep 2024 16:59:38 -0700
Subject: [PATCH 6/7] Better support for makeCmpResultType
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 10 +++-------
llvm/include/llvm/SandboxIR/Type.h | 3 +++
llvm/lib/SandboxIR/SandboxIR.cpp | 13 +++++++++++--
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 5 +++++
4 files changed, 22 insertions(+), 9 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 11f462ef5e86e8..1233b57cede7e8 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -3267,13 +3267,7 @@ class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
}
/// Create a result type for fcmp/icmp
- static Type *makeCmpResultType(Type *OpndType) {
- if (VectorType *vt = dyn_cast<VectorType>(OpndType)) {
- return VectorType::get(Type::getInt1Ty(OpndType->getContext()),
- vt->getElementCount());
- }
- return Type::getInt1Ty(OpndType->getContext());
- }
+ static Type *makeCmpResultType(Type *OpndType);
#ifndef NDEBUG
void dumpOS(raw_ostream &OS) const override;
@@ -3361,6 +3355,8 @@ class Context {
LLVMContext &LLVMCtx;
friend class Type; // For LLVMCtx.
friend class PointerType; // For LLVMCtx.
+ friend class CmpInst; // For LLVMCtx. TODO: cleanup when sandboxir::VectorType
+ // is complete
Tracker IRTracker;
/// Maps LLVM Value to the corresponding sandboxir::Value. Owns all
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 89e787f5f5d4b2..1f16ad02b30cec 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -41,6 +41,9 @@ class Type {
friend class Function; // For LLVMTy.
friend class CallBase; // For LLVMTy.
friend class ConstantInt; // For LLVMTy.
+ friend class CmpInst; // For LLVMTy. TODO: Cleanup after sandboxir::VectorType
+ // is more complete.
+
// Friend all instruction classes because `create()` functions use LLVMTy.
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
#define DEF_CONST(ID, CLASS) friend class CLASS;
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 94af9046c3a9bd..e5ff88f92c42f1 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2826,7 +2826,6 @@ FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
return cast<FCmpInst>(registerValue(std::move(NewPtr)));
}
-
CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
@@ -2837,7 +2836,6 @@ CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2,
return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
}
-
CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
const Instruction *F,
Instruction *InsertBefore, Context &Ctx,
@@ -2847,6 +2845,17 @@ CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
return Inst;
}
+Type *CmpInst::makeCmpResultType(Type *OpndType) {
+ if (auto *VT = dyn_cast<VectorType>(OpndType)) {
+ // TODO: Cleanup when we have more complete support for
+ // sandboxir::VectorType
+ return OpndType->getContext().getType(llvm::VectorType::get(
+ llvm::Type::getInt1Ty(OpndType->getContext().LLVMCtx),
+ cast<llvm::VectorType>(VT->LLVMTy)->getElementCount()));
+ }
+ return Type::getInt1Ty(OpndType->getContext());
+}
+
void CmpInst::setPredicate(Predicate P) {
Ctx.getTracker()
.emplaceIfTracking<
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 5d93eccaac4e71..21b99b8a50dbe5 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4656,6 +4656,11 @@ define void @foo(i32 %i0, i32 %i1) {
sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
F.getArg(1), &*BB->begin(), Ctx, "");
EXPECT_EQ(NewCmp, &*BB->begin());
+ // TODO: Improve this test when sandboxir::VectorType is more completely
+ // implemented.
+ sandboxir::Type *RT =
+ sandboxir::CmpInst::makeCmpResultType(F.getArg(0)->getType());
+ EXPECT_TRUE(RT->isIntegerTy(1)); // Only one bit in a single comparison
}
TEST_F(SandboxIRTest, FCmpInst) {
>From 1a8a4fb79d681d6dc1900b236507408e4952ae9b Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 3 Sep 2024 17:30:59 -0700
Subject: [PATCH 7/7] Address more comments
---
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 26 +++++++++++++++++++---
1 file changed, 23 insertions(+), 3 deletions(-)
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 21b99b8a50dbe5..2925a2889395de 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4577,6 +4577,8 @@ static void checkSwapOperands(sandboxir::Context &Ctx,
Cmp->swapOperands();
EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp0);
EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp1);
+ EXPECT_EQ(Cmp->getOperand(0), OrigOp1);
+ EXPECT_EQ(Cmp->getOperand(1), OrigOp0);
// Undo it to keep the rest of the test consistent
Cmp->swapOperands();
}
@@ -4654,8 +4656,14 @@ define void @foo(i32 %i0, i32 %i1) {
}
auto *NewCmp =
sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
- F.getArg(1), &*BB->begin(), Ctx, "");
+ F.getArg(1), &*BB->begin(), Ctx, "NewCmp");
EXPECT_EQ(NewCmp, &*BB->begin());
+ EXPECT_EQ(NewCmp->getPredicate(), llvm::CmpInst::ICMP_ULE);
+ EXPECT_EQ(NewCmp->getOperand(0), F.getArg(0));
+ EXPECT_EQ(NewCmp->getOperand(1), F.getArg(1));
+#ifndef NDEBUG
+ EXPECT_EQ(NewCmp->getName(), "NewCmp");
+#endif // NDEBUG
// TODO: Improve this test when sandboxir::VectorType is more completely
// implemented.
sandboxir::Type *RT =
@@ -4712,15 +4720,27 @@ define void @foo(float %f0, float %f1) {
// create with default flags
auto *NewFCmp = sandboxir::CmpInst::create(
- llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), &*It1, Ctx, "");
+ llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), &*It1, Ctx, "NewFCmp");
+ EXPECT_EQ(NewFCmp->getPredicate(), llvm::CmpInst::FCMP_ONE);
+ EXPECT_EQ(NewFCmp->getOperand(0), F.getArg(0));
+ EXPECT_EQ(NewFCmp->getOperand(1), F.getArg(1));
+#ifndef NDEBUG
+ EXPECT_EQ(NewFCmp->getName(), "NewFCmp");
+#endif // NDEBUG
FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
// create with copied flags
auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, &*It1, Ctx,
- "");
+ "NewFCmpFlags");
EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
CopyFrom->getFastMathFlags());
+ EXPECT_EQ(NewFCmpFlags->getPredicate(), llvm::CmpInst::FCMP_ONE);
+ EXPECT_EQ(NewFCmpFlags->getOperand(0), F.getArg(0));
+ EXPECT_EQ(NewFCmpFlags->getOperand(1), F.getArg(1));
+#ifndef NDEBUG
+ EXPECT_EQ(NewFCmpFlags->getName(), "NewFCmpFlags");
+#endif // NDEBUG
}
TEST_F(SandboxIRTest, UnreachableInst) {
More information about the llvm-commits
mailing list