[llvm] [SandboxIR] Implement AtomicCmpXchgInst (PR #102710)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 13 09:56:25 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/102710

>From 37888675818e5f6cb96887a6cd0ccf08b21796fc Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 1 Aug 2024 15:45:23 -0700
Subject: [PATCH] [SandboxIR] Implement AtomicCmpXchgInst

This patch implements sandboxir::AtomicCmpXchgInst which mirrors llvm::AtomiCmpXchgInst.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       |  90 ++++++++
 .../llvm/SandboxIR/SandboxIRValues.def        |   1 +
 llvm/lib/SandboxIR/SandboxIR.cpp              | 110 ++++++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 192 ++++++++++++++++++
 llvm/unittests/SandboxIR/TrackerTest.cpp      |  75 +++++++
 5 files changed, 468 insertions(+)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index c160520788d873..a6adb448ff0b19 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -130,6 +130,7 @@ class CastInst;
 class PtrToIntInst;
 class BitCastInst;
 class AllocaInst;
+class AtomicCmpXchgInst;
 
 /// Iterator for the `Use` edges of a User's operands.
 /// \Returns the operand `Use` when dereferenced.
@@ -248,6 +249,7 @@ class Value {
   friend class InvokeInst;         // For getting `Val`.
   friend class CallBrInst;         // For getting `Val`.
   friend class GetElementPtrInst;  // For getting `Val`.
+  friend class AtomicCmpXchgInst;  // For getting `Val`.
   friend class AllocaInst;         // For getting `Val`.
   friend class CastInst;           // For getting `Val`.
   friend class PHINode;            // For getting `Val`.
@@ -628,6 +630,7 @@ class Instruction : public sandboxir::User {
   friend class InvokeInst;         // For getTopmostLLVMInstruction().
   friend class CallBrInst;         // For getTopmostLLVMInstruction().
   friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
+  friend class AtomicCmpXchgInst;  // For getTopmostLLVMInstruction().
   friend class AllocaInst;         // For getTopmostLLVMInstruction().
   friend class CastInst;           // For getTopmostLLVMInstruction().
   friend class PHINode;            // For getTopmostLLVMInstruction().
@@ -1337,6 +1340,91 @@ class GetElementPtrInst final
   // TODO: Add missing member functions.
 };
 
+class AtomicCmpXchgInst
+    : public SingleLLVMInstructionImpl<llvm::AtomicCmpXchgInst> {
+  AtomicCmpXchgInst(llvm::AtomicCmpXchgInst *Atomic, Context &Ctx)
+      : SingleLLVMInstructionImpl(ClassID::AtomicCmpXchg,
+                                  Instruction::Opcode::AtomicCmpXchg, Atomic,
+                                  Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  /// Return the alignment of the memory that is being allocated by the
+  /// instruction.
+  Align getAlign() const {
+    return cast<llvm::AtomicCmpXchgInst>(Val)->getAlign();
+  }
+
+  void setAlignment(Align Align);
+  /// Return true if this is a cmpxchg from a volatile memory
+  /// location.
+  bool isVolatile() const {
+    return cast<llvm::AtomicCmpXchgInst>(Val)->isVolatile();
+  }
+  /// Specify whether this is a volatile cmpxchg.
+  void setVolatile(bool V);
+  /// Return true if this cmpxchg may spuriously fail.
+  bool isWeak() const { return cast<llvm::AtomicCmpXchgInst>(Val)->isWeak(); }
+  void setWeak(bool IsWeak);
+  static bool isValidSuccessOrdering(AtomicOrdering Ordering) {
+    return llvm::AtomicCmpXchgInst::isValidSuccessOrdering(Ordering);
+  }
+  static bool isValidFailureOrdering(AtomicOrdering Ordering) {
+    return llvm::AtomicCmpXchgInst::isValidFailureOrdering(Ordering);
+  }
+  AtomicOrdering getSuccessOrdering() const {
+    return cast<llvm::AtomicCmpXchgInst>(Val)->getSuccessOrdering();
+  }
+  void setSuccessOrdering(AtomicOrdering Ordering);
+
+  AtomicOrdering getFailureOrdering() const {
+    return cast<llvm::AtomicCmpXchgInst>(Val)->getFailureOrdering();
+  }
+  void setFailureOrdering(AtomicOrdering Ordering);
+  AtomicOrdering getMergedOrdering() const {
+    return cast<llvm::AtomicCmpXchgInst>(Val)->getMergedOrdering();
+  }
+  SyncScope::ID getSyncScopeID() const {
+    return cast<llvm::AtomicCmpXchgInst>(Val)->getSyncScopeID();
+  }
+  void setSyncScopeID(SyncScope::ID SSID);
+  Value *getPointerOperand();
+  const Value *getPointerOperand() const {
+    return const_cast<AtomicCmpXchgInst *>(this)->getPointerOperand();
+  }
+
+  Value *getCompareOperand();
+  const Value *getCompareOperand() const {
+    return const_cast<AtomicCmpXchgInst *>(this)->getCompareOperand();
+  }
+
+  Value *getNewValOperand();
+  const Value *getNewValOperand() const {
+    return const_cast<AtomicCmpXchgInst *>(this)->getNewValOperand();
+  }
+
+  /// Returns the address space of the pointer operand.
+  unsigned getPointerAddressSpace() const {
+    return cast<llvm::AtomicCmpXchgInst>(Val)->getPointerAddressSpace();
+  }
+
+  static AtomicCmpXchgInst *
+  create(Value *Ptr, Value *Cmp, Value *New, MaybeAlign Align,
+         AtomicOrdering SuccessOrdering, AtomicOrdering FailureOrdering,
+         BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
+         SyncScope::ID SSID = SyncScope::System, const Twine &Name = "");
+  static AtomicCmpXchgInst *
+  create(Value *Ptr, Value *Cmp, Value *New, MaybeAlign Align,
+         AtomicOrdering SuccessOrdering, AtomicOrdering FailureOrdering,
+         Instruction *InsertBefore, Context &Ctx,
+         SyncScope::ID SSID = SyncScope::System, const Twine &Name = "");
+  static AtomicCmpXchgInst *
+  create(Value *Ptr, Value *Cmp, Value *New, MaybeAlign Align,
+         AtomicOrdering SuccessOrdering, AtomicOrdering FailureOrdering,
+         BasicBlock *InsertAtEnd, Context &Ctx,
+         SyncScope::ID SSID = SyncScope::System, const Twine &Name = "");
+};
+
 class AllocaInst final : public UnaryInstruction {
   AllocaInst(llvm::AllocaInst *AI, Context &Ctx)
       : UnaryInstruction(ClassID::Alloca, Instruction::Opcode::Alloca, AI,
@@ -1696,6 +1784,8 @@ class Context {
   friend CallBrInst; // For createCallBrInst()
   GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
   friend GetElementPtrInst; // For createGetElementPtrInst()
+  AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);
+  friend AtomicCmpXchgInst; // For createAtomicCmpXchgInst()
   AllocaInst *createAllocaInst(llvm::AllocaInst *I);
   friend AllocaInst; // For createAllocaInst()
   CastInst *createCastInst(llvm::CastInst *I);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 11f4f2e74712f4..114dc8505ecd6f 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -45,6 +45,7 @@ DEF_INSTR(Call,           OP(Call),           CallInst)
 DEF_INSTR(Invoke,         OP(Invoke),         InvokeInst)
 DEF_INSTR(CallBr,         OP(CallBr),         CallBrInst)
 DEF_INSTR(GetElementPtr,  OP(GetElementPtr),  GetElementPtrInst)
+DEF_INSTR(AtomicCmpXchg, OP(AtomicCmpXchg), AtomicCmpXchgInst)
 DEF_INSTR(Alloca,         OP(Alloca),         AllocaInst)
 DEF_INSTR(Cast,   OPCODES(\
                           OP(ZExt)          \
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 445f56b14e83b5..80809b23e34b2c 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1122,6 +1122,104 @@ static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
   }
 }
 
+void AtomicCmpXchgInst::setSyncScopeID(SyncScope::ID SSID) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::getSyncScopeID,
+                                       &AtomicCmpXchgInst::setSyncScopeID>>(
+          this);
+  cast<llvm::AtomicCmpXchgInst>(Val)->setSyncScopeID(SSID);
+}
+
+Value *AtomicCmpXchgInst::getPointerOperand() {
+  return Ctx.getValue(cast<llvm::AtomicCmpXchgInst>(Val)->getPointerOperand());
+}
+
+Value *AtomicCmpXchgInst::getCompareOperand() {
+  return Ctx.getValue(cast<llvm::AtomicCmpXchgInst>(Val)->getCompareOperand());
+}
+
+Value *AtomicCmpXchgInst::getNewValOperand() {
+  return Ctx.getValue(cast<llvm::AtomicCmpXchgInst>(Val)->getNewValOperand());
+}
+
+AtomicCmpXchgInst *
+AtomicCmpXchgInst::create(Value *Ptr, Value *Cmp, Value *New, MaybeAlign Align,
+                          AtomicOrdering SuccessOrdering,
+                          AtomicOrdering FailureOrdering, BBIterator WhereIt,
+                          BasicBlock *WhereBB, Context &Ctx, SyncScope::ID SSID,
+                          const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt == WhereBB->end())
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  else
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  auto *LLVMAtomicCmpXchg =
+      Builder.CreateAtomicCmpXchg(Ptr->Val, Cmp->Val, New->Val, Align,
+                                  SuccessOrdering, FailureOrdering, SSID);
+  LLVMAtomicCmpXchg->setName(Name);
+  return Ctx.createAtomicCmpXchgInst(LLVMAtomicCmpXchg);
+}
+
+AtomicCmpXchgInst *AtomicCmpXchgInst::create(Value *Ptr, Value *Cmp, Value *New,
+                                             MaybeAlign Align,
+                                             AtomicOrdering SuccessOrdering,
+                                             AtomicOrdering FailureOrdering,
+                                             Instruction *InsertBefore,
+                                             Context &Ctx, SyncScope::ID SSID,
+                                             const Twine &Name) {
+  return create(Ptr, Cmp, New, Align, SuccessOrdering, FailureOrdering,
+                InsertBefore->getIterator(), InsertBefore->getParent(), Ctx,
+                SSID, Name);
+}
+
+AtomicCmpXchgInst *AtomicCmpXchgInst::create(Value *Ptr, Value *Cmp, Value *New,
+                                             MaybeAlign Align,
+                                             AtomicOrdering SuccessOrdering,
+                                             AtomicOrdering FailureOrdering,
+                                             BasicBlock *InsertAtEnd,
+                                             Context &Ctx, SyncScope::ID SSID,
+                                             const Twine &Name) {
+  return create(Ptr, Cmp, New, Align, SuccessOrdering, FailureOrdering,
+                InsertAtEnd->end(), InsertAtEnd, Ctx, SSID, Name);
+}
+
+void AtomicCmpXchgInst::setAlignment(Align Align) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::getAlign,
+                                       &AtomicCmpXchgInst::setAlignment>>(this);
+  cast<llvm::AtomicCmpXchgInst>(Val)->setAlignment(Align);
+}
+
+void AtomicCmpXchgInst::setVolatile(bool V) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::isVolatile,
+                                       &AtomicCmpXchgInst::setVolatile>>(this);
+  cast<llvm::AtomicCmpXchgInst>(Val)->setVolatile(V);
+}
+
+void AtomicCmpXchgInst::setWeak(bool IsWeak) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::isWeak,
+                                       &AtomicCmpXchgInst::setWeak>>(this);
+  cast<llvm::AtomicCmpXchgInst>(Val)->setWeak(IsWeak);
+}
+
+void AtomicCmpXchgInst::setSuccessOrdering(AtomicOrdering Ordering) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::getSuccessOrdering,
+                                       &AtomicCmpXchgInst::setSuccessOrdering>>(
+          this);
+  cast<llvm::AtomicCmpXchgInst>(Val)->setSuccessOrdering(Ordering);
+}
+
+void AtomicCmpXchgInst::setFailureOrdering(AtomicOrdering Ordering) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::getFailureOrdering,
+                                       &AtomicCmpXchgInst::setFailureOrdering>>(
+          this);
+  cast<llvm::AtomicCmpXchgInst>(Val)->setFailureOrdering(Ordering);
+}
+
 AllocaInst *AllocaInst::create(Type *Ty, unsigned AddrSpace, BBIterator WhereIt,
                                BasicBlock *WhereBB, Context &Ctx,
                                Value *ArraySize, const Twine &Name) {
@@ -1433,6 +1531,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         new GetElementPtrInst(LLVMGEP, *this));
     return It->second.get();
   }
