[flang-commits] [flang] Adding parsing and semantic check support for omp masked (PR #91432)

Anchu Rajendran S via flang-commits flang-commits at lists.llvm.org
Wed May 8 10:38:39 PDT 2024


https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/91432

>From 491113877dc26832855608f2ab1b7568687e16e5 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Tue, 7 May 2024 21:32:25 -0500
Subject: [PATCH 1/2] Adding parsing and semantic check support for omp masked

omp masked directive in OpenMP 5.2 allows to specify code regions
which are expected to be executed by thread ids specified by the programmer.
Filter clause of the directive allows to specify the thread id.
This change adds the parsing support for the directive
---
 flang/lib/Parser/openmp-parsers.cpp         | 19 +++--
 flang/lib/Parser/unparse.cpp                | 18 ++++
 flang/lib/Semantics/resolve-directives.cpp  | 61 +++++++++++---
 flang/test/Parser/OpenMP/masked-unparse.f90 | 92 +++++++++++++++++++++
 flang/test/Semantics/OpenMP/masked.f90      | 13 +++
 5 files changed, 186 insertions(+), 17 deletions(-)
 create mode 100644 flang/test/Parser/OpenMP/masked-unparse.f90
 create mode 100644 flang/test/Semantics/OpenMP/masked.f90

diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 48f213794247d..e470bf7856607 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -220,12 +220,11 @@ TYPE_PARSER(construct<OmpAlignedClause>(
 
 // 2.9.5 ORDER ([order-modifier :]concurrent)
 TYPE_PARSER(construct<OmpOrderModifier>(
-    "REPRODUCIBLE" >> pure(OmpOrderModifier::Kind::Reproducible)) ||
+                "REPRODUCIBLE" >> pure(OmpOrderModifier::Kind::Reproducible)) ||
     construct<OmpOrderModifier>(
-    "UNCONSTRAINED" >> pure(OmpOrderModifier::Kind::Unconstrained)))
+        "UNCONSTRAINED" >> pure(OmpOrderModifier::Kind::Unconstrained)))
 
