[flang-commits] [flang] [mlir] [mlir][OpenMP] - Transform target offloading directives for easier translation to LLVMIR (PR #83966)

Pranav Bhandarkar via flang-commits flang-commits at lists.llvm.org
Mon Mar 4 23:22:42 PST 2024


https://github.com/bhandarkar-pranav updated https://github.com/llvm/llvm-project/pull/83966

>From 8b5d7e6ec5b2c342bebd5c8a67f88fbc25e5f95c Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 20 Feb 2024 12:51:45 -0600
Subject: [PATCH 01/10] Checkpoint commit - Add buildDependData

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 54 +++++++++++--------
 1 file changed, 31 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index fd1de274da60e8..1b6fa5ebd83158 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -685,7 +685,36 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
       ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
   return bodyGenStatus;
 }
-
+template <typename T>
+static void buildDependData(T taskOrTargetop,
+                            SmallVector<llvm::OpenMPIRBuilder::DependData> &dds,
+                            LLVM::ModuleTranslation &moduleTranslation) {
+  // std::optional<ArrayAttr> depends,
+  //                  OperandRange &dependVars,
+  if (taskOrTargetop.getDependVars().empty())
+    return;
+  std::optional<ArrayAttr> depends = taskOrTargetop.getDepends();
+  const OperandRange &dependVars = taskOrTargetop.getDependVars();
+  for (auto dep : llvm::zip(dependVars, depends->getValue())) {
+    llvm::omp::RTLDependenceKindTy type;
+    switch (
+        cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
+    case mlir::omp::ClauseTaskDepend::taskdependin:
+      type = llvm::omp::RTLDependenceKindTy::DepIn;
+      break;
+      // The OpenMP runtime requires that the codegen for 'depend' clause for
+      // 'out' dependency kind must be the same as codegen for 'depend' clause
+      // with 'inout' dependency.
+    case mlir::omp::ClauseTaskDepend::taskdependout:
+    case mlir::omp::ClauseTaskDepend::taskdependinout:
+      type = llvm::omp::RTLDependenceKindTy::DepInOut;
+      break;
+    };
+    llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
+    llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+    dds.emplace_back(dd);
+  }
+}
 // Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
 static LogicalResult
 convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
@@ -748,28 +777,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   };
 
   SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
-  if (!taskOp.getDependVars().empty() && taskOp.getDepends()) {
-    for (auto dep :
-         llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) {
-      llvm::omp::RTLDependenceKindTy type;
-      switch (
-          cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
-      case mlir::omp::ClauseTaskDepend::taskdependin:
-        type = llvm::omp::RTLDependenceKindTy::DepIn;
-        break;
-      // The OpenMP runtime requires that the codegen for 'depend' clause for
-      // 'out' dependency kind must be the same as codegen for 'depend' clause
-      // with 'inout' dependency.
-      case mlir::omp::ClauseTaskDepend::taskdependout:
-      case mlir::omp::ClauseTaskDepend::taskdependinout:
-        type = llvm::omp::RTLDependenceKindTy::DepInOut;
-        break;
-      };
-      llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
-      llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
-      dds.emplace_back(dd);
-    }
-  }
+  buildDependData(taskOp, dds, moduleTranslation);
 
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);

>From 89cb7832daa6c8a9c42122b76769d339f7b0154e Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Fri, 1 Mar 2024 00:28:34 -0600
Subject: [PATCH 02/10] checkpoint commit - pass openmp-task-based-target set
 up

---
 .../flang/Optimizer/Builder/FIRBuilder.h      |  5 ++
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 19 +++++++-
 .../mlir/Dialect/OpenMP/CMakeLists.txt        |  8 ++++
 mlir/include/mlir/Dialect/OpenMP/Passes.h     | 35 ++++++++++++++
 mlir/include/mlir/Dialect/OpenMP/Passes.td    | 23 ++++++++++
 mlir/include/mlir/InitAllPasses.h             |  2 +
 mlir/lib/CAPI/Dialect/CMakeLists.txt          |  1 +
 mlir/lib/Dialect/OpenMP/CMakeLists.txt        | 19 +-------
 mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt     | 17 +++++++
 .../Dialect/OpenMP/Transforms/CMakeLists.txt  | 15 ++++++
 .../Transforms/OpenMPTaskBasedTarget.cpp      | 46 +++++++++++++++++++
 11 files changed, 171 insertions(+), 19 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/OpenMP/Passes.h
 create mode 100644 mlir/include/mlir/Dialect/OpenMP/Passes.td
 create mode 100644 mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp

diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index bd9b67b14b966c..715faa11e5a12a 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -22,6 +22,7 @@
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "llvm/ADT/DenseMap.h"
@@ -499,6 +500,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
     if (previous.isSet())
       return;
     setCommonAttributes(op);
+    mlir::omp::TargetOp targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(op);
+    if (targetOp)
+      llvm::errs() << "Inserted operation\n";
+    // op->dump();
   }
 
 private:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 90fc1f80f57ab7..6a827ad1b60ea9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1143,7 +1143,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
     }
   };
   Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
-
+  llvm::errs() << "generating targetOp\n";
   auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
       dependTypeOperands.empty()
@@ -1151,9 +1151,22 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                  dependTypeOperands),
       dependOperands, nowaitAttr, mapOperands);
-
+  mlir::Operation *targetOperation = targetOp.getOperation();
+  llvm::errs() << "NumRegions held by targetOperation = "
+               << targetOperation->getNumRegions();
+  llvm::errs() << "Number of blocks in Region 0 = "
+               << targetOperation->getRegion(0).getBlocks().size();
+  llvm::errs() << "Block that the targetOperation is in\n";
+  llvm::errs() << "generating body of TargetOp\n";
   genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
                     mapSymLocs, mapSymbols, currentLocation);
+  llvm::errs() << "NumRegions held by targetOperation = "
+               << targetOperation->getNumRegions();
+  llvm::errs() << "Number of blocks in Region 0 = "
+               << targetOperation->getRegion(0).getBlocks().size();
+  llvm::errs() << "Block that the targetOperation is in\n";
+  //  converter.getFirOpBuilder().getInsertionBlock()->getParent()->dump();
+  targetOperation->getBlock()->dump();
 
   return targetOp;
 }
@@ -1797,8 +1810,10 @@ genOMP(Fortran::lower::AbstractConverter &converter,
               beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_task:
+    //    converter.getFirOpBuilder().getInsertionBlock()->dump();
     genTaskOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
               beginClauseList);
+    //    converter.getFirOpBuilder().getInsertionBlock()->dump();
     break;
   case llvm::omp::Directive::OMPD_taskgroup:
     genTaskGroupOp(converter, semaCtx, eval, /*genNested=*/true,
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index 419e24a7335361..51ab0f23cd00d5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -23,3 +23,11 @@ mlir_tablegen(OpenMPTypeInterfaces.h.inc -gen-type-interface-decls)
 mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs)
 add_public_tablegen_target(MLIROpenMPTypeInterfacesIncGen)
 add_dependencies(mlir-generic-headers MLIROpenMPTypeInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenMP)
+mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix OpenMP)
+mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix OpenMP)
+add_public_tablegen_target(MLIROpenMPPassIncGen)
+
+add_mlir_doc(Passes OpenMPPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.h b/mlir/include/mlir/Dialect/OpenMP/Passes.h
new file mode 100644
index 00000000000000..2167c95055d31f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.h
@@ -0,0 +1,35 @@
+//===- Passes.h - OpenMP passes entry points -----------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_PASSES_H
+#define MLIR_DIALECT_OPENMP_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+std::unique_ptr<Pass> createOpenMPTaskBasedTargetPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+namespace omp {
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/OpenMP/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Passes.td
new file mode 100644
index 00000000000000..57d73384856ec5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.td
@@ -0,0 +1,23 @@
+//===-- Passes.td - OpenMP pass definition file -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_PASSES
+#define MLIR_DIALECT_OPENMP_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> {
+  let summary = "Nest certain instances of mlir::omp::TargetOp inside mlir::omp::TaskOp";
+
+  let constructor = "mlir::createOpenMPTaskBasedTargetPass()";
+
+  let description = [{ First pass attempt}];
+
+  let dependentDialects = ["omp::OpenMPDialect"];
+}
+#endif
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5d90c197a6cced..902ab8f4c4fd1b 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -36,6 +36,7 @@
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/OpenMP/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
   memref::registerMemRefPasses();
   mesh::registerMeshPasses();
   ml_program::registerMLProgramPasses();
