[llvm] 1a547a9 - [OMPIRBuilder] Add support for atomic compare

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 2 12:39:20 PST 2022


Author: Shilei Tian
Date: 2022-02-02T15:39:14-05:00
New Revision: 1a547a94c1afcd05bba2e36831d26a8aa871fa06

URL: https://github.com/llvm/llvm-project/commit/1a547a94c1afcd05bba2e36831d26a8aa871fa06
DIFF: https://github.com/llvm/llvm-project/commit/1a547a94c1afcd05bba2e36831d26a8aa871fa06.diff

LOG: [OMPIRBuilder] Add support for atomic compare

This patch adds the support for `atomic compare` in `OMPIRBuilder`.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D118547

Added: 
    

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
index bee90281e086b..84954594f8bce 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
@@ -116,6 +116,9 @@ enum class AddressSpace : unsigned {
 /// \note This needs to be kept in sync with interop.h enum kmp_interop_type_t.:
 enum class OMPInteropType { Unknown, Target, TargetSync };
 
+/// Atomic compare operations. Currently OpenMP only supports ==, >, and <.
+enum class OMPAtomicCompareOp : unsigned { EQ, MIN, MAX };
+
 } // end namespace omp
 
 } // end namespace llvm

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f60debe8411c4..18a88ecf1388c 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1198,7 +1198,7 @@ class OpenMPIRBuilder {
       const function_ref<Value *(Value *XOld, IRBuilder<> &IRB)>;
 
 private:
-  enum AtomicKind { Read, Write, Update, Capture };
+  enum AtomicKind { Read, Write, Update, Capture, Compare };
 
   /// Determine whether to emit flush or not
   ///
@@ -1344,6 +1344,39 @@ class OpenMPIRBuilder {
                       AtomicUpdateCallbackTy &UpdateOp, bool UpdateExpr,
                       bool IsPostfixUpdate, bool IsXBinopExpr);
 
+  /// Emit atomic compare for constructs: --- Only scalar data types
+  /// cond-update-atomic:
+  /// x = x ordop expr ? expr : x;
+  /// x = expr ordop x ? expr : x;
+  /// x = x == e ? d : x;
+  /// x = e == x ? d : x; (this one is not in the spec)
+  /// cond-update-stmt:
+  /// if (x ordop expr) { x = expr; }
+  /// if (expr ordop x) { x = expr; }
+  /// if (x == e) { x = d; }
+  /// if (e == x) { x = d; } (this one is not in the spec)
+  ///
+  /// \param Loc          The insert and source location description.
+  /// \param X            The target atomic pointer to be updated.
+  /// \param E            The expected value ('e') for forms that use an
+  ///                     equality comparison or an expression ('expr') for
+  ///                     forms that use 'ordop' (logically an atomic maximum or
+  ///                     minimum).
+  /// \param D            The desired value for forms that use an equality
+  ///                     comparison. If forms that use 'ordop', it should be
+  ///                     \p nullptr.
+  /// \param AO           Atomic ordering of the generated atomic instructions.
+  /// \param OP           Atomic compare operation. It can only be ==, <, or >.
+  /// \param IsXBinopExpr True if the conditional statement is in the form where
+  ///                     x is on LHS. It only matters for < or >.
+  ///
+  /// \return Insertion point after generated atomic capture IR.
+  InsertPointTy createAtomicCompare(const LocationDescription &Loc,
+                                    AtomicOpValue &X, Value *E, Value *D,
+                                    AtomicOrdering AO,
+                                    omp::OMPAtomicCompareOp Op,
+                                    bool IsXBinopExpr);
+
   /// Create the control flow structure of a canonical OpenMP loop.
   ///
   /// The emitted loop will be disconnected, i.e. no edge to the loop's

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 99001269e1f8c..3a39e7ac9ff99 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -3171,6 +3171,7 @@ bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
     }
     break;
   case Write:
