[flang-commits] [flang] [flang] Accept CLASS(*) array in EOSHIFT (PR #116114)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Wed Nov 13 16:39:42 PST 2024


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/116114

>From 9327701f6d0d880ecea69ef49df819fd2b64aff0 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 13 Nov 2024 14:30:47 -0800
Subject: [PATCH] [flang] Accept CLASS(*) array in EOSHIFT

The intrinsic processing code wasn't allowing the ARRAY= argument
to the EOSHIFT intrinsic function to be CLASS(*).  That case
seems to conform to the standard, although only one compiler could
actually handle it, so allow for it.

Fixes https://github.com/llvm/llvm-project/issues/115923.
---
 flang/lib/Evaluate/intrinsics.cpp  | 18 +++++++-----------
 flang/runtime/transformational.cpp | 14 +++++++-------
 flang/test/Evaluate/bug115923.f90  | 22 ++++++++++++++++++++++
 3 files changed, 36 insertions(+), 18 deletions(-)
 create mode 100644 flang/test/Evaluate/bug115923.f90

diff --git a/flang/lib/Evaluate/intrinsics.cpp b/flang/lib/Evaluate/intrinsics.cpp
index aa44967817722e..a094ad55294635 100644
--- a/flang/lib/Evaluate/intrinsics.cpp
+++ b/flang/lib/Evaluate/intrinsics.cpp
@@ -168,8 +168,6 @@ static constexpr TypePattern SameCharNoLen{CharType, KindCode::sameKind};
 static constexpr TypePattern SameLogical{LogicalType, KindCode::same};
 static constexpr TypePattern SameRelatable{RelatableType, KindCode::same};
 static constexpr TypePattern SameIntrinsic{IntrinsicType, KindCode::same};
-static constexpr TypePattern SameDerivedType{
-    CategorySet{TypeCategory::Derived}, KindCode::same};
 static constexpr TypePattern SameType{AnyType, KindCode::same};
 
 // Match some kind of some INTEGER or REAL type(s); when argument types
@@ -438,6 +436,12 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
             {"shift", AnyInt}},
         SameInt},
     {"dshiftr", {{"i", BOZ}, {"j", SameInt}, {"shift", AnyInt}}, SameInt},
+    {"eoshift",
+        {{"array", SameType, Rank::array},
+            {"shift", AnyInt, Rank::dimRemovedOrScalar},
+            // BOUNDARY= is not optional for non-intrinsic types
+            {"boundary", SameType, Rank::dimRemovedOrScalar}, OptionalDIM},
+        SameType, Rank::conformable, IntrinsicClass::transformationalFunction},
     {"eoshift",
         {{"array", SameIntrinsic, Rank::array},
             {"shift", AnyInt, Rank::dimRemovedOrScalar},
@@ -446,14 +450,6 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
             OptionalDIM},
         SameIntrinsic, Rank::conformable,
         IntrinsicClass::transformationalFunction},
