[flang-commits] [flang] 335b399 - [flang] Do concurrent locality specifiers
V Donaldson via flang-commits
flang-commits at lists.llvm.org
Tue Aug 8 10:10:24 PDT 2023
Author: V Donaldson
Date: 2023-08-08T10:09:38-07:00
New Revision: 335b3990ef9115e3b20eb9dfa32393a7fdfde4e3
URL: https://github.com/llvm/llvm-project/commit/335b3990ef9115e3b20eb9dfa32393a7fdfde4e3
DIFF: https://github.com/llvm/llvm-project/commit/335b3990ef9115e3b20eb9dfa32393a7fdfde4e3.diff
LOG: [flang] Do concurrent locality specifiers
Added:
Modified:
flang/include/flang/Lower/SymbolMap.h
flang/include/flang/Semantics/symbol.h
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/ConvertExpr.cpp
flang/lib/Lower/SymbolMap.cpp
flang/test/Lower/loops.f90
Removed:
################################################################################
diff --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h
index dc36a672f8c15f..a55e4b133fe0a8 100644
--- a/flang/include/flang/Lower/SymbolMap.h
+++ b/flang/include/flang/Lower/SymbolMap.h
@@ -316,10 +316,10 @@ class SymMap {
}
private:
- /// Add `symbol` to the current map and bind a `box`.
+ /// Bind `box` to `symRef` in the symbol map.
void makeSym(semantics::SymbolRef symRef, const SymbolBox &box,
bool force = false) {
- const auto *sym = &symRef.get().GetUltimate();
+ auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
if (force)
symbolMapStack.back().erase(sym);
assert(box && "cannot add an undefined symbol box");
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 333f63b2c2842b..93ed272149f307 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -707,6 +707,9 @@ class Symbol {
},
details_);
}
+ bool HasLocalLocality() const {
+ return test(Flag::LocalityLocal) || test(Flag::LocalityLocalInit);
+ }
bool operator==(const Symbol &that) const { return this == &that; }
bool operator!=(const Symbol &that) const { return !(*this == that); }
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 5e52ca69079aac..0b1d07d9bc0fb2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -95,6 +95,11 @@ struct IncrementLoopInfo {
return fir::unwrapRefType(loopVariable.getType());
}
+ bool hasLocalitySpecs() const {
+ return !localSymList.empty() || !localInitSymList.empty() ||
+ !sharedSymList.empty();
+ }
+
// Data members common to both structured and unstructured loops.
const Fortran::semantics::Symbol &loopVariableSym;
const Fortran::lower::SomeExpr *lowerExpr;
@@ -102,6 +107,7 @@ struct IncrementLoopInfo {
const Fortran::lower::SomeExpr *stepExpr;
const Fortran::lower::SomeExpr *maskExpr = nullptr;
bool isUnordered; // do concurrent, forall
+ llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
mlir::Value loopVariable = nullptr;
@@ -1514,6 +1520,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
info.maskExpr = Fortran::semantics::GetExpr(
std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(header.t));
for (const Fortran::parser::LocalitySpec &x : localityList) {
+ if (const auto *localList =
+ std::get_if<Fortran::parser::LocalitySpec::Local>(&x.u))
+ for (const Fortran::parser::Name &x : localList->v)
+ info.localSymList.push_back(x.symbol);
if (const auto *localInitList =
std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
for (const Fortran::parser::Name &x : localInitList->v)
@@ -1522,12 +1532,38 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
for (const Fortran::parser::Name &x : sharedList->v)
info.sharedSymList.push_back(x.symbol);
- if (std::get_if<Fortran::parser::LocalitySpec::Local>(&x.u))
- TODO(toLocation(), "do concurrent locality specs not implemented");
}
return incrementLoopNestInfo;
}
+ /// Create DO CONCURRENT construct symbol bindings and generate LOCAL_INIT
+ /// assignments.
+ void handleLocalitySpecs(const IncrementLoopInfo &info) {
+ Fortran::semantics::SemanticsContext &semanticsContext =
+ bridge.getSemanticsContext();
+ for (const Fortran::semantics::Symbol *sym : info.localSymList)
+ createHostAssociateVarClone(*sym);
+ for (const Fortran::semantics::Symbol *sym : info.localInitSymList) {
+ createHostAssociateVarClone(*sym);
+ const auto *hostDetails =
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>();
+ assert(hostDetails && "missing locality spec host symbol");
+ const Fortran::semantics::Symbol *hostSym = &hostDetails->symbol();
+ Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
+ Fortran::evaluate::Assignment assign{
+ ea.Designate(Fortran::evaluate::DataRef{*sym}).value(),
+ ea.Designate(Fortran::evaluate::DataRef{*hostSym}).value()};
+ if (Fortran::semantics::IsPointer(*sym))
+ assign.u = Fortran::evaluate::Assignment::BoundsSpec{};
+ genAssignment(assign);
+ }
+ for (const Fortran::semantics::Symbol *sym : info.sharedSymList) {
+ const auto *hostDetails =
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>();
+ copySymbolBinding(hostDetails->symbol(), *sym);
+ }
+ }
+
/// Generate FIR for a DO construct. There are six variants:
/// - unstructured infinite and while loops
/// - structured and unstructured increment loops
@@ -1656,25 +1692,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return builder->createRealConstant(loc, controlType, 1u);
return builder->createIntegerConstant(loc, controlType, 1); // step
};
- auto handleLocalitySpec = [&](IncrementLoopInfo &info) {
- // Generate Local Init Assignments
- for (const Fortran::semantics::Symbol *sym : info.localInitSymList) {
- const auto *hostDetails =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>();
- assert(hostDetails && "missing local_init variable host variable");
- const Fortran::semantics::Symbol &hostSym = hostDetails->symbol();
- (void)hostSym;
- TODO(loc, "do concurrent locality specs not implemented");
- }
- // Handle shared locality spec
- for (const Fortran::semantics::Symbol *sym : info.sharedSymList) {
- const auto *hostDetails =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>();
- assert(hostDetails && "missing shared variable host variable");
- const Fortran::semantics::Symbol &hostSym = hostDetails->symbol();
- copySymbolBinding(hostSym, *sym);
- }
- };
for (IncrementLoopInfo &info : incrementLoopNestInfo) {
info.loopVariable =
genLoopVariableAddress(loc, info.loopVariableSym, info.isUnordered);
@@ -1714,7 +1731,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/*withElseRegion=*/false);
builder->setInsertionPointToStart(&ifOp.getThenRegion().front());
}
- handleLocalitySpec(info);
+ if (info.hasLocalitySpecs())
+ handleLocalitySpecs(info);
continue;
}
@@ -1771,10 +1789,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (&info != &incrementLoopNestInfo.back()) // not innermost
startBlock(info.bodyBlock); // preheader block of enclosed dimension
}
- if (!info.localInitSymList.empty()) {
+ if (info.hasLocalitySpecs()) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
builder->setInsertionPointToStart(info.bodyBlock);
- handleLocalitySpec(info);
+ handleLocalitySpecs(info);
builder->restoreInsertionPoint(insertPt);
}
}
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index e811b45e6c74f4..ea952198251c9b 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -590,13 +590,15 @@ absentBoxToUnallocatedBox(fir::FirOpBuilder &builder, mlir::Location loc,
// associations.
template <typename A>
const Fortran::semantics::Symbol &getFirstSym(const A &obj) {
- return obj.GetFirstSymbol().GetUltimate();
+ const Fortran::semantics::Symbol &sym = obj.GetFirstSymbol();
+ return sym.HasLocalLocality() ? sym : sym.GetUltimate();
}
// Helper to get the ultimate last symbol.
template <typename A>
const Fortran::semantics::Symbol &getLastSym(const A &obj) {
- return obj.GetLastSymbol().GetUltimate();
+ const Fortran::semantics::Symbol &sym = obj.GetLastSymbol();
+ return sym.HasLocalLocality() ? sym : sym.GetUltimate();
}
// Return true if TRANSPOSE should be lowered without a runtime call.
diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 2d9c16346cac3f..ce078a7fbde651 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -35,10 +35,10 @@ void Fortran::lower::SymMap::addSymbol(Fortran::semantics::SymbolRef sym,
Fortran::lower::SymbolBox
Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
- Fortran::semantics::SymbolRef sym = symRef.get().GetUltimate();
+ auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
jmap != jend; ++jmap) {
- auto iter = jmap->find(&*sym);
+ auto iter = jmap->find(sym);
if (iter != jmap->end())
return iter->second;
}
@@ -47,8 +47,9 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
Fortran::semantics::SymbolRef symRef) {
+ auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
auto &map = symbolMapStack.back();
- auto iter = map.find(&symRef.get().GetUltimate());
+ auto iter = map.find(sym);
if (iter != map.end())
return iter->second;
return SymbolBox::None{};
@@ -59,14 +60,14 @@ Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
/// host-association in OpenMP code.
Fortran::lower::SymbolBox Fortran::lower::SymMap::lookupOneLevelUpSymbol(
Fortran::semantics::SymbolRef symRef) {
- Fortran::semantics::SymbolRef sym = symRef.get().GetUltimate();
+ auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
auto jmap = symbolMapStack.rbegin();
auto jend = symbolMapStack.rend();
if (jmap == jend)
return SymbolBox::None{};
// Skip one level in symbol map stack.
for (++jmap; jmap != jend; ++jmap) {
- auto iter = jmap->find(&*sym);
+ auto iter = jmap->find(sym);
if (iter != jmap->end())
return iter->second;
}
diff --git a/flang/test/Lower/loops.f90 b/flang/test/Lower/loops.f90
index febae0e8a0dd66..12385cbb4c186f 100644
--- a/flang/test/Lower/loops.f90
+++ b/flang/test/Lower/loops.f90
@@ -91,6 +91,71 @@ subroutine loop_test
print '(" F:",X,F3.1,A,I2)', x, ' -', xsum
end subroutine loop_test
+! CHECK-LABEL: c.func @_QPlis
+subroutine lis(n)
+ ! CHECK-DAG: fir.alloca i32 {bindc_name = "m"}
+ ! CHECK-DAG: fir.alloca i32 {bindc_name = "j"}
+ ! CHECK-DAG: fir.alloca i32 {bindc_name = "i"}
+ ! CHECK-DAG: fir.alloca i8 {bindc_name = "i"}
+ ! CHECK-DAG: fir.alloca i32 {bindc_name = "j", uniq_name = "_QFlisEj"}
+ ! CHECK-DAG: fir.alloca i32 {bindc_name = "k", uniq_name = "_QFlisEk"}
+ ! CHECK-DAG: fir.alloca !fir.box<!fir.ptr<!fir.array<?x?x?xi32>>> {bindc_name = "p", uniq_name = "_QFlisEp"}
+ ! CHECK-DAG: fir.alloca !fir.array<?x?x?xi32>, %{{.*}}, %{{.*}}, %{{.*}} {bindc_name = "a", fir.target, uniq_name = "_QFlisEa"}
+ ! CHECK-DAG: fir.alloca !fir.array<?x?xi32>, %{{.*}}, %{{.*}} {bindc_name = "r", uniq_name = "_QFlisEr"}
+ ! CHECK-DAG: fir.alloca !fir.array<?x?xi32>, %{{.*}}, %{{.*}} {bindc_name = "s", uniq_name = "_QFlisEs"}
+ ! CHECK-DAG: fir.alloca !fir.array<?x?xi32>, %{{.*}}, %{{.*}} {bindc_name = "t", uniq_name = "_QFlisEt"}
+ integer, target :: a(n,n,n) ! operand via p
+ integer :: r(n,n) ! result, unspecified locality
+ integer :: s(n,n) ! shared locality
+ integer :: t(n,n) ! local locality
+ integer, pointer :: p(:,:,:) ! local_init locality
+
+ p => a
+ ! CHECK: fir.do_loop %arg1 = %c0{{.*}} to %{{.*}} step %c1{{.*}} unordered iter_args(%arg2 = %{{.*}}) -> (!fir.array<?x?xi32>) {
+ ! CHECK: fir.do_loop %arg3 = %c0{{.*}} to %{{.*}} step %c1{{.*}} unordered iter_args(%arg4 = %arg2) -> (!fir.array<?x?xi32>) {
+ ! CHECK: }
+ ! CHECK: }
+ r = 0
+
+ ! CHECK: fir.do_loop %arg1 = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+ ! CHECK: fir.do_loop %arg2 = %{{.*}} to %{{.*}} step %c1{{.*}} iter_args(%arg3 = %{{.*}}) -> (index, i32) {
+ ! CHECK: }
+ ! CHECK: }
+ do concurrent (integer(kind=1)::i=n:1:-1)
+ do j = 1,n
+ a(i,j,:) = 2*(i+j)
+ s(i,j) = -i-j
+ enddo
+ enddo
+
+ ! CHECK: fir.do_loop %arg1 = %{{.*}} to %{{.*}} step %c1{{.*}} unordered {
+ ! CHECK: fir.do_loop %arg2 = %{{.*}} to %{{.*}} step %c1{{.*}} unordered {
+ ! CHECK: fir.if %{{.*}} {
+ ! CHECK: %[[V_95:[0-9]+]] = fir.alloca !fir.array<?x?xi32>, %{{.*}}, %{{.*}} {bindc_name = "t", pinned, uniq_name = "_QFlisEt"}
+ ! CHECK: %[[V_96:[0-9]+]] = fir.alloca !fir.box<!fir.ptr<!fir.array<?x?x?xi32>>> {bindc_name = "p", pinned, uniq_name = "_QFlisEp"}
+ ! CHECK: fir.store %{{.*}} to %[[V_96]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?x?xi32>>>>
+ ! CHECK: fir.do_loop %arg3 = %{{.*}} to %{{.*}} step %c1{{.*}} iter_args(%arg4 = %{{.*}}) -> (index, i32) {
+ ! CHECK: fir.do_loop %arg5 = %{{.*}} to %{{.*}} step %c1{{.*}} unordered {
+ ! CHECK: fir.load %[[V_96]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?x?xi32>>>>
+ ! CHECK: fir.convert %[[V_95]] : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.ref<!fir.array<?xi32>>
+ ! CHECK: }
+ ! CHECK: }
+ ! CHECK: fir.convert %[[V_95]] : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.ref<!fir.array<?xi32>>
+ ! CHECK: }
+ ! CHECK: }
+ ! CHECK: }
+ do concurrent (i=1:n,j=1:n,i.ne.j) local(t) local_init(p) shared(s)
+ do k=1,n
+ do concurrent (m=1:n)
+ t(k,m) = p(k,m,k)
+ enddo
+ enddo
+ r(i,j) = t(i,j) + s(i,j)
+ enddo
+
+ print*, sum(r) ! n=6 -> 210
+end
+
! CHECK-LABEL: print_nothing
subroutine print_nothing(k1, k2)
if (k1 > 0) then
@@ -105,5 +170,6 @@ subroutine print_nothing(k1, k2)
end
call loop_test
+ call lis(6)
call print_nothing(2, 2)
end
More information about the flang-commits
mailing list