[flang-commits] [flang] da7c77b - [flang] Handle lowering arguments in subroutine and function

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Wed Feb 16 11:28:15 PST 2022


Author: Valentin Clement
Date: 2022-02-16T20:28:07+01:00
New Revision: da7c77b82c217592cc14f5b5a3c6a9e6741896af

URL: https://github.com/llvm/llvm-project/commit/da7c77b82c217592cc14f5b5a3c6a9e6741896af
DIFF: https://github.com/llvm/llvm-project/commit/da7c77b82c217592cc14f5b5a3c6a9e6741896af.diff

LOG: [flang] Handle lowering arguments in subroutine and function

This patch adds infrsatrcutrue to be able to lower
arguments in functions and subroutines.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D119957

Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Co-authored-by: Jean Perier <jperier at nvidia.com>

Added: 
    flang/test/Lower/arguments.f90

Modified: 
    flang/include/flang/Lower/CallInterface.h
    flang/lib/Lower/Bridge.cpp
    flang/lib/Lower/CallInterface.cpp
    flang/lib/Lower/ConvertVariable.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/CallInterface.h b/flang/include/flang/Lower/CallInterface.h
index a8f08ac4528d2..896fde850e7ab 100644
--- a/flang/include/flang/Lower/CallInterface.h
+++ b/flang/include/flang/Lower/CallInterface.h
@@ -85,6 +85,26 @@ class CallInterface {
   friend CallInterfaceImpl<T>;
 
 public:
+  /// Enum the 
diff erent ways an entity can be passed-by
+  enum class PassEntityBy {
+    BaseAddress,
+    BoxChar,
+    // passing a read-only descriptor
+    Box,
+    // passing a writable descriptor
+    MutableBox,
+    AddressAndLength,
+    /// Value means passed by value at the mlir level, it is not necessarily
+    /// implied by Fortran Value attribute.
+    Value,
+    /// ValueAttribute means dummy has the the Fortran VALUE attribute.
+    BaseAddressValueAttribute,
+    CharBoxValueAttribute, // BoxChar with VALUE
+    // Passing a character procedure as a <procedure address, result length>
+    // tuple.
+    CharProcTuple
+  };
+
   /// Different properties of an entity that can be passed/returned.
   /// One-to-One mapping with PassEntityBy but for
   /// PassEntityBy::AddressAndLength that has two properties.
@@ -105,8 +125,10 @@ class CallInterface {
   /// FirPlaceHolder are place holders for the mlir inputs and outputs that are
   /// created during the first pass before the mlir::FuncOp is created.
   struct FirPlaceHolder {
-    FirPlaceHolder(mlir::Type t, int passedPosition, Property p)
-        : type{t}, passedEntityPosition{passedPosition}, property{p} {}
+    FirPlaceHolder(mlir::Type t, int passedPosition, Property p,
+                   llvm::ArrayRef<mlir::NamedAttribute> attrs)
+        : type{t}, passedEntityPosition{passedPosition}, property{p},
+          attributes{attrs.begin(), attrs.end()} {}
     /// Type for this input/output
     mlir::Type type;
     /// Position of related passedEntity in passedArguments.
@@ -116,8 +138,41 @@ class CallInterface {
     /// Indicate property of the entity passedEntityPosition that must be passed
     /// through this argument.
     Property property;
+    /// MLIR attributes for this argument
+    llvm::SmallVector<mlir::NamedAttribute> attributes;
   };
 
+  /// PassedEntity is what is provided back to the CallInterface user.
+  /// It describe how the entity is plugged in the interface
+  struct PassedEntity {
+    /// Is the dummy argument optional ?
+    bool isOptional() const;
+    /// Can the argument be modified by the callee ?
+    bool mayBeModifiedByCall() const;
+    /// Can the argument be read by the callee ?
+    bool mayBeReadByCall() const;
+    /// How entity is passed by.
+    PassEntityBy passBy;
+    /// What is the entity (SymbolRef for callee/ActualArgument* for caller)
+    /// What is the related mlir::FuncOp argument(s) (mlir::Value for callee /
+    /// index for the caller).
+    FortranEntity entity;
+    FirValue firArgument;
+    FirValue firLength; /* only for AddressAndLength */
+
+    /// Pointer to the argument characteristics. Nullptr for results.
+    const Fortran::evaluate::characteristics::DummyArgument *characteristics =
+        nullptr;
+  };
+
+  /// Return a container of Symbol/ActualArgument* and how they must
+  /// be plugged with the mlir::FuncOp.
+  llvm::ArrayRef<PassedEntity> getPassedArguments() const {
+    return passedArguments;
+  }
+  /// In case the result must be passed by the caller, indicate how.
+  /// nullopt if the result is not passed by the caller.
+  std::optional<PassedEntity> getPassedResult() const { return passedResult; }
   /// Returns the mlir function type
   mlir::FunctionType genFunctionType();
 
@@ -134,9 +189,16 @@ class CallInterface {
   /// Entry point to be called by child ctor to analyze the signature and
   /// create/find the mlir::FuncOp. Child needs to be initialized first.
   void declare();
+  /// Second pass entry point, once the mlir::FuncOp is created.
+  /// Nothing is done if it was already called.
+  void mapPassedEntities();
+  void mapBackInputToPassedEntity(const FirPlaceHolder &, FirValue);
 
   llvm::SmallVector<FirPlaceHolder> outputs;
+  llvm::SmallVector<FirPlaceHolder> inputs;
   mlir::FuncOp func;
+  llvm::SmallVector<PassedEntity> passedArguments;
+  std::optional<PassedEntity> passedResult;
 
   Fortran::lower::AbstractConverter &converter;
   /// Store characteristic once created, it is required for further information
@@ -165,6 +227,10 @@ class CalleeInterface : public CallInterface<CalleeInterface> {
   Fortran::evaluate::characteristics::Procedure characterize() const;
   bool isMainProgram() const;
 
+  Fortran::lower::pft::FunctionLikeUnit &getCallDescription() const {
+    return funit;
+  }
+
   /// On the callee side it does not matter whether the procedure is
   /// called through pointers or not.
   bool isIndirectCall() const { return false; }

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 6e7f56c50ada5..cfb326c3af483 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -227,6 +227,59 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     localSymbols.clear();
   }
 
+  /// Map mlir function block arguments to the corresponding Fortran dummy
+  /// variables. When the result is passed as a hidden argument, the Fortran
+  /// result is also mapped. The symbol map is used to hold this mapping.
+  void mapDummiesAndResults(Fortran::lower::pft::FunctionLikeUnit &funit,
+                            const Fortran::lower::CalleeInterface &callee) {
+    assert(builder && "require a builder object at this point");
+    using PassBy = Fortran::lower::CalleeInterface::PassEntityBy;
+    auto mapPassedEntity = [&](const auto arg) -> void {
+      if (arg.passBy == PassBy::AddressAndLength) {
+        // // TODO: now that fir call has some attributes regarding character
+        // // return, PassBy::AddressAndLength should be retired.
+        // mlir::Location loc = toLocation();
+        // fir::factory::CharacterExprHelper charHelp{*builder, loc};
+        // mlir::Value box =
+        //     charHelp.createEmboxChar(arg.firArgument, arg.firLength);
+        // addSymbol(arg.entity->get(), box);
+      } else {
+        if (arg.entity.has_value()) {
+          addSymbol(arg.entity->get(), arg.firArgument);
+        } else {
+          // assert(funit.parentHasHostAssoc());
+          // funit.parentHostAssoc().internalProcedureBindings(*this,
+          //                                                   localSymbols);
+        }
+      }
+    };
+    for (const Fortran::lower::CalleeInterface::PassedEntity &arg :
+         callee.getPassedArguments())
+      mapPassedEntity(arg);
+
+    // Allocate local skeleton instances of dummies from other entry points.
+    // Most of these locals will not survive into final generated code, but
+    // some will.  It is illegal to reference them at run time if they do.
+    for (const Fortran::semantics::Symbol *arg :
+         funit.nonUniversalDummyArguments) {
+      if (lookupSymbol(*arg))
+        continue;
+      mlir::Type type = genType(*arg);
+      // TODO: Account for VALUE arguments (and possibly other variants).
+      type = builder->getRefType(type);
+      addSymbol(*arg, builder->create<fir::UndefOp>(toLocation(), type));
+    }
+    if (std::optional<Fortran::lower::CalleeInterface::PassedEntity>
+            passedResult = callee.getPassedResult()) {
+      mapPassedEntity(*passedResult);
+      // FIXME: need to make sure things are OK here. addSymbol may not be OK
+      if (funit.primaryResult &&
+          passedResult->entity->get() != *funit.primaryResult)
+        addSymbol(*funit.primaryResult,
+                  getSymbolAddress(passedResult->entity->get()));
+    }
+  }
+
   /// Instantiate variable \p var and add it to the symbol map.
   /// See ConvertVariable.cpp.
   void instantiateVar(const Fortran::lower::pft::Variable &var) {
@@ -243,6 +296,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     assert(builder && "FirOpBuilder did not instantiate");
     builder->setInsertionPointToStart(&func.front());
 
+    mapDummiesAndResults(funit, callee);
+
     for (const Fortran::lower::pft::Variable &var :
          funit.getOrderedSymbolTable()) {
       const Fortran::semantics::Symbol &sym = var.getSymbol();
@@ -319,6 +374,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return {};
   }
 
+  /// Add the symbol to the local map and return `true`. If the symbol is
+  /// already in the map and \p forced is `false`, the map is not updated.
+  /// Instead the value `false` is returned.
+  bool addSymbol(const Fortran::semantics::SymbolRef sym, mlir::Value val,
+                 bool forced = false) {
+    if (!forced && lookupSymbol(sym))
+      return false;
+    localSymbols.addSymbol(sym, val, forced);
+    return true;
+  }
+
   void genFIRBranch(mlir::Block *targetBlock) {
     assert(targetBlock && "missing unconditional target block");
     builder->create<cf::BranchOp>(toLocation(), targetBlock);

diff  --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index 8bf110cf2daf7..93c8f02bc7039 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -77,6 +77,7 @@ mlir::FuncOp Fortran::lower::CalleeInterface::addEntryBlockAndMapArguments() {
   // On the callee side, directly map the mlir::value argument of
   // the function block to the Fortran symbols.
   func.addEntryBlock();
+  mapPassedEntities();
   return func;
 }
 
@@ -122,10 +123,58 @@ void Fortran::lower::CallInterface<T>::declare() {
       func = fir::FirOpBuilder::createFunction(loc, module, name, ty);
       if (const Fortran::semantics::Symbol *sym = side().getProcedureSymbol())
         addSymbolAttribute(func, *sym, converter.getMLIRContext());
+      for (const auto &placeHolder : llvm::enumerate(inputs))
+        if (!placeHolder.value().attributes.empty())
+          func.setArgAttrs(placeHolder.index(), placeHolder.value().attributes);
     }
   }
 }
 
+/// Once the signature has been analyzed and the mlir::FuncOp was built/found,
+/// map the fir inputs to Fortran entities (the symbols or expressions).
+template <typename T>
+void Fortran::lower::CallInterface<T>::mapPassedEntities() {
+  // map back fir inputs to passed entities
+  if constexpr (std::is_same_v<T, Fortran::lower::CalleeInterface>) {
+    assert(inputs.size() == func.front().getArguments().size() &&
+           "function previously created with 
diff erent number of arguments");
+    for (auto [fst, snd] : llvm::zip(inputs, func.front().getArguments()))
+      mapBackInputToPassedEntity(fst, snd);
+  } else {
+    // On the caller side, map the index of the mlir argument position
+    // to Fortran ActualArguments.
+    int firPosition = 0;
+    for (const FirPlaceHolder &placeHolder : inputs)
+      mapBackInputToPassedEntity(placeHolder, firPosition++);
+  }
+}
+
+template <typename T>
+void Fortran::lower::CallInterface<T>::mapBackInputToPassedEntity(
+    const FirPlaceHolder &placeHolder, FirValue firValue) {
+  PassedEntity &passedEntity =
+      placeHolder.passedEntityPosition == FirPlaceHolder::resultEntityPosition
+          ? passedResult.value()
+          : passedArguments[placeHolder.passedEntityPosition];
+  if (placeHolder.property == Property::CharLength)
+    passedEntity.firLength = firValue;
+  else
+    passedEntity.firArgument = firValue;
+}
+
+static const std::vector<Fortran::semantics::Symbol *> &
+getEntityContainer(Fortran::lower::pft::FunctionLikeUnit &funit) {
+  return funit.getSubprogramSymbol()
+      .get<Fortran::semantics::SubprogramDetails>()
+      .dummyArgs();
+}
+
+static const Fortran::semantics::Symbol &
+getDataObjectEntity(const Fortran::semantics::Symbol *arg) {
+  assert(arg && "expect symbol for data object entity");
+  return *arg;
+}
+
 //===----------------------------------------------------------------------===//
 // CallInterface implementation: this part is common to both caller and caller
 // sides.
@@ -136,9 +185,14 @@ void Fortran::lower::CallInterface<T>::declare() {
 template <typename T>
 class Fortran::lower::CallInterfaceImpl {
   using CallInterface = Fortran::lower::CallInterface<T>;
+  using PassEntityBy = typename CallInterface::PassEntityBy;
+  using PassedEntity = typename CallInterface::PassedEntity;
+  using FortranEntity = typename CallInterface::FortranEntity;
   using FirPlaceHolder = typename CallInterface::FirPlaceHolder;
   using Property = typename CallInterface::Property;
   using TypeAndShape = Fortran::evaluate::characteristics::TypeAndShape;
+  using DummyCharacteristics =
+      Fortran::evaluate::characteristics::DummyArgument;
 
 public:
   CallInterfaceImpl(CallInterface &i)
@@ -153,6 +207,24 @@ class Fortran::lower::CallInterfaceImpl {
     else if (interface.side().hasAlternateReturns())
       addFirResult(mlir::IndexType::get(&mlirContext),
                    FirPlaceHolder::resultEntityPosition, Property::Value);
+    // Handle arguments
+    const auto &argumentEntities =
+        getEntityContainer(interface.side().getCallDescription());
+    for (auto pair : llvm::zip(procedure.dummyArguments, argumentEntities)) {
+      const Fortran::evaluate::characteristics::DummyArgument
+          &argCharacteristics = std::get<0>(pair);
+      std::visit(
+          Fortran::common::visitors{
+              [&](const auto &dummy) {
+                const auto &entity = getDataObjectEntity(std::get<1>(pair));
+                handleImplicitDummy(&argCharacteristics, dummy, entity);
+              },
+              [&](const Fortran::evaluate::characteristics::AlternateReturn &) {
+                // nothing to do
+              },
+          },
+          argCharacteristics.u);
+    }
   }
 
   void buildExplicitInterface(
@@ -248,9 +320,78 @@ class Fortran::lower::CallInterfaceImpl {
         getConverter().getFoldingContext(), std::move(expr)));
   }
 
-  void addFirResult(mlir::Type type, int entityPosition, Property p) {
-    interface.outputs.emplace_back(FirPlaceHolder{type, entityPosition, p});
+  /// Return a vector with an attribute with the name of the argument if this
+  /// is a callee interface and the name is available. Otherwise, just return
+  /// an empty vector.
+  llvm::SmallVector<mlir::NamedAttribute>
+  dummyNameAttr(const FortranEntity &entity) {
+    if constexpr (std::is_same_v<FortranEntity,
+                                 std::optional<Fortran::common::Reference<
+                                     const Fortran::semantics::Symbol>>>) {
+      if (entity.has_value()) {
+        const Fortran::semantics::Symbol *argument = &*entity.value();
+        // "fir.bindc_name" is used for arguments for the sake of consistency
+        // with other attributes carrying surface syntax names in FIR.
+        return {mlir::NamedAttribute(
+            mlir::StringAttr::get(&mlirContext, "fir.bindc_name"),
+            mlir::StringAttr::get(&mlirContext,
+                                  toStringRef(argument->name())))};
+      }
+    }
+    return {};
+  }
+
+  void handleImplicitDummy(
+      const DummyCharacteristics *characteristics,
+      const Fortran::evaluate::characteristics::DummyDataObject &obj,
+      const FortranEntity &entity) {
+    Fortran::evaluate::DynamicType dynamicType = obj.type.type();
+    if (dynamicType.category() == Fortran::common::TypeCategory::Character) {
+      mlir::Type boxCharTy =
+          fir::BoxCharType::get(&mlirContext, dynamicType.kind());
+      addFirOperand(boxCharTy, nextPassedArgPosition(), Property::BoxChar,
+                    dummyNameAttr(entity));
+      addPassedArg(PassEntityBy::BoxChar, entity, characteristics);
+    } else {
+      // non-PDT derived type allowed in implicit interface.
+      Fortran::common::TypeCategory cat = dynamicType.category();
+      mlir::Type type = getConverter().genType(cat, dynamicType.kind());
+      fir::SequenceType::Shape bounds = getBounds(obj.type.shape());
+      if (!bounds.empty())
+        type = fir::SequenceType::get(bounds, type);
+      mlir::Type refType = fir::ReferenceType::get(type);
+      addFirOperand(refType, nextPassedArgPosition(), Property::BaseAddress,
+                    dummyNameAttr(entity));
+      addPassedArg(PassEntityBy::BaseAddress, entity, characteristics);
+    }
+  }
+
+  void handleImplicitDummy(
+      const DummyCharacteristics *characteristics,
+      const Fortran::evaluate::characteristics::DummyProcedure &proc,
+      const FortranEntity &entity) {
+    TODO(interface.converter.getCurrentLocation(),
+         "handleImlicitDummy DummyProcedure");
+  }
+
+  void
+  addFirOperand(mlir::Type type, int entityPosition, Property p,
+                llvm::ArrayRef<mlir::NamedAttribute> attributes = llvm::None) {
+    interface.inputs.emplace_back(
+        FirPlaceHolder{type, entityPosition, p, attributes});
+  }
+  void
+  addFirResult(mlir::Type type, int entityPosition, Property p,
+               llvm::ArrayRef<mlir::NamedAttribute> attributes = llvm::None) {
+    interface.outputs.emplace_back(
+        FirPlaceHolder{type, entityPosition, p, attributes});
+  }
+  void addPassedArg(PassEntityBy p, FortranEntity entity,
+                    const DummyCharacteristics *characteristics) {
+    interface.passedArguments.emplace_back(
+        PassedEntity{p, entity, {}, {}, characteristics});
   }
+  int nextPassedArgPosition() { return interface.passedArguments.size(); }
 
   Fortran::lower::AbstractConverter &getConverter() {
     return interface.converter;
@@ -273,9 +414,13 @@ void Fortran::lower::CallInterface<T>::determineInterface(
 template <typename T>
 mlir::FunctionType Fortran::lower::CallInterface<T>::genFunctionType() {
   llvm::SmallVector<mlir::Type> returnTys;
+  llvm::SmallVector<mlir::Type> inputTys;
   for (const FirPlaceHolder &placeHolder : outputs)
     returnTys.emplace_back(placeHolder.type);
-  return mlir::FunctionType::get(&converter.getMLIRContext(), {}, returnTys);
+  for (const FirPlaceHolder &placeHolder : inputs)
+    inputTys.emplace_back(placeHolder.type);
+  return mlir::FunctionType::get(&converter.getMLIRContext(), inputTys,
+                                 returnTys);
 }
 
 template class Fortran::lower::CallInterface<Fortran::lower::CalleeInterface>;

diff  --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index c207f60437695..bd347362fbc92 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -66,9 +66,25 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
                              Fortran::lower::SymMap &symMap) {
   assert(!var.isAlias());
   const Fortran::semantics::Symbol &sym = var.getSymbol();
+  const bool isDummy = Fortran::semantics::IsDummy(sym);
+  const bool isResult = Fortran::semantics::IsFunctionResult(sym);
   if (symMap.lookupSymbol(sym))
     return;
+
   const mlir::Location loc = converter.genLocation(sym.name());
+  if (isDummy) {
+    // This is an argument.
+    if (!symMap.lookupSymbol(sym))
+      mlir::emitError(loc, "symbol \"")
+          << toStringRef(sym.name()) << "\" must already be in map";
+    return;
+  } else if (isResult) {
+    // Some Fortran results may be passed by argument (e.g. derived
+    // types)
+    if (symMap.lookupSymbol(sym))
+      return;
+  }
+  // Otherwise, it's a local variable or function result.
   mlir::Value local = createNewLocal(converter, loc, var, {});
   symMap.addSymbol(sym, local);
 }

diff  --git a/flang/test/Lower/arguments.f90 b/flang/test/Lower/arguments.f90
new file mode 100644
index 0000000000000..e4515101be843
--- /dev/null
+++ b/flang/test/Lower/arguments.f90
@@ -0,0 +1,48 @@
+! RUN: bbc %s -o "-" -emit-fir | FileCheck %s
+
+subroutine sub1(a, b)
+  integer, intent(in) :: a
+  logical :: b
+end
+
+! Check that arguments are correctly set and no local allocation is happening.
+! CHECK-LABEL: func @_QPsub1(
+! CHECK-SAME:    %{{.*}}: !fir.ref<i32> {fir.bindc_name = "a"}, %{{.*}}: !fir.ref<!fir.logical<4>> {fir.bindc_name = "b"})
+! CHECK-NOT:     fir.alloc
+! CHECK:         return
+
+subroutine sub2(i)
+  integer :: i(2, 5)
+end
+
+! CHECK-LABEL: func @_QPsub2(
+! CHECK-SAME: %{{.*}}: !fir.ref<!fir.array<2x5xi32>>{{.*}})
+
+subroutine sub3(i)
+  real :: i(2)
+end
+
+! CHECK-LABEL: func @_QPsub3(
+! CHECK-SAME: %{{.*}}: !fir.ref<!fir.array<2xf32>>{{.*}})
+
+integer function fct1(a, b)
+  integer, intent(in) :: a
+  logical :: b
+end
+
+! CHECK-LABEL: func @_QPfct1(
+! CHECK-SAME:    %{{.*}}: !fir.ref<i32> {fir.bindc_name = "a"}, %{{.*}}: !fir.ref<!fir.logical<4>> {fir.bindc_name = "b"}) -> i32
+
+real function fct2(i)
+  integer :: i(2, 5)
+end
+
+! CHECK-LABEL: func @_QPfct2(
+! CHECK-SAME:    %{{.*}}: !fir.ref<!fir.array<2x5xi32>> {fir.bindc_name = "i"}) -> f32
+
+function fct3(i)
+  real :: i(2)
+end
+
+! CHECK-LABEL: func @_QPfct3(
+! CHECK-SAME:    %{{.*}}: !fir.ref<!fir.array<2xf32>> {fir.bindc_name = "i"}) -> f32


        


More information about the flang-commits mailing list