-    {"eoshift",
-        {{"array", SameDerivedType, Rank::array},
-            {"shift", AnyInt, Rank::dimRemovedOrScalar},
-            // BOUNDARY= is not optional for derived types
-            {"boundary", SameDerivedType, Rank::dimRemovedOrScalar},
-            OptionalDIM},
-        SameDerivedType, Rank::conformable,
-        IntrinsicClass::transformationalFunction},
     {"epsilon",
         {{"x", SameReal, Rank::anyOrAssumedRank, Optionality::required,
             common::Intent::In, {ArgFlag::canBeMoldNull}}},
@@ -1943,7 +1939,7 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
       if (!sameArg) {
         sameArg = arg;
       }
-      argOk = type->IsTkLenCompatibleWith(sameArg->GetType().value());
+      argOk = sameArg->GetType().value().IsTkLenCompatibleWith(*type);
       break;
     case KindCode::sameKind:
       if (!sameArg) {
diff --git a/flang/runtime/transformational.cpp b/flang/runtime/transformational.cpp
index b65502933b862f..ab303bdef9b1d1 100644
--- a/flang/runtime/transformational.cpp
+++ b/flang/runtime/transformational.cpp
@@ -46,7 +46,7 @@ class ShiftControl {
           lb_[k++] = shiftDim.LowerBound();
           if (shiftDim.Extent() != source.GetDimension(j).Extent()) {
             terminator_.Crash("%s: on dimension %d, SHIFT= has extent %jd but "
-                              "SOURCE= has extent %jd",
+                              "ARRAY= has extent %jd",
                 which, k, static_cast<std::intmax_t>(shiftDim.Extent()),
                 static_cast<std::intmax_t>(source.GetDimension(j).Extent()));
           }
@@ -460,7 +460,7 @@ void RTDEF(Cshift)(Descriptor &result, const Descriptor &source,
   RUNTIME_CHECK(terminator, rank > 1);
   if (dim < 1 || dim > rank) {
     terminator.Crash(
-        "CSHIFT: DIM=%d must be >= 1 and <= SOURCE= rank %d", dim, rank);
+        "CSHIFT: DIM=%d must be >= 1 and <= ARRAY= rank %d", dim, rank);
   }
   ShiftControl shiftControl{shift, terminator, dim};
   shiftControl.Init(source, "CSHIFT");
@@ -527,7 +527,7 @@ void RTDEF(Eoshift)(Descriptor &result, const Descriptor &source,
   RUNTIME_CHECK(terminator, rank > 1);
   if (dim < 1 || dim > rank) {
     terminator.Crash(
-        "EOSHIFT: DIM=%d must be >= 1 and <= SOURCE= rank %d", dim, rank);
+        "EOSHIFT: DIM=%d must be >= 1 and <= ARRAY= rank %d", dim, rank);
   }
   std::size_t elementLen{
       AllocateResult(result, source, rank, extent, terminator, "EOSHIFT")};
@@ -538,7 +538,7 @@ void RTDEF(Eoshift)(Descriptor &result, const Descriptor &source,
     RUNTIME_CHECK(terminator, boundary->type() == source.type());
     if (boundary->ElementBytes() != elementLen) {
       terminator.Crash("EOSHIFT: BOUNDARY= has element byte length %zd, but "
-                       "SOURCE= has length %zd",
+                       "ARRAY= has length %zd",
           boundary->ElementBytes(), elementLen);
     }
     if (boundaryRank > 0) {
@@ -547,7 +547,7 @@ void RTDEF(Eoshift)(Descriptor &result, const Descriptor &source,
         if (j != dim - 1) {
           if (boundary->GetDimension(k).Extent() != extent[j]) {
             terminator.Crash("EOSHIFT: BOUNDARY= has extent %jd on dimension "
-                             "%d but must conform with extent %jd of SOURCE=",
+                             "%d but must conform with extent %jd of ARRAY=",
                 static_cast<std::intmax_t>(boundary->GetDimension(k).Extent()),
                 k + 1, static_cast<std::intmax_t>(extent[j]));
           }
@@ -611,7 +611,7 @@ void RTDEF(EoshiftVector)(Descriptor &result, const Descriptor &source,
     RUNTIME_CHECK(terminator, boundary->type() == source.type());
     if (boundary->ElementBytes() != elementLen) {
       terminator.Crash("EOSHIFT: BOUNDARY= has element byte length %zd but "
-                       "SOURCE= has length %zd",
+                       "ARRAY= has length %zd",
           boundary->ElementBytes(), elementLen);
     }
   }
@@ -658,7 +658,7 @@ void RTDEF(Pack)(Descriptor &result, const Descriptor &source,
     RUNTIME_CHECK(terminator, vector->rank() == 1);
     RUNTIME_CHECK(terminator, source.type() == vector->type());
     if (source.ElementBytes() != vector->ElementBytes()) {
-      terminator.Crash("PACK: SOURCE= has element byte length %zd, but VECTOR= "
+      terminator.Crash("PACK: ARRAY= has element byte length %zd, but VECTOR= "
                        "has length %zd",
           source.ElementBytes(), vector->ElementBytes());
     }
diff --git a/flang/test/Evaluate/bug115923.f90 b/flang/test/Evaluate/bug115923.f90
new file mode 100644
index 00000000000000..5d2da806bbd36b
--- /dev/null
+++ b/flang/test/Evaluate/bug115923.f90
@@ -0,0 +1,22 @@
+! RUN: %flang_fc1 -fsyntax-only -pedantic 2>&1 | FileCheck --allow-empty %s
+! Ensure that EOSHIFT's ARRAY= argument and result can be CLASS(*).
+! CHECK-NOT: error:
+! CHECK: warning: Source of TRANSFER is polymorphic
+! CHECK: warning: Mold of TRANSFER is polymorphic
+program p
+  type base
+    integer j
+  end type
+  type, extends(base) :: extended
+    integer k
+  end type
+  class(base), allocatable :: polyArray(:,:,:)
+  class(*), allocatable :: unlimited(:)
+  allocate(polyArray, source=reshape([(extended(n,n-1),n=1,8)],[2,2,2]))
+  allocate(unlimited, source=[(base(9),n=1,16)])
+  select type (x => eoshift(transfer(polyArray, unlimited), -4, base(-1)))
+    type is (base); print *, 'base', x
+    type is (extended); print *, 'extended?', x
+    class default; print *, 'class default??'
+  end select
+end



More information about the flang-commits mailing list