[llvm] [SandboxIR] Implement CmpInst, FCmpInst, and ICmpInst (PR #106301)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 3 17:37:28 PDT 2024


https://github.com/Sterling-Augustine updated https://github.com/llvm/llvm-project/pull/106301

>From e91c6ce666144c102e889f51c13fa13bacead8a9 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Thu, 8 Aug 2024 23:08:25 +0000
Subject: [PATCH 1/7] [SandboxIR] Implement CmpInst, FCmpInst, and ICmpInst

---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 212 +++++++++++++++++-
 .../llvm/SandboxIR/SandboxIRValues.def        |   3 +
 llvm/include/llvm/SandboxIR/Tracker.h         |  28 +++
 llvm/lib/SandboxIR/SandboxIR.cpp              |  62 +++++
 llvm/lib/SandboxIR/Tracker.cpp                |  18 ++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 127 +++++++++++
 llvm/unittests/SandboxIR/TrackerTest.cpp      |  43 ++++
 7 files changed, 488 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index b7bdf9acd2ef45..b47b6ee512471e 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -54,11 +54,17 @@
 //                                      |                   |
 //                                      |                   +- ZExtInst
 //                                      |
-//                                      +- CallBase -----------+- CallBrInst
-//                                      |                      |
-//                                      +- CmpInst             +- CallInst
-//                                      |                      |
-//                                      +- ExtractElementInst  +- InvokeInst
+//                                      +- CallBase --------+- CallBrInst
+//                                      |                   |
+//                                      |                   +- CallInst
+//                                      |                   |
+//                                      |                   +- InvokeInst
+//                                      |
+//                                      +- CmpInst ---------+- ICmpInst
+//                                      |                   |
+//                                      |                   +- FCmpInst
+//                                      |
+//                                      +- ExtractElementInst
 //                                      |
 //                                      +- GetElementPtrInst
 //                                      |
@@ -150,6 +156,9 @@ class BinaryOperator;
 class PossiblyDisjointInst;
 class AtomicRMWInst;
 class AtomicCmpXchgInst;
+class CmpInst;
+class ICmpInst;
+class FCmpInst;
 
 /// Iterator for the `Use` edges of a User's operands.
 /// \Returns the operand `Use` when dereferenced.
@@ -294,6 +303,9 @@ class Value {
   friend class PHINode;               // For getting `Val`.
   friend class UnreachableInst;       // For getting `Val`.
   friend class CatchSwitchAddHandler; // For `Val`.
+  friend class CmpInst;               // For getting `Val`.
+  friend class ICmpInst;              // For getting `Val`.
+  friend class FCmpInst;              // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -730,6 +742,9 @@ class Instruction : public sandboxir::User {
   friend class CastInst;           // For getTopmostLLVMInstruction().
   friend class PHINode;            // For getTopmostLLVMInstruction().
   friend class UnreachableInst;    // For getTopmostLLVMInstruction().
+  friend class CmpInst;            // For getTopmostLLVMInstruction().
+  friend class ICmpInst;           // For getTopmostLLVMInstruction().
+  friend class FCmpInst;           // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -2974,6 +2989,187 @@ class PHINode final : public SingleLLVMInstructionImpl<llvm::PHINode> {
   //                         uint32_t ToIdx = 0)
 };
 
+// Wraps a static function that takes a single Predicate parameter
+// LLVMValType should be the type of the wrapped class
+#define WRAP_STATIC_PREDICATE(FunctionName)                                    \
+  static auto FunctionName(Predicate P) { return LLVMValType::FunctionName(P); }
+// Wraps a member function that takes no parameters
+// LLVMValType should be the type of the wrapped class
+#define WRAP_MEMBER(FunctionName)                                              \
+  auto FunctionName() { return cast<LLVMValType>(Val)->FunctionName(); }
+// Wraps both--a common idiom in the CmpInst classes
+#define WRAP_BOTH(FunctionName)                                                \
+  WRAP_STATIC_PREDICATE(FunctionName)                                          \
+  WRAP_MEMBER(FunctionName)
+
+class CmpInst : public Instruction {
+protected:
+  using LLVMValType = llvm::CmpInst;
+  /// Use Context::createCmpInst(). Don't call the
+  /// constructor directly.
+  CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id = ClassID::Cmp,
+          Opcode Opc = Opcode::Cmp)
+      : Instruction(Id, Opc, CI, Ctx) {}
+  friend Context; // for CmpInst()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+  static Value *createCommon(Value *Cond, Value *True, Value *False,
+                             const Twine &Name, IRBuilder<> &Builder,
+                             Context &Ctx);
+
+public:
+  using Predicate = llvm::CmpInst::Predicate;
+  using PredicateField = llvm::CmpInst::PredicateField;
+  using OtherOps = llvm::Instruction::OtherOps;
+
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+  static CmpInst *create(OtherOps Op, Predicate Pred, Value *S1, Value *S2,
+                         const Twine &Name = "",
+                         InsertPosition InsertBefore = nullptr);
+  static CmpInst *CreateWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
+                                        Value *S2,
+                                        const Instruction *FlagsSource,
+                                        const Twine &Name = "",
+                                        InsertPosition InsertBefore = nullptr);
+  OtherOps getOpcode() const {
+    return static_cast<OtherOps>(Instruction::getOpcode());
+  }
+  void setPredicate(Predicate P);
+  void swapOperands();
+
+  WRAP_MEMBER(getPredicate);
+  WRAP_BOTH(isFPPredicate);
+  WRAP_BOTH(isIntPredicate);
+  WRAP_STATIC_PREDICATE(getPredicateName);
+  WRAP_BOTH(getInversePredicate);
+  WRAP_BOTH(getOrderedPredicate);
+  WRAP_BOTH(getUnorderedPredicate);
+  WRAP_BOTH(getSwappedPredicate);
+  WRAP_BOTH(isStrictPredicate);
+  WRAP_BOTH(isNonStrictPredicate);
+  WRAP_BOTH(getStrictPredicate);
+  WRAP_BOTH(getNonStrictPredicate);
+  WRAP_BOTH(getFlippedStrictnessPredicate);
+  WRAP_MEMBER(isCommutative);
+  WRAP_BOTH(isEquality);
+  WRAP_BOTH(isRelational);
+  WRAP_BOTH(isSigned);
+  WRAP_BOTH(getSignedPredicate);
+  WRAP_BOTH(getUnsignedPredicate);
+  WRAP_BOTH(getFlippedSignednessPredicate);
+  WRAP_BOTH(isTrueWhenEqual);
+  WRAP_BOTH(isFalseWhenEqual);
+  WRAP_BOTH(isUnsigned);
+  WRAP_STATIC_PREDICATE(isOrdered);
+  WRAP_STATIC_PREDICATE(isUnordered);
+
+  static bool isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+    return llvm::CmpInst::isImpliedTrueByMatchingCmp(Pred1, Pred2);
+  }
+  static bool isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+    return llvm::CmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
+  }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast:
+  static bool classof(const Instruction *From) {
+    return isa<ICmpInst>(From) || isa<FCmpInst>(From);
+  }
+  static bool classof(const Value *From) {
+    return isa<Instruction>(From) && classof(cast<Instruction>(From));
+  }
+  /// Create a result type for fcmp/icmp
+  static Type *makeCmpResultType(Type *opnd_type) {
+    if (VectorType *vt = dyn_cast<VectorType>(opnd_type)) {
+      return VectorType::get(Type::getInt1Ty(opnd_type->getContext()),
+                             vt->getElementCount());
+    }
+    return Type::getInt1Ty(opnd_type->getContext());
+  }
+
+#ifndef NDEBUG
+  void verify() const final {
+    assert(isa<LLVMValType>(Val) && "Expected CmpInst!");
+  }
+  void dumpOS(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+class ICmpInst : public CmpInst {
+  /// Use Context::createICmpInst(). Don't call the
+  /// constructor directly.
+  ICmpInst(llvm::ICmpInst *CI, Context &Ctx)
+      : CmpInst(CI, Ctx, ClassID::ICmp, Opcode::ICmp) {}
+  friend class Context; // For constructor.
+  using LLVMValType = llvm::ICmpInst;
+
+public:
+  void swapOperands();
+
+  WRAP_BOTH(getSignedPredicate);
+  WRAP_BOTH(getUnsignedPredicate);
+  WRAP_BOTH(isEquality);
+  WRAP_MEMBER(isCommutative);
+  WRAP_MEMBER(isRelational);
+  WRAP_STATIC_PREDICATE(isGT);
+  WRAP_STATIC_PREDICATE(isLT);
+  WRAP_STATIC_PREDICATE(isGE);
+  WRAP_STATIC_PREDICATE(isLE);
+
+  static auto predicates() { return llvm::ICmpInst::predicates(); }
+  static bool compare(const APInt &LHS, const APInt &RHS,
+                      ICmpInst::Predicate Pred) {
+    return llvm::ICmpInst::compare(LHS, RHS, Pred);
+  }
+
+  static bool classof(const Instruction *From) {
+    return From->getSubclassID() == ClassID::ICmp;
+  }
+  static bool classof(const Value *From) {
+    return isa<Instruction>(From) && classof(cast<Instruction>(From));
+  }
+};
+
+class FCmpInst : public CmpInst {
+  /// Use Context::createFCmpInst(). Don't call the
+  /// constructor directly.
+  FCmpInst(llvm::FCmpInst *CI, Context &Ctx)
+      : CmpInst(CI, Ctx, ClassID::FCmp, Opcode::FCmp) {}
+  friend class Context; // For constructor.
+  using LLVMValType = llvm::FCmpInst;
+
+public:
+  void swapOperands();
+
+  WRAP_BOTH(isEquality);
+  WRAP_MEMBER(isCommutative);
+  WRAP_MEMBER(isRelational);
+
+  static auto predicates() { return llvm::FCmpInst::predicates(); }
+  static bool compare(const APFloat &LHS, const APFloat &RHS,
+                      FCmpInst::Predicate Pred) {
+    return llvm::FCmpInst::compare(LHS, RHS, Pred);
+  }
+
+  static bool classof(const Instruction *From) {
+    return From->getSubclassID() == ClassID::FCmp;
+  }
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::FCmp;
+  }
+};
+
+#undef WRAP_STATIC_PREDICATE
+#undef WRAP_MEMBER
+#undef WRAP_BOTH
+
 /// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
 /// an OpaqueInstr.
 class OpaqueInst : public SingleLLVMInstructionImpl<llvm::Instruction> {
@@ -3101,6 +3297,12 @@ class Context {
   friend PHINode; // For createPHINode()
   UnreachableInst *createUnreachableInst(llvm::UnreachableInst *UI);
   friend UnreachableInst; // For createUnreachableInst()
+  CmpInst *createCmpInst(llvm::CmpInst *I);
+  friend CmpInst; // For createCmpInst()
+  ICmpInst *createICmpInst(llvm::ICmpInst *I);
+  friend ICmpInst; // For createICmpInst()
+  FCmpInst *createFCmpInst(llvm::FCmpInst *I);
+  friend FCmpInst; // For createFCmpInst()
 
 public:
   Context(LLVMContext &LLVMCtx)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 00c1a6333c8ec4..ac0b942a7cf67e 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -107,6 +107,9 @@ DEF_INSTR(Cast,   OPCODES(\
                           ),                  CastInst)
 DEF_INSTR(PHI,            OP(PHI),            PHINode)
 DEF_INSTR(Unreachable,    OP(Unreachable),    UnreachableInst)
+DEF_INSTR(Cmp,            OP(Cmp),           CmpInst)
+DEF_INSTR(ICmp,           OP(ICmp),          FCmpInst)
+DEF_INSTR(FCmp,           OP(FCmp),          ICmpInst)
 
 // clang-format on
 #ifdef DEF_VALUE
diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index c8a9e99a34341d..d3d0bf3dfe7803 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -63,6 +63,7 @@ class CatchSwitchInst;
 class SwitchInst;
 class ConstantInt;
 class ShuffleVectorInst;
+class CmpInst;
 
 /// The base class for IR Change classes.
 class IRChangeBase {
@@ -130,6 +131,33 @@ class PHIAddIncoming : public IRChangeBase {
 #endif
 };
 
+class CmpSetPredicate : public IRChangeBase {
+  CmpInst *Cmp;
+  llvm::CmpInst::Predicate OldP;
+
+public:
+  CmpSetPredicate(CmpInst *Cmp);
+  void revert(Tracker &Tracker) final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final { OS << "CmpSetPredicate"; }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
+class CmpSwapOperands : public IRChangeBase {
+  CmpInst *Cmp;
+
+public:
+  CmpSwapOperands(CmpInst *Cmp);
+  void revert(Tracker &Tracker) final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final { OS << "CmpSwapOperands"; }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
 /// Tracks swapping a Use with another Use.
 class UseSwap : public IRChangeBase {
   Use ThisUse;
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index b75424909f0835..29a92aa82572b8 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2462,6 +2462,16 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
     return It->second.get();
   }
+  case llvm::Instruction::ICmp: {
+    auto *LLVMICmp = cast<llvm::ICmpInst>(LLVMV);
+    It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
+    return It->second.get();
+  }
+  case llvm::Instruction::FCmp: {
+    auto *LLVMFCmp = cast<llvm::FCmpInst>(LLVMV);
+    It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::Unreachable: {
     auto *LLVMUnreachable = cast<llvm::UnreachableInst>(LLVMV);
     It->second = std::unique_ptr<UnreachableInst>(
@@ -2637,6 +2647,58 @@ PHINode *Context::createPHINode(llvm::PHINode *I) {
   auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
   return cast<PHINode>(registerValue(std::move(NewPtr)));
 }
+CmpInst *Context::createCmpInst(llvm::CmpInst *I) {
+  auto NewPtr = std::unique_ptr<CmpInst>(new CmpInst(I, *this));
+  return cast<CmpInst>(registerValue(std::move(NewPtr)));
+}
+ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
+  auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
+  return cast<ICmpInst>(registerValue(std::move(NewPtr)));
+}
+FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
+  auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
+  return cast<FCmpInst>(registerValue(std::move(NewPtr)));
+}
+
+void CmpInst::setPredicate(Predicate P) {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<CmpSetPredicate>(this));
+  cast<llvm::CmpInst>(Val)->setPredicate(P);
+}
+
+void CmpInst::swapOperands() {
+  if (ICmpInst *IC = dyn_cast<ICmpInst>(this))
+    IC->swapOperands();
+  else
+    cast<FCmpInst>(this)->swapOperands();
+}
+
+void ICmpInst::swapOperands() {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<CmpSwapOperands>(this));
+  cast<llvm::ICmpInst>(Val)->swapOperands();
+}
+
+void FCmpInst::swapOperands() {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<CmpSwapOperands>(this));
+  cast<llvm::FCmpInst>(Val)->swapOperands();
+}
+
+#ifndef NDEBUG
+void CmpInst::dumpOS(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+
+void CmpInst::dump() const {
+  dumpOS(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
 
 Value *Context::getValue(llvm::Value *V) const {
   auto It = LLVMValueToValueMap.find(V);
diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index 953d4bd51353a9..9a824d5211a407 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -248,6 +248,24 @@ void ShuffleVectorSetMask::dump() const {
 }
 #endif
 
+CmpSetPredicate::CmpSetPredicate(CmpInst *Cmp)
+    : IRChangeBase(), Cmp(Cmp), OldP(Cmp->getPredicate()) {}
+
+void CmpSetPredicate::revert(Tracker &Tracker) { Cmp->setPredicate(OldP); }
+
+void CmpSetPredicate::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+
+CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
+
+void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
+void CmpSwapOperands::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+
 void Tracker::save() { State = TrackerState::Record; }
 
 void Tracker::revert() {
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index bc3fddf9e163dc..5bee722d2b953d 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4188,6 +4188,133 @@ define void @foo(i32 %arg) {
   EXPECT_EQ(NewPHI->getNumIncomingValues(), PHI->getNumIncomingValues());
 }
 
+void checkSwapOperands(sandboxir::Context &Ctx, llvm::sandboxir::CmpInst *SBCmp,
+                       llvm::CmpInst *LLVMCmp) {
+  auto OrigOp0 = SBCmp->getOperand(0);
+  auto OrigOp1 = SBCmp->getOperand(1);
+  EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp0);
+  EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp1);
+  // This checks the dispatch mechanism in CmpInst, as well as
+  // the specific implementations.
+  SBCmp->swapOperands();
+  EXPECT_NE(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp0);
+  EXPECT_NE(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp1);
+  EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp0);
+  EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp1);
+  // Undo it to keep the rest of the test consistent
+  SBCmp->swapOperands();
+}
+
+void checkCommonPredicates(sandboxir::CmpInst *SBCmp, llvm::CmpInst *LLVMCmp) {
+  // Check proper creation
+  auto SBPred = SBCmp->getPredicate();
+  auto LLVMPred = LLVMCmp->getPredicate();
+  EXPECT_EQ(SBPred, LLVMPred);
+  // Check setPredicate
+  SBCmp->setPredicate(llvm::CmpInst::FCMP_FALSE);
+  EXPECT_EQ(LLVMCmp->getPredicate(), llvm::CmpInst::FCMP_FALSE);
+  SBCmp->setPredicate(SBPred);
+  EXPECT_EQ(LLVMCmp->getPredicate(), SBPred);
+  // Ensure the accessors properly forward to the underlying implementation
+  EXPECT_STREQ(sandboxir::CmpInst::getPredicateName(SBPred).data(),
+               llvm::CmpInst::getPredicateName(LLVMPred).data());
+  EXPECT_EQ(SBCmp->isFPPredicate(), LLVMCmp->isFPPredicate());
+  EXPECT_EQ(SBCmp->isIntPredicate(), LLVMCmp->isIntPredicate());
+  EXPECT_EQ(SBCmp->getInversePredicate(), LLVMCmp->getInversePredicate());
+  EXPECT_EQ(SBCmp->getOrderedPredicate(), LLVMCmp->getOrderedPredicate());
+  EXPECT_EQ(SBCmp->getUnorderedPredicate(), LLVMCmp->getUnorderedPredicate());
+  EXPECT_EQ(SBCmp->getSwappedPredicate(), LLVMCmp->getSwappedPredicate());
+  EXPECT_EQ(SBCmp->isStrictPredicate(), LLVMCmp->isStrictPredicate());
+  EXPECT_EQ(SBCmp->isNonStrictPredicate(), LLVMCmp->isNonStrictPredicate());
+  EXPECT_EQ(SBCmp->isRelational(), LLVMCmp->isRelational());
+  if (SBCmp->isRelational()) {
+    EXPECT_EQ(SBCmp->getFlippedStrictnessPredicate(),
+              LLVMCmp->getFlippedStrictnessPredicate());
+  }
+  EXPECT_EQ(SBCmp->isCommutative(), LLVMCmp->isCommutative());
+  EXPECT_EQ(SBCmp->isTrueWhenEqual(), LLVMCmp->isTrueWhenEqual());
+  EXPECT_EQ(SBCmp->isFalseWhenEqual(), LLVMCmp->isFalseWhenEqual());
+  EXPECT_EQ(sandboxir::CmpInst::isOrdered(SBPred),
+            llvm::CmpInst::isOrdered(LLVMPred));
+  EXPECT_EQ(sandboxir::CmpInst::isUnordered(SBPred),
+            llvm::CmpInst::isUnordered(LLVMPred));
+}
+
+TEST_F(SandboxIRTest, ICmpInst) {
+  SCOPED_TRACE("SandboxIRTest sandboxir::ICmpInst tests");
+  parseIR(C, R"IR(
+define void @foo(float %f0, float %f1, i32 %i0, i32 %i1) {
+ bb:
+  %ine  = icmp ne i32 %i0, %i1
+  %iugt = icmp ugt i32 %i0, %i1
+  %iuge = icmp uge i32 %i0, %i1
+  %iult = icmp ult i32 %i0, %i1
+  %iule = icmp ule i32 %i0, %i1
+  %isgt = icmp sgt i32 %i0, %i1
+  %isle = icmp sle i32 %i0, %i1
+  %ieg  = icmp eq i32 %i0, %i1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+
+  auto *LLVMBB = getBasicBlockByName(LLVMF, "bb");
+  auto LLVMIt = LLVMBB->begin();
+  auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
+  auto It = BB->begin();
+  // Check classof()
+  while (auto *ICmp = dyn_cast<sandboxir::ICmpInst>(&*It++)) {
+    auto *LLVMICmp = cast<llvm::ICmpInst>(&*LLVMIt++);
+    checkSwapOperands(Ctx, ICmp, LLVMICmp);
+    checkCommonPredicates(ICmp, LLVMICmp);
+    EXPECT_EQ(ICmp->isSigned(), LLVMICmp->isSigned());
+    EXPECT_EQ(ICmp->isUnsigned(), LLVMICmp->isUnsigned());
+    EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
+    EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
+  }
+}
+
+TEST_F(SandboxIRTest, FCmpInst) {
+  SCOPED_TRACE("SandboxIRTest sandboxir::FCmpInst tests");
+  parseIR(C, R"IR(
+define void @foo(float %f0, float %f1) {
+bb:
+  %ffalse = fcmp false float %f0, %f1
+  %foeq = fcmp oeq float %f0, %f1
+  %fogt = fcmp ogt float %f0, %f1
+  %folt = fcmp olt float %f0, %f1
+  %fole = fcmp ole float %f0, %f1
+  %fone = fcmp one float %f0, %f1
+  %ford = fcmp ord float %f0, %f1
+  %funo = fcmp uno float %f0, %f1
+  %fueq = fcmp ueq float %f0, %f1
+  %fugt = fcmp ugt float %f0, %f1
+  %fuge = fcmp uge float %f0, %f1
+  %fult = fcmp ult float %f0, %f1
+  %fule = fcmp ule float %f0, %f1
+  %fune = fcmp une float %f0, %f1
+  %ftrue = fcmp true float %f0, %f1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+
+  auto *LLVMBB = getBasicBlockByName(LLVMF, "bb");
+  auto LLVMIt = LLVMBB->begin();
+  auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
+  auto It = BB->begin();
+  // Check classof()
+  while (auto *FCmp = dyn_cast<sandboxir::ICmpInst>(&*It++)) {
+    auto *LLVMFCmp = cast<llvm::ICmpInst>(&*LLVMIt++);
+    checkSwapOperands(Ctx, FCmp, LLVMFCmp);
+    checkCommonPredicates(FCmp, LLVMFCmp);
+  }
+}
+
 TEST_F(SandboxIRTest, UnreachableInst) {
   parseIR(C, R"IR(
 define void @foo() {
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index ca6effb727bf37..b4905c2231e3c8 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -1439,6 +1439,49 @@ define void @foo(i8 %arg0, i8 %arg1, i8 %arg2) {
   EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
 }
 
+void checkCmpInst(sandboxir::Context &Ctx, sandboxir::CmpInst *Cmp) {
+  Ctx.save();
+  auto OrigP = Cmp->getPredicate();
+  auto NewP = Cmp->getSwappedPredicate();
+  Cmp->setPredicate(NewP);
+  EXPECT_EQ(Cmp->getPredicate(), NewP);
+  Ctx.revert();
+  EXPECT_EQ(Cmp->getPredicate(), OrigP);
+
+  Ctx.save();
+  auto OrigOp0 = Cmp->getOperand(0);
+  auto OrigOp1 = Cmp->getOperand(1);
+  Cmp->swapOperands();
+  EXPECT_EQ(Cmp->getPredicate(), NewP);
+  EXPECT_EQ(Cmp->getOperand(0), OrigOp1);
+  EXPECT_EQ(Cmp->getOperand(1), OrigOp0);
+  Ctx.revert();
+  EXPECT_EQ(Cmp->getPredicate(), OrigP);
+  EXPECT_EQ(Cmp->getOperand(0), OrigOp0);
+  EXPECT_EQ(Cmp->getOperand(1), OrigOp1);
+}
+
+TEST_F(TrackerTest, CmpInst) {
+  SCOPED_TRACE("TrackerTest sandboxir::CmpInst tests");
+  parseIR(C, R"IR(
+define void @foo(i64 %i0, i64 %i1, float %f0, float %f1) {
+  %foeq = fcmp ogt float %f0, %f1
+  %ioeq = icmp uge i64 %i0, %i1
+
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *FCmp = cast<sandboxir::FCmpInst>(&*It++);
+  checkCmpInst(Ctx, FCmp);
+  auto *ICmp = cast<sandboxir::ICmpInst>(&*It++);
+  checkCmpInst(Ctx, ICmp);
+}
+
 TEST_F(TrackerTest, SetVolatile) {
   parseIR(C, R"IR(
 define void @foo(ptr %arg0, i8 %val) {

>From 3bf39eba11917fe6e174b546b16f9f9236191d32 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Wed, 28 Aug 2024 23:22:21 +0000
Subject: [PATCH 2/7] Address comments

---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 49 +++++++------------
 .../llvm/SandboxIR/SandboxIRValues.def        |  1 -
 llvm/lib/SandboxIR/SandboxIR.cpp              | 27 ++++++++--
 llvm/unittests/SandboxIR/TrackerTest.cpp      |  4 +-
 4 files changed, 43 insertions(+), 38 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 55f0d0489e44ed..e2a3cf1902189b 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -312,8 +312,6 @@ class Value {
   friend class UnreachableInst;       // For getting `Val`.
   friend class CatchSwitchAddHandler; // For `Val`.
   friend class CmpInst;               // For getting `Val`.
-  friend class ICmpInst;              // For getting `Val`.
-  friend class FCmpInst;              // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -3139,10 +3137,8 @@ class PHINode final : public SingleLLVMInstructionImpl<llvm::PHINode> {
 class CmpInst : public Instruction {
 protected:
   using LLVMValType = llvm::CmpInst;
-  /// Use Context::createCmpInst(). Don't call the
-  /// constructor directly.
-  CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id = ClassID::Cmp,
-          Opcode Opc = Opcode::Cmp)
+  /// Use Context::createCmpInst(). Don't call the constructor directly.
+  CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id, Opcode Opc)
       : Instruction(Id, Opc, CI, Ctx) {}
   friend Context; // for CmpInst()
   Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
@@ -3165,13 +3161,13 @@ class CmpInst : public Instruction {
   }
   unsigned getNumOfIRInstrs() const final { return 1u; }
   static CmpInst *create(OtherOps Op, Predicate Pred, Value *S1, Value *S2,
-                         const Twine &Name = "",
-                         InsertPosition InsertBefore = nullptr);
-  static CmpInst *CreateWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
+                         Context &Ctx, const Twine &Name = "",
+                         Instruction *InsertBefore = nullptr);
+  static CmpInst *createWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
                                         Value *S2,
                                         const Instruction *FlagsSource,
-                                        const Twine &Name = "",
-                                        InsertPosition InsertBefore = nullptr);
+                                        Context &Ctx, const Twine &Name = "",
+                                        Instruction *InsertBefore = nullptr);
   OtherOps getOpcode() const {
     return static_cast<OtherOps>(Instruction::getOpcode());
   }
@@ -3211,20 +3207,19 @@ class CmpInst : public Instruction {
     return llvm::CmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
   }
 
-  /// Methods for support type inquiry through isa, cast, and dyn_cast:
-  static bool classof(const Instruction *From) {
-    return isa<ICmpInst>(From) || isa<FCmpInst>(From);
-  }
+  /// Method for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Value *From) {
-    return isa<Instruction>(From) && classof(cast<Instruction>(From));
+    return From->getSubclassID() == ClassID::ICmp ||
+           From->getSubclassID() == ClassID::FCmp;
   }
+
   /// Create a result type for fcmp/icmp
-  static Type *makeCmpResultType(Type *opnd_type) {
-    if (VectorType *vt = dyn_cast<VectorType>(opnd_type)) {
-      return VectorType::get(Type::getInt1Ty(opnd_type->getContext()),
+  static Type *makeCmpResultType(Type *OpndType) {
+    if (VectorType *vt = dyn_cast<VectorType>(OpndType)) {
+      return VectorType::get(Type::getInt1Ty(OpndType->getContext()),
                              vt->getElementCount());
     }
-    return Type::getInt1Ty(opnd_type->getContext());
+    return Type::getInt1Ty(OpndType->getContext());
   }
 
 #ifndef NDEBUG
@@ -3237,8 +3232,7 @@ class CmpInst : public Instruction {
 };
 
 class ICmpInst : public CmpInst {
-  /// Use Context::createICmpInst(). Don't call the
-  /// constructor directly.
+  /// Use Context::createICmpInst(). Don't call the constructor directly.
   ICmpInst(llvm::ICmpInst *CI, Context &Ctx)
       : CmpInst(CI, Ctx, ClassID::ICmp, Opcode::ICmp) {}
   friend class Context; // For constructor.
@@ -3263,17 +3257,13 @@ class ICmpInst : public CmpInst {
     return llvm::ICmpInst::compare(LHS, RHS, Pred);
   }
 
-  static bool classof(const Instruction *From) {
-    return From->getSubclassID() == ClassID::ICmp;
-  }
   static bool classof(const Value *From) {
-    return isa<Instruction>(From) && classof(cast<Instruction>(From));
+    return From->getSubclassID() == ClassID::ICmp;
   }
 };
 
 class FCmpInst : public CmpInst {
-  /// Use Context::createFCmpInst(). Don't call the
-  /// constructor directly.
+  /// Use Context::createFCmpInst(). Don't call the constructor directly.
   FCmpInst(llvm::FCmpInst *CI, Context &Ctx)
       : CmpInst(CI, Ctx, ClassID::FCmp, Opcode::FCmp) {}
   friend class Context; // For constructor.
@@ -3292,9 +3282,6 @@ class FCmpInst : public CmpInst {
     return llvm::FCmpInst::compare(LHS, RHS, Pred);
   }
 
-  static bool classof(const Instruction *From) {
-    return From->getSubclassID() == ClassID::FCmp;
-  }
   static bool classof(const Value *From) {
     return From->getSubclassID() == ClassID::FCmp;
   }
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 4addbd049f1af0..0712498d7c894b 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -109,7 +109,6 @@ DEF_INSTR(Cast,   OPCODES(\
                           ),                  CastInst)
 DEF_INSTR(PHI,            OP(PHI),            PHINode)
 DEF_INSTR(Unreachable,    OP(Unreachable),    UnreachableInst)
-DEF_INSTR(Cmp,            OP(Cmp),           CmpInst)
 DEF_INSTR(ICmp,           OP(ICmp),          FCmpInst)
 DEF_INSTR(FCmp,           OP(FCmp),          ICmpInst)
 
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 69ed354bd928c0..20a8a300b289cd 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2701,10 +2701,6 @@ PHINode *Context::createPHINode(llvm::PHINode *I) {
   auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
   return cast<PHINode>(registerValue(std::move(NewPtr)));
 }
-CmpInst *Context::createCmpInst(llvm::CmpInst *I) {
-  auto NewPtr = std::unique_ptr<CmpInst>(new CmpInst(I, *this));
-  return cast<CmpInst>(registerValue(std::move(NewPtr)));
-}
 ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
   auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
   return cast<ICmpInst>(registerValue(std::move(NewPtr)));
@@ -2714,6 +2710,29 @@ FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
   return cast<FCmpInst>(registerValue(std::move(NewPtr)));
 }
 
+CmpInst *CmpInst::create(OtherOps Op, Predicate P, Value *S1, Value *S2,
+                         Context &Ctx, const Twine &Name,
+                         Instruction *InsertBefore) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (InsertBefore)
+    Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
+  auto *LLVMI =
+      cast<llvm::CmpInst>(Builder.CreateCmp(P, S1->Val, S2->Val, Name));
+  if (llvm::ICmpInst *IC = dyn_cast<llvm::ICmpInst>(LLVMI))
+    return Ctx.createICmpInst(IC);
+  else
+    return Ctx.createFCmpInst(cast<llvm::FCmpInst>(IC));
+}
+
+CmpInst *CmpInst::createWithCopiedFlags(OtherOps Op, Predicate P, Value *S1,
+                                        Value *S2, const Instruction *F,
+                                        Context &Ctx, const Twine &Name,
+                                        Instruction *InsertBefore) {
+  CmpInst *Inst = create(Op, P, S1, S2, Ctx, Name, InsertBefore);
+  cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
+  return Inst;
+}
+
 void CmpInst::setPredicate(Predicate P) {
   auto &Tracker = Ctx.getTracker();
   if (Tracker.isTracking())
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index b4905c2231e3c8..b192e0f4bd9ca6 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -1476,9 +1476,9 @@ define void @foo(i64 %i0, i64 %i1, float %f0, float %f1) {
   auto &F = *Ctx.createFunction(&LLVMF);
   auto *BB = &*F.begin();
   auto It = BB->begin();
-  auto *FCmp = cast<sandboxir::FCmpInst>(&*It++);
+  auto *FCmp = cast<sandboxir::CmpInst>(&*It++);
   checkCmpInst(Ctx, FCmp);
-  auto *ICmp = cast<sandboxir::ICmpInst>(&*It++);
+  auto *ICmp = cast<sandboxir::CmpInst>(&*It++);
   checkCmpInst(Ctx, ICmp);
 }
 

>From 42b5d75b9a70a811df235e651ab231a5856afd01 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Thu, 29 Aug 2024 21:51:26 +0000
Subject: [PATCH 3/7] Address comments

---
 llvm/lib/SandboxIR/SandboxIR.cpp           | 13 +++++-----
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 28 +++++++++++++++++++++-
 2 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 20a8a300b289cd..70207267fe04b0 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2716,12 +2716,13 @@ CmpInst *CmpInst::create(OtherOps Op, Predicate P, Value *S1, Value *S2,
   auto &Builder = Ctx.getLLVMIRBuilder();
   if (InsertBefore)
     Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
-  auto *LLVMI =
-      cast<llvm::CmpInst>(Builder.CreateCmp(P, S1->Val, S2->Val, Name));
-  if (llvm::ICmpInst *IC = dyn_cast<llvm::ICmpInst>(LLVMI))
-    return Ctx.createICmpInst(IC);
-  else
-    return Ctx.createFCmpInst(cast<llvm::FCmpInst>(IC));
+  auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
+  if (Op == OtherOps::ICmp)
+    return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
+  else {
+    assert(Op == OtherOps::FCmp);
+    return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
+  }
 }
 
 CmpInst *CmpInst::createWithCopiedFlags(OtherOps Op, Predicate P, Value *S1,
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 65e6e93a97781a..bbccb042013d6a 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4447,7 +4447,7 @@ void checkCommonPredicates(sandboxir::CmpInst *SBCmp, llvm::CmpInst *LLVMCmp) {
 TEST_F(SandboxIRTest, ICmpInst) {
   SCOPED_TRACE("SandboxIRTest sandboxir::ICmpInst tests");
   parseIR(C, R"IR(
-define void @foo(float %f0, float %f1, i32 %i0, i32 %i1) {
+define void @foo(i32 %i0, i32 %i1) {
  bb:
   %ine  = icmp ne i32 %i0, %i1
   %iugt = icmp ugt i32 %i0, %i1
@@ -4478,6 +4478,10 @@ define void @foo(float %f0, float %f1, i32 %i0, i32 %i1) {
     EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
     EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
   }
+  auto *NewCmp = sandboxir::CmpInst::create(
+      CmpInst::OtherOps::ICmp, llvm::CmpInst::ICMP_ULE, F.getArg(0),
+      F.getArg(1), Ctx, "", &*BB->begin());
+  EXPECT_EQ(NewCmp, &*BB->begin());
 }
 
 TEST_F(SandboxIRTest, FCmpInst) {
@@ -4501,6 +4505,9 @@ define void @foo(float %f0, float %f1) {
   %fune = fcmp une float %f0, %f1
   %ftrue = fcmp true float %f0, %f1
   ret void
+bb1:
+  %copyfrom = fadd reassoc float %f0, 42.0
+  ret void
 }
 )IR");
   Function &LLVMF = *M->getFunction("foo");
@@ -4517,6 +4524,25 @@ define void @foo(float %f0, float %f1) {
     checkSwapOperands(Ctx, FCmp, LLVMFCmp);
     checkCommonPredicates(FCmp, LLVMFCmp);
   }
+
+  auto *LLVMBB1 = getBasicBlockByName(LLVMF, "bb1");
+  auto *BB1 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB1));
+  auto It1 = BB1->begin();
+  auto *CopyFrom = &*It1++;
+  CopyFrom->setFastMathFlags(FastMathFlags::getFast());
+
+  // create with default flags
+  auto *NewFCmp = sandboxir::CmpInst::create(
+      CmpInst::OtherOps::FCmp, llvm::CmpInst::FCMP_ONE, F.getArg(0),
+      F.getArg(1), Ctx, "", &*It1);
+  FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
+  EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
+  // create with copied flags
+  auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
+      CmpInst::OtherOps::FCmp, llvm::CmpInst::FCMP_ONE, F.getArg(0),
+      F.getArg(1), CopyFrom, Ctx, "", &*It1);
+  EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
+               CopyFrom->getFastMathFlags());
 }
 
 TEST_F(SandboxIRTest, UnreachableInst) {

>From 17ca863e4120ed18bdfe33499ed67244a43bd455 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Fri, 30 Aug 2024 17:57:41 +0000
Subject: [PATCH 4/7] Address more comments.

---
 llvm/include/llvm/SandboxIR/SandboxIR.h    | 34 ++--------
 llvm/include/llvm/SandboxIR/Tracker.h      | 15 -----
 llvm/lib/SandboxIR/SandboxIR.cpp           | 37 ++++-------
 llvm/lib/SandboxIR/Tracker.cpp             | 10 ---
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 76 +++++++++++-----------
 5 files changed, 59 insertions(+), 113 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index e2a3cf1902189b..1555654b4c079e 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -751,8 +751,6 @@ class Instruction : public sandboxir::User {
   friend class PHINode;            // For getTopmostLLVMInstruction().
   friend class UnreachableInst;    // For getTopmostLLVMInstruction().
   friend class CmpInst;            // For getTopmostLLVMInstruction().
-  friend class ICmpInst;           // For getTopmostLLVMInstruction().
-  friend class FCmpInst;           // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -909,6 +907,7 @@ template <typename LLVMT> class SingleLLVMInstructionImpl : public Instruction {
   friend class UnaryInstruction;
   friend class CallBase;
   friend class FuncletPadInst;
+  friend class CmpInst;
 
   Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
     return getOperandUseDefault(OpIdx, Verify);
@@ -3128,49 +3127,33 @@ class PHINode final : public SingleLLVMInstructionImpl<llvm::PHINode> {
 // Wraps a member function that takes no parameters
 // LLVMValType should be the type of the wrapped class
 #define WRAP_MEMBER(FunctionName)                                              \
-  auto FunctionName() { return cast<LLVMValType>(Val)->FunctionName(); }
+  auto FunctionName() const { return cast<LLVMValType>(Val)->FunctionName(); }
 // Wraps both--a common idiom in the CmpInst classes
 #define WRAP_BOTH(FunctionName)                                                \
   WRAP_STATIC_PREDICATE(FunctionName)                                          \
   WRAP_MEMBER(FunctionName)
 
-class CmpInst : public Instruction {
+class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
 protected:
   using LLVMValType = llvm::CmpInst;
   /// Use Context::createCmpInst(). Don't call the constructor directly.
   CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id, Opcode Opc)
-      : Instruction(Id, Opc, CI, Ctx) {}
+      : SingleLLVMInstructionImpl(Id, Opc, CI, Ctx) {}
   friend Context; // for CmpInst()
-  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
-    return getOperandUseDefault(OpIdx, Verify);
-  }
-  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
-    return {cast<llvm::Instruction>(Val)};
-  }
   static Value *createCommon(Value *Cond, Value *True, Value *False,
                              const Twine &Name, IRBuilder<> &Builder,
                              Context &Ctx);
 
 public:
   using Predicate = llvm::CmpInst::Predicate;
-  using PredicateField = llvm::CmpInst::PredicateField;
-  using OtherOps = llvm::Instruction::OtherOps;
 
-  unsigned getUseOperandNo(const Use &Use) const final {
-    return getUseOperandNoDefault(Use);
-  }
-  unsigned getNumOfIRInstrs() const final { return 1u; }
-  static CmpInst *create(OtherOps Op, Predicate Pred, Value *S1, Value *S2,
-                         Context &Ctx, const Twine &Name = "",
+  static CmpInst *create(Predicate Pred, Value *S1, Value *S2, Context &Ctx,
+                         const Twine &Name = "",
                          Instruction *InsertBefore = nullptr);
-  static CmpInst *createWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
-                                        Value *S2,
+  static CmpInst *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
                                         const Instruction *FlagsSource,
                                         Context &Ctx, const Twine &Name = "",
                                         Instruction *InsertBefore = nullptr);
-  OtherOps getOpcode() const {
-    return static_cast<OtherOps>(Instruction::getOpcode());
-  }
   void setPredicate(Predicate P);
   void swapOperands();
 
@@ -3223,9 +3206,6 @@ class CmpInst : public Instruction {
   }
 
 #ifndef NDEBUG
-  void verify() const final {
-    assert(isa<LLVMValType>(Val) && "Expected CmpInst!");
-  }
   void dumpOS(raw_ostream &OS) const override;
   LLVM_DUMP_METHOD void dump() const;
 #endif
diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index d3d0bf3dfe7803..5fc43db82bd707 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -64,7 +64,6 @@ class SwitchInst;
 class ConstantInt;
 class ShuffleVectorInst;
 class CmpInst;
-
 /// The base class for IR Change classes.
 class IRChangeBase {
 protected:
@@ -131,20 +130,6 @@ class PHIAddIncoming : public IRChangeBase {
 #endif
 };
 
-class CmpSetPredicate : public IRChangeBase {
-  CmpInst *Cmp;
-  llvm::CmpInst::Predicate OldP;
-
-public:
-  CmpSetPredicate(CmpInst *Cmp);
-  void revert(Tracker &Tracker) final;
-  void accept() final {}
-#ifndef NDEBUG
-  void dump(raw_ostream &OS) const final { OS << "CmpSetPredicate"; }
-  LLVM_DUMP_METHOD void dump() const final;
-#endif
-};
-
 class CmpSwapOperands : public IRChangeBase {
   CmpInst *Cmp;
 
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 70207267fe04b0..29e2ffcc405b05 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2710,34 +2710,29 @@ FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
   return cast<FCmpInst>(registerValue(std::move(NewPtr)));
 }
 
-CmpInst *CmpInst::create(OtherOps Op, Predicate P, Value *S1, Value *S2,
-                         Context &Ctx, const Twine &Name,
-                         Instruction *InsertBefore) {
+CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2, Context &Ctx,
+                         const Twine &Name, Instruction *InsertBefore) {
   auto &Builder = Ctx.getLLVMIRBuilder();
-  if (InsertBefore)
-    Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
+  Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
   auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
-  if (Op == OtherOps::ICmp)
+  if (dyn_cast<llvm::ICmpInst>(LLVMI))
     return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
-  else {
-    assert(Op == OtherOps::FCmp);
-    return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
-  }
+  return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
 }
 
-CmpInst *CmpInst::createWithCopiedFlags(OtherOps Op, Predicate P, Value *S1,
-                                        Value *S2, const Instruction *F,
-                                        Context &Ctx, const Twine &Name,
+CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
+                                        const Instruction *F, Context &Ctx,
+                                        const Twine &Name,
                                         Instruction *InsertBefore) {
-  CmpInst *Inst = create(Op, P, S1, S2, Ctx, Name, InsertBefore);
+  CmpInst *Inst = create(P, S1, S2, Ctx, Name, InsertBefore);
   cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
   return Inst;
 }
 
 void CmpInst::setPredicate(Predicate P) {
-  auto &Tracker = Ctx.getTracker();
-  if (Tracker.isTracking())
-    Tracker.track(std::make_unique<CmpSetPredicate>(this));
+  Ctx.getTracker()
+      .emplaceIfTracking<
+          GenericSetter<&CmpInst::getPredicate, &CmpInst::setPredicate>>(this);
   cast<llvm::CmpInst>(Val)->setPredicate(P);
 }
 
@@ -2749,16 +2744,12 @@ void CmpInst::swapOperands() {
 }
 
 void ICmpInst::swapOperands() {
-  auto &Tracker = Ctx.getTracker();
-  if (Tracker.isTracking())
-    Tracker.track(std::make_unique<CmpSwapOperands>(this));
+  Ctx.getTracker().emplaceIfTracking<CmpSwapOperands>(this);
   cast<llvm::ICmpInst>(Val)->swapOperands();
 }
 
 void FCmpInst::swapOperands() {
-  auto &Tracker = Ctx.getTracker();
-  if (Tracker.isTracking())
-    Tracker.track(std::make_unique<CmpSwapOperands>(this));
+  Ctx.getTracker().emplaceIfTracking<CmpSwapOperands>(this);
   cast<llvm::FCmpInst>(Val)->swapOperands();
 }
 
diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index 9a824d5211a407..c6eb9fc68a4b11 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -248,16 +248,6 @@ void ShuffleVectorSetMask::dump() const {
 }
 #endif
 
-CmpSetPredicate::CmpSetPredicate(CmpInst *Cmp)
-    : IRChangeBase(), Cmp(Cmp), OldP(Cmp->getPredicate()) {}
-
-void CmpSetPredicate::revert(Tracker &Tracker) { Cmp->setPredicate(OldP); }
-
-void CmpSetPredicate::dump() const {
-  dump(dbgs());
-  dbgs() << "\n";
-}
-
 CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
 
 void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index bbccb042013d6a..2fec35b1d5962b 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4392,55 +4392,56 @@ define void @foo(i32 %arg) {
   EXPECT_EQ(NewPHI->getNumIncomingValues(), PHI->getNumIncomingValues());
 }
 
-void checkSwapOperands(sandboxir::Context &Ctx, llvm::sandboxir::CmpInst *SBCmp,
-                       llvm::CmpInst *LLVMCmp) {
-  auto OrigOp0 = SBCmp->getOperand(0);
-  auto OrigOp1 = SBCmp->getOperand(1);
+static void checkSwapOperands(sandboxir::Context &Ctx,
+                              llvm::sandboxir::CmpInst *Cmp,
+                              llvm::CmpInst *LLVMCmp) {
+  auto OrigOp0 = Cmp->getOperand(0);
+  auto OrigOp1 = Cmp->getOperand(1);
   EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp0);
   EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp1);
   // This checks the dispatch mechanism in CmpInst, as well as
   // the specific implementations.
-  SBCmp->swapOperands();
-  EXPECT_NE(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp0);
-  EXPECT_NE(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp1);
+  Cmp->swapOperands();
   EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp0);
   EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp1);
   // Undo it to keep the rest of the test consistent
-  SBCmp->swapOperands();
+  Cmp->swapOperands();
 }
 
-void checkCommonPredicates(sandboxir::CmpInst *SBCmp, llvm::CmpInst *LLVMCmp) {
+static void checkCommonPredicates(sandboxir::CmpInst *Cmp,
+                                  llvm::CmpInst *LLVMCmp) {
   // Check proper creation
-  auto SBPred = SBCmp->getPredicate();
+  auto Pred = Cmp->getPredicate();
   auto LLVMPred = LLVMCmp->getPredicate();
-  EXPECT_EQ(SBPred, LLVMPred);
+  EXPECT_EQ(Pred, LLVMPred);
   // Check setPredicate
-  SBCmp->setPredicate(llvm::CmpInst::FCMP_FALSE);
+  Cmp->setPredicate(llvm::CmpInst::FCMP_FALSE);
+  EXPECT_EQ(Cmp->getPredicate(), llvm::CmpInst::FCMP_FALSE);
   EXPECT_EQ(LLVMCmp->getPredicate(), llvm::CmpInst::FCMP_FALSE);
-  SBCmp->setPredicate(SBPred);
-  EXPECT_EQ(LLVMCmp->getPredicate(), SBPred);
+  Cmp->setPredicate(Pred);
+  EXPECT_EQ(LLVMCmp->getPredicate(), Pred);
   // Ensure the accessors properly forward to the underlying implementation
-  EXPECT_STREQ(sandboxir::CmpInst::getPredicateName(SBPred).data(),
+  EXPECT_STREQ(sandboxir::CmpInst::getPredicateName(Pred).data(),
                llvm::CmpInst::getPredicateName(LLVMPred).data());
-  EXPECT_EQ(SBCmp->isFPPredicate(), LLVMCmp->isFPPredicate());
-  EXPECT_EQ(SBCmp->isIntPredicate(), LLVMCmp->isIntPredicate());
-  EXPECT_EQ(SBCmp->getInversePredicate(), LLVMCmp->getInversePredicate());
-  EXPECT_EQ(SBCmp->getOrderedPredicate(), LLVMCmp->getOrderedPredicate());
-  EXPECT_EQ(SBCmp->getUnorderedPredicate(), LLVMCmp->getUnorderedPredicate());
-  EXPECT_EQ(SBCmp->getSwappedPredicate(), LLVMCmp->getSwappedPredicate());
-  EXPECT_EQ(SBCmp->isStrictPredicate(), LLVMCmp->isStrictPredicate());
-  EXPECT_EQ(SBCmp->isNonStrictPredicate(), LLVMCmp->isNonStrictPredicate());
-  EXPECT_EQ(SBCmp->isRelational(), LLVMCmp->isRelational());
-  if (SBCmp->isRelational()) {
-    EXPECT_EQ(SBCmp->getFlippedStrictnessPredicate(),
+  EXPECT_EQ(Cmp->isFPPredicate(), LLVMCmp->isFPPredicate());
+  EXPECT_EQ(Cmp->isIntPredicate(), LLVMCmp->isIntPredicate());
+  EXPECT_EQ(Cmp->getInversePredicate(), LLVMCmp->getInversePredicate());
+  EXPECT_EQ(Cmp->getOrderedPredicate(), LLVMCmp->getOrderedPredicate());
+  EXPECT_EQ(Cmp->getUnorderedPredicate(), LLVMCmp->getUnorderedPredicate());
+  EXPECT_EQ(Cmp->getSwappedPredicate(), LLVMCmp->getSwappedPredicate());
+  EXPECT_EQ(Cmp->isStrictPredicate(), LLVMCmp->isStrictPredicate());
+  EXPECT_EQ(Cmp->isNonStrictPredicate(), LLVMCmp->isNonStrictPredicate());
+  EXPECT_EQ(Cmp->isRelational(), LLVMCmp->isRelational());
+  if (Cmp->isRelational()) {
+    EXPECT_EQ(Cmp->getFlippedStrictnessPredicate(),
               LLVMCmp->getFlippedStrictnessPredicate());
   }
-  EXPECT_EQ(SBCmp->isCommutative(), LLVMCmp->isCommutative());
-  EXPECT_EQ(SBCmp->isTrueWhenEqual(), LLVMCmp->isTrueWhenEqual());
-  EXPECT_EQ(SBCmp->isFalseWhenEqual(), LLVMCmp->isFalseWhenEqual());
-  EXPECT_EQ(sandboxir::CmpInst::isOrdered(SBPred),
+  EXPECT_EQ(Cmp->isCommutative(), LLVMCmp->isCommutative());
+  EXPECT_EQ(Cmp->isTrueWhenEqual(), LLVMCmp->isTrueWhenEqual());
+  EXPECT_EQ(Cmp->isFalseWhenEqual(), LLVMCmp->isFalseWhenEqual());
+  EXPECT_EQ(sandboxir::CmpInst::isOrdered(Pred),
             llvm::CmpInst::isOrdered(LLVMPred));
-  EXPECT_EQ(sandboxir::CmpInst::isUnordered(SBPred),
+  EXPECT_EQ(sandboxir::CmpInst::isUnordered(Pred),
             llvm::CmpInst::isUnordered(LLVMPred));
 }
 
@@ -4478,9 +4479,9 @@ define void @foo(i32 %i0, i32 %i1) {
     EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
     EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
   }
-  auto *NewCmp = sandboxir::CmpInst::create(
-      CmpInst::OtherOps::ICmp, llvm::CmpInst::ICMP_ULE, F.getArg(0),
-      F.getArg(1), Ctx, "", &*BB->begin());
+  auto *NewCmp =
+      sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
+                                 F.getArg(1), Ctx, "", &*BB->begin());
   EXPECT_EQ(NewCmp, &*BB->begin());
 }
 
@@ -4533,14 +4534,13 @@ define void @foo(float %f0, float %f1) {
 
   // create with default flags
   auto *NewFCmp = sandboxir::CmpInst::create(
-      CmpInst::OtherOps::FCmp, llvm::CmpInst::FCMP_ONE, F.getArg(0),
-      F.getArg(1), Ctx, "", &*It1);
+      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), Ctx, "", &*It1);
   FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
   EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
   // create with copied flags
   auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
-      CmpInst::OtherOps::FCmp, llvm::CmpInst::FCMP_ONE, F.getArg(0),
-      F.getArg(1), CopyFrom, Ctx, "", &*It1);
+      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, Ctx, "",
+      &*It1);
   EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
                CopyFrom->getFastMathFlags());
 }

>From 6508dcfffe7c2ef0e2ef0058b012217686a0778a Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Fri, 30 Aug 2024 18:59:56 +0000
Subject: [PATCH 5/7] Address more comments.

---
 llvm/include/llvm/SandboxIR/SandboxIR.h    | 10 +++++-----
 llvm/lib/SandboxIR/SandboxIR.cpp           | 13 +++++++------
 llvm/unittests/SandboxIR/SandboxIRTest.cpp |  8 ++++----
 3 files changed, 16 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 1555654b4c079e..b8c238aa9adec9 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -3147,13 +3147,13 @@ class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
 public:
   using Predicate = llvm::CmpInst::Predicate;
 
-  static CmpInst *create(Predicate Pred, Value *S1, Value *S2, Context &Ctx,
-                         const Twine &Name = "",
-                         Instruction *InsertBefore = nullptr);
+  static CmpInst *create(Predicate Pred, Value *S1, Value *S2,
+                         Instruction *InsertBefore, Context &Ctx,
+                         const Twine &Name = "");
   static CmpInst *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
                                         const Instruction *FlagsSource,
-                                        Context &Ctx, const Twine &Name = "",
-                                        Instruction *InsertBefore = nullptr);
+                                        Instruction *InsertBefore, Context &Ctx,
+                                        const Twine &Name = "");
   void setPredicate(Predicate P);
   void swapOperands();
 
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 29e2ffcc405b05..f66b7a3eaa4684 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2710,8 +2710,9 @@ FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
   return cast<FCmpInst>(registerValue(std::move(NewPtr)));
 }
 
-CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2, Context &Ctx,
-                         const Twine &Name, Instruction *InsertBefore) {
+CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2,
+                         Instruction *InsertBefore, Context &Ctx,
+                         const Twine &Name) {
   auto &Builder = Ctx.getLLVMIRBuilder();
   Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
   auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
@@ -2721,10 +2722,10 @@ CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2, Context &Ctx,
 }
 
 CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
-                                        const Instruction *F, Context &Ctx,
-                                        const Twine &Name,
-                                        Instruction *InsertBefore) {
-  CmpInst *Inst = create(P, S1, S2, Ctx, Name, InsertBefore);
+                                        const Instruction *F,
+                                        Instruction *InsertBefore, Context &Ctx,
+                                        const Twine &Name) {
+  CmpInst *Inst = create(P, S1, S2, InsertBefore, Ctx, Name);
   cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
   return Inst;
 }
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 2fec35b1d5962b..4406415db94355 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4481,7 +4481,7 @@ define void @foo(i32 %i0, i32 %i1) {
   }
   auto *NewCmp =
       sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
-                                 F.getArg(1), Ctx, "", &*BB->begin());
+                                 F.getArg(1), &*BB->begin(), Ctx, "");
   EXPECT_EQ(NewCmp, &*BB->begin());
 }
 
@@ -4534,13 +4534,13 @@ define void @foo(float %f0, float %f1) {
 
   // create with default flags
   auto *NewFCmp = sandboxir::CmpInst::create(
-      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), Ctx, "", &*It1);
+      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), &*It1, Ctx, "");
   FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
   EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
   // create with copied flags
   auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
-      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, Ctx, "",
-      &*It1);
+      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, &*It1, Ctx,
+      "");
   EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
                CopyFrom->getFastMathFlags());
 }

