[flang-commits] [flang] 1e55ec6 - [flang] SELECT CASE constructs with character selectors that require a temp
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Thu Jun 30 00:04:35 PDT 2022
Author: Valentin Clement
Date: 2022-06-30T09:04:27+02:00
New Revision: 1e55ec6666fa687b1a86bdaa95ea814557855fd1
URL: https://github.com/llvm/llvm-project/commit/1e55ec6666fa687b1a86bdaa95ea814557855fd1
DIFF: https://github.com/llvm/llvm-project/commit/1e55ec6666fa687b1a86bdaa95ea814557855fd1.diff
LOG: [flang] SELECT CASE constructs with character selectors that require a temp
Here is a character SELECT CASE construct that requires a temp to hold the
result of the TRIM intrinsic call:
```
module m
character(len=6) :: s
contains
subroutine sc
n = 0
if (lge(s,'00')) then
select case(trim(s))
case('11')
n = 1
case default
continue
case('22')
n = 2
case('33')
n = 3
case('44':'55','66':'77','88':)
n = 4
end select
end if
print*, n
end subroutine
end module m
```
This SELECT CASE construct is implemented as an IF/ELSE-IF/ELSE comparison
sequence. The temp must be retained until some comparison is successful.
At that point the temp may be freed. Generalize statement context processing
to allow multiple finalize calls to do this, such that the program always
executes exactly one freemem call.
This patch is part of the upstreaming effort from fir-dev branch.
Reviewed By: klausler, vdonaldson
Differential Revision: https://reviews.llvm.org/D128852
Co-authored-by: V Donaldson <vdonaldson at nvidia.com>
Added:
Modified:
flang/include/flang/Lower/StatementContext.h
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/ConvertExpr.cpp
flang/lib/Lower/IO.cpp
flang/test/Lower/select-case-statement.f90
Removed:
################################################################################
diff --git a/flang/include/flang/Lower/StatementContext.h b/flang/include/flang/Lower/StatementContext.h
index 58cb9e9271596..69ceeaebfbbc8 100644
--- a/flang/include/flang/Lower/StatementContext.h
+++ b/flang/include/flang/Lower/StatementContext.h
@@ -35,7 +35,7 @@ class StatementContext {
~StatementContext() {
if (!cufs.empty())
- finalize(/*popScope=*/true);
+ finalizeAndPop();
assert(cufs.empty() && "invalid StatementContext destructor call");
}
@@ -61,15 +61,29 @@ class StatementContext {
}
}
- /// Make cleanup calls. Pop or reset the stack top list.
- void finalize(bool popScope = false) {
+ /// Make cleanup calls. Retain the stack top list for a repeat call.
+ void finalizeAndKeep() {
assert(!cufs.empty() && "invalid finalize statement context");
if (cufs.back())
(*cufs.back())();
- if (popScope)
- cufs.pop_back();
- else
- cufs.back().reset();
+ }
+
+ /// Make cleanup calls. Pop the stack top list.
+ void finalizeAndPop() {
+ finalizeAndKeep();
+ cufs.pop_back();
+ }
+
+ /// Make cleanup calls. Clear the stack top list.
+ void finalize() {
+ finalizeAndKeep();
+ cufs.back().reset();
+ }
+
+ bool workListIsEmpty() const {
+ return cufs.empty() || llvm::all_of(cufs, [](auto &opt) -> bool {
+ return !opt.hasValue();
+ });
}
private:
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index b7d180ed73207..d3bc95ab7e050 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1749,8 +1749,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Generate a sequence of case value comparisons and branches.
auto caseValue = valueList.begin();
auto caseBlock = blockList.begin();
- for (mlir::Attribute attr : attrList) {
- if (attr.isa<mlir::UnitAttr>()) {
+ bool skipFinalization = false;
+ for (const auto attr : llvm::enumerate(attrList)) {
+ if (attr.value().isa<mlir::UnitAttr>()) {
+ if (attrList.size() == 1)
+ stmtCtx.finalize();
genFIRBranch(*caseBlock++);
break;
}
@@ -1767,16 +1770,30 @@ class FirConverter : public Fortran::lower::AbstractConverter {
charHelper.createUnboxChar(rhs);
mlir::Value &rhsAddr = rhsVal.first;
mlir::Value &rhsLen = rhsVal.second;
- return fir::runtime::genCharCompare(*builder, loc, pred, lhsAddr,
- lhsLen, rhsAddr, rhsLen);
+ mlir::Value result = fir::runtime::genCharCompare(
+ *builder, loc, pred, lhsAddr, lhsLen, rhsAddr, rhsLen);
+ if (stmtCtx.workListIsEmpty() || skipFinalization)
+ return result;
+ if (attr.index() == attrList.size() - 2) {
+ stmtCtx.finalize();
+ return result;
+ }
+ fir::IfOp ifOp = builder->create<fir::IfOp>(loc, result,
+ /*withElseRegion=*/false);
+ builder->setInsertionPointToStart(&ifOp.getThenRegion().front());
+ stmtCtx.finalizeAndKeep();
+ builder->setInsertionPointAfter(ifOp);
+ return result;
};
mlir::Block *newBlock = insertBlock(*caseBlock);
- if (attr.isa<fir::ClosedIntervalAttr>()) {
+ if (attr.value().isa<fir::ClosedIntervalAttr>()) {
mlir::Block *newBlock2 = insertBlock(*caseBlock);
+ skipFinalization = true;
mlir::Value cond =
genCond(*caseValue++, mlir::arith::CmpIPredicate::sge);
genFIRConditionalBranch(cond, newBlock, newBlock2);
builder->setInsertionPointToEnd(newBlock);
+ skipFinalization = false;
mlir::Value cond2 =
genCond(*caseValue++, mlir::arith::CmpIPredicate::sle);
genFIRConditionalBranch(cond2, *caseBlock++, newBlock2);
@@ -1784,12 +1801,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
continue;
}
mlir::arith::CmpIPredicate pred;
- if (attr.isa<fir::PointIntervalAttr>()) {
+ if (attr.value().isa<fir::PointIntervalAttr>()) {
pred = mlir::arith::CmpIPredicate::eq;
- } else if (attr.isa<fir::LowerBoundAttr>()) {
+ } else if (attr.value().isa<fir::LowerBoundAttr>()) {
pred = mlir::arith::CmpIPredicate::sge;
} else {
- assert(attr.isa<fir::UpperBoundAttr>() && "unexpected predicate");
+ assert(attr.value().isa<fir::UpperBoundAttr>() &&
+ "unexpected predicate");
pred = mlir::arith::CmpIPredicate::sle;
}
mlir::Value cond = genCond(*caseValue++, pred);
@@ -1798,12 +1816,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
assert(caseValue == valueList.end() && caseBlock == blockList.end() &&
"select case list mismatch");
- // Clean-up the selector at the end of the construct if it is a temporary
- // (which is possible with characters).
- mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
- builder->setInsertionPointToEnd(eval.parentConstruct->constructExit->block);
- stmtCtx.finalize();
- builder->restoreInsertionPoint(insertPt);
+ assert(stmtCtx.workListIsEmpty() && "statement context must be empty");
}
fir::ExtendedValue
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 5771a89015ac3..e0f77736208d1 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -3813,7 +3813,7 @@ class ArrayExprLowering {
// be needed afterwards.
stmtCtx.pushScope();
[[maybe_unused]] ExtValue loopRes = lowerArrayExpression(expr);
- stmtCtx.finalize(/*popScope=*/true);
+ stmtCtx.finalizeAndPop();
assert(fir::getBase(loopRes));
}
@@ -4719,7 +4719,7 @@ class ArrayExprLowering {
/// fir::ResultOp at the end of the innermost loop.
void finalizeElementCtx() {
if (elementCtx) {
- stmtCtx.finalize(/*popScope=*/true);
+ stmtCtx.finalizeAndPop();
elementCtx = false;
}
}
@@ -6433,7 +6433,7 @@ class ArrayExprLowering {
builder.create<fir::StoreOp>(loc, castLen, charLen.value());
}
}
- stmtCtx.finalize(/*popScope=*/true);
+ stmtCtx.finalizeAndPop();
builder.create<fir::ResultOp>(loc, mem);
builder.restoreInsertionPoint(insPt);
diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp
index 849288b47b24f..b3bc3a92edde3 100644
--- a/flang/lib/Lower/IO.cpp
+++ b/flang/lib/Lower/IO.cpp
@@ -196,7 +196,7 @@ static mlir::Value genEndIO(Fortran::lower::AbstractConverter &converter,
mlir::ValueRange{cookie});
mlir::Value iostat = call.getResult(0);
if (csi.bigUnitIfOp) {
- stmtCtx.finalize(/*popScope=*/true);
+ stmtCtx.finalizeAndPop();
builder.create<fir::ResultOp>(loc, iostat);
builder.setInsertionPointAfter(csi.bigUnitIfOp);
iostat = csi.bigUnitIfOp.getResult(0);
diff --git a/flang/test/Lower/select-case-statement.f90 b/flang/test/Lower/select-case-statement.f90
index 2efbcc036dd67..e188f886007b4 100644
--- a/flang/test/Lower/select-case-statement.f90
+++ b/flang/test/Lower/select-case-statement.f90
@@ -158,54 +158,188 @@ subroutine scharacter(c)
print*, nn
end
- ! CHECK-LABEL: func @_QPtest_char_temp_selector
- subroutine test_char_temp_selector()
- ! Test that character selector that are temps are deallocated
- ! only after they have been used in the select case comparisons.
- interface
- function gen_char_temp_selector()
- character(:), allocatable :: gen_char_temp_selector
- end function
- end interface
- select case (gen_char_temp_selector())
- case ('case1')
- call foo1()
- case ('case2')
- call foo2()
- case ('case3')
- call foo3()
+ ! CHECK-LABEL: func @_QPscharacter1
+ subroutine scharacter1(s)
+ ! CHECK-DAG: %[[V_0:[0-9]+]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>>
+ character(len=3) :: s
+ ! CHECK-DAG: %[[V_1:[0-9]+]] = fir.alloca i32 {bindc_name = "n", uniq_name = "_QFscharacter1En"}
+ ! CHECK: fir.store %c0{{.*}} to %[[V_1]] : !fir.ref<i32>
+ n = 0
+
+ ! CHECK: %[[V_8:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+ ! CHECK: %[[V_9:[0-9]+]] = arith.cmpi sge, %[[V_8]], %c0{{.*}} : i32
+ ! CHECK: cond_br %[[V_9]], ^bb1, ^bb15
+ ! CHECK: ^bb1: // pred: ^bb0
+ if (lge(s,'00')) then
+
+ ! CHECK: %[[V_18:[0-9]+]] = fir.load %[[V_0]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+ ! CHECK: %[[V_20:[0-9]+]] = fir.box_addr %[[V_18]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
+ ! CHECK: %[[V_42:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+ ! CHECK: %[[V_43:[0-9]+]] = arith.cmpi eq, %[[V_42]], %c0{{.*}} : i32
+ ! CHECK: fir.if %[[V_43]] {
+ ! CHECK: fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+ ! CHECK: }
+ ! CHECK: cond_br %[[V_43]], ^bb3, ^bb2
+ ! CHECK: ^bb2: // pred: ^bb1
+ select case(trim(s))
+ case('11')
+ n = 1
+
+ case default
+ continue
+
+ ! CHECK: %[[V_48:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+ ! CHECK: %[[V_49:[0-9]+]] = arith.cmpi eq, %[[V_48]], %c0{{.*}} : i32
+ ! CHECK: fir.if %[[V_49]] {
+ ! CHECK: fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+ ! CHECK: }
+ ! CHECK: cond_br %[[V_49]], ^bb6, ^bb5
+ ! CHECK: ^bb3: // pred: ^bb1
+ ! CHECK: fir.store %c1{{.*}} to %[[V_1]] : !fir.ref<i32>
+ ! CHECK: ^bb4: // pred: ^bb13
+ ! CHECK: ^bb5: // pred: ^bb2
+ case('22')
+ n = 2
+
+ ! CHECK: %[[V_54:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+ ! CHECK: %[[V_55:[0-9]+]] = arith.cmpi eq, %[[V_54]], %c0{{.*}} : i32
+ ! CHECK: fir.if %[[V_55]] {
+ ! CHECK: fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+ ! CHECK: }
+ ! CHECK: cond_br %[[V_55]], ^bb8, ^bb7
+ ! CHECK: ^bb6: // pred: ^bb2
+ ! CHECK: fir.store %c2{{.*}} to %[[V_1]] : !fir.ref<i32>
+ ! CHECK: ^bb7: // pred: ^bb5
+ case('33')
+ n = 3
+
+ case('44':'55','66':'77','88':)
+ n = 4
+ ! CHECK: %[[V_60:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+ ! CHECK: %[[V_61:[0-9]+]] = arith.cmpi sge, %[[V_60]], %c0{{.*}} : i32
+ ! CHECK: cond_br %[[V_61]], ^bb9, ^bb10
+ ! CHECK: ^bb8: // pred: ^bb5
+ ! CHECK: fir.store %c3{{.*}} to %[[V_1]] : !fir.ref<i32>
+ ! CHECK: ^bb9: // pred: ^bb7
+ ! CHECK: %[[V_66:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+ ! CHECK: %[[V_67:[0-9]+]] = arith.cmpi sle, %[[V_66]], %c0{{.*}} : i32
+ ! CHECK: fir.if %[[V_67]] {
+ ! CHECK: fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+ ! CHECK: }
+ ! CHECK: cond_br %[[V_67]], ^bb14, ^bb10
+ ! CHECK: ^bb10: // 2 preds: ^bb7, ^bb9
+ ! CHECK: %[[V_72:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+ ! CHECK: %[[V_73:[0-9]+]] = arith.cmpi sge, %[[V_72]], %c0{{.*}} : i32
+ ! CHECK: cond_br %[[V_73]], ^bb11, ^bb12
+ ! CHECK: ^bb11: // pred: ^bb10
+ ! CHECK: %[[V_78:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+ ! CHECK: %[[V_79:[0-9]+]] = arith.cmpi sle, %[[V_78]], %c0{{.*}} : i32
+ ! CHECK: fir.if %[[V_79]] {
+ ! CHECK: fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+ ! CHECK: }
+ ! CHECK: ^bb12: // 2 preds: ^bb10, ^bb11
+ ! CHECK: %[[V_84:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1
+ ! CHECK: %[[V_85:[0-9]+]] = arith.cmpi sge, %[[V_84]], %c0{{.*}} : i32
+ ! CHECK: fir.freemem %[[V_20]] : !fir.heap<!fir.char<1,?>>
+ ! CHECK: cond_br %[[V_85]], ^bb14, ^bb13
+ ! CHECK: ^bb13: // pred: ^bb12
+ ! CHECK: ^bb14: // 3 preds: ^bb9, ^bb11, ^bb12
+ ! CHECK: fir.store %c4{{.*}} to %[[V_1]] : !fir.ref<i32>
+ ! CHECK: ^bb15: // 6 preds: ^bb0, ^bb3, ^bb4, ^bb6, ^bb8, ^bb14
+ end select
+ end if
+ ! CHECK: %[[V_89:[0-9]+]] = fir.load %[[V_1]] : !fir.ref<i32>
+ print*, n
+ end subroutine
+
+
+ ! CHECK-LABEL: func @_QPscharacter2
+ subroutine scharacter2(s)
+ ! CHECK-DAG: %[[V_0:[0-9]+]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>>
+ ! CHECK: %[[V_1:[0-9]+]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>>
+ character(len=3) :: s
+ n = 0
+
+ ! CHECK: %[[V_12:[0-9]+]] = fir.load %[[V_1]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+ ! CHECK: %[[V_13:[0-9]+]] = fir.box_addr %[[V_12]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
+ ! CHECK: fir.freemem %[[V_13]] : !fir.heap<!fir.char<1,?>>
+ ! CHECK: br ^bb1
+ ! CHECK: ^bb1: // pred: ^bb0
+ ! CHECK: br ^bb2
+ n = -10
+ select case(trim(s))
case default
- call foo_default()
+ n = 9
+ end select
+ print*, n
+
+ ! CHECK: ^bb2: // pred: ^bb1
+ ! CHECK: %[[V_28:[0-9]+]] = fir.load %[[V_0]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
+ ! CHECK: %[[V_29:[0-9]+]] = fir.box_addr %[[V_28]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
+ ! CHECK: fir.freemem %[[V_29]] : !fir.heap<!fir.char<1,?>>
+ ! CHECK: br ^bb3
+ ! CHECK: ^bb3: // pred: ^bb2
+ n = -2
+ select case(trim(s))
end select
- ! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>> {bindc_name = ".result"}
- ! CHECK: %[[VAL_1:.*]] = fir.call @_QPgen_char_temp_selector() : () -> !fir.box<!fir.heap<!fir.char<1,?>>>
- ! CHECK: fir.save_result %[[VAL_1]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.char<1,?>>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
- ! CHECK: cond_br %{{.*}}, ^bb2, ^bb1
- ! CHECK: ^bb1:
- ! CHECK: cond_br %{{.*}}, ^bb4, ^bb3
- ! CHECK: ^bb2:
- ! CHECK: fir.call @_QPfoo1() : () -> ()
- ! CHECK: br ^bb8
- ! CHECK: ^bb3:
- ! CHECK: cond_br %{{.*}}, ^bb6, ^bb5
- ! CHECK: ^bb4:
- ! CHECK: fir.call @_QPfoo2() : () -> ()
- ! CHECK: br ^bb8
- ! CHECK: ^bb5:
- ! CHECK: br ^bb7
- ! CHECK: ^bb6:
- ! CHECK: fir.call @_QPfoo3() : () -> ()
- ! CHECK: br ^bb8
- ! CHECK: ^bb7:
- ! CHECK: fir.call @_QPfoo_default() : () -> ()
- ! CHECK: br ^bb8
- ! CHECK: ^bb8:
- ! CHECK: %[[VAL_36:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
- ! CHECK: %[[VAL_37:.*]] = fir.box_addr %[[VAL_36]] : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
- ! CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_37]] : (!fir.heap<!fir.char<1,?>>) -> i64
- ! CHECK: %[[VAL_39:.*]] = arith.constant 0 : i64
- ! CHECK: %[[VAL_40:.*]] = arith.cmpi ne, %[[VAL_38]], %[[VAL_39]] : i64
- ! CHECK: fir.if %[[VAL_40]] {
- ! CHECK: fir.freemem %[[VAL_37]]
- ! CHECK: }
+ print*, n
end subroutine
+
+ ! CHECK-LABEL: main
+ program p
+ integer sinteger, v(10)
+
+ n = -10
+ do j = 1, 4
+ do k = 1, 10
+ n = n + 1
+ v(k) = sinteger(n)
+ enddo
+ ! expected output: 1 1 1 1 1 1 1 1 1 1
+ ! 1 2 3 4 4 6 7 7 7 7
+ ! 7 7 7 7 7 0 0 0 0 0
+ ! 7 7 7 7 7 7 7 7 7 7
+ print*, v
+ enddo
+
+ print*
+ call slogical(.false.) ! expected output: 0 1 0 3 1 1 3 1
+ call slogical(.true.) ! expected output: 0 0 2 3 2 3 2 2
+
+ print*
+ call scharacter('aa') ! expected output: 10
+ call scharacter('d') ! expected output: 10
+ call scharacter('f') ! expected output: -1
+ call scharacter('ff') ! expected output: 20
+ call scharacter('fff') ! expected output: 20
+ call scharacter('ffff') ! expected output: 20
+ call scharacter('fffff') ! expected output: -1
+ call scharacter('jj') ! expected output: -1
+ call scharacter('m') ! expected output: 30
+ call scharacter('q') ! expected output: -1
+ call scharacter('qq') ! expected output: 40
+ call scharacter('qqq') ! expected output: -1
+ call scharacter('vv') ! expected output: -1
+ call scharacter('xx') ! expected output: 50
+ call scharacter('zz') ! expected output: 50
+
+ print*
+ call scharacter1('99 ') ! expected output: 4
+ call scharacter1('88 ') ! expected output: 4
+ call scharacter1('77 ') ! expected output: 4
+ call scharacter1('66 ') ! expected output: 4
+ call scharacter1('55 ') ! expected output: 4
+ call scharacter1('44 ') ! expected output: 4
+ call scharacter1('33 ') ! expected output: 3
+ call scharacter1('22 ') ! expected output: 2
+ call scharacter1('11 ') ! expected output: 1
+ call scharacter1('00 ') ! expected output: 0
+ call scharacter1('. ') ! expected output: 0
+ call scharacter1(' ') ! expected output: 0
+
+ print*
+ call scharacter2('99 ') ! expected output: 9 -2
+ call scharacter2('22 ') ! expected output: 9 -2
+ call scharacter2('. ') ! expected output: 9 -2
+ call scharacter2(' ') ! expected output: 9 -2
+ end
More information about the flang-commits
mailing list