[flang-commits] [flang] 7c9404c - [flang][OpenMP] Add frontend support for ompx_bare clause (#111106)

via flang-commits flang-commits at lists.llvm.org
Fri Dec 13 04:44:52 PST 2024


Author: Ivan R. Ivanov
Date: 2024-12-13T21:44:43+09:00
New Revision: 7c9404c279cfa13e24a043e6357cc85bd12f55f1

URL: https://github.com/llvm/llvm-project/commit/7c9404c279cfa13e24a043e6357cc85bd12f55f1
DIFF: https://github.com/llvm/llvm-project/commit/7c9404c279cfa13e24a043e6357cc85bd12f55f1.diff

LOG: [flang][OpenMP] Add frontend support for ompx_bare clause (#111106)

Added: 
    flang/test/Lower/OpenMP/KernelLanguage/bare-clause.f90
    flang/test/Semantics/OpenMP/ompx-bare.f90

Modified: 
    clang/lib/Parse/ParseOpenMP.cpp
    flang/lib/Lower/OpenMP/ClauseProcessor.cpp
    flang/lib/Lower/OpenMP/ClauseProcessor.h
    flang/lib/Lower/OpenMP/OpenMP.cpp
    flang/lib/Parser/openmp-parsers.cpp
    flang/lib/Semantics/check-omp-structure.cpp
    llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h
    llvm/include/llvm/Frontend/OpenMP/OMP.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index b91928063169ee..b4e973bc84a7b0 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3474,6 +3474,16 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
     Clause = ParseOpenMPOMPXAttributesClause(WrongDirective);
     break;
   case OMPC_ompx_bare:
+    if (DKind == llvm::omp::Directive::OMPD_target) {
+      // Flang splits the combined directives which requires OMPD_target to be
+      // marked as accepting the `ompx_bare` clause in `OMP.td`. Thus, we need
+      // to explicitly check whether this clause is applied to an `omp target`
+      // without `teams` and emit an error.
+      Diag(Tok, diag::err_omp_unexpected_clause)
+          << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);
+      ErrorFound = true;
+      WrongDirective = true;
+    }
     if (WrongDirective)
       Diag(Tok, diag::note_ompx_bare_clause)
           << getOpenMPClauseName(CKind) << "target teams";

diff  --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 48c559a78b9bc4..3c9831120351ee 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -220,6 +220,10 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
 // ClauseProcessor unique clauses
 //===----------------------------------------------------------------------===//
 
+bool ClauseProcessor::processBare(mlir::omp::BareClauseOps &result) const {
+  return markClauseOccurrence<omp::clause::OmpxBare>(result.bare);
+}
+
 bool ClauseProcessor::processBind(mlir::omp::BindClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::Bind>()) {
     fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

diff  --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index e0fe917c50e8f8..3942c54e6e935d 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -53,6 +53,7 @@ class ClauseProcessor {
       : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
 
   // 'Unique' clauses: They can appear at most once in the clause list.
+  bool processBare(mlir::omp::BareClauseOps &result) const;
   bool processBind(mlir::omp::BindClauseOps &result) const;
   bool
   processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,

diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c167d347b43159..f30d2687682c8d 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1184,6 +1184,7 @@ static void genTargetClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processBare(clauseOps);
   cp.processDepend(clauseOps);
   cp.processDevice(stmtCtx, clauseOps);
   cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms);
@@ -2860,6 +2861,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
         !std::holds_alternative<clause::Nowait>(clause.u) &&
         !std::holds_alternative<clause::NumTeams>(clause.u) &&
         !std::holds_alternative<clause::NumThreads>(clause.u) &&
+        !std::holds_alternative<clause::OmpxBare>(clause.u) &&
         !std::holds_alternative<clause::Priority>(clause.u) &&
         !std::holds_alternative<clause::Private>(clause.u) &&
         !std::holds_alternative<clause::ProcBind>(clause.u) &&

diff  --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 7d10de8c60977f..791fee3507b441 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -657,6 +657,7 @@ TYPE_PARSER(
                        parenthesized(scalarIntExpr))) ||
     "NUM_THREADS" >> construct<OmpClause>(construct<OmpClause::NumThreads>(
                          parenthesized(scalarIntExpr))) ||
+    "OMPX_BARE" >> construct<OmpClause>(construct<OmpClause::OmpxBare>()) ||
     "ORDER" >> construct<OmpClause>(construct<OmpClause::Order>(
                    parenthesized(Parser<OmpOrderClause>{}))) ||
     "ORDERED" >> construct<OmpClause>(construct<OmpClause::Ordered>(

diff  --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index d63f7a5aea3ab6..3b2033de45546f 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -2867,7 +2867,6 @@ CHECK_SIMPLE_CLAUSE(Align, OMPC_align)
 CHECK_SIMPLE_CLAUSE(Compare, OMPC_compare)
 CHECK_SIMPLE_CLAUSE(CancellationConstructType, OMPC_cancellation_construct_type)
 CHECK_SIMPLE_CLAUSE(OmpxAttribute, OMPC_ompx_attribute)
-CHECK_SIMPLE_CLAUSE(OmpxBare, OMPC_ompx_bare)
 CHECK_SIMPLE_CLAUSE(Weak, OMPC_weak)
 
 CHECK_REQ_SCALAR_INT_CLAUSE(NumTeams, OMPC_num_teams)
@@ -4395,6 +4394,17 @@ void OmpStructureChecker::Enter(const parser::OmpClause::To &x) {
   }
 }
 
+void OmpStructureChecker::Enter(const parser::OmpClause::OmpxBare &x) {
+  // Don't call CheckAllowedClause, because it allows "ompx_bare" on
+  // a non-combined "target" directive (for reasons of splitting combined
+  // directives). In source code it's only allowed on "target teams".
+  if (GetContext().directive != llvm::omp::Directive::OMPD_target_teams) {
+    context_.Say(GetContext().clauseSource,
+        "%s clause is only allowed on combined TARGET TEAMS"_err_en_US,
+        parser::ToUpperCaseLetters(getClauseName(llvm::omp::OMPC_ompx_bare)));
+  }
+}
+
 llvm::StringRef OmpStructureChecker::getClauseName(llvm::omp::Clause clause) {
   return llvm::omp::getOpenMPClauseName(clause);
 }

diff  --git a/flang/test/Lower/OpenMP/KernelLanguage/bare-clause.f90 b/flang/test/Lower/OpenMP/KernelLanguage/bare-clause.f90
new file mode 100644
index 00000000000000..1445c4fa225d2e
--- /dev/null
+++ b/flang/test/Lower/OpenMP/KernelLanguage/bare-clause.f90
@@ -0,0 +1,10 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %s -o - | FileCheck %s
+
+program test
+    integer :: tmp
+    !$omp target teams ompx_bare num_teams(42) thread_limit(43)
+    tmp = 1
+    !$omp end target teams
+end program
+
+! CHECK: omp.target ompx_bare

diff  --git a/flang/test/Semantics/OpenMP/ompx-bare.f90 b/flang/test/Semantics/OpenMP/ompx-bare.f90
new file mode 100644
index 00000000000000..21a603e9a826bf
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/ompx-bare.f90
@@ -0,0 +1,30 @@
+!RUN: %python %S/../test_errors.py %s %flang -fopenmp -fopenmp-version=51
+
+subroutine test1
+!ERROR: OMPX_BARE clause is only allowed on combined TARGET TEAMS
+  !$omp target ompx_bare
+  !$omp end target
+end
+
+subroutine test2
+  !$omp target
+!ERROR: OMPX_BARE clause is only allowed on combined TARGET TEAMS
+  !$omp teams ompx_bare
+  !$omp end teams
+  !$omp end target
+end
+
+subroutine test3
+  integer i
+!ERROR: OMPX_BARE clause is only allowed on combined TARGET TEAMS
+  !$omp target teams distribute ompx_bare
+  do i = 0, 10
+  end do
+  !$omp end target teams distribute
+end
+
+subroutine test4
+!No errors
+  !$omp target teams ompx_bare
+  !$omp end target teams
+end

diff  --git a/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h b/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h
index 4bdfa1cf4c1490..20fb581ee631a6 100644
--- a/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h
+++ b/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h
@@ -239,6 +239,8 @@ struct ConstructDecompositionT {
   bool
   applyClause(const tomp::clause::OmpxAttributeT<TypeTy, IdTy, ExprTy> &clause,
               const ClauseTy *);
+  bool applyClause(const tomp::clause::OmpxBareT<TypeTy, IdTy, ExprTy> &clause,
+                   const ClauseTy *);
 
   uint32_t version;
   llvm::omp::Directive construct;
@@ -1103,6 +1105,13 @@ bool ConstructDecompositionT<C, H>::applyClause(
   return applyToOutermost(node);
 }
 
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+    const tomp::clause::OmpxBareT<TypeTy, IdTy, ExprTy> &clause,
+    const ClauseTy *node) {
+  return applyToOutermost(node);
+}
+
 template <typename C, typename H>
 bool ConstructDecompositionT<C, H>::applyClause(
     const tomp::clause::OmpxAttributeT<TypeTy, IdTy, ExprTy> &clause,

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 4f23a6792d6344..6d04ee21ab508a 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -1018,6 +1018,7 @@ def OMP_Target : Directive<"target"> {
     VersionedClause<OMPC_Device>,
     VersionedClause<OMPC_If>,
     VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_OMPX_Bare>,
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
     VersionedClause<OMPC_ThreadLimit, 51>,
   ];

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 077d6602628aa0..98d2e80ed2d81d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -84,6 +84,32 @@ class OpenMP_AllocateClauseSkip<
 
 def OpenMP_AllocateClause : OpenMP_AllocateClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// LLVM OpenMP extension `ompx_bare` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_BareClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+    UnitAttr:$bare
+  );
+
+  let optAssemblyFormat = [{
+    `ompx_bare` $bare
+  }];
+
+  let description = [{
+    `ompx_bare` allows `omp target teams` to be executed on a GPU with an
+    explicit number of teams and threads. This clause also allows the teams and
+    threads sizes to have up to 3 dimensions.
+  }];
+}
+
+def OpenMP_BareClause : OpenMP_BareClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [16.1, 16.2] `cancel-directive-name` clause set
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 89c7ed46ff5004..65aa260a80cc01 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1223,10 +1223,11 @@ def TargetOp : OpenMP_Op<"target", traits = [
     OutlineableOpenMPOpInterface
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
-    OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
-    OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause,
-    OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip<assemblyFormat = true>,
-    OpenMP_NowaitClause, OpenMP_PrivateClause, OpenMP_ThreadLimitClause
+    OpenMP_AllocateClause, OpenMP_BareClause, OpenMP_DependClause,
+    OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_IfClause,
+    OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
+    OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
+    OpenMP_PrivateClause, OpenMP_ThreadLimitClause,
   ], singleRegion = true> {
   let summary = "target construct";
   let description = [{

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e20530be07b2f9..3d62b3218869ea 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1709,13 +1709,13 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
   // inReductionByref, inReductionSyms.
   TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
-                  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                  clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
-                  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
-                  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
-                  clauses.mapVars, clauses.nowait, clauses.privateVars,
-                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
-                  /*private_maps=*/nullptr);
+                  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
+                  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
+                  clauses.ifExpr, /*in_reduction_vars=*/{},
+                  /*in_reduction_byref=*/nullptr, /*in_reduction_syms=*/nullptr,
+                  clauses.isDevicePtrVars, clauses.mapVars, clauses.nowait,
+                  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
+                  clauses.threadLimit, /*private_maps=*/nullptr);
 }
 
 LogicalResult TargetOp::verify() {

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index ff8606ed6b3f9e..060113c4123241 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -158,6 +158,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
       result = todo("allocate");
   };
+  auto checkBare = [&todo](auto op, LogicalResult &result) {
+    if (op.getBare())
+      result = todo("ompx_bare");
+  };
   auto checkDepend = [&todo](auto op, LogicalResult &result) {
     if (!op.getDependVars().empty() || op.getDependKinds())
       result = todo("depend");
@@ -283,6 +287,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
           [&](auto op) { checkDepend(op, result); })
       .Case([&](omp::TargetOp op) {
         checkAllocate(op, result);
+        checkBare(op, result);
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
         checkIf(op, result);


        


More information about the flang-commits mailing list