[flang-commits] [flang] 84b9ae6 - [flang]Add support for do concurrent
Mats Petersson via flang-commits
flang-commits at lists.llvm.org
Mon Jun 13 04:29:12 PDT 2022
Author: Mats Petersson
Date: 2022-06-13T12:28:49+01:00
New Revision: 84b9ae662419ce97b3cb13879be431f6a0c9eaa4
URL: https://github.com/llvm/llvm-project/commit/84b9ae662419ce97b3cb13879be431f6a0c9eaa4
DIFF: https://github.com/llvm/llvm-project/commit/84b9ae662419ce97b3cb13879be431f6a0c9eaa4.diff
LOG: [flang]Add support for do concurrent
[flang]Add support for do concurrent
Upstreaming from fir-dev on https://github.com/flang-compiler/f18-llvm-project
Support for concurrent execution in do-loops.
A selection of tests are also added.
Co-authored-by: V Donaldson <vdonaldson at nvidia.com>
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D127240
Added:
flang/test/Lower/loops.f90
Modified:
flang/lib/Lower/Bridge.cpp
flang/test/Lower/OpenMP/omp-unstructured.f90
Removed:
################################################################################
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index eb58c4ee06346..6caea151c20e8 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -87,7 +87,10 @@ struct IncrementLoopInfo {
const Fortran::lower::SomeExpr *lowerExpr;
const Fortran::lower::SomeExpr *upperExpr;
const Fortran::lower::SomeExpr *stepExpr;
+ const Fortran::lower::SomeExpr *maskExpr = nullptr;
bool isUnordered; // do concurrent, forall
+ llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
mlir::Value loopVariable = nullptr;
mlir::Value stepValue = nullptr; // possible uses in multiple blocks
@@ -98,6 +101,7 @@ struct IncrementLoopInfo {
bool hasRealControl = false;
mlir::Value tripVariable = nullptr;
mlir::Block *headerBlock = nullptr; // loop entry and test block
+ mlir::Block *maskBlock = nullptr; // concurrent loop mask block
mlir::Block *bodyBlock = nullptr; // first loop body block
mlir::Block *exitBlock = nullptr; // loop exit target block
};
@@ -636,9 +640,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
/// Generate the address of loop variable \p sym.
+ /// If \p sym is not mapped yet, allocate local storage for it.
mlir::Value genLoopVariableAddress(mlir::Location loc,
- const Fortran::semantics::Symbol &sym) {
- assert(lookupSymbol(sym) && "loop control variable must already be in map");
+ const Fortran::semantics::Symbol &sym,
+ bool isUnordered) {
+ if (isUnordered || sym.has<Fortran::semantics::HostAssocDetails>() ||
+ sym.has<Fortran::semantics::UseDetails>()) {
+ if (!shallowLookupSymbol(sym)) {
+ // Do concurrent loop variables are not mapped yet since they are local
+ // to the Do concurrent scope (same for OpenMP loops).
+ auto newVal = builder->createTemporary(loc, genType(sym),
+ toStringRef(sym.name()));
+ bindIfNewSymbol(sym, newVal);
+ return newVal;
+ }
+ }
+ auto entry = lookupSymbol(sym);
+ (void)entry;
+ assert(entry && "loop control variable must already be in map");
Fortran::lower::StatementContext stmtCtx;
return fir::getBase(
genExprAddr(Fortran::evaluate::AsGenericExpr(sym).value(), stmtCtx));
@@ -973,6 +992,34 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->create<fir::SelectOp>(loc, selectExpr, indexList, blockList);
}
+ /// Collect DO CONCURRENT or FORALL loop control information.
+ IncrementLoopNestInfo getConcurrentControl(
+ const Fortran::parser::ConcurrentHeader &header,
+ const std::list<Fortran::parser::LocalitySpec> &localityList = {}) {
+ IncrementLoopNestInfo incrementLoopNestInfo;
+ for (const Fortran::parser::ConcurrentControl &control :
+ std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t))
+ incrementLoopNestInfo.emplace_back(
+ *std::get<0>(control.t).symbol, std::get<1>(control.t),
+ std::get<2>(control.t), std::get<3>(control.t), /*isUnordered=*/true);
+ IncrementLoopInfo &info = incrementLoopNestInfo.back();
+ info.maskExpr = Fortran::semantics::GetExpr(
+ std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(header.t));
+ for (const Fortran::parser::LocalitySpec &x : localityList) {
+ if (const auto *localInitList =
+ std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
+ for (const Fortran::parser::Name &x : localInitList->v)
+ info.localInitSymList.push_back(x.symbol);
+ if (const auto *sharedList =
+ std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
+ for (const Fortran::parser::Name &x : sharedList->v)
+ info.sharedSymList.push_back(x.symbol);
+ if (std::get_if<Fortran::parser::LocalitySpec::Local>(&x.u))
+ TODO(toLocation(), "do concurrent locality specs not implemented");
+ }
+ return incrementLoopNestInfo;
+ }
+
/// Generate FIR for a DO construct. There are six variants:
/// - unstructured infinite and while loops
/// - structured and unstructured increment loops
@@ -1029,7 +1076,34 @@ class FirConverter : public Fortran::lower::AbstractConverter {
info.exitBlock = exitBlock;
}
} else {
- TODO(toLocation(), "infinite/unstructured loop/concurrent loop");
+ const auto *concurrent =
+ std::get_if<Fortran::parser::LoopControl::Concurrent>(
+ &loopControl->u);
+ assert(concurrent && "invalid DO loop variant");
+ incrementLoopNestInfo = getConcurrentControl(
+ std::get<Fortran::parser::ConcurrentHeader>(concurrent->t),
+ std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent->t));
+ if (unstructuredContext) {
+ maybeStartBlock(preheaderBlock);
+ for (IncrementLoopInfo &info : incrementLoopNestInfo) {
+ // The original loop body provides the body and latch blocks of the
+ // innermost dimension. The (first) body block of a non-innermost
+ // dimension is the preheader block of the immediately enclosed
+ // dimension. The latch block of a non-innermost dimension is the
+ // exit block of the immediately enclosed dimension.
+ auto createNextExitBlock = [&]() {
+ // Create unstructured loop exit blocks, outermost to innermost.
+ return exitBlock = insertBlock(exitBlock);
+ };
+ bool isInnermost = &info == &incrementLoopNestInfo.back();
+ bool isOutermost = &info == &incrementLoopNestInfo.front();
+ info.headerBlock = isOutermost ? headerBlock : createNextBeginBlock();
+ info.bodyBlock = isInnermost ? bodyBlock : createNextBeginBlock();
+ info.exitBlock = isOutermost ? exitBlock : createNextExitBlock();
+ if (info.maskExpr)
+ info.maskBlock = createNextBeginBlock();
+ }
+ }
}
// Increment loop begin code. (Infinite/while code was already generated.)
@@ -1065,8 +1139,28 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return builder->createRealConstant(loc, controlType, 1u);
return builder->createIntegerConstant(loc, controlType, 1); // step
};
+ auto handleLocalitySpec = [&](IncrementLoopInfo &info) {
+ // Generate Local Init Assignments
+ for (const Fortran::semantics::Symbol *sym : info.localInitSymList) {
+ const auto *hostDetails =
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>();
+ assert(hostDetails && "missing local_init variable host variable");
+ const Fortran::semantics::Symbol &hostSym = hostDetails->symbol();
+ (void)hostSym;
+ TODO(loc, "do concurrent locality specs not implemented");
+ }
+ // Handle shared locality spec
+ for (const Fortran::semantics::Symbol *sym : info.sharedSymList) {
+ const auto *hostDetails =
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>();
+ assert(hostDetails && "missing shared variable host variable");
+ const Fortran::semantics::Symbol &hostSym = hostDetails->symbol();
+ copySymbolBinding(hostSym, *sym);
+ }
+ };
for (IncrementLoopInfo &info : incrementLoopNestInfo) {
- info.loopVariable = genLoopVariableAddress(loc, info.loopVariableSym);
+ info.loopVariable =
+ genLoopVariableAddress(loc, info.loopVariableSym, info.isUnordered);
mlir::Value lowerValue = genControlValue(info.lowerExpr, info);
mlir::Value upperValue = genControlValue(info.upperExpr, info);
info.stepValue = genControlValue(info.stepExpr, info);
@@ -1081,8 +1175,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::Value value = builder->createConvert(
loc, info.getLoopVariableType(), info.doLoop.getInductionVar());
builder->create<fir::StoreOp>(loc, value, info.loopVariable);
- // TODO: Mask expr
- // TODO: handle Locality Spec
+ if (info.maskExpr) {
+ Fortran::lower::StatementContext stmtCtx;
+ mlir::Value maskCond = createFIRExpr(loc, info.maskExpr, stmtCtx);
+ stmtCtx.finalize();
+ mlir::Value maskCondCast =
+ builder->createConvert(loc, builder->getI1Type(), maskCond);
+ auto ifOp = builder->create<fir::IfOp>(loc, maskCondCast,
+ /*withElseRegion=*/false);
+ builder->setInsertionPointToStart(&ifOp.getThenRegion().front());
+ }
+ handleLocalitySpec(info);
continue;
}
@@ -1119,16 +1222,33 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->create<fir::StoreOp>(loc, lowerValue, info.loopVariable);
// Unstructured loop header - generate loop condition and mask.
+ // Note - Currently there is no way to tag a loop as a concurrent loop.
startBlock(info.headerBlock);
tripCount = builder->create<fir::LoadOp>(loc, info.tripVariable);
mlir::Value zero =
builder->createIntegerConstant(loc, tripCount.getType(), 0);
auto cond = builder->create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, tripCount, zero);
- // TODO: mask expression
- genFIRConditionalBranch(cond, info.bodyBlock, info.exitBlock);
- if (&info != &incrementLoopNestInfo.back()) // not innermost
- startBlock(info.bodyBlock); // preheader block of enclosed dimension
+ if (info.maskExpr) {
+ genFIRConditionalBranch(cond, info.maskBlock, info.exitBlock);
+ startBlock(info.maskBlock);
+ mlir::Block *latchBlock = getEval().getLastNestedEvaluation().block;
+ assert(latchBlock && "missing masked concurrent loop latch block");
+ Fortran::lower::StatementContext stmtCtx;
+ mlir::Value maskCond = createFIRExpr(loc, info.maskExpr, stmtCtx);
+ stmtCtx.finalize();
+ genFIRConditionalBranch(maskCond, info.bodyBlock, latchBlock);
+ } else {
+ genFIRConditionalBranch(cond, info.bodyBlock, info.exitBlock);
+ if (&info != &incrementLoopNestInfo.back()) // not innermost
+ startBlock(info.bodyBlock); // preheader block of enclosed dimension
+ }
+ if (!info.localInitSymList.empty()) {
+ mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+ builder->setInsertionPointToStart(info.bodyBlock);
+ handleLocalitySpec(info);
+ builder->restoreInsertionPoint(insertPt);
+ }
}
}
diff --git a/flang/test/Lower/OpenMP/omp-unstructured.f90 b/flang/test/Lower/OpenMP/omp-unstructured.f90
index 28ed87169bbc6..e7d48bb269349 100644
--- a/flang/test/Lower/OpenMP/omp-unstructured.f90
+++ b/flang/test/Lower/OpenMP/omp-unstructured.f90
@@ -59,8 +59,8 @@ subroutine ss2(n) ! unstructured OpenMP construct; loop exit inside construct
end
! CHECK-LABEL: func @_QPss3{{.*}} {
-! CHECK: %[[ALLOCA_K:.*]] = fir.alloca i32 {bindc_name = "k", {{.*}}}
! CHECK: omp.parallel {
+! CHECK: %[[ALLOCA_K:.*]] = fir.alloca i32 {bindc_name = "k", pinned}
! CHECK: %[[ALLOCA_1:.*]] = fir.alloca i32 {{{.*}}, pinned}
! CHECK: %[[ALLOCA_2:.*]] = fir.alloca i32 {{{.*}}, pinned}
! CHECK: br ^bb1
diff --git a/flang/test/Lower/loops.f90 b/flang/test/Lower/loops.f90
new file mode 100644
index 0000000000000..2a95c69d6e225
--- /dev/null
+++ b/flang/test/Lower/loops.f90
@@ -0,0 +1,109 @@
+! RUN: bbc -emit-fir -o - %s | FileCheck %s
+
+! CHECK-LABEL: loop_test
+subroutine loop_test
+ ! CHECK: %[[VAL_2:.*]] = fir.alloca i16 {bindc_name = "i"}
+ ! CHECK: %[[VAL_3:.*]] = fir.alloca i16 {bindc_name = "i"}
+ ! CHECK: %[[VAL_4:.*]] = fir.alloca i16 {bindc_name = "i"}
+ ! CHECK: %[[VAL_5:.*]] = fir.alloca i8 {bindc_name = "k"}
+ ! CHECK: %[[VAL_6:.*]] = fir.alloca i8 {bindc_name = "j"}
+ ! CHECK: %[[VAL_7:.*]] = fir.alloca i8 {bindc_name = "i"}
+ ! CHECK: %[[VAL_8:.*]] = fir.alloca i32 {bindc_name = "k"}
+ ! CHECK: %[[VAL_9:.*]] = fir.alloca i32 {bindc_name = "j"}
+ ! CHECK: %[[VAL_10:.*]] = fir.alloca i32 {bindc_name = "i"}
+ ! CHECK: %[[VAL_11:.*]] = fir.alloca !fir.array<5x5x5xi32> {bindc_name = "a", uniq_name = "_QFloop_testEa"}
+ ! CHECK: %[[VAL_12:.*]] = fir.alloca i32 {bindc_name = "asum", uniq_name = "_QFloop_testEasum"}
+ ! CHECK: %[[VAL_13:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFloop_testEi"}
+ ! CHECK: %[[VAL_14:.*]] = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFloop_testEj"}
+ ! CHECK: %[[VAL_15:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFloop_testEk"}
+ ! CHECK: %[[VAL_16:.*]] = fir.alloca f32 {bindc_name = "x", uniq_name = "_QFloop_testEx"}
+ ! CHECK: %[[VAL_17:.*]] = fir.alloca i32 {bindc_name = "xsum", uniq_name = "_QFloop_testExsum"}
+
+ integer(4) :: a(5,5,5), i, j, k, asum, xsum
+
+ i = 100
+ j = 200
+ k = 300
+
+ ! CHECK-COUNT-3: fir.do_loop {{.*}} unordered
+ do concurrent (i=1:5, j=1:5, k=1:5) ! shared(a)
+ ! CHECK: fir.coordinate_of
+ a(i,j,k) = 0
+ enddo
+ ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+ print*, 'A:', i, j, k
+
+ ! CHECK-COUNT-3: fir.do_loop {{.*}} unordered
+ ! CHECK: fir.if
+ do concurrent (integer(1)::i=1:5, j=1:5, k=1:5, i.ne.j .and. k.ne.3) shared(a)
+ ! CHECK-COUNT-2: fir.coordinate_of
+ a(i,j,k) = a(i,j,k) + 1
+ enddo
+
+ ! CHECK-COUNT-3: fir.do_loop {{[^un]*}} -> index
+ asum = 0
+ do i=1,5
+ do j=1,5
+ do k=1,5
+ ! CHECK: fir.coordinate_of
+ asum = asum + a(i,j,k)
+ enddo
+ enddo
+ enddo
+ ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+ print*, 'B:', i, j, k, '-', asum
+
+ ! CHECK: fir.do_loop {{.*}} unordered
+ ! CHECK-COUNT-2: fir.if
+ do concurrent (integer(2)::i=1:5, i.ne.3)
+ if (i.eq.2 .or. i.eq.4) goto 5 ! fir.if
+ ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+ print*, 'C:', i
+ 5 continue
+ enddo
+
+ ! CHECK: fir.do_loop {{.*}} unordered
+ ! CHECK-COUNT-2: fir.if
+ do concurrent (integer(2)::i=1:5, i.ne.3)
+ if (i.eq.2 .or. i.eq.4) then ! fir.if
+ goto 6
+ endif
+ ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+ print*, 'D:', i
+ 6 continue
+ enddo
+
+ ! CHECK-NOT: fir.do_loop
+ ! CHECK-NOT: fir.if
+ do concurrent (integer(2)::i=1:5, i.ne.3)
+ goto (7, 7) i+1
+ ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+ print*, 'E:', i
+ 7 continue
+ enddo
+
+ xsum = 0.0
+ ! CHECK-NOT: fir.do_loop
+ do x = 1.5, 3.5, 0.3
+ xsum = xsum + 1
+ enddo
+ ! CHECK: fir.call @_FortranAioBeginExternalFormattedOutput
+ print '(" F:",X,F3.1,A,I2)', x, ' -', xsum
+end subroutine loop_test
+
+! CHECK-LABEL: print_nothing
+subroutine print_nothing(k1, k2)
+ if (k1 > 0) then
+ ! CHECK: br [[header:\^bb[0-9]+]]
+ ! CHECK: [[header]]
+ do while (k1 > k2)
+ print*, k1, k2 ! no output
+ k2 = k2 + 1
+ ! CHECK: br [[header]]
+ end do
+ end if
+end
+
+ call loop_test
+ call print_nothing(2, 2)
+end
More information about the flang-commits
mailing list