[flang-commits] [flang] [flang] Lower sequence associated argument passed by descriptor (PR #85696)

via flang-commits flang-commits at lists.llvm.org
Mon Mar 18 13:41:25 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: None (jeanPerier)

<details>
<summary>Changes</summary>

The current lowering did not handle sequence associated argument passed by descriptor. This case is special because sequence association implies that the actual and dummy argument need to to agree in rank and shape.
Usually, arguments that can be sequence associated are passed by raw address, and the shape mistmatch is transparent. But there are three cases of explicit and assumed-size arrays passed by descriptors:
 - polymorphic arguments
 - BIND(C) assumed-length arguments (F'2023 18.3.7 (5)).
 - length parametrized derived types (TBD)

The callee side is expecting a descriptor containing the dummy rank and shape. This was not the case. This patch fix that by evaluating the dummy shape on the caller side using the interface (that has to be available when arguments are passed by descriptors).

---

Patch is 54.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85696.diff


7 Files Affected:

- (modified) flang/include/flang/Evaluate/characteristics.h (+8) 
- (modified) flang/include/flang/Lower/CallInterface.h (+35-5) 
- (modified) flang/include/flang/Lower/ConvertVariable.h (+10-3) 
- (modified) flang/lib/Lower/CallInterface.cpp (+110-24) 
- (modified) flang/lib/Lower/ConvertCall.cpp (+148-10) 
- (modified) flang/lib/Lower/ConvertVariable.cpp (+22-7) 
- (added) flang/test/Lower/HLFIR/call-sequence-associated-descriptors.f90 (+309) 


``````````diff
diff --git a/flang/include/flang/Evaluate/characteristics.h b/flang/include/flang/Evaluate/characteristics.h
index f2f37866ecde86..82c31c0c404301 100644
--- a/flang/include/flang/Evaluate/characteristics.h
+++ b/flang/include/flang/Evaluate/characteristics.h
@@ -177,6 +177,14 @@ class TypeAndShape {
   int corank() const { return corank_; }
 
   int Rank() const { return GetRank(shape_); }
+
+  // Can sequence association apply to this argument?
+  bool CanBeSequenceAssociated() const {
+    constexpr Attrs notAssumedOrExplicitShape{
+        ~Attrs{Attr::AssumedSize, Attr::Coarray}};
+    return Rank() > 0 && (attrs() & notAssumedOrExplicitShape).none();
+  }
+
   bool IsCompatibleWith(parser::ContextualMessages &, const TypeAndShape &that,
       const char *thisIs = "pointer", const char *thatIs = "target",
       bool omitShapeConformanceCheck = false,
diff --git a/flang/include/flang/Lower/CallInterface.h b/flang/include/flang/Lower/CallInterface.h
index e77ac4e179ba86..fbffeb8d4938c8 100644
--- a/flang/include/flang/Lower/CallInterface.h
+++ b/flang/include/flang/Lower/CallInterface.h
@@ -174,6 +174,12 @@ class CallInterface {
     /// May the dummy argument require INTENT(OUT) finalization
     /// on entry to the invoked procedure? Provides conservative answer.
     bool mayRequireIntentoutFinalization() const;
+    /// Is the dummy argument an explicit-shape or assumed-size array that
+    /// must be passed by descriptor? Sequence association imply the actual
+    /// argument shape/rank may differ with the dummy shape/rank (see F'2023
+    /// section 15.5.2.12), so care is needed when creating the descriptor
+    /// for the dummy argument.
+    bool isSequenceAssociatedDescriptor() const;
     /// How entity is passed by.
     PassEntityBy passBy;
     /// What is the entity (SymbolRef for callee/ActualArgument* for caller)
@@ -273,8 +279,6 @@ class CallerInterface : public CallInterface<CallerInterface> {
     actualInputs.resize(getNumFIRArguments());
   }
 
-  using ExprVisitor = std::function<void(evaluate::Expr<evaluate::SomeType>)>;
-
   /// CRTP callbacks
   bool hasAlternateReturns() const;
   std::string getMangledName() const;
@@ -312,12 +316,21 @@ class CallerInterface : public CallInterface<CallerInterface> {
   /// procedure.
   const Fortran::semantics::Symbol *getProcedureSymbol() const;
 
+  /// Return the dummy argument symbol if this is a call to a user
+  /// defined procedure with explicit interface. Returns nullptr if there
+  /// is no user defined explicit interface.
+  const Fortran::semantics::Symbol *
+  getDummySymbol(const PassedEntity &entity) const;
+
   /// Helpers to place the lowered arguments at the right place once they
   /// have been lowered.
   void placeInput(const PassedEntity &passedEntity, mlir::Value arg);
   void placeAddressAndLengthInput(const PassedEntity &passedEntity,
                                   mlir::Value addr, mlir::Value len);
 
+  /// Get lowered argument FIR argument given the Fortran argument.
+  mlir::Value getInput(const PassedEntity &passedEntity);
+
   /// If this is a call to a procedure pointer or dummy, returns the related
   /// procedure designator. Nullptr otherwise.
   const Fortran::evaluate::ProcedureDesignator *getIfIndirectCall() const;
@@ -333,13 +346,27 @@ class CallerInterface : public CallInterface<CallerInterface> {
   /// the result specification expressions (extents and lengths) ? If needed,
   /// this mapping must be done after argument lowering, and before the call
   /// itself.
-  bool mustMapInterfaceSymbols() const;
+  bool mustMapInterfaceSymbolsForResult() const;
+  /// Does the caller must map function interface symbols in order to evaluate
+  /// the specification expressions of a given dummy argument?
+  bool mustMapInterfaceSymbolsForDummyArgument(const PassedEntity &) const;
+
+  /// Visitor for specification expression. Boolean indicate the specification
+  /// expression is for the last extent of an assumed size array.
+  using ExprVisitor =
+      std::function<void(evaluate::Expr<evaluate::SomeType>, bool)>;
 
   /// Walk the result non-deferred extent specification expressions.
-  void walkResultExtents(ExprVisitor) const;
+  void walkResultExtents(const ExprVisitor &) const;
 
   /// Walk the result non-deferred length specification expressions.
-  void walkResultLengths(ExprVisitor) const;
+  void walkResultLengths(const ExprVisitor &) const;
+  /// Walk non-deferred extent specification expressions of a dummy argument.
+  void walkDummyArgumentExtents(const PassedEntity &,
+                                const ExprVisitor &) const;
+  /// Walk non-deferred length specification expressions of a dummy argument.
+  void walkDummyArgumentLengths(const PassedEntity &,
+                                const ExprVisitor &) const;
 
   /// Get the mlir::Value that is passed as argument \p sym of the function
   /// being called. The arguments must have been placed before calling this
@@ -355,6 +382,9 @@ class CallerInterface : public CallInterface<CallerInterface> {
   /// returns the storage type.
   mlir::Type getResultStorageType() const;
 
+  /// Return FIR type of argument.
+  mlir::Type getDummyArgumentType(const PassedEntity &) const;
+
   // Copy of base implementation.
   static constexpr bool hasHostAssociated() { return false; }
   mlir::Type getHostAssociatedTy() const {
diff --git a/flang/include/flang/Lower/ConvertVariable.h b/flang/include/flang/Lower/ConvertVariable.h
index b13bb412f0f3e7..ab30e317d1d9d4 100644
--- a/flang/include/flang/Lower/ConvertVariable.h
+++ b/flang/include/flang/Lower/ConvertVariable.h
@@ -93,9 +93,16 @@ void mapSymbolAttributes(AbstractConverter &, const semantics::SymbolRef &,
 /// Instantiate the variables that appear in the specification expressions
 /// of the result of a function call. The instantiated variables are added
 /// to \p symMap.
-void mapCallInterfaceSymbols(AbstractConverter &,
-                             const Fortran::lower::CallerInterface &caller,
-                             SymMap &symMap);
+void mapCallInterfaceSymbolsForResult(
+    AbstractConverter &, const Fortran::lower::CallerInterface &caller,
+    SymMap &symMap);
+
+/// Instantiate the variables that appear in the specification expressions
+/// of a dummy argument of a procedure call. The instantiated variables are
+/// added to \p symMap.
+void mapCallInterfaceSymbolsForDummyArgument(
+    AbstractConverter &, const Fortran::lower::CallerInterface &caller,
+    SymMap &symMap, const Fortran::semantics::Symbol &dummySymbol);
 
 // TODO: consider saving the initial expression symbol dependence analysis in
 // in the PFT variable and dealing with the dependent symbols instantiation in
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index 6b71aabf7fdc89..2f95d53c383b9a 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -310,21 +310,22 @@ bool Fortran::lower::CallerInterface::verifyActualInputs() const {
   return true;
 }
 
-void Fortran::lower::CallerInterface::walkResultLengths(
-    ExprVisitor visitor) const {
-  assert(characteristic && "characteristic was not computed");
-  const Fortran::evaluate::characteristics::FunctionResult &result =
-      characteristic->functionResult.value();
-  const Fortran::evaluate::characteristics::TypeAndShape *typeAndShape =
-      result.GetTypeAndShape();
-  assert(typeAndShape && "no result type");
-  Fortran::evaluate::DynamicType dynamicType = typeAndShape->type();
-  // Visit result length specification expressions that are explicit.
+mlir::Value
+Fortran::lower::CallerInterface::getInput(const PassedEntity &passedEntity) {
+  return actualInputs[passedEntity.firArgument];
+}
+
+static void walkLengths(
+    const Fortran::evaluate::characteristics::TypeAndShape &typeAndShape,
+    const Fortran::lower::CallerInterface::ExprVisitor &visitor,
+    Fortran::lower::AbstractConverter &converter) {
+  Fortran::evaluate::DynamicType dynamicType = typeAndShape.type();
+  // Visit length specification expressions that are explicit.
   if (dynamicType.category() == Fortran::common::TypeCategory::Character) {
     if (std::optional<Fortran::evaluate::ExtentExpr> length =
             dynamicType.GetCharLength())
-      visitor(toEvExpr(*length));
-  } else if (dynamicType.category() == common::TypeCategory::Derived &&
+      visitor(toEvExpr(*length), /*assumedSize=*/false);
+  } else if (dynamicType.category() == Fortran::common::TypeCategory::Derived &&
              !dynamicType.IsUnlimitedPolymorphic()) {
     const Fortran::semantics::DerivedTypeSpec &derivedTypeSpec =
         dynamicType.GetDerivedTypeSpec();
@@ -334,11 +335,33 @@ void Fortran::lower::CallerInterface::walkResultLengths(
   }
 }
 
+void Fortran::lower::CallerInterface::walkResultLengths(
+    const ExprVisitor &visitor) const {
+  assert(characteristic && "characteristic was not computed");
+  const Fortran::evaluate::characteristics::FunctionResult &result =
+      characteristic->functionResult.value();
+  const Fortran::evaluate::characteristics::TypeAndShape *typeAndShape =
+      result.GetTypeAndShape();
+  assert(typeAndShape && "no result type");
+  return walkLengths(*typeAndShape, visitor, converter);
+}
+
+void Fortran::lower::CallerInterface::walkDummyArgumentLengths(
+    const PassedEntity &passedEntity, const ExprVisitor &visitor) const {
+  if (!passedEntity.characteristics)
+    return;
+  if (const auto *dummy =
+          std::get_if<Fortran::evaluate::characteristics::DummyDataObject>(
+              &passedEntity.characteristics->u))
+    walkLengths(dummy->type, visitor, converter);
+}
+
 // Compute extent expr from shapeSpec of an explicit shape.
-// TODO: Allow evaluate shape analysis to work in a mode where it disregards
-// the non-constant aspects when building the shape to avoid having this here.
 static Fortran::evaluate::ExtentExpr
 getExtentExpr(const Fortran::semantics::ShapeSpec &shapeSpec) {
+  if (shapeSpec.ubound().isStar())
+    // F'2023 18.5.3 point 5.
+    return Fortran::evaluate::ExtentExpr{-1};
   const auto &ubound = shapeSpec.ubound().GetExplicit();
   const auto &lbound = shapeSpec.lbound().GetExplicit();
   assert(lbound && ubound && "shape must be explicit");
@@ -346,20 +369,27 @@ getExtentExpr(const Fortran::semantics::ShapeSpec &shapeSpec) {
          Fortran::evaluate::ExtentExpr{1};
 }
 
+static void
+walkExtents(const Fortran::semantics::Symbol &symbol,
+            const Fortran::lower::CallerInterface::ExprVisitor &visitor) {
+  if (const auto *objectDetails =
+          symbol.detailsIf<Fortran::semantics::ObjectEntityDetails>())
+    if (objectDetails->shape().IsExplicitShape() ||
+        Fortran::semantics::IsAssumedSizeArray(symbol))
+      for (const Fortran::semantics::ShapeSpec &shapeSpec :
+           objectDetails->shape())
+        visitor(Fortran::evaluate::AsGenericExpr(getExtentExpr(shapeSpec)),
+                /*assumedSize=*/shapeSpec.ubound().isStar());
+}
+
 void Fortran::lower::CallerInterface::walkResultExtents(
-    ExprVisitor visitor) const {
+    const ExprVisitor &visitor) const {
   // Walk directly the result symbol shape (the characteristic shape may contain
   // descriptor inquiries to it that would fail to lower on the caller side).
   const Fortran::semantics::SubprogramDetails *interfaceDetails =
       getInterfaceDetails();
   if (interfaceDetails) {
-    const Fortran::semantics::Symbol &result = interfaceDetails->result();
-    if (const auto *objectDetails =
-            result.detailsIf<Fortran::semantics::ObjectEntityDetails>())
-      if (objectDetails->shape().IsExplicitShape())
-        for (const Fortran::semantics::ShapeSpec &shapeSpec :
-             objectDetails->shape())
-          visitor(Fortran::evaluate::AsGenericExpr(getExtentExpr(shapeSpec)));
+    walkExtents(interfaceDetails->result(), visitor);
   } else {
     if (procRef.Rank() != 0)
       fir::emitFatalError(
@@ -368,7 +398,18 @@ void Fortran::lower::CallerInterface::walkResultExtents(
   }
 }
 
-bool Fortran::lower::CallerInterface::mustMapInterfaceSymbols() const {
+void Fortran::lower::CallerInterface::walkDummyArgumentExtents(
+    const PassedEntity &passedEntity, const ExprVisitor &visitor) const {
+  const Fortran::semantics::SubprogramDetails *interfaceDetails =
+      getInterfaceDetails();
+  if (!interfaceDetails)
+    return;
+  const Fortran::semantics::Symbol *dummy = getDummySymbol(passedEntity);
+  assert(dummy && "dummy symbol was not set");
+  walkExtents(*dummy, visitor);
+}
+
+bool Fortran::lower::CallerInterface::mustMapInterfaceSymbolsForResult() const {
   assert(characteristic && "characteristic was not computed");
   const std::optional<Fortran::evaluate::characteristics::FunctionResult>
       &result = characteristic->functionResult;
@@ -376,7 +417,7 @@ bool Fortran::lower::CallerInterface::mustMapInterfaceSymbols() const {
       !getInterfaceDetails() || result->IsProcedurePointer())
     return false;
   bool allResultSpecExprConstant = true;
-  auto visitor = [&](const Fortran::lower::SomeExpr &e) {
+  auto visitor = [&](const Fortran::lower::SomeExpr &e, bool) {
     allResultSpecExprConstant &= Fortran::evaluate::IsConstantExpr(e);
   };
   walkResultLengths(visitor);
@@ -384,6 +425,17 @@ bool Fortran::lower::CallerInterface::mustMapInterfaceSymbols() const {
   return !allResultSpecExprConstant;
 }
 
+bool Fortran::lower::CallerInterface::mustMapInterfaceSymbolsForDummyArgument(
+    const PassedEntity &arg) const {
+  bool allResultSpecExprConstant = true;
+  auto visitor = [&](const Fortran::lower::SomeExpr &e, bool) {
+    allResultSpecExprConstant &= Fortran::evaluate::IsConstantExpr(e);
+  };
+  walkDummyArgumentLengths(arg, visitor);
+  walkDummyArgumentExtents(arg, visitor);
+  return !allResultSpecExprConstant;
+}
+
 mlir::Value Fortran::lower::CallerInterface::getArgumentValue(
     const semantics::Symbol &sym) const {
   mlir::Location loc = converter.getCurrentLocation();
@@ -401,6 +453,24 @@ mlir::Value Fortran::lower::CallerInterface::getArgumentValue(
   return actualInputs[mlirArgIndex];
 }
 
+const Fortran::semantics::Symbol *
+Fortran::lower::CallerInterface::getDummySymbol(
+    const PassedEntity &passedEntity) const {
+  const Fortran::semantics::SubprogramDetails *ifaceDetails =
+      getInterfaceDetails();
+  if (!ifaceDetails)
+    return nullptr;
+  std::size_t argPosition = 0;
+  for (const auto &arg : getPassedArguments()) {
+    if (&arg == &passedEntity)
+      break;
+    ++argPosition;
+  }
+  if (argPosition >= ifaceDetails->dummyArgs().size())
+    return nullptr;
+  return ifaceDetails->dummyArgs()[argPosition];
+}
+
 mlir::Type Fortran::lower::CallerInterface::getResultStorageType() const {
   if (passedResult)
     return fir::dyn_cast_ptrEleTy(inputs[passedResult->firArgument].type);
@@ -408,6 +478,11 @@ mlir::Type Fortran::lower::CallerInterface::getResultStorageType() const {
   return outputs[0].type;
 }
 
+mlir::Type Fortran::lower::CallerInterface::getDummyArgumentType(
+    const PassedEntity &passedEntity) const {
+  return inputs[passedEntity.firArgument].type;
+}
+
 const Fortran::semantics::Symbol &
 Fortran::lower::CallerInterface::getResultSymbol() const {
   mlir::Location loc = converter.getCurrentLocation();
@@ -1387,6 +1462,17 @@ bool Fortran::lower::CallInterface<
   return Fortran::semantics::IsFinalizable(*derived);
 }
 
+template <typename T>
+bool Fortran::lower::CallInterface<
+    T>::PassedEntity::isSequenceAssociatedDescriptor() const {
+  if (!characteristics || passBy != PassEntityBy::Box)
+    return false;
+  const auto *dummy =
+      std::get_if<Fortran::evaluate::characteristics::DummyDataObject>(
+          &characteristics->u);
+  return dummy && dummy->type.CanBeSequenceAssociated();
+}
+
 template <typename T>
 void Fortran::lower::CallInterface<T>::determineInterface(
     bool isImplicit,
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 95569337a06e90..6eba243c237cf2 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -164,6 +164,125 @@ static mlir::Value readDim3Value(fir::FirOpBuilder &builder, mlir::Location loc,
   return hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{designate});
 }
 
+static mlir::Value remapActualToDummyDescriptor(
+    mlir::Location loc, Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::SymMap &symMap,
+    const Fortran::lower::CallerInterface::PassedEntity &arg,
+    Fortran::lower::CallerInterface &caller, bool isBindcCall) {
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  mlir::IndexType idxTy = builder.getIndexType();
+  mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
+  Fortran::lower::StatementContext localStmtCtx;
+  auto lowerSpecExpr = [&](const auto &expr,
+                           bool isAssumedSizeExtent) -> mlir::Value {
+    mlir::Value convertExpr = builder.createConvert(
+        loc, idxTy, fir::getBase(converter.genExprValue(expr, localStmtCtx)));
+    if (isAssumedSizeExtent)
+      return convertExpr;
+    return fir::factory::genMaxWithZero(builder, loc, convertExpr);
+  };
+  bool mapSymbols = caller.mustMapInterfaceSymbolsForDummyArgument(arg);
+  if (mapSymbols) {
+    symMap.pushScope();
+    const Fortran::semantics::Symbol *sym = caller.getDummySymbol(arg);
+    assert(sym && "call must have explicit interface to map interface symbols");
+    Fortran::lower::mapCallInterfaceSymbolsForDummyArgument(converter, caller,
+                                                            symMap, *sym);
+  }
+  llvm::SmallVector<mlir::Value> extents;
+  llvm::SmallVector<mlir::Value> lengths;
+  mlir::Type dummyBoxType = caller.getDummyArgumentType(arg);
+  mlir::Type dummyBaseType = fir::unwrapPassByRefType(dummyBoxType);
+  if (dummyBaseType.isa<fir::SequenceType>())
+    caller.walkDummyArgumentExtents(
+        arg, [&](const Fortran::lower::SomeExpr &e, bool isAssumedSizeExtent) {
+          extents.emplace_back(lowerSpecExpr(e, isAssumedSizeExtent));
+        });
+  mlir::Value shape;
+  if (!extents.empty()) {
+    if (isBindcCall) {
+      // Preserve zero lower bounds (see F'2023 18.5.3).
+      llvm::SmallVector<mlir::Value> lowerBounds(extents.size(), zero);
+      shape = builder.genShape(loc, lowerBounds, extents);
+    } else {
+      shape = builder.genShape(loc, extents);
+    }
+  }
+
+  hlfir::Entity explicitArgument = hlfir::Entity{caller.getInput(arg)};
+  mlir::Type dummyElementType = fir::unwrapSequenceType(dummyBaseType);
+  if (auto recType = llvm::dyn_cast<fir::RecordType>(dummyElementType))
+    if (recType.getNumLenParams() > 0)
+      TODO(loc, "sequence association of length parameterized derived type "
+                "dummy arguments");
+  if (fir::isa_char(dummyElementType))
+    lengths.emplace_back(hlfir::genCharLength(loc, builder, explicitArgument));
+  mlir::Value baseAddr =
+      hlfir::genVariableRawAddress(loc, builder, explicitArgument);
+  baseAddr = builder.createConvert(loc, fir::ReferenceType::get(dummyBaseType),
+                                   baseAddr);
+  mlir::Value mold;
+  if (fir::isPolymorphicType(dummyBoxType))
+    mold = explicitArgument;
+  mlir::Value remapped =
+      builder.create<fir::EmboxOp>(loc, dummyBoxType, baseAddr, shape,
+                                   /*slice=*/mlir::Value{}, lengths, mold);
+  if (mapSymbols)
+    symMap.popScope();
+  return remapped;
+}
+
+/// Create a descriptor for sequenced associated descriptor that are passed
+/// by descriptor. Sequence association (F'2023 15.5.2.12) implies that the
+/// dummy shape and rank need to not be the same as the actual argument. This
+/// helper creates a descriptor based on the dummy shape and rank (sequence
+/// association can only happen with explicit and assumed-size array) so that it
+/// is safe to assume the rank of the incoming descriptor inside the callee.
+/// This helper must be called once all the actual arguments have been lowered
+/// and placed inside "caller". Copy-in/copy-out must already have been
+/// generated if needed using the actual argument shape (the dummy shape may be
+/// assumed-size).
+static void remapActualToDummyDescriptors(
+    mlir::Location loc, Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::SymMap &symMap,
+    const Fortran::lower::Prepa...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/85696


More information about the flang-commits mailing list