[flang-commits] [flang] [flang] Add general symbol dependence collection utility (PR #144618)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Tue Jun 17 16:33:40 PDT 2025


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/144618

>From 8e0a688414ea6c52fc92bb9f25257b0ce062026f Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Tue, 17 Jun 2025 11:51:02 -0700
Subject: [PATCH] [flang] Add general symbol dependence collection utility

Replace HarvestSymbolsNeededFromOtherModules() in mod-file.cpp with a
general utility function in Semantics.  This new code will find
other uses in further rework of hermetic module file generation
as the means by which the necessary subsets of symbols in dependency
modules are collected.
---
 .../flang/Semantics/symbol-set-closure.h      |  34 ++++
 flang/lib/Semantics/CMakeLists.txt            |   1 +
 flang/lib/Semantics/mod-file.cpp              |  68 +------
 flang/lib/Semantics/resolve-names.cpp         |  10 +-
 flang/lib/Semantics/symbol-set-closure.cpp    | 185 ++++++++++++++++++
 flang/lib/Semantics/tools.cpp                 |   6 +-
 flang/lib/Semantics/type.cpp                  |  12 +-
 7 files changed, 244 insertions(+), 72 deletions(-)
 create mode 100644 flang/include/flang/Semantics/symbol-set-closure.h
 create mode 100644 flang/lib/Semantics/symbol-set-closure.cpp

diff --git a/flang/include/flang/Semantics/symbol-set-closure.h b/flang/include/flang/Semantics/symbol-set-closure.h
new file mode 100644
index 0000000000000..507ec851b2898
--- /dev/null
+++ b/flang/include/flang/Semantics/symbol-set-closure.h
@@ -0,0 +1,34 @@
+//===-- include/flang/Semantics/symbol-set-closure.h ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_
+#define FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_
+
+#include "flang/Semantics/symbol.h"
+
+namespace Fortran::semantics {
+
+// For a set or scope of symbols, computes the transitive closure of their
+// dependences due to their types, bounds, specific procedures, interfaces,
+// initialization, storage association, &c. Includes the original symbol
+// or members of the original set.  Does not include dependences from
+// subprogram definitions, only their interfaces.
+enum DependenceCollectionFlags {
+  NoDependenceCollectionFlags = 0,
+  IncludeOriginalSymbols = 1 << 0,
+  FollowUseAssociations = 1 << 1,
+  IncludeSpecificsOfGenerics = 1 << 2,
+  IncludeComponentsInExprs = 1 << 3,
+};
+UnorderedSymbolSet CollectAllDependences(const UnorderedSymbolSet &,
+    enum DependenceCollectionFlags = NoDependenceCollectionFlags);
+UnorderedSymbolSet CollectAllDependences(const Scope &,
+    enum DependenceCollectionFlags = NoDependenceCollectionFlags);
+
+} // namespace Fortran::semantics
+#endif // FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_
diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt
index 18c89587843a9..c1be83b0c744c 100644
--- a/flang/lib/Semantics/CMakeLists.txt
+++ b/flang/lib/Semantics/CMakeLists.txt
@@ -46,6 +46,7 @@ add_flang_library(FortranSemantics
   scope.cpp
   semantics.cpp
   symbol.cpp
+  symbol-set-closure.cpp
   tools.cpp
   type.cpp
   unparse-with-symbols.cpp
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index 82c8536902eb2..05a4ee2ea21e3 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -15,6 +15,7 @@
 #include "flang/Parser/unparse.h"
 #include "flang/Semantics/scope.h"
 #include "flang/Semantics/semantics.h"
+#include "flang/Semantics/symbol-set-closure.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "llvm/Support/FileSystem.h"
@@ -223,71 +224,10 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) {
 // Collect symbols from constant and specification expressions that are being
 // referenced directly from other modules; they may require new USE
 // associations.
-static void HarvestSymbolsNeededFromOtherModules(
-    SourceOrderedSymbolSet &, const Scope &);
-static void HarvestSymbolsNeededFromOtherModules(
-    SourceOrderedSymbolSet &set, const Symbol &symbol, const Scope &scope) {
-  auto HarvestBound{[&](const Bound &bound) {
-    if (const auto &expr{bound.GetExplicit()}) {
-      for (SymbolRef ref : evaluate::CollectSymbols(*expr)) {
-        set.emplace(*ref);
-      }
-    }
-  }};
-  auto HarvestShapeSpec{[&](const ShapeSpec &shapeSpec) {
-    HarvestBound(shapeSpec.lbound());
-    HarvestBound(shapeSpec.ubound());
-  }};
-  auto HarvestArraySpec{[&](const ArraySpec &arraySpec) {
-    for (const auto &shapeSpec : arraySpec) {
-      HarvestShapeSpec(shapeSpec);
-    }
-  }};
-
-  if (symbol.has<DerivedTypeDetails>()) {
-    if (symbol.scope()) {
-      HarvestSymbolsNeededFromOtherModules(set, *symbol.scope());
-    }
-  } else if (const auto &generic{symbol.detailsIf<GenericDetails>()};
-             generic && generic->derivedType()) {
-    const Symbol &dtSym{*generic->derivedType()};
-    if (dtSym.has<DerivedTypeDetails>()) {
-      if (dtSym.scope()) {
-        HarvestSymbolsNeededFromOtherModules(set, *dtSym.scope());
-      }
-    } else {
-      CHECK(dtSym.has<UseDetails>() || dtSym.has<UseErrorDetails>());
-    }
-  } else if (const auto *object{symbol.detailsIf<ObjectEntityDetails>()}) {
-    HarvestArraySpec(object->shape());
-    HarvestArraySpec(object->coshape());
-    if (IsNamedConstant(symbol) || scope.IsDerivedType()) {
-      if (object->init()) {
-        for (SymbolRef ref : evaluate::CollectSymbols(*object->init())) {
-          set.emplace(*ref);
-        }
-      }
-    }
-  } else if (const auto *proc{symbol.detailsIf<ProcEntityDetails>()}) {
-    if (proc->init() && *proc->init() && scope.IsDerivedType()) {
-      set.emplace(**proc->init());
-    }
-  } else if (const auto *subp{symbol.detailsIf<SubprogramDetails>()}) {
-    for (const Symbol *dummy : subp->dummyArgs()) {
-      if (dummy) {
-        HarvestSymbolsNeededFromOtherModules(set, *dummy, scope);
-      }
-    }
-    if (subp->isFunction()) {
-      HarvestSymbolsNeededFromOtherModules(set, subp->result(), scope);
-    }
-  }
-}
-
 static void HarvestSymbolsNeededFromOtherModules(
     SourceOrderedSymbolSet &set, const Scope &scope) {
-  for (const auto &[_, symbol] : scope) {
-    HarvestSymbolsNeededFromOtherModules(set, *symbol, scope);
+  for (const Symbol &symbol : CollectAllDependences(scope)) {
+    set.insert(symbol);
   }
 }
 
@@ -369,7 +309,7 @@ void ModFileWriter::PutSymbols(
   PrepareRenamings(scope);
   SourceOrderedSymbolSet modules;
   CollectSymbols(scope, sorted, uses, modules);
-  // Write module files for dependencies first so that their
+  // Write module files for dependences first so that their
   // hashes are known.
   for (const Symbol &mod : modules) {
     if (hermeticModules) {
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index f66918e5c140e..f6cbe49f56543 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -7416,7 +7416,8 @@ void DeclarationVisitor::SetType(
 std::optional<DerivedTypeSpec> DeclarationVisitor::ResolveDerivedType(
     const parser::Name &name) {
   Scope &outer{NonDerivedTypeScope()};
-  Symbol *symbol{FindSymbol(outer, name)};
+  Symbol *original{FindSymbol(outer, name)};
+  Symbol *symbol{original};
   Symbol *ultimate{symbol ? &symbol->GetUltimate() : nullptr};
   auto *generic{ultimate ? ultimate->detailsIf<GenericDetails>() : nullptr};
   if (generic) {
@@ -7429,11 +7430,12 @@ std::optional<DerivedTypeSpec> DeclarationVisitor::ResolveDerivedType(
       (generic && &ultimate->owner() == &outer)) {
     if (allowForwardReferenceToDerivedType()) {
       if (!symbol) {
-        symbol = &MakeSymbol(outer, name.source, Attrs{});
+        symbol = original = &MakeSymbol(outer, name.source, Attrs{});
         Resolve(name, *symbol);
       } else if (generic) {
         // forward ref to type with later homonymous generic
-        symbol = &outer.MakeSymbol(name.source, Attrs{}, UnknownDetails{});
+        symbol = original =
+            &outer.MakeSymbol(name.source, Attrs{}, UnknownDetails{});
         generic->set_derivedType(*symbol);
         name.symbol = symbol;
       }
@@ -7453,7 +7455,7 @@ std::optional<DerivedTypeSpec> DeclarationVisitor::ResolveDerivedType(
   if (CheckUseError(name)) {
     return std::nullopt;
   } else if (symbol->GetUltimate().has<DerivedTypeDetails>()) {
-    return DerivedTypeSpec{name.source, *symbol};
+    return DerivedTypeSpec{name.source, *original};
   } else {
     Say(name, "'%s' is not a derived type"_err_en_US);
     return std::nullopt;
diff --git a/flang/lib/Semantics/symbol-set-closure.cpp b/flang/lib/Semantics/symbol-set-closure.cpp
new file mode 100644
index 0000000000000..f8fbf74d803ef
--- /dev/null
+++ b/flang/lib/Semantics/symbol-set-closure.cpp
@@ -0,0 +1,185 @@
+//===-- lib/Semantics/symbol-set-closure.cpp ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Semantics/symbol-set-closure.h"
+#include "flang/Common/idioms.h"
+#include "flang/Common/visit.h"
+
+namespace Fortran::semantics {
+
+class Collector {
+public:
+  explicit Collector(enum DependenceCollectionFlags flags) : flags_{flags} {}
+
+  UnorderedSymbolSet Collected() { return std::move(set_); }
+
+  void operator()(const Symbol &x) { set_.insert(x); }
+
+  void operator()(SymbolRef x) { (*this)(*x); }
+  template <typename A> void operator()(const std::optional<A> &x) {
+    if (x) {
+      (*this)(*x);
+    }
+  }
+  template <typename A> void operator()(const A *x) {
+    if (x) {
+      (*this)(*x);
+    }
+  }
+
+  void operator()(const UnorderedSymbolSet &x) {
+    for (const Symbol &symbol : x) {
+      (*this)(symbol);
+    }
+  }
+  void operator()(const SourceOrderedSymbolSet &x) {
+    for (const Symbol &symbol : x) {
+      (*this)(symbol);
+    }
+  }
+
+  void operator()(const Scope &x) {
+    for (const auto &[_, ref] : x) {
+      (*this)(*ref);
+    }
+  }
+
+  template <typename T> void operator()(const evaluate::Expr<T> &x) {
+    UnorderedSymbolSet exprSyms{evaluate::CollectSymbols(x)};
+    for (const Symbol &sym : exprSyms) {
+      if (!sym.owner().IsDerivedType() || sym.has<DerivedTypeDetails>() ||
+          (flags_ & IncludeComponentsInExprs)) {
+        (*this)(sym);
+      }
+    }
+  }
+
+  void operator()(const DeclTypeSpec &type) {
+    if (type.category() == DeclTypeSpec::Category::Character) {
+      (*this)(type.characterTypeSpec().length());
+    } else {
+      (*this)(type.AsDerived());
+    }
+  }
+
+  void operator()(const DerivedTypeSpec &type) {
+    (*this)(type.originalTypeSymbol());
+    for (const auto &[_, value] : type.parameters()) {
+      (*this)(value);
+    }
+  }
+
+  void operator()(const ParamValue &x) { (*this)(x.GetExplicit()); }
+
+  void operator()(const Bound &x) { (*this)(x.GetExplicit()); }
+
+  void operator()(const ShapeSpec &x) {
+    (*this)(x.lbound());
+    (*this)(x.ubound());
+  }
+
+  void operator()(const ArraySpec &x) {
+    for (const ShapeSpec &shapeSpec : x) {
+      (*this)(shapeSpec);
+    }
+  }
+
+private:
+  UnorderedSymbolSet set_;
+  enum DependenceCollectionFlags flags_ { NoDependenceCollectionFlags };
+};
+
+UnorderedSymbolSet CollectAllDependences(
+    const Scope &scope, enum DependenceCollectionFlags flags) {
+  UnorderedSymbolSet basis;
+  for (const auto &[_, symbol] : scope) {
+    basis.insert(*symbol);
+  }
+  return CollectAllDependences(basis, flags);
+}
+
+UnorderedSymbolSet CollectAllDependences(
+    const UnorderedSymbolSet &original, enum DependenceCollectionFlags flags) {
+  UnorderedSymbolSet result{original};
+  if (flags & IncludeOriginalSymbols) {
+    result = original;
+  }
+  UnorderedSymbolSet work{original};
+  while (!work.empty()) {
+    Collector collect{flags};
+    for (const Symbol &symbol : work) {
+      collect(symbol.GetType());
+      common::visit(
+          common::visitors{
+              [&collect, &symbol](const ObjectEntityDetails &x) {
+                collect(x.shape());
+                collect(x.coshape());
+                if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) {
+                  collect(x.init());
+                }
+              },
+              [&collect, &symbol](const ProcEntityDetails &x) {
+                collect(x.rawProcInterface());
+                if (symbol.owner().IsDerivedType()) {
+                  collect(x.init());
+                }
+              },
+              [&collect](const ProcBindingDetails &x) { collect(x.symbol()); },
+              [&collect](const SubprogramDetails &x) {
+                for (const Symbol *dummy : x.dummyArgs()) {
+                  collect(dummy);
+                }
+                if (x.isFunction()) {
+                  collect(x.result());
+                }
+              },
+              [&collect, &symbol](
+                  const DerivedTypeDetails &) { collect(symbol.scope()); },
+              [&collect, flags](const GenericDetails &x) {
+                collect(x.derivedType());
+                collect(x.specific());
+                if (flags & IncludeSpecificsOfGenerics) {
+                  for (const Symbol &specific : x.specificProcs()) {
+                    collect(specific);
+                  }
+                }
+              },
+              [&collect](const NamelistDetails &x) {
+                for (const Symbol &symbol : x.objects()) {
+                  collect(symbol);
+                }
+              },
+              [&collect](const CommonBlockDetails &x) {
+                for (auto ref : x.objects()) {
+                  collect(*ref);
+                }
+              },
+              [&collect, &symbol, flags](const UseDetails &x) {
+                if (flags & FollowUseAssociations) {
+                  collect(x.symbol());
+                }
+              },
+              [&collect](const HostAssocDetails &x) { collect(x.symbol()); },
+              [](const auto &) {},
+          },
+          symbol.details());
+    }
+    work.clear();
+    for (const Symbol &symbol : collect.Collected()) {
+      if (result.find(symbol) == result.end() &&
+          ((flags & IncludeOriginalSymbols) ||
+              original.find(symbol) == original.end())) {
+        result.insert(symbol);
+        work.insert(symbol);
+      }
+    }
+  }
+  return result;
+}
+
+} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp
index bf520d04a50cc..adaad89b0bcfd 100644
--- a/flang/lib/Semantics/tools.cpp
+++ b/flang/lib/Semantics/tools.cpp
@@ -6,15 +6,15 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "flang/Parser/tools.h"
+#include "flang/Semantics/tools.h"
 #include "flang/Common/indirection.h"
 #include "flang/Parser/dump-parse-tree.h"
 #include "flang/Parser/message.h"
 #include "flang/Parser/parse-tree.h"
+#include "flang/Parser/tools.h"
 #include "flang/Semantics/scope.h"
 #include "flang/Semantics/semantics.h"
 #include "flang/Semantics/symbol.h"
-#include "flang/Semantics/tools.h"
 #include "flang/Semantics/type.h"
 #include "flang/Support/Fortran.h"
 #include "llvm/ADT/StringSwitch.h"
@@ -2117,4 +2117,4 @@ bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x) {
     return false;
   }
 }
-} // namespace Fortran::semantics
\ No newline at end of file
+} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/type.cpp b/flang/lib/Semantics/type.cpp
index 964a37e1c822b..4a56902524417 100644
--- a/flang/lib/Semantics/type.cpp
+++ b/flang/lib/Semantics/type.cpp
@@ -22,9 +22,19 @@
 
 namespace Fortran::semantics {
 
+static const Symbol &ResolveOriginalTypeSymbol(const Symbol *symbol) {
+  symbol = &symbol->GetUltimate();
+  if (const auto *generic{symbol->detailsIf<GenericDetails>()}) {
+    CHECK(generic->derivedType() != nullptr);
+    return generic->derivedType()->GetUltimate();
+  } else {
+    return *symbol;
+  }
+}
+
 DerivedTypeSpec::DerivedTypeSpec(SourceName name, const Symbol &typeSymbol)
     : name_{name}, originalTypeSymbol_{typeSymbol},
-      typeSymbol_{typeSymbol.GetUltimate()} {
+      typeSymbol_{ResolveOriginalTypeSymbol(&typeSymbol)} {
   CHECK(typeSymbol_.has<DerivedTypeDetails>());
 }
 DerivedTypeSpec::DerivedTypeSpec(const DerivedTypeSpec &that) = default;



More information about the flang-commits mailing list