[flang-commits] [flang] 335b399 - [flang] Do concurrent locality specifiers

V Donaldson via flang-commits flang-commits at lists.llvm.org
Tue Aug 8 10:10:24 PDT 2023


Author: V Donaldson
Date: 2023-08-08T10:09:38-07:00
New Revision: 335b3990ef9115e3b20eb9dfa32393a7fdfde4e3

URL: https://github.com/llvm/llvm-project/commit/335b3990ef9115e3b20eb9dfa32393a7fdfde4e3
DIFF: https://github.com/llvm/llvm-project/commit/335b3990ef9115e3b20eb9dfa32393a7fdfde4e3.diff

LOG: [flang] Do concurrent locality specifiers

Added: 
    

Modified: 
    flang/include/flang/Lower/SymbolMap.h
    flang/include/flang/Semantics/symbol.h
    flang/lib/Lower/Bridge.cpp
    flang/lib/Lower/ConvertExpr.cpp
    flang/lib/Lower/SymbolMap.cpp
    flang/test/Lower/loops.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h
index dc36a672f8c15f..a55e4b133fe0a8 100644
--- a/flang/include/flang/Lower/SymbolMap.h
+++ b/flang/include/flang/Lower/SymbolMap.h
@@ -316,10 +316,10 @@ class SymMap {
   }
 
 private:
-  /// Add `symbol` to the current map and bind a `box`.
+  /// Bind `box` to `symRef` in the symbol map.
   void makeSym(semantics::SymbolRef symRef, const SymbolBox &box,
                bool force = false) {
-    const auto *sym = &symRef.get().GetUltimate();
+    auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
     if (force)
       symbolMapStack.back().erase(sym);
     assert(box && "cannot add an undefined symbol box");

diff  --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 333f63b2c2842b..93ed272149f307 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -707,6 +707,9 @@ class Symbol {
         },
         details_);
   }
+  bool HasLocalLocality() const {
+    return test(Flag::LocalityLocal) || test(Flag::LocalityLocalInit);
+  }
 
   bool operator==(const Symbol &that) const { return this == &that; }
   bool operator!=(const Symbol &that) const { return !(*this == that); }

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 5e52ca69079aac..0b1d07d9bc0fb2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -95,6 +95,11 @@ struct IncrementLoopInfo {
     return fir::unwrapRefType(loopVariable.getType());
   }
 
+  bool hasLocalitySpecs() const {
+    return !localSymList.empty() || !localInitSymList.empty() ||
+           !sharedSymList.empty();
+  }
+
   // Data members common to both structured and unstructured loops.
   const Fortran::semantics::Symbol &loopVariableSym;
   const Fortran::lower::SomeExpr *lowerExpr;
