[flang] [llvm] [Flang][OpenMP] Prevent re-composition of composite constructs (PR #102613)

Sergio Afonso via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 19 02:33:42 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/102613

>From 32b2169da3bb80f3df79bf8060ab07b5a177019f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 9 Aug 2024 12:58:27 +0100
Subject: [PATCH] [Flang][OpenMP] Prevent re-composition of composite
 constructs

After decomposition of OpenMP compound constructs and assignment of applicable
clauses to each leaf construct, composite constructs are then combined again
into a single element in the construct queue. This helped later lowering stages
easily identify composite constructs.

However, as a result of the re-composition stage, the same list of clauses is
used to produce all MLIR operations corresponding to each leaf of the original
composite construct. This undoes existing logic introducing implicit clauses
and deciding to which leaf construct(s) each clause applies.

This patch removes construct re-composition logic and updates Flang lowering to
be able to identify composite constructs from a list of leaf constructs. As a
result, the right set of clauses is produced for each operation representing a
leaf of a composite construct.
---
 flang/lib/Lower/OpenMP/Decomposer.cpp         |  60 ++-
 flang/lib/Lower/OpenMP/Decomposer.h           |  10 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  96 ++--
 .../Lower/OpenMP/Todo/omp-do-simd-linear.f90  |   2 +-
 .../Lower/OpenMP/default-clause-byref.f90     |   4 +-
 flang/test/Lower/OpenMP/default-clause.f90    |   4 +-
 .../Frontend/OpenMP/ConstructCompositionT.h   | 425 ------------------
 7 files changed, 105 insertions(+), 496 deletions(-)
 delete mode 100644 llvm/include/llvm/Frontend/OpenMP/ConstructCompositionT.h

diff --git a/flang/lib/Lower/OpenMP/Decomposer.cpp b/flang/lib/Lower/OpenMP/Decomposer.cpp
index dfd85897469e28..33568bf96b5dfb 100644
--- a/flang/lib/Lower/OpenMP/Decomposer.cpp
+++ b/flang/lib/Lower/OpenMP/Decomposer.cpp
@@ -22,7 +22,6 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Frontend/OpenMP/ClauseT.h"
-#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
 #include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
 #include "llvm/Frontend/OpenMP/OMP.h"
 #include "llvm/Support/raw_ostream.h"
@@ -68,12 +67,6 @@ struct ConstructDecomposition {
 };
 } // namespace
 
