[flang-commits] [flang] 308fc3f - [flang] Lower select case statement

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Fri Mar 18 07:41:36 PDT 2022


Author: Valentin Clement
Date: 2022-03-18T15:41:29+01:00
New Revision: 308fc3f27797ce2b0dc01970d6fe2c6c9e1f55c7

URL: https://github.com/llvm/llvm-project/commit/308fc3f27797ce2b0dc01970d6fe2c6c9e1f55c7
DIFF: https://github.com/llvm/llvm-project/commit/308fc3f27797ce2b0dc01970d6fe2c6c9e1f55c7.diff

LOG: [flang] Lower select case statement

This patch adds lowering for the `select case`
statement.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D122007

Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Co-authored-by: V Donaldson <vdonaldson at nvidia.com>

Added: 
    flang/test/Lower/select-case-statement.f90

Modified: 
    flang/lib/Lower/Bridge.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 3f354b7868a3b..a4185a47318c7 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -29,6 +29,7 @@
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Builder/Character.h"
 #include "flang/Optimizer/Builder/MutableBox.h"
+#include "flang/Optimizer/Builder/Runtime/Character.h"
 #include "flang/Optimizer/Builder/Runtime/Ragged.h"
 #include "flang/Optimizer/Dialect/FIRAttr.h"
 #include "flang/Optimizer/Support/FIRContext.h"
@@ -811,6 +812,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
            cat == Fortran::common::TypeCategory::Complex ||
            cat == Fortran::common::TypeCategory::Logical;
   }
+  static bool isLogicalCategory(Fortran::common::TypeCategory cat) {
+    return cat == Fortran::common::TypeCategory::Logical;
+  }
   bool isCharacterCategory(Fortran::common::TypeCategory cat) {
     return cat == Fortran::common::TypeCategory::Character;
   }
@@ -818,6 +822,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return cat == Fortran::common::TypeCategory::Derived;
   }
 
+  /// Insert a new block before \p block.  Leave the insertion point unchanged.
+  mlir::Block *insertBlock(mlir::Block *block) {
+    mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+    mlir::Block *newBlock = builder->createBlock(block);
+    builder->restoreInsertionPoint(insertPt);
+    return newBlock;
+  }
+
   mlir::Block *blockOfLabel(Fortran::lower::pft::Evaluation &eval,
                             Fortran::parser::Label label) {
     const Fortran::lower::pft::LabelEvalMap &labelEvaluationMap =
@@ -1399,7 +1411,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   }
 
   void genFIR(const Fortran::parser::CaseConstruct &) {
-    TODO(toLocation(), "CaseConstruct lowering");
+    for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
+      genFIR(e);
   }
 
   template <typename A>
@@ -1630,8 +1643,170 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     TODO(toLocation(), "OpenMPDeclarativeConstruct lowering");
   }
 