@@ -102,6 +107,7 @@ struct IncrementLoopInfo {
   const Fortran::lower::SomeExpr *stepExpr;
   const Fortran::lower::SomeExpr *maskExpr = nullptr;
   bool isUnordered; // do concurrent, forall
+  llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
   llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
   llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
   mlir::Value loopVariable = nullptr;
@@ -1514,6 +1520,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     info.maskExpr = Fortran::semantics::GetExpr(
         std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(header.t));
     for (const Fortran::parser::LocalitySpec &x : localityList) {
+      if (const auto *localList =
+              std::get_if<Fortran::parser::LocalitySpec::Local>(&x.u))
+        for (const Fortran::parser::Name &x : localList->v)
+          info.localSymList.push_back(x.symbol);
       if (const auto *localInitList =
               std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
         for (const Fortran::parser::Name &x : localInitList->v)
@@ -1522,12 +1532,38 @@ class FirConverter : public Fortran::lower::AbstractConverter {
               std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
         for (const Fortran::parser::Name &x : sharedList->v)
           info.sharedSymList.push_back(x.symbol);
-      if (std::get_if<Fortran::parser::LocalitySpec::Local>(&x.u))
-        TODO(toLocation(), "do concurrent locality specs not implemented");
     }
     return incrementLoopNestInfo;
   }
 
+  /// Create DO CONCURRENT construct symbol bindings and generate LOCAL_INIT
+  /// assignments.
+  void handleLocalitySpecs(const IncrementLoopInfo &info) {
+    Fortran::semantics::SemanticsContext &semanticsContext =
+        bridge.getSemanticsContext();
+    for (const Fortran::semantics::Symbol *sym : info.localSymList)
+      createHostAssociateVarClone(*sym);
+    for (const Fortran::semantics::Symbol *sym : info.localInitSymList) {
+      createHostAssociateVarClone(*sym);
+      const auto *hostDetails =
+          sym->detailsIf<Fortran::semantics::HostAssocDetails>();
+      assert(hostDetails && "missing locality spec host symbol");
+      const Fortran::semantics::Symbol *hostSym = &hostDetails->symbol();
+      Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
+      Fortran::evaluate::Assignment assign{
+          ea.Designate(Fortran::evaluate::DataRef{*sym}).value(),
+          ea.Designate(Fortran::evaluate::DataRef{*hostSym}).value()};
+      if (Fortran::semantics::IsPointer(*sym))
+        assign.u = Fortran::evaluate::Assignment::BoundsSpec{};
+      genAssignment(assign);
+    }
+    for (const Fortran::semantics::Symbol *sym : info.sharedSymList) {
+      const auto *hostDetails =
+          sym->detailsIf<Fortran::semantics::HostAssocDetails>();
+      copySymbolBinding(hostDetails->symbol(), *sym);
+    }
+  }
+
   /// Generate FIR for a DO construct. There are six variants:
   ///  - unstructured infinite and while loops
   ///  - structured and unstructured increment loops
@@ -1656,25 +1692,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         return builder->createRealConstant(loc, controlType, 1u);
       return builder->createIntegerConstant(loc, controlType, 1); // step
     };
-    auto handleLocalitySpec = [&](IncrementLoopInfo &info) {
-      // Generate Local Init Assignments
-      for (const Fortran::semantics::Symbol *sym : info.localInitSymList) {
-        const auto *hostDetails =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>();
-        assert(hostDetails && "missing local_init variable host variable");
-        const Fortran::semantics::Symbol &hostSym = hostDetails->symbol();
-        (void)hostSym;
-        TODO(loc, "do concurrent locality specs not implemented");
-      }
-      // Handle shared locality spec
-      for (const Fortran::semantics::Symbol *sym : info.sharedSymList) {
-        const auto *hostDetails =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>();
-        assert(hostDetails && "missing shared variable host variable");
-        const Fortran::semantics::Symbol &hostSym = hostDetails->symbol();
-        copySymbolBinding(hostSym, *sym);
-      }
-    };
     for (IncrementLoopInfo &info : incrementLoopNestInfo) {
       info.loopVariable =
           genLoopVariableAddress(loc, info.loopVariableSym, info.isUnordered);
@@ -1714,7 +1731,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                                                  /*withElseRegion=*/false);
           builder->setInsertionPointToStart(&ifOp.getThenRegion().front());
         }
-        handleLocalitySpec(info);
+        if (info.hasLocalitySpecs())
+          handleLocalitySpecs(info);
         continue;
       }
 
@@ -1771,10 +1789,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         if (&info != &incrementLoopNestInfo.back()) // not innermost
           startBlock(info.bodyBlock); // preheader block of enclosed dimension
       }
-      if (!info.localInitSymList.empty()) {
+      if (info.hasLocalitySpecs()) {
         mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
         builder->setInsertionPointToStart(info.bodyBlock);
-        handleLocalitySpec(info);
+        handleLocalitySpecs(info);
         builder->restoreInsertionPoint(insertPt);
       }
     }

