[llvm] [SandboxIR] Implement AtomicRMWInst (PR #104529)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 17:57:07 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/104529

This patch implements sandboxir::AtomicRMWInst mirroring llvm::AtomicRMWInst.

>From e060c46264c381d650f0666b8a75cf68290378d1 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 12 Aug 2024 12:18:44 -0700
Subject: [PATCH] [SandboxIR] Implement AtomicRMWInst

This patch implements sandboxir::AtomicRMWInst mirroring llvm::AtomicRMWInst.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       |  76 ++++++++
 .../llvm/SandboxIR/SandboxIRValues.def        |   1 +
 llvm/lib/SandboxIR/SandboxIR.cpp              |  78 +++++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 163 ++++++++++++++++++
 llvm/unittests/SandboxIR/TrackerTest.cpp      |  55 ++++++
 5 files changed, 373 insertions(+)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 423dad854a91cb..75c9d5017054b4 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -132,6 +132,7 @@ class BitCastInst;
 class AllocaInst;
 class UnaryOperator;
 class BinaryOperator;
+class AtomicRMWInst;
 class AtomicCmpXchgInst;
 
 /// Iterator for the `Use` edges of a User's operands.
@@ -253,6 +254,7 @@ class Value {
   friend class GetElementPtrInst;  // For getting `Val`.
   friend class UnaryOperator;      // For getting `Val`.
   friend class BinaryOperator;     // For getting `Val`.
+  friend class AtomicRMWInst;      // For getting `Val`.
   friend class AtomicCmpXchgInst;  // For getting `Val`.
   friend class AllocaInst;         // For getting `Val`.
   friend class CastInst;           // For getting `Val`.
@@ -636,6 +638,7 @@ class Instruction : public sandboxir::User {
   friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
   friend class UnaryOperator;      // For getTopmostLLVMInstruction().
   friend class BinaryOperator;     // For getTopmostLLVMInstruction().
+  friend class AtomicRMWInst;      // For getTopmostLLVMInstruction().
   friend class AtomicCmpXchgInst;  // For getTopmostLLVMInstruction().
   friend class AllocaInst;         // For getTopmostLLVMInstruction().
   friend class CastInst;           // For getTopmostLLVMInstruction().
@@ -1559,6 +1562,77 @@ class BinaryOperator : public SingleLLVMInstructionImpl<llvm::BinaryOperator> {
   void swapOperands() { swapOperandsInternal(0, 1); }
 };
 
+class AtomicRMWInst : public SingleLLVMInstructionImpl<llvm::AtomicRMWInst> {
+  AtomicRMWInst(llvm::AtomicRMWInst *Atomic, Context &Ctx)
+      : SingleLLVMInstructionImpl(ClassID::AtomicRMW,
+                                  Instruction::Opcode::AtomicRMW, Atomic, Ctx) {
+  }
+  friend class Context; // For constructor.
+
+public:
+  using BinOp = llvm::AtomicRMWInst::BinOp;
+  BinOp getOperation() const {
+    return cast<llvm::AtomicRMWInst>(Val)->getOperation();
+  }
+  static StringRef getOperationName(BinOp Op) {
+    return llvm::AtomicRMWInst::getOperationName(Op);
+  }
+  static bool isFPOperation(BinOp Op) {
+    return llvm::AtomicRMWInst::isFPOperation(Op);
+  }
+  void setOperation(BinOp Op) {
+    cast<llvm::AtomicRMWInst>(Val)->setOperation(Op);
+  }
+  Align getAlign() const { return cast<llvm::AtomicRMWInst>(Val)->getAlign(); }
+  void setAlignment(Align Align);
+  bool isVolatile() const {
+    return cast<llvm::AtomicRMWInst>(Val)->isVolatile();
+  }
+  void setVolatile(bool V);
+  AtomicOrdering getOrdering() const {
+    return cast<llvm::AtomicRMWInst>(Val)->getOrdering();
+  }
+  void setOrdering(AtomicOrdering Ordering);
+  SyncScope::ID getSyncScopeID() const {
+    return cast<llvm::AtomicRMWInst>(Val)->getSyncScopeID();
+  }
+  void setSyncScopeID(SyncScope::ID SSID);
+  Value *getPointerOperand();
+  const Value *getPointerOperand() const {
+    return const_cast<AtomicRMWInst *>(this)->getPointerOperand();
+  }
+  Value *getValOperand();
+  const Value *getValOperand() const {
+    return const_cast<AtomicRMWInst *>(this)->getValOperand();
+  }
+  unsigned getPointerAddressSpace() const {
+    return cast<llvm::AtomicRMWInst>(Val)->getPointerAddressSpace();
+  }
+  bool isFloatingPointOperation() const {
+    return cast<llvm::AtomicRMWInst>(Val)->isFloatingPointOperation();
+  }
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::AtomicRMW;
+  }
+
+  static AtomicRMWInst *create(BinOp Op, Value *Ptr, Value *Val,
+                               MaybeAlign Align, AtomicOrdering Ordering,
+                               BBIterator WhereIt, BasicBlock *WhereBB,
+                               Context &Ctx,
+                               SyncScope::ID SSID = SyncScope::System,
+                               const Twine &Name = "");
+  static AtomicRMWInst *create(BinOp Op, Value *Ptr, Value *Val,
+                               MaybeAlign Align, AtomicOrdering Ordering,
+                               Instruction *InsertBefore, Context &Ctx,
+                               SyncScope::ID SSID = SyncScope::System,
+                               const Twine &Name = "");
+  static AtomicRMWInst *create(BinOp Op, Value *Ptr, Value *Val,
+                               MaybeAlign Align, AtomicOrdering Ordering,
+                               BasicBlock *InsertAtEnd, Context &Ctx,
+                               SyncScope::ID SSID = SyncScope::System,
+                               const Twine &Name = "");
+};
+
 class AtomicCmpXchgInst
     : public SingleLLVMInstructionImpl<llvm::AtomicCmpXchgInst> {
   AtomicCmpXchgInst(llvm::AtomicCmpXchgInst *Atomic, Context &Ctx)
@@ -2007,6 +2081,8 @@ class Context {
   friend UnaryOperator; // For createUnaryOperator()
   BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
   friend BinaryOperator; // For createBinaryOperator()
+  AtomicRMWInst *createAtomicRMWInst(llvm::AtomicRMWInst *I);
+  friend AtomicRMWInst; // For createAtomicRMWInst()
   AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);
   friend AtomicCmpXchgInst; // For createAtomicCmpXchgInst()
   AllocaInst *createAllocaInst(llvm::AllocaInst *I);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 7332316c85026c..81a916be21e4a3 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -68,6 +68,7 @@ DEF_INSTR(BinaryOperator, OPCODES(\
                          OP(Or)   \
                          OP(Xor)  \
                          ),                 BinaryOperator)
+DEF_INSTR(AtomicRMW,     OP(AtomicRMW),     AtomicRMWInst)
 DEF_INSTR(AtomicCmpXchg, OP(AtomicCmpXchg), AtomicCmpXchgInst)
 DEF_INSTR(Alloca,         OP(Alloca),         AllocaInst)
 DEF_INSTR(Cast,   OPCODES(\
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 67262bdd0dea99..96c5af807d1001 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1385,6 +1385,74 @@ Value *BinaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
                                InsertAtEnd, Ctx, Name);
 }
 
+void AtomicRMWInst::setAlignment(Align Align) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicRMWInst::getAlign,
+                                       &AtomicRMWInst::setAlignment>>(this);
+  cast<llvm::AtomicRMWInst>(Val)->setAlignment(Align);
+}
+
+void AtomicRMWInst::setVolatile(bool V) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicRMWInst::isVolatile,
+                                       &AtomicRMWInst::setVolatile>>(this);
+  cast<llvm::AtomicRMWInst>(Val)->setVolatile(V);
+}
+
+void AtomicRMWInst::setOrdering(AtomicOrdering Ordering) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicRMWInst::getOrdering,
+                                       &AtomicRMWInst::setOrdering>>(this);
+  cast<llvm::AtomicRMWInst>(Val)->setOrdering(Ordering);
+}
+
+void AtomicRMWInst::setSyncScopeID(SyncScope::ID SSID) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&AtomicRMWInst::getSyncScopeID,
+                                       &AtomicRMWInst::setSyncScopeID>>(this);
+  cast<llvm::AtomicRMWInst>(Val)->setSyncScopeID(SSID);
+}
+
+Value *AtomicRMWInst::getPointerOperand() {
+  return Ctx.getValue(cast<llvm::AtomicRMWInst>(Val)->getPointerOperand());
+}
+
+Value *AtomicRMWInst::getValOperand() {
+  return Ctx.getValue(cast<llvm::AtomicRMWInst>(Val)->getValOperand());
+}
+
+AtomicRMWInst *AtomicRMWInst::create(BinOp Op, Value *Ptr, Value *Val,
+                                     MaybeAlign Align, AtomicOrdering Ordering,
+                                     BBIterator WhereIt, BasicBlock *WhereBB,
+                                     Context &Ctx, SyncScope::ID SSID,
+                                     const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt == WhereBB->end())
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  else
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  auto *LLVMAtomicRMW =
+      Builder.CreateAtomicRMW(Op, Ptr->Val, Val->Val, Align, Ordering, SSID);
+  LLVMAtomicRMW->setName(Name);
+  return Ctx.createAtomicRMWInst(LLVMAtomicRMW);
+}
+
+AtomicRMWInst *AtomicRMWInst::create(BinOp Op, Value *Ptr, Value *Val,
+                                     MaybeAlign Align, AtomicOrdering Ordering,
+                                     Instruction *InsertBefore, Context &Ctx,
+                                     SyncScope::ID SSID, const Twine &Name) {
+  return create(Op, Ptr, Val, Align, Ordering, InsertBefore->getIterator(),
+                InsertBefore->getParent(), Ctx, SSID, Name);
+}
+
+AtomicRMWInst *AtomicRMWInst::create(BinOp Op, Value *Ptr, Value *Val,
+                                     MaybeAlign Align, AtomicOrdering Ordering,
+                                     BasicBlock *InsertAtEnd, Context &Ctx,
+                                     SyncScope::ID SSID, const Twine &Name) {
+  return create(Op, Ptr, Val, Align, Ordering, InsertAtEnd->end(), InsertAtEnd,
+                Ctx, SSID, Name);
+}
+
 void AtomicCmpXchgInst::setSyncScopeID(SyncScope::ID SSID) {
   Ctx.getTracker()
       .emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::getSyncScopeID,
@@ -1823,6 +1891,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         new BinaryOperator(LLVMBinaryOperator, *this));
     return It->second.get();
   }