+  case llvm::Instruction::AtomicCmpXchg: {
+    auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV);
+    It->second = std::unique_ptr<AtomicCmpXchgInst>(
+        new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::Alloca: {
     auto *LLVMAlloca = cast<llvm::AllocaInst>(LLVMV);
     It->second = std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this));
@@ -1550,6 +1654,12 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
       std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
   return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
 }
+AtomicCmpXchgInst *
+Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
+  auto NewPtr =
+      std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this));
+  return cast<AtomicCmpXchgInst>(registerValue(std::move(NewPtr)));
+}
 AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) {
   auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this));
   return cast<AllocaInst>(registerValue(std::move(NewPtr)));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 6e1a8f691141fc..caf306922847ed 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1569,6 +1569,198 @@ define void @foo(ptr %ptr, <2 x ptr> %ptrs) {
   EXPECT_EQ(NewGEP2->getNextNode(), nullptr);
 }
 
+TEST_F(SandboxIRTest, AtomicCmpXchgInst) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %cmp, i8 %new) {
+  %cmpxchg = cmpxchg ptr %ptr, i8 %cmp, i8 %new monotonic monotonic, align 128
+  ret void
+}
+)IR");
+  llvm::Function &LLVMF = *M->getFunction("foo");
+  llvm::BasicBlock *LLVMBB = &*LLVMF.begin();
+  auto LLVMIt = LLVMBB->begin();
+  auto *LLVMCmpXchg = cast<llvm::AtomicCmpXchgInst>(&*LLVMIt++);
+
+  sandboxir::Context Ctx(C);
+  sandboxir::Function *F = Ctx.createFunction(&LLVMF);
+  auto *Ptr = F->getArg(0);
+  auto *Cmp = F->getArg(1);
+  auto *New = F->getArg(2);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *CmpXchg = cast<sandboxir::AtomicCmpXchgInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  // Check getAlign(), setAlignment().
+  EXPECT_EQ(CmpXchg->getAlign(), LLVMCmpXchg->getAlign());
+  auto OrigAlign = CmpXchg->getAlign();
+  Align NewAlign(256);
+  EXPECT_NE(NewAlign, OrigAlign);
+  CmpXchg->setAlignment(NewAlign);
+  EXPECT_EQ(CmpXchg->getAlign(), NewAlign);
+  CmpXchg->setAlignment(OrigAlign);
+  EXPECT_EQ(CmpXchg->getAlign(), OrigAlign);
+  // Check isVolatile(), setVolatile().
+  EXPECT_EQ(CmpXchg->isVolatile(), LLVMCmpXchg->isVolatile());
+  bool OrigV = CmpXchg->isVolatile();
+  bool NewV = true;
+  EXPECT_NE(NewV, OrigV);
+  CmpXchg->setVolatile(NewV);
+  EXPECT_EQ(CmpXchg->isVolatile(), NewV);
+  CmpXchg->setVolatile(OrigV);
+  EXPECT_EQ(CmpXchg->isVolatile(), OrigV);
+  // Check isWeak(), setWeak().
+  EXPECT_EQ(CmpXchg->isWeak(), LLVMCmpXchg->isWeak());
+  bool OrigWeak = CmpXchg->isWeak();
+  bool NewWeak = true;
+  EXPECT_NE(NewWeak, OrigWeak);
+  CmpXchg->setWeak(NewWeak);
+  EXPECT_EQ(CmpXchg->isWeak(), NewWeak);
+  CmpXchg->setWeak(OrigWeak);
+  EXPECT_EQ(CmpXchg->isWeak(), OrigWeak);
+  // Check isValidSuccessOrdering(), isValidFailureOrdering().
+  SmallVector<AtomicOrdering> AllOrderings(
+      {AtomicOrdering::NotAtomic, AtomicOrdering::Unordered,
+       AtomicOrdering::Monotonic, AtomicOrdering::Acquire,
+       AtomicOrdering::Release, AtomicOrdering::AcquireRelease,
+       AtomicOrdering::SequentiallyConsistent});
+  for (auto Ordering : AllOrderings) {
+    EXPECT_EQ(sandboxir::AtomicCmpXchgInst::isValidSuccessOrdering(Ordering),
+              llvm::AtomicCmpXchgInst::isValidSuccessOrdering(Ordering));
+    EXPECT_EQ(sandboxir::AtomicCmpXchgInst::isValidFailureOrdering(Ordering),
+              llvm::AtomicCmpXchgInst::isValidFailureOrdering(Ordering));
+  }
+  // Check getSuccessOrdering(), setSuccessOrdering().
+  EXPECT_EQ(CmpXchg->getSuccessOrdering(), LLVMCmpXchg->getSuccessOrdering());
+  auto OldSuccOrdering = CmpXchg->getSuccessOrdering();
+  auto NewSuccOrdering = AtomicOrdering::Acquire;
+  EXPECT_NE(NewSuccOrdering, OldSuccOrdering);
+  CmpXchg->setSuccessOrdering(NewSuccOrdering);
+  EXPECT_EQ(CmpXchg->getSuccessOrdering(), NewSuccOrdering);
+  CmpXchg->setSuccessOrdering(OldSuccOrdering);
+  EXPECT_EQ(CmpXchg->getSuccessOrdering(), OldSuccOrdering);
+  // Check getFailureOrdering(), setFailureOrdering().
+  EXPECT_EQ(CmpXchg->getFailureOrdering(), LLVMCmpXchg->getFailureOrdering());
+  auto OldFailOrdering = CmpXchg->getFailureOrdering();
+  auto NewFailOrdering = AtomicOrdering::Acquire;
+  EXPECT_NE(NewFailOrdering, OldFailOrdering);
+  CmpXchg->setFailureOrdering(NewFailOrdering);
+  EXPECT_EQ(CmpXchg->getFailureOrdering(), NewFailOrdering);
+  CmpXchg->setFailureOrdering(OldFailOrdering);
+  EXPECT_EQ(CmpXchg->getFailureOrdering(), OldFailOrdering);
+  // Check getMergedOrdering().
+  EXPECT_EQ(CmpXchg->getMergedOrdering(), LLVMCmpXchg->getMergedOrdering());
+  // Check getSyncScopeID(), setSyncScopeID().
+  EXPECT_EQ(CmpXchg->getSyncScopeID(), LLVMCmpXchg->getSyncScopeID());
+  auto OrigSSID = CmpXchg->getSyncScopeID();
+  SyncScope::ID NewSSID = SyncScope::SingleThread;
+  EXPECT_NE(NewSSID, OrigSSID);
+  CmpXchg->setSyncScopeID(NewSSID);
+  EXPECT_EQ(CmpXchg->getSyncScopeID(), NewSSID);
+  CmpXchg->setSyncScopeID(OrigSSID);
+  EXPECT_EQ(CmpXchg->getSyncScopeID(), OrigSSID);
+  // Check getPointerOperand().
+  EXPECT_EQ(CmpXchg->getPointerOperand(),
+            Ctx.getValue(LLVMCmpXchg->getPointerOperand()));
+  // Check getCompareOperand().
+  EXPECT_EQ(CmpXchg->getCompareOperand(),
+            Ctx.getValue(LLVMCmpXchg->getCompareOperand()));
+  // Check getNewValOperand().
+  EXPECT_EQ(CmpXchg->getNewValOperand(),
+            Ctx.getValue(LLVMCmpXchg->getNewValOperand()));
+  // Check getPointerAddressSpace().
+  EXPECT_EQ(CmpXchg->getPointerAddressSpace(),
+            LLVMCmpXchg->getPointerAddressSpace());
+
+  Align Align(1024);
+  auto SuccOrdering = AtomicOrdering::Acquire;
+  auto FailOrdering = AtomicOrdering::Monotonic;
+  auto SSID = SyncScope::System;
+  {
+    // Check create() WhereIt, WhereBB.
+    auto *NewI =
+        cast<sandboxir::AtomicCmpXchgInst>(sandboxir::AtomicCmpXchgInst::create(
+            Ptr, Cmp, New, Align, SuccOrdering, FailOrdering,
+            /*WhereIt=*/Ret->getIterator(),
+            /*WhereBB=*/Ret->getParent(), Ctx, SSID, "NewAtomicCmpXchg1"));
+    // Check getOpcode().
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::AtomicCmpXchg);
+    // Check getAlign().
+    EXPECT_EQ(NewI->getAlign(), Align);
+    // Check getSuccessOrdering().
+    EXPECT_EQ(NewI->getSuccessOrdering(), SuccOrdering);
+    // Check getFailureOrdering().
+    EXPECT_EQ(NewI->getFailureOrdering(), FailOrdering);
+    // Check instr position.
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+    // Check getPointerOperand().
+    EXPECT_EQ(NewI->getPointerOperand(), Ptr);
+    // Check getCompareOperand().
+    EXPECT_EQ(NewI->getCompareOperand(), Cmp);
+    // Check getNewValOperand().
+    EXPECT_EQ(NewI->getNewValOperand(), New);
+#ifndef NDEBUG
+    // Check getName().
+    EXPECT_EQ(NewI->getName(), "NewAtomicCmpXchg1");
+#endif // NDEBUG
+  }
+  {
+    // Check create() InsertBefore.
+    auto *NewI =
+        cast<sandboxir::AtomicCmpXchgInst>(sandboxir::AtomicCmpXchgInst::create(
+            Ptr, Cmp, New, Align, SuccOrdering, FailOrdering,
+            /*InsertBefore=*/Ret, Ctx, SSID, "NewAtomicCmpXchg2"));
+    // Check getOpcode().
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::AtomicCmpXchg);
+    // Check getAlign().
+    EXPECT_EQ(NewI->getAlign(), Align);
+    // Check getSuccessOrdering().
+    EXPECT_EQ(NewI->getSuccessOrdering(), SuccOrdering);
+    // Check getFailureOrdering().
+    EXPECT_EQ(NewI->getFailureOrdering(), FailOrdering);
+    // Check instr position.
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+    // Check getPointerOperand().
+    EXPECT_EQ(NewI->getPointerOperand(), Ptr);
+    // Check getCompareOperand().
+    EXPECT_EQ(NewI->getCompareOperand(), Cmp);
+    // Check getNewValOperand().
+    EXPECT_EQ(NewI->getNewValOperand(), New);
+#ifndef NDEBUG
+    // Check getName().
+    EXPECT_EQ(NewI->getName(), "NewAtomicCmpXchg2");
+#endif // NDEBUG
+  }
+  {
+    // Check create() InsertAtEnd.
+    auto *NewI =
+        cast<sandboxir::AtomicCmpXchgInst>(sandboxir::AtomicCmpXchgInst::create(
+            Ptr, Cmp, New, Align, SuccOrdering, FailOrdering,
+            /*InsertAtEnd=*/BB, Ctx, SSID, "NewAtomicCmpXchg3"));
+    // Check getOpcode().
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::AtomicCmpXchg);
+    // Check getAlign().
+    EXPECT_EQ(NewI->getAlign(), Align);
+    // Check getSuccessOrdering().
+    EXPECT_EQ(NewI->getSuccessOrdering(), SuccOrdering);
+    // Check getFailureOrdering().
+    EXPECT_EQ(NewI->getFailureOrdering(), FailOrdering);
+    // Check instr position.
+    EXPECT_EQ(NewI->getParent(), BB);
+    EXPECT_EQ(NewI->getNextNode(), nullptr);
+    // Check getPointerOperand().
+    EXPECT_EQ(NewI->getPointerOperand(), Ptr);
+    // Check getCompareOperand().
+    EXPECT_EQ(NewI->getCompareOperand(), Cmp);
+    // Check getNewValOperand().
+    EXPECT_EQ(NewI->getNewValOperand(), New);
+#ifndef NDEBUG
+    // Check getName().
+    EXPECT_EQ(NewI->getName(), "NewAtomicCmpXchg3");
+#endif // NDEBUG
+  }
+}
+
 TEST_F(SandboxIRTest, AllocaInst) {
   parseIR(C, R"IR(
 define void @foo() {
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index a8cf41a177d1a9..ba1d16de7d2120 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -644,6 +644,81 @@ define void @foo(i8 %arg) {
   EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
 }
 
+TEST_F(TrackerTest, AtomicCmpXchgSetters) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %cmp, i8 %new) {
+  %cmpxchg = cmpxchg ptr %ptr, i8 %cmp, i8 %new monotonic monotonic, align 128
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *CmpXchg = cast<sandboxir::AtomicCmpXchgInst>(&*It++);
+
+  // Check setAlignment().
+  Ctx.save();
+  auto OrigAlign = CmpXchg->getAlign();
+  Align NewAlign(1024);
+  EXPECT_NE(NewAlign, OrigAlign);
+  CmpXchg->setAlignment(NewAlign);
+  EXPECT_EQ(CmpXchg->getAlign(), NewAlign);
+  Ctx.revert();
+  EXPECT_EQ(CmpXchg->getAlign(), OrigAlign);
+
+  // Check setVolatile().
+  Ctx.save();
+  auto OrigIsVolatile = CmpXchg->isVolatile();
+  bool NewIsVolatile = true;
+  EXPECT_NE(NewIsVolatile, OrigIsVolatile);
+  CmpXchg->setVolatile(NewIsVolatile);
+  EXPECT_EQ(CmpXchg->isVolatile(), NewIsVolatile);
+  Ctx.revert();
+  EXPECT_EQ(CmpXchg->isVolatile(), OrigIsVolatile);
+
+  // Check setWeak().
+  Ctx.save();
+  auto OrigIsWeak = CmpXchg->isWeak();
+  bool NewIsWeak = true;
+  EXPECT_NE(NewIsWeak, OrigIsWeak);
+  CmpXchg->setWeak(NewIsWeak);
+  EXPECT_EQ(CmpXchg->isWeak(), NewIsWeak);
+  Ctx.revert();
+  EXPECT_EQ(CmpXchg->isWeak(), OrigIsWeak);
+
+  // Check setSuccessOrdering().
+  Ctx.save();
+  auto OrigSuccessOrdering = CmpXchg->getSuccessOrdering();
+  auto NewSuccessOrdering = AtomicOrdering::SequentiallyConsistent;
+  EXPECT_NE(NewSuccessOrdering, OrigSuccessOrdering);
+  CmpXchg->setSuccessOrdering(NewSuccessOrdering);
+  EXPECT_EQ(CmpXchg->getSuccessOrdering(), NewSuccessOrdering);
+  Ctx.revert();
+  EXPECT_EQ(CmpXchg->getSuccessOrdering(), OrigSuccessOrdering);
+
+  // Check setFailureOrdering().
+  Ctx.save();
+  auto OrigFailureOrdering = CmpXchg->getFailureOrdering();
+  auto NewFailureOrdering = AtomicOrdering::SequentiallyConsistent;
+  EXPECT_NE(NewFailureOrdering, OrigFailureOrdering);
+  CmpXchg->setFailureOrdering(NewFailureOrdering);
+  EXPECT_EQ(CmpXchg->getFailureOrdering(), NewFailureOrdering);
+  Ctx.revert();
+  EXPECT_EQ(CmpXchg->getFailureOrdering(), OrigFailureOrdering);
+
+  // Check setSyncScopeID().
+  Ctx.save();
+  auto OrigSSID = CmpXchg->getSyncScopeID();
+  auto NewSSID = SyncScope::SingleThread;
+  EXPECT_NE(NewSSID, OrigSSID);
+  CmpXchg->setSyncScopeID(NewSSID);
+  EXPECT_EQ(CmpXchg->getSyncScopeID(), NewSSID);
+  Ctx.revert();
+  EXPECT_EQ(CmpXchg->getSyncScopeID(), OrigSSID);
+}
+
 TEST_F(TrackerTest, AllocaInstSetters) {
   parseIR(C, R"IR(
 define void @foo(i8 %arg) {



More information about the llvm-commits mailing list