[flang-commits] [flang] [Flang][OpenMP][NFC] Track Objects for BlockArgs (PR #197442)

via flang-commits flang-commits at lists.llvm.org
Wed May 13 06:29:53 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Jack Styles (Stylie777)

<details>
<summary>Changes</summary>

When lowering a BlockArg in OpenMP, currently the symbol is tracked. This can however cause issues later on down the line as information may be lost relating to an expression. For example, an ArrayElement will be represented by its symbol, in this case the full array. This is not ideal as its just he ArrayElement that is intended to be represented.

Now, the object is tracked instead of the Symbol. For cases where the symbol is required, appropriate API is available to retrieve this information. This change opens the ability to better handle lowering of expressions such as Array Elements.

Assisted-by: Codex

---

Patch is 73.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/197442.diff


6 Files Affected:

- (modified) flang/include/flang/Support/OpenMP-utils.h (+29-7) 
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+57-47) 
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+30-30) 
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+215-186) 
- (modified) flang/lib/Lower/OpenMP/Utils.cpp (+6-6) 
- (modified) flang/lib/Lower/OpenMP/Utils.h (+1-1) 


``````````diff
diff --git a/flang/include/flang/Support/OpenMP-utils.h b/flang/include/flang/Support/OpenMP-utils.h
index 6d9db2b682c50..47d5ab2c023c0 100644
--- a/flang/include/flang/Support/OpenMP-utils.h
+++ b/flang/include/flang/Support/OpenMP-utils.h
@@ -9,25 +9,35 @@
 #ifndef FORTRAN_SUPPORT_OPENMP_UTILS_H_
 #define FORTRAN_SUPPORT_OPENMP_UTILS_H_
 
+#include "flang/Lower/OpenMP/Clauses.h"
 #include "flang/Semantics/symbol.h"
 
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Value.h"
 
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
 
 namespace Fortran::common::openmp {
 /// Structure holding the information needed to create and bind entry block
 /// arguments associated to a single clause.
 struct EntryBlockArgsEntry {
-  llvm::ArrayRef<const Fortran::semantics::Symbol *> syms;
+  llvm::SmallVector<Fortran::lower::omp::Object> objects;
   llvm::ArrayRef<mlir::Value> vars;
 
   bool isValid() const {
-    // This check allows specifying a smaller number of symbols than values
+    // This check allows specifying a smaller number of objects than values
     // because in some case cases a single symbol generates multiple block
     // arguments.
-    return syms.size() <= vars.size();
+    return objects.size() <= vars.size();
+  }
+
+  llvm::SmallVector<const Fortran::semantics::Symbol *> getSyms() const {
+    llvm::SmallVector<const Fortran::semantics::Symbol *> syms;
+    for (const Fortran::lower::omp::Object &object : objects) {
+      syms.push_back(object.sym());
+    }
+    return syms;
   }
 };
 
@@ -50,10 +60,22 @@ struct EntryBlockArgs {
         useDeviceAddr.isValid() && useDevicePtr.isValid();
   }
 
-  auto getSyms() const {
-    return llvm::concat<const semantics::Symbol *const>(hasDeviceAddr.syms,
-        inReduction.syms, map.syms, priv.syms, reduction.syms,
-        taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
+  llvm::SmallVector<const semantics::Symbol *> getSyms() const {
+    llvm::SmallVector<const semantics::Symbol *> syms;
+    auto appendSyms = [&syms](const EntryBlockArgsEntry &entry) {
+      for (const Fortran::lower::omp::Object &object : entry.objects) {
+        syms.push_back(object.sym());
+      }
+    };
+    appendSyms(hasDeviceAddr);
+    appendSyms(inReduction);
+    appendSyms(map);
+    appendSyms(priv);
+    appendSyms(reduction);
+    appendSyms(taskReduction);
+    appendSyms(useDeviceAddr);
+    appendSyms(useDevicePtr);
+    return syms;
   }
 
   auto getVars() const {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 5f5b4fe77f701..6bbabf7e38cdb 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1493,7 +1493,7 @@ bool ClauseProcessor::processGrainsize(
 
 bool ClauseProcessor::processHasDeviceAddr(
     lower::StatementContext &stmtCtx, mlir::omp::HasDeviceAddrClauseOps &result,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceSyms) const {
+    llvm::SmallVectorImpl<Object> &hasDeviceObjects) const {
   // For HAS_DEVICE_ADDR objects, implicitly map the top-level entities.
   // Their address (or the whole descriptor, if the entity had one) will be
   // passed to the target region.
@@ -1513,11 +1513,11 @@ bool ClauseProcessor::processHasDeviceAddr(
                         });
         processMapObjects(stmtCtx, location, baseObjects, mapTypeBits,
                           parentMemberIndices, result.hasDeviceAddrVars,
-                          hasDeviceSyms);
+                          hasDeviceObjects);
       });
 
   insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
-                               result.hasDeviceAddrVars, hasDeviceSyms);
+                               result.hasDeviceAddrVars, hasDeviceObjects);
   return clauseFound;
 }
 
