[flang-commits] [flang] 308fc3f - [flang] Lower select case statement
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Fri Mar 18 07:41:36 PDT 2022
Author: Valentin Clement
Date: 2022-03-18T15:41:29+01:00
New Revision: 308fc3f27797ce2b0dc01970d6fe2c6c9e1f55c7
URL: https://github.com/llvm/llvm-project/commit/308fc3f27797ce2b0dc01970d6fe2c6c9e1f55c7
DIFF: https://github.com/llvm/llvm-project/commit/308fc3f27797ce2b0dc01970d6fe2c6c9e1f55c7.diff
LOG: [flang] Lower select case statement
This patch adds lowering for the `select case`
statement.
This patch is part of the upstreaming effort from fir-dev branch.
Reviewed By: jeanPerier
Differential Revision: https://reviews.llvm.org/D122007
Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Co-authored-by: V Donaldson <vdonaldson at nvidia.com>
Added:
flang/test/Lower/select-case-statement.f90
Modified:
flang/lib/Lower/Bridge.cpp
Removed:
################################################################################
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 3f354b7868a3b..a4185a47318c7 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -29,6 +29,7 @@
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/MutableBox.h"
+#include "flang/Optimizer/Builder/Runtime/Character.h"
#include "flang/Optimizer/Builder/Runtime/Ragged.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Support/FIRContext.h"
@@ -811,6 +812,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
cat == Fortran::common::TypeCategory::Complex ||
cat == Fortran::common::TypeCategory::Logical;
}
+ static bool isLogicalCategory(Fortran::common::TypeCategory cat) {
+ return cat == Fortran::common::TypeCategory::Logical;
+ }
bool isCharacterCategory(Fortran::common::TypeCategory cat) {
return cat == Fortran::common::TypeCategory::Character;
}
@@ -818,6 +822,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return cat == Fortran::common::TypeCategory::Derived;
}
+ /// Insert a new block before \p block. Leave the insertion point unchanged.
+ mlir::Block *insertBlock(mlir::Block *block) {
+ mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+ mlir::Block *newBlock = builder->createBlock(block);
+ builder->restoreInsertionPoint(insertPt);
+ return newBlock;
+ }
+
mlir::Block *blockOfLabel(Fortran::lower::pft::Evaluation &eval,
Fortran::parser::Label label) {
const Fortran::lower::pft::LabelEvalMap &labelEvaluationMap =
@@ -1399,7 +1411,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
void genFIR(const Fortran::parser::CaseConstruct &) {
- TODO(toLocation(), "CaseConstruct lowering");
+ for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
+ genFIR(e);
}
template <typename A>
@@ -1630,8 +1643,170 @@ class FirConverter : public Fortran::lower::AbstractConverter {
TODO(toLocation(), "OpenMPDeclarativeConstruct lowering");
}
- void genFIR(const Fortran::parser::SelectCaseStmt &) {
- TODO(toLocation(), "SelectCaseStmt lowering");
+ /// Generate FIR for a SELECT CASE statement.
+ /// The type may be CHARACTER, INTEGER, or LOGICAL.
+ void genFIR(const Fortran::parser::SelectCaseStmt &stmt) {
+ Fortran::lower::pft::Evaluation &eval = getEval();
+ MLIRContext *context = builder->getContext();
+ mlir::Location loc = toLocation();
+ Fortran::lower::StatementContext stmtCtx;
+ const Fortran::lower::SomeExpr *expr = Fortran::semantics::GetExpr(
+ std::get<Fortran::parser::Scalar<Fortran::parser::Expr>>(stmt.t));
+ bool isCharSelector = isCharacterCategory(expr->GetType()->category());
+ bool isLogicalSelector = isLogicalCategory(expr->GetType()->category());
+ auto charValue = [&](const Fortran::lower::SomeExpr *expr) {
+ fir::ExtendedValue exv = genExprAddr(*expr, stmtCtx, &loc);
+ return exv.match(
+ [&](const fir::CharBoxValue &cbv) {
+ return fir::factory::CharacterExprHelper{*builder, loc}
+ .createEmboxChar(cbv.getAddr(), cbv.getLen());
+ },
+ [&](auto) {
+ fir::emitFatalError(loc, "not a character");
+ return mlir::Value{};
+ });
+ };
+ mlir::Value selector;
+ if (isCharSelector) {
+ selector = charValue(expr);
+ } else {
+ selector = createFIRExpr(loc, expr, stmtCtx);
+ if (isLogicalSelector)
+ selector = builder->createConvert(loc, builder->getI1Type(), selector);
+ }
+ mlir::Type selectType = selector.getType();
+ llvm::SmallVector<mlir::Attribute> attrList;
+ llvm::SmallVector<mlir::Value> valueList;
+ llvm::SmallVector<mlir::Block *> blockList;
+ mlir::Block *defaultBlock = eval.parentConstruct->constructExit->block;
+ using CaseValue = Fortran::parser::Scalar<Fortran::parser::ConstantExpr>;
+ auto addValue = [&](const CaseValue &caseValue) {
+ const Fortran::lower::SomeExpr *expr =
+ Fortran::semantics::GetExpr(caseValue.thing);
+ if (isCharSelector)
+ valueList.push_back(charValue(expr));
+ else if (isLogicalSelector)
+ valueList.push_back(builder->createConvert(
+ loc, selectType, createFIRExpr(toLocation(), expr, stmtCtx)));
+ else
+ valueList.push_back(builder->createIntegerConstant(
+ loc, selectType, *Fortran::evaluate::ToInt64(*expr)));
+ };
+ for (Fortran::lower::pft::Evaluation *e = eval.controlSuccessor; e;
+ e = e->controlSuccessor) {
+ const auto &caseStmt = e->getIf<Fortran::parser::CaseStmt>();
+ assert(e->block && "missing CaseStmt block");
+ const auto &caseSelector =
+ std::get<Fortran::parser::CaseSelector>(caseStmt->t);
+ const auto *caseValueRangeList =
+ std::get_if<std::list<Fortran::parser::CaseValueRange>>(
+ &caseSelector.u);
+ if (!caseValueRangeList) {
+ defaultBlock = e->block;
+ continue;
+ }
+ for (const Fortran::parser::CaseValueRange &caseValueRange :
+ *caseValueRangeList) {
+ blockList.push_back(e->block);
+ if (const auto *caseValue = std::get_if<CaseValue>(&caseValueRange.u)) {
+ attrList.push_back(fir::PointIntervalAttr::get(context));
+ addValue(*caseValue);
+ continue;
+ }
+ const auto &caseRange =
+ std::get<Fortran::parser::CaseValueRange::Range>(caseValueRange.u);
+ if (caseRange.lower && caseRange.upper) {
+ attrList.push_back(fir::ClosedIntervalAttr::get(context));
+ addValue(*caseRange.lower);
+ addValue(*caseRange.upper);
+ } else if (caseRange.lower) {
+ attrList.push_back(fir::LowerBoundAttr::get(context));
+ addValue(*caseRange.lower);
+ } else {
+ attrList.push_back(fir::UpperBoundAttr::get(context));
+ addValue(*caseRange.upper);
+ }
+ }
+ }
+ // Skip a logical default block that can never be referenced.
+ if (isLogicalSelector && attrList.size() == 2)
+ defaultBlock = eval.parentConstruct->constructExit->block;
+ attrList.push_back(mlir::UnitAttr::get(context));
+ blockList.push_back(defaultBlock);
+
+ // Generate a fir::SelectCaseOp.
+ // Explicit branch code is better for the LOGICAL type. The CHARACTER type
+ // does not yet have downstream support, and also uses explicit branch code.
+ // The -no-structured-fir option can be used to force generation of INTEGER
+ // type branch code.
+ if (!isLogicalSelector && !isCharSelector && eval.lowerAsStructured()) {
+ // Numeric selector is a ssa register, all temps that may have
+ // been generated while evaluating it can be cleaned-up before the
+ // fir.select_case.
+ stmtCtx.finalize();
+ builder->create<fir::SelectCaseOp>(loc, selector, attrList, valueList,
+ blockList);
+ return;
+ }
+
+ // Generate a sequence of case value comparisons and branches.
+ auto caseValue = valueList.begin();
+ auto caseBlock = blockList.begin();
+ for (mlir::Attribute attr : attrList) {
+ if (attr.isa<mlir::UnitAttr>()) {
+ genFIRBranch(*caseBlock++);
+ break;
+ }
+ auto genCond = [&](mlir::Value rhs,
+ mlir::arith::CmpIPredicate pred) -> mlir::Value {
+ if (!isCharSelector)
+ return builder->create<mlir::arith::CmpIOp>(loc, pred, selector, rhs);
+ fir::factory::CharacterExprHelper charHelper{*builder, loc};
+ std::pair<mlir::Value, mlir::Value> lhsVal =
+ charHelper.createUnboxChar(selector);
+ mlir::Value &lhsAddr = lhsVal.first;
+ mlir::Value &lhsLen = lhsVal.second;
+ std::pair<mlir::Value, mlir::Value> rhsVal =
+ charHelper.createUnboxChar(rhs);
+ mlir::Value &rhsAddr = rhsVal.first;
+ mlir::Value &rhsLen = rhsVal.second;
+ return fir::runtime::genCharCompare(*builder, loc, pred, lhsAddr,
+ lhsLen, rhsAddr, rhsLen);
+ };
+ mlir::Block *newBlock = insertBlock(*caseBlock);
+ if (attr.isa<fir::ClosedIntervalAttr>()) {
+ mlir::Block *newBlock2 = insertBlock(*caseBlock);
+ mlir::Value cond =
+ genCond(*caseValue++, mlir::arith::CmpIPredicate::sge);
+ genFIRConditionalBranch(cond, newBlock, newBlock2);
+ builder->setInsertionPointToEnd(newBlock);
+ mlir::Value cond2 =
+ genCond(*caseValue++, mlir::arith::CmpIPredicate::sle);
+ genFIRConditionalBranch(cond2, *caseBlock++, newBlock2);
+ builder->setInsertionPointToEnd(newBlock2);
+ continue;
+ }
+ mlir::arith::CmpIPredicate pred;
+ if (attr.isa<fir::PointIntervalAttr>()) {
+ pred = mlir::arith::CmpIPredicate::eq;
+ } else if (attr.isa<fir::LowerBoundAttr>()) {
+ pred = mlir::arith::CmpIPredicate::sge;
+ } else {
+ assert(attr.isa<fir::UpperBoundAttr>() && "unexpected predicate");
+ pred = mlir::arith::CmpIPredicate::sle;
+ }
+ mlir::Value cond = genCond(*caseValue++, pred);
+ genFIRConditionalBranch(cond, *caseBlock++, newBlock);
+ builder->setInsertionPointToEnd(newBlock);
+ }
+ assert(caseValue == valueList.end() && caseBlock == blockList.end() &&
+ "select case list mismatch");
+ // Clean-up the selector at the end of the construct if it is a temporary
+ // (which is possible with characters).
+ mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+ builder->setInsertionPointToEnd(eval.parentConstruct->constructExit->block);
+ stmtCtx.finalize();
+ builder->restoreInsertionPoint(insertPt);
}
fir::ExtendedValue
@@ -2115,10 +2290,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genFIRBranch(getEval().controlSuccessor->block);
}
- void genFIR(const Fortran::parser::CaseStmt &) {
- TODO(toLocation(), "CaseStmt lowering");
- }
-
void genFIR(const Fortran::parser::ElseIfStmt &) {
TODO(toLocation(), "ElseIfStmt lowering");
}
@@ -2135,16 +2306,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
TODO(toLocation(), "EndMpSubprogramStmt lowering");
}
- void genFIR(const Fortran::parser::EndSelectStmt &) {
- TODO(toLocation(), "EndSelectStmt lowering");
- }
-
// Nop statements - No code, or code is generated at the construct level.
void genFIR(const Fortran::parser::AssociateStmt &) {} // nop
+ void genFIR(const Fortran::parser::CaseStmt &) {} // nop
void genFIR(const Fortran::parser::ContinueStmt &) {} // nop
void genFIR(const Fortran::parser::EndAssociateStmt &) {} // nop
void genFIR(const Fortran::parser::EndFunctionStmt &) {} // nop
void genFIR(const Fortran::parser::EndIfStmt &) {} // nop
+ void genFIR(const Fortran::parser::EndSelectStmt &) {} // nop
void genFIR(const Fortran::parser::EndSubroutineStmt &) {} // nop
void genFIR(const Fortran::parser::EntryStmt &) {} // nop
@@ -2168,6 +2337,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
TODO(toLocation(), "NamelistStmt lowering");
}
+ /// Generate FIR for the Evaluation `eval`.
void genFIR(Fortran::lower::pft::Evaluation &eval,
bool unstructuredContext = true) {
if (unstructuredContext) {
@@ -2181,6 +2351,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
setCurrentEval(eval);
setCurrentPosition(eval.position);
eval.visit([&](const auto &stmt) { genFIR(stmt); });
+
+ if (unstructuredContext && blockIsUnterminated()) {
+ // Exit from an unstructured IF or SELECT construct block.
+ Fortran::lower::pft::Evaluation *successor{};
+ if (eval.isActionStmt())
+ successor = eval.controlSuccessor;
+ else if (eval.isConstruct() &&
+ eval.getLastNestedEvaluation()
+ .lexicalSuccessor->isIntermediateConstructStmt())
+ successor = eval.constructExit;
+ if (successor && successor->block)
+ genFIRBranch(successor->block);
+ }
}
//===--------------------------------------------------------------------===//
diff --git a/flang/test/Lower/select-case-statement.f90 b/flang/test/Lower/select-case-statement.f90
new file mode 100644
index 0000000000000..2efbcc036dd67
--- /dev/null
+++ b/flang/test/Lower/select-case-statement.f90
@@ -0,0 +1,211 @@
+! RUN: bbc -emit-fir -o - %s | FileCheck %s
+
+ ! CHECK-LABEL: sinteger
+ function sinteger(n)
+ integer sinteger
+ nn = -88
+ ! CHECK: fir.select_case {{.*}} : i32
+ ! CHECK-SAME: upper, %c1
+ ! CHECK-SAME: point, %c2
+ ! CHECK-SAME: point, %c3
+ ! CHECK-SAME: interval, %c4{{.*}} %c5
+ ! CHECK-SAME: point, %c6
+ ! CHECK-SAME: point, %c7
+ ! CHECK-SAME: interval, %c8{{.*}} %c15
+ ! CHECK-SAME: lower, %c21
+ ! CHECK-SAME: unit
+ select case(n)
+ case (:1)
+ nn = 1
+ case (2)
+ nn = 2
+ case default
+ nn = 0
+ case (3)
+ nn = 3
+ case (4:5+1-1)
+ nn = 4
+ case (6)
+ nn = 6
+ case (7,8:15,21:)
+ nn = 7
+ end select
+ sinteger = nn
+ end
+
+ ! CHECK-LABEL: slogical
+ subroutine slogical(L)
+ logical :: L
+ n1 = 0
+ n2 = 0
+ n3 = 0
+ n4 = 0
+ n5 = 0
+ n6 = 0
+ n7 = 0
+ n8 = 0
+
+ select case (L)
+ end select
+
+ select case (L)
+ ! CHECK: cmpi eq, {{.*}} %false
+ ! CHECK: cond_br
+ case (.false.)
+ n2 = 1
+ end select
+
+ select case (L)
+ ! CHECK: cmpi eq, {{.*}} %true
+ ! CHECK: cond_br
+ case (.true.)
+ n3 = 2
+ end select
+
+ select case (L)
+ case default
+ n4 = 3
+ end select
+
+ select case (L)
+ ! CHECK: cmpi eq, {{.*}} %false
+ ! CHECK: cond_br
+ case (.false.)
+ n5 = 1
+ ! CHECK: cmpi eq, {{.*}} %true
+ ! CHECK: cond_br
+ case (.true.)
+ n5 = 2
+ end select
+
+ select case (L)
+ ! CHECK: cmpi eq, {{.*}} %false
+ ! CHECK: cond_br
+ case (.false.)
+ n6 = 1
+ case default
+ n6 = 3
+ end select
+
+ select case (L)
+ ! CHECK: cmpi eq, {{.*}} %true
+ ! CHECK: cond_br
+ case (.true.)
+ n7 = 2
+ case default
+ n7 = 3
+ end select
+
+ select case (L)
+ ! CHECK: cmpi eq, {{.*}} %false
+ ! CHECK: cond_br
+ case (.false.)
+ n8 = 1
+ ! CHECK: cmpi eq, {{.*}} %true
+ ! CHECK: cond_br
+ case (.true.)
+ n8 = 2
+ ! CHECK-NOT: constant 888
+ case default ! dead
+ n8 = 888
+ end select
+
+ print*, n1, n2, n3, n4, n5, n6, n7, n8
+ end
+
+ ! CHECK-LABEL: scharacter
+ subroutine scharacter(c)
+ character(*) :: c
+ nn = 0
+ select case (c)
+ case default
+ nn = -1
+ ! CHECK: CharacterCompareScalar1
+ ! CHECK-NEXT: constant 0
+ ! CHECK-NEXT: cmpi sle, {{.*}} %c0
+ ! CHECK-NEXT: cond_br
+ case (:'d')
+ nn = 10
+ ! CHECK: CharacterCompareScalar1
+ ! CHECK-NEXT: constant 0
+ ! CHECK-NEXT: cmpi sge, {{.*}} %c0
+ ! CHECK-NEXT: cond_br
+ ! CHECK: CharacterCompareScalar1
+ ! CHECK-NEXT: constant 0
+ ! CHECK-NEXT: cmpi sle, {{.*}} %c0
+ ! CHECK-NEXT: cond_br
+ case ('ff':'ffff')
+ nn = 20
+ ! CHECK: CharacterCompareScalar1
+ ! CHECK-NEXT: constant 0
+ ! CHECK-NEXT: cmpi eq, {{.*}} %c0
+ ! CHECK-NEXT: cond_br
+ case ('m')
+ nn = 30
+ ! CHECK: CharacterCompareScalar1
+ ! CHECK-NEXT: constant 0
+ ! CHECK-NEXT: cmpi eq, {{.*}} %c0
+ ! CHECK-NEXT: cond_br
+ case ('qq')
+ nn = 40
+ ! CHECK: CharacterCompareScalar1
+ ! CHECK-NEXT: constant 0
+ ! CHECK-NEXT: cmpi sge, {{.*}} %c0
+ ! CHECK-NEXT: cond_br
+ case ('x':)
+ nn = 50
+ end select
+ print*, nn
+ end
+
+ ! CHECK-LABEL: func @_QPtest_char_temp_selector
+ subroutine test_char_temp_selector()
+ ! Test that character selector that are temps are deallocated
+ ! only after they have been used in the select case comparisons.
+ interface
+ function gen_char_temp_selector()
+ character(:), allocatable :: gen_char_temp_selector
+ end function
+ end interface
+ select case (gen_char_temp_selector())
+ case ('case1')
+ call foo1()
+ case ('case2')
+ call foo2()
+ case ('case3')
+ call foo3()
+ case default
+ call foo_default()
+ end select
+ ! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>> {bindc_name = ".result"}
+ ! CHECK: %[[VAL_1:.*]] = fir.call @_QPgen_char_temp_selector() : () -> !fir.box<!fir.heap<!fir.char<1,?>>>
+ ! CHECK: fir.save_result %[[VAL_1]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.char<1,?>>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+ ! CHECK: cond_br %{{.*}}, ^bb2, ^bb1
+ ! CHECK: ^bb1:
+ ! CHECK: cond_br %{{.*}}, ^bb4, ^bb3
+ ! CHECK: ^bb2:
+ ! CHECK: fir.call @_QPfoo1() : () -> ()
+ ! CHECK: br ^bb8
+ ! CHECK: ^bb3:
+ ! CHECK: cond_br %{{.*}}, ^bb6, ^bb5
+ ! CHECK: ^bb4:
+ ! CHECK: fir.call @_QPfoo2() : () -> ()
+ ! CHECK: br ^bb8
+ ! CHECK: ^bb5:
+ ! CHECK: br ^bb7
+ ! CHECK: ^bb6:
+ ! CHECK: fir.call @_QPfoo3() : () -> ()
+ ! CHECK: br ^bb8
+ ! CHECK: ^bb7:
+ ! CHECK: fir.call @_QPfoo_default() : () -> ()
+ ! CHECK: br ^bb8
+ ! CHECK: ^bb8:
+ ! CHECK: %[[VAL_36:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+ ! CHECK: %[[VAL_37:.*]] = fir.box_addr %[[VAL_36]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
+ ! CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_37]] : (!fir.heap<!fir.char<1,?>>) -> i64
+ ! CHECK: %[[VAL_39:.*]] = arith.constant 0 : i64
+ ! CHECK: %[[VAL_40:.*]] = arith.cmpi ne, %[[VAL_38]], %[[VAL_39]] : i64
+ ! CHECK: fir.if %[[VAL_40]] {
+ ! CHECK: fir.freemem %[[VAL_37]]
+ ! CHECK: }
+ end subroutine
More information about the flang-commits
mailing list