>From 4efbddd93159e5e3f3709ea2df9df78225837906 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 3 Sep 2024 16:59:38 -0700
Subject: [PATCH 6/7] Better support for makeCmpResultType

---
 llvm/include/llvm/SandboxIR/SandboxIR.h    | 10 +++-------
 llvm/include/llvm/SandboxIR/Type.h         |  3 +++
 llvm/lib/SandboxIR/SandboxIR.cpp           | 13 +++++++++++--
 llvm/unittests/SandboxIR/SandboxIRTest.cpp |  5 +++++
 4 files changed, 22 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 11f462ef5e86e8..1233b57cede7e8 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -3267,13 +3267,7 @@ class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
   }
 
   /// Create a result type for fcmp/icmp
-  static Type *makeCmpResultType(Type *OpndType) {
-    if (VectorType *vt = dyn_cast<VectorType>(OpndType)) {
-      return VectorType::get(Type::getInt1Ty(OpndType->getContext()),
-                             vt->getElementCount());
-    }
-    return Type::getInt1Ty(OpndType->getContext());
-  }
+  static Type *makeCmpResultType(Type *OpndType);
 
 #ifndef NDEBUG
   void dumpOS(raw_ostream &OS) const override;
@@ -3361,6 +3355,8 @@ class Context {
   LLVMContext &LLVMCtx;
   friend class Type;        // For LLVMCtx.
   friend class PointerType; // For LLVMCtx.
+  friend class CmpInst; // For LLVMCtx. TODO: cleanup when sandboxir::VectorType
+                        // is complete
   Tracker IRTracker;
 
   /// Maps LLVM Value to the corresponding sandboxir::Value. Owns all
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 89e787f5f5d4b2..1f16ad02b30cec 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -41,6 +41,9 @@ class Type {
   friend class Function;     // For LLVMTy.
   friend class CallBase;     // For LLVMTy.
   friend class ConstantInt;  // For LLVMTy.
+  friend class CmpInst; // For LLVMTy. TODO: Cleanup after sandboxir::VectorType
+                        // is more complete.
+
   // Friend all instruction classes because `create()` functions use LLVMTy.
 #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
 #define DEF_CONST(ID, CLASS) friend class CLASS;
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 94af9046c3a9bd..e5ff88f92c42f1 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2826,7 +2826,6 @@ FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
   auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
   return cast<FCmpInst>(registerValue(std::move(NewPtr)));
 }
-
 CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2,
                          Instruction *InsertBefore, Context &Ctx,
                          const Twine &Name) {
@@ -2837,7 +2836,6 @@ CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2,
     return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
   return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
 }