@@ -1541,26 +1541,32 @@ bool ClauseProcessor::processIf(
 }
 
 template <typename T>
-void collectReductionSyms(
-    const T &reduction,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
-  const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
-  for (const Object &object : objectList) {
-    const semantics::Symbol *symbol = object.sym();
-    reductionSyms.push_back(symbol);
-  }
+void collectReductionObjects(const T &reduction,
+                             llvm::SmallVectorImpl<Object> &reductionObjects) {
+  const omp::ObjectList &objectList{std::get<omp::ObjectList>(reduction.t)};
+  llvm::copy(objectList, std::back_inserter(reductionObjects));
+}
+
+static llvm::SmallVector<const semantics::Symbol *>
+getObjectsSyms(llvm::ArrayRef<Object> objects) {
+  llvm::SmallVector<const semantics::Symbol *> syms;
+  for (const Object &object : objects)
+    syms.push_back(object.sym());
+  return syms;
 }
 
 bool ClauseProcessor::processInReduction(
     mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
+    llvm::SmallVectorImpl<Object> &outReductionObjects) const {
   return findRepeatableClause<omp::clause::InReduction>(
       [&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
         llvm::SmallVector<mlir::Value> inReductionVars;
         llvm::SmallVector<bool> inReduceVarByRef;
         llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
-        llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
-        collectReductionSyms(clause, inReductionSyms);
+        llvm::SmallVector<Object> inReductionObjects;
+        collectReductionObjects(clause, inReductionObjects);
+        llvm::SmallVector<const semantics::Symbol *> inReductionSyms =
+            getObjectsSyms(inReductionObjects);
 
         ReductionProcessor rp;
         if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
@@ -1576,13 +1582,13 @@ bool ClauseProcessor::processInReduction(
                    std::back_inserter(result.inReductionByref));
         llvm::copy(inReductionDeclSymbols,
                    std::back_inserter(result.inReductionSyms));
-        llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
+        llvm::copy(inReductionObjects, std::back_inserter(outReductionObjects));
       });
 }
 
 bool ClauseProcessor::processIsDevicePtr(
     lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
+    llvm::SmallVectorImpl<Object> &isDeviceObjects) const {
   std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
   bool clauseFound = findRepeatableClause<omp::clause::IsDevicePtr>(
       [&](const omp::clause::IsDevicePtr &clause,
@@ -1595,11 +1601,11 @@ bool ClauseProcessor::processIsDevicePtr(
             mlir::omp::ClauseMapFlags::to;
         processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
                           parentMemberIndices, result.isDevicePtrVars,
-                          isDeviceSyms);
+                          isDeviceObjects);
       });
 
   insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
-                               result.isDevicePtrVars, isDeviceSyms);
+                               result.isDevicePtrVars, isDeviceObjects);
   return clauseFound;
 }
 
@@ -1709,9 +1715,8 @@ void ClauseProcessor::processMapObjects(
     const omp::ObjectList &objects, mlir::omp::ClauseMapFlags mapTypeBits,
     std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
     llvm::SmallVectorImpl<mlir::Value> &mapVars,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
-    llvm::StringRef mapperIdNameRef, bool isMotionModifier,
-    llvm::omp::Directive directive) const {
+    llvm::SmallVectorImpl<Object> &mapObjects, llvm::StringRef mapperIdNameRef,
+    bool isMotionModifier, llvm::omp::Directive directive) const {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
   auto getSymbolDerivedType = [](const semantics::Symbol &symbol)
@@ -1902,7 +1907,7 @@ void ClauseProcessor::processMapObjects(
           object, mapOp, semaCtx);
     } else {
       mapVars.push_back(mapOp);
-      mapSyms.push_back(object.sym());
+      mapObjects.push_back(object);
     }
   }
 }
