[flang-commits] [flang] 6393d2e - [flang] Create fir.dispatch_table and fir.dt_entry operations

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Thu Nov 17 01:53:50 PST 2022


Author: Valentin Clement
Date: 2022-11-17T10:53:43+01:00
New Revision: 6393d2ea24fb458c353f8d453ab5f20663875cb1

URL: https://github.com/llvm/llvm-project/commit/6393d2ea24fb458c353f8d453ab5f20663875cb1
DIFF: https://github.com/llvm/llvm-project/commit/6393d2ea24fb458c353f8d453ab5f20663875cb1.diff

LOG: [flang] Create fir.dispatch_table and fir.dt_entry operations

Create the fir.dispatch_table operation based on semantics
information. The fir.dispatch_table will be used for static devirtualization
as well as for fir.select_type conversion.

Depends on D138129

Reviewed By: jeanPerier, PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D138131

Added: 
    flang/test/Lower/dispatch-table.f90

Modified: 
    flang/include/flang/Lower/AbstractConverter.h
    flang/include/flang/Optimizer/Builder/FIRBuilder.h
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Lower/Bridge.cpp
    flang/lib/Lower/ConvertType.cpp
    flang/lib/Optimizer/Builder/FIRBuilder.cpp
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/lib/Optimizer/Dialect/FIROps.cpp
    flang/test/Lower/polymorphic-types.f90

Removed: 
    flang/test/Fir/Todo/dispatch_table.fir


################################################################################
diff  --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index a84e34f4bd8a4..358a6193438bc 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -204,6 +204,10 @@ class AbstractConverter {
   virtual void registerRuntimeTypeInfo(mlir::Location loc,
                                        SymbolRef typeInfoSym) = 0;
 
+  virtual void registerDispatchTableInfo(
+      mlir::Location loc,
+      const Fortran::semantics::DerivedTypeSpec *typeSpec) = 0;
+
   //===--------------------------------------------------------------------===//
   // Locations
   //===--------------------------------------------------------------------===//

diff  --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index a28ada96ecf7a..f6b795515ecc2 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -212,6 +212,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
                         bodyBuilder, linkage);
   }
 
+  /// Create a fir::DispatchTable operation.
+  fir::DispatchTableOp createDispatchTableOp(mlir::Location loc,
+                                             llvm::StringRef name,
+                                             llvm::StringRef parentName);
+
   /// Convert a StringRef string into a fir::StringLitOp.
   fir::StringLitOp createStringLitOp(mlir::Location loc,
                                      llvm::StringRef string);

diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8154b1b7d2b93..3e1ae7183956d 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2804,26 +2804,30 @@ def fir_DispatchTableOp : fir_Op<"dispatch_table",
     ```
   }];
 
+  let arguments = (ins
+    SymbolNameAttr:$sym_name,
+    OptionalAttr<StrAttr>:$parent
+  );
+
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 
-  let regions = (region SizedRegion<1>:$region);
+  let regions = (region AnyRegion:$region);
 
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins "llvm::StringRef":$name, "mlir::Type":$type,
-      CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs),
-    [{
-      $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
-                           $_builder.getStringAttr(name));
-      $_state.addAttributes(attrs);
-    }]>
+      "llvm::StringRef":$parent,
+      CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs)>
   ];
 
   let extraClassDeclaration = [{
     /// Append a dispatch table entry to the table.
     void appendTableEntry(mlir::Operation *op);
 
+    static constexpr llvm::StringRef getParentAttrNameStr() { return "parent"; }
+    static constexpr llvm::StringRef getExtendsKeyword() { return "extends"; }
+
     mlir::Block &getBlock() {
       return getRegion().front();
     }

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 15cb297896c90..5d8be0390ca87 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -46,11 +46,13 @@
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Runtime/iostat.h"
+#include "flang/Semantics/runtime-type-info.h"
 #include "flang/Semantics/tools.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -193,6 +195,67 @@ class RuntimeTypeInfoConverter {
   llvm::SmallSetVector<Fortran::semantics::SymbolRef, 64> seen;
 };
 
+class DispatchTableConverter {
+  struct DispatchTableInfo {
+    const Fortran::semantics::DerivedTypeSpec *typeSpec;
+    mlir::Location loc;
+  };
+
+public:
+  void registerTypeSpec(mlir::Location loc,
+                        const Fortran::semantics::DerivedTypeSpec *typeSpec) {
+    assert(typeSpec && "type spec is null");
+    std::string dtName = Fortran::lower::mangle::mangleName(*typeSpec);
+    if (seen.contains(dtName) || dtName.find("__fortran") != std::string::npos)
+      return;
+    seen.insert(dtName);
+    registeredDispatchTableInfo.emplace_back(DispatchTableInfo{typeSpec, loc});
+  }
+
+  void createDispatchTableOps(Fortran::lower::AbstractConverter &converter) {
+    for (const DispatchTableInfo &info : registeredDispatchTableInfo) {
+      std::string dtName = Fortran::lower::mangle::mangleName(*info.typeSpec);
+      const Fortran::semantics::DerivedTypeSpec *parent =
+          Fortran::evaluate::GetParentTypeSpec(*info.typeSpec);
+      fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+      fir::DispatchTableOp dt = builder.createDispatchTableOp(
+          info.loc, dtName,
+          parent ? Fortran::lower::mangle::mangleName(*parent) : "");
+      auto insertPt = builder.saveInsertionPoint();
+
+      std::vector<const Fortran::semantics::Symbol *> bindings =
+          Fortran::semantics::CollectBindings(*info.typeSpec->scope());
+
+      if (!bindings.empty())
+        builder.createBlock(&dt.getRegion());
+
+      for (const Fortran::semantics::Symbol *binding : bindings) {
+        const auto *details =
+            binding->detailsIf<Fortran::semantics::ProcBindingDetails>();
+        std::string bindingName =
+            Fortran::lower::mangle::mangleName(details->symbol());
+        builder.create<fir::DTEntryOp>(
+            info.loc,
+            mlir::StringAttr::get(builder.getContext(),
+                                  binding->name().ToString()),
+            mlir::SymbolRefAttr::get(builder.getContext(), bindingName));
+      }
+      if (!bindings.empty())
+        builder.create<fir::FirEndOp>(info.loc);
+      builder.restoreInsertionPoint(insertPt);
+    }
+    registeredDispatchTableInfo.clear();
+  }
+
+private:
+  /// Store the semantic DerivedTypeSpec that will be required to generate the
+  /// dispatch table.
+  llvm::SmallVector<DispatchTableInfo> registeredDispatchTableInfo;
+
+  /// Track processed type specs to avoid multiple creation.
+  llvm::StringSet<> seen;
+};
+
 using IncrementLoopNestInfo = llvm::SmallVector<IncrementLoopInfo, 8>;
 } // namespace
 
@@ -270,6 +333,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     createGlobalOutsideOfFunctionLowering(
         [&]() { runtimeTypeInfoConverter.createTypeInfoGlobals(*this); });
 
+    /// Create the dispatch tables for derived types.
+    createGlobalOutsideOfFunctionLowering(
+        [&]() { dispatchTableConverter.createDispatchTableOps(*this); });
+
     // Create the list of any environment defaults for the runtime to set. The
     // runtime default list is only created if there is a main program to ensure
     // it only happens once and to provide consistent results if multiple files
@@ -745,6 +812,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     runtimeTypeInfoConverter.registerTypeInfoSymbol(*this, loc, typeInfoSym);
   }
 
+  void registerDispatchTableInfo(
+      mlir::Location loc,
+      const Fortran::semantics::DerivedTypeSpec *typeSpec) override final {
+    dispatchTableConverter.registerTypeSpec(loc, typeSpec);
+  }
+
 private:
   FirConverter() = delete;
   FirConverter(const FirConverter &) = delete;
@@ -3591,6 +3664,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   Fortran::lower::SymMap localSymbols;
   Fortran::parser::CharBlock currentPosition;
   RuntimeTypeInfoConverter runtimeTypeInfoConverter;
+  DispatchTableConverter dispatchTableConverter;
 
   /// WHERE statement/construct mask expression stack.
   Fortran::lower::ImplicitIterSpace implicitIterSpace;

diff  --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp
index e9a2e339e7876..6f86cbfee79d6 100644
--- a/flang/lib/Lower/ConvertType.cpp
+++ b/flang/lib/Lower/ConvertType.cpp
@@ -340,6 +340,8 @@ struct TypeBuilder {
     }
     LLVM_DEBUG(llvm::dbgs() << "derived type: " << rec << '\n');
 
+    converter.registerDispatchTableInfo(loc, &tySpec);
+
     // Generate the type descriptor object if any
     if (const Fortran::semantics::Scope *derivedScope =
             tySpec.scope() ? tySpec.scope() : tySpec.typeSymbol().scope())

diff  --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 14eca83dc9944..f40953edb344e 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -271,6 +271,18 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
   return glob;
 }
 
+fir::DispatchTableOp fir::FirOpBuilder::createDispatchTableOp(
+    mlir::Location loc, llvm::StringRef name, llvm::StringRef parentName) {
+  auto module = getModule();
+  auto insertPt = saveInsertionPoint();
+  if (auto dt = module.lookupSymbol<fir::DispatchTableOp>(name))
+    return dt;
+  setInsertionPoint(module.getBody(), module.getBody()->end());
+  auto dt = create<fir::DispatchTableOp>(loc, name, mlir::Type{}, parentName);
+  restoreInsertionPoint(insertPt);
+  return dt;
+}
+
 mlir::Value
 fir::FirOpBuilder::convertWithSemantics(mlir::Location loc, mlir::Type toTy,
                                         mlir::Value val,

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f8a46a9d9555e..f5e035d470eb7 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -1051,30 +1051,30 @@ struct DispatchOpConversion : public FIROpConversion<fir::DispatchOp> {
   }
 };
 
-/// Lower `fir.dispatch_table` operation. The dispatch table for a Fortran
-/// derived type.
+/// `fir.disptach_table` operation has no specific CodeGen. The operation is
+/// only used to carry information during FIR to FIR passes.
 struct DispatchTableOpConversion
     : public FIROpConversion<fir::DispatchTableOp> {
   using FIROpConversion::FIROpConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(fir::DispatchTableOp dispTab, OpAdaptor adaptor,
+  matchAndRewrite(fir::DispatchTableOp op, OpAdaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    TODO(dispTab.getLoc(), "fir.dispatch_table codegen");
-    return mlir::failure();
+    rewriter.eraseOp(op);
+    return mlir::success();
   }
 };
 
-/// Lower `fir.dt_entry` operation. An entry in a dispatch table; binds a
-/// method-name to a function.
+/// `fir.dt_entry` operation has no specific CodeGen. The operation is only used
+/// to carry information during FIR to FIR passes.
 struct DTEntryOpConversion : public FIROpConversion<fir::DTEntryOp> {
   using FIROpConversion::FIROpConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(fir::DTEntryOp dtEnt, OpAdaptor adaptor,
+  matchAndRewrite(fir::DTEntryOp op, OpAdaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    TODO(dtEnt.getLoc(), "fir.dt_entry codegen");
-    return mlir::failure();
+    rewriter.eraseOp(op);
+    return mlir::success();
   }
 };
 

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index a64edcc147580..53b21b3fb0d3f 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1092,14 +1092,19 @@ void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) {
 mlir::ParseResult fir::DispatchTableOp::parse(mlir::OpAsmParser &parser,
                                               mlir::OperationState &result) {
   // Parse the name as a symbol reference attribute.
-  mlir::SymbolRefAttr nameAttr;
-  if (parser.parseAttribute(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
-                            result.attributes))
+  mlir::StringAttr nameAttr;
+  if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
+                             result.attributes))
     return mlir::failure();
 
-  // Convert the parsed name attr into a string attr.
-  result.attributes.set(mlir::SymbolTable::getSymbolAttrName(),
-                        nameAttr.getRootReference());
+  if (!failed(parser.parseOptionalKeyword(getExtendsKeyword()))) {
+    mlir::StringAttr parent;
+    if (parser.parseLParen() ||
+        parser.parseAttribute(parent, getParentAttrNameStr(),
+                              result.attributes) ||
+        parser.parseRParen())
+      return mlir::failure();
+  }
 
   // Parse the optional table body.
   mlir::Region *body = result.addRegion();
@@ -1113,11 +1118,11 @@ mlir::ParseResult fir::DispatchTableOp::parse(mlir::OpAsmParser &parser,
 }
 
 void fir::DispatchTableOp::print(mlir::OpAsmPrinter &p) {
-  auto tableName = getOperation()
-                       ->getAttrOfType<mlir::StringAttr>(
-                           mlir::SymbolTable::getSymbolAttrName())
-                       .getValue();
-  p << " @" << tableName;
+  p << ' ';
+  p.printSymbolName(getSymName());
+  if (getParent())
+    p << ' ' << getExtendsKeyword() << '('
+      << (*this)->getAttr(getParentAttrNameStr()) << ')';
 
   mlir::Region &body = getOperation()->getRegion(0);
   if (!body.empty()) {
@@ -1128,12 +1133,29 @@ void fir::DispatchTableOp::print(mlir::OpAsmPrinter &p) {
 }
 
 mlir::LogicalResult fir::DispatchTableOp::verify() {
+  if (getRegion().empty())
+    return mlir::success();
   for (auto &op : getBlock())
     if (!mlir::isa<fir::DTEntryOp, fir::FirEndOp>(op))
       return op.emitOpError("dispatch table must contain dt_entry");
   return mlir::success();
 }
 
+void fir::DispatchTableOp::build(mlir::OpBuilder &builder,
+                                 mlir::OperationState &result,
+                                 llvm::StringRef name, mlir::Type type,
+                                 llvm::StringRef parent,
+                                 llvm::ArrayRef<mlir::NamedAttribute> attrs) {
+  result.addRegion();
+  result.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
+                      builder.getStringAttr(name));
+  if (!parent.empty())
+    result.addAttribute(getParentAttrNameStr(), builder.getStringAttr(parent));
+  // result.addAttribute(getSymbolAttrNameStr(),
+  //                     mlir::SymbolRefAttr::get(builder.getContext(), name));
+  result.addAttributes(attrs);
+}
+
 //===----------------------------------------------------------------------===//
 // EmboxOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/Fir/Todo/dispatch_table.fir b/flang/test/Fir/Todo/dispatch_table.fir
deleted file mode 100644
index 6aa4811e52a6d..0000000000000
--- a/flang/test/Fir/Todo/dispatch_table.fir
+++ /dev/null
@@ -1,9 +0,0 @@
-// RUN: %not_todo_cmd fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s 2>&1 | FileCheck %s
-
-// Test fir.dispatch_table conversion to llvm.
-// Not implemented yet.
-
-// CHECK: not yet implemented: fir.dispatch_table codegen
-fir.dispatch_table @dispatch_tbl {
-  fir.dt_entry "method", @method_impl
-}

diff  --git a/flang/test/Lower/dispatch-table.f90 b/flang/test/Lower/dispatch-table.f90
new file mode 100644
index 0000000000000..7b5a5ec4a39e6
--- /dev/null
+++ b/flang/test/Lower/dispatch-table.f90
@@ -0,0 +1,75 @@
+! RUN: bbc -polymorphic-type -emit-fir %s -o - | FileCheck %s
+
+! Tests the generation of fir.dispatch_table operations.
+
+module polymorphic_types
+  type p1
+    integer :: a
+    integer :: b
+  contains
+    procedure :: proc1 => proc1_p1
+    procedure :: aproc
+    procedure :: zproc
+  end type
+
+  type, extends(p1) :: p2
+    integer :: c
+  contains
+    procedure :: proc1 => proc1_p2
+    procedure :: aproc2
+  end type
+
+  type, extends(p2) :: p3
+    integer :: d
+  contains
+    procedure :: aproc3
+  end type
+contains
+
+
+  subroutine proc1_p1(p)
+    class(p1) :: p
+  end subroutine
+
+  subroutine aproc(p)
+    class(p1) :: p
+  end subroutine
+
+  subroutine zproc(p)
+    class(p1) :: p
+  end subroutine
+
+  subroutine proc1_p2(p)
+    class(p2) :: p
+  end subroutine
+
+  subroutine aproc2(p)
+    class(p2) :: p
+  end subroutine
+
+  subroutine aproc3(p)
+    class(p3) :: p
+  end subroutine
+
+end module
+
+! CHECK-LABEL: fir.dispatch_table @_QMpolymorphic_typesTp1 {
+! CHECK:         fir.dt_entry "aproc", @_QMpolymorphic_typesPaproc
+! CHECK:         fir.dt_entry "proc1", @_QMpolymorphic_typesPproc1_p1
+! CHECK:         fir.dt_entry "zproc", @_QMpolymorphic_typesPzproc
+! CHECK:       }
+
+! CHECK-LABEL: fir.dispatch_table @_QMpolymorphic_typesTp2 extends("_QMpolymorphic_typesTp1") {
+! CHECK:         fir.dt_entry "aproc", @_QMpolymorphic_typesPaproc
+! CHECK:         fir.dt_entry "proc1", @_QMpolymorphic_typesPproc1_p2
+! CHECK:         fir.dt_entry "zproc", @_QMpolymorphic_typesPzproc
+! CHECK:         fir.dt_entry "aproc2", @_QMpolymorphic_typesPaproc2
+! CHECK:       }
+
+! CHECK-LABEL: fir.dispatch_table @_QMpolymorphic_typesTp3 extends("_QMpolymorphic_typesTp2") {
+! CHECK:         fir.dt_entry "aproc", @_QMpolymorphic_typesPaproc
+! CHECK:         fir.dt_entry "proc1", @_QMpolymorphic_typesPproc1_p2
+! CHECK:         fir.dt_entry "zproc", @_QMpolymorphic_typesPzproc
+! CHECK:         fir.dt_entry "aproc2", @_QMpolymorphic_typesPaproc2
+! CHECK:         fir.dt_entry "aproc3", @_QMpolymorphic_typesPaproc3
+! CHECK:       }

diff  --git a/flang/test/Lower/polymorphic-types.f90 b/flang/test/Lower/polymorphic-types.f90
index d0354f95ca4a5..1c9284a309e28 100644
--- a/flang/test/Lower/polymorphic-types.f90
+++ b/flang/test/Lower/polymorphic-types.f90
@@ -180,4 +180,5 @@ end subroutine assumed_type_dummy_array
 
   ! CHECK-LABEL: func.func @assumed_type_dummy_array(
   ! CHECK-SAME: %{{.*}}: !fir.box<!fir.array<?xnone>>
+
 end module


        


More information about the flang-commits mailing list