-
 CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
                                         const Instruction *F,
                                         Instruction *InsertBefore, Context &Ctx,
@@ -2847,6 +2845,17 @@ CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
   return Inst;
 }
 
+Type *CmpInst::makeCmpResultType(Type *OpndType) {
+  if (auto *VT = dyn_cast<VectorType>(OpndType)) {
+    // TODO: Cleanup when we have more complete support for
+    // sandboxir::VectorType
+    return OpndType->getContext().getType(llvm::VectorType::get(
+        llvm::Type::getInt1Ty(OpndType->getContext().LLVMCtx),
+        cast<llvm::VectorType>(VT->LLVMTy)->getElementCount()));
+  }
+  return Type::getInt1Ty(OpndType->getContext());
+}
+
 void CmpInst::setPredicate(Predicate P) {
   Ctx.getTracker()
       .emplaceIfTracking<
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 5d93eccaac4e71..21b99b8a50dbe5 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4656,6 +4656,11 @@ define void @foo(i32 %i0, i32 %i1) {
       sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
                                  F.getArg(1), &*BB->begin(), Ctx, "");
   EXPECT_EQ(NewCmp, &*BB->begin());
+  // TODO: Improve this test when sandboxir::VectorType is more completely
+  // implemented.
+  sandboxir::Type *RT =
+      sandboxir::CmpInst::makeCmpResultType(F.getArg(0)->getType());
+  EXPECT_TRUE(RT->isIntegerTy(1)); // Only one bit in a single comparison
 }
 
 TEST_F(SandboxIRTest, FCmpInst) {

>From 1a8a4fb79d681d6dc1900b236507408e4952ae9b Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 3 Sep 2024 17:30:59 -0700
Subject: [PATCH 7/7] Address more comments

---
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 26 +++++++++++++++++++---
 1 file changed, 23 insertions(+), 3 deletions(-)

diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 21b99b8a50dbe5..2925a2889395de 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4577,6 +4577,8 @@ static void checkSwapOperands(sandboxir::Context &Ctx,
   Cmp->swapOperands();
   EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp0);
   EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp1);
