[flang-commits] [flang] ff761f2 - [flang] Move fir.select_type into the PolymorphicOpConversion pass

via flang-commits flang-commits at lists.llvm.org
Wed Mar 1 11:34:12 PST 2023


Author: Renaud-K
Date: 2023-03-01T11:33:31-08:00
New Revision: ff761f2ce49ae25cdd46459b44abe6ed78ff64bc

URL: https://github.com/llvm/llvm-project/commit/ff761f2ce49ae25cdd46459b44abe6ed78ff64bc
DIFF: https://github.com/llvm/llvm-project/commit/ff761f2ce49ae25cdd46459b44abe6ed78ff64bc.diff

LOG: [flang] Move fir.select_type into the PolymorphicOpConversion pass
https://reviews.llvm.org/D144921

Added: 
    flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/include/flang/Tools/CLOptions.inc
    flang/lib/Optimizer/Transforms/CMakeLists.txt
    flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
    flang/test/Driver/bbc-mlir-pass-pipeline.f90
    flang/test/Driver/mlir-pass-pipeline.f90
    flang/test/Fir/basic-program.fir
    flang/test/Lower/select-type.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 55f000f067b5..8af14f8013ab 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -42,6 +42,7 @@ namespace fir {
 #define GEN_PASS_DECL_MEMORYALLOCATIONOPT
 #define GEN_PASS_DECL_SIMPLIFYREGIONLITE
 #define GEN_PASS_DECL_ALGEBRAICSIMPLIFICATION
+#define GEN_PASS_DECL_POLYMORPHICOPCONVERSION
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 
 std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
@@ -68,6 +69,7 @@ std::unique_ptr<mlir::Pass> createSimplifyRegionLitePass();
 std::unique_ptr<mlir::Pass> createAlgebraicSimplificationPass();
 std::unique_ptr<mlir::Pass>
 createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
+std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();
 
 // declarative passes
 #define GEN_PASS_REGISTRATION

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index dcec17f622f6..b8ad0243b6af 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -271,4 +271,18 @@ def AlgebraicSimplification : Pass<"flang-algebraic-simplification"> {
   let constructor = "::fir::createAlgebraicSimplificationPass()";
 }
 
+def PolymorphicOpConversion : Pass<"fir-polymorphic-op", "::mlir::func::FuncOp"> {
+  let summary =
+    "Simplify operations on polymorphic types";
+  let description = [{
+    This pass breaks up the lowering of operations on polymorphic types by 
+    introducing an intermediate FIR level that simplifies code geneation. 
+  }];
+  let constructor = "::fir::createPolymorphicOpConversionPass()";
+  let dependentDialects = [
+    "fir::FIROpsDialect", "mlir::func::FuncDialect"
+  ];
+}
+
+  
 #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

diff  --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 1015b9f86dc7..932468613871 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -199,6 +199,9 @@ inline void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
   pm.addPass(fir::createSimplifyRegionLitePass());
   pm.addPass(mlir::createCSEPass());
 
+  // Polymorphic types
+  pm.addPass(fir::createPolymorphicOpConversionPass());
+
   // convert control flow to CFG form
   fir::addCfgConversionPass(pm);
   pm.addPass(mlir::createConvertSCFToCFPass());

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 737011bde37c..2db16de2187a 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_flang_library(FIRTransforms
   AlgebraicSimplification.cpp
   SimplifyIntrinsics.cpp
   AddDebugFoundation.cpp
+  PolymorphicOpConversion.cpp
 
   DEPENDS
   FIRBuilder

diff  --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index d334866939e0..d13b488cfda5 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -22,7 +22,6 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/Support/CommandLine.h"
-#include <mutex>
 
 namespace fir {
 #define GEN_PASS_DEF_CFGCONVERSION
@@ -308,278 +307,20 @@ class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> {
   }
 };
 
-/// SelectTypeOp converted to an if-then-else chain
-///
-/// This lowers the test conditions to calls into the runtime.
-class CfgSelectTypeConv : public OpConversionPattern<fir::SelectTypeOp> {
-public:
-  using OpConversionPattern<fir::SelectTypeOp>::OpConversionPattern;
-
-  CfgSelectTypeConv(mlir::MLIRContext *ctx, std::mutex *moduleMutex)
-      : mlir::OpConversionPattern<fir::SelectTypeOp>(ctx),
-        moduleMutex(moduleMutex) {}
-
-  mlir::LogicalResult
-  matchAndRewrite(fir::SelectTypeOp selectType, OpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
-    auto operands = adaptor.getOperands();
-    auto typeGuards = selectType.getCases();
-    unsigned typeGuardNum = typeGuards.size();
-    auto selector = selectType.getSelector();
-    auto loc = selectType.getLoc();
-    auto mod = selectType.getOperation()->getParentOfType<mlir::ModuleOp>();
-    fir::KindMapping kindMap = fir::getKindMapping(mod);
-
-    // Order type guards so the condition and branches are done to respect the
-    // Execution of SELECT TYPE construct as described in the Fortran 2018
-    // standard 11.1.11.2 point 4.
-    // 1. If a TYPE IS type guard statement matches the selector, the block
-    //    following that statement is executed.
-    // 2. Otherwise, if exactly one CLASS IS type guard statement matches the
-    //    selector, the block following that statement is executed.
-    // 3. Otherwise, if several CLASS IS type guard statements match the
-    //    selector, one of these statements will inevitably specify a type that
-    //    is an extension of all the types specified in the others; the block
-    //    following that statement is executed.
-    // 4. Otherwise, if there is a CLASS DEFAULT type guard statement, the block
-    //    following that statement is executed.
-    // 5. Otherwise, no block is executed.
-
-    llvm::SmallVector<unsigned> orderedTypeGuards;
-    llvm::SmallVector<unsigned> orderedClassIsGuards;
-    unsigned defaultGuard = typeGuardNum - 1;
-
-    // The following loop go through the type guards in the fir.select_type
-    // operation and sort them into two lists.
-    // - All the TYPE IS type guard are added in order to the orderedTypeGuards
-    //   list. This list is used at the end to generate the if-then-else ladder.
-    // - CLASS IS type guard are added in a separate list. If a CLASS IS type
-    //   guard type extends a type already present, the type guard is inserted
-    //   before in the list to respect point 3. above. Otherwise it is just
-    //   added in order at the end.
-    for (unsigned t = 0; t < typeGuardNum; ++t) {
-      if (auto a = typeGuards[t].dyn_cast<fir::ExactTypeAttr>()) {
-        orderedTypeGuards.push_back(t);
-        continue;
-      }
-
-      if (auto a = typeGuards[t].dyn_cast<fir::SubclassAttr>()) {
-        if (auto recTy = a.getType().dyn_cast<fir::RecordType>()) {
-          auto dt = mod.lookupSymbol<fir::DispatchTableOp>(recTy.getName());
-          assert(dt && "dispatch table not found");
-          llvm::SmallSet<llvm::StringRef, 4> ancestors =
-              collectAncestors(dt, mod);
-          if (!ancestors.empty()) {
-            auto it = orderedClassIsGuards.begin();
-            while (it != orderedClassIsGuards.end()) {
-              fir::SubclassAttr sAttr =
-                  typeGuards[*it].dyn_cast<fir::SubclassAttr>();
-              if (auto ty = sAttr.getType().dyn_cast<fir::RecordType>()) {
-                if (ancestors.contains(ty.getName()))
-                  break;
-              }
-              ++it;
-            }
-            if (it != orderedClassIsGuards.end()) {
-              // Parent type is present so place it before.
-              orderedClassIsGuards.insert(it, t);
-              continue;
-            }
-          }
-        }
-        orderedClassIsGuards.push_back(t);
-      }
-    }
-    orderedTypeGuards.append(orderedClassIsGuards);
-    orderedTypeGuards.push_back(defaultGuard);
-    assert(orderedTypeGuards.size() == typeGuardNum &&
-           "ordered type guard size doesn't match number of type guards");
-
-    for (unsigned idx : orderedTypeGuards) {
-      auto *dest = selectType.getSuccessor(idx);
-      std::optional<mlir::ValueRange> destOps =
-          selectType.getSuccessorOperands(operands, idx);
-      if (typeGuards[idx].dyn_cast<mlir::UnitAttr>())
-        rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(selectType, dest);
-      else if (mlir::failed(genTypeLadderStep(loc, selector, typeGuards[idx],
-                                              dest, destOps, mod, rewriter,
-                                              kindMap)))
-        return mlir::failure();
-    }
-    return mlir::success();
-  }
-
-  llvm::SmallSet<llvm::StringRef, 4>
-  collectAncestors(fir::DispatchTableOp dt, mlir::ModuleOp mod) const {
-    llvm::SmallSet<llvm::StringRef, 4> ancestors;
-    if (!dt.getParent().has_value())
-      return ancestors;
-    while (dt.getParent().has_value()) {
-      ancestors.insert(*dt.getParent());
-      dt = mod.lookupSymbol<fir::DispatchTableOp>(*dt.getParent());
-    }
-    return ancestors;
-  }
-
-  // Generate comparison of type descriptor addresses.
-  mlir::Value genTypeDescCompare(mlir::Location loc, mlir::Value selector,
-                                 mlir::Type ty, mlir::ModuleOp mod,
-                                 mlir::PatternRewriter &rewriter) const {
-    assert(ty.isa<fir::RecordType>() && "expect fir.record type");
-    fir::RecordType recTy = ty.dyn_cast<fir::RecordType>();
-    std::string typeDescName =
-        fir::NameUniquer::getTypeDescriptorName(recTy.getName());
-    auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName);
-    if (!typeDescGlobal)
-      return {};
-    auto typeDescAddr = rewriter.create<fir::AddrOfOp>(
-        loc, fir::ReferenceType::get(typeDescGlobal.getType()),
-        typeDescGlobal.getSymbol());
-    auto intPtrTy = rewriter.getIndexType();
-    mlir::Type tdescType =
-        fir::TypeDescType::get(mlir::NoneType::get(rewriter.getContext()));
-    mlir::Value selectorTdescAddr =
-        rewriter.create<fir::BoxTypeDescOp>(loc, tdescType, selector);
-    auto typeDescInt =
-        rewriter.create<fir::ConvertOp>(loc, intPtrTy, typeDescAddr);
-    auto selectorTdescInt =
-        rewriter.create<fir::ConvertOp>(loc, intPtrTy, selectorTdescAddr);
-    return rewriter.create<mlir::arith::CmpIOp>(
-        loc, mlir::arith::CmpIPredicate::eq, typeDescInt, selectorTdescInt);
-  }
-
-  static int getTypeCode(mlir::Type ty, fir::KindMapping &kindMap) {
-    if (auto intTy = ty.dyn_cast<mlir::IntegerType>())
-      return fir::integerBitsToTypeCode(intTy.getWidth());
-    if (auto floatTy = ty.dyn_cast<mlir::FloatType>())
-      return fir::realBitsToTypeCode(floatTy.getWidth());
-    if (auto logicalTy = ty.dyn_cast<fir::LogicalType>())
-      return fir::logicalBitsToTypeCode(
-          kindMap.getLogicalBitsize(logicalTy.getFKind()));
-    if (fir::isa_complex(ty)) {
-      if (auto cmplxTy = ty.dyn_cast<mlir::ComplexType>())
-        return fir::complexBitsToTypeCode(
-            cmplxTy.getElementType().cast<mlir::FloatType>().getWidth());
-      auto cmplxTy = ty.cast<fir::ComplexType>();
-      return fir::complexBitsToTypeCode(
-          kindMap.getRealBitsize(cmplxTy.getFKind()));
-    }
-    if (auto charTy = ty.dyn_cast<fir::CharacterType>())
-      return fir::characterBitsToTypeCode(
-          kindMap.getCharacterBitsize(charTy.getFKind()));
-    return 0;
-  }
-
-  mlir::LogicalResult genTypeLadderStep(mlir::Location loc,
-                                        mlir::Value selector,
-                                        mlir::Attribute attr, mlir::Block *dest,
-                                        std::optional<mlir::ValueRange> destOps,
-                                        mlir::ModuleOp mod,
-                                        mlir::PatternRewriter &rewriter,
-                                        fir::KindMapping &kindMap) const {
-    mlir::Value cmp;
-    // TYPE IS type guard comparison are all done inlined.
-    if (auto a = attr.dyn_cast<fir::ExactTypeAttr>()) {
-      if (fir::isa_trivial(a.getType()) ||
-          a.getType().isa<fir::CharacterType>()) {
-        // For type guard statement with Intrinsic type spec the type code of
-        // the descriptor is compared.
-        int code = getTypeCode(a.getType(), kindMap);
-        if (code == 0)
-          return mlir::emitError(loc)
-                 << "type code unavailable for " << a.getType();
-        mlir::Value typeCode = rewriter.create<mlir::arith::ConstantOp>(
-            loc, rewriter.getI8IntegerAttr(code));
-        mlir::Value selectorTypeCode = rewriter.create<fir::BoxTypeCodeOp>(
-            loc, rewriter.getI8Type(), selector);
-        cmp = rewriter.create<mlir::arith::CmpIOp>(
-            loc, mlir::arith::CmpIPredicate::eq, selectorTypeCode, typeCode);
-      } else {
-        // Flang inline the kind parameter in the type descriptor so we can
-        // directly check if the type descriptor addresses are identical for
-        // the TYPE IS type guard statement.
-        mlir::Value res =
-            genTypeDescCompare(loc, selector, a.getType(), mod, rewriter);
-        if (!res)
-          return mlir::failure();
-        cmp = res;
-      }
-      // CLASS IS type guard statement is done with a runtime call.
-    } else if (auto a = attr.dyn_cast<fir::SubclassAttr>()) {
-      // Retrieve the type descriptor from the type guard statement record type.
-      assert(a.getType().isa<fir::RecordType>() && "expect fir.record type");
-      fir::RecordType recTy = a.getType().dyn_cast<fir::RecordType>();
-      std::string typeDescName =
-          fir::NameUniquer::getTypeDescriptorName(recTy.getName());
-      auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName);
-      auto typeDescAddr = rewriter.create<fir::AddrOfOp>(
-          loc, fir::ReferenceType::get(typeDescGlobal.getType()),
-          typeDescGlobal.getSymbol());
-      mlir::Type typeDescTy = ReferenceType::get(rewriter.getNoneType());
-      mlir::Value typeDesc =
-          rewriter.create<ConvertOp>(loc, typeDescTy, typeDescAddr);
-
-      // Prepare the selector descriptor for the runtime call.
-      mlir::Type descNoneTy = fir::BoxType::get(rewriter.getNoneType());
-      mlir::Value descSelector =
-          rewriter.create<ConvertOp>(loc, descNoneTy, selector);
-
-      // Generate runtime call.
-      llvm::StringRef fctName = RTNAME_STRING(ClassIs);
-      mlir::func::FuncOp callee;
-      {
-        // Since conversion is done in parallel for each fir.select_type
-        // operation, the runtime function insertion must be threadsafe.
-        std::lock_guard<std::mutex> lock(*moduleMutex);
-        callee =
-            fir::createFuncOp(rewriter.getUnknownLoc(), mod, fctName,
-                              rewriter.getFunctionType({descNoneTy, typeDescTy},
-                                                       rewriter.getI1Type()));
-      }
-      cmp = rewriter
-                .create<fir::CallOp>(loc, callee,
-                                     mlir::ValueRange{descSelector, typeDesc})
-                .getResult(0);
-    }
-
-    auto *thisBlock = rewriter.getInsertionBlock();
-    auto *newBlock =
-        rewriter.createBlock(dest->getParent(), mlir::Region::iterator(dest));
-    rewriter.setInsertionPointToEnd(thisBlock);
-    if (destOps.has_value())
-      rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, destOps.value(),
-                                              newBlock, std::nullopt);
-    else
-      rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, newBlock);
-    rewriter.setInsertionPointToEnd(newBlock);
-    return mlir::success();
-  }
-
-private:
-  // Mutex used to guard insertion of mlir::func::FuncOp in the module.
-  std::mutex *moduleMutex;
-};
-
 /// Convert FIR structured control flow ops to CFG ops.
 class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> {
 public:
-  mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override {
-    moduleMutex = new std::mutex();
-    return mlir::success();
-  }
-
   void runOnOperation() override {
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
     patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
         context, forceLoopToExecuteOnce);
-    patterns.insert<CfgSelectTypeConv>(context, moduleMutex);
     mlir::ConversionTarget target(*context);
     target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect,
                            FIROpsDialect, mlir::func::FuncDialect>();
 
     // apply the patterns
-    target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp, SelectTypeOp>();
+    target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
                                                   std::move(patterns)))) {
@@ -588,9 +329,6 @@ class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> {
       signalPassFailure();
     }
   }
