[flang-commits] [flang] 84b9ae6 - [flang]Add support for do concurrent

Mats Petersson via flang-commits flang-commits at lists.llvm.org
Mon Jun 13 04:29:12 PDT 2022


Author: Mats Petersson
Date: 2022-06-13T12:28:49+01:00
New Revision: 84b9ae662419ce97b3cb13879be431f6a0c9eaa4

URL: https://github.com/llvm/llvm-project/commit/84b9ae662419ce97b3cb13879be431f6a0c9eaa4
DIFF: https://github.com/llvm/llvm-project/commit/84b9ae662419ce97b3cb13879be431f6a0c9eaa4.diff

LOG: [flang]Add support for do concurrent

[flang]Add support for do concurrent

Upstreaming from fir-dev on https://github.com/flang-compiler/f18-llvm-project

Support for concurrent execution in do-loops.

A selection of tests are also added.

Co-authored-by: V Donaldson <vdonaldson at nvidia.com>

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D127240

Added: 
    flang/test/Lower/loops.f90

Modified: 
    flang/lib/Lower/Bridge.cpp
    flang/test/Lower/OpenMP/omp-unstructured.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index eb58c4ee06346..6caea151c20e8 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -87,7 +87,10 @@ struct IncrementLoopInfo {
   const Fortran::lower::SomeExpr *lowerExpr;
   const Fortran::lower::SomeExpr *upperExpr;
   const Fortran::lower::SomeExpr *stepExpr;
+  const Fortran::lower::SomeExpr *maskExpr = nullptr;
   bool isUnordered; // do concurrent, forall
+  llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
+  llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
   mlir::Value loopVariable = nullptr;
   mlir::Value stepValue = nullptr; // possible uses in multiple blocks
 
@@ -98,6 +101,7 @@ struct IncrementLoopInfo {
   bool hasRealControl = false;
   mlir::Value tripVariable = nullptr;
   mlir::Block *headerBlock = nullptr; // loop entry and test block
+  mlir::Block *maskBlock = nullptr;   // concurrent loop mask block
   mlir::Block *bodyBlock = nullptr;   // first loop body block
   mlir::Block *exitBlock = nullptr;   // loop exit target block
 };
@@ -636,9 +640,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   }
 
   /// Generate the address of loop variable \p sym.
+  /// If \p sym is not mapped yet, allocate local storage for it.
   mlir::Value genLoopVariableAddress(mlir::Location loc,
-                                     const Fortran::semantics::Symbol &sym) {
-    assert(lookupSymbol(sym) && "loop control variable must already be in map");
+                                     const Fortran::semantics::Symbol &sym,
+                                     bool isUnordered) {
+    if (isUnordered || sym.has<Fortran::semantics::HostAssocDetails>() ||
+        sym.has<Fortran::semantics::UseDetails>()) {
+      if (!shallowLookupSymbol(sym)) {
+        // Do concurrent loop variables are not mapped yet since they are local
+        // to the Do concurrent scope (same for OpenMP loops).
+        auto newVal = builder->createTemporary(loc, genType(sym),
+                                               toStringRef(sym.name()));
+        bindIfNewSymbol(sym, newVal);
+        return newVal;
+      }
+    }
+    auto entry = lookupSymbol(sym);
+    (void)entry;
+    assert(entry && "loop control variable must already be in map");
     Fortran::lower::StatementContext stmtCtx;
     return fir::getBase(
         genExprAddr(Fortran::evaluate::AsGenericExpr(sym).value(), stmtCtx));
@@ -973,6 +992,34 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     builder->create<fir::SelectOp>(loc, selectExpr, indexList, blockList);
   }
 