-  void genFIR(const Fortran::parser::SelectCaseStmt &) {
-    TODO(toLocation(), "SelectCaseStmt lowering");
+  /// Generate FIR for a SELECT CASE statement.
+  /// The type may be CHARACTER, INTEGER, or LOGICAL.
+  void genFIR(const Fortran::parser::SelectCaseStmt &stmt) {
+    Fortran::lower::pft::Evaluation &eval = getEval();
+    MLIRContext *context = builder->getContext();
+    mlir::Location loc = toLocation();
+    Fortran::lower::StatementContext stmtCtx;
+    const Fortran::lower::SomeExpr *expr = Fortran::semantics::GetExpr(
+        std::get<Fortran::parser::Scalar<Fortran::parser::Expr>>(stmt.t));
+    bool isCharSelector = isCharacterCategory(expr->GetType()->category());
+    bool isLogicalSelector = isLogicalCategory(expr->GetType()->category());
+    auto charValue = [&](const Fortran::lower::SomeExpr *expr) {
+      fir::ExtendedValue exv = genExprAddr(*expr, stmtCtx, &loc);
+      return exv.match(
+          [&](const fir::CharBoxValue &cbv) {
+            return fir::factory::CharacterExprHelper{*builder, loc}
+                .createEmboxChar(cbv.getAddr(), cbv.getLen());
+          },
+          [&](auto) {
+            fir::emitFatalError(loc, "not a character");
+            return mlir::Value{};
+          });
+    };
+    mlir::Value selector;
+    if (isCharSelector) {
+      selector = charValue(expr);
+    } else {
+      selector = createFIRExpr(loc, expr, stmtCtx);
+      if (isLogicalSelector)
+        selector = builder->createConvert(loc, builder->getI1Type(), selector);
+    }
+    mlir::Type selectType = selector.getType();
+    llvm::SmallVector<mlir::Attribute> attrList;
+    llvm::SmallVector<mlir::Value> valueList;
+    llvm::SmallVector<mlir::Block *> blockList;
+    mlir::Block *defaultBlock = eval.parentConstruct->constructExit->block;
+    using CaseValue = Fortran::parser::Scalar<Fortran::parser::ConstantExpr>;
+    auto addValue = [&](const CaseValue &caseValue) {
+      const Fortran::lower::SomeExpr *expr =
+          Fortran::semantics::GetExpr(caseValue.thing);
+      if (isCharSelector)
+        valueList.push_back(charValue(expr));
+      else if (isLogicalSelector)
+        valueList.push_back(builder->createConvert(
+            loc, selectType, createFIRExpr(toLocation(), expr, stmtCtx)));
+      else
+        valueList.push_back(builder->createIntegerConstant(
+            loc, selectType, *Fortran::evaluate::ToInt64(*expr)));
+    };
+    for (Fortran::lower::pft::Evaluation *e = eval.controlSuccessor; e;
+         e = e->controlSuccessor) {
+      const auto &caseStmt = e->getIf<Fortran::parser::CaseStmt>();
+      assert(e->block && "missing CaseStmt block");
+      const auto &caseSelector =
+          std::get<Fortran::parser::CaseSelector>(caseStmt->t);
+      const auto *caseValueRangeList =
+          std::get_if<std::list<Fortran::parser::CaseValueRange>>(
+              &caseSelector.u);
+      if (!caseValueRangeList) {
+        defaultBlock = e->block;
+        continue;
+      }
+      for (const Fortran::parser::CaseValueRange &caseValueRange :
+           *caseValueRangeList) {
+        blockList.push_back(e->block);
+        if (const auto *caseValue = std::get_if<CaseValue>(&caseValueRange.u)) {
+          attrList.push_back(fir::PointIntervalAttr::get(context));
+          addValue(*caseValue);
+          continue;
+        }
+        const auto &caseRange =
+            std::get<Fortran::parser::CaseValueRange::Range>(caseValueRange.u);
+        if (caseRange.lower && caseRange.upper) {
+          attrList.push_back(fir::ClosedIntervalAttr::get(context));
+          addValue(*caseRange.lower);
+          addValue(*caseRange.upper);
+        } else if (caseRange.lower) {
+          attrList.push_back(fir::LowerBoundAttr::get(context));
+          addValue(*caseRange.lower);
+        } else {
+          attrList.push_back(fir::UpperBoundAttr::get(context));
+          addValue(*caseRange.upper);
+        }
+      }
+    }
+    // Skip a logical default block that can never be referenced.
+    if (isLogicalSelector && attrList.size() == 2)
+      defaultBlock = eval.parentConstruct->constructExit->block;
+    attrList.push_back(mlir::UnitAttr::get(context));
+    blockList.push_back(defaultBlock);
+
+    // Generate a fir::SelectCaseOp.
+    // Explicit branch code is better for the LOGICAL type.  The CHARACTER type
+    // does not yet have downstream support, and also uses explicit branch code.
+    // The -no-structured-fir option can be used to force generation of INTEGER
+    // type branch code.
+    if (!isLogicalSelector && !isCharSelector && eval.lowerAsStructured()) {
+      // Numeric selector is a ssa register, all temps that may have
+      // been generated while evaluating it can be cleaned-up before the
+      // fir.select_case.
+      stmtCtx.finalize();
+      builder->create<fir::SelectCaseOp>(loc, selector, attrList, valueList,
+                                         blockList);
+      return;
+    }
+
+    // Generate a sequence of case value comparisons and branches.
+    auto caseValue = valueList.begin();
+    auto caseBlock = blockList.begin();
+    for (mlir::Attribute attr : attrList) {
+      if (attr.isa<mlir::UnitAttr>()) {
+        genFIRBranch(*caseBlock++);
+        break;
+      }
+      auto genCond = [&](mlir::Value rhs,
+                         mlir::arith::CmpIPredicate pred) -> mlir::Value {
+        if (!isCharSelector)
+          return builder->create<mlir::arith::CmpIOp>(loc, pred, selector, rhs);
+        fir::factory::CharacterExprHelper charHelper{*builder, loc};
+        std::pair<mlir::Value, mlir::Value> lhsVal =
+            charHelper.createUnboxChar(selector);
+        mlir::Value &lhsAddr = lhsVal.first;
+        mlir::Value &lhsLen = lhsVal.second;
+        std::pair<mlir::Value, mlir::Value> rhsVal =
+            charHelper.createUnboxChar(rhs);
+        mlir::Value &rhsAddr = rhsVal.first;
+        mlir::Value &rhsLen = rhsVal.second;
+        return fir::runtime::genCharCompare(*builder, loc, pred, lhsAddr,
+                                            lhsLen, rhsAddr, rhsLen);
+      };
+      mlir::Block *newBlock = insertBlock(*caseBlock);
+      if (attr.isa<fir::ClosedIntervalAttr>()) {
+        mlir::Block *newBlock2 = insertBlock(*caseBlock);
+        mlir::Value cond =
+            genCond(*caseValue++, mlir::arith::CmpIPredicate::sge);
+        genFIRConditionalBranch(cond, newBlock, newBlock2);
+        builder->setInsertionPointToEnd(newBlock);
+        mlir::Value cond2 =
+            genCond(*caseValue++, mlir::arith::CmpIPredicate::sle);
+        genFIRConditionalBranch(cond2, *caseBlock++, newBlock2);
+        builder->setInsertionPointToEnd(newBlock2);
+        continue;
+      }
+      mlir::arith::CmpIPredicate pred;
+      if (attr.isa<fir::PointIntervalAttr>()) {
+        pred = mlir::arith::CmpIPredicate::eq;
+      } else if (attr.isa<fir::LowerBoundAttr>()) {
+        pred = mlir::arith::CmpIPredicate::sge;
+      } else {
+        assert(attr.isa<fir::UpperBoundAttr>() && "unexpected predicate");
+        pred = mlir::arith::CmpIPredicate::sle;
+      }
+      mlir::Value cond = genCond(*caseValue++, pred);
+      genFIRConditionalBranch(cond, *caseBlock++, newBlock);
+      builder->setInsertionPointToEnd(newBlock);
+    }
+    assert(caseValue == valueList.end() && caseBlock == blockList.end() &&
+           "select case list mismatch");
+    // Clean-up the selector at the end of the construct if it is a temporary
+    // (which is possible with characters).
+    mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+    builder->setInsertionPointToEnd(eval.parentConstruct->constructExit->block);
+    stmtCtx.finalize();
+    builder->restoreInsertionPoint(insertPt);
   }
 
   fir::ExtendedValue
@@ -2115,10 +2290,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     genFIRBranch(getEval().controlSuccessor->block);
   }
 