-
-private:
-  std::mutex *moduleMutex;
 };
 } // namespace
 

diff  --git a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
new file mode 100644
index 000000000000..580e3ca2191a
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
@@ -0,0 +1,346 @@
+//===-- PolymorphicOpConversion.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Support/FIRContext.h"
+#include "flang/Optimizer/Support/InternalNames.h"
+#include "flang/Optimizer/Support/KindMapping.h"
+#include "flang/Optimizer/Support/TypeCode.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/derived-api.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/CommandLine.h"
+#include <mutex>
+
+namespace fir {
+#define GEN_PASS_DEF_POLYMORPHICOPCONVERSION
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+
+namespace {
+
+/// SelectTypeOp converted to an if-then-else chain
+///
+/// This lowers the test conditions to calls into the runtime.
+class SelectTypeConv : public OpConversionPattern<fir::SelectTypeOp> {
+public:
+  using OpConversionPattern<fir::SelectTypeOp>::OpConversionPattern;
+
+  SelectTypeConv(mlir::MLIRContext *ctx, std::mutex *moduleMutex)
+      : mlir::OpConversionPattern<fir::SelectTypeOp>(ctx),
+        moduleMutex(moduleMutex) {}
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::SelectTypeOp selectType, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override;
+
+private:
+  // Generate comparison of type descriptor addresses.
+  mlir::Value genTypeDescCompare(mlir::Location loc, mlir::Value selector,
+                                 mlir::Type ty, mlir::ModuleOp mod,
+                                 mlir::PatternRewriter &rewriter) const;
+
+  static int getTypeCode(mlir::Type ty, fir::KindMapping &kindMap);
+
+  mlir::LogicalResult genTypeLadderStep(mlir::Location loc,
+                                        mlir::Value selector,
+                                        mlir::Attribute attr, mlir::Block *dest,
+                                        std::optional<mlir::ValueRange> destOps,
+                                        mlir::ModuleOp mod,
+                                        mlir::PatternRewriter &rewriter,
+                                        fir::KindMapping &kindMap) const;
+
+  llvm::SmallSet<llvm::StringRef, 4> collectAncestors(fir::DispatchTableOp dt,
+                                                      mlir::ModuleOp mod) const;
+
+  // Mutex used to guard insertion of mlir::func::FuncOp in the module.
+  std::mutex *moduleMutex;
+};
+
+/// Convert FIR structured control flow ops to CFG ops.
+class PolymorphicOpConversion
+    : public fir::impl::PolymorphicOpConversionBase<PolymorphicOpConversion> {
+public:
+  mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override {
+    moduleMutex = new std::mutex();
+    return mlir::success();
+  }
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    patterns.insert<SelectTypeConv>(context, moduleMutex);
+    mlir::ConversionTarget target(*context);
+    target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect,
+                           FIROpsDialect, mlir::func::FuncDialect>();
+
+    // apply the patterns
+    target.addIllegalOp<SelectTypeOp>();
+    target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+    if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
+                                                  std::move(patterns)))) {
+      mlir::emitError(mlir::UnknownLoc::get(context),
+                      "error in converting to CFG\n");
+      signalPassFailure();
+    }
+  }
+
+private:
+  std::mutex *moduleMutex;
+};
+} // namespace
+
+mlir::LogicalResult SelectTypeConv::matchAndRewrite(
+    fir::SelectTypeOp selectType, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  auto operands = adaptor.getOperands();
+  auto typeGuards = selectType.getCases();
+  unsigned typeGuardNum = typeGuards.size();
+  auto selector = selectType.getSelector();
+  auto loc = selectType.getLoc();
+  auto mod = selectType.getOperation()->getParentOfType<mlir::ModuleOp>();
+  fir::KindMapping kindMap = fir::getKindMapping(mod);
+
+  // Order type guards so the condition and branches are done to respect the
+  // Execution of SELECT TYPE construct as described in the Fortran 2018
+  // standard 11.1.11.2 point 4.
+  // 1. If a TYPE IS type guard statement matches the selector, the block
+  //    following that statement is executed.
+  // 2. Otherwise, if exactly one CLASS IS type guard statement matches the
+  //    selector, the block following that statement is executed.
+  // 3. Otherwise, if several CLASS IS type guard statements match the
+  //    selector, one of these statements will inevitably specify a type that
+  //    is an extension of all the types specified in the others; the block
+  //    following that statement is executed.
+  // 4. Otherwise, if there is a CLASS DEFAULT type guard statement, the block
+  //    following that statement is executed.
+  // 5. Otherwise, no block is executed.
+
+  llvm::SmallVector<unsigned> orderedTypeGuards;
+  llvm::SmallVector<unsigned> orderedClassIsGuards;
+  unsigned defaultGuard = typeGuardNum - 1;
+
+  // The following loop go through the type guards in the fir.select_type
+  // operation and sort them into two lists.
+  // - All the TYPE IS type guard are added in order to the orderedTypeGuards
+  //   list. This list is used at the end to generate the if-then-else ladder.
+  // - CLASS IS type guard are added in a separate list. If a CLASS IS type
+  //   guard type extends a type already present, the type guard is inserted
+  //   before in the list to respect point 3. above. Otherwise it is just
+  //   added in order at the end.
+  for (unsigned t = 0; t < typeGuardNum; ++t) {
+    if (auto a = typeGuards[t].dyn_cast<fir::ExactTypeAttr>()) {
+      orderedTypeGuards.push_back(t);
+      continue;
+    }
+
+    if (auto a = typeGuards[t].dyn_cast<fir::SubclassAttr>()) {
+      if (auto recTy = a.getType().dyn_cast<fir::RecordType>()) {
+        auto dt = mod.lookupSymbol<fir::DispatchTableOp>(recTy.getName());
+        assert(dt && "dispatch table not found");
+        llvm::SmallSet<llvm::StringRef, 4> ancestors =
+            collectAncestors(dt, mod);
+        if (!ancestors.empty()) {
+          auto it = orderedClassIsGuards.begin();
+          while (it != orderedClassIsGuards.end()) {
+            fir::SubclassAttr sAttr =
+                typeGuards[*it].dyn_cast<fir::SubclassAttr>();
+            if (auto ty = sAttr.getType().dyn_cast<fir::RecordType>()) {
+              if (ancestors.contains(ty.getName()))
+                break;
+            }
+            ++it;
+          }
+          if (it != orderedClassIsGuards.end()) {
+            // Parent type is present so place it before.
+            orderedClassIsGuards.insert(it, t);
+            continue;
+          }
+        }
+      }
+      orderedClassIsGuards.push_back(t);
+    }
+  }
+  orderedTypeGuards.append(orderedClassIsGuards);
+  orderedTypeGuards.push_back(defaultGuard);
+  assert(orderedTypeGuards.size() == typeGuardNum &&
+         "ordered type guard size doesn't match number of type guards");
+
+  for (unsigned idx : orderedTypeGuards) {
+    auto *dest = selectType.getSuccessor(idx);
+    std::optional<mlir::ValueRange> destOps =
+        selectType.getSuccessorOperands(operands, idx);
+    if (typeGuards[idx].dyn_cast<mlir::UnitAttr>())
+      rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(selectType, dest);
+    else if (mlir::failed(genTypeLadderStep(loc, selector, typeGuards[idx],
+                                            dest, destOps, mod, rewriter,
+                                            kindMap)))
+      return mlir::failure();
+  }
+  return mlir::success();
+}
+
+mlir::LogicalResult SelectTypeConv::genTypeLadderStep(
+    mlir::Location loc, mlir::Value selector, mlir::Attribute attr,
+    mlir::Block *dest, std::optional<mlir::ValueRange> destOps,
+    mlir::ModuleOp mod, mlir::PatternRewriter &rewriter,
+    fir::KindMapping &kindMap) const {
+  mlir::Value cmp;
+  // TYPE IS type guard comparison are all done inlined.
+  if (auto a = attr.dyn_cast<fir::ExactTypeAttr>()) {
+    if (fir::isa_trivial(a.getType()) ||
+        a.getType().isa<fir::CharacterType>()) {
+      // For type guard statement with Intrinsic type spec the type code of
+      // the descriptor is compared.
+      int code = getTypeCode(a.getType(), kindMap);
+      if (code == 0)
+        return mlir::emitError(loc)
+               << "type code unavailable for " << a.getType();
+      mlir::Value typeCode = rewriter.create<mlir::arith::ConstantOp>(
+          loc, rewriter.getI8IntegerAttr(code));
+      mlir::Value selectorTypeCode = rewriter.create<fir::BoxTypeCodeOp>(
+          loc, rewriter.getI8Type(), selector);
+      cmp = rewriter.create<mlir::arith::CmpIOp>(
+          loc, mlir::arith::CmpIPredicate::eq, selectorTypeCode, typeCode);
+    } else {
+      // Flang inline the kind parameter in the type descriptor so we can
+      // directly check if the type descriptor addresses are identical for
+      // the TYPE IS type guard statement.
+      mlir::Value res =
+          genTypeDescCompare(loc, selector, a.getType(), mod, rewriter);
+      if (!res)
+        return mlir::failure();
+      cmp = res;
+    }
+    // CLASS IS type guard statement is done with a runtime call.
+  } else if (auto a = attr.dyn_cast<fir::SubclassAttr>()) {
+    // Retrieve the type descriptor from the type guard statement record type.
+    assert(a.getType().isa<fir::RecordType>() && "expect fir.record type");
+    fir::RecordType recTy = a.getType().dyn_cast<fir::RecordType>();
+    std::string typeDescName =
+        fir::NameUniquer::getTypeDescriptorName(recTy.getName());
+    auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName);
+    auto typeDescAddr = rewriter.create<fir::AddrOfOp>(
+        loc, fir::ReferenceType::get(typeDescGlobal.getType()),
+        typeDescGlobal.getSymbol());
+    mlir::Type typeDescTy = ReferenceType::get(rewriter.getNoneType());
+    mlir::Value typeDesc =
+        rewriter.create<ConvertOp>(loc, typeDescTy, typeDescAddr);
+
+    // Prepare the selector descriptor for the runtime call.
+    mlir::Type descNoneTy = fir::BoxType::get(rewriter.getNoneType());
+    mlir::Value descSelector =
+        rewriter.create<ConvertOp>(loc, descNoneTy, selector);
+
+    // Generate runtime call.
+    llvm::StringRef fctName = RTNAME_STRING(ClassIs);
+    mlir::func::FuncOp callee;
+    {
+      // Since conversion is done in parallel for each fir.select_type
+      // operation, the runtime function insertion must be threadsafe.
+      std::lock_guard<std::mutex> lock(*moduleMutex);
+      callee =
+          fir::createFuncOp(rewriter.getUnknownLoc(), mod, fctName,
+                            rewriter.getFunctionType({descNoneTy, typeDescTy},
+                                                     rewriter.getI1Type()));
+    }
+    cmp = rewriter
+              .create<fir::CallOp>(loc, callee,
+                                   mlir::ValueRange{descSelector, typeDesc})
+              .getResult(0);
+  }
+
+  auto *thisBlock = rewriter.getInsertionBlock();
+  auto *newBlock =
+      rewriter.createBlock(dest->getParent(), mlir::Region::iterator(dest));
+  rewriter.setInsertionPointToEnd(thisBlock);
+  if (destOps.has_value())
+    rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, destOps.value(),
+                                            newBlock, std::nullopt);
+  else
+    rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, newBlock);
+  rewriter.setInsertionPointToEnd(newBlock);
+  return mlir::success();
+}
+
+// Generate comparison of type descriptor addresses.
+mlir::Value
+SelectTypeConv::genTypeDescCompare(mlir::Location loc, mlir::Value selector,
+                                   mlir::Type ty, mlir::ModuleOp mod,
+                                   mlir::PatternRewriter &rewriter) const {
+  assert(ty.isa<fir::RecordType>() && "expect fir.record type");
+  fir::RecordType recTy = ty.dyn_cast<fir::RecordType>();
+  std::string typeDescName =
+      fir::NameUniquer::getTypeDescriptorName(recTy.getName());
+  auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName);
+  if (!typeDescGlobal)
+    return {};
+  auto typeDescAddr = rewriter.create<fir::AddrOfOp>(
+      loc, fir::ReferenceType::get(typeDescGlobal.getType()),
+      typeDescGlobal.getSymbol());
+  auto intPtrTy = rewriter.getIndexType();
+  mlir::Type tdescType =
+      fir::TypeDescType::get(mlir::NoneType::get(rewriter.getContext()));
+  mlir::Value selectorTdescAddr =
+      rewriter.create<fir::BoxTypeDescOp>(loc, tdescType, selector);
+  auto typeDescInt =
+      rewriter.create<fir::ConvertOp>(loc, intPtrTy, typeDescAddr);
+  auto selectorTdescInt =
+      rewriter.create<fir::ConvertOp>(loc, intPtrTy, selectorTdescAddr);
+  return rewriter.create<mlir::arith::CmpIOp>(
+      loc, mlir::arith::CmpIPredicate::eq, typeDescInt, selectorTdescInt);
+}
+
+int SelectTypeConv::getTypeCode(mlir::Type ty, fir::KindMapping &kindMap) {
+  if (auto intTy = ty.dyn_cast<mlir::IntegerType>())
+    return fir::integerBitsToTypeCode(intTy.getWidth());
+  if (auto floatTy = ty.dyn_cast<mlir::FloatType>())
+    return fir::realBitsToTypeCode(floatTy.getWidth());
+  if (auto logicalTy = ty.dyn_cast<fir::LogicalType>())
+    return fir::logicalBitsToTypeCode(
+        kindMap.getLogicalBitsize(logicalTy.getFKind()));
+  if (fir::isa_complex(ty)) {
+    if (auto cmplxTy = ty.dyn_cast<mlir::ComplexType>())
+      return fir::complexBitsToTypeCode(
+          cmplxTy.getElementType().cast<mlir::FloatType>().getWidth());
+    auto cmplxTy = ty.cast<fir::ComplexType>();
+    return fir::complexBitsToTypeCode(
+        kindMap.getRealBitsize(cmplxTy.getFKind()));
+  }
+  if (auto charTy = ty.dyn_cast<fir::CharacterType>())
+    return fir::characterBitsToTypeCode(
+        kindMap.getCharacterBitsize(charTy.getFKind()));
+  return 0;
+}
+
+llvm::SmallSet<llvm::StringRef, 4>
+SelectTypeConv::collectAncestors(fir::DispatchTableOp dt,
+                                 mlir::ModuleOp mod) const {
+  llvm::SmallSet<llvm::StringRef, 4> ancestors;
+  if (!dt.getParent().has_value())
+    return ancestors;
+  while (dt.getParent().has_value()) {
+    ancestors.insert(*dt.getParent());
+    dt = mod.lookupSymbol<fir::DispatchTableOp>(*dt.getParent());
+  }
+  return ancestors;
+}
+
+std::unique_ptr<mlir::Pass> fir::createPolymorphicOpConversionPass() {
+  return std::make_unique<PolymorphicOpConversion>();
+}

