[flang-commits] [flang] [flang] Handle OPTIONAL polymorphic captured in internal procedures (PR #82042)

via flang-commits flang-commits at lists.llvm.org
Fri Feb 16 13:23:34 PST 2024


https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/82042

The current code was doing an unconditional `fir.store %optional_box to %host_link` which caused a crash when %optional_box is absent because is is attempting to copy a descriptor from a null address.

Add code to conditionally do the copy at runtime.

The polymorphic array case with lower bounds can be handled with the array case that already deals with descriptor argument with a few modifications, just use that.

>From 6d57e963f11843967a3bb385932aa6857d2cfc44 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Fri, 16 Feb 2024 12:46:54 -0800
Subject: [PATCH] [flang] Handle OPTIONAL polymorphic captured in internal
 procedure

The current code was doing an unconditional
`fir.store %optional_box to %host_link` which caused a crash when
%optional_box is absent because is is attempting to copy a descriptor
from a null address.

Add code to conditionally do the copy at runtime.
The polymorphic array case with lower bounds can be handled with the
array case that already deals with descriptor argument, just use that.
---
 flang/lib/Lower/HostAssociations.cpp          | 64 +++++++++++----
 flang/lib/Optimizer/Builder/MutableBox.cpp    |  2 +-
 .../HLFIR/internal-procedures-polymorphic.f90 | 81 +++++++++++++++++++
 3 files changed, 131 insertions(+), 16 deletions(-)
 create mode 100644 flang/test/Lower/HLFIR/internal-procedures-polymorphic.f90

diff --git a/flang/lib/Lower/HostAssociations.cpp b/flang/lib/Lower/HostAssociations.cpp
index a62f7a7e99b6ff..44cc0e74e3b52a 100644
--- a/flang/lib/Lower/HostAssociations.cpp
+++ b/flang/lib/Lower/HostAssociations.cpp
@@ -247,9 +247,11 @@ class CapturedCharacterScalars
   }
 };
 
-/// Class defining how polymorphic entities are captured in internal procedures.
-/// Polymorphic entities are always boxed as a fir.class box.
-class CapturedPolymorphic : public CapturedSymbols<CapturedPolymorphic> {
+/// Class defining how polymorphic scalar entities are captured in internal
+/// procedures. Polymorphic entities are always boxed as a fir.class box.
+/// Polymorphic array can be handled in CapturedArrays directly
+class CapturedPolymorphicScalar
+    : public CapturedSymbols<CapturedPolymorphicScalar> {
 public:
   static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
                             const Fortran::semantics::Symbol &sym) {
@@ -257,19 +259,50 @@ class CapturedPolymorphic : public CapturedSymbols<CapturedPolymorphic> {
   }
   static void instantiateHostTuple(const InstantiateHostTuple &args,
                                    Fortran::lower::AbstractConverter &converter,
-                                   const Fortran::semantics::Symbol &) {
+                                   const Fortran::semantics::Symbol &sym) {
     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Location loc = args.loc;
     mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
     assert(typeInTuple && "addrInTuple must be an address");
     mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
                                                 fir::getBase(args.hostValue));
-    builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
+    if (Fortran::semantics::IsOptional(sym)) {
+      auto isPresent =
+          builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), castBox);
+      builder.genIfThenElse(loc, isPresent)
+          .genThen([&]() {
+            builder.create<fir::StoreOp>(loc, castBox, args.addrInTuple);
+          })
+          .genElse([&]() {
+            mlir::Value null = fir::factory::createUnallocatedBox(
+                builder, loc, typeInTuple,
+                /*nonDeferredParams=*/mlir::ValueRange{});
+            builder.create<fir::StoreOp>(loc, null, args.addrInTuple);
+          })
+          .end();
+    } else {
+      builder.create<fir::StoreOp>(loc, castBox, args.addrInTuple);
+    }
   }
   static void getFromTuple(const GetFromTuple &args,
                            Fortran::lower::AbstractConverter &converter,
                            const Fortran::semantics::Symbol &sym,
                            const Fortran::lower::BoxAnalyzer &ba) {
-    bindCapturedSymbol(sym, args.valueInTuple, converter, args.symMap);
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Location loc = args.loc;
+    mlir::Value box = args.valueInTuple;
+    if (Fortran::semantics::IsOptional(sym)) {
+      auto boxTy = box.getType().cast<fir::BaseBoxType>();
+      auto eleTy = boxTy.getEleTy();
+      if (!fir::isa_ref_type(eleTy))
+        eleTy = builder.getRefType(eleTy);
+      auto addr = builder.create<fir::BoxAddrOp>(loc, eleTy, box);
+      mlir::Value isPresent = builder.genIsNotNullAddr(loc, addr);
+      auto absentBox = builder.create<fir::AbsentOp>(loc, boxTy);
+      box =
+          builder.create<mlir::arith::SelectOp>(loc, isPresent, box, absentBox);
+    }
+    bindCapturedSymbol(sym, box, converter, args.symMap);
   }
 };
 