-static UnitConstruct mergeConstructs(uint32_t version,
-                                     llvm::ArrayRef<UnitConstruct> units) {
-  tomp::ConstructCompositionT compose(version, units);
-  return compose.merged;
-}
-
 namespace Fortran::lower::omp {
 LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
                                                const UnitConstruct &uc) {
@@ -90,38 +83,37 @@ ConstructQueue buildConstructQueue(
     Fortran::lower::pft::Evaluation &eval, const parser::CharBlock &source,
     llvm::omp::Directive compound, const List<Clause> &clauses) {
 
-  List<UnitConstruct> constructs;
-
   ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
   assert(!decompose.output.empty() && "Construct decomposition failed");
 
-  llvm::SmallVector<llvm::omp::Directive> loweringUnits;
-  std::ignore =
-      llvm::omp::getLeafOrCompositeConstructs(compound, loweringUnits);
-  uint32_t version = getOpenMPVersionAttribute(modOp);
-
-  int leafIndex = 0;
-  for (llvm::omp::Directive dir_id : loweringUnits) {
-    llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
-        llvm::omp::getLeafConstructsOrSelf(dir_id);
-    size_t numLeafs = leafsOrSelf.size();
-
-    llvm::ArrayRef<UnitConstruct> toMerge{&decompose.output[leafIndex],
-                                          numLeafs};
-    auto &uc = constructs.emplace_back(mergeConstructs(version, toMerge));
-
-    if (!transferLocations(clauses, uc.clauses)) {
-      // If some clauses are left without source information, use the
-      // directive's source.
-      for (auto &clause : uc.clauses) {
-        if (clause.source.empty())
-          clause.source = source;
-      }
-    }
-    leafIndex += numLeafs;
+  for (UnitConstruct &uc : decompose.output) {
+    assert(getLeafConstructs(uc.id).empty() && "unexpected compound directive");
+    //  If some clauses are left without source information, use the directive's
+    //  source.
+    for (auto &clause : uc.clauses)
+      if (clause.source.empty())
+        clause.source = source;
+  }
+
+  return decompose.output;
+}
+
+bool matchLeafSequence(ConstructQueue::const_iterator item,
+                       const ConstructQueue &queue,
+                       llvm::omp::Directive directive) {
+  llvm::ArrayRef<llvm::omp::Directive> leafDirs =
+      llvm::omp::getLeafConstructsOrSelf(directive);
+
+  for (auto [dir, leaf] :
+       llvm::zip_longest(leafDirs, llvm::make_range(item, queue.end()))) {
+    if (!dir.has_value() || !leaf.has_value())
+      return false;
+
+    if (*dir != leaf->id)
+      return false;
   }
 
-  return constructs;
+  return true;
 }
 
 bool isLastItemInQueue(ConstructQueue::const_iterator item,
diff --git a/flang/lib/Lower/OpenMP/Decomposer.h b/flang/lib/Lower/OpenMP/Decomposer.h
index e85956ffe1a231..e3291b7c59e216 100644
--- a/flang/lib/Lower/OpenMP/Decomposer.h
+++ b/flang/lib/Lower/OpenMP/Decomposer.h
@@ -10,7 +10,6 @@
 
 #include "Clauses.h"
 #include "mlir/IR/BuiltinOps.h"
-#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
 #include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
 #include "llvm/Frontend/OpenMP/OMP.h"
 #include "llvm/Support/Compiler.h"
@@ -49,6 +48,15 @@ ConstructQueue buildConstructQueue(mlir::ModuleOp modOp,
 
 bool isLastItemInQueue(ConstructQueue::const_iterator item,
                        const ConstructQueue &queue);
+
+/// Try to match the leaf constructs conforming the given \c directive to the
+/// range of leaf constructs starting from \c item to the end of the \c queue.
+/// If \c directive doesn't represent a compound directive, check that \c item
+/// matches that directive and is the only element before the end of the
+/// \c queue.
+bool matchLeafSequence(ConstructQueue::const_iterator item,
+                       const ConstructQueue &queue,
+                       llvm::omp::Directive directive);
 } // namespace Fortran::lower::omp
 
 #endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 64b581e8910d07..d614db8b68ef65 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2044,6 +2044,7 @@ static void genCompositeDistributeParallelDoSimd(
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     mlir::Location loc, const ConstructQueue &queue,
     ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
+  assert(std::distance(item, queue.end()) == 4 && "Invalid leaf constructs");
   TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
 }
 
@@ -2054,17 +2055,23 @@ static void genCompositeDistributeSimd(
     ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
   lower::StatementContext stmtCtx;
 
+  assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
+  ConstructQueue::const_iterator distributeItem = item;
+  ConstructQueue::const_iterator simdItem = std::next(distributeItem);
+
   // Clause processing.
   mlir::omp::DistributeOperands distributeClauseOps;
-  genDistributeClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
-                       distributeClauseOps);
+  genDistributeClauses(converter, semaCtx, stmtCtx, distributeItem->clauses,
+                       loc, distributeClauseOps);
 
   mlir::omp::SimdOperands simdClauseOps;
-  genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps);
+  genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);
 
+  // Pass the innermost leaf construct's clauses because that's where COLLAPSE
+  // is placed by construct decomposition.
   mlir::omp::LoopNestOperands loopNestClauseOps;
   llvm::SmallVector<const semantics::Symbol *> iv;
-  genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
+  genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
                      loopNestClauseOps, iv);
 
   // Operation creation.
@@ -2086,7 +2093,7 @@ static void genCompositeDistributeSimd(
       llvm::concat<mlir::BlockArgument>(distributeOp.getRegion().getArguments(),
                                         simdOp.getRegion().getArguments()));
 