-  void genFIR(const Fortran::parser::CaseStmt &) {
-    TODO(toLocation(), "CaseStmt lowering");
-  }
-
   void genFIR(const Fortran::parser::ElseIfStmt &) {
     TODO(toLocation(), "ElseIfStmt lowering");
   }
@@ -2135,16 +2306,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     TODO(toLocation(), "EndMpSubprogramStmt lowering");
   }
 
-  void genFIR(const Fortran::parser::EndSelectStmt &) {
-    TODO(toLocation(), "EndSelectStmt lowering");
-  }
-
   // Nop statements - No code, or code is generated at the construct level.
   void genFIR(const Fortran::parser::AssociateStmt &) {}     // nop
+  void genFIR(const Fortran::parser::CaseStmt &) {}          // nop
   void genFIR(const Fortran::parser::ContinueStmt &) {}      // nop
   void genFIR(const Fortran::parser::EndAssociateStmt &) {}  // nop
   void genFIR(const Fortran::parser::EndFunctionStmt &) {}   // nop
   void genFIR(const Fortran::parser::EndIfStmt &) {}         // nop
+  void genFIR(const Fortran::parser::EndSelectStmt &) {}     // nop
   void genFIR(const Fortran::parser::EndSubroutineStmt &) {} // nop
   void genFIR(const Fortran::parser::EntryStmt &) {}         // nop
 
@@ -2168,6 +2337,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     TODO(toLocation(), "NamelistStmt lowering");
   }
 