+  /// Collect DO CONCURRENT or FORALL loop control information.
+  IncrementLoopNestInfo getConcurrentControl(
+      const Fortran::parser::ConcurrentHeader &header,
+      const std::list<Fortran::parser::LocalitySpec> &localityList = {}) {
+    IncrementLoopNestInfo incrementLoopNestInfo;
+    for (const Fortran::parser::ConcurrentControl &control :
+         std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t))
+      incrementLoopNestInfo.emplace_back(
+          *std::get<0>(control.t).symbol, std::get<1>(control.t),
+          std::get<2>(control.t), std::get<3>(control.t), /*isUnordered=*/true);
+    IncrementLoopInfo &info = incrementLoopNestInfo.back();
+    info.maskExpr = Fortran::semantics::GetExpr(
+        std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(header.t));
+    for (const Fortran::parser::LocalitySpec &x : localityList) {
+      if (const auto *localInitList =
+              std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
+        for (const Fortran::parser::Name &x : localInitList->v)
+          info.localInitSymList.push_back(x.symbol);
+      if (const auto *sharedList =
+              std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
+        for (const Fortran::parser::Name &x : sharedList->v)
+          info.sharedSymList.push_back(x.symbol);
+      if (std::get_if<Fortran::parser::LocalitySpec::Local>(&x.u))
+        TODO(toLocation(), "do concurrent locality specs not implemented");
+    }
+    return incrementLoopNestInfo;
+  }
+
   /// Generate FIR for a DO construct.  There are six variants:
   ///  - unstructured infinite and while loops
   ///  - structured and unstructured increment loops
@@ -1029,7 +1076,34 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         info.exitBlock = exitBlock;
       }
     } else {
-      TODO(toLocation(), "infinite/unstructured loop/concurrent loop");
+      const auto *concurrent =
+          std::get_if<Fortran::parser::LoopControl::Concurrent>(
+              &loopControl->u);
+      assert(concurrent && "invalid DO loop variant");
+      incrementLoopNestInfo = getConcurrentControl(
+          std::get<Fortran::parser::ConcurrentHeader>(concurrent->t),
+          std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent->t));
+      if (unstructuredContext) {
+        maybeStartBlock(preheaderBlock);
+        for (IncrementLoopInfo &info : incrementLoopNestInfo) {
+          // The original loop body provides the body and latch blocks of the
+          // innermost dimension.  The (first) body block of a non-innermost
+          // dimension is the preheader block of the immediately enclosed
+          // dimension.  The latch block of a non-innermost dimension is the
+          // exit block of the immediately enclosed dimension.
+          auto createNextExitBlock = [&]() {
+            // Create unstructured loop exit blocks, outermost to innermost.
+            return exitBlock = insertBlock(exitBlock);
+          };
+          bool isInnermost = &info == &incrementLoopNestInfo.back();
+          bool isOutermost = &info == &incrementLoopNestInfo.front();
+          info.headerBlock = isOutermost ? headerBlock : createNextBeginBlock();
+          info.bodyBlock = isInnermost ? bodyBlock : createNextBeginBlock();
+          info.exitBlock = isOutermost ? exitBlock : createNextExitBlock();
+          if (info.maskExpr)
+            info.maskBlock = createNextBeginBlock();
+        }
+      }
     }
 
     // Increment loop begin code.  (Infinite/while code was already generated.)
@@ -1065,8 +1139,28 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         return builder->createRealConstant(loc, controlType, 1u);
       return builder->createIntegerConstant(loc, controlType, 1); // step
     };
+    auto handleLocalitySpec = [&](IncrementLoopInfo &info) {
+      // Generate Local Init Assignments
+      for (const Fortran::semantics::Symbol *sym : info.localInitSymList) {
+        const auto *hostDetails =
+            sym->detailsIf<Fortran::semantics::HostAssocDetails>();
+        assert(hostDetails && "missing local_init variable host variable");
+        const Fortran::semantics::Symbol &hostSym = hostDetails->symbol();
+        (void)hostSym;
+        TODO(loc, "do concurrent locality specs not implemented");
+      }
+      // Handle shared locality spec
+      for (const Fortran::semantics::Symbol *sym : info.sharedSymList) {
+        const auto *hostDetails =
+            sym->detailsIf<Fortran::semantics::HostAssocDetails>();
+        assert(hostDetails && "missing shared variable host variable");
+        const Fortran::semantics::Symbol &hostSym = hostDetails->symbol();
+        copySymbolBinding(hostSym, *sym);
+      }
+    };
     for (IncrementLoopInfo &info : incrementLoopNestInfo) {
-      info.loopVariable = genLoopVariableAddress(loc, info.loopVariableSym);
+      info.loopVariable =
+          genLoopVariableAddress(loc, info.loopVariableSym, info.isUnordered);
       mlir::Value lowerValue = genControlValue(info.lowerExpr, info);
       mlir::Value upperValue = genControlValue(info.upperExpr, info);
       info.stepValue = genControlValue(info.stepExpr, info);
@@ -1081,8 +1175,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         mlir::Value value = builder->createConvert(
             loc, info.getLoopVariableType(), info.doLoop.getInductionVar());
         builder->create<fir::StoreOp>(loc, value, info.loopVariable);
-        // TODO: Mask expr
-        // TODO: handle Locality Spec
+        if (info.maskExpr) {
+          Fortran::lower::StatementContext stmtCtx;
+          mlir::Value maskCond = createFIRExpr(loc, info.maskExpr, stmtCtx);
+          stmtCtx.finalize();
+          mlir::Value maskCondCast =
+              builder->createConvert(loc, builder->getI1Type(), maskCond);
+          auto ifOp = builder->create<fir::IfOp>(loc, maskCondCast,
+                                                 /*withElseRegion=*/false);
+          builder->setInsertionPointToStart(&ifOp.getThenRegion().front());
+        }
+        handleLocalitySpec(info);
         continue;
       }
 
