[clang] [NFC][Clang][OpenMP] Add helper functions/utils for finding/comparing attach base-ptrs. (PR #155625)
via cfe-commits
cfe-commits at lists.llvm.org
Wed Aug 27 07:16:47 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
Author: Abhinav Gaba (abhinavgaba)
<details>
<summary>Changes</summary>
These have been pulled out of the codegen PR #<!-- -->153683, to reduce the size of that PR.
---
Patch is 25.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155625.diff
5 Files Affected:
- (modified) clang/include/clang/AST/OpenMPClause.h (+95)
- (modified) clang/include/clang/Basic/OpenMPKinds.h (+8)
- (modified) clang/lib/AST/OpenMPClause.cpp (+68)
- (modified) clang/lib/Basic/OpenMPKinds.cpp (+5)
- (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+358)
``````````diff
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 1118d3e062e68..9627e99a306b4 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -5815,6 +5815,12 @@ class OMPClauseMappableExprCommon {
ValueDecl *getAssociatedDeclaration() const {
return AssociatedDeclaration;
}
+
+ bool operator==(const MappableComponent &Other) const {
+ return AssociatedExpressionNonContiguousPr ==
+ Other.AssociatedExpressionNonContiguousPr &&
+ AssociatedDeclaration == Other.AssociatedDeclaration;
+ }
};
// List of components of an expression. This first one is the whole
@@ -5828,6 +5834,95 @@ class OMPClauseMappableExprCommon {
using MappableExprComponentLists = SmallVector<MappableExprComponentList, 8>;
using MappableExprComponentListsRef = ArrayRef<MappableExprComponentList>;
+ // Hash function to allow usage as DenseMap keys.
+ friend llvm::hash_code hash_value(const MappableComponent &MC) {
+ return llvm::hash_combine(MC.getAssociatedExpression(),
+ MC.getAssociatedDeclaration(),
+ MC.isNonContiguous());
+ }
+
+public:
+ /// Get the type of an element of a ComponentList Expr \p Exp.
+ ///
+ /// For something like the following:
+ /// ```c
+ /// int *p, **p;
+ /// ```
+ /// The types for the following Exprs would be:
+ /// Expr | Type
+ /// ---------|-----------
+ /// p | int *
+ /// *p | int
+ /// p[0] | int
+ /// p[0:1] | int
+ /// pp | int **
+ /// pp[0] | int *
+ /// pp[0:1] | int *
+ /// Note: this assumes that if \p Exp is an array-section, it is contiguous.
+ static QualType getComponentExprElementType(const Expr *Exp);
+
+ /// Find the attach pointer expression from a list of mappable expression
+ /// components.
+ ///
+ /// This function traverses the component list to find the first
+ /// expression that has a pointer type, which represents the attach
+ /// base pointer expr for the current component-list.
+ ///
+ /// For example, given the following:
+ ///
+ /// ```c
+ /// struct S {
+ /// int a;
+ /// int b[10];
+ /// int c[10][10];
+ /// int *p;
+ /// int **pp;
+ /// }
+ /// S s, *ps, **pps, *(pas[10]), ***ppps;
+ /// int i;
+ /// ```
+ ///
+ /// The base-pointers for the following map operands would be:
+ /// map list-item | attach base-pointer | attach base-pointer
+ /// | for directives except | target_update (if
+ /// | target_update | different)
+ /// ----------------|-----------------------|---------------------
+ /// s | N/A |
+ /// s.a | N/A |
+ /// s.p | N/A |
+ /// ps | N/A |
+ /// ps->p | ps |
+ /// ps[1] | ps |
+ /// *(ps + 1) | ps |
+ /// (ps + 1)[1] | ps |
+ /// ps[1:10] | ps |
+ /// ps->b[10] | ps |
+ /// ps->p[10] | ps->p |
+ /// ps->c[1][2] | ps |
+ /// ps->c[1:2][2] | (error diagnostic) | N/A, TODO: ps
+ /// ps->c[1:1][2] | ps | N/A, TODO: ps
+ /// pps[1][2] | pps[1] |
+ /// pps[1:1][2] | pps[1:1] | N/A, TODO: pps[1:1]
+ /// pps[1:i][2] | pps[1:i] | N/A, TODO: pps[1:i]
+ /// pps[1:2][2] | (error diagnostic) | N/A
+ /// pps[1]->p | pps[1] |
+ /// pps[1]->p[10] | pps[1] |
+ /// pas[1] | N/A |
+ /// pas[1][2] | pas[1] |
+ /// ppps[1][2] | ppps[1] |
+ /// ppps[1][2][3] | ppps[1][2] |
+ /// ppps[1][2:1][3] | ppps[1][2:1] | N/A, TODO: ppps[1][2:1]
+ /// ppps[1][2:2][3] | (error diagnostic) | N/A
+ /// Returns a pair of the attach pointer expression and its depth in the
+ /// component list.
+ /// TODO: This may need to be updated to handle ref_ptr/ptee cases for byref
+ /// map operands.
+ /// TODO: Handle cases for target-update, where the list-item is a
+ /// non-contiguous array-section that still has a base-pointer.
+ static std::pair<const Expr *, std::optional<size_t>>
+ findAttachPtrExpr(MappableExprComponentListRef Components,
+ OpenMPDirectiveKind CurDirKind);
+
protected:
// Return the total number of elements in a list of component lists.
static unsigned
diff --git a/clang/include/clang/Basic/OpenMPKinds.h b/clang/include/clang/Basic/OpenMPKinds.h
index f40db4c13c55a..e37887e8b86ba 100644
--- a/clang/include/clang/Basic/OpenMPKinds.h
+++ b/clang/include/clang/Basic/OpenMPKinds.h
@@ -301,6 +301,14 @@ bool isOpenMPTargetExecutionDirective(OpenMPDirectiveKind DKind);
/// otherwise - false.
bool isOpenMPTargetDataManagementDirective(OpenMPDirectiveKind DKind);
+/// Checks if the specified directive is a map-entering target directive.
+/// \param DKind Specified directive.
+/// \return true - the directive is a map-entering target directive like
+/// 'omp target', 'omp target data', 'omp target enter data',
+/// 'omp target parallel', etc. (excludes 'omp target exit data', 'omp target
+/// update') otherwise - false.
+bool isOpenMPTargetMapEnteringDirective(OpenMPDirectiveKind DKind);
+
/// Checks if the specified composite/combined directive constitutes a teams
/// directive in the outermost nest. For example
/// 'omp teams distribute' or 'omp teams distribute parallel for'.
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 588b0dcc6d7b8..eff897a1a33b2 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -15,6 +15,7 @@
#include "clang/AST/Attr.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclOpenMP.h"
+#include "clang/AST/ExprOpenMP.h"
#include "clang/Basic/LLVM.h"
#include "clang/Basic/OpenMPKinds.h"
#include "clang/Basic/TargetInfo.h"
@@ -1156,6 +1157,73 @@ unsigned OMPClauseMappableExprCommon::getUniqueDeclarationsTotalNumber(
return UniqueDecls.size();
}
+QualType
+OMPClauseMappableExprCommon::getComponentExprElementType(const Expr *Exp) {
+ assert(!isa<OMPArrayShapingExpr>(Exp) &&
+ "Cannot get element-type from array-shaping expr.");
+
+ // Unless we are handling array-section expressions, including
+ // array-subscripts, derefs, we can rely on getType.
+ if (!isa<ArraySectionExpr>(Exp))
+ return Exp->getType().getNonReferenceType().getCanonicalType();
+
+ // For array-sections, we need to find the type of one element of
+ // the section.
+ const auto *OASE = cast<ArraySectionExpr>(Exp);
+
+ QualType BaseType = ArraySectionExpr::getBaseOriginalType(OASE->getBase());
+
+ QualType ElemTy;
+ if (const auto *ATy = BaseType->getAsArrayTypeUnsafe())
+ ElemTy = ATy->getElementType();
+ else
+ ElemTy = BaseType->getPointeeType();
+
+ ElemTy = ElemTy.getNonReferenceType().getCanonicalType();
+ return ElemTy;
+}
+
+std::pair<const Expr *, std::optional<size_t>>
+OMPClauseMappableExprCommon::findAttachPtrExpr(
+ MappableExprComponentListRef Components, OpenMPDirectiveKind CurDirKind) {
+
+ // If we only have a single component, we have a map like "map(p)", which
+ // cannot have a base-pointer.
+ if (Components.size() < 2)
+ return {nullptr, std::nullopt};
+
+ // Only check for non-contiguous sections on target_update, since we can
+ // assume array-sections are contiguous on maps on other constructs, even if
+ // we are not sure of it at compile-time, like for a[1:x][2].
+ if (Components.back().isNonContiguous() && CurDirKind == OMPD_target_update)
+ return {nullptr, std::nullopt};
+
+ // To find the attach base-pointer, we start with the second component,
+ // stripping away one component at a time, until we reach a pointer Expr
+ // (that is not a binary operator). The first such pointer should be the
+ // attach base-pointer for the component list.
+ for (size_t I = 1; I < Components.size(); ++I) {
+ const Expr *CurExpr = Components[I].getAssociatedExpression();
+ if (!CurExpr)
+ break;
+
+ // If CurExpr is something like `p + 10`, we need to ignore it, since
+ // we are looking for `p`.
+ if (isa<BinaryOperator>(CurExpr))
+ continue;
+
+ // Keep going until we reach an Expr of pointer type.
+ QualType CurType = getComponentExprElementType(CurExpr);
+ if (!CurType->isPointerType())
+ continue;
+
+ // We have found a pointer Expr. This must be the attach pointer.
+ return {CurExpr, Components.size() - I};
+ }
+
+ return {nullptr, std::nullopt};
+}
+
OMPMapClause *OMPMapClause::Create(
const ASTContext &C, const OMPVarListLocTy &Locs, ArrayRef<Expr *> Vars,
ArrayRef<ValueDecl *> Declarations,
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 220b31b0f19bc..2f2a5b66e4ca5 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -650,6 +650,11 @@ bool clang::isOpenMPTargetDataManagementDirective(OpenMPDirectiveKind DKind) {
DKind == OMPD_target_exit_data || DKind == OMPD_target_update;
}
+bool clang::isOpenMPTargetMapEnteringDirective(OpenMPDirectiveKind DKind) {
+ return DKind == OMPD_target_data || DKind == OMPD_target_enter_data ||
+ isOpenMPTargetExecutionDirective(DKind);
+}
+
bool clang::isOpenMPNestingTeamsDirective(OpenMPDirectiveKind DKind) {
if (DKind == OMPD_teams)
return true;
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index f98339d472fa9..d592c29a412a9 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6765,12 +6765,256 @@ llvm::Value *CGOpenMPRuntime::emitNumThreadsForTargetDirective(
namespace {
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+/// Utility to compare expression locations.
+/// Returns true if expr-loc of LHS is less-than that of RHS.
+/// This function asserts that both expressions have valid expr-locations.
+static bool compareExprLocs(const Expr *LHS, const Expr *RHS) {
+ // Assert that neither LHS nor RHS can be null
+ assert(LHS && "LHS expression cannot be null");
+ assert(RHS && "RHS expression cannot be null");
+
+ // Get source locations
+ SourceLocation LocLHS = LHS->getExprLoc();
+ SourceLocation LocRHS = RHS->getExprLoc();
+
+ // Assert that we have valid source locations
+ assert(LocLHS.isValid() && "LHS expression must have valid source location");
+ assert(LocRHS.isValid() && "RHS expression must have valid source location");
+
+ // Compare source locations for deterministic ordering
+ bool result = LocLHS < LocRHS;
+ return result;
+}
+
// Utility to handle information from clauses associated with a given
// construct that use mappable expressions (e.g. 'map' clause, 'to' clause).
// It provides a convenient interface to obtain the information and generate
// code for that information.
class MappableExprsHandler {
public:
+ /// Custom comparator for attach-pointer expressions that compares them by
+ /// complexity (i.e. their component-depth) first, then by their expr-locs if
+ /// they are semantically different.
+ struct AttachPtrExprComparator {
+ const MappableExprsHandler *Handler;
+ // Cache of previous equality comparison results.
+ mutable llvm::DenseMap<std::pair<const Expr *, const Expr *>, bool>
+ CachedEqualityComparisons;
+
+ AttachPtrExprComparator(const MappableExprsHandler *H) : Handler(H) {}
+
+ // Return true iff LHS is "less than" RHS.
+ bool operator()(const Expr *LHS, const Expr *RHS) const {
+ if (LHS == RHS)
+ return false;
+
+ // First, compare by complexity (depth)
+ auto ItLHS = Handler->AttachPtrComponentDepthMap.find(LHS);
+ auto ItRHS = Handler->AttachPtrComponentDepthMap.find(RHS);
+
+ std::optional<size_t> DepthLHS =
+ (ItLHS != Handler->AttachPtrComponentDepthMap.end()) ? ItLHS->second
+ : std::nullopt;
+ std::optional<size_t> DepthRHS =
+ (ItRHS != Handler->AttachPtrComponentDepthMap.end()) ? ItRHS->second
+ : std::nullopt;
+
+ // std::nullopt (no attach pointer) has lowest complexity
+ if (!DepthLHS.has_value() && !DepthRHS.has_value()) {
+ // Both have same complexity, now check semantic equality
+ if (areEqual(LHS, RHS))
+ return false;
+ // Different semantically, compare by location
+ return compareExprLocs(LHS, RHS);
+ }
+ if (!DepthLHS.has_value())
+ return true; // LHS has lower complexity
+ if (!DepthRHS.has_value())
+ return false; // RHS has lower complexity
+
+ // Both have values, compare by depth (lower depth = lower complexity)
+ if (DepthLHS.value() != DepthRHS.value())
+ return DepthLHS.value() < DepthRHS.value();
+
+ // Same complexity, now check semantic equality
+ if (areEqual(LHS, RHS))
+ return false;
+ // Different semantically, compare by location
+ return compareExprLocs(LHS, RHS);
+ }
+
+ public:
+ /// Return true if \p LHS and \p RHS are semantically equal. Uses pre-cached
+ /// results, if available, otherwise does a recursive semantic comparison.
+ bool areEqual(const Expr *LHS, const Expr *RHS) const {
+ // Check cache first for faster lookup
+ auto CachedResultIt = CachedEqualityComparisons.find({LHS, RHS});
+ if (CachedResultIt != CachedEqualityComparisons.end())
+ return CachedResultIt->second;
+
+ bool ComparisonResult = areSemanticallyEqual(LHS, RHS);
+
+ // Cache the result for future lookups (both orders since semantic
+ // equality is commutative)
+ CachedEqualityComparisons[{LHS, RHS}] = ComparisonResult;
+ CachedEqualityComparisons[{RHS, LHS}] = ComparisonResult;
+ return ComparisonResult;
+ }
+
+ private:
+ /// Helper function to compare attach-pointer expressions semantically.
+ /// This function handles various expression types that can be part of an
+ /// attach-pointer.
+ /// TODO: Not urgent, but we should ideally return true when comparing
+ /// `p[10]`, `*(p + 10)`, `*(p + 5 + 5)`, `p[10:1]` etc.
+ bool areSemanticallyEqual(const Expr *LHS, const Expr *RHS) const {
+ if (LHS == RHS)
+ return true;
+
+ // If only one is null, they aren't equal
+ if (!LHS || !RHS)
+ return false;
+
+ ASTContext &Ctx = Handler->CGF.getContext();
+ // Strip away parentheses and no-op casts to get to the core expression
+ LHS = LHS->IgnoreParenNoopCasts(Ctx);
+ RHS = RHS->IgnoreParenNoopCasts(Ctx);
+
+ // Direct pointer comparison of the underlying expressions
+ if (LHS == RHS)
+ return true;
+
+ // Check if the expression classes match
+ if (LHS->getStmtClass() != RHS->getStmtClass())
+ return false;
+
+ // Handle DeclRefExpr (variable references)
+ if (const auto *LD = dyn_cast<DeclRefExpr>(LHS)) {
+ const auto *RD = dyn_cast<DeclRefExpr>(RHS);
+ if (!RD)
+ return false;
+ return LD->getDecl()->getCanonicalDecl() ==
+ RD->getDecl()->getCanonicalDecl();
+ }
+
+ // Handle ArraySubscriptExpr (array indexing like a[i])
+ if (const auto *LA = dyn_cast<ArraySubscriptExpr>(LHS)) {
+ const auto *RA = dyn_cast<ArraySubscriptExpr>(RHS);
+ if (!RA)
+ return false;
+ return areSemanticallyEqual(LA->getBase(), RA->getBase()) &&
+ areSemanticallyEqual(LA->getIdx(), RA->getIdx());
+ }
+
+ // Handle MemberExpr (member access like s.m or p->m)
+ if (const auto *LM = dyn_cast<MemberExpr>(LHS)) {
+ const auto *RM = dyn_cast<MemberExpr>(RHS);
+ if (!RM)
+ return false;
+ if (LM->getMemberDecl()->getCanonicalDecl() !=
+ RM->getMemberDecl()->getCanonicalDecl())
+ return false;
+ return areSemanticallyEqual(LM->getBase(), RM->getBase());
+ }
+
+ // Handle UnaryOperator (unary operations like *p, &x, etc.)
+ if (const auto *LU = dyn_cast<UnaryOperator>(LHS)) {
+ const auto *RU = dyn_cast<UnaryOperator>(RHS);
+ if (!RU)
+ return false;
+ if (LU->getOpcode() != RU->getOpcode())
+ return false;
+ return areSemanticallyEqual(LU->getSubExpr(), RU->getSubExpr());
+ }
+
+ // Handle BinaryOperator (binary operations like p + offset)
+ if (const auto *LB = dyn_cast<BinaryOperator>(LHS)) {
+ const auto *RB = dyn_cast<BinaryOperator>(RHS);
+ if (!RB)
+ return false;
+ if (LB->getOpcode() != RB->getOpcode())
+ return false;
+ return areSemanticallyEqual(LB->getLHS(), RB->getLHS()) &&
+ areSemanticallyEqual(LB->getRHS(), RB->getRHS());
+ }
+
+ // Handle ArraySectionExpr (array sections like a[0:1])
+ // Attach pointers should not contain array-sections, but currently we
+ // don't emit an error.
+ if (const auto *LAS = dyn_cast<ArraySectionExpr>(LHS)) {
+ const auto *RAS = dyn_cast<ArraySectionExpr>(RHS);
+ if (!RAS)
+ return false;
+ return areSemanticallyEqual(LAS->getBase(), RAS->getBase()) &&
+ areSemanticallyEqual(LAS->getLowerBound(),
+ RAS->getLowerBound()) &&
+ areSemanticallyEqual(LAS->getLength(), RAS->getLength());
+ }
+
+ // Handle CastExpr (explicit casts)
+ if (const auto *LC = dyn_cast<CastExpr>(LHS)) {
+ const auto *RC = dyn_cast<CastExpr>(RHS);
+ if (!RC)
+ return false;
+ if (LC->getCastKind() != RC->getCastKind())
+ return false;
+ return areSemanticallyEqual(LC->getSubExpr(), RC->getSubExpr());
+ }
+
+ // Handle CXXThisExpr (this pointer)
+ if (isa<CXXThisExpr>(LHS) && isa<CXXThisExpr>(RHS))
+ return true;
+
+ // Handle IntegerLiteral (integer constants)
+ if (const auto *LI = dyn_cast<IntegerLiteral>(LHS)) {
+ const auto *RI = dyn_cast<IntegerLiteral>(RHS);
+ if (!RI)
+ return false;
+ return LI->getValue() == RI->getValue();
+ }
+
+ // Handle CharacterLiteral (character constants)
+ if (const auto *LC = dyn_cast<CharacterLiteral>(LHS)) {
+ const auto *RC = dyn_cast<CharacterLiteral>(RHS);
+ if (!RC)
+ return false;
+ return LC->getValue() == RC->getValue();
+ }
+
+ // Handle FloatingLiteral (floating point constants)
+ if (const auto *LF = dyn_cast<FloatingLiteral>(LHS)) {
+ const auto *RF = dyn_cast<FloatingLiteral>(RHS);
+ if (!RF)
+ return false;
+ // Use bitwise comparison for floating point literals
+ return LF->getValue().bitwiseIsEqual(RF->getValue());
+ }
+
+ // Handle StringLiteral (string constants)
+ if (const auto *LS = dyn_cast<StringLiteral>(LHS)) {
+ const auto *RS = dyn_cast<StringLiteral>(RHS);
+ if (!RS)
+ return false;
+ return LS->getString() == RS->getString();
+ }
+
+ // Handle CXXNullPtrLiteralExpr (nullptr)
+ if (isa<CXXNullPtrLiteralExpr>(LHS) && isa<CXXNullPtrLiteralExpr>(RHS))
+ return true;
+
+ // Handle CXXBoolLiteralExpr (true/false)
+ if (const auto *LB = dyn_cast<CXXBoolLiteralExpr>(LHS)) {
+ const auto *RB = dyn_cast<CXXBoolLiteralExpr>(RHS);
+ if (!RB)
+ return false;
+ return LB->getValue() == RB->getValue();
+ }
+
+ // Fallback for other forms - use the existing comparison method
+ return Expr::isSameComparisonOperand(LHS, RHS);
+ }
+ };
+
/// Get the offset of the OMP_MAP_MEMBER_OF field.
static unsigned getFlagMemberOffset() {
unsigned Offset = 0;
@@ -6846,8 +7090,42 @@ class MappableExprsHandler {
Address LB = Address::invalid();
bool IsArraySection = false;
bool HasCompleteRecord = false;
+ // ATTACH information for delaye...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/155625
More information about the cfe-commits
mailing list