-TYPE_PARSER(construct<OmpOrderClause>(
-    maybe(Parser<OmpOrderModifier>{} / ":"),
+TYPE_PARSER(construct<OmpOrderClause>(maybe(Parser<OmpOrderModifier>{} / ":"),
     "CONCURRENT" >> pure(OmpOrderClause::Type::Concurrent)))
 
 TYPE_PARSER(
@@ -266,6 +265,8 @@ TYPE_PARSER(
         construct<OmpClause>(construct<OmpClause::DynamicAllocators>()) ||
     "ENTER" >> construct<OmpClause>(construct<OmpClause::Enter>(
                    parenthesized(Parser<OmpObjectList>{}))) ||
+    "FILTER" >> construct<OmpClause>(construct<OmpClause::Filter>(
+                    parenthesized(scalarIntExpr))) ||
     "FINAL" >> construct<OmpClause>(construct<OmpClause::Final>(
                    parenthesized(scalarLogicalExpr))) ||
     "FULL" >> construct<OmpClause>(construct<OmpClause::Full>()) ||
@@ -486,9 +487,17 @@ TYPE_PARSER(
     endOfLine)
 
 // Directives enclosing structured-block
-TYPE_PARSER(construct<OmpBlockDirective>(first(
+TYPE_PARSER(construct<OmpBlockDirective>(first("MASKED TASKLOOP SIMD" >>
+        pure(llvm::omp::Directive::OMPD_masked_taskloop_simd),
+    "MASKED TASKLOOP" >> pure(llvm::omp::Directive::OMPD_masked_taskloop),
+    "MASKED" >> pure(llvm::omp::Directive::OMPD_masked),
     "MASTER" >> pure(llvm::omp::Directive::OMPD_master),
     "ORDERED" >> pure(llvm::omp::Directive::OMPD_ordered),
+    "PARALLEL MASKED TASKLOOP SIMD" >>
+        pure(llvm::omp::Directive::OMPD_parallel_masked_taskloop_simd),
+    "PARALLEL MASKED TASKLOOP" >>
+        pure(llvm::omp::Directive::OMPD_parallel_masked_taskloop),
+    "PARALLEL MASKED" >> pure(llvm::omp::Directive::OMPD_parallel_masked),
     "PARALLEL WORKSHARE" >> pure(llvm::omp::Directive::OMPD_parallel_workshare),
     "PARALLEL" >> pure(llvm::omp::Directive::OMPD_parallel),
     "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 3398b395f198f..7b6b083cf3e62 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2283,12 +2283,30 @@ class UnparseVisitor {
   }
   void Unparse(const OmpBlockDirective &x) {
     switch (x.v) {
+    case llvm::omp::Directive::OMPD_masked_taskloop_simd:
+      Word("MASKED TASKLOOP SIMD");
+      break;
+    case llvm::omp::Directive::OMPD_masked_taskloop:
+      Word("MASKED TASKLOOP");
+      break;
+    case llvm::omp::Directive::OMPD_masked:
+      Word("MASKED");
+      break;
     case llvm::omp::Directive::OMPD_master:
       Word("MASTER");
       break;
     case llvm::omp::Directive::OMPD_ordered:
       Word("ORDERED ");
       break;
+    case llvm::omp::Directive::OMPD_parallel_masked_taskloop_simd:
+      Word("PARALLEL MASKED TASKLOOP SIMD");
+      break;
+    case llvm::omp::Directive::OMPD_parallel_masked_taskloop:
+      Word("PARALLEL MASKED TASKLOOP");
+      break;
+    case llvm::omp::Directive::OMPD_parallel_masked:
+      Word("PARALLEL MASKED");
+      break;
     case llvm::omp::Directive::OMPD_parallel_workshare:
       Word("PARALLEL WORKSHARE ");
       break;
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 2add2056f658d..b11fa3174277d 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -19,8 +19,10 @@
 #include "flang/Parser/parse-tree.h"
 #include "flang/Parser/tools.h"
 #include "flang/Semantics/expression.h"
+#include <cstdint>
 #include <list>
 #include <map>
+#include <optional>
 #include <sstream>
 
 template <typename T>
@@ -50,6 +52,7 @@ template <typename T> class DirectiveAttributeVisitor {
     Symbol::Flag defaultDSA{Symbol::Flag::AccShared}; // TODOACC
     std::map<const Symbol *, Symbol::Flag> objectWithDSA;
     bool withinConstruct{false};
+    std::optional<int64_t> maskedTId;
     std::int64_t associatedLoopLevel{0};
   };
 
@@ -90,6 +93,9 @@ template <typename T> class DirectiveAttributeVisitor {
   void SetContextAssociatedLoopLevel(std::int64_t level) {
     GetContext().associatedLoopLevel = level;
   }
+  void SetMaskedTId(std::optional<int64_t> tid) {
+    GetContext().maskedTId = tid;
+  }
   Symbol &MakeAssocSymbol(const SourceName &name, Symbol &prev, Scope &scope) {
     const auto pair{scope.try_emplace(name, Attrs{}, HostAssocDetails{prev})};
     return *pair.first->second;
@@ -646,6 +652,7 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
 
 private:
   std::int64_t GetAssociatedLoopLevelFromClauses(const parser::OmpClauseList &);
+  std::optional<int64_t> GetMaskedTId(const parser::OmpClauseList &);
 
   Symbol::Flags dataSharingAttributeFlags{Symbol::Flag::OmpShared,
       Symbol::Flag::OmpPrivate, Symbol::Flag::OmpFirstPrivate,
@@ -1105,18 +1112,18 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCCombinedConstruct &x) {
 static bool IsLastNameArray(const parser::Designator &designator) {
   const auto &name{GetLastName(designator)};
   const evaluate::DataRef dataRef{*(name.symbol)};
-  return common::visit(
-      common::visitors{
-          [](const evaluate::SymbolRef &ref) {
-            return ref->Rank() > 0 ||
-                ref->GetType()->category() == DeclTypeSpec::Numeric;
-          },
-          [](const evaluate::ArrayRef &aref) {
-            return aref.base().IsSymbol() ||
-                aref.base().GetComponent().base().Rank() == 0;
-          },
-          [](const auto &) { return false; },
-      },
+  return common::visit(common::visitors{
+                           [](const evaluate::SymbolRef &ref) {
+                             return ref->Rank() > 0 ||
+                                 ref->GetType()->category() ==
+                                 DeclTypeSpec::Numeric;
+                           },
+                           [](const evaluate::ArrayRef &aref) {
+                             return aref.base().IsSymbol() ||
+                                 aref.base().GetComponent().base().Rank() == 0;
+                           },
+                           [](const auto &) { return false; },
+                       },
       dataRef.u);
 }
 
@@ -1498,11 +1505,35 @@ void AccAttributeVisitor::CheckMultipleAppearances(
     AddDataSharingAttributeObject(*target);
   }
 }
+std::optional<int64_t> OmpAttributeVisitor::GetMaskedTId(
+    const parser::OmpClauseList &clauseList) {
+  for (const auto &clause : clauseList.v) {
+    if (const auto *filterClause{
+            std::get_if<parser::OmpClause::Filter>(&clause.u)}) {
+      if (const auto v{EvaluateInt64(context_, filterClause->v)}) {
+        return v;
+      }
+    }
+  }
+  // if no thread id is specified in filter clause, the masked thread id should
+  // be master's
+  return 0;
+}
 
 bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
   const auto &beginBlockDir{std::get<parser::OmpBeginBlockDirective>(x.t)};
   const auto &beginDir{std::get<parser::OmpBlockDirective>(beginBlockDir.t)};
+  const auto &clauseList{std::get<parser::OmpClauseList>(beginBlockDir.t)};
   switch (beginDir.v) {
+  case llvm::omp::Directive::OMPD_masked_taskloop_simd:
+  case llvm::omp::Directive::OMPD_masked_taskloop:
+  case llvm::omp::Directive::OMPD_masked:
+  case llvm::omp::Directive::OMPD_parallel_masked_taskloop_simd:
+  case llvm::omp::Directive::OMPD_parallel_masked_taskloop:
+  case llvm::omp::Directive::OMPD_parallel_masked:
+    PushContext(beginDir.source, beginDir.v);
+    SetMaskedTId(GetMaskedTId(clauseList));
+    break;
   case llvm::omp::Directive::OMPD_master:
   case llvm::omp::Directive::OMPD_ordered:
   case llvm::omp::Directive::OMPD_parallel:
@@ -1532,6 +1563,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
   const auto &beginBlockDir{std::get<parser::OmpBeginBlockDirective>(x.t)};
   const auto &beginDir{std::get<parser::OmpBlockDirective>(beginBlockDir.t)};
   switch (beginDir.v) {
+  case llvm::omp::Directive::OMPD_masked_taskloop_simd:
+  case llvm::omp::Directive::OMPD_masked_taskloop:
+  case llvm::omp::Directive::OMPD_masked:
+  case llvm::omp::Directive::OMPD_parallel_masked_taskloop_simd:
+  case llvm::omp::Directive::OMPD_parallel_masked_taskloop:
+  case llvm::omp::Directive::OMPD_parallel_masked:
   case llvm::omp::Directive::OMPD_parallel:
   case llvm::omp::Directive::OMPD_single:
   case llvm::omp::Directive::OMPD_target:
diff --git a/flang/test/Parser/OpenMP/masked-unparse.f90 b/flang/test/Parser/OpenMP/masked-unparse.f90
new file mode 100644
index 0000000000000..96ccc3be238c0
--- /dev/null
+++ b/flang/test/Parser/OpenMP/masked-unparse.f90
@@ -0,0 +1,92 @@
+! RUN: %flang_fc1 -fdebug-unparse -fopenmp %s | FileCheck --ignore-case %s
+! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp %s | FileCheck --check-prefix="PARSE-TREE" %s
+
+! Check for parsing of masked directive with filter clause.
+
+
+subroutine test_masked()
+  integer :: c = 1
+  !PARSE-TREE: OmpBeginBlockDirective
+  !PARSE-TREE-NEXT: OmpBlockDirective -> llvm::omp::Directive = masked
+  !CHECK: !$omp masked
+  !$omp masked
+  c = c + 1
+  !$omp end masked
+  !PARSE-TREE: OmpBeginBlockDirective
+  !PARSE-TREE-NEXT: OmpBlockDirective -> llvm::omp::Directive = masked
+  !PARSE-TREE-NEXT: OmpClauseList -> OmpClause -> Filter -> Scalar -> Integer -> Expr = '1_4'
+  !PARSE-TREE-NEXT: LiteralConstant -> IntLiteralConstant = '1'
+  !CHECK: !$omp masked filter(1_4)
+  !$omp masked filter(1)
+  c = c + 2
+  !$omp end masked
+end subroutine
+
+subroutine test_masked_taskloop_simd()
+  integer :: i, j = 1
+  !PARSE-TREE: OmpBeginBlockDirective
+  !PARSE-TREE-NEXT: OmpBlockDirective -> llvm::omp::Directive = masked taskloop simd
+  !CHECK: !$omp masked taskloop simd
+  !$omp masked taskloop simd
+  do i=1,10
+   j = j + 1
+  end do
+  !$omp end masked taskloop simd
+end subroutine
+
+subroutine test_masked_taskloop
+  integer :: i, j = 1
+  !PARSE-TREE: OmpBeginBlockDirective
+  !PARSE-TREE-NEXT: OmpBlockDirective -> llvm::omp::Directive = masked taskloop
+  !PARSE-TREE-NEXT: OmpClauseList -> OmpClause -> Filter -> Scalar -> Integer -> Expr = '2_4'
+  !PARSE-TREE-NEXT: LiteralConstant -> IntLiteralConstant = '2'
+  !CHECK: !$omp masked taskloop filter(2_4)
+  !$omp masked taskloop filter(2)
+  do i=1,10
+   j = j + 1
+  end do
+  !$omp end masked taskloop
+end subroutine
+
+subroutine test_parallel_masked
+  integer, parameter :: i = 1, j = 1
+  integer :: c = 2
+  !PARSE-TREE: OmpBeginBlockDirective
+  !PARSE-TREE-NEXT: OmpBlockDirective -> llvm::omp::Directive = parallel masked
+  !PARSE-TREE-NEXT: OmpClauseList -> OmpClause -> Filter -> Scalar -> Integer -> Expr = '2_4'
+  !PARSE-TREE-NEXT: Add
+  !PARSE-TREE-NEXT: Expr = '1_4'
+  !PARSE-TREE-NEXT: Designator -> DataRef -> Name = 'i'
+  !PARSE-TREE-NEXT: Expr = '1_4'
+  !PARSE-TREE-NEXT: Designator -> DataRef -> Name = 'j'
+  !CHECK: !$omp parallel masked filter(2_4)
+  !$omp parallel masked filter(i+j)
+  c = c + 2
+  !$omp end parallel masked
+end subroutine
+
+subroutine test_parallel_masked_taskloop_simd
+  integer :: i, j = 1
+  !PARSE-TREE: OmpBeginBlockDirective
+  !PARSE-TREE-NEXT: OmpBlockDirective -> llvm::omp::Directive = parallel masked taskloop simd
+  !CHECK: !$omp parallel masked taskloop simd
+  !$omp parallel masked taskloop simd
+  do i=1,10
+   j = j + 1
+  end do
+  !$omp end parallel masked taskloop simd
+end subroutine
+
+subroutine test_parallel_masked_taskloop
+  integer :: i, j = 1
+  !PARSE-TREE: OmpBeginBlockDirective
+  !PARSE-TREE-NEXT: OmpBlockDirective -> llvm::omp::Directive = parallel masked taskloop
+  !PARSE-TREE-NEXT: OmpClauseList -> OmpClause -> Filter -> Scalar -> Integer -> Expr = '2_4'
+  !PARSE-TREE-NEXT: LiteralConstant -> IntLiteralConstant = '2'
+  !CHECK: !$omp parallel masked taskloop filter(2_4)
+  !$omp parallel masked taskloop filter(2)
+  do i=1,10
+   j = j + 1
+  end do
+  !$omp end parallel masked taskloop
+end subroutine
diff --git a/flang/test/Semantics/OpenMP/masked.f90 b/flang/test/Semantics/OpenMP/masked.f90
new file mode 100644
index 0000000000000..36e22ee0be8c5
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/masked.f90
@@ -0,0 +1,13 @@
+! RUN: %python %S/../test_errors.py %s %flang_fc1 -fopenmp
+
+subroutine test_masked()
+  integer :: c = 1
+  !ERROR: At most one FILTER clause can appear on the MASKED directive
+  !$omp masked filter(1) filter(2)
+  c = c + 1
+  !$omp end masked
+  !ERROR: NOWAIT clause is not allowed on the MASKED directive
+  !$omp masked nowait
+  c = c + 2
+  !$omp end masked
+end subroutine

>From a069fdbbdab5a005225d4d9595bd53910c14c310 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Wed, 8 May 2024 12:37:26 -0500
Subject: [PATCH 2/2] R2: Removing the unrelated formatting changes

---
 flang/lib/Parser/openmp-parsers.cpp        |  7 ++++---
 flang/lib/Semantics/resolve-directives.cpp | 24 +++++++++++-----------
 2 files changed, 16 insertions(+), 15 deletions(-)

diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index e470bf7856607..b1900993f1f42 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -220,11 +220,12 @@ TYPE_PARSER(construct<OmpAlignedClause>(
 
 // 2.9.5 ORDER ([order-modifier :]concurrent)
 TYPE_PARSER(construct<OmpOrderModifier>(
-                "REPRODUCIBLE" >> pure(OmpOrderModifier::Kind::Reproducible)) ||
+    "REPRODUCIBLE" >> pure(OmpOrderModifier::Kind::Reproducible)) ||
     construct<OmpOrderModifier>(
-        "UNCONSTRAINED" >> pure(OmpOrderModifier::Kind::Unconstrained)))
+    "UNCONSTRAINED" >> pure(OmpOrderModifier::Kind::Unconstrained)))
 
-TYPE_PARSER(construct<OmpOrderClause>(maybe(Parser<OmpOrderModifier>{} / ":"),
+TYPE_PARSER(construct<OmpOrderClause>(
+    maybe(Parser<OmpOrderModifier>{} / ":"),
     "CONCURRENT" >> pure(OmpOrderClause::Type::Concurrent)))
 
 TYPE_PARSER(
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index b11fa3174277d..9301b1247795d 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1112,18 +1112,18 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCCombinedConstruct &x) {
 static bool IsLastNameArray(const parser::Designator &designator) {
   const auto &name{GetLastName(designator)};
   const evaluate::DataRef dataRef{*(name.symbol)};
-  return common::visit(common::visitors{
-                           [](const evaluate::SymbolRef &ref) {
-                             return ref->Rank() > 0 ||
-                                 ref->GetType()->category() ==
-                                 DeclTypeSpec::Numeric;
-                           },
-                           [](const evaluate::ArrayRef &aref) {
-                             return aref.base().IsSymbol() ||
-                                 aref.base().GetComponent().base().Rank() == 0;
-                           },
-                           [](const auto &) { return false; },
-                       },
+  return common::visit(
+      common::visitors{
+          [](const evaluate::SymbolRef &ref) {
+            return ref->Rank() > 0 ||
+                ref->GetType()->category() == DeclTypeSpec::Numeric;
+          },
+          [](const evaluate::ArrayRef &aref) {
+            return aref.base().IsSymbol() ||
+                aref.base().GetComponent().base().Rank() == 0;
+          },
+          [](const auto &) { return false; },
+      },
       dataRef.u);
 }
 



More information about the flang-commits mailing list