@@ -1936,13 +1941,13 @@ getMapperIdentifier(lower::AbstractConverter &converter,
 bool ClauseProcessor::processMap(
     mlir::Location currentLocation, lower::StatementContext &stmtCtx,
     mlir::omp::MapClauseOps &result, llvm::omp::Directive directive,
-    llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms) const {
-  // We always require tracking of symbols, even if the caller does not,
-  // so we create an optionally used local set of symbols when the mapSyms
+    llvm::SmallVectorImpl<Object> *mapObjects) const {
+  // We always require tracking of objects, even if the caller does not,
+  // so we create an optionally used local set of objects when the mapObjects
   // argument is not present.
-  llvm::SmallVector<const semantics::Symbol *> localMapSyms;
-  llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
-      mapSyms ? mapSyms : &localMapSyms;
+  llvm::SmallVector<Object> localMapObjects;
+  llvm::SmallVectorImpl<Object> *ptrMapObjects =
+      mapObjects ? mapObjects : &localMapObjects;
   std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
 
   auto process = [&](const omp::clause::Map &clause,
@@ -2005,13 +2010,13 @@ bool ClauseProcessor::processMap(
     }
     processMapObjects(stmtCtx, clauseLocation,
                       std::get<omp::ObjectList>(clause.t), mapTypeBits,
-                      parentMemberIndices, result.mapVars, *ptrMapSyms,
+                      parentMemberIndices, result.mapVars, *ptrMapObjects,
                       mapperIdName, /*isMotionModifier=*/false, directive);
   };
 
   bool clauseFound = findRepeatableClause<omp::clause::Map>(process);
   insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
-                               result.mapVars, *ptrMapSyms);
+                               result.mapVars, *ptrMapObjects);
 
   return clauseFound;
 }
@@ -2019,7 +2024,7 @@ bool ClauseProcessor::processMap(
 bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
                                            mlir::omp::MapClauseOps &result) {
   std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
-  llvm::SmallVector<const semantics::Symbol *> mapSymbols;
+  llvm::SmallVector<Object> mapObjects;
 
   auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) {
     mlir::Location clauseLocation = converter.genLocation(source);
@@ -2040,7 +2045,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
     }
 
     processMapObjects(stmtCtx, clauseLocation, objects, mapTypeBits,
-                      parentMemberIndices, result.mapVars, mapSymbols,
+                      parentMemberIndices, result.mapVars, mapObjects,
                       mapperIdName, /*isMotionModifier=*/true);
   };
 
@@ -2049,7 +2054,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
       findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;
 
   insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
-                               result.mapVars, mapSymbols);
+                               result.mapVars, mapObjects);
 
   return clauseFound;
 }
@@ -2068,7 +2073,7 @@ bool ClauseProcessor::processNontemporal(
 
 bool ClauseProcessor::processReduction(
     mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms,
+    llvm::SmallVectorImpl<Object> &outReductionObjects,
     llvm::DenseMap<const semantics::Symbol *, mlir::Value> *reductionVarCache)
     const {
   return findRepeatableClause<omp::clause::Reduction>(
@@ -2076,8 +2081,10 @@ bool ClauseProcessor::processReduction(
         llvm::SmallVector<mlir::Value> reductionVars;
         llvm::SmallVector<bool> reduceVarByRef;
         llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
-        llvm::SmallVector<const semantics::Symbol *> reductionSyms;
-        collectReductionSyms(clause, reductionSyms);
+        llvm::SmallVector<Object> reductionObjects;
+        collectReductionObjects(clause, reductionObjects);
+        llvm::SmallVector<const semantics::Symbol *> reductionSyms =
+            getObjectsSyms(reductionObjects);
 
         auto mod = std::get<std::optional<ReductionModifier>>(clause.t);
         if (mod.has_value()) {
@@ -2101,20 +2108,22 @@ bool ClauseProcessor::processReduction(
         llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
         llvm::copy(reductionDeclSymbols,
                    std::back_inserter(result.reductionSyms));
-        llvm::copy(reductionSyms, std::back_inserter(outReductionSyms));
+        llvm::copy(reductionObjects, std::back_inserter(outReductionObjects));
       });
 }
 
 bool ClauseProcessor::processTaskReduction(
     mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
+    llvm::SmallVectorImpl<Object> &outReductionObjects) const {
   return findRepeatableClause<omp::clause::TaskReduction>(
       [&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
         llvm::SmallVector<mlir::Value> taskReductionVars;
         llvm::SmallVector<bool> taskReduceVarByRef;
         llvm::SmallVector<mlir::Attribute> taskReductionDeclSymbols;
-        llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
-        collectReductionSyms(clause, taskReductionSyms);
+        llvm::SmallVector<Object> taskReductionObjects;
+        collectReductionObjects(clause, taskReductionObjects);
+        llvm::SmallVector<const semantics::Symbol *> taskReductionSyms =
+            getObjectsSyms(taskReductionObjects);
 
         ReductionProcessor rp;
         if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
@@ -2130,7 +2139,8 @@ bool ClauseProcessor::processTaskReduction(
                    std::back_inserter(result.taskReductionByref));
         llvm::copy(taskReductionDeclSymbols,
                    std::back_inserter(result.taskReductionSyms));
-        llvm::copy(taskReductionSyms, std::back_inserter(outReductionSyms));
+        llvm::copy(taskReductionObjects,
+                   std::back_inserter(outReductionObjects));
       });
 }
 
@@ -2161,7 +2171,7 @@ bool ClauseProcessor::processEnter(
 
 bool ClauseProcessor::processUseDeviceAddr(
     lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
+    llvm::SmallVectorImpl<Object> &useDeviceObjects) const {
   std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
   bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
       [&](const omp::clause::UseDeviceAddr &clause,
@@ -2171,17 +2181,17 @@ bool ClauseProcessor::processUseDeviceAddr(
             mlir::omp::ClauseMapFlags::return_param;
         processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
                           parentMemberIndices, result.useDeviceAddrVars,
-                          useDeviceSyms);
+                          useDeviceObjects);
       });
 
   insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
-                               result.useDeviceAddrVars, useDeviceSyms);
+                               result.useDeviceAddrVars, useDeviceObjects);
   return clauseFound;
 }
 
 bool ClauseProcessor::processUseDevicePtr(
     lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
+    llvm::SmallVectorImpl<Object> &useDeviceObjects) const {
   std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
 
   bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
@@ -2192,11 +2202,11 @@ bool ClauseProcessor::processUseDevicePtr(
             mlir::omp::ClauseMapFlags::return_param;
         processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
                           parentMemberIndices, result.useDevicePtrVars,
-                          useDeviceSyms);
+                          useDeviceObjects);
       });
 
   insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