+  omp::registerOpenMPPasses();
   registerSCFPasses();
   registerShapePasses();
   spirv::registerSPIRVPasses();
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index 58b8739043f9df..439f0093cc3d26 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -223,6 +223,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIOpenMP
   LINK_LIBS PUBLIC
   MLIRCAPIIR
   MLIROpenMPDialect
+  MLIROpenMPTransforms
 )
 
 add_mlir_upstream_c_api_library(MLIRCAPIPDL
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 40b4837484a136..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,17 +1,2 @@
-add_mlir_dialect_library(MLIROpenMPDialect
-  IR/OpenMPDialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
-
-  DEPENDS
-  MLIROpenMPOpsIncGen
-  MLIROpenMPOpsInterfacesIncGen
-  MLIROpenMPTypeInterfacesIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRLLVMDialect
-  MLIRFuncDialect
-  MLIROpenACCMPCommon
-  )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..def53387255653
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIROpenMPDialect
+  OpenMPDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
+
+  DEPENDS
+  MLIROpenMPOpsIncGen
+  MLIROpenMPOpsInterfacesIncGen
+  MLIROpenMPTypeInterfacesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRFuncDialect
+  MLIROpenACCMPCommon
+  )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..06a1187df6d63a
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIROpenMPTransforms
+  OpenMPTaskBasedTarget.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
+
+  DEPENDS
+  MLIROpenMPPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIROpenMPDialect
+  MLIRFuncDialect
+  MLIRIR
+  MLIRPass
+)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
new file mode 100644
index 00000000000000..343d830189686f
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -0,0 +1,46 @@
+//===- OpenMPTaskBasedTarget.cpp - Implementation of OpenMPTaskBasedTargetPass
+//---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements scf.parallel to scf.for + async.execute conversion pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_OPENMPTASKBASEDTARGET
+#include "mlir/Dialect/OpenMP/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::omp;
+
+#define DEBUG_TYPE "openmp-task-based-target"
+
+namespace {
+
+struct OpenMPTaskBasedTargetPass
+    : public impl::OpenMPTaskBasedTargetBase<OpenMPTaskBasedTargetPass> {
+
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void OpenMPTaskBasedTargetPass::runOnOperation() {
+  Operation *op = getOperation();
+
+  op->dump();
+}
+std::unique_ptr<Pass> mlir::createOpenMPTaskBasedTargetPass() {
+  return std::make_unique<OpenMPTaskBasedTargetPass>();
+}

>From 17c68a313729ced73e2c9f2a0ec9bf72fc5b51dd Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Sat, 2 Mar 2024 07:29:35 -0600
Subject: [PATCH 03/10] Add patterns that match but do not rewrite anytyhing
 yet

---
 .../Dialect/OpenMP/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/OpenMPTaskBasedTarget.cpp      | 27 +++++++++++++++++--
 2 files changed, 26 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
index 06a1187df6d63a..1a64b5268e0839 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIROpenMPTransforms
   MLIRFuncDialect
   MLIRIR
   MLIRPass
+  MLIRTransforms
 )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index 343d830189686f..3c37848c36872f 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -15,6 +15,8 @@
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_OPENMPTASKBASEDTARGET
@@ -33,13 +35,34 @@ struct OpenMPTaskBasedTargetPass
 
   void runOnOperation() override;
 };
-
+template <typename OpTy>
+class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getDependVars().empty()) {
+      return rewriter.notifyMatchFailure(op, "depend clause not found on op");
+    }
+    return success();
+  }
+};
 } // namespace