@@ -342,7 +375,12 @@ class CapturedArrays : public CapturedSymbols<CapturedArrays> {
   static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
                             const Fortran::semantics::Symbol &sym) {
     mlir::Type type = converter.genType(sym);
-    assert(type.isa<fir::SequenceType>() && "must be a sequence type");
+    bool isPolymorphic = Fortran::semantics::IsPolymorphic(sym);
+    assert(type.isa<fir::SequenceType>() ||
+           (isPolymorphic && type.isa<fir::ClassType>()) &&
+               "must be a sequence type");
+    if (isPolymorphic)
+      return type;
     return fir::BoxType::get(type);
   }
 
@@ -410,13 +448,13 @@ class CapturedArrays : public CapturedSymbols<CapturedArrays> {
                          fir::factory::readBoxValue(builder, loc, boxValue),
                          converter, args.symMap);
     } else {
-      // Keep variable as a fir.box.
+      // Keep variable as a fir.box/fir.class.
       // If this is an optional that is absent, the fir.box needs to be an
       // AbsentOp result, otherwise it will not work properly with IsPresentOp
       // (absent boxes are null descriptor addresses, not descriptors containing
       // a null base address).
       if (Fortran::semantics::IsOptional(sym)) {
-        auto boxTy = box.getType().cast<fir::BoxType>();
+        auto boxTy = box.getType().cast<fir::BaseBoxType>();
         auto eleTy = boxTy.getEleTy();
         if (!fir::isa_ref_type(eleTy))
           eleTy = builder.getRefType(eleTy);
@@ -470,14 +508,10 @@ walkCaptureCategories(T visitor, Fortran::lower::AbstractConverter &converter,
   ba.analyze(sym);
   if (Fortran::semantics::IsAllocatableOrPointer(sym))
     return CapturedAllocatableAndPointer::visit(visitor, converter, sym, ba);
-  if (Fortran::semantics::IsPolymorphic(sym)) {
-    if (ba.isArray() && !ba.lboundIsAllOnes())
-      TODO(converter.genLocation(sym.name()),
-           "polymorphic array with non default lower bound");
-    return CapturedPolymorphic::visit(visitor, converter, sym, ba);
-  }
   if (ba.isArray())
     return CapturedArrays::visit(visitor, converter, sym, ba);
+  if (Fortran::semantics::IsPolymorphic(sym))
+    return CapturedPolymorphicScalar::visit(visitor, converter, sym, ba);
   if (ba.isChar())
     return CapturedCharacterScalars::visit(visitor, converter, sym, ba);
   assert(ba.isTrivial() && "must be trivial scalar");
diff --git a/flang/lib/Optimizer/Builder/MutableBox.cpp b/flang/lib/Optimizer/Builder/MutableBox.cpp
index 4d8860b60915c4..d4012e9c3d9d93 100644
--- a/flang/lib/Optimizer/Builder/MutableBox.cpp
+++ b/flang/lib/Optimizer/Builder/MutableBox.cpp
@@ -674,7 +674,7 @@ void fir::factory::disassociateMutableBox(fir::FirOpBuilder &builder,
     // 7.3.2.3 point 7. The dynamic type of a disassociated pointer is the
     // same as its declared type.
     auto boxTy = box.getBoxTy().dyn_cast<fir::BaseBoxType>();
-    auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy.getEleTy());
+    auto eleTy = fir::unwrapPassByRefType(boxTy.getEleTy());
     mlir::Type derivedType = fir::getDerivedType(eleTy);
     if (auto recTy = derivedType.dyn_cast<fir::RecordType>()) {
       fir::runtime::genNullifyDerivedType(builder, loc, box.getAddr(), recTy,
diff --git a/flang/test/Lower/HLFIR/internal-procedures-polymorphic.f90 b/flang/test/Lower/HLFIR/internal-procedures-polymorphic.f90
new file mode 100644
index 00000000000000..8645488290d715
--- /dev/null
+++ b/flang/test/Lower/HLFIR/internal-procedures-polymorphic.f90
@@ -0,0 +1,81 @@
+! Test lowering of internal procedure capturing OPTIONAL polymorphic
+! objects.
+! RUN: bbc -emit-hlfir --polymorphic-type -o - %s -I nw | FileCheck %s
+
+
+module captured_optional_polymorphic
+  type sometype
+  end type
+contains
+subroutine test(x, y)
+  class(sometype), optional :: x
+  class(sometype), optional :: y(2:)
+  call internal()
+contains
+  subroutine internal()
+    if (present(x).and.present(y)) then
+      print *, same_type_as(x, y)
+    end if
+  end subroutine
+end
+end module
+
+! CHECK-LABEL:   func.func @_QMcaptured_optional_polymorphicPtest(
+! CHECK:           %[[VAL_2:.*]]:2 = hlfir.declare{{.*}}Ex
+! CHECK:           %[[VAL_3:.*]] = arith.constant 2 : i64
+! CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
+! CHECK:           %[[VAL_5:.*]] = fir.shift %[[VAL_4]] : (index) -> !fir.shift<1>
+! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare{{.*}}Ey
+! CHECK:           %[[VAL_7:.*]] = fir.alloca tuple<!fir.class<!fir.type<_QMcaptured_optional_polymorphicTsometype>>, !fir.class<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>>
+! CHECK:           %[[VAL_8:.*]] = arith.constant 0 : i32
+! CHECK:           %[[VAL_9:.*]] = fir.coordinate_of %[[VAL_7]], %[[VAL_8]]
+! CHECK:           %[[VAL_10:.*]] = fir.is_present %[[VAL_2]]#1 : (!fir.class<!fir.type<_QMcaptured_optional_polymorphicTsometype>>) -> i1
+! CHECK:           fir.if %[[VAL_10]] {
+! CHECK:             fir.store %[[VAL_2]]#1 to %[[VAL_9]] : !fir.ref<!fir.class<!fir.type<_QMcaptured_optional_polymorphicTsometype>>>
+! CHECK:           } else {
+! CHECK:             %[[VAL_11:.*]] = fir.zero_bits !fir.ref<!fir.type<_QMcaptured_optional_polymorphicTsometype>>
+! CHECK:             %[[VAL_12:.*]] = fir.embox %[[VAL_11]] : (!fir.ref<!fir.type<_QMcaptured_optional_polymorphicTsometype>>) -> !fir.class<!fir.type<_QMcaptured_optional_polymorphicTsometype>>
+! CHECK:             fir.store %[[VAL_12]] to %[[VAL_9]] : !fir.ref<!fir.class<!fir.type<_QMcaptured_optional_polymorphicTsometype>>>
+! CHECK:           }
+! CHECK:           %[[VAL_13:.*]] = arith.constant 1 : i32
+! CHECK:           %[[VAL_14:.*]] = fir.coordinate_of %[[VAL_7]], %[[VAL_13]]
+! CHECK:           %[[VAL_15:.*]] = fir.is_present %[[VAL_6]]#1 : (!fir.class<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>) -> i1
+! CHECK:           fir.if %[[VAL_15]] {
+! CHECK:             %[[VAL_16:.*]] = fir.shift %[[VAL_4]] : (index) -> !fir.shift<1>
+! CHECK:             %[[VAL_17:.*]] = fir.rebox %[[VAL_6]]#1(%[[VAL_16]]) : (!fir.class<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>, !fir.shift<1>) -> !fir.class<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>
+! CHECK:             fir.store %[[VAL_17]] to %[[VAL_14]] : !fir.ref<!fir.class<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>>
+! CHECK:           } else {
+! CHECK:             %[[VAL_18:.*]] = fir.type_desc !fir.type<_QMcaptured_optional_polymorphicTsometype>
+! CHECK:             %[[VAL_19:.*]] = fir.convert %[[VAL_14]] : (!fir.ref<!fir.class<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>>) -> !fir.ref<!fir.box<none>>
+! CHECK:             %[[VAL_20:.*]] = fir.convert %[[VAL_18]] : (!fir.tdesc<!fir.type<_QMcaptured_optional_polymorphicTsometype>>) -> !fir.ref<none>
+! CHECK:             %[[VAL_21:.*]] = arith.constant 1 : i32
+! CHECK:             %[[VAL_22:.*]] = arith.constant 0 : i32
+! CHECK:             %[[VAL_23:.*]] = fir.call @_FortranAPointerNullifyDerived(%[[VAL_19]], %[[VAL_20]], %[[VAL_21]], %[[VAL_22]]) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.ref<none>, i32, i32) -> none
+! CHECK:           }
+! CHECK:           fir.call @_QMcaptured_optional_polymorphicFtestPinternal(%[[VAL_7]])
+
+! CHECK-LABEL: func.func{{.*}} @_QMcaptured_optional_polymorphicFtestPinternal(
+! CHECK-SAME:      %[[VAL_0:.*]]: !fir.ref<tuple<{{.*}}>>
+! CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i32
+! CHECK:           %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_1]]
+! CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.class<!fir.type<_QMcaptured_optional_polymorphicTsometype>>>
+! CHECK:           %[[VAL_4:.*]] = fir.box_addr %[[VAL_3]] : (!fir.class<!fir.type<_QMcaptured_optional_polymorphicTsometype>>) -> !fir.ref<!fir.type<_QMcaptured_optional_polymorphicTsometype>>
+! CHECK:           %[[VAL_5:.*]] = fir.convert %[[VAL_4]] : (!fir.ref<!fir.type<_QMcaptured_optional_polymorphicTsometype>>) -> i64
+! CHECK:           %[[VAL_6:.*]] = arith.constant 0 : i64
+! CHECK:           %[[VAL_7:.*]] = arith.cmpi ne, %[[VAL_5]], %[[VAL_6]] : i64
+! CHECK:           %[[VAL_8:.*]] = fir.absent !fir.class<!fir.type<_QMcaptured_optional_polymorphicTsometype>>
+! CHECK:           %[[VAL_9:.*]] = arith.select %[[VAL_7]], %[[VAL_3]], %[[VAL_8]] : !fir.class<!fir.type<_QMcaptured_optional_polymorphicTsometype>>
+! CHECK:           %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]] {fortran_attrs = #fir.var_attrs<optional, host_assoc>, {{.*}}Ex
+! CHECK:           %[[VAL_11:.*]] = arith.constant 1 : i32
+! CHECK:           %[[VAL_12:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_11]]
+! CHECK:           %[[VAL_13:.*]] = fir.load %[[VAL_12]] : !fir.ref<!fir.class<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>>
+! CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
+! CHECK:           %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_14]]
+! CHECK:           %[[VAL_16:.*]] = fir.box_addr %[[VAL_13]]
+! CHECK:           %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (!fir.ref<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>) -> i64
+! CHECK:           %[[VAL_18:.*]] = arith.constant 0 : i64
+! CHECK:           %[[VAL_19:.*]] = arith.cmpi ne, %[[VAL_17]], %[[VAL_18]] : i64
+! CHECK:           %[[VAL_20:.*]] = fir.absent !fir.class<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>
+! CHECK:           %[[VAL_21:.*]] = arith.select %[[VAL_19]], %[[VAL_13]], %[[VAL_20]] : !fir.class<!fir.array<?x!fir.type<_QMcaptured_optional_polymorphicTsometype>>>
+! CHECK:           %[[VAL_22:.*]] = fir.shift %[[VAL_15]]#0 : (index) -> !fir.shift<1>
+! CHECK:           %[[VAL_23:.*]]:2 = hlfir.declare %[[VAL_21]](%[[VAL_22]]) {fortran_attrs = #fir.var_attrs<optional, host_assoc>, {{.*}}Ey



More information about the flang-commits mailing list