+  case llvm::Instruction::AtomicRMW: {
+    auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(LLVMV);
+    It->second =
+        std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(LLVMAtomicRMW, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::AtomicCmpXchg: {
     auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV);
     It->second = std::unique_ptr<AtomicCmpXchgInst>(
@@ -1954,6 +2028,10 @@ BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
   auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
   return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
 }
+AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
+  auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
+  return cast<AtomicRMWInst>(registerValue(std::move(NewPtr)));
+}
 AtomicCmpXchgInst *
 Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
   auto NewPtr =
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 3df335985aa705..e4563ab8f07ba6 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1932,6 +1932,169 @@ define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {
   }
 }
 
+TEST_F(SandboxIRTest, AtomicRMWInst) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %arg) {
+  %atomicrmw = atomicrmw add ptr %ptr, i8 %arg acquire, align 128
+  ret void
+}
+)IR");
+  llvm::Function &LLVMF = *M->getFunction("foo");
+  llvm::BasicBlock *LLVMBB = &*LLVMF.begin();
+  auto LLVMIt = LLVMBB->begin();
+  auto *LLVMRMW = cast<llvm::AtomicRMWInst>(&*LLVMIt++);
+
+  sandboxir::Context Ctx(C);
+  sandboxir::Function *F = Ctx.createFunction(&LLVMF);
+  auto *Ptr = F->getArg(0);
+  auto *Arg = F->getArg(1);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *RMW = cast<sandboxir::AtomicRMWInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  // Check getOperationName().
+  EXPECT_EQ(
+      sandboxir::AtomicRMWInst::getOperationName(
+          sandboxir::AtomicRMWInst::BinOp::Add),
+      llvm::AtomicRMWInst::getOperationName(llvm::AtomicRMWInst::BinOp::Add));
+  // Check isFPOperation().
+  EXPECT_EQ(
+      sandboxir::AtomicRMWInst::isFPOperation(
+          sandboxir::AtomicRMWInst::BinOp::Add),
+      llvm::AtomicRMWInst::isFPOperation(llvm::AtomicRMWInst::BinOp::Add));
+  EXPECT_FALSE(sandboxir::AtomicRMWInst::isFPOperation(
+      sandboxir::AtomicRMWInst::BinOp::Add));
+  EXPECT_TRUE(sandboxir::AtomicRMWInst::isFPOperation(
+      sandboxir::AtomicRMWInst::BinOp::FAdd));
+  // Check setOperation(), getOperation().
+  EXPECT_EQ(RMW->getOperation(), LLVMRMW->getOperation());
+  RMW->setOperation(sandboxir::AtomicRMWInst::BinOp::Sub);
+  EXPECT_EQ(RMW->getOperation(), sandboxir::AtomicRMWInst::BinOp::Sub);
+  RMW->setOperation(sandboxir::AtomicRMWInst::BinOp::Add);
+  // Check getAlign().
+  EXPECT_EQ(RMW->getAlign(), LLVMRMW->getAlign());
+  auto OrigAlign = RMW->getAlign();
+  Align NewAlign(256);
+  EXPECT_NE(NewAlign, OrigAlign);
+  RMW->setAlignment(NewAlign);
+  EXPECT_EQ(RMW->getAlign(), NewAlign);
+  RMW->setAlignment(OrigAlign);
+  EXPECT_EQ(RMW->getAlign(), OrigAlign);
+  // Check isVolatile(), setVolatile().
+  EXPECT_EQ(RMW->isVolatile(), LLVMRMW->isVolatile());
+  bool OrigV = RMW->isVolatile();
+  bool NewV = true;
+  EXPECT_NE(NewV, OrigV);
+  RMW->setVolatile(NewV);
+  EXPECT_EQ(RMW->isVolatile(), NewV);
+  RMW->setVolatile(OrigV);
+  EXPECT_EQ(RMW->isVolatile(), OrigV);
+  // Check getOrdering(), setOrdering().
+  EXPECT_EQ(RMW->getOrdering(), LLVMRMW->getOrdering());
+  auto OldOrdering = RMW->getOrdering();
+  auto NewOrdering = AtomicOrdering::Monotonic;
+  EXPECT_NE(NewOrdering, OldOrdering);
+  RMW->setOrdering(NewOrdering);
+  EXPECT_EQ(RMW->getOrdering(), NewOrdering);
+  RMW->setOrdering(OldOrdering);
+  EXPECT_EQ(RMW->getOrdering(), OldOrdering);
+  // Check getSyncScopeID(), setSyncScopeID().
+  EXPECT_EQ(RMW->getSyncScopeID(), LLVMRMW->getSyncScopeID());
+  auto OrigSSID = RMW->getSyncScopeID();
+  SyncScope::ID NewSSID = SyncScope::SingleThread;
+  EXPECT_NE(NewSSID, OrigSSID);
+  RMW->setSyncScopeID(NewSSID);
+  EXPECT_EQ(RMW->getSyncScopeID(), NewSSID);
+  RMW->setSyncScopeID(OrigSSID);
+  EXPECT_EQ(RMW->getSyncScopeID(), OrigSSID);
+  // Check getPointerOperand().
+  EXPECT_EQ(RMW->getPointerOperand(),
+            Ctx.getValue(LLVMRMW->getPointerOperand()));
+  // Check getValOperand().
+  EXPECT_EQ(RMW->getValOperand(), Ctx.getValue(LLVMRMW->getValOperand()));
+  // Check getPointerAddressSpace().
+  EXPECT_EQ(RMW->getPointerAddressSpace(), LLVMRMW->getPointerAddressSpace());
+  // Check isFloatingPointOperation().
+  EXPECT_EQ(RMW->isFloatingPointOperation(),
+            LLVMRMW->isFloatingPointOperation());
+
+  Align Align(1024);
+  auto Ordering = AtomicOrdering::Acquire;
+  auto SSID = SyncScope::System;
+  {
+    // Check create() WhereIt, WhereBB.
+    auto *NewI =
+        cast<sandboxir::AtomicRMWInst>(sandboxir::AtomicRMWInst::create(
+            sandboxir::AtomicRMWInst::BinOp::Sub, Ptr, Arg, Align, Ordering,
+            /*WhereIt=*/Ret->getIterator(),
+            /*WhereBB=*/Ret->getParent(), Ctx, SSID, "NewAtomicRMW1"));
+    // Check getOpcode().
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::AtomicRMW);
+    // Check getAlign().
+    EXPECT_EQ(NewI->getAlign(), Align);
+    // Check getSuccessOrdering().
+    EXPECT_EQ(NewI->getOrdering(), Ordering);
+    // Check instr position.
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+    // Check getPointerOperand().
+    EXPECT_EQ(NewI->getPointerOperand(), Ptr);
+    // Check getValOperand().
+    EXPECT_EQ(NewI->getValOperand(), Arg);
+#ifndef NDEBUG
+    // Check getName().
+    EXPECT_EQ(NewI->getName(), "NewAtomicRMW1");
+#endif // NDEBUG
+  }
+  {
+    // Check create() InsertBefore.
+    auto *NewI =
+        cast<sandboxir::AtomicRMWInst>(sandboxir::AtomicRMWInst::create(
+            sandboxir::AtomicRMWInst::BinOp::Sub, Ptr, Arg, Align, Ordering,
+            /*InsertBefore=*/Ret, Ctx, SSID, "NewAtomicRMW2"));
+    // Check getOpcode().
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::AtomicRMW);
+    // Check getAlign().
+    EXPECT_EQ(NewI->getAlign(), Align);
+    // Check getSuccessOrdering().
+    EXPECT_EQ(NewI->getOrdering(), Ordering);
+    // Check instr position.
+    EXPECT_EQ(NewI->getNextNode(), Ret);
+    // Check getPointerOperand().
+    EXPECT_EQ(NewI->getPointerOperand(), Ptr);
+    // Check getValOperand().
+    EXPECT_EQ(NewI->getValOperand(), Arg);
+#ifndef NDEBUG
+    // Check getName().
+    EXPECT_EQ(NewI->getName(), "NewAtomicRMW2");
+#endif // NDEBUG
+  }
+  {
+    // Check create() InsertAtEnd.
+    auto *NewI =
+        cast<sandboxir::AtomicRMWInst>(sandboxir::AtomicRMWInst::create(
+            sandboxir::AtomicRMWInst::BinOp::Sub, Ptr, Arg, Align, Ordering,
+            /*InsertAtEnd=*/BB, Ctx, SSID, "NewAtomicRMW3"));
+    // Check getOpcode().
+    EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::AtomicRMW);
+    // Check getAlign().
+    EXPECT_EQ(NewI->getAlign(), Align);
+    // Check getSuccessOrdering().
+    EXPECT_EQ(NewI->getOrdering(), Ordering);
+    // Check instr position.
+    EXPECT_EQ(NewI->getParent(), BB);
+    EXPECT_EQ(NewI->getNextNode(), nullptr);
+    // Check getPointerOperand().
+    EXPECT_EQ(NewI->getPointerOperand(), Ptr);
+    // Check getValOperand().
+    EXPECT_EQ(NewI->getValOperand(), Arg);
+#ifndef NDEBUG
+    // Check getName().
+    EXPECT_EQ(NewI->getName(), "NewAtomicRMW3");
+#endif // NDEBUG
+  }
+}
+
 TEST_F(SandboxIRTest, AtomicCmpXchgInst) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr, i8 %cmp, i8 %new) {
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index c1f23c95cbfaed..380c90e7f0f65f 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -644,6 +644,61 @@ define void @foo(i8 %arg) {
   EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
 }
 
+TEST_F(TrackerTest, AtomicRMWSetters) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %arg) {
+  %atomicrmw = atomicrmw add ptr %ptr, i8 %arg acquire, align 128
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *RMW = cast<sandboxir::AtomicRMWInst>(&*It++);
+
+  // Check setAlignment().
+  Ctx.save();
+  auto OrigAlign = RMW->getAlign();
+  Align NewAlign(1024);
+  EXPECT_NE(NewAlign, OrigAlign);
+  RMW->setAlignment(NewAlign);
+  EXPECT_EQ(RMW->getAlign(), NewAlign);
+  Ctx.revert();
+  EXPECT_EQ(RMW->getAlign(), OrigAlign);
+
+  // Check setVolatile().
+  Ctx.save();
+  auto OrigIsVolatile = RMW->isVolatile();
+  bool NewIsVolatile = true;
+  EXPECT_NE(NewIsVolatile, OrigIsVolatile);
+  RMW->setVolatile(NewIsVolatile);
+  EXPECT_EQ(RMW->isVolatile(), NewIsVolatile);
+  Ctx.revert();
+  EXPECT_EQ(RMW->isVolatile(), OrigIsVolatile);
+
+  // Check setOrdering().
+  Ctx.save();
+  auto OrigOrdering = RMW->getOrdering();
+  auto NewOrdering = AtomicOrdering::SequentiallyConsistent;
+  EXPECT_NE(NewOrdering, OrigOrdering);
+  RMW->setOrdering(NewOrdering);
+  EXPECT_EQ(RMW->getOrdering(), NewOrdering);
+  Ctx.revert();
+  EXPECT_EQ(RMW->getOrdering(), OrigOrdering);
+
+  // Check setSyncScopeID().
+  Ctx.save();
+  auto OrigSSID = RMW->getSyncScopeID();
+  auto NewSSID = SyncScope::SingleThread;
+  EXPECT_NE(NewSSID, OrigSSID);
+  RMW->setSyncScopeID(NewSSID);
+  EXPECT_EQ(RMW->getSyncScopeID(), NewSSID);
+  Ctx.revert();
+  EXPECT_EQ(RMW->getSyncScopeID(), OrigSSID);
+}
+
 TEST_F(TrackerTest, AtomicCmpXchgSetters) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr, i8 %cmp, i8 %new) {



More information about the llvm-commits mailing list