[flang-commits] [flang] [flang][OpenMP] Add `sym()` member function to omp::Object (PR #94493)

Krzysztof Parzyszek via flang-commits flang-commits at lists.llvm.org
Wed Jun 5 08:59:50 PDT 2024


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/94493

The object identity requires more than just `Symbol`. Don't use `id()` to get the Symbol associated with the object, becase the return value will need to change. Instead use `sym()` which is added for that reason.

>From 828beb6622142c1d175a61a2664398fc91d36192 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 4 Jun 2024 15:57:38 -0500
Subject: [PATCH] [flang][OpenMP] Add `sym()` member function to omp::Object

The object identity requires more than just `Symbol`. Don't use `id()`
to get the Symbol associated with the object, becase the return value
will need to change. Instead use `sym()` which is added for that reason.
---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp      | 16 ++++++++--------
 flang/lib/Lower/OpenMP/ClauseProcessor.h        |  8 ++++----
 flang/lib/Lower/OpenMP/Clauses.h                |  7 ++++++-
 flang/lib/Lower/OpenMP/DataSharingProcessor.cpp |  2 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp               |  2 +-
 flang/lib/Lower/OpenMP/ReductionProcessor.cpp   |  8 ++++----
 flang/lib/Lower/OpenMP/Utils.cpp                |  6 +++---
 7 files changed, 27 insertions(+), 22 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 68619f699ebb2..d289f2fdfab26 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -175,7 +175,7 @@ static void addUseDeviceClause(
     useDeviceLocs.push_back(operand.getLoc());
   }
   for (const omp::Object &object : objects)
-    useDeviceSyms.push_back(object.id());
+    useDeviceSyms.push_back(object.sym());
 }
 
 static void convertLoopBounds(lower::AbstractConverter &converter,
@@ -525,7 +525,7 @@ bool ClauseProcessor::processCopyin() const {
   bool hasCopyin = findRepeatableClause<omp::clause::Copyin>(
       [&](const omp::clause::Copyin &clause, const parser::CharBlock &) {
         for (const omp::Object &object : clause.v) {
-          semantics::Symbol *sym = object.id();
+          semantics::Symbol *sym = object.sym();
           assert(sym && "Expecting symbol");
           if (const auto *commonDetails =
                   sym->detailsIf<semantics::CommonBlockDetails>()) {
@@ -698,7 +698,7 @@ bool ClauseProcessor::processCopyprivate(
   bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
       [&](const clause::Copyprivate &clause, const parser::CharBlock &) {
         for (const Object &object : clause.v) {
-          semantics::Symbol *sym = object.id();
+          semantics::Symbol *sym = object.sym();
           if (const auto *commonDetails =
                   sym->detailsIf<semantics::CommonBlockDetails>()) {
             for (const auto &mem : commonDetails->objects())
@@ -739,7 +739,7 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
                  "array sections not supported for task depend");
           }
 
-          semantics::Symbol *sym = object.id();
+          semantics::Symbol *sym = object.sym();
           const mlir::Value variable = converter.getSymbolAddress(*sym);
           result.dependVars.push_back(variable);
         }
@@ -870,11 +870,11 @@ bool ClauseProcessor::processMap(
           lower::AddrAndBoundsInfo info =
               lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
                                                     mlir::omp::MapBoundsType>(
-                  converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
+                  converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
                   object.ref(), clauseLocation, asFortran, bounds,
                   treatIndexAsSection);
 
-          auto origSymbol = converter.getSymbolAddress(*object.id());
+          auto origSymbol = converter.getSymbolAddress(*object.sym());
           mlir::Value symAddr = info.addr;
           if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
             symAddr = origSymbol;
@@ -894,12 +894,12 @@ bool ClauseProcessor::processMap(
                   mapTypeBits),
               mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
 
-          if (object.id()->owner().IsDerivedType()) {
+          if (object.sym()->owner().IsDerivedType()) {
             addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
                                         semaCtx);
           } else {
             result.mapVars.push_back(mapOp);
-            ptrMapSyms->push_back(object.id());
+            ptrMapSyms->push_back(object.sym());
             if (mapSymTypes)
               mapSymTypes->push_back(symAddr.getType());
             if (mapSymLocs)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 4d3d4448e8f03..28f26697c1f50 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -205,11 +205,11 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
           lower::AddrAndBoundsInfo info =
               lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
                                                     mlir::omp::MapBoundsType>(
-                  converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
+                  converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
                   object.ref(), clauseLocation, asFortran, bounds,
                   treatIndexAsSection);
 
-          auto origSymbol = converter.getSymbolAddress(*object.id());
+          auto origSymbol = converter.getSymbolAddress(*object.sym());
           mlir::Value symAddr = info.addr;
           if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
             symAddr = origSymbol;
@@ -226,12 +226,12 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
                   mapTypeBits),
               mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
 
-          if (object.id()->owner().IsDerivedType()) {
+          if (object.sym()->owner().IsDerivedType()) {
             addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
                                         semaCtx);
           } else {
             result.mapVars.push_back(mapOp);
-            mapSymbols.push_back(object.id());
+            mapSymbols.push_back(object.sym());
           }
         }
       });
diff --git a/flang/lib/Lower/OpenMP/Clauses.h b/flang/lib/Lower/OpenMP/Clauses.h
index 5391b134e979d..f7cd0ea83ad12 100644
--- a/flang/lib/Lower/OpenMP/Clauses.h
+++ b/flang/lib/Lower/OpenMP/Clauses.h
@@ -21,6 +21,10 @@
 #include <type_traits>
 #include <utility>
 
+namespace Fortran::semantics {
+class Symbol;
+}
+
 namespace Fortran::lower::omp {
 using namespace Fortran;
 using SomeExpr = semantics::SomeExpr;
@@ -45,7 +49,8 @@ struct ObjectT<Fortran::lower::omp::IdTy, Fortran::lower::omp::ExprTy> {
   using IdTy = Fortran::lower::omp::IdTy;
   using ExprTy = Fortran::lower::omp::ExprTy;
 
-  const IdTy &id() const { return symbol; }
+  IdTy id() const { return symbol; }
+  Fortran::semantics::Symbol *sym() const { return symbol; }
   const std::optional<ExprTy> &ref() const { return designator; }
 
   IdTy symbol;
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 557a9685024c5..b206040c237c5 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -139,7 +139,7 @@ void DataSharingProcessor::collectOmpObjectListSymbol(
     const omp::ObjectList &objects,
     llvm::SetVector<const semantics::Symbol *> &symbolSet) {
   for (const omp::Object &object : objects)
-    symbolSet.insert(object.id());
+    symbolSet.insert(object.sym());
 }
 
 void DataSharingProcessor::collectSymbolsForPrivatization() {
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index af9e2af24619b..f84440d95ec11 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1434,7 +1434,7 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       mlir::OpBuilder::InsertPoint insp = builder.saveInsertionPoint();
       const auto &objList = std::get<ObjectList>(lastp->t);
       for (const Object &object : objList) {
-        semantics::Symbol *sym = object.id();
+        semantics::Symbol *sym = object.sym();
         converter.copyHostAssociateVar(*sym, &insp);
       }
     }
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 1a63e316ef068..60e933f5bc1f7 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -37,7 +37,7 @@ namespace omp {
 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
     const omp::clause::ProcedureDesignator &pd) {
   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
-                     getRealName(pd.v.id()).ToString())
+                     getRealName(pd.v.sym()).ToString())
                      .Case("max", ReductionIdentifier::MAX)
                      .Case("min", ReductionIdentifier::MIN)
                      .Case("iand", ReductionIdentifier::IAND)
@@ -72,7 +72,7 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
 
 bool ReductionProcessor::supportedIntrinsicProcReduction(
     const omp::clause::ProcedureDesignator &pd) {
-  semantics::Symbol *sym = pd.v.id();
+  semantics::Symbol *sym = pd.v.sym();
   if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC))
     return false;
   auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
@@ -707,7 +707,7 @@ void ReductionProcessor::addDeclareReduction(
   // should happen byref
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   for (const Object &object : objectList) {
-    const semantics::Symbol *symbol = object.id();
+    const semantics::Symbol *symbol = object.sym();
     if (reductionSymbols)
       reductionSymbols->push_back(symbol);
     mlir::Value symVal = converter.getSymbolAddress(*symbol);
@@ -825,7 +825,7 @@ ReductionProcessor::getRealName(const semantics::Symbol *symbol) {
 
 const semantics::SourceName
 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
-  return getRealName(pd.v.id());
+  return getRealName(pd.v.sym());
 }
 
 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 4d665e6dd34c5..eff915f569f27 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -55,7 +55,7 @@ void genObjectList(const ObjectList &objects,
                    lower::AbstractConverter &converter,
                    llvm::SmallVectorImpl<mlir::Value> &operands) {
   for (const Object &object : objects) {
-    const semantics::Symbol *sym = object.id();
+    const semantics::Symbol *sym = object.sym();
     assert(sym && "Expected Symbol");
     if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
       operands.push_back(variable);
@@ -107,7 +107,7 @@ void gatherFuncAndVarSyms(
     const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
   for (const Object &object : objects)
-    symbolAndClause.emplace_back(clause, *object.id());
+    symbolAndClause.emplace_back(clause, *object.sym());
 }
 
 mlir::omp::MapInfoOp
@@ -175,7 +175,7 @@ generateMemberPlacementIndices(const Object &object,
                                semantics::SemanticsContext &semaCtx) {
   auto compObj = getComponentObject(object, semaCtx);
   while (compObj) {
-    indices.push_back(getComponentPlacementInParent(compObj->id()));
+    indices.push_back(getComponentPlacementInParent(compObj->sym()));
     compObj =
         getComponentObject(getBaseObject(compObj.value(), semaCtx), semaCtx);
   }



More information about the flang-commits mailing list