+static void
+populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>>(
+      patterns.getContext());
+}
 
 void OpenMPTaskBasedTargetPass::runOnOperation() {
   Operation *op = getOperation();
+  LLVM_DEBUG(llvm::dbgs() << "Running on the following operation\n");
+  //  LLVM_DEBUG(llvm::dbgs() << op->dump());
 
-  op->dump();
+  RewritePatternSet patterns(op->getContext());
+  populateOmpTaskBasedTargetRewritePatterns(patterns);
+  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+    signalPassFailure();
 }
 std::unique_ptr<Pass> mlir::createOpenMPTaskBasedTargetPass() {
   return std::make_unique<OpenMPTaskBasedTargetPass>();

>From 0c695eb8e49acf2f61a65515017366aa8a7b46b3 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Mon, 4 Mar 2024 21:30:01 -0600
Subject: [PATCH 04/10] Working for a simple testcase

---
 .../Transforms/OpenMPTaskBasedTarget.cpp      | 34 +++++++++++++++++--
 .../Dialect/OpenMP/task-based-target.mlir     | 13 +++++++
 2 files changed, 45 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/OpenMP/task-based-target.mlir

diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index 3c37848c36872f..ec173e05e36f98 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -41,9 +41,40 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
+
+    // Only match a target op with a  'depend' clause on it.
     if (op.getDependVars().empty()) {
       return rewriter.notifyMatchFailure(op, "depend clause not found on op");
     }
+
+    // Step 1: Create a new task op and tack on the dependency from the 'depend'
+    // clause on it.
+    omp::TaskOp taskOp = rewriter.create<omp::TaskOp>(
+        op.getLoc(), /*if_expr*/ Value(),
+        /*final_expr*/ Value(),
+        /*untied*/ UnitAttr(),
+        /*mergeable*/ UnitAttr(),
+        /*in_reduction_vars*/ ValueRange(),
+        /*in_reductions*/ nullptr,
+        /*priority*/ Value(), op.getDepends().value(), op.getDependVars(),
+        /*allocate_vars*/ ValueRange(),
+        /*allocate_vars*/ ValueRange());
+    Block *block = rewriter.createBlock(&taskOp.getRegion());
+    rewriter.setInsertionPointToEnd(block);
+    // Step 2: Clone and put the entire target op inside the newly created
+    // task's region.
+    Operation *clonedTargetOperation = rewriter.clone(*op.getOperation());
+    rewriter.create<mlir::omp::TerminatorOp>(op.getLoc());
+
+    // Step 3: Remove the dependency information from the clone target op.
+    omp::TargetOp clonedTargetOp =
+        llvm::dyn_cast<omp::TargetOp>(clonedTargetOperation);
+    if (clonedTargetOp) {
+      clonedTargetOp.removeDependsAttr();
+      clonedTargetOp.getDependVarsMutable().clear();
+    }
+    // Step 4: Erase the original target op
+    rewriter.eraseOp(op.getOperation());
     return success();
   }
 };
@@ -56,11 +87,10 @@ populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
 
 void OpenMPTaskBasedTargetPass::runOnOperation() {
   Operation *op = getOperation();
-  LLVM_DEBUG(llvm::dbgs() << "Running on the following operation\n");
-  //  LLVM_DEBUG(llvm::dbgs() << op->dump());
 
   RewritePatternSet patterns(op->getContext());
   populateOmpTaskBasedTargetRewritePatterns(patterns);
+
   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/test/Dialect/OpenMP/task-based-target.mlir b/mlir/test/Dialect/OpenMP/task-based-target.mlir
new file mode 100644
index 00000000000000..d285fe2f9abfe3
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/task-based-target.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -openmp-task-based-target -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @omp_target_depend
+// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
+func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
+  // CHECK: omp.task depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+  // CHECK: omp.target {
+  omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+    // CHECK: omp.terminator
+    omp.terminator
+  } {operandSegmentSizes = array<i32: 0,0,0,3,0>}
+  return
+}

