[llvm-branch-commits] [flang] [flang][OpenMP] Convert DataSharingProcessor to omp::Clause (PR #81629)

Krzysztof Parzyszek via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Feb 23 06:45:00 PST 2024


https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/81629

>From 158901f346ae6ff3f010b2c96c01129b578291ee Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 21 Feb 2024 17:10:30 -0600
Subject: [PATCH] [flang][OpenMP] Convert DataSharingProcessor to omp::Clause

---
 .../lib/Lower/OpenMP/DataSharingProcessor.cpp | 277 +++++++++---------
 flang/lib/Lower/OpenMP/DataSharingProcessor.h |  13 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             |   6 +-
 3 files changed, 146 insertions(+), 150 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 136bda0b582ee3..3b687d03adf371 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -79,30 +79,28 @@ void DataSharingProcessor::copyLastPrivateSymbol(
 }
 
 void DataSharingProcessor::collectOmpObjectListSymbol(
-    const Fortran::parser::OmpObjectList &ompObjectList,
+    const omp::ObjectList &objects,
     llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet) {
-  for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
-    Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+  for (const omp::Object &object : objects) {
+    Fortran::semantics::Symbol *sym = object.id();
     symbolSet.insert(sym);
   }
 }
 
 void DataSharingProcessor::collectSymbolsForPrivatization() {
   bool hasCollapse = false;
-  for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
+  for (const omp::Clause &clause : clauses) {
     if (const auto &privateClause =
-            std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) {
+            std::get_if<omp::clause::Private>(&clause.u)) {
       collectOmpObjectListSymbol(privateClause->v, privatizedSymbols);
     } else if (const auto &firstPrivateClause =
-                   std::get_if<Fortran::parser::OmpClause::Firstprivate>(
-                       &clause.u)) {
+                   std::get_if<omp::clause::Firstprivate>(&clause.u)) {
       collectOmpObjectListSymbol(firstPrivateClause->v, privatizedSymbols);
     } else if (const auto &lastPrivateClause =
-                   std::get_if<Fortran::parser::OmpClause::Lastprivate>(
-                       &clause.u)) {
+                   std::get_if<omp::clause::Lastprivate>(&clause.u)) {
       collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols);
       hasLastPrivateOp = true;
-    } else if (std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
+    } else if (std::get_if<omp::clause::Collapse>(&clause.u)) {
       hasCollapse = true;
     }
   }
@@ -135,138 +133,135 @@ void DataSharingProcessor::insertBarrier() {
 void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
   bool cmpCreated = false;
   mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint();
-  for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
-    if (std::get_if<Fortran::parser::OmpClause::Lastprivate>(&clause.u)) {
-      // TODO: Add lastprivate support for simd construct
-      if (mlir::isa<mlir::omp::SectionOp>(op)) {
-        if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) {
-          // For `omp.sections`, lastprivatized variables occur in
-          // lexically final `omp.section` operation. The following FIR
-          // shall be generated for the same:
-          //
-          // omp.sections lastprivate(...) {
-          //  omp.section {...}
-          //  omp.section {...}
-          //  omp.section {
-          //      fir.allocate for `private`/`firstprivate`
-          //      <More operations here>
-          //      fir.if %true {
-          //          ^%lpv_update_blk
-          //      }
-          //  }
-          // }
-          //
-          // To keep code consistency while handling privatization
-          // through this control flow, add a `fir.if` operation
-          // that always evaluates to true, in order to create
-          // a dedicated sub-region in `omp.section` where
-          // lastprivate FIR can reside. Later canonicalizations
-          // will optimize away this operation.
-          if (!eval.lowerAsUnstructured()) {
-            auto ifOp = firOpBuilder.create<fir::IfOp>(
-                op->getLoc(),
-                firOpBuilder.createIntegerConstant(
-                    op->getLoc(), firOpBuilder.getIntegerType(1), 0x1),
-                /*else*/ false);
-            firOpBuilder.setInsertionPointToStart(
-                &ifOp.getThenRegion().front());
-
-            const Fortran::parser::OpenMPConstruct *parentOmpConstruct =
-                eval.parentConstruct->getIf<Fortran::parser::OpenMPConstruct>();
-            assert(parentOmpConstruct &&
-                   "Expected a valid enclosing OpenMP construct");
-            const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct =
-                std::get_if<Fortran::parser::OpenMPSectionsConstruct>(
-                    &parentOmpConstruct->u);
-            assert(sectionsConstruct &&
-                   "Expected an enclosing omp.sections construct");
-            const Fortran::parser::OmpClauseList &sectionsEndClauseList =
-                std::get<Fortran::parser::OmpClauseList>(
-                    std::get<Fortran::parser::OmpEndSectionsDirective>(
-                        sectionsConstruct->t)
-                        .t);
-            for (const Fortran::parser::OmpClause &otherClause :
-                 sectionsEndClauseList.v)
-              if (std::get_if<Fortran::parser::OmpClause::Nowait>(
-                      &otherClause.u))
-                // Emit implicit barrier to synchronize threads and avoid data
-                // races on post-update of lastprivate variables when `nowait`
-                // clause is present.
-                firOpBuilder.create<mlir::omp::BarrierOp>(
-                    converter.getCurrentLocation());
-            firOpBuilder.setInsertionPointToStart(
-                &ifOp.getThenRegion().front());
-            lastPrivIP = firOpBuilder.saveInsertionPoint();
-            firOpBuilder.setInsertionPoint(ifOp);
-            insPt = firOpBuilder.saveInsertionPoint();
-          } else {
-            // Lastprivate operation is inserted at the end
-            // of the lexically last section in the sections
-            // construct
-            mlir::OpBuilder::InsertPoint unstructuredSectionsIP =
-                firOpBuilder.saveInsertionPoint();
-            mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
-            firOpBuilder.setInsertionPoint(lastOper);
-            lastPrivIP = firOpBuilder.saveInsertionPoint();
-            firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP);
-          }
-        }
-      } else if (mlir::isa<mlir::omp::WsLoopOp>(op)) {
-        // Update the original variable just before exiting the worksharing
-        // loop. Conversion as follows:
+  for (const omp::Clause &clause : clauses) {
+    if (clause.id != llvm::omp::OMPC_lastprivate)
+      continue;
+    // TODO: Add lastprivate support for simd construct
+    if (mlir::isa<mlir::omp::SectionOp>(op)) {
+      if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) {
+        // For `omp.sections`, lastprivatized variables occur in
+        // lexically final `omp.section` operation. The following FIR
+        // shall be generated for the same:
         //
-        //                       omp.wsloop {
-        // omp.wsloop {            ...
-        //    ...                  store
-        //    store       ===>     %v = arith.addi %iv, %step
-        //    omp.yield            %cmp = %step < 0 ? %v < %ub : %v > %ub
-        // }                       fir.if %cmp {
-        //                           fir.store %v to %loopIV
-        //                           ^%lpv_update_blk:
-        //                         }
-        //                         omp.yield
-        //                       }
+        // omp.sections lastprivate(...) {
+        //  omp.section {...}
+        //  omp.section {...}
+        //  omp.section {
+        //      fir.allocate for `private`/`firstprivate`
+        //      <More operations here>
+        //      fir.if %true {
+        //          ^%lpv_update_blk
+        //      }
+        //  }
+        // }
         //
+        // To keep code consistency while handling privatization
+        // through this control flow, add a `fir.if` operation
+        // that always evaluates to true, in order to create
+        // a dedicated sub-region in `omp.section` where
+        // lastprivate FIR can reside. Later canonicalizations
+        // will optimize away this operation.
+        if (!eval.lowerAsUnstructured()) {
+          auto ifOp = firOpBuilder.create<fir::IfOp>(
+              op->getLoc(),
+              firOpBuilder.createIntegerConstant(
+                  op->getLoc(), firOpBuilder.getIntegerType(1), 0x1),
+              /*else*/ false);
+          firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
 
-        // Only generate the compare once in presence of multiple LastPrivate
-        // clauses.
-        if (cmpCreated)
-          continue;
-        cmpCreated = true;
+          const Fortran::parser::OpenMPConstruct *parentOmpConstruct =
+              eval.parentConstruct->getIf<Fortran::parser::OpenMPConstruct>();
+          assert(parentOmpConstruct &&
+                 "Expected a valid enclosing OpenMP construct");
+          const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct =
+              std::get_if<Fortran::parser::OpenMPSectionsConstruct>(
+                  &parentOmpConstruct->u);
+          assert(sectionsConstruct &&
+                 "Expected an enclosing omp.sections construct");
+          const Fortran::parser::OmpClauseList &sectionsEndClauseList =
+              std::get<Fortran::parser::OmpClauseList>(
+                  std::get<Fortran::parser::OmpEndSectionsDirective>(
+                      sectionsConstruct->t)
+                      .t);
+          for (const Fortran::parser::OmpClause &otherClause :
+               sectionsEndClauseList.v)
+            if (std::get_if<Fortran::parser::OmpClause::Nowait>(&otherClause.u))
+              // Emit implicit barrier to synchronize threads and avoid data
+              // races on post-update of lastprivate variables when `nowait`
+              // clause is present.
+              firOpBuilder.create<mlir::omp::BarrierOp>(
+                  converter.getCurrentLocation());
+          firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+          lastPrivIP = firOpBuilder.saveInsertionPoint();
+          firOpBuilder.setInsertionPoint(ifOp);
+          insPt = firOpBuilder.saveInsertionPoint();
+        } else {
+          // Lastprivate operation is inserted at the end
+          // of the lexically last section in the sections
+          // construct
+          mlir::OpBuilder::InsertPoint unstructuredSectionsIP =
+              firOpBuilder.saveInsertionPoint();
+          mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+          firOpBuilder.setInsertionPoint(lastOper);
+          lastPrivIP = firOpBuilder.saveInsertionPoint();
+          firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP);
+        }
+      }
+    } else if (mlir::isa<mlir::omp::WsLoopOp>(op)) {
+      // Update the original variable just before exiting the worksharing
+      // loop. Conversion as follows:
+      //
+      //                       omp.wsloop {
+      // omp.wsloop {            ...
+      //    ...                  store
+      //    store       ===>     %v = arith.addi %iv, %step
+      //    omp.yield            %cmp = %step < 0 ? %v < %ub : %v > %ub
+      // }                       fir.if %cmp {
+      //                           fir.store %v to %loopIV
+      //                           ^%lpv_update_blk:
+      //                         }
+      //                         omp.yield
+      //                       }
+      //
 
-        mlir::Location loc = op->getLoc();
-        mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
-        firOpBuilder.setInsertionPoint(lastOper);
+      // Only generate the compare once in presence of multiple LastPrivate
+      // clauses.
+      if (cmpCreated)
+        continue;
+      cmpCreated = true;
 
-        mlir::Value iv = op->getRegion(0).front().getArguments()[0];
-        mlir::Value ub =
-            mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getUpperBound()[0];
-        mlir::Value step = mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getStep()[0];
+      mlir::Location loc = op->getLoc();
+      mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+      firOpBuilder.setInsertionPoint(lastOper);
 
-        // v = iv + step
-        // cmp = step < 0 ? v < ub : v > ub
-        mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
-        mlir::Value zero =
-            firOpBuilder.createIntegerConstant(loc, step.getType(), 0);
-        mlir::Value negativeStep = firOpBuilder.create<mlir::arith::CmpIOp>(
-            loc, mlir::arith::CmpIPredicate::slt, step, zero);
-        mlir::Value vLT = firOpBuilder.create<mlir::arith::CmpIOp>(
-            loc, mlir::arith::CmpIPredicate::slt, v, ub);
-        mlir::Value vGT = firOpBuilder.create<mlir::arith::CmpIOp>(
-            loc, mlir::arith::CmpIPredicate::sgt, v, ub);
-        mlir::Value cmpOp = firOpBuilder.create<mlir::arith::SelectOp>(
-            loc, negativeStep, vLT, vGT);
+      mlir::Value iv = op->getRegion(0).front().getArguments()[0];
+      mlir::Value ub =
+          mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getUpperBound()[0];
+      mlir::Value step = mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getStep()[0];
 
-        auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
-        firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-        assert(loopIV && "loopIV was not set");
-        firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
-        lastPrivIP = firOpBuilder.saveInsertionPoint();
-      } else {
-        TODO(converter.getCurrentLocation(),
-             "lastprivate clause in constructs other than "
-             "simd/worksharing-loop");
-      }
+      // v = iv + step
+      // cmp = step < 0 ? v < ub : v > ub
+      mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
+      mlir::Value zero =
+          firOpBuilder.createIntegerConstant(loc, step.getType(), 0);
+      mlir::Value negativeStep = firOpBuilder.create<mlir::arith::CmpIOp>(
+          loc, mlir::arith::CmpIPredicate::slt, step, zero);
+      mlir::Value vLT = firOpBuilder.create<mlir::arith::CmpIOp>(
+          loc, mlir::arith::CmpIPredicate::slt, v, ub);
+      mlir::Value vGT = firOpBuilder.create<mlir::arith::CmpIOp>(
+          loc, mlir::arith::CmpIPredicate::sgt, v, ub);
+      mlir::Value cmpOp = firOpBuilder.create<mlir::arith::SelectOp>(
+          loc, negativeStep, vLT, vGT);
+
+      auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
+      firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      assert(loopIV && "loopIV was not set");
+      firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
+      lastPrivIP = firOpBuilder.saveInsertionPoint();
+    } else {
+      TODO(converter.getCurrentLocation(),
+           "lastprivate clause in constructs other than "
+           "simd/worksharing-loop");
     }
   }
   firOpBuilder.restoreInsertionPoint(localInsPt);
