[flang-commits] [flang] 0538bfe - [flang] Moving common polymorphic code into utility files

via flang-commits flang-commits at lists.llvm.org
Wed Mar 8 08:41:00 PST 2023


Author: Renaud-K
Date: 2023-03-08T08:23:21-08:00
New Revision: 0538bfe7744ad9a1a4b1ffe5aa5c6466f88aac8f

URL: https://github.com/llvm/llvm-project/commit/0538bfe7744ad9a1a4b1ffe5aa5c6466f88aac8f
DIFF: https://github.com/llvm/llvm-project/commit/0538bfe7744ad9a1a4b1ffe5aa5c6466f88aac8f.diff

LOG: [flang] Moving common polymorphic code into utility files
Differential revision: https://reviews.llvm.org/D145530

Added: 
    flang/lib/Optimizer/Support/Utils.cpp

Modified: 
    flang/include/flang/Optimizer/Dialect/FIRType.h
    flang/include/flang/Optimizer/Support/Utils.h
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/lib/Optimizer/Dialect/FIRType.cpp
    flang/lib/Optimizer/Support/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 3a0254ab4de8e..83486ff614101 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -261,6 +261,24 @@ inline fir::SequenceType unwrapUntilSeqType(mlir::Type t) {
   }
 }
 
+/// Unwrap the referential and sequential outer types (if any). Returns the
+/// the element if type is fir::RecordType
+inline fir::RecordType unwrapIfDerived(fir::BaseBoxType boxTy) {
+  return fir::unwrapSequenceType(fir::unwrapRefType(boxTy.getEleTy()))
+      .template dyn_cast<fir::RecordType>();
+}
+
+/// Return true iff `boxTy` wraps a fir::RecordType with length parameters
+inline bool isDerivedTypeWithLenParams(fir::BaseBoxType boxTy) {
+  auto recTy = unwrapIfDerived(boxTy);
+  return recTy && recTy.getNumLenParams() > 0;
+}
+
+/// Return true iff `boxTy` wraps a fir::RecordType
+inline bool isDerivedType(fir::BaseBoxType boxTy) {
+  return static_cast<bool>(unwrapIfDerived(boxTy));
+}
+
 #ifndef NDEBUG
 // !fir.ptr<X> and !fir.heap<X> where X is !fir.ptr, !fir.heap, or !fir.ref
 // is undefined and disallowed.
@@ -300,6 +318,13 @@ bool isPolymorphicType(mlir::Type ty);
 /// value.
 bool isUnlimitedPolymorphicType(mlir::Type ty);
 
+/// Return true iff `boxTy` wraps a record type or an unlimited polymorphic
+/// entity. Polymorphic entities with intrinsic type spec do not have addendum
+inline bool boxHasAddendum(fir::BaseBoxType boxTy) {
+  return static_cast<bool>(unwrapIfDerived(boxTy)) ||
+         fir::isUnlimitedPolymorphicType(boxTy);
+}
+
 /// Return the inner type of the given type.
 mlir::Type unwrapInnerType(mlir::Type ty);
 

diff  --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index d25a4005f4017..7d06f56274eac 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -15,8 +15,12 @@
 
 #include "flang/Common/default-kinds.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringRef.h"
 
 namespace fir {
 /// Return the integer value of a arith::ConstantOp.
@@ -24,6 +28,11 @@ inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
   return cop.getValue().cast<mlir::IntegerAttr>().getValue().getSExtValue();
 }
 
+// Reconstruct binding tables for dynamic dispatch.
+using BindingTable = llvm::DenseMap<llvm::StringRef, unsigned>;
+using BindingTables = llvm::DenseMap<llvm::StringRef, BindingTable>;
+void buildBindingTables(BindingTables &, mlir::ModuleOp mod);
+
 // Translate front-end KINDs for use in the IR and code gen.
 inline std::vector<fir::KindTy>
 fromDefaultKinds(const Fortran::common::IntrinsicTypeDefaultKinds &defKinds) {

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 75f95f213c87a..51df08bffad33 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -16,8 +16,10 @@
 #include "flang/ISO_Fortran_binding.h"
 #include "flang/Optimizer/Dialect/FIRAttr.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Support/InternalNames.h"
 #include "flang/Optimizer/Support/TypeCode.h"
+#include "flang/Optimizer/Support/Utils.h"
 #include "flang/Semantics/runtime-type-info.h"
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
@@ -50,9 +52,6 @@ namespace fir {
 // fir::LLVMTypeConverter for converting to LLVM IR dialect types.
 #include "TypeConverter.h"
 
-using BindingTable = llvm::DenseMap<llvm::StringRef, unsigned>;
-using BindingTables = llvm::DenseMap<llvm::StringRef, BindingTable>;
-
 // TODO: This should really be recovered from the specified target.
 static constexpr unsigned defaultAlign = 8;
 
@@ -106,7 +105,7 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
 public:
   explicit FIROpConversion(fir::LLVMTypeConverter &lowering,
                            const fir::FIRToLLVMPassOptions &options,
-                           const BindingTables &bindingTables)
+                           const fir::BindingTables &bindingTables)
       : mlir::ConvertOpToLLVMPattern<FromOp>(lowering), options(options),
         bindingTables(bindingTables) {}
 
@@ -359,7 +358,7 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
   }
 
   const fir::FIRToLLVMPassOptions &options;