>From 7b474fcae6e254fb3232668e3cb75790fc02fc76 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 00:19:22 -0600
Subject: [PATCH 05/10] Exten openmp-task-based-target transformation to target
 enter/update/exit data as well

---
 .../Transforms/OpenMPTaskBasedTarget.cpp      |  8 +--
 .../Dialect/OpenMP/task-based-target.mlir     | 54 +++++++++++++++++++
 2 files changed, 59 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index ec173e05e36f98..cbb8546404a9ab 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -67,8 +67,7 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
     rewriter.create<mlir::omp::TerminatorOp>(op.getLoc());
 
     // Step 3: Remove the dependency information from the clone target op.
-    omp::TargetOp clonedTargetOp =
-        llvm::dyn_cast<omp::TargetOp>(clonedTargetOperation);
+    OpTy clonedTargetOp = llvm::dyn_cast<OpTy>(clonedTargetOperation);
     if (clonedTargetOp) {
       clonedTargetOp.removeDependsAttr();
       clonedTargetOp.getDependVarsMutable().clear();
@@ -81,7 +80,10 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
 } // namespace
 static void
 populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
-  patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>>(
+  patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>,
+               OmpTaskBasedTargetRewritePattern<omp::EnterDataOp>,
+               OmpTaskBasedTargetRewritePattern<omp::UpdateDataOp>,
+               OmpTaskBasedTargetRewritePattern<omp::ExitDataOp>>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/OpenMP/task-based-target.mlir b/mlir/test/Dialect/OpenMP/task-based-target.mlir
index d285fe2f9abfe3..26cc493047e19b 100644
--- a/mlir/test/Dialect/OpenMP/task-based-target.mlir
+++ b/mlir/test/Dialect/OpenMP/task-based-target.mlir
@@ -11,3 +11,57 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
   } {operandSegmentSizes = array<i32: 0,0,0,3,0>}
   return
 }
+// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
+// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
+func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
+// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
+  %map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+  %map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  %map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
+
+  // Do some work on the host that writes to 'a'
+  omp.task depend(taskdependout -> %a : memref<?xi32>) {
+    "test.foo"(%a) : (memref<?xi32>) -> ()
+    omp.terminator
+  }
+
+  // Then map that over to the target
+  // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+  // CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>)
+  omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin ->  %a: memref<?xi32>)
+
+  // Compute 'b' on the target and copy it back
+  // CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
+  omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>) :
+      "test.foo"(%arg0) : (memref<?xi32>) -> ()
+      omp.terminator
+  }
+
+  // Update 'a' on the host using 'b'
+  omp.task depend(taskdependout -> %a: memref<?xi32>){
+    "test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
+  }
+
+  // Copy the updated 'a' onto the target
+  // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+  // CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>)
+  omp.target_update_data motion_entries(%map_a :  memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
+
+  // Compute 'c' on the target and copy it back
+  // CHECK:[[MAP3:%.*]] = omp.map_info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  %map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  // CHECK: omp.task depend(taskdependout -> [[ARG2]] : memref<?xi32>)
+  // CHECK: omp.target map_entries([[MAP0]] -> {{%.*}}, [[MAP3]] -> {{%.*}} : memref<?xi32>, memref<?xi32>) {
+  omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+  ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+    "test.foobar"() : ()->()
+    omp.terminator
+  }
+  // CHECK: omp.task depend(taskdependin -> [[ARG2]] : memref<?xi32>) {
+  // CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>)
+  omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
+  return
+}

>From 0bd1e80aa412972c07200ada6e7a11cd0f16df0e Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 00:35:39 -0600
Subject: [PATCH 06/10] Remove some debug prints

---
 .../flang/Optimizer/Builder/FIRBuilder.h      |  5 -----
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 19 ++-----------------
 2 files changed, 2 insertions(+), 22 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 715faa11e5a12a..bd9b67b14b966c 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -22,7 +22,6 @@
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "llvm/ADT/DenseMap.h"
@@ -500,10 +499,6 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
     if (previous.isSet())
       return;
     setCommonAttributes(op);
