[flang-commits] [flang] [mlir][OpenMP] - Transform target offloading directives with dependencies during PFT to MLIR conversion (PR #85130)

Pranav Bhandarkar via flang-commits flang-commits at lists.llvm.org
Wed Mar 13 13:58:24 PDT 2024


https://github.com/bhandarkar-pranav created https://github.com/llvm/llvm-project/pull/85130

This PR changes PFT to MLIR lowering for the following directives when they have the `depend` clause on them.
```
!$ omp target depend(..)
!$ omp target enter data depend(..)
!$ omp target update data depend(..)
!$ omp target exit data depend(..)
```
With this PR, lowering now involves the creation of an `omp.task` operation that encloses the `omp.target` operation that is otherwise generated for the target construct. In addition, the depend clause from the target is moved to the enclosing new `omp.task`. The new `omp.task` is a mergeable task.

```
!$ omp target map(..) depend(in:a)
   b = a
```
is transformed to the following MLIR

```
omp.task mergeable depend(in:a) {
  omp.target map(..) {
    //MLIR for b = a;
  }
  omp.terminator
}
```
>[!NOTE]
>This PR is an alternative to https://github.com/llvm/llvm-project/pull/83966. Its benefits are that unlike https://github.com/llvm/llvm-project/pull/83966 it does not do an entire pass over the IR and changes required for translation of the `depend` clause on target constructs into LLVM IR are entirely contained in this one change. This is in comparison to the approach of https://github.com/llvm/llvm-project/pull/83966 which would need adding that pass to the pass pipeline in flang.
Further, now the all `depend` clause related operands and attributes of `omp.target` can be removed from MLIR (patch coming soon after this PR is merged)

>From 1da3208d9ea66cb8b475ddefb0db759a802363a7 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Wed, 13 Mar 2024 00:26:53 -0500
Subject: [PATCH] [mlir][OpenMP] - Transform target offloading directives with
 dependencies during PFT to MLIR conversion

This patch changes PFT to MLIR lowering for the following directives
when they have the `depend` clause on them.

!$ omp target depend(..)
!$ omp target enter data depend(..)
!$ omp target update data depend(..)
!$ omp target exit data depend(..)

Lowering now involves the creation of an `omp.task` operation that encloses
the `omp.target` operation that is otherwise generated for the target construct.
In addition, the depend clause from the target is moved to the enclosing
new `omp.task`. The new `omp.task` is a mergeable task.

```
!$ omp target map(..) depend(in:a)
   b = a
```
is transformed to the following MLIR

```
omp.task mergeable depend(in:a) {
  omp.target map(..) {
    //MLIR for b = a;
  }
  omp.terminator
}
```
---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 37 ++++++++++++++++++++++
 flang/lib/Lower/OpenMP/ClauseProcessor.h   |  7 ++++
 flang/lib/Lower/OpenMP/OpenMP.cpp          | 20 ++++--------
 flang/test/Lower/OpenMP/target.f90         | 13 +++++---
 4 files changed, 59 insertions(+), 18 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a41f8312a28c9e..8961268fd63a17 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -808,6 +808,43 @@ bool ClauseProcessor::processDepend(
       });
 }
 