-  const BindingTables &bindingTables;
+  const fir::BindingTables &bindingTables;
 };
 
 /// FIR conversion pattern template
@@ -993,7 +992,7 @@ struct DispatchOpConversion : public FIROpConversion<fir::DispatchOp> {
              << "cannot find binding table for " << recordType.getName();
 
     // Lookup for the binding.
-    const BindingTable &bindingTable = bindingsIter->second;
+    const fir::BindingTable &bindingTable = bindingsIter->second;
     auto bindingIter = bindingTable.find(dispatch.getMethod());
     if (bindingIter == bindingTable.end())
       return emitError(loc)
@@ -1336,22 +1335,6 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
     return CFI_attribute_other;
   }
 
-  static fir::RecordType unwrapIfDerived(fir::BaseBoxType boxTy) {
-    return fir::unwrapSequenceType(fir::dyn_cast_ptrOrBoxEleTy(boxTy))
-        .template dyn_cast<fir::RecordType>();
-  }
-  static bool isDerivedTypeWithLenParams(fir::BaseBoxType boxTy) {
-    auto recTy = unwrapIfDerived(boxTy);
-    return recTy && recTy.getNumLenParams() > 0;
-  }
-  static bool isDerivedType(fir::BaseBoxType boxTy) {
-    return static_cast<bool>(unwrapIfDerived(boxTy));
-  }
-  static bool hasAddendum(fir::BaseBoxType boxTy) {
-    return static_cast<bool>(unwrapIfDerived(boxTy)) ||
-           fir::isUnlimitedPolymorphicType(boxTy);
-  }
-
   // Get the element size and CFI type code of the boxed value.
   std::tuple<mlir::Value, mlir::Value> getSizeAndTypeCode(
       mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
@@ -1571,7 +1554,7 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
     descriptor =
         insertField(rewriter, loc, descriptor, {kAttributePosInBox},
                     this->genI32Constant(loc, rewriter, getCFIAttr(boxTy)));
-    const bool hasAddendum = isDerivedType(boxTy) || isUnlimitedPolymorphic;
+    const bool hasAddendum = fir::boxHasAddendum(boxTy);
     descriptor =
         insertField(rewriter, loc, descriptor, {kF18AddendumPosInBox},
                     this->genI32Constant(loc, rewriter, hasAddendum ? 1 : 0));
@@ -1591,8 +1574,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
                 loc, ::getVoidPtrType(mod.getContext()));
           }
         } else {
-          typeDesc =
-              getTypeDescriptor(mod, rewriter, loc, unwrapIfDerived(boxTy));
+          typeDesc = getTypeDescriptor(mod, rewriter, loc,
+                                       fir::unwrapIfDerived(boxTy));
         }
       }
       if (typeDesc)
@@ -1674,7 +1657,7 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
       // TODO: For initial box that are unlimited polymorphic entities, this
       // code must be made conditional because unlimited polymorphic entities
       // with intrinsic type spec does not have addendum.
-      if (hasAddendum(inputBoxTy))
+      if (fir::boxHasAddendum(inputBoxTy))
         typeDesc = this->loadTypeDescAddress(loc, box.getBox().getType(),
                                              loweredBox, rewriter);
     }
@@ -1826,7 +1809,7 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
         /*rank=*/0, /*lenParams=*/operands.drop_front(1), sourceBox,
         sourceBoxType);
     dest = insertBaseAddress(rewriter, embox.getLoc(), dest, operands[0]);
-    if (isDerivedTypeWithLenParams(boxTy)) {
+    if (fir::isDerivedTypeWithLenParams(boxTy)) {
       TODO(embox.getLoc(),
            "fir.embox codegen of derived with length parameters");
       return mlir::failure();
@@ -2010,7 +1993,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
                              fieldIndices, substringOffset);
     }
     dest = insertBaseAddress(rewriter, loc, dest, base);