-    mlir::omp::TargetOp targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(op);
-    if (targetOp)
-      llvm::errs() << "Inserted operation\n";
-    // op->dump();
   }
 
 private:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 6a827ad1b60ea9..90fc1f80f57ab7 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1143,7 +1143,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
     }
   };
   Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
-  llvm::errs() << "generating targetOp\n";
+
   auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
       dependTypeOperands.empty()
@@ -1151,22 +1151,9 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                  dependTypeOperands),
       dependOperands, nowaitAttr, mapOperands);
-  mlir::Operation *targetOperation = targetOp.getOperation();
-  llvm::errs() << "NumRegions held by targetOperation = "
-               << targetOperation->getNumRegions();
-  llvm::errs() << "Number of blocks in Region 0 = "
-               << targetOperation->getRegion(0).getBlocks().size();
-  llvm::errs() << "Block that the targetOperation is in\n";
-  llvm::errs() << "generating body of TargetOp\n";
+
   genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
                     mapSymLocs, mapSymbols, currentLocation);
-  llvm::errs() << "NumRegions held by targetOperation = "
-               << targetOperation->getNumRegions();
-  llvm::errs() << "Number of blocks in Region 0 = "
-               << targetOperation->getRegion(0).getBlocks().size();
-  llvm::errs() << "Block that the targetOperation is in\n";
-  //  converter.getFirOpBuilder().getInsertionBlock()->getParent()->dump();
-  targetOperation->getBlock()->dump();
 
   return targetOp;
 }
@@ -1810,10 +1797,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
               beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_task:
-    //    converter.getFirOpBuilder().getInsertionBlock()->dump();
     genTaskOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
               beginClauseList);