diff  --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index c71c36a18cf3..243a620a9fd0 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -39,6 +39,7 @@
 ! CHECK-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
 ! CHECK-NEXT: 'func.func' Pipeline
+! CHECK-NEXT:   PolymorphicOpConversion
 ! CHECK-NEXT:   CFGConversion
 
 ! CHECK-NEXT: SCFToControlFlow

diff  --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 9ccbe3195334..f569ddac8a39 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -42,6 +42,7 @@
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
 ! ALL-NEXT: 'func.func' Pipeline
+! ALL-NEXT:   PolymorphicOpConversion
 ! ALL-NEXT:   CFGConversion
 
 ! ALL-NEXT: SCFToControlFlow

diff  --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index af7912ebcab2..78c1ab080db1 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -42,6 +42,7 @@ func.func @_QQmain() {
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
 // PASSES-NEXT: 'func.func' Pipeline
+// PASSES-NEXT:   PolymorphicOpConversion
 // PASSES-NEXT:   CFGConversion
 
 // PASSES-NEXT: SCFToControlFlow

diff  --git a/flang/test/Lower/select-type.f90 b/flang/test/Lower/select-type.f90
index 3463cda4e9a9..ef4336fbd261 100644
--- a/flang/test/Lower/select-type.f90
+++ b/flang/test/Lower/select-type.f90
@@ -1,5 +1,5 @@
 ! RUN: bbc -polymorphic-type -emit-fir %s -o - | FileCheck %s
-! RUN: bbc -polymorphic-type -emit-fir %s -o - | fir-opt --cfg-conversion | FileCheck --check-prefix=CFG %s
+! RUN: bbc -polymorphic-type -emit-fir %s -o - | fir-opt --fir-polymorphic-op | FileCheck --check-prefix=CFG %s
 module select_type_lower_test
   type p1
     integer :: a


        


More information about the flang-commits mailing list