diff  --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index e811b45e6c74f4..ea952198251c9b 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -590,13 +590,15 @@ absentBoxToUnallocatedBox(fir::FirOpBuilder &builder, mlir::Location loc,
 // associations.
 template <typename A>
 const Fortran::semantics::Symbol &getFirstSym(const A &obj) {
-  return obj.GetFirstSymbol().GetUltimate();
+  const Fortran::semantics::Symbol &sym = obj.GetFirstSymbol();
+  return sym.HasLocalLocality() ? sym : sym.GetUltimate();
 }
 
 // Helper to get the ultimate last symbol.
 template <typename A>
 const Fortran::semantics::Symbol &getLastSym(const A &obj) {
-  return obj.GetLastSymbol().GetUltimate();
+  const Fortran::semantics::Symbol &sym = obj.GetLastSymbol();
+  return sym.HasLocalLocality() ? sym : sym.GetUltimate();
 }
 
 // Return true if TRANSPOSE should be lowered without a runtime call.

diff  --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 2d9c16346cac3f..ce078a7fbde651 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -35,10 +35,10 @@ void Fortran::lower::SymMap::addSymbol(Fortran::semantics::SymbolRef sym,
 
 Fortran::lower::SymbolBox
 Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
-  Fortran::semantics::SymbolRef sym = symRef.get().GetUltimate();
+  auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
   for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
        jmap != jend; ++jmap) {
-    auto iter = jmap->find(&*sym);
+    auto iter = jmap->find(sym);
     if (iter != jmap->end())
       return iter->second;
   }
@@ -47,8 +47,9 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
 
 Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
     Fortran::semantics::SymbolRef symRef) {
+  auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
   auto &map = symbolMapStack.back();
-  auto iter = map.find(&symRef.get().GetUltimate());
+  auto iter = map.find(sym);
   if (iter != map.end())
     return iter->second;
   return SymbolBox::None{};
@@ -59,14 +60,14 @@ Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
 /// host-association in OpenMP code.
 Fortran::lower::SymbolBox Fortran::lower::SymMap::lookupOneLevelUpSymbol(
     Fortran::semantics::SymbolRef symRef) {
-  Fortran::semantics::SymbolRef sym = symRef.get().GetUltimate();
+  auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
   auto jmap = symbolMapStack.rbegin();
   auto jend = symbolMapStack.rend();
   if (jmap == jend)
     return SymbolBox::None{};
   // Skip one level in symbol map stack.
   for (++jmap; jmap != jend; ++jmap) {
-    auto iter = jmap->find(&*sym);
+    auto iter = jmap->find(sym);
     if (iter != jmap->end())
       return iter->second;
   }

diff  --git a/flang/test/Lower/loops.f90 b/flang/test/Lower/loops.f90
index febae0e8a0dd66..12385cbb4c186f 100644
--- a/flang/test/Lower/loops.f90
+++ b/flang/test/Lower/loops.f90
@@ -91,6 +91,71 @@ subroutine loop_test
   print '(" F:",X,F3.1,A,I2)', x, ' -', xsum
 end subroutine loop_test
 
+! CHECK-LABEL: c.func @_QPlis
+subroutine lis(n)
+  ! CHECK-DAG: fir.alloca i32 {bindc_name = "m"}
+  ! CHECK-DAG: fir.alloca i32 {bindc_name = "j"}
+  ! CHECK-DAG: fir.alloca i32 {bindc_name = "i"}
+  ! CHECK-DAG: fir.alloca i8 {bindc_name = "i"}
+  ! CHECK-DAG: fir.alloca i32 {bindc_name = "j", uniq_name = "_QFlisEj"}
+  ! CHECK-DAG: fir.alloca i32 {bindc_name = "k", uniq_name = "_QFlisEk"}
+  ! CHECK-DAG: fir.alloca !fir.box<!fir.ptr<!fir.array<?x?x?xi32>>> {bindc_name = "p", uniq_name = "_QFlisEp"}
+  ! CHECK-DAG: fir.alloca !fir.array<?x?x?xi32>, %{{.*}}, %{{.*}}, %{{.*}} {bindc_name = "a", fir.target, uniq_name = "_QFlisEa"}
+  ! CHECK-DAG: fir.alloca !fir.array<?x?xi32>, %{{.*}}, %{{.*}} {bindc_name = "r", uniq_name = "_QFlisEr"}
+  ! CHECK-DAG: fir.alloca !fir.array<?x?xi32>, %{{.*}}, %{{.*}} {bindc_name = "s", uniq_name = "_QFlisEs"}
+  ! CHECK-DAG: fir.alloca !fir.array<?x?xi32>, %{{.*}}, %{{.*}} {bindc_name = "t", uniq_name = "_QFlisEt"}
+  integer, target    :: a(n,n,n) ! operand via p
+  integer            :: r(n,n)   ! result, unspecified locality
+  integer            :: s(n,n)   ! shared locality
+  integer            :: t(n,n)   ! local locality
+  integer, pointer   :: p(:,:,:) ! local_init locality
+
+  p => a
+  ! CHECK:     fir.do_loop %arg1 = %c0{{.*}} to %{{.*}} step %c1{{.*}} unordered iter_args(%arg2 = %{{.*}}) -> (!fir.array<?x?xi32>) {
+  ! CHECK:       fir.do_loop %arg3 = %c0{{.*}} to %{{.*}} step %c1{{.*}} unordered iter_args(%arg4 = %arg2) -> (!fir.array<?x?xi32>) {
+  ! CHECK:       }
+  ! CHECK:     }
+  r = 0
+
+  ! CHECK:     fir.do_loop %arg1 = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+  ! CHECK:       fir.do_loop %arg2 = %{{.*}} to %{{.*}} step %c1{{.*}} iter_args(%arg3 = %{{.*}}) -> (index, i32) {
+  ! CHECK:       }
+  ! CHECK:     }
+  do concurrent (integer(kind=1)::i=n:1:-1)
+    do j = 1,n
+      a(i,j,:) = 2*(i+j)
+      s(i,j)   = -i-j
+    enddo
+  enddo
+
+  ! CHECK:     fir.do_loop %arg1 = %{{.*}} to %{{.*}} step %c1{{.*}} unordered {
+  ! CHECK:       fir.do_loop %arg2 = %{{.*}} to %{{.*}} step %c1{{.*}} unordered {
+  ! CHECK:         fir.if %{{.*}} {
+  ! CHECK:           %[[V_95:[0-9]+]] = fir.alloca !fir.array<?x?xi32>, %{{.*}}, %{{.*}} {bindc_name = "t", pinned, uniq_name = "_QFlisEt"}
+  ! CHECK:           %[[V_96:[0-9]+]] = fir.alloca !fir.box<!fir.ptr<!fir.array<?x?x?xi32>>> {bindc_name = "p", pinned, uniq_name = "_QFlisEp"}
+  ! CHECK:           fir.store %{{.*}} to %[[V_96]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?x?xi32>>>>
+  ! CHECK:           fir.do_loop %arg3 = %{{.*}} to %{{.*}} step %c1{{.*}} iter_args(%arg4 = %{{.*}}) -> (index, i32) {
+  ! CHECK:             fir.do_loop %arg5 = %{{.*}} to %{{.*}} step %c1{{.*}} unordered {
+  ! CHECK:               fir.load %[[V_96]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?x?xi32>>>>
+  ! CHECK:               fir.convert %[[V_95]] : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.ref<!fir.array<?xi32>>
+  ! CHECK:             }
+  ! CHECK:           }
+  ! CHECK:           fir.convert %[[V_95]] : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.ref<!fir.array<?xi32>>
+  ! CHECK:         }
+  ! CHECK:       }
+  ! CHECK:     }
+  do concurrent (i=1:n,j=1:n,i.ne.j) local(t) local_init(p) shared(s)
+    do k=1,n
+      do concurrent (m=1:n)
+        t(k,m) = p(k,m,k)
+      enddo
+    enddo
+    r(i,j) = t(i,j) + s(i,j)
+  enddo
+
+  print*, sum(r) ! n=6 -> 210
+end
+
 ! CHECK-LABEL: print_nothing
 subroutine print_nothing(k1, k2)
   if (k1 > 0) then
@@ -105,5 +170,6 @@ subroutine print_nothing(k1, k2)
 end
 
   call loop_test
+  call lis(6)
   call print_nothing(2, 2)
 end


        


More information about the flang-commits mailing list