@@ -290,14 +285,12 @@ void DataSharingProcessor::collectSymbols(
 }
 
 void DataSharingProcessor::collectDefaultSymbols() {
-  for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
-    if (const auto &defaultClause =
-            std::get_if<Fortran::parser::OmpClause::Default>(&clause.u)) {
-      if (defaultClause->v.v ==
-          Fortran::parser::OmpDefaultClause::Type::Private)
+  for (const omp::Clause &clause : clauses) {
+    if (const auto *defaultClause =
+            std::get_if<omp::clause::Default>(&clause.u)) {
+      if (defaultClause->v == omp::clause::Default::Type::Private)
         collectSymbols(Fortran::semantics::Symbol::Flag::OmpPrivate);
-      else if (defaultClause->v.v ==
-               Fortran::parser::OmpDefaultClause::Type::Firstprivate)
+      else if (defaultClause->v == omp::clause::Default::Type::Firstprivate)
         collectSymbols(Fortran::semantics::Symbol::Flag::OmpFirstPrivate);
     }
   }
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index 10c0a30c09c391..a82d6064abf85d 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -12,6 +12,7 @@
 #ifndef FORTRAN_LOWER_DATASHARINGPROCESSOR_H
 #define FORTRAN_LOWER_DATASHARINGPROCESSOR_H
 
+#include "Clauses.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/OpenMP.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -33,14 +34,15 @@ class DataSharingProcessor {
   llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInNestedRegions;
   llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInParentRegions;
   Fortran::lower::AbstractConverter &converter;
+  Fortran::semantics::SemanticsContext &semaCtx;
   fir::FirOpBuilder &firOpBuilder;
-  const Fortran::parser::OmpClauseList &opClauseList;
+  omp::List<omp::Clause> clauses;
   Fortran::lower::pft::Evaluation &eval;
 
   bool needBarrier();
   void collectSymbols(Fortran::semantics::Symbol::Flag flag);
   void collectOmpObjectListSymbol(
-      const Fortran::parser::OmpObjectList &ompObjectList,
+      const omp::ObjectList &objects,
       llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet);
   void collectSymbolsForPrivatization();
   void insertBarrier();
@@ -57,11 +59,12 @@ class DataSharingProcessor {
 
 public:
   DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
+                       Fortran::semantics::SemanticsContext &semaCtx,
                        const Fortran::parser::OmpClauseList &opClauseList,
                        Fortran::lower::pft::Evaluation &eval)
-      : hasLastPrivateOp(false), converter(converter),
-        firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
-        eval(eval) {}
+      : hasLastPrivateOp(false), converter(converter), semaCtx(semaCtx),
+        firOpBuilder(converter.getFirOpBuilder()),
+        clauses(omp::makeList(opClauseList, semaCtx)), eval(eval) {}
   // Privatisation is split into two steps.
   // Step1 performs cloning of all privatisation clauses and copying for
   // firstprivates. Step1 is performed at the place where process/processStep1
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index a07aba28bd76f6..e20c39887c59e5 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -372,7 +372,7 @@ static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) {
   std::optional<DataSharingProcessor> tempDsp;
   if (privatize) {
     if (!info.dsp) {
-      tempDsp.emplace(info.converter, *info.clauses, info.eval);
+      tempDsp.emplace(info.converter, info.semaCtx, *info.clauses, info.eval);
       tempDsp->processStep1();
     }
   }
@@ -1383,7 +1383,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
                const Fortran::parser::OmpClauseList &loopOpClauseList,
                mlir::Location loc) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  DataSharingProcessor dsp(converter, loopOpClauseList, eval);
+  DataSharingProcessor dsp(converter, semaCtx, loopOpClauseList, eval);
   dsp.processStep1();
 
   Fortran::lower::StatementContext stmtCtx;
@@ -1439,7 +1439,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
                          const Fortran::parser::OmpClauseList *endClauseList,
                          mlir::Location loc) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  DataSharingProcessor dsp(converter, beginClauseList, eval);
+  DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval);
   dsp.processStep1();
 
   Fortran::lower::StatementContext stmtCtx;



More information about the llvm-branch-commits mailing list