+  EXPECT_EQ(Cmp->getOperand(0), OrigOp1);
+  EXPECT_EQ(Cmp->getOperand(1), OrigOp0);
   // Undo it to keep the rest of the test consistent
   Cmp->swapOperands();
 }
@@ -4654,8 +4656,14 @@ define void @foo(i32 %i0, i32 %i1) {
   }
   auto *NewCmp =
       sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
-                                 F.getArg(1), &*BB->begin(), Ctx, "");
+                                 F.getArg(1), &*BB->begin(), Ctx, "NewCmp");
   EXPECT_EQ(NewCmp, &*BB->begin());
+  EXPECT_EQ(NewCmp->getPredicate(), llvm::CmpInst::ICMP_ULE);
+  EXPECT_EQ(NewCmp->getOperand(0), F.getArg(0));
+  EXPECT_EQ(NewCmp->getOperand(1), F.getArg(1));
+#ifndef NDEBUG
+  EXPECT_EQ(NewCmp->getName(), "NewCmp");
+#endif // NDEBUG
   // TODO: Improve this test when sandboxir::VectorType is more completely
   // implemented.
   sandboxir::Type *RT =
@@ -4712,15 +4720,27 @@ define void @foo(float %f0, float %f1) {
 
   // create with default flags
   auto *NewFCmp = sandboxir::CmpInst::create(
-      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), &*It1, Ctx, "");
+      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), &*It1, Ctx, "NewFCmp");
+  EXPECT_EQ(NewFCmp->getPredicate(), llvm::CmpInst::FCMP_ONE);
+  EXPECT_EQ(NewFCmp->getOperand(0), F.getArg(0));
+  EXPECT_EQ(NewFCmp->getOperand(1), F.getArg(1));
+#ifndef NDEBUG
+  EXPECT_EQ(NewFCmp->getName(), "NewFCmp");
+#endif // NDEBUG
   FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
   EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
   // create with copied flags
   auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
       llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, &*It1, Ctx,
-      "");
+      "NewFCmpFlags");
   EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
                CopyFrom->getFastMathFlags());
+  EXPECT_EQ(NewFCmpFlags->getPredicate(), llvm::CmpInst::FCMP_ONE);
+  EXPECT_EQ(NewFCmpFlags->getOperand(0), F.getArg(0));
+  EXPECT_EQ(NewFCmpFlags->getOperand(1), F.getArg(1));
+#ifndef NDEBUG
+  EXPECT_EQ(NewFCmpFlags->getName(), "NewFCmpFlags");
+#endif // NDEBUG
 }
 
 TEST_F(SandboxIRTest, UnreachableInst) {



More information about the llvm-commits mailing list