@@ -1119,16 +1222,33 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       builder->create<fir::StoreOp>(loc, lowerValue, info.loopVariable);
 
       // Unstructured loop header - generate loop condition and mask.
+      // Note - Currently there is no way to tag a loop as a concurrent loop.
       startBlock(info.headerBlock);
       tripCount = builder->create<fir::LoadOp>(loc, info.tripVariable);
       mlir::Value zero =
           builder->createIntegerConstant(loc, tripCount.getType(), 0);
       auto cond = builder->create<mlir::arith::CmpIOp>(
           loc, mlir::arith::CmpIPredicate::sgt, tripCount, zero);
-      // TODO: mask expression
-      genFIRConditionalBranch(cond, info.bodyBlock, info.exitBlock);
-      if (&info != &incrementLoopNestInfo.back()) // not innermost
-        startBlock(info.bodyBlock); // preheader block of enclosed dimension
+      if (info.maskExpr) {
+        genFIRConditionalBranch(cond, info.maskBlock, info.exitBlock);
+        startBlock(info.maskBlock);
+        mlir::Block *latchBlock = getEval().getLastNestedEvaluation().block;
+        assert(latchBlock && "missing masked concurrent loop latch block");
+        Fortran::lower::StatementContext stmtCtx;
+        mlir::Value maskCond = createFIRExpr(loc, info.maskExpr, stmtCtx);
+        stmtCtx.finalize();
+        genFIRConditionalBranch(maskCond, info.bodyBlock, latchBlock);
+      } else {
+        genFIRConditionalBranch(cond, info.bodyBlock, info.exitBlock);
+        if (&info != &incrementLoopNestInfo.back()) // not innermost
+          startBlock(info.bodyBlock); // preheader block of enclosed dimension
+      }
+      if (!info.localInitSymList.empty()) {
+        mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+        builder->setInsertionPointToStart(info.bodyBlock);
+        handleLocalitySpec(info);
+        builder->restoreInsertionPoint(insertPt);
+      }
     }
   }
 