-    //    converter.getFirOpBuilder().getInsertionBlock()->dump();
     break;
   case llvm::omp::Directive::OMPD_taskgroup:
     genTaskGroupOp(converter, semaCtx, eval, /*genNested=*/true,

>From 419874127272edb3c334cacb4a6a1bbe7dfc88a5 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 00:43:59 -0600
Subject: [PATCH 07/10] Fix top level comment in OpenMPTaskBasedTarget.cpp

---
 .../Transforms/OpenMPTaskBasedTarget.cpp      | 21 ++++++++++++++++++-
 1 file changed, 20 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index cbb8546404a9ab..1e3977eb33e754 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -7,7 +7,26 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements scf.parallel to scf.for + async.execute conversion pass.
+// This file implements a pass that transforms certain omp.target.
+// Specifically, an omp.target op that has the depend clause on it is
+// transformed into an omp.task clause with the same depend clause on it.
+// The original omp.target loses its depend clause and is contained in
+// the new task region.
+//
+// omp.target depend(..) {
+//  omp.terminator
+//
+// }
+//
+// =>
+//
+// omp.task depend(..) {
+//   omp.target {
+//     omp.terminator
+//   }
+//   omp.terminator
+// }
+//
 //
 //===----------------------------------------------------------------------===//
 

>From d353d9d76943ed1d076d08bfb0904de83a48566e Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 01:08:47 -0600
Subject: [PATCH 08/10] Add a description for openmp-task-based-target pass

---
 mlir/include/mlir/Dialect/OpenMP/Passes.td | 32 +++++++++++++++++++++-
 1 file changed, 31 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Passes.td
index 57d73384856ec5..d4ff138c3c93c4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.td
@@ -16,7 +16,37 @@ def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> {
 
   let constructor = "mlir::createOpenMPTaskBasedTargetPass()";
 
-  let description = [{ First pass attempt}];
+  let description = [{
+  This pass transforms `omp.target`, `omp.target_enter_data`,
+  `omp.target_update_data` and `omp.target_exit_data` whenever these operations
+  have the `depend` clause on them.
+
+  These operations are transformed by enclosing them inside a new `omp.task`
+  operation. The `depend` clause related arguments are moved to the new `omp.task`
+  operation from the original 'target' operation.
+
+  Example:
+  Input:
+  ```mlir
+  omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+    ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+        "test.foobar"() : ()->()
+         omp.terminator
+  }
+  ```
+  Output:
+  ```mlir
+  omp.task depend(taskdependout -> %c : memref<?xi32>) {
+    omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) {
+      ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+          "test.foobar"() : ()->()
+           omp.terminator
+    }
+  }
+  ```
+  The intent is to make it easier to translate to LLVMIR by avoiding the
+  creation of such tasks in the OMPIRBuilder
+  }];
 
   let dependentDialects = ["omp::OpenMPDialect"];
 }

>From 65397b62a0df3d61c24d4882c5fb7a4401f4a96d Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 01:11:13 -0600
Subject: [PATCH 09/10] Fix the description of openmp-task-based-target pass

---
 mlir/include/mlir/Dialect/OpenMP/Passes.td | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Passes.td
index d4ff138c3c93c4..4863999db24c56 100644
--- a/mlir/include/mlir/Dialect/OpenMP/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.td
@@ -42,10 +42,11 @@ def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> {
           "test.foobar"() : ()->()
            omp.terminator
     }
+    omp.terminator
   }
   ```
   The intent is to make it easier to translate to LLVMIR by avoiding the
-  creation of such tasks in the OMPIRBuilder
+  creation of such tasks in the OMPIRBuilder.
   }];
 
   let dependentDialects = ["omp::OpenMPDialect"];

>From f3df411a6319cd9c759232da0fe9038c88a91c11 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 01:22:25 -0600
Subject: [PATCH 10/10] Roll back a now unnecessary change in
 mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 54 ++++++++-----------
 1 file changed, 23 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1b6fa5ebd83158..c844af76817ca7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -685,36 +685,7 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
       ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
   return bodyGenStatus;
 }
-template <typename T>
-static void buildDependData(T taskOrTargetop,
-                            SmallVector<llvm::OpenMPIRBuilder::DependData> &dds,
-                            LLVM::ModuleTranslation &moduleTranslation) {
-  // std::optional<ArrayAttr> depends,
-  //                  OperandRange &dependVars,
-  if (taskOrTargetop.getDependVars().empty())
-    return;
-  std::optional<ArrayAttr> depends = taskOrTargetop.getDepends();
-  const OperandRange &dependVars = taskOrTargetop.getDependVars();
-  for (auto dep : llvm::zip(dependVars, depends->getValue())) {
-    llvm::omp::RTLDependenceKindTy type;
-    switch (
-        cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
-    case mlir::omp::ClauseTaskDepend::taskdependin:
-      type = llvm::omp::RTLDependenceKindTy::DepIn;
-      break;
-      // The OpenMP runtime requires that the codegen for 'depend' clause for
-      // 'out' dependency kind must be the same as codegen for 'depend' clause
-      // with 'inout' dependency.
-    case mlir::omp::ClauseTaskDepend::taskdependout:
-    case mlir::omp::ClauseTaskDepend::taskdependinout:
-      type = llvm::omp::RTLDependenceKindTy::DepInOut;
-      break;
-    };
-    llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
-    llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
-    dds.emplace_back(dd);
-  }
-}
+
 // Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
 static LogicalResult
 convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
@@ -777,7 +748,28 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   };
 
   SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
-  buildDependData(taskOp, dds, moduleTranslation);
+  if (!taskOp.getDependVars().empty() && taskOp.getDepends()) {
+    for (auto dep :
+         llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) {
+      llvm::omp::RTLDependenceKindTy type;
+      switch (
+          cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
+      case mlir::omp::ClauseTaskDepend::taskdependin:
+        type = llvm::omp::RTLDependenceKindTy::DepIn;
+        break;
+        // The OpenMP runtime requires that the codegen for 'depend' clause for
+        // 'out' dependency kind must be the same as codegen for 'depend' clause
+        // with 'inout' dependency.
+      case mlir::omp::ClauseTaskDepend::taskdependout:
+      case mlir::omp::ClauseTaskDepend::taskdependinout:
+        type = llvm::omp::RTLDependenceKindTy::DepInOut;
+        break;
+      };
+      llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
+      llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+      dds.emplace_back(dd);
+    }
+  }
 
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);



More information about the flang-commits mailing list