[flang-commits] [flang] [flang] Prefer non-elemental to elemental defined operator resolution (PR #124941)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Wed Jan 29 08:33:07 PST 2025


https://github.com/klausler created https://github.com/llvm/llvm-project/pull/124941

A non-elemental specific procedure must take precedence over an elemental specific procedure in defined operator generic resolution.

Fixes https://github.com/llvm/llvm-project/issues/124777.

>From b6eb0068476374d1f099164fd7cc4adec982902b Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 29 Jan 2025 08:27:03 -0800
Subject: [PATCH] [flang] Prefer non-elemental to elemental defined operator
 resolution

A non-elemental specific procedure must take precedence over an
elemental specific procedure in defined operator generic
resolution.

Fixes https://github.com/llvm/llvm-project/issues/124777.
---
 flang/include/flang/Semantics/expression.h |  3 +-
 flang/lib/Semantics/expression.cpp         | 83 ++++++++++++----------
 flang/test/Semantics/bug12477.f90          | 26 +++++++
 3 files changed, 73 insertions(+), 39 deletions(-)
 create mode 100644 flang/test/Semantics/bug12477.f90

diff --git a/flang/include/flang/Semantics/expression.h b/flang/include/flang/Semantics/expression.h
index bb1674a9f88778..b95ceebc5e8e47 100644
--- a/flang/include/flang/Semantics/expression.h
+++ b/flang/include/flang/Semantics/expression.h
@@ -348,7 +348,8 @@ class ExpressionAnalyzer {
   bool CheckDataRef(const DataRef &); // ditto
   std::optional<Expr<SubscriptInteger>> GetSubstringBound(
       const std::optional<parser::ScalarIntExpr> &);
-  MaybeExpr AnalyzeDefinedOp(const parser::Name &, ActualArguments &&);
+  MaybeExpr AnalyzeDefinedOp(
+      const parser::Name &, ActualArguments &&, const Symbol *&);
   MaybeExpr FixMisparsedSubstring(const parser::Designator &);
 
   struct CalleeAndArguments {
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 3ec6f385ceb86e..73753cfa6fe9bc 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2834,13 +2834,12 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
   // Check for generic or explicit INTRINSIC of the same name in outer scopes.
   // See 15.5.5.2 for details.
   if (!symbol.owner().IsGlobal() && !symbol.owner().IsDerivedType()) {
-    for (const std::string &n : GetAllNames(context_, symbol.name())) {
-      if (const Symbol *outer{symbol.owner().parent().FindSymbol(n)}) {
-        auto pair{ResolveGeneric(*outer, actuals, adjustActuals, isSubroutine,
-            mightBeStructureConstructor)};
-        if (pair.first) {
-          return pair;
-        }
+    if (const Symbol *
+        outer{symbol.owner().parent().FindSymbol(symbol.name())}) {
+      auto pair{ResolveGeneric(*outer, actuals, adjustActuals, isSubroutine,
+          mightBeStructureConstructor)};
+      if (pair.first) {
+        return pair;
       }
     }
   }
@@ -3635,13 +3634,13 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::Concat &x) {
 // The Name represents a user-defined intrinsic operator.
 // If the actuals match one of the specific procedures, return a function ref.
 // Otherwise report the error in messages.
-MaybeExpr ExpressionAnalyzer::AnalyzeDefinedOp(
-    const parser::Name &name, ActualArguments &&actuals) {
+MaybeExpr ExpressionAnalyzer::AnalyzeDefinedOp(const parser::Name &name,
+    ActualArguments &&actuals, const Symbol *&symbol) {
   if (auto callee{GetCalleeAndArguments(name, std::move(actuals))}) {
-    CHECK(std::holds_alternative<ProcedureDesignator>(callee->u));
-    return MakeFunctionRef(name.source,
-        std::move(std::get<ProcedureDesignator>(callee->u)),
-        std::move(callee->arguments));
+    auto &proc{std::get<evaluate::ProcedureDesignator>(callee->u)};
+    symbol = proc.GetSymbol();
+    return MakeFunctionRef(
+        name.source, std::move(proc), std::move(callee->arguments));
   } else {
     return std::nullopt;
   }
@@ -4453,38 +4452,45 @@ MaybeExpr ArgumentAnalyzer::TryDefinedOp(
     parser::Messages buffer;
     auto restorer{context_.GetContextualMessages().SetMessages(buffer)};
     const auto &scope{context_.context().FindScope(source_)};
-    if (Symbol *symbol{scope.FindSymbol(oprName)}) {
+
+    auto FoundOne{[&](MaybeExpr &&thisResult, const Symbol &generic,
+                      const Symbol *resolution) {
       anyPossibilities = true;
-      parser::Name name{symbol->name(), symbol};
-      if (!fatalErrors_) {
-        result = context_.AnalyzeDefinedOp(name, GetActuals());
-      }
-      if (result) {
-        inaccessible = CheckAccessibleSymbol(scope, *symbol);
-        if (inaccessible) {
-          result.reset();
+      if (thisResult) {
+        if (auto thisInaccessible{CheckAccessibleSymbol(scope, generic)}) {
+          inaccessible = thisInaccessible;
         } else {
-          hit.push_back(symbol);
-          hitBuffer = std::move(buffer);
+          bool isElemental{IsElementalProcedure(DEREF(resolution))};
+          bool hitsAreNonElemental{
+              !hit.empty() && !IsElementalProcedure(DEREF(hit[0]))};
+          if (isElemental && hitsAreNonElemental) {
+            // ignore elemental resolutions in favor of a non-elemental one
+          } else {
+            if (!isElemental && !hitsAreNonElemental) {
+              hit.clear();
+            }
+            result = std::move(thisResult);
+            hit.push_back(resolution);
+            hitBuffer = std::move(buffer);
+          }
         }
       }
+    }};
+
+    if (Symbol * generic{scope.FindSymbol(oprName)}; generic && !fatalErrors_) {
+      parser::Name name{generic->name(), generic};
+      const Symbol *resultSymbol{nullptr};
+      MaybeExpr possibleResult{context_.AnalyzeDefinedOp(
+          name, ActualArguments{actuals_}, resultSymbol)};
+      FoundOne(std::move(possibleResult), *generic, resultSymbol);
     }
     for (std::size_t passIndex{0}; passIndex < actuals_.size(); ++passIndex) {
       buffer.clear();
       const Symbol *generic{nullptr};
-      if (const Symbol *binding{
-              FindBoundOp(oprName, passIndex, generic, false)}) {
-        anyPossibilities = true;
-        if (MaybeExpr thisResult{TryBoundOp(*binding, passIndex)}) {
-          if (auto thisInaccessible{
-                  CheckAccessibleSymbol(scope, DEREF(generic))}) {
-            inaccessible = thisInaccessible;
-          } else {
-            result = std::move(thisResult);
-            hit.push_back(binding);
-            hitBuffer = std::move(buffer);
-          }
-        }
+      if (const Symbol *
+          binding{FindBoundOp(
+              oprName, passIndex, generic, /*isSubroutine=*/false)}) {
+        FoundOne(TryBoundOp(*binding, passIndex), DEREF(generic), binding);
       }
     }
   }
@@ -4655,7 +4661,8 @@ std::optional<ProcedureRef> ArgumentAnalyzer::GetDefinedAssignmentProc() {
     }
     for (std::size_t i{0}; !proc && i < actuals_.size(); ++i) {
       const Symbol *generic{nullptr};
-      if (const Symbol *binding{FindBoundOp(oprName, i, generic, true)}) {
+      if (const Symbol *
+          binding{FindBoundOp(oprName, i, generic, /*isSubroutine=*/true)}) {
         if (CheckAccessibleSymbol(scope, DEREF(generic))) {
           // ignore inaccessible type-bound ASSIGNMENT(=) generic
         } else if (const Symbol *
diff --git a/flang/test/Semantics/bug12477.f90 b/flang/test/Semantics/bug12477.f90
new file mode 100644
index 00000000000000..52d079e3b26bc3
--- /dev/null
+++ b/flang/test/Semantics/bug12477.f90
@@ -0,0 +1,26 @@
+!RUN: %flang_fc1 -fsyntax-only %s  2>&1 | FileCheck %s --allow-empty
+!CHECK-NOT: error:
+module m
+  type t
+   contains
+    procedure nonelemental
+    generic :: operator(+) => nonelemental
+  end type
+  interface operator(+)
+    procedure elemental
+  end interface
+ contains
+  type(t) elemental function elemental (a, b)
+    class(t), intent(in) :: a, b
+    elemental = t()
+  end
+  type(t) function nonelemental (a, b)
+    class(t), intent(in) :: a, b(:)
+    nonelemental = t()
+  end
+end
+program main
+  use m
+  type(t) x, y(1)
+  x = x + y ! ok
+end



More information about the flang-commits mailing list