diff  --git a/flang/test/Lower/OpenMP/omp-unstructured.f90 b/flang/test/Lower/OpenMP/omp-unstructured.f90
index 28ed87169bbc6..e7d48bb269349 100644
--- a/flang/test/Lower/OpenMP/omp-unstructured.f90
+++ b/flang/test/Lower/OpenMP/omp-unstructured.f90
@@ -59,8 +59,8 @@ subroutine ss2(n) ! unstructured OpenMP construct; loop exit inside construct
 end
 
 ! CHECK-LABEL: func @_QPss3{{.*}} {
-! CHECK:     %[[ALLOCA_K:.*]] = fir.alloca i32 {bindc_name = "k", {{.*}}}
 ! CHECK:   omp.parallel {
+! CHECK:     %[[ALLOCA_K:.*]] = fir.alloca i32 {bindc_name = "k", pinned}
 ! CHECK:     %[[ALLOCA_1:.*]] = fir.alloca i32 {{{.*}}, pinned}
 ! CHECK:     %[[ALLOCA_2:.*]] = fir.alloca i32 {{{.*}}, pinned}
 ! CHECK:     br ^bb1

diff  --git a/flang/test/Lower/loops.f90 b/flang/test/Lower/loops.f90
new file mode 100644
index 0000000000000..2a95c69d6e225
--- /dev/null
+++ b/flang/test/Lower/loops.f90
@@ -0,0 +1,109 @@
+! RUN: bbc -emit-fir -o - %s | FileCheck %s
+
+! CHECK-LABEL: loop_test
+subroutine loop_test
+  ! CHECK: %[[VAL_2:.*]] = fir.alloca i16 {bindc_name = "i"}
+  ! CHECK: %[[VAL_3:.*]] = fir.alloca i16 {bindc_name = "i"}
+  ! CHECK: %[[VAL_4:.*]] = fir.alloca i16 {bindc_name = "i"}
+  ! CHECK: %[[VAL_5:.*]] = fir.alloca i8 {bindc_name = "k"}
+  ! CHECK: %[[VAL_6:.*]] = fir.alloca i8 {bindc_name = "j"}
+  ! CHECK: %[[VAL_7:.*]] = fir.alloca i8 {bindc_name = "i"}
+  ! CHECK: %[[VAL_8:.*]] = fir.alloca i32 {bindc_name = "k"}
+  ! CHECK: %[[VAL_9:.*]] = fir.alloca i32 {bindc_name = "j"}
+  ! CHECK: %[[VAL_10:.*]] = fir.alloca i32 {bindc_name = "i"}
+  ! CHECK: %[[VAL_11:.*]] = fir.alloca !fir.array<5x5x5xi32> {bindc_name = "a", uniq_name = "_QFloop_testEa"}
+  ! CHECK: %[[VAL_12:.*]] = fir.alloca i32 {bindc_name = "asum", uniq_name = "_QFloop_testEasum"}
+  ! CHECK: %[[VAL_13:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFloop_testEi"}
+  ! CHECK: %[[VAL_14:.*]] = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFloop_testEj"}
+  ! CHECK: %[[VAL_15:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFloop_testEk"}
+  ! CHECK: %[[VAL_16:.*]] = fir.alloca f32 {bindc_name = "x", uniq_name = "_QFloop_testEx"}
+  ! CHECK: %[[VAL_17:.*]] = fir.alloca i32 {bindc_name = "xsum", uniq_name = "_QFloop_testExsum"}
+
+  integer(4) :: a(5,5,5), i, j, k, asum, xsum
+
+  i = 100
+  j = 200
+  k = 300
+
+  ! CHECK-COUNT-3: fir.do_loop {{.*}} unordered
+  do concurrent (i=1:5, j=1:5, k=1:5) ! shared(a)
+    ! CHECK: fir.coordinate_of
+    a(i,j,k) = 0
+  enddo
+  ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+  print*, 'A:', i, j, k
+
+  ! CHECK-COUNT-3: fir.do_loop {{.*}} unordered
+  ! CHECK: fir.if
+  do concurrent (integer(1)::i=1:5, j=1:5, k=1:5, i.ne.j .and. k.ne.3) shared(a)
+    ! CHECK-COUNT-2: fir.coordinate_of
+    a(i,j,k) = a(i,j,k) + 1
+  enddo
+
+  ! CHECK-COUNT-3: fir.do_loop {{[^un]*}} -> index
+  asum = 0
+  do i=1,5
+    do j=1,5
+      do k=1,5
+        ! CHECK: fir.coordinate_of
+        asum = asum + a(i,j,k)
+      enddo
+    enddo
+  enddo
+  ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+  print*, 'B:', i, j, k, '-', asum
+
+  ! CHECK: fir.do_loop {{.*}} unordered
+  ! CHECK-COUNT-2: fir.if
+  do concurrent (integer(2)::i=1:5, i.ne.3)
+    if (i.eq.2 .or. i.eq.4) goto 5 ! fir.if
+    ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+    print*, 'C:', i
+  5 continue
+  enddo
+
+  ! CHECK: fir.do_loop {{.*}} unordered
+  ! CHECK-COUNT-2: fir.if
+  do concurrent (integer(2)::i=1:5, i.ne.3)
+    if (i.eq.2 .or. i.eq.4) then ! fir.if
+      goto 6
+    endif
+    ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+    print*, 'D:', i
+  6 continue
+  enddo
+
+  ! CHECK-NOT: fir.do_loop
+  ! CHECK-NOT: fir.if
+  do concurrent (integer(2)::i=1:5, i.ne.3)
+    goto (7, 7) i+1
+    ! CHECK: fir.call @_FortranAioBeginExternalListOutput
+    print*, 'E:', i
+  7 continue
+  enddo
+
+  xsum = 0.0
+  ! CHECK-NOT: fir.do_loop
+  do x = 1.5, 3.5, 0.3
+    xsum = xsum + 1
+  enddo
+  ! CHECK: fir.call @_FortranAioBeginExternalFormattedOutput
+  print '(" F:",X,F3.1,A,I2)', x, ' -', xsum
+end subroutine loop_test
+
+! CHECK-LABEL: print_nothing
+subroutine print_nothing(k1, k2)
+  if (k1 > 0) then
+    ! CHECK: br [[header:\^bb[0-9]+]]
+    ! CHECK: [[header]]
+    do while (k1 > k2)
+      print*, k1, k2 ! no output
+      k2 = k2 + 1
+      ! CHECK: br [[header]]
+    end do
+  end if
+end
+
+  call loop_test
+  call print_nothing(2, 2)
+end


        


More information about the flang-commits mailing list