-    if (isDerivedTypeWithLenParams(boxTy))
+    if (fir::isDerivedTypeWithLenParams(boxTy))
       TODO(loc, "fir.embox codegen of derived with length parameters");
 
     mlir::Value result =
@@ -3670,7 +3653,7 @@ template <typename FromOp>
 struct MustBeDeadConversion : public FIROpConversion<FromOp> {
   explicit MustBeDeadConversion(fir::LLVMTypeConverter &lowering,
                                 const fir::FIRToLLVMPassOptions &options,
-                                const BindingTables &bindingTables)
+                                const fir::BindingTables &bindingTables)
       : FIROpConversion<FromOp>(lowering, options, bindingTables) {}
   using OpAdaptor = typename FromOp::Adaptor;
 
@@ -3781,24 +3764,8 @@ class FIRToLLVMLowering
     if (mlir::failed(runPipeline(mathConvertionPM, mod)))
       return signalPassFailure();
 
-    // Reconstruct binding tables for dynamic dispatch. The binding tables
-    // are defined in FIR from lowering as fir.dispatch_table operation.
-    // Go through each binding tables and store the procedure name
-    // and binding index for later use by the fir.dispatch conversion pattern.
-    BindingTables bindingTables;
-    for (auto dispatchTableOp : mod.getOps<fir::DispatchTableOp>()) {
-      unsigned bindingIdx = 0;
-      BindingTable bindings;
-      if (dispatchTableOp.getRegion().empty()) {
-        bindingTables[dispatchTableOp.getSymName()] = bindings;
-        continue;
-      }
-      for (auto dtEntry : dispatchTableOp.getBlock().getOps<fir::DTEntryOp>()) {
-        bindings[dtEntry.getMethod()] = bindingIdx;
-        ++bindingIdx;
-      }
-      bindingTables[dispatchTableOp.getSymName()] = bindings;
-    }
+    fir::BindingTables bindingTables;
+    fir::buildBindingTables(bindingTables, mod);
 
     auto *context = getModule().getContext();
     fir::LLVMTypeConverter typeConverter{getModule(),

diff  --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index decb93f3d55e3..17fe1179162d1 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -219,12 +219,8 @@ mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
   return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
       .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
             fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
-      .Case<fir::BaseBoxType>([](auto p) {
-        auto eleTy = p.getEleTy();
-        if (auto ty = fir::dyn_cast_ptrEleTy(eleTy))
-          return ty;
-        return eleTy;
-      })
+      .Case<fir::BaseBoxType>(
+          [](auto p) { return unwrapRefType(p.getEleTy()); })
       .Default([](mlir::Type) { return mlir::Type{}; });
 }
 

diff  --git a/flang/lib/Optimizer/Support/CMakeLists.txt b/flang/lib/Optimizer/Support/CMakeLists.txt
index 6a1c004ac88a4..b878a1f86b5d1 100644
--- a/flang/lib/Optimizer/Support/CMakeLists.txt
+++ b/flang/lib/Optimizer/Support/CMakeLists.txt
@@ -5,6 +5,7 @@ add_flang_library(FIRSupport
   InitFIR.cpp
   InternalNames.cpp
   KindMapping.cpp
+  Utils.cpp
 
   DEPENDS
   FIROpsIncGen

diff  --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp
new file mode 100644
index 0000000000000..e973bf4c7049b
--- /dev/null
+++ b/flang/lib/Optimizer/Support/Utils.cpp
@@ -0,0 +1,36 @@
+//===-- Utils.cpp ---------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Support/Utils.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+
+namespace fir {
+void buildBindingTables(BindingTables &bindingTables, mlir::ModuleOp mod) {
+
+  // The binding tables are defined in FIR from lowering as fir.dispatch_table
+  // operation. Go through each binding tables and store the procedure name and
+  // binding index for later use by the fir.dispatch conversion pattern.
+  for (auto dispatchTableOp : mod.getOps<fir::DispatchTableOp>()) {
+    unsigned bindingIdx = 0;
+    BindingTable bindings;
+    if (dispatchTableOp.getRegion().empty()) {
+      bindingTables[dispatchTableOp.getSymName()] = bindings;
+      continue;
+    }
+    for (auto dtEntry : dispatchTableOp.getBlock().getOps<fir::DTEntryOp>()) {
+      bindings[dtEntry.getMethod()] = bindingIdx;
+      ++bindingIdx;
+    }
+    bindingTables[dispatchTableOp.getSymName()] = bindings;
+  }
+}
+} // namespace fir


        


More information about the flang-commits mailing list