-  genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
+  genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem,
                 loopNestClauseOps, iv, /*wrapperSyms=*/{}, wrapperArgs,
                 llvm::omp::Directive::OMPD_distribute_simd, dsp);
 }
@@ -2100,19 +2107,25 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
                                DataSharingProcessor &dsp) {
   lower::StatementContext stmtCtx;
 
+  assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
+  ConstructQueue::const_iterator doItem = item;
+  ConstructQueue::const_iterator simdItem = std::next(doItem);
+
   // Clause processing.
   mlir::omp::WsloopOperands wsloopClauseOps;
   llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
   llvm::SmallVector<mlir::Type> wsloopReductionTypes;
-  genWsloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
+  genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
                    wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);
 
   mlir::omp::SimdOperands simdClauseOps;
-  genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps);
+  genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);
 
+  // Pass the innermost leaf construct's clauses because that's where COLLAPSE
+  // is placed by construct decomposition.
   mlir::omp::LoopNestOperands loopNestClauseOps;
   llvm::SmallVector<const semantics::Symbol *> iv;
-  genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
+  genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
                      loopNestClauseOps, iv);
 
   // Operation creation.
@@ -2133,7 +2146,7 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
   auto wrapperArgs = llvm::to_vector(llvm::concat<mlir::BlockArgument>(
       wsloopOp.getRegion().getArguments(), simdOp.getRegion().getArguments()));
 
-  genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
+  genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem,
                 loopNestClauseOps, iv, wsloopReductionSyms, wrapperArgs,
                 llvm::omp::Directive::OMPD_do_simd, dsp);
 }
@@ -2143,6 +2156,7 @@ static void genCompositeTaskloopSimd(
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     mlir::Location loc, const ConstructQueue &queue,
     ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
+  assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
   TODO(loc, "Composite TASKLOOP SIMD");
 }
 
@@ -2150,6 +2164,36 @@ static void genCompositeTaskloopSimd(
 // Dispatch
 //===----------------------------------------------------------------------===//
 
+static bool genOMPCompositeDispatch(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
+  using llvm::omp::Directive;
+  using lower::omp::matchLeafSequence;
+
+  if (matchLeafSequence(item, queue, Directive::OMPD_distribute_parallel_do))
+    genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
+                                     queue, item, dsp);
+  else if (matchLeafSequence(item, queue,
+                             Directive::OMPD_distribute_parallel_do_simd))
+    genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
+                                         loc, queue, item, dsp);
+  else if (matchLeafSequence(item, queue, Directive::OMPD_distribute_simd))
+    genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
+                               item, dsp);
+  else if (matchLeafSequence(item, queue, Directive::OMPD_do_simd))
+    genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
+                       dsp);
+  else if (matchLeafSequence(item, queue, Directive::OMPD_taskloop_simd))
+    genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
+                             item, dsp);
+  else
+    return false;
+
+  return true;
+}
+
 static void genOMPDispatch(lower::AbstractConverter &converter,
                            lower::SymMap &symTable,
                            semantics::SemanticsContext &semaCtx,
@@ -2163,10 +2207,18 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                   llvm::omp::Association::Loop;
   if (loopLeaf) {
     symTable.pushScope();
+    // TODO: Use one DataSharingProcessor for each leaf of a composite
+    // construct.
     loopDsp.emplace(converter, semaCtx, item->clauses, eval,
                     /*shouldCollectPreDeterminedSymbols=*/true,
                     /*useDelayedPrivatization=*/false, &symTable);
     loopDsp->processStep1();
+
+    if (genOMPCompositeDispatch(converter, symTable, semaCtx, eval, loc, queue,
+                                item, *loopDsp)) {
+      symTable.popScope();
+      return;
+    }
   }
 
   switch (llvm::omp::Directive dir = item->id) {
@@ -2262,29 +2314,11 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     // that use this construct, add a single construct for now.
     genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
-
-  // Composite constructs
-  case llvm::omp::Directive::OMPD_distribute_parallel_do:
-    genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
-                                     queue, item, *loopDsp);
-    break;
-  case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
-    genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
-                                         loc, queue, item, *loopDsp);
-    break;
-  case llvm::omp::Directive::OMPD_distribute_simd:
-    genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
-                               item, *loopDsp);
-    break;
-  case llvm::omp::Directive::OMPD_do_simd:
-    genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
-                       *loopDsp);
-    break;
-  case llvm::omp::Directive::OMPD_taskloop_simd:
-    genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
-                             item, *loopDsp);
-    break;
   default:
+    // Combined and composite constructs should have been split into a sequence
+    // of leaf constructs when building the construct queue.
+    assert(!llvm::omp::isLeafConstruct(dir) &&
+           "Unexpected compound construct.");
     break;
   }
 
diff --git a/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90 b/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90
index 2f5366c2a5b368..4caf12a0169c42 100644
--- a/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90
+++ b/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90
@@ -4,7 +4,7 @@
 ! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
 subroutine testDoSimdLinear(int_array)
         integer :: int_array(*)
-!CHECK: not yet implemented: Unhandled clause LINEAR in DO construct
+!CHECK: not yet implemented: Unhandled clause LINEAR in SIMD construct
 !$omp do simd linear(int_array)
         do index_ = 1, 10
         end do
diff --git a/flang/test/Lower/OpenMP/default-clause-byref.f90 b/flang/test/Lower/OpenMP/default-clause-byref.f90
index d9f0eff4e6fde1..626ba3335a8c10 100644
--- a/flang/test/Lower/OpenMP/default-clause-byref.f90
+++ b/flang/test/Lower/OpenMP/default-clause-byref.f90
@@ -197,9 +197,9 @@ subroutine nested_default_clause_tests
 !CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_testsEz"}
 !CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
-!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: omp.parallel private({{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
 !CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[PRIVATE_K_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_K]] {uniq_name = "_QFnested_default_clause_testsEk"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_X:.*]] : {{.*}}) {
diff --git a/flang/test/Lower/OpenMP/default-clause.f90 b/flang/test/Lower/OpenMP/default-clause.f90
index 775ce9ac801934..fefb5fcc4239e6 100644
--- a/flang/test/Lower/OpenMP/default-clause.f90
+++ b/flang/test/Lower/OpenMP/default-clause.f90
@@ -134,9 +134,9 @@ end program default_clause_lowering
 !CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_test1Ez"}
 !CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
-!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: omp.parallel private({{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
 !CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_test1Ex"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[PRIVATE_K_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_K]] {uniq_name = "_QFnested_default_clause_test1Ek"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_X:.*]] : {{.*}}) {
diff --git a/llvm/include/llvm/Frontend/OpenMP/ConstructCompositionT.h b/llvm/include/llvm/Frontend/OpenMP/ConstructCompositionT.h
deleted file mode 100644
index b3a02cd5312170..00000000000000
--- a/llvm/include/llvm/Frontend/OpenMP/ConstructCompositionT.h
+++ /dev/null
@@ -1,425 +0,0 @@
-//===- ConstructCompositionT.h -- Composing compound constructs -----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-// Given a list of leaf construct, each with a set of clauses, generate the
-// compound construct whose leaf constructs are the given list, and whose clause
-// list is the merged lists of individual leaf clauses.
-//
-// *** At the moment it assumes that the individual constructs and their clauses
-// *** are a subset of those created by splitting a valid compound construct.
-//===----------------------------------------------------------------------===//
-#ifndef LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
-#define LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
-
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/BitVector.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Frontend/OpenMP/ClauseT.h"
-#include "llvm/Frontend/OpenMP/OMP.h"
-
-#include <iterator>
-#include <optional>
-#include <tuple>
-#include <unordered_map>
-#include <unordered_set>
-#include <utility>
-
-namespace tomp {
-template <typename ClauseType> struct ConstructCompositionT {
-  using ClauseTy = ClauseType;
-
-  using TypeTy = typename ClauseTy::TypeTy;
-  using IdTy = typename ClauseTy::IdTy;
-  using ExprTy = typename ClauseTy::ExprTy;
-
-  ConstructCompositionT(uint32_t version,
-                        llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs);
-
-  DirectiveWithClauses<ClauseTy> merged;
-
-private:
-  // Use an ordered container, since we beed to maintain the order in which
-  // clauses are added to it. This is to avoid non-deterministic output.
-  using ClauseSet = ListT<ClauseTy>;
-
-  enum class Presence {
-    All,  // Clause is preesnt on all leaf constructs that allow it.
-    Some, // Clause is present on some, but not on all constructs.
-    None, // Clause is absent on all constructs.
-  };
-
-  template <typename S>
-  ClauseTy makeClause(llvm::omp::Clause clauseId, S &&specific) {
-    return typename ClauseTy::BaseT{clauseId, std::move(specific)};
-  }
-
-  llvm::omp::Directive
-  makeCompound(llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts);
-
-  Presence checkPresence(llvm::omp::Clause clauseId);
-
-  // There are clauses that need special handling:
-  // 1. "if": the "directive-name-modifier" on the merged clause may need
-  // to be set appropriately.
-  // 2. "reduction": implies "privateness" of all objects (incompatible
-  // with "shared"); there are rules for merging modifiers
-  void mergeIf();
-  void mergeReduction();
-  void mergeDSA();
-
-  uint32_t version;
-  llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs;
-
-  // clause id -> set of leaf constructs that contain it
-  std::unordered_map<llvm::omp::Clause, llvm::BitVector> clausePresence;
-  // clause id -> set of instances of that clause
-  std::unordered_map<llvm::omp::Clause, ClauseSet> clauseSets;
-};
-
-template <typename ClauseTy>
-ConstructCompositionT(uint32_t, llvm::ArrayRef<DirectiveWithClauses<ClauseTy>>)
-    -> ConstructCompositionT<ClauseTy>;
-
-template <typename C>
-ConstructCompositionT<C>::ConstructCompositionT(
-    uint32_t version, llvm::ArrayRef<DirectiveWithClauses<C>> leafs)
-    : version(version), leafs(leafs) {
-  // Merge the list of constructs with clauses into a compound construct
-  // with a single list of clauses.
-  // The intended use of this function is in splitting compound constructs,
-  // while preserving composite constituent constructs:
-  // Step 1: split compound construct into leaf constructs.
-  // Step 2: identify composite sub-construct, and merge the constituent leafs.
-  //
-  // *** At the moment it assumes that the individual constructs and their
-  // *** clauses are a subset of those created by splitting a valid compound
-  // *** construct.
-  //
-  // 1. Deduplicate clauses
-  //    - exact duplicates: e.g. shared(x) shared(x) -> shared(x)
-  //    - special cases of clauses differing in modifier:
-  //      (a) reduction: inscan + (none|default) = inscan
-  //      (b) reduction: task + (none|default) = task
-  //      (c) combine repeated "if" clauses if possible
-  // 2. Merge DSA clauses: e.g. private(x) private(y) -> private(x, y).
-  // 3. Resolve potential DSA conflicts (typically due to implied clauses).
-
-  if (leafs.empty())
-    return;
-
-  merged.id = makeCompound(leafs);
-
-  // Populate the two maps:
-  for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
-    for (const auto &clause : leaf.clauses) {
-      // Update clausePresence.
-      auto &pset = clausePresence[clause.id];
-      if (pset.size() < leafs.size())
-        pset.resize(leafs.size());
-      pset.set(index);
-      // Update clauseSets.
-      ClauseSet &cset = clauseSets[clause.id];
-      if (!llvm::is_contained(cset, clause))
-        cset.push_back(clause);
-    }
-  }
-
-  mergeIf();
-  mergeReduction();
-  mergeDSA();
-
-  // For the rest of the clauses, just copy them.
-  for (auto &[id, clauses] : clauseSets) {
-    // Skip clauses we've already dealt with.
-    switch (id) {
-    case llvm::omp::Clause::OMPC_if:
-    case llvm::omp::Clause::OMPC_reduction:
-    case llvm::omp::Clause::OMPC_shared:
-    case llvm::omp::Clause::OMPC_private:
-    case llvm::omp::Clause::OMPC_firstprivate:
-    case llvm::omp::Clause::OMPC_lastprivate:
-      continue;
-    default:
-      break;
-    }
-    llvm::append_range(merged.clauses, clauses);
-  }
-}
-
-template <typename C>
-llvm::omp::Directive ConstructCompositionT<C>::makeCompound(
-    llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts) {
-  llvm::SmallVector<llvm::omp::Directive> dirIds;
-  llvm::transform(parts, std::back_inserter(dirIds),
-                  [](auto &&dwc) { return dwc.id; });
-
-  return llvm::omp::getCompoundConstruct(dirIds);
-}
-
-template <typename C>
-auto ConstructCompositionT<C>::checkPresence(llvm::omp::Clause clauseId)
-    -> Presence {
-  auto found = clausePresence.find(clauseId);
-  if (found == clausePresence.end())
-    return Presence::None;
-
-  bool OnAll = true, OnNone = true;
-  for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
-    if (!llvm::omp::isAllowedClauseForDirective(leaf.id, clauseId, version))
-      continue;
-
-    if (found->second.test(index))
-      OnNone = false;
-    else
-      OnAll = false;
-  }
-
-  if (OnNone)
-    return Presence::None;
-  if (OnAll)
-    return Presence::All;
-  return Presence::Some;
-}
-
-template <typename C> void ConstructCompositionT<C>::mergeIf() {
-  using IfTy = tomp::clause::IfT<TypeTy, IdTy, ExprTy>;
-  // Deal with the "if" clauses. If it's on all leafs that allow it, then it
-  // will apply to the compound construct. Otherwise it will apply to the
-  // single (assumed) leaf construct.
-  // This assumes that the "if" clauses have the same expression.
-  Presence presence = checkPresence(llvm::omp::Clause::OMPC_if);
-  if (presence == Presence::None)
-    return;
-
-  const ClauseTy &some = *clauseSets[llvm::omp::Clause::OMPC_if].begin();
-  const auto &someIf = std::get<IfTy>(some.u);
-
-  if (presence == Presence::All) {
-    // Create "if" without "directive-name-modifier".
-    merged.clauses.emplace_back(
-        makeClause(llvm::omp::Clause::OMPC_if,
-                   IfTy{{/*DirectiveNameModifier=*/std::nullopt,
-                         /*IfExpression=*/std::get<typename IfTy::IfExpression>(
-                             someIf.t)}}));
-  } else {
-    // Find out where it's present and create "if" with the corresponding
-    // "directive-name-modifier".
-    int Idx = clausePresence[llvm::omp::Clause::OMPC_if].find_first();
-    assert(Idx >= 0);
-    merged.clauses.emplace_back(
-        makeClause(llvm::omp::Clause::OMPC_if,
-                   IfTy{{/*DirectiveNameModifier=*/leafs[Idx].id,
-                         /*IfExpression=*/std::get<typename IfTy::IfExpression>(
-                             someIf.t)}}));
-  }
-}
-
-template <typename C> void ConstructCompositionT<C>::mergeReduction() {
-  Presence presence = checkPresence(llvm::omp::Clause::OMPC_reduction);
-  if (presence == Presence::None)
-    return;
-
-  using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
-  using ModifierTy = typename ReductionTy::ReductionModifier;
-  using IdentifiersTy = typename ReductionTy::ReductionIdentifiers;
-  using ListTy = typename ReductionTy::List;
-  // There are exceptions on which constructs "reduction" may appear
-  // (specifically "parallel", and "teams"). Assume that if "reduction"
-  // is present, it can be applied to the compound construct.
-
-  // What's left is to see if there are any modifiers present. Again,
-  // assume that there are no conflicting modifiers.
-  // There can be, however, multiple reductions on different objects.
-  auto equal = [](const ClauseTy &red1, const ClauseTy &red2) {
-    // Extract actual reductions.
-    const auto r1 = std::get<ReductionTy>(red1.u);
-    const auto r2 = std::get<ReductionTy>(red2.u);
-    // Compare everything except modifiers.
-    if (std::get<IdentifiersTy>(r1.t) != std::get<IdentifiersTy>(r2.t))
-      return false;
-    if (std::get<ListTy>(r1.t) != std::get<ListTy>(r2.t))
-      return false;
-    return true;
-  };
-
-  auto getModifier = [](const ClauseTy &clause) {
-    const ReductionTy &red = std::get<ReductionTy>(clause.u);
-    return std::get<std::optional<ModifierTy>>(red.t);
-  };
-
-  const ClauseSet &reductions = clauseSets[llvm::omp::Clause::OMPC_reduction];
-  std::unordered_set<const ClauseTy *> visited;
-  while (reductions.size() != visited.size()) {
-    typename ClauseSet::const_iterator first;
-
-    // Find first non-visited reduction.
-    for (first = reductions.begin(); first != reductions.end(); ++first) {
-      if (visited.count(&*first))
-        continue;
-      visited.insert(&*first);
-      break;
-    }
-
-    std::optional<ModifierTy> modifier = getModifier(*first);
-
-    // Visit all other reductions that are "equal" (with respect to the
-    // definition above) to "first". Collect modifiers.
-    for (auto iter = std::next(first); iter != reductions.end(); ++iter) {
-      if (!equal(*first, *iter))
-        continue;
-      visited.insert(&*iter);
-      if (!modifier || *modifier == ModifierTy::Default)
-        modifier = getModifier(*iter);
-    }
-
-    const auto &firstRed = std::get<ReductionTy>(first->u);
-    merged.clauses.emplace_back(makeClause(
-        llvm::omp::Clause::OMPC_reduction,
-        ReductionTy{
-            {/*ReductionModifier=*/modifier,
-             /*ReductionIdentifiers=*/std::get<IdentifiersTy>(firstRed.t),
-             /*List=*/std::get<ListTy>(firstRed.t)}}));
-  }
-}
-
-template <typename C> void ConstructCompositionT<C>::mergeDSA() {
-  using ObjectTy = tomp::type::ObjectT<IdTy, ExprTy>;
-
-  // Resolve data-sharing attributes.
-  enum DSA : int {
-    None = 0,
-    Shared = 1 << 0,
-    Private = 1 << 1,
-    FirstPrivate = 1 << 2,
-    LastPrivate = 1 << 3,
-    LastPrivateConditional = 1 << 4,
-  };
-
-  // Use ordered containers to avoid non-deterministic output.
-  llvm::SmallVector<std::pair<ObjectTy, int>, 8> objectDsa;
-
-  auto getDsa = [&](const ObjectTy &object) -> std::pair<ObjectTy, int> & {
-    auto found = llvm::find_if(objectDsa, [&](std::pair<ObjectTy, int> &p) {
-      return p.first.id() == object.id();
-    });
-    if (found != objectDsa.end())
-      return *found;
-    return objectDsa.emplace_back(object, DSA::None);
-  };
-
-  using SharedTy = tomp::clause::SharedT<TypeTy, IdTy, ExprTy>;
-  using PrivateTy = tomp::clause::PrivateT<TypeTy, IdTy, ExprTy>;
-  using FirstprivateTy = tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy>;
-  using LastprivateTy = tomp::clause::LastprivateT<TypeTy, IdTy, ExprTy>;
-
-  // Visit clauses that affect DSA.
-  for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_shared]) {
-    for (auto &object : std::get<SharedTy>(clause.u).v)
-      getDsa(object).second |= DSA::Shared;
-  }
-
-  for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_private]) {
-    for (auto &object : std::get<PrivateTy>(clause.u).v)
-      getDsa(object).second |= DSA::Private;
-  }
-
-  for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_firstprivate]) {
-    for (auto &object : std::get<FirstprivateTy>(clause.u).v)
-      getDsa(object).second |= DSA::FirstPrivate;
-  }
-
-  for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_lastprivate]) {
-    using ModifierTy = typename LastprivateTy::LastprivateModifier;
-    using ListTy = typename LastprivateTy::List;
-    const auto &lastp = std::get<LastprivateTy>(clause.u);
-    for (auto &object : std::get<ListTy>(lastp.t)) {
-      auto &mod = std::get<std::optional<ModifierTy>>(lastp.t);
-      if (mod && *mod == ModifierTy::Conditional) {
-        getDsa(object).second |= DSA::LastPrivateConditional;
-      } else {
-        getDsa(object).second |= DSA::LastPrivate;
-      }
-    }
-  }
-
-  // Check other privatizing clauses as well, clear "shared" if set.
-  for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_in_reduction]) {
-    using InReductionTy = tomp::clause::InReductionT<TypeTy, IdTy, ExprTy>;
-    using ListTy = typename InReductionTy::List;
-    for (auto &object : std::get<ListTy>(std::get<InReductionTy>(clause.u).t))
-      getDsa(object).second &= ~DSA::Shared;
-  }
-  for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_linear]) {
-    using LinearTy = tomp::clause::LinearT<TypeTy, IdTy, ExprTy>;
-    using ListTy = typename LinearTy::List;
-    for (auto &object : std::get<ListTy>(std::get<LinearTy>(clause.u).t))
-      getDsa(object).second &= ~DSA::Shared;
-  }
-  for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_reduction]) {
-    using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
-    using ListTy = typename ReductionTy::List;
-    for (auto &object : std::get<ListTy>(std::get<ReductionTy>(clause.u).t))
-      getDsa(object).second &= ~DSA::Shared;
-  }
-  for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_task_reduction]) {
-    using TaskReductionTy = tomp::clause::TaskReductionT<TypeTy, IdTy, ExprTy>;
-    using ListTy = typename TaskReductionTy::List;
-    for (auto &object : std::get<ListTy>(std::get<TaskReductionTy>(clause.u).t))
-      getDsa(object).second &= ~DSA::Shared;
-  }
-
-  tomp::ListT<ObjectTy> privateObj, sharedObj, firstpObj, lastpObj, lastpcObj;
-  for (auto &[object, dsa] : objectDsa) {
-    if (dsa &
-        (DSA::FirstPrivate | DSA::LastPrivate | DSA::LastPrivateConditional)) {
-      if (dsa & DSA::FirstPrivate)
-        firstpObj.push_back(object); // no else
-      if (dsa & DSA::LastPrivateConditional)
-        lastpcObj.push_back(object);
-      else if (dsa & DSA::LastPrivate)
-        lastpObj.push_back(object);
-    } else if (dsa & DSA::Private) {
-      privateObj.push_back(object);
-    } else if (dsa & DSA::Shared) {
-      sharedObj.push_back(object);
-    }
-  }
-
-  // Materialize each clause.
-  if (!privateObj.empty()) {
-    merged.clauses.emplace_back(
-        makeClause(llvm::omp::Clause::OMPC_private,
-                   PrivateTy{/*List=*/std::move(privateObj)}));
-  }
-  if (!sharedObj.empty()) {
-    merged.clauses.emplace_back(
-        makeClause(llvm::omp::Clause::OMPC_shared,
-                   SharedTy{/*List=*/std::move(sharedObj)}));
-  }
-  if (!firstpObj.empty()) {
-    merged.clauses.emplace_back(
-        makeClause(llvm::omp::Clause::OMPC_firstprivate,
-                   FirstprivateTy{/*List=*/std::move(firstpObj)}));
-  }
-  if (!lastpObj.empty()) {
-    merged.clauses.emplace_back(
-        makeClause(llvm::omp::Clause::OMPC_lastprivate,
-                   LastprivateTy{{/*LastprivateModifier=*/std::nullopt,
-                                  /*List=*/std::move(lastpObj)}}));
-  }
-  if (!lastpcObj.empty()) {
-    auto conditional = LastprivateTy::LastprivateModifier::Conditional;
-    merged.clauses.emplace_back(
-        makeClause(llvm::omp::Clause::OMPC_lastprivate,
-                   LastprivateTy{{/*LastprivateModifier=*/conditional,
-                                  /*List=*/std::move(lastpcObj)}}));
-  }
-}
-} // namespace tomp
-
-#endif // LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H



More information about the llvm-commits mailing list