+  /// Generate FIR for the Evaluation `eval`.
   void genFIR(Fortran::lower::pft::Evaluation &eval,
               bool unstructuredContext = true) {
     if (unstructuredContext) {
@@ -2181,6 +2351,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     setCurrentEval(eval);
     setCurrentPosition(eval.position);
     eval.visit([&](const auto &stmt) { genFIR(stmt); });
+
+    if (unstructuredContext && blockIsUnterminated()) {
+      // Exit from an unstructured IF or SELECT construct block.
+      Fortran::lower::pft::Evaluation *successor{};
+      if (eval.isActionStmt())
+        successor = eval.controlSuccessor;
+      else if (eval.isConstruct() &&
+               eval.getLastNestedEvaluation()
+                   .lexicalSuccessor->isIntermediateConstructStmt())
+        successor = eval.constructExit;
+      if (successor && successor->block)
+        genFIRBranch(successor->block);
+    }
   }
 
   //===--------------------------------------------------------------------===//

diff  --git a/flang/test/Lower/select-case-statement.f90 b/flang/test/Lower/select-case-statement.f90
new file mode 100644
index 0000000000000..2efbcc036dd67
--- /dev/null
+++ b/flang/test/Lower/select-case-statement.f90
@@ -0,0 +1,211 @@
+! RUN: bbc -emit-fir -o - %s | FileCheck %s
+
+  ! CHECK-LABEL: sinteger
+  function sinteger(n)
+    integer sinteger
+    nn = -88
+    ! CHECK: fir.select_case {{.*}} : i32
+    ! CHECK-SAME: upper, %c1
+    ! CHECK-SAME: point, %c2
+    ! CHECK-SAME: point, %c3
+    ! CHECK-SAME: interval, %c4{{.*}} %c5
+    ! CHECK-SAME: point, %c6
+    ! CHECK-SAME: point, %c7
+    ! CHECK-SAME: interval, %c8{{.*}} %c15
+    ! CHECK-SAME: lower, %c21
+    ! CHECK-SAME: unit
+    select case(n)
+    case (:1)
+      nn = 1
+    case (2)
+      nn = 2
+    case default
+      nn = 0
+    case (3)
+      nn = 3
+    case (4:5+1-1)
+      nn = 4
+    case (6)
+      nn = 6
+    case (7,8:15,21:)
+      nn = 7
+    end select
+    sinteger = nn
+  end
+
+  ! CHECK-LABEL: slogical
+  subroutine slogical(L)
+    logical :: L
+    n1 = 0
+    n2 = 0
+    n3 = 0
+    n4 = 0
+    n5 = 0
+    n6 = 0
+    n7 = 0
+    n8 = 0
+
+    select case (L)
+    end select
+
+    select case (L)
+      ! CHECK: cmpi eq, {{.*}} %false
+      ! CHECK: cond_br
+      case (.false.)
+        n2 = 1
+    end select
+
+    select case (L)
+      ! CHECK: cmpi eq, {{.*}} %true
+      ! CHECK: cond_br
+      case (.true.)
+        n3 = 2
+    end select
+
+    select case (L)
+      case default
+        n4 = 3
+    end select
+
+    select case (L)
+      ! CHECK: cmpi eq, {{.*}} %false
+      ! CHECK: cond_br
+      case (.false.)
+        n5 = 1
+      ! CHECK: cmpi eq, {{.*}} %true
+      ! CHECK: cond_br
+      case (.true.)
+        n5 = 2
+    end select
+
+    select case (L)
+      ! CHECK: cmpi eq, {{.*}} %false
+      ! CHECK: cond_br
+      case (.false.)
+        n6 = 1
+      case default
+        n6 = 3
+    end select
+
+    select case (L)
+      ! CHECK: cmpi eq, {{.*}} %true
+      ! CHECK: cond_br
+      case (.true.)
+        n7 = 2
+      case default
+        n7 = 3
+    end select
+
+    select case (L)
+      ! CHECK: cmpi eq, {{.*}} %false
+      ! CHECK: cond_br
+      case (.false.)
+        n8 = 1
+      ! CHECK: cmpi eq, {{.*}} %true
+      ! CHECK: cond_br
+      case (.true.)
+        n8 = 2
+      ! CHECK-NOT: constant 888
+      case default ! dead
+        n8 = 888
+    end select
+
+    print*, n1, n2, n3, n4, n5, n6, n7, n8
+  end
+
+  ! CHECK-LABEL: scharacter
+  subroutine scharacter(c)
+    character(*) :: c
+    nn = 0
+    select case (c)
+      case default
+        nn = -1
+      ! CHECK: CharacterCompareScalar1
+      ! CHECK-NEXT: constant 0
+      ! CHECK-NEXT: cmpi sle, {{.*}} %c0
+      ! CHECK-NEXT: cond_br
+      case (:'d')
+        nn = 10
+      ! CHECK: CharacterCompareScalar1
+      ! CHECK-NEXT: constant 0
+      ! CHECK-NEXT: cmpi sge, {{.*}} %c0
+      ! CHECK-NEXT: cond_br
+      ! CHECK: CharacterCompareScalar1
+      ! CHECK-NEXT: constant 0
+      ! CHECK-NEXT: cmpi sle, {{.*}} %c0
+      ! CHECK-NEXT: cond_br
+      case ('ff':'ffff')
+        nn = 20
+      ! CHECK: CharacterCompareScalar1
+      ! CHECK-NEXT: constant 0
+      ! CHECK-NEXT: cmpi eq, {{.*}} %c0
+      ! CHECK-NEXT: cond_br
+      case ('m')
+        nn = 30
+      ! CHECK: CharacterCompareScalar1
+      ! CHECK-NEXT: constant 0
+      ! CHECK-NEXT: cmpi eq, {{.*}} %c0
+      ! CHECK-NEXT: cond_br
+      case ('qq')
+        nn = 40
+      ! CHECK: CharacterCompareScalar1
+      ! CHECK-NEXT: constant 0
+      ! CHECK-NEXT: cmpi sge, {{.*}} %c0
+      ! CHECK-NEXT: cond_br
+      case ('x':)
+        nn = 50
+    end select
+    print*, nn
+  end
+
+  ! CHECK-LABEL: func @_QPtest_char_temp_selector
+  subroutine test_char_temp_selector()
+    ! Test that character selector that are temps are deallocated
+    ! only after they have been used in the select case comparisons.
+    interface
+      function gen_char_temp_selector()
+        character(:), allocatable :: gen_char_temp_selector
+      end function
+    end interface
+    select case (gen_char_temp_selector())
+    case ('case1')
+      call foo1()
+    case ('case2')
+      call foo2()
+    case ('case3')
+      call foo3()
+    case default
+      call foo_default()
+    end select
+    ! CHECK:   %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>> {bindc_name = ".result"}
+    ! CHECK:   %[[VAL_1:.*]] = fir.call @_QPgen_char_temp_selector() : () -> !fir.box<!fir.heap<!fir.char<1,?>>>
+    ! CHECK:   fir.save_result %[[VAL_1]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.char<1,?>>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+    ! CHECK:   cond_br %{{.*}}, ^bb2, ^bb1
+    ! CHECK: ^bb1:
+    ! CHECK:   cond_br %{{.*}}, ^bb4, ^bb3
+    ! CHECK: ^bb2:
+    ! CHECK:   fir.call @_QPfoo1() : () -> ()
+    ! CHECK:   br ^bb8
+    ! CHECK: ^bb3:
+    ! CHECK:   cond_br %{{.*}}, ^bb6, ^bb5
+    ! CHECK: ^bb4:
+    ! CHECK:   fir.call @_QPfoo2() : () -> ()
+    ! CHECK:   br ^bb8
+    ! CHECK: ^bb5:
+    ! CHECK:   br ^bb7
+    ! CHECK: ^bb6:
+    ! CHECK:   fir.call @_QPfoo3() : () -> ()
+    ! CHECK:   br ^bb8
+    ! CHECK: ^bb7:
+    ! CHECK:   fir.call @_QPfoo_default() : () -> ()
+    ! CHECK:   br ^bb8
+    ! CHECK: ^bb8:
+    ! CHECK:   %[[VAL_36:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+    ! CHECK:   %[[VAL_37:.*]] = fir.box_addr %[[VAL_36]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
+    ! CHECK:   %[[VAL_38:.*]] = fir.convert %[[VAL_37]] : (!fir.heap<!fir.char<1,?>>) -> i64
+    ! CHECK:   %[[VAL_39:.*]] = arith.constant 0 : i64
+    ! CHECK:   %[[VAL_40:.*]] = arith.cmpi ne, %[[VAL_38]], %[[VAL_39]] : i64
+    ! CHECK:   fir.if %[[VAL_40]] {
+    ! CHECK:     fir.freemem %[[VAL_37]]
+    ! CHECK:   }
+  end subroutine


        


More information about the flang-commits mailing list