-                               result.useDevicePtrVars, useDeviceSyms);
+                               result.useDevicePtrVars, useDeviceObjects);
   return clauseFound;
 }
 
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index e138b4df30b71..1fc221b721ebf 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -87,10 +87,10 @@ class ClauseProcessor {
                     mlir::omp::FinalClauseOps &result) const;
   bool processGrainsize(lower::StatementContext &stmtCtx,
                         mlir::omp::GrainsizeClauseOps &result) const;
-  bool processHasDeviceAddr(
-      lower::StatementContext &stmtCtx,
-      mlir::omp::HasDeviceAddrClauseOps &result,
-      llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceSyms) const;
+  bool
+  processHasDeviceAddr(lower::StatementContext &stmtCtx,
+                       mlir::omp::HasDeviceAddrClauseOps &result,
+                       llvm::SmallVectorImpl<Object> &hasDeviceObjects) const;
   bool processHint(mlir::omp::HintClauseOps &result) const;
   bool processInbranch(mlir::omp::InbranchClauseOps &result) const;
   bool processInclusive(mlir::Location currentLocation,
@@ -141,47 +141,47 @@ class ClauseProcessor {
   processEnter(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
   bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
                  mlir::omp::IfClauseOps &result) const;
-  bool processInReduction(
-      mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
-      llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
-  bool processIsDevicePtr(
-      lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
-      llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+  bool
+  processInReduction(mlir::Location currentLocation,
+                     mlir::omp::InReductionClauseOps &result,
+                     llvm::SmallVectorImpl<Object> &outReductionObjects) const;
+  bool processIsDevicePtr(lower::StatementContext &stmtCtx,
+                          mlir::omp::IsDevicePtrClauseOps &result,
+                          llvm::SmallVectorImpl<Object> &isDeviceObjects) const;
   bool processLinear(mlir::omp::LinearClauseOps &result,
                      bool isDeclareSimd = false) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
 
   // This method is used to process a map clause.
-  // The optional parameter mapSyms is used to store the original Fortran symbol
-  // for the map operands. It may be used later on to create the block_arguments
-  // for some of the directives that require it.
+  // The optional parameter mapObjects is used to store the original Fortran
+  // objects for the map operands. It may be used later on to create the
+  // block_arguments for some of the directives that require it.
   bool processMap(mlir::Location currentLocation,
                   lower::StatementContext &stmtCtx,
                   mlir::omp::MapClauseOps &result,
                   llvm::omp::Directive directive = llvm::omp::OMPD_unknown,
-                  llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms =
-                      nullptr) const;
+                  llvm::SmallVectorImpl<Object> *mapObjects = nullptr) const;
   bool processMotionClauses(lower::StatementContext &stmtCtx,
                             mlir::omp::MapClauseOps &result);
   bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const;
-  bool processReduction(
-      mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
-      llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms,
-      llvm::DenseMap<const semantics::Symbol *, mlir::Value>
-          *reductio...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/197442


More information about the flang-commits mailing list