+  case Compare:
   case Update:
     if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
         AO == AtomicOrdering::SequentiallyConsistent) {
@@ -3472,6 +3473,68 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
   return Builder.saveIP();
 }
 
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
+    const LocationDescription &Loc, AtomicOpValue &X, Value *E, Value *D,
+    AtomicOrdering AO, OMPAtomicCompareOp Op, bool IsXBinopExpr) {
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+
+  assert(X.Var->getType()->isPointerTy() &&
+         "OMP atomic expects a pointer to target memory");
+  assert((X.ElemTy->isFloatingPointTy() || X.ElemTy->isIntegerTy() ||
+          X.ElemTy->isPointerTy()) &&
+         "OMP atomic compare expected a scalar type");
+
+  if (Op == OMPAtomicCompareOp::EQ) {
+    unsigned Addrspace = cast<PointerType>(X.Var->getType())->getAddressSpace();
+    IntegerType *IntCastTy =
+        IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
+    Value *XAddr =
+        X.ElemTy->isIntegerTy()
+            ? X.Var
+            : Builder.CreateBitCast(X.Var, IntCastTy->getPointerTo(Addrspace));
+    AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
+    // We don't need the result for now.
+    (void)Builder.CreateAtomicCmpXchg(XAddr, E, D, MaybeAlign(), AO, Failure);
+  } else {
+    assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
+           "Op should be either max or min at this point");
+    assert(X.ElemTy->isIntegerTy() &&
+           "max and min operators only support integer type");
+
+    // Reverse the ordop as the OpenMP forms are 
diff erent from LLVM forms.
+    // Let's take max as example.
+    // OpenMP form:
+    // x = x > expr ? expr : x;
+    // LLVM form:
+    // *ptr = *ptr > val ? *ptr : val;
+    // We need to transform to LLVM form.
+    // x = x <= expr ? x : expr;
+    AtomicRMWInst::BinOp NewOp;
+    if (IsXBinopExpr) {
+      if (X.IsSigned)
+        NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
+                                              : AtomicRMWInst::Max;
+      else
+        NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
+                                              : AtomicRMWInst::UMax;
+    } else {
+      if (X.IsSigned)
+        NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
+                                              : AtomicRMWInst::Min;
+      else
+        NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
+                                              : AtomicRMWInst::UMin;
+    }
+    // We dont' need the result for now.
+    (void)Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
+  }
+
+  checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
+
+  return Builder.saveIP();
+}
+
 GlobalVariable *
 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
                                        std::string VarName) {

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index e0c91d636f9c0..7c8ea485925bc 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -3031,6 +3031,63 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicCapture) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, OMPAtomicCompare) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  LLVMContext &Ctx = M->getContext();
+  IntegerType *Int32 = Type::getInt32Ty(Ctx);
+  AllocaInst *XVal = Builder.CreateAlloca(Int32);
+  XVal->setName("x");
+  StoreInst *Init =
+      Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), XVal);
+
+  OpenMPIRBuilder::AtomicOpValue XSigned = {XVal, Int32, true, false};
+  OpenMPIRBuilder::AtomicOpValue XUnsigned = {XVal, Int32, false, false};
+  AtomicOrdering AO = AtomicOrdering::Monotonic;
+  ConstantInt *Expr = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
+  ConstantInt *D = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
+  OMPAtomicCompareOp OpMax = OMPAtomicCompareOp::MAX;
+  OMPAtomicCompareOp OpEQ = OMPAtomicCompareOp::EQ;
+
+  Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, XSigned, Expr,
+                                                   nullptr, AO, OpMax, true));
+  Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, XUnsigned, Expr,
+                                                   nullptr, AO, OpMax, false));
+  Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, XSigned, Expr, D,
+                                                   AO, OpEQ, true));
+
+  BasicBlock *EntryBB = BB;
+  EXPECT_EQ(EntryBB->getParent()->size(), 1U);
+  EXPECT_EQ(EntryBB->size(), 5U);
+
+  AtomicRMWInst *ARWM1 = dyn_cast<AtomicRMWInst>(Init->getNextNode());
+  EXPECT_NE(ARWM1, nullptr);
+  EXPECT_EQ(ARWM1->getPointerOperand(), XVal);
+  EXPECT_EQ(ARWM1->getValOperand(), Expr);
+  EXPECT_EQ(ARWM1->getOperation(), AtomicRMWInst::Min);
+
+  AtomicRMWInst *ARWM2 = dyn_cast<AtomicRMWInst>(ARWM1->getNextNode());
+  EXPECT_NE(ARWM2, nullptr);
+  EXPECT_EQ(ARWM2->getPointerOperand(), XVal);
+  EXPECT_EQ(ARWM2->getValOperand(), Expr);
+  EXPECT_EQ(ARWM2->getOperation(), AtomicRMWInst::UMax);
+
+  AtomicCmpXchgInst *AXCHG = dyn_cast<AtomicCmpXchgInst>(ARWM2->getNextNode());
+  EXPECT_NE(AXCHG, nullptr);
+  EXPECT_EQ(AXCHG->getPointerOperand(), XVal);
+  EXPECT_EQ(AXCHG->getCompareOperand(), Expr);
+  EXPECT_EQ(AXCHG->getNewValOperand(), D);
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.
 template <typename InstTy>


        


More information about the llvm-commits mailing list