+bool ClauseProcessor::processTargetDepend(
+    mlir::Location currentLocation) const {
+  llvm::SmallVector<mlir::Attribute> dependTypeOperands;
+  llvm::SmallVector<mlir::Value> dependOperands;
+
+  processDepend(dependTypeOperands, dependOperands);
+  if (dependTypeOperands.empty())
+    return false;
+
+  // If 'dependTypeOperands' is not empty, this means the depend
+  // clause was used and we create an omp.task operation that'll
+  // enclose the omp.target operation corresponding to the target
+  // construct used. This new omp.task will be a mergeable task
+  // on which the depend clause will be tacked on. The depend
+  // clause on the original target construct is dropped.
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  // Create the new omp.task op.
+  // As per the OpenMP Spec a target directive creates a mergeable 'target
+  // task'
+  mlir::omp::TaskOp taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
+      currentLocation, /*if_expr*/ mlir::Value(),
+      /*final_expr*/ mlir::Value(), /*untied*/ mlir::UnitAttr(),
+      /*mergeable*/ firOpBuilder.getUnitAttr(),
+      /*in_reduction_vars*/ mlir::ValueRange(), /*in_reductions*/ nullptr,
+      /*priority*/ mlir::Value(),
+      mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                           dependTypeOperands),
+      dependOperands, /*allocate_vars*/ mlir::ValueRange(),
+      /*allocate_vars*/ mlir::ValueRange());
+
+  firOpBuilder.createBlock(&taskOp.getRegion());
+  firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
+  firOpBuilder.setInsertionPointToStart(&taskOp.getRegion().front());
+  return true;
+}
+
 bool ClauseProcessor::processIf(
     Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
     mlir::Value &result) const {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 11aff0be25053d..140c7a8d5cefe8 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -101,6 +101,13 @@ class ClauseProcessor {
       llvm::SmallVectorImpl<mlir::Attribute> &copyPrivateFuncs) const;
   bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
                      llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
+
+  // This is a special case of processDepend that processes the depend
+  // clause on Target ops - TargetOp, EnterDataOp, ExitDataOp, UpdateDataOp
+  // It sets up the generation of MLIR code for the target construct
+  // in question by first creating an enclosing omp.task operation and transfers
+  // the 'depend' clause and its arguments to this new omp.task operation.
+  bool processTargetDepend(mlir::Location currentLocation) const;
   bool
   processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
   bool
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 25bb4d9cff5d16..011e7d75245061 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -933,7 +933,7 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
   ClauseProcessor cp(converter, semaCtx, clauseList);
   cp.processIf(directiveName, ifClauseOperand);
   cp.processDevice(stmtCtx, deviceOperand);
-  cp.processDepend(dependTypeOperands, dependOperands);
+  cp.processTargetDepend(currentLocation);
   cp.processNowait(nowaitAttr);
 
   if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
@@ -946,13 +946,9 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
     cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
   }
 
-  return firOpBuilder.create<OpTy>(
-      currentLocation, ifClauseOperand, deviceOperand,
-      dependTypeOperands.empty()
-          ? nullptr
-          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
-                                 dependTypeOperands),
-      dependOperands, nowaitAttr, mapOperands);
+  return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
+                                   deviceOperand, nullptr, mlir::ValueRange(),
+                                   nowaitAttr, mapOperands);
 }
 
 // This functions creates a block for the body of the targetOp's region. It adds
@@ -1130,7 +1126,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
                ifClauseOperand);
   cp.processDevice(stmtCtx, deviceOperand);
   cp.processThreadLimit(stmtCtx, threadLimitOperand);
-  cp.processDepend(dependTypeOperands, dependOperands);
+  cp.processTargetDepend(currentLocation);
   cp.processNowait(nowaitAttr);
   cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
                 &mapSymLocs, &mapSymbols);
@@ -1232,11 +1228,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
 
   auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
-      dependTypeOperands.empty()
-          ? nullptr
-          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
-                                 dependTypeOperands),
-      dependOperands, nowaitAttr, mapOperands);
+      nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
 
   genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
                     mapSymLocs, mapSymbols, currentLocation);
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 030533e1a04553..10c928dda19b43 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -27,9 +27,10 @@ subroutine omp_target_enter_depend
    !$omp task depend(out: a)
    call foo(a)
    !$omp end task
+   !CHECK: omp.task mergeable depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}})   map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target_enter_data   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+   !CHECK: omp.target_enter_data   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
    !$omp target enter data map(to: a) depend(in: a)
     return
 end subroutine omp_target_enter_depend
@@ -166,9 +167,10 @@ subroutine omp_target_exit_depend
    !$omp task depend(out: a)
    call foo(a)
    !$omp end task
+   !CHECK: omp.task mergeable depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}})   map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target_exit_data   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+   !CHECK: omp.target_exit_data   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
    !$omp target exit data map(from: a) depend(out: a)
 end subroutine omp_target_exit_depend
 
@@ -187,9 +189,10 @@ subroutine omp_target_update_depend
    call foo(a)
    !$omp end task
 
+   !CHECK: omp.task mergeable depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
    !CHECK: %[[BOUNDS:.*]] = omp.bounds
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+   !CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
    !$omp target update to(a) depend(in:a)
 end subroutine omp_target_update_depend
 
@@ -367,12 +370,14 @@ subroutine omp_target_depend
    !$omp task depend(out: a)
    call foo(a)
    !$omp end task
+
+   !CHECK: omp.task mergeable depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
    !CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index
    !CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index
    !CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index
    !CHECK: %[[BOUNDS_A:.*]] = omp.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index)
    !CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>)
    !$omp target map(tofrom: a) depend(in: a)
       a(1) = 10
       !CHECK: omp.terminator



More information about the flang-commits mailing list