[llvm] 840da2e - [SandboxIR] Implement CmpInst, FCmpInst, and ICmpInst (#106301)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 4 10:54:21 PDT 2024


Author: Sterling-Augustine
Date: 2024-09-04T10:54:17-07:00
New Revision: 840da2e8ba7e0f77938adfc6f6d315137542a1b8

URL: https://github.com/llvm/llvm-project/commit/840da2e8ba7e0f77938adfc6f6d315137542a1b8
DIFF: https://github.com/llvm/llvm-project/commit/840da2e8ba7e0f77938adfc6f6d315137542a1b8.diff

LOG: [SandboxIR] Implement CmpInst, FCmpInst, and ICmpInst (#106301)

As in the description.

Not sure the macros for "WRAP_XXX" add value or not, but do save some
boiler plate. Maybe there is a better way.

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/include/llvm/SandboxIR/SandboxIRValues.def
    llvm/include/llvm/SandboxIR/Tracker.h
    llvm/include/llvm/SandboxIR/Type.h
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/lib/SandboxIR/Tracker.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp
    llvm/unittests/SandboxIR/TrackerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 5c2d58c1b99dc3..89e963498426dc 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -54,11 +54,17 @@
 //                                      |                   |
 //                                      |                   +- ZExtInst
 //                                      |
-//                                      +- CallBase -----------+- CallBrInst
-//                                      |                      |
-//                                      +- CmpInst             +- CallInst
-//                                      |                      |
-//                                      +- ExtractElementInst  +- InvokeInst
+//                                      +- CallBase --------+- CallBrInst
+//                                      |                   |
+//                                      |                   +- CallInst
+//                                      |                   |
+//                                      |                   +- InvokeInst
+//                                      |
+//                                      +- CmpInst ---------+- ICmpInst
+//                                      |                   |
+//                                      |                   +- FCmpInst
+//                                      |
+//                                      +- ExtractElementInst
 //                                      |
 //                                      +- GetElementPtrInst
 //                                      |
@@ -158,6 +164,9 @@ class BinaryOperator;
 class PossiblyDisjointInst;
 class AtomicRMWInst;
 class AtomicCmpXchgInst;
+class CmpInst;
+class ICmpInst;
+class FCmpInst;
 
 /// Iterator for the `Use` edges of a User's operands.
 /// \Returns the operand `Use` when dereferenced.
@@ -304,6 +313,7 @@ class Value {
   friend class PHINode;               // For getting `Val`.
   friend class UnreachableInst;       // For getting `Val`.
   friend class CatchSwitchAddHandler; // For `Val`.
+  friend class CmpInst;               // For getting `Val`.
   friend class ConstantArray;         // For `Val`.
   friend class ConstantStruct;        // For `Val`.
 
@@ -1076,6 +1086,7 @@ class Instruction : public sandboxir::User {
   friend class CastInst;           // For getTopmostLLVMInstruction().
   friend class PHINode;            // For getTopmostLLVMInstruction().
   friend class UnreachableInst;    // For getTopmostLLVMInstruction().
+  friend class CmpInst;            // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -1232,6 +1243,7 @@ template <typename LLVMT> class SingleLLVMInstructionImpl : public Instruction {
   friend class UnaryInstruction;
   friend class CallBase;
   friend class FuncletPadInst;
+  friend class CmpInst;
 
   Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
     return getOperandUseDefault(OpIdx, Verify);
@@ -3425,6 +3437,151 @@ class PHINode final : public SingleLLVMInstructionImpl<llvm::PHINode> {
   //                         uint32_t ToIdx = 0)
 };
 
+// Wraps a static function that takes a single Predicate parameter
+// LLVMValType should be the type of the wrapped class
+#define WRAP_STATIC_PREDICATE(FunctionName)                                    \
+  static auto FunctionName(Predicate P) { return LLVMValType::FunctionName(P); }
+// Wraps a member function that takes no parameters
+// LLVMValType should be the type of the wrapped class
+#define WRAP_MEMBER(FunctionName)                                              \
+  auto FunctionName() const { return cast<LLVMValType>(Val)->FunctionName(); }
+// Wraps both--a common idiom in the CmpInst classes
+#define WRAP_BOTH(FunctionName)                                                \
+  WRAP_STATIC_PREDICATE(FunctionName)                                          \
+  WRAP_MEMBER(FunctionName)
+
+class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
+protected:
+  using LLVMValType = llvm::CmpInst;
+  /// Use Context::createCmpInst(). Don't call the constructor directly.
+  CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id, Opcode Opc)
+      : SingleLLVMInstructionImpl(Id, Opc, CI, Ctx) {}
+  friend Context; // for CmpInst()
+  static Value *createCommon(Value *Cond, Value *True, Value *False,
+                             const Twine &Name, IRBuilder<> &Builder,
+                             Context &Ctx);
+
+public:
+  using Predicate = llvm::CmpInst::Predicate;
+
+  static CmpInst *create(Predicate Pred, Value *S1, Value *S2,
+                         Instruction *InsertBefore, Context &Ctx,
+                         const Twine &Name = "");
+  static CmpInst *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
+                                        const Instruction *FlagsSource,
+                                        Instruction *InsertBefore, Context &Ctx,
+                                        const Twine &Name = "");
+  void setPredicate(Predicate P);
+  void swapOperands();
+
+  WRAP_MEMBER(getPredicate);
+  WRAP_BOTH(isFPPredicate);
+  WRAP_BOTH(isIntPredicate);
+  WRAP_STATIC_PREDICATE(getPredicateName);
+  WRAP_BOTH(getInversePredicate);
+  WRAP_BOTH(getOrderedPredicate);
+  WRAP_BOTH(getUnorderedPredicate);
+  WRAP_BOTH(getSwappedPredicate);
+  WRAP_BOTH(isStrictPredicate);
+  WRAP_BOTH(isNonStrictPredicate);
+  WRAP_BOTH(getStrictPredicate);
+  WRAP_BOTH(getNonStrictPredicate);
+  WRAP_BOTH(getFlippedStrictnessPredicate);
+  WRAP_MEMBER(isCommutative);
+  WRAP_BOTH(isEquality);
+  WRAP_BOTH(isRelational);
+  WRAP_BOTH(isSigned);
+  WRAP_BOTH(getSignedPredicate);
+  WRAP_BOTH(getUnsignedPredicate);
+  WRAP_BOTH(getFlippedSignednessPredicate);
+  WRAP_BOTH(isTrueWhenEqual);
+  WRAP_BOTH(isFalseWhenEqual);
+  WRAP_BOTH(isUnsigned);
+  WRAP_STATIC_PREDICATE(isOrdered);
+  WRAP_STATIC_PREDICATE(isUnordered);
+
+  static bool isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+    return llvm::CmpInst::isImpliedTrueByMatchingCmp(Pred1, Pred2);
+  }
+  static bool isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+    return llvm::CmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
+  }
+
+  /// Method for support type inquiry through isa, cast, and dyn_cast:
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::ICmp ||
+           From->getSubclassID() == ClassID::FCmp;
+  }
+
+  /// Create a result type for fcmp/icmp
+  static Type *makeCmpResultType(Type *OpndType);
+
+#ifndef NDEBUG
+  void dumpOS(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+class ICmpInst : public CmpInst {
+  /// Use Context::createICmpInst(). Don't call the constructor directly.
+  ICmpInst(llvm::ICmpInst *CI, Context &Ctx)
+      : CmpInst(CI, Ctx, ClassID::ICmp, Opcode::ICmp) {}
+  friend class Context; // For constructor.
+  using LLVMValType = llvm::ICmpInst;
+
+public:
+  void swapOperands();
+
+  WRAP_BOTH(getSignedPredicate);
+  WRAP_BOTH(getUnsignedPredicate);
+  WRAP_BOTH(isEquality);
+  WRAP_MEMBER(isCommutative);
+  WRAP_MEMBER(isRelational);
+  WRAP_STATIC_PREDICATE(isGT);
+  WRAP_STATIC_PREDICATE(isLT);
+  WRAP_STATIC_PREDICATE(isGE);
+  WRAP_STATIC_PREDICATE(isLE);
+
+  static auto predicates() { return llvm::ICmpInst::predicates(); }
+  static bool compare(const APInt &LHS, const APInt &RHS,
+                      ICmpInst::Predicate Pred) {
+    return llvm::ICmpInst::compare(LHS, RHS, Pred);
+  }
+
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::ICmp;
+  }
+};
+
+class FCmpInst : public CmpInst {
+  /// Use Context::createFCmpInst(). Don't call the constructor directly.
+  FCmpInst(llvm::FCmpInst *CI, Context &Ctx)
+      : CmpInst(CI, Ctx, ClassID::FCmp, Opcode::FCmp) {}
+  friend class Context; // For constructor.
+  using LLVMValType = llvm::FCmpInst;
+
+public:
+  void swapOperands();
+
+  WRAP_BOTH(isEquality);
+  WRAP_MEMBER(isCommutative);
+  WRAP_MEMBER(isRelational);
+
+  static auto predicates() { return llvm::FCmpInst::predicates(); }
+  static bool compare(const APFloat &LHS, const APFloat &RHS,
+                      FCmpInst::Predicate Pred) {
+    return llvm::FCmpInst::compare(LHS, RHS, Pred);
+  }
+
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::FCmp;
+  }
+};
+
+#undef WRAP_STATIC_PREDICATE
+#undef WRAP_MEMBER
+#undef WRAP_BOTH
+
 /// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
 /// an OpaqueInstr.
 class OpaqueInst : public SingleLLVMInstructionImpl<llvm::Instruction> {
@@ -3445,6 +3602,8 @@ class Context {
   LLVMContext &LLVMCtx;
   friend class Type;        // For LLVMCtx.
   friend class PointerType; // For LLVMCtx.
+  friend class CmpInst; // For LLVMCtx. TODO: cleanup when sandboxir::VectorType
+                        // is complete
   friend class IntegerType; // For LLVMCtx.
   friend class StructType;  // For LLVMCtx.
   Tracker IRTracker;
@@ -3572,6 +3731,12 @@ class Context {
   friend PHINode; // For createPHINode()
   UnreachableInst *createUnreachableInst(llvm::UnreachableInst *UI);
   friend UnreachableInst; // For createUnreachableInst()
+  CmpInst *createCmpInst(llvm::CmpInst *I);
+  friend CmpInst; // For createCmpInst()
+  ICmpInst *createICmpInst(llvm::ICmpInst *I);
+  friend ICmpInst; // For createICmpInst()
+  FCmpInst *createFCmpInst(llvm::FCmpInst *I);
+  friend FCmpInst; // For createFCmpInst()
 
 public:
   Context(LLVMContext &LLVMCtx)

diff  --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index d2031bbdcfb543..f320f61934efa6 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -113,6 +113,8 @@ DEF_INSTR(Cast,   OPCODES(\
                           ),                  CastInst)
 DEF_INSTR(PHI,            OP(PHI),            PHINode)
 DEF_INSTR(Unreachable,    OP(Unreachable),    UnreachableInst)
+DEF_INSTR(ICmp,           OP(ICmp),          FCmpInst)
+DEF_INSTR(FCmp,           OP(FCmp),          ICmpInst)
 
 // clang-format on
 #ifdef DEF_VALUE

diff  --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index c8a9e99a34341d..5fc43db82bd707 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -63,7 +63,7 @@ class CatchSwitchInst;
 class SwitchInst;
 class ConstantInt;
 class ShuffleVectorInst;
-
+class CmpInst;
 /// The base class for IR Change classes.
 class IRChangeBase {
 protected:
@@ -130,6 +130,19 @@ class PHIAddIncoming : public IRChangeBase {
 #endif
 };
 
+class CmpSwapOperands : public IRChangeBase {
+  CmpInst *Cmp;
+
+public:
+  CmpSwapOperands(CmpInst *Cmp);
+  void revert(Tracker &Tracker) final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final { OS << "CmpSwapOperands"; }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
 /// Tracks swapping a Use with another Use.
 class UseSwap : public IRChangeBase {
   Use ThisUse;

diff  --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 2f9b94b8d71751..61721ca836321c 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -50,6 +50,8 @@ class Type {
   friend class ConstantArray;  // For LLVMTy.
   friend class ConstantStruct; // For LLVMTy.
   friend class ConstantVector; // For LLVMTy.
+  friend class CmpInst; // For LLVMTy. TODO: Cleanup after sandboxir::VectorType
+                        // is more complete.
 
   // Friend all instruction classes because `create()` functions use LLVMTy.
 #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index e8d081e6b17e74..c0e5837209213f 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2735,6 +2735,16 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
     return It->second.get();
   }
+  case llvm::Instruction::ICmp: {
+    auto *LLVMICmp = cast<llvm::ICmpInst>(LLVMV);
+    It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
+    return It->second.get();
+  }
+  case llvm::Instruction::FCmp: {
+    auto *LLVMFCmp = cast<llvm::FCmpInst>(LLVMV);
+    It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::Unreachable: {
     auto *LLVMUnreachable = cast<llvm::UnreachableInst>(LLVMV);
     It->second = std::unique_ptr<UnreachableInst>(
@@ -2922,6 +2932,79 @@ PHINode *Context::createPHINode(llvm::PHINode *I) {
   auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
   return cast<PHINode>(registerValue(std::move(NewPtr)));
 }
+ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
+  auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
+  return cast<ICmpInst>(registerValue(std::move(NewPtr)));
+}
+FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
+  auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
+  return cast<FCmpInst>(registerValue(std::move(NewPtr)));
+}
+CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2,
+                         Instruction *InsertBefore, Context &Ctx,
+                         const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
+  auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
+  if (dyn_cast<llvm::ICmpInst>(LLVMI))
+    return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
+  return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
+}
+CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
+                                        const Instruction *F,
+                                        Instruction *InsertBefore, Context &Ctx,
+                                        const Twine &Name) {
+  CmpInst *Inst = create(P, S1, S2, InsertBefore, Ctx, Name);
+  cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
+  return Inst;
+}
+
+Type *CmpInst::makeCmpResultType(Type *OpndType) {
+  if (auto *VT = dyn_cast<VectorType>(OpndType)) {
+    // TODO: Cleanup when we have more complete support for
+    // sandboxir::VectorType
+    return OpndType->getContext().getType(llvm::VectorType::get(
+        llvm::Type::getInt1Ty(OpndType->getContext().LLVMCtx),
+        cast<llvm::VectorType>(VT->LLVMTy)->getElementCount()));
+  }
+  return Type::getInt1Ty(OpndType->getContext());
+}
+
+void CmpInst::setPredicate(Predicate P) {
+  Ctx.getTracker()
+      .emplaceIfTracking<
+          GenericSetter<&CmpInst::getPredicate, &CmpInst::setPredicate>>(this);
+  cast<llvm::CmpInst>(Val)->setPredicate(P);
+}
+
+void CmpInst::swapOperands() {
+  if (ICmpInst *IC = dyn_cast<ICmpInst>(this))
+    IC->swapOperands();
+  else
+    cast<FCmpInst>(this)->swapOperands();
+}
+
+void ICmpInst::swapOperands() {
+  Ctx.getTracker().emplaceIfTracking<CmpSwapOperands>(this);
+  cast<llvm::ICmpInst>(Val)->swapOperands();
+}
+
+void FCmpInst::swapOperands() {
+  Ctx.getTracker().emplaceIfTracking<CmpSwapOperands>(this);
+  cast<llvm::FCmpInst>(Val)->swapOperands();
+}
+
+#ifndef NDEBUG
+void CmpInst::dumpOS(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+
+void CmpInst::dump() const {
+  dumpOS(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
 
 Value *Context::getValue(llvm::Value *V) const {
   auto It = LLVMValueToValueMap.find(V);

diff  --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index 953d4bd51353a9..c6eb9fc68a4b11 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -248,6 +248,14 @@ void ShuffleVectorSetMask::dump() const {
 }
 #endif
 
+CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
+
+void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
+void CmpSwapOperands::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+
 void Tracker::save() { State = TrackerState::Record; }
 
 void Tracker::revert() {

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index ca2a183e532683..d1c5690ccad5b4 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -4800,6 +4800,184 @@ define void @foo(i32 %arg) {
   EXPECT_EQ(NewPHI->getNumIncomingValues(), PHI->getNumIncomingValues());
 }
 
+static void checkSwapOperands(sandboxir::Context &Ctx,
+                              llvm::sandboxir::CmpInst *Cmp,
+                              llvm::CmpInst *LLVMCmp) {
+  auto OrigOp0 = Cmp->getOperand(0);
+  auto OrigOp1 = Cmp->getOperand(1);
+  EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp0);
+  EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp1);
+  // This checks the dispatch mechanism in CmpInst, as well as
+  // the specific implementations.
+  Cmp->swapOperands();
+  EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(1)), OrigOp0);
+  EXPECT_EQ(Ctx.getValue(LLVMCmp->getOperand(0)), OrigOp1);
+  EXPECT_EQ(Cmp->getOperand(0), OrigOp1);
+  EXPECT_EQ(Cmp->getOperand(1), OrigOp0);
+  // Undo it to keep the rest of the test consistent
+  Cmp->swapOperands();
+}
+
+static void checkCommonPredicates(sandboxir::CmpInst *Cmp,
+                                  llvm::CmpInst *LLVMCmp) {
+  // Check proper creation
+  auto Pred = Cmp->getPredicate();
+  auto LLVMPred = LLVMCmp->getPredicate();
+  EXPECT_EQ(Pred, LLVMPred);
+  // Check setPredicate
+  Cmp->setPredicate(llvm::CmpInst::FCMP_FALSE);
+  EXPECT_EQ(Cmp->getPredicate(), llvm::CmpInst::FCMP_FALSE);
+  EXPECT_EQ(LLVMCmp->getPredicate(), llvm::CmpInst::FCMP_FALSE);
+  Cmp->setPredicate(Pred);
+  EXPECT_EQ(LLVMCmp->getPredicate(), Pred);
+  // Ensure the accessors properly forward to the underlying implementation
+  EXPECT_STREQ(sandboxir::CmpInst::getPredicateName(Pred).data(),
+               llvm::CmpInst::getPredicateName(LLVMPred).data());
+  EXPECT_EQ(Cmp->isFPPredicate(), LLVMCmp->isFPPredicate());
+  EXPECT_EQ(Cmp->isIntPredicate(), LLVMCmp->isIntPredicate());
+  EXPECT_EQ(Cmp->getInversePredicate(), LLVMCmp->getInversePredicate());
+  EXPECT_EQ(Cmp->getOrderedPredicate(), LLVMCmp->getOrderedPredicate());
+  EXPECT_EQ(Cmp->getUnorderedPredicate(), LLVMCmp->getUnorderedPredicate());
+  EXPECT_EQ(Cmp->getSwappedPredicate(), LLVMCmp->getSwappedPredicate());
+  EXPECT_EQ(Cmp->isStrictPredicate(), LLVMCmp->isStrictPredicate());
+  EXPECT_EQ(Cmp->isNonStrictPredicate(), LLVMCmp->isNonStrictPredicate());
+  EXPECT_EQ(Cmp->isRelational(), LLVMCmp->isRelational());
+  if (Cmp->isRelational()) {
+    EXPECT_EQ(Cmp->getFlippedStrictnessPredicate(),
+              LLVMCmp->getFlippedStrictnessPredicate());
+  }
+  EXPECT_EQ(Cmp->isCommutative(), LLVMCmp->isCommutative());
+  EXPECT_EQ(Cmp->isTrueWhenEqual(), LLVMCmp->isTrueWhenEqual());
+  EXPECT_EQ(Cmp->isFalseWhenEqual(), LLVMCmp->isFalseWhenEqual());
+  EXPECT_EQ(sandboxir::CmpInst::isOrdered(Pred),
+            llvm::CmpInst::isOrdered(LLVMPred));
+  EXPECT_EQ(sandboxir::CmpInst::isUnordered(Pred),
+            llvm::CmpInst::isUnordered(LLVMPred));
+}
+
+TEST_F(SandboxIRTest, ICmpInst) {
+  SCOPED_TRACE("SandboxIRTest sandboxir::ICmpInst tests");
+  parseIR(C, R"IR(
+define void @foo(i32 %i0, i32 %i1) {
+ bb:
+  %ine  = icmp ne i32 %i0, %i1
+  %iugt = icmp ugt i32 %i0, %i1
+  %iuge = icmp uge i32 %i0, %i1
+  %iult = icmp ult i32 %i0, %i1
+  %iule = icmp ule i32 %i0, %i1
+  %isgt = icmp sgt i32 %i0, %i1
+  %isle = icmp sle i32 %i0, %i1
+  %ieg  = icmp eq i32 %i0, %i1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+
+  auto *LLVMBB = getBasicBlockByName(LLVMF, "bb");
+  auto LLVMIt = LLVMBB->begin();
+  auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
+  auto It = BB->begin();
+  // Check classof()
+  while (auto *ICmp = dyn_cast<sandboxir::ICmpInst>(&*It++)) {
+    auto *LLVMICmp = cast<llvm::ICmpInst>(&*LLVMIt++);
+    checkSwapOperands(Ctx, ICmp, LLVMICmp);
+    checkCommonPredicates(ICmp, LLVMICmp);
+    EXPECT_EQ(ICmp->isSigned(), LLVMICmp->isSigned());
+    EXPECT_EQ(ICmp->isUnsigned(), LLVMICmp->isUnsigned());
+    EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
+    EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
+  }
+  auto *NewCmp =
+      sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
+                                 F.getArg(1), &*BB->begin(), Ctx, "NewCmp");
+  EXPECT_EQ(NewCmp, &*BB->begin());
+  EXPECT_EQ(NewCmp->getPredicate(), llvm::CmpInst::ICMP_ULE);
+  EXPECT_EQ(NewCmp->getOperand(0), F.getArg(0));
+  EXPECT_EQ(NewCmp->getOperand(1), F.getArg(1));
+#ifndef NDEBUG
+  EXPECT_EQ(NewCmp->getName(), "NewCmp");
+#endif // NDEBUG
+  // TODO: Improve this test when sandboxir::VectorType is more completely
+  // implemented.
+  sandboxir::Type *RT =
+      sandboxir::CmpInst::makeCmpResultType(F.getArg(0)->getType());
+  EXPECT_TRUE(RT->isIntegerTy(1)); // Only one bit in a single comparison
+}
+
+TEST_F(SandboxIRTest, FCmpInst) {
+  SCOPED_TRACE("SandboxIRTest sandboxir::FCmpInst tests");
+  parseIR(C, R"IR(
+define void @foo(float %f0, float %f1) {
+bb:
+  %ffalse = fcmp false float %f0, %f1
+  %foeq = fcmp oeq float %f0, %f1
+  %fogt = fcmp ogt float %f0, %f1
+  %folt = fcmp olt float %f0, %f1
+  %fole = fcmp ole float %f0, %f1
+  %fone = fcmp one float %f0, %f1
+  %ford = fcmp ord float %f0, %f1
+  %funo = fcmp uno float %f0, %f1
+  %fueq = fcmp ueq float %f0, %f1
+  %fugt = fcmp ugt float %f0, %f1
+  %fuge = fcmp uge float %f0, %f1
+  %fult = fcmp ult float %f0, %f1
+  %fule = fcmp ule float %f0, %f1
+  %fune = fcmp une float %f0, %f1
+  %ftrue = fcmp true float %f0, %f1
+  ret void
+bb1:
+  %copyfrom = fadd reassoc float %f0, 42.0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+
+  auto *LLVMBB = getBasicBlockByName(LLVMF, "bb");
+  auto LLVMIt = LLVMBB->begin();
+  auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
+  auto It = BB->begin();
+  // Check classof()
+  while (auto *FCmp = dyn_cast<sandboxir::ICmpInst>(&*It++)) {
+    auto *LLVMFCmp = cast<llvm::ICmpInst>(&*LLVMIt++);
+    checkSwapOperands(Ctx, FCmp, LLVMFCmp);
+    checkCommonPredicates(FCmp, LLVMFCmp);
+  }
+
+  auto *LLVMBB1 = getBasicBlockByName(LLVMF, "bb1");
+  auto *BB1 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB1));
+  auto It1 = BB1->begin();
+  auto *CopyFrom = &*It1++;
+  CopyFrom->setFastMathFlags(FastMathFlags::getFast());
+
+  // create with default flags
+  auto *NewFCmp = sandboxir::CmpInst::create(
+      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), &*It1, Ctx, "NewFCmp");
+  EXPECT_EQ(NewFCmp->getPredicate(), llvm::CmpInst::FCMP_ONE);
+  EXPECT_EQ(NewFCmp->getOperand(0), F.getArg(0));
+  EXPECT_EQ(NewFCmp->getOperand(1), F.getArg(1));
+#ifndef NDEBUG
+  EXPECT_EQ(NewFCmp->getName(), "NewFCmp");
+#endif // NDEBUG
+  FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
+  EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
+  // create with copied flags
+  auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
+      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, &*It1, Ctx,
+      "NewFCmpFlags");
+  EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
+               CopyFrom->getFastMathFlags());
+  EXPECT_EQ(NewFCmpFlags->getPredicate(), llvm::CmpInst::FCMP_ONE);
+  EXPECT_EQ(NewFCmpFlags->getOperand(0), F.getArg(0));
+  EXPECT_EQ(NewFCmpFlags->getOperand(1), F.getArg(1));
+#ifndef NDEBUG
+  EXPECT_EQ(NewFCmpFlags->getName(), "NewFCmpFlags");
+#endif // NDEBUG
+}
+
 TEST_F(SandboxIRTest, UnreachableInst) {
   parseIR(C, R"IR(
 define void @foo() {

diff  --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index fe29452a8aea20..a1f39fe958e351 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -1452,6 +1452,49 @@ define void @foo(i8 %arg0, i8 %arg1, i8 %arg2) {
   EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
 }
 
+void checkCmpInst(sandboxir::Context &Ctx, sandboxir::CmpInst *Cmp) {
+  Ctx.save();
+  auto OrigP = Cmp->getPredicate();
+  auto NewP = Cmp->getSwappedPredicate();
+  Cmp->setPredicate(NewP);
+  EXPECT_EQ(Cmp->getPredicate(), NewP);
+  Ctx.revert();
+  EXPECT_EQ(Cmp->getPredicate(), OrigP);
+
+  Ctx.save();
+  auto OrigOp0 = Cmp->getOperand(0);
+  auto OrigOp1 = Cmp->getOperand(1);
+  Cmp->swapOperands();
+  EXPECT_EQ(Cmp->getPredicate(), NewP);
+  EXPECT_EQ(Cmp->getOperand(0), OrigOp1);
+  EXPECT_EQ(Cmp->getOperand(1), OrigOp0);
+  Ctx.revert();
+  EXPECT_EQ(Cmp->getPredicate(), OrigP);
+  EXPECT_EQ(Cmp->getOperand(0), OrigOp0);
+  EXPECT_EQ(Cmp->getOperand(1), OrigOp1);
+}
+
+TEST_F(TrackerTest, CmpInst) {
+  SCOPED_TRACE("TrackerTest sandboxir::CmpInst tests");
+  parseIR(C, R"IR(
+define void @foo(i64 %i0, i64 %i1, float %f0, float %f1) {
+  %foeq = fcmp ogt float %f0, %f1
+  %ioeq = icmp uge i64 %i0, %i1
+
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *FCmp = cast<sandboxir::CmpInst>(&*It++);
+  checkCmpInst(Ctx, FCmp);
+  auto *ICmp = cast<sandboxir::CmpInst>(&*It++);
+  checkCmpInst(Ctx, ICmp);
+}
+
 TEST_F(TrackerTest, SetVolatile) {
   parseIR(C, R"IR(
 define void @foo(ptr %arg0, i8 %val) {


        


More information about the llvm-commits mailing list