[flang-commits] [flang] Adding Changes for invoking Masked Operation (PR #98423)

via flang-commits flang-commits at lists.llvm.org
Wed Jul 10 18:33:29 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Anchu Rajendran S (anchuraj)

<details>
<summary>Changes</summary>

Change adds the lowering support for Omp masked directive.  Omp masked is introduced in 5.2 standard and allows a parallel region to be executed by threads specified by a programmer. This is achieved with the help of filter clause which helps to specify thread id expected to execute the region.

Other related PRs: 
- [Fortran Parsing and Semantic Support](https://github.com/llvm/llvm-project/pull/91432) - Merged
- [MLIR Support](https://github.com/llvm/llvm-project/pull/96022/files) - Merged
- [Lowering Support](https://github.com/llvm/llvm-project/pull/98401) - Under Review

---
Full diff: https://github.com/llvm/llvm-project/pull/98423.diff


5 Files Affected:

- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+10) 
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+2) 
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+28-1) 
- (removed) flang/test/Lower/OpenMP/Todo/masked-directive.f90 (-13) 
- (added) flang/test/Lower/OpenMP/masked.f90 (+25) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index d507e58b164dd..f1e049c15c5c3 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -332,6 +332,16 @@ bool ClauseProcessor::processDistSchedule(
   return false;
 }
 
+bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
+                                    mlir::omp::FilterClauseOps &result) const {
+  if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
+    result.filteredThreadIdVar =
+        fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+    return true;
+  }
+  return false;
+}
+
 bool ClauseProcessor::processFinal(lower::StatementContext &stmtCtx,
                                    mlir::omp::FinalClauseOps &result) const {
   const parser::CharBlock *source = nullptr;
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 43795d5c25399..a8c0f26759565 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -63,6 +63,8 @@ class ClauseProcessor {
   bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
   bool processDistSchedule(lower::StatementContext &stmtCtx,
                            mlir::omp::DistScheduleClauseOps &result) const;
+  bool processFilter(lower::StatementContext &stmtCtx,
+                     mlir::omp::FilterClauseOps &result) const;
   bool processFinal(lower::StatementContext &stmtCtx,
                     mlir::omp::FinalClauseOps &result) const;
   bool processHasDeviceAddr(
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index dffdb834d4e66..63740835a131a 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1067,6 +1067,15 @@ genLoopNestClauses(lower::AbstractConverter &converter,
   clauseOps.loopInclusiveAttr = converter.getFirOpBuilder().getUnitAttr();
 }
 
+static void genMaskedClauses(lower::AbstractConverter &converter,
+                             semantics::SemanticsContext &semaCtx,
+                             lower::StatementContext &stmtCtx,
+                             const List<Clause> &clauses, mlir::Location loc,
+                             mlir::omp::MaskedClauseOps &clauseOps) {
+  ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processFilter(stmtCtx, clauseOps);
+}
+
 static void
 genOrderedRegionClauses(lower::AbstractConverter &converter,
                         semantics::SemanticsContext &semaCtx,
@@ -1375,6 +1384,21 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+static mlir::omp::MaskedOp
+genMaskedOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+            semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+            mlir::Location loc, const ConstructQueue &queue,
+            ConstructQueue::iterator item) {
+  lower::StatementContext stmtCtx;
+  mlir::omp::MaskedClauseOps clauseOps;
+  genMaskedClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
+
+  return genOpWithBody<mlir::omp::MaskedOp>(
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_masked),
+      queue, item, clauseOps);
+}
+
 static mlir::omp::MasterOp
 genMasterOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
             semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
@@ -2106,9 +2130,11 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                     *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_loop:
-  case llvm::omp::Directive::OMPD_masked:
     TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
     break;
+  case llvm::omp::Directive::OMPD_masked:
+    genMaskedOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    break;
   case llvm::omp::Directive::OMPD_master:
     genMasterOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
@@ -2444,6 +2470,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
         !std::holds_alternative<clause::Copyprivate>(clause.u) &&
         !std::holds_alternative<clause::Default>(clause.u) &&
         !std::holds_alternative<clause::Depend>(clause.u) &&
+        !std::holds_alternative<clause::Filter>(clause.u) &&
         !std::holds_alternative<clause::Final>(clause.u) &&
         !std::holds_alternative<clause::Firstprivate>(clause.u) &&
         !std::holds_alternative<clause::HasDeviceAddr>(clause.u) &&
diff --git a/flang/test/Lower/OpenMP/Todo/masked-directive.f90 b/flang/test/Lower/OpenMP/Todo/masked-directive.f90
deleted file mode 100644
index 77767715af522..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/masked-directive.f90
+++ /dev/null
@@ -1,13 +0,0 @@
-! This test checks lowering of OpenMP masked Directive.
-
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Unhandled directive masked
-subroutine test_masked()
-  integer :: c = 1
-  !$omp masked
-  c = c + 1
-  !$omp end masked
-end subroutine
-
diff --git a/flang/test/Lower/OpenMP/masked.f90 b/flang/test/Lower/OpenMP/masked.f90
new file mode 100644
index 0000000000000..0d67c08d2d9f4
--- /dev/null
+++ b/flang/test/Lower/OpenMP/masked.f90
@@ -0,0 +1,25 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+!CHECK-LABEL: func @_QPomp_masked
+subroutine omp_masked(threadId)
+integer :: threadId
+
+!CHECK: omp.masked  {
+!$omp masked
+
+    !CHECK: fir.call @_QPmasked() {{.*}}: () -> ()
+    call masked()
+
+!CHECK: omp.terminator
+!$omp end masked
+
+!CHECK: omp.masked filter({{.*}})  {
+!$omp masked filter(threadId)
+
+    !CHECK: fir.call @_QPmaskedwithfilter() {{.*}}: () -> ()
+    call maskedWithFilter()
+
+!CHECK: omp.terminator
+!$omp end masked
+end subroutine omp_masked
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/98423


More information about the flang-commits mailing list