[flang-commits] [flang] [mlir] [mlir][OpenMP] - Transform target offloading directives for easier translation to LLVMIR (PR #83966)
Pranav Bhandarkar via flang-commits
flang-commits at lists.llvm.org
Wed May 1 22:00:48 PDT 2024
https://github.com/bhandarkar-pranav updated https://github.com/llvm/llvm-project/pull/83966
>From 25c5dda574e6df26d286f8aaa7218e5f1402e0e0 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 20 Feb 2024 12:51:45 -0600
Subject: [PATCH 01/13] Checkpoint commit - Add buildDependData
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 54 +++++++++++--------
1 file changed, 31 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f87f89d8c636b..16baf3a3a47494 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -685,7 +685,36 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
return bodyGenStatus;
}
-
+template <typename T>
+static void buildDependData(T taskOrTargetop,
+ SmallVector<llvm::OpenMPIRBuilder::DependData> &dds,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ // std::optional<ArrayAttr> depends,
+ // OperandRange &dependVars,
+ if (taskOrTargetop.getDependVars().empty())
+ return;
+ std::optional<ArrayAttr> depends = taskOrTargetop.getDepends();
+ const OperandRange &dependVars = taskOrTargetop.getDependVars();
+ for (auto dep : llvm::zip(dependVars, depends->getValue())) {
+ llvm::omp::RTLDependenceKindTy type;
+ switch (
+ cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
+ case mlir::omp::ClauseTaskDepend::taskdependin:
+ type = llvm::omp::RTLDependenceKindTy::DepIn;
+ break;
+ // The OpenMP runtime requires that the codegen for 'depend' clause for
+ // 'out' dependency kind must be the same as codegen for 'depend' clause
+ // with 'inout' dependency.
+ case mlir::omp::ClauseTaskDepend::taskdependout:
+ case mlir::omp::ClauseTaskDepend::taskdependinout:
+ type = llvm::omp::RTLDependenceKindTy::DepInOut;
+ break;
+ };
+ llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
+ llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+ dds.emplace_back(dd);
+ }
+}
// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
static LogicalResult
convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
@@ -748,28 +777,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
};
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
- if (!taskOp.getDependVars().empty() && taskOp.getDepends()) {
- for (auto dep :
- llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) {
- llvm::omp::RTLDependenceKindTy type;
- switch (
- cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
- case mlir::omp::ClauseTaskDepend::taskdependin:
- type = llvm::omp::RTLDependenceKindTy::DepIn;
- break;
- // The OpenMP runtime requires that the codegen for 'depend' clause for
- // 'out' dependency kind must be the same as codegen for 'depend' clause
- // with 'inout' dependency.
- case mlir::omp::ClauseTaskDepend::taskdependout:
- case mlir::omp::ClauseTaskDepend::taskdependinout:
- type = llvm::omp::RTLDependenceKindTy::DepInOut;
- break;
- };
- llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
- llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
- dds.emplace_back(dd);
- }
- }
+ buildDependData(taskOp, dds, moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
>From 99819c70d52b5dd4233456ccd05f024b749264bf Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Fri, 1 Mar 2024 00:28:34 -0600
Subject: [PATCH 02/13] checkpoint commit - pass openmp-task-based-target set
up
---
.../flang/Optimizer/Builder/FIRBuilder.h | 5 ++
.../mlir/Dialect/OpenMP/CMakeLists.txt | 8 ++++
mlir/include/mlir/Dialect/OpenMP/Passes.h | 35 ++++++++++++++
mlir/include/mlir/Dialect/OpenMP/Passes.td | 23 ++++++++++
mlir/include/mlir/InitAllPasses.h | 2 +
mlir/lib/CAPI/Dialect/CMakeLists.txt | 1 +
mlir/lib/Dialect/OpenMP/CMakeLists.txt | 20 +-------
mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt | 18 ++++++++
.../Dialect/OpenMP/Transforms/CMakeLists.txt | 15 ++++++
.../Transforms/OpenMPTaskBasedTarget.cpp | 46 +++++++++++++++++++
10 files changed, 155 insertions(+), 18 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/OpenMP/Passes.h
create mode 100644 mlir/include/mlir/Dialect/OpenMP/Passes.td
create mode 100644 mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index e4c954159f71be..28c24ee77fb245 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -22,6 +22,7 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/DenseMap.h"
@@ -512,6 +513,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
if (previous.isSet())
return;
setCommonAttributes(op);
+ mlir::omp::TargetOp targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(op);
+ if (targetOp)
+ llvm::errs() << "Inserted operation\n";
+ // op->dump();
}
private:
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index 419e24a7335361..51ab0f23cd00d5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -23,3 +23,11 @@ mlir_tablegen(OpenMPTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIROpenMPTypeInterfacesIncGen)
add_dependencies(mlir-generic-headers MLIROpenMPTypeInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenMP)
+mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix OpenMP)
+mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix OpenMP)
+add_public_tablegen_target(MLIROpenMPPassIncGen)
+
+add_mlir_doc(Passes OpenMPPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.h b/mlir/include/mlir/Dialect/OpenMP/Passes.h
new file mode 100644
index 00000000000000..2167c95055d31f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.h
@@ -0,0 +1,35 @@
+//===- Passes.h - OpenMP passes entry points -----------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_PASSES_H
+#define MLIR_DIALECT_OPENMP_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+std::unique_ptr<Pass> createOpenMPTaskBasedTargetPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+namespace omp {
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/OpenMP/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Passes.td
new file mode 100644
index 00000000000000..57d73384856ec5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.td
@@ -0,0 +1,23 @@
+//===-- Passes.td - OpenMP pass definition file -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_PASSES
+#define MLIR_DIALECT_OPENMP_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> {
+ let summary = "Nest certain instances of mlir::omp::TargetOp inside mlir::omp::TaskOp";
+
+ let constructor = "mlir::createOpenMPTaskBasedTargetPass()";
+
+ let description = [{ First pass attempt}];
+
+ let dependentDialects = ["omp::OpenMPDialect"];
+}
+#endif
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 90406f555b0f47..da4db972432f39 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -36,6 +36,7 @@
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/OpenMP/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -83,6 +84,7 @@ inline void registerAllPasses() {
memref::registerMemRefPasses();
mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
+ omp::registerOpenMPPasses();
registerSCFPasses();
registerShapePasses();
spirv::registerSPIRVPasses();
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index 58b8739043f9df..439f0093cc3d26 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -223,6 +223,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIOpenMP
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIROpenMPDialect
+ MLIROpenMPTransforms
)
add_mlir_upstream_c_api_library(MLIRCAPIPDL
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d3445c151c..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,18 +1,2 @@
-add_mlir_dialect_library(MLIROpenMPDialect
- IR/OpenMPDialect.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
-
- DEPENDS
- omp_gen
- MLIROpenMPOpsIncGen
- MLIROpenMPOpsInterfacesIncGen
- MLIROpenMPTypeInterfacesIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRLLVMDialect
- MLIRFuncDialect
- MLIROpenACCMPCommon
- )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..05923032d90770
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_dialect_library(MLIROpenMPDialect
+ OpenMPDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
+
+ DEPENDS
+ omp_gen
+ MLIROpenMPOpsIncGen
+ MLIROpenMPOpsInterfacesIncGen
+ MLIROpenMPTypeInterfacesIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRFuncDialect
+ MLIROpenACCMPCommon
+ )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..06a1187df6d63a
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIROpenMPTransforms
+ OpenMPTaskBasedTarget.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
+
+ DEPENDS
+ MLIROpenMPPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIROpenMPDialect
+ MLIRFuncDialect
+ MLIRIR
+ MLIRPass
+)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
new file mode 100644
index 00000000000000..343d830189686f
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -0,0 +1,46 @@
+//===- OpenMPTaskBasedTarget.cpp - Implementation of OpenMPTaskBasedTargetPass
+//---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements scf.parallel to scf.for + async.execute conversion pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_OPENMPTASKBASEDTARGET
+#include "mlir/Dialect/OpenMP/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::omp;
+
+#define DEBUG_TYPE "openmp-task-based-target"
+
+namespace {
+
+struct OpenMPTaskBasedTargetPass
+ : public impl::OpenMPTaskBasedTargetBase<OpenMPTaskBasedTargetPass> {
+
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void OpenMPTaskBasedTargetPass::runOnOperation() {
+ Operation *op = getOperation();
+
+ op->dump();
+}
+std::unique_ptr<Pass> mlir::createOpenMPTaskBasedTargetPass() {
+ return std::make_unique<OpenMPTaskBasedTargetPass>();
+}
>From 3b9714011d6e833834f6701717266d474c5f1367 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Sat, 2 Mar 2024 07:29:35 -0600
Subject: [PATCH 03/13] Add patterns that match but do not rewrite anytyhing
yet
---
.../Dialect/OpenMP/Transforms/CMakeLists.txt | 1 +
.../Transforms/OpenMPTaskBasedTarget.cpp | 27 +++++++++++++++++--
2 files changed, 26 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
index 06a1187df6d63a..1a64b5268e0839 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIROpenMPTransforms
MLIRFuncDialect
MLIRIR
MLIRPass
+ MLIRTransforms
)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index 343d830189686f..3c37848c36872f 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -15,6 +15,8 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_OPENMPTASKBASEDTARGET
@@ -33,13 +35,34 @@ struct OpenMPTaskBasedTargetPass
void runOnOperation() override;
};
-
+template <typename OpTy>
+class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (op.getDependVars().empty()) {
+ return rewriter.notifyMatchFailure(op, "depend clause not found on op");
+ }
+ return success();
+ }
+};
} // namespace
+static void
+populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>>(
+ patterns.getContext());
+}
void OpenMPTaskBasedTargetPass::runOnOperation() {
Operation *op = getOperation();
+ LLVM_DEBUG(llvm::dbgs() << "Running on the following operation\n");
+ // LLVM_DEBUG(llvm::dbgs() << op->dump());
- op->dump();
+ RewritePatternSet patterns(op->getContext());
+ populateOmpTaskBasedTargetRewritePatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ signalPassFailure();
}
std::unique_ptr<Pass> mlir::createOpenMPTaskBasedTargetPass() {
return std::make_unique<OpenMPTaskBasedTargetPass>();
>From 39939978abac2a95f6d6cc60a37f7ec01bf5684a Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Mon, 4 Mar 2024 21:30:01 -0600
Subject: [PATCH 04/13] Working for a simple testcase
---
.../Transforms/OpenMPTaskBasedTarget.cpp | 34 +++++++++++++++++--
.../Dialect/OpenMP/task-based-target.mlir | 13 +++++++
2 files changed, 45 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Dialect/OpenMP/task-based-target.mlir
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index 3c37848c36872f..ec173e05e36f98 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -41,9 +41,40 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
+
+ // Only match a target op with a 'depend' clause on it.
if (op.getDependVars().empty()) {
return rewriter.notifyMatchFailure(op, "depend clause not found on op");
}
+
+ // Step 1: Create a new task op and tack on the dependency from the 'depend'
+ // clause on it.
+ omp::TaskOp taskOp = rewriter.create<omp::TaskOp>(
+ op.getLoc(), /*if_expr*/ Value(),
+ /*final_expr*/ Value(),
+ /*untied*/ UnitAttr(),
+ /*mergeable*/ UnitAttr(),
+ /*in_reduction_vars*/ ValueRange(),
+ /*in_reductions*/ nullptr,
+ /*priority*/ Value(), op.getDepends().value(), op.getDependVars(),
+ /*allocate_vars*/ ValueRange(),
+ /*allocate_vars*/ ValueRange());
+ Block *block = rewriter.createBlock(&taskOp.getRegion());
+ rewriter.setInsertionPointToEnd(block);
+ // Step 2: Clone and put the entire target op inside the newly created
+ // task's region.
+ Operation *clonedTargetOperation = rewriter.clone(*op.getOperation());
+ rewriter.create<mlir::omp::TerminatorOp>(op.getLoc());
+
+ // Step 3: Remove the dependency information from the clone target op.
+ omp::TargetOp clonedTargetOp =
+ llvm::dyn_cast<omp::TargetOp>(clonedTargetOperation);
+ if (clonedTargetOp) {
+ clonedTargetOp.removeDependsAttr();
+ clonedTargetOp.getDependVarsMutable().clear();
+ }
+ // Step 4: Erase the original target op
+ rewriter.eraseOp(op.getOperation());
return success();
}
};
@@ -56,11 +87,10 @@ populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
void OpenMPTaskBasedTargetPass::runOnOperation() {
Operation *op = getOperation();
- LLVM_DEBUG(llvm::dbgs() << "Running on the following operation\n");
- // LLVM_DEBUG(llvm::dbgs() << op->dump());
RewritePatternSet patterns(op->getContext());
populateOmpTaskBasedTargetRewritePatterns(patterns);
+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Dialect/OpenMP/task-based-target.mlir b/mlir/test/Dialect/OpenMP/task-based-target.mlir
new file mode 100644
index 00000000000000..d285fe2f9abfe3
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/task-based-target.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -openmp-task-based-target -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @omp_target_depend
+// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
+func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
+ // CHECK: omp.task depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+ // CHECK: omp.target {
+ omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+ // CHECK: omp.terminator
+ omp.terminator
+ } {operandSegmentSizes = array<i32: 0,0,0,3,0>}
+ return
+}
>From 167bb19390f2d2ba74d0dc6c028c6c27f0d5728d Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 00:19:22 -0600
Subject: [PATCH 05/13] Exten openmp-task-based-target transformation to target
enter/update/exit data as well
---
.../Transforms/OpenMPTaskBasedTarget.cpp | 8 +--
.../Dialect/OpenMP/task-based-target.mlir | 54 +++++++++++++++++++
2 files changed, 59 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index ec173e05e36f98..cbb8546404a9ab 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -67,8 +67,7 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
rewriter.create<mlir::omp::TerminatorOp>(op.getLoc());
// Step 3: Remove the dependency information from the clone target op.
- omp::TargetOp clonedTargetOp =
- llvm::dyn_cast<omp::TargetOp>(clonedTargetOperation);
+ OpTy clonedTargetOp = llvm::dyn_cast<OpTy>(clonedTargetOperation);
if (clonedTargetOp) {
clonedTargetOp.removeDependsAttr();
clonedTargetOp.getDependVarsMutable().clear();
@@ -81,7 +80,10 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
} // namespace
static void
populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
- patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>>(
+ patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>,
+ OmpTaskBasedTargetRewritePattern<omp::EnterDataOp>,
+ OmpTaskBasedTargetRewritePattern<omp::UpdateDataOp>,
+ OmpTaskBasedTargetRewritePattern<omp::ExitDataOp>>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/OpenMP/task-based-target.mlir b/mlir/test/Dialect/OpenMP/task-based-target.mlir
index d285fe2f9abfe3..26cc493047e19b 100644
--- a/mlir/test/Dialect/OpenMP/task-based-target.mlir
+++ b/mlir/test/Dialect/OpenMP/task-based-target.mlir
@@ -11,3 +11,57 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
} {operandSegmentSizes = array<i32: 0,0,0,3,0>}
return
}
+// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
+// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
+func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
+// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
+ %map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+ %map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ %map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
+
+ // Do some work on the host that writes to 'a'
+ omp.task depend(taskdependout -> %a : memref<?xi32>) {
+ "test.foo"(%a) : (memref<?xi32>) -> ()
+ omp.terminator
+ }
+
+ // Then map that over to the target
+ // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ // CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>)
+ omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>)
+
+ // Compute 'b' on the target and copy it back
+ // CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
+ omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
+ ^bb0(%arg0: memref<?xi32>) :
+ "test.foo"(%arg0) : (memref<?xi32>) -> ()
+ omp.terminator
+ }
+
+ // Update 'a' on the host using 'b'
+ omp.task depend(taskdependout -> %a: memref<?xi32>){
+ "test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
+ }
+
+ // Copy the updated 'a' onto the target
+ // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ // CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>)
+ omp.target_update_data motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
+
+ // Compute 'c' on the target and copy it back
+ // CHECK:[[MAP3:%.*]] = omp.map_info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ %map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ // CHECK: omp.task depend(taskdependout -> [[ARG2]] : memref<?xi32>)
+ // CHECK: omp.target map_entries([[MAP0]] -> {{%.*}}, [[MAP3]] -> {{%.*}} : memref<?xi32>, memref<?xi32>) {
+ omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+ ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+ "test.foobar"() : ()->()
+ omp.terminator
+ }
+ // CHECK: omp.task depend(taskdependin -> [[ARG2]] : memref<?xi32>) {
+ // CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>)
+ omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
+ return
+}
>From 5789bfb612c504902afb3c247f1d0160755d4b58 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 00:35:39 -0600
Subject: [PATCH 06/13] Remove some debug prints
---
flang/include/flang/Optimizer/Builder/FIRBuilder.h | 5 -----
1 file changed, 5 deletions(-)
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 28c24ee77fb245..e4c954159f71be 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -22,7 +22,6 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/DenseMap.h"
@@ -513,10 +512,6 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
if (previous.isSet())
return;
setCommonAttributes(op);
- mlir::omp::TargetOp targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(op);
- if (targetOp)
- llvm::errs() << "Inserted operation\n";
- // op->dump();
}
private:
>From e60969a208ec662bb916edf52423b6860d1a9ad9 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 00:43:59 -0600
Subject: [PATCH 07/13] Fix top level comment in OpenMPTaskBasedTarget.cpp
---
.../Transforms/OpenMPTaskBasedTarget.cpp | 21 ++++++++++++++++++-
1 file changed, 20 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index cbb8546404a9ab..1e3977eb33e754 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -7,7 +7,26 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements scf.parallel to scf.for + async.execute conversion pass.
+// This file implements a pass that transforms certain omp.target.
+// Specifically, an omp.target op that has the depend clause on it is
+// transformed into an omp.task clause with the same depend clause on it.
+// The original omp.target loses its depend clause and is contained in
+// the new task region.
+//
+// omp.target depend(..) {
+// omp.terminator
+//
+// }
+//
+// =>
+//
+// omp.task depend(..) {
+// omp.target {
+// omp.terminator
+// }
+// omp.terminator
+// }
+//
//
//===----------------------------------------------------------------------===//
>From 23b1ea06fdd17fe637b5cdf445b69296026d85fc Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 01:08:47 -0600
Subject: [PATCH 08/13] Add a description for openmp-task-based-target pass
---
mlir/include/mlir/Dialect/OpenMP/Passes.td | 32 +++++++++++++++++++++-
1 file changed, 31 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Passes.td
index 57d73384856ec5..d4ff138c3c93c4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.td
@@ -16,7 +16,37 @@ def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> {
let constructor = "mlir::createOpenMPTaskBasedTargetPass()";
- let description = [{ First pass attempt}];
+ let description = [{
+ This pass transforms `omp.target`, `omp.target_enter_data`,
+ `omp.target_update_data` and `omp.target_exit_data` whenever these operations
+ have the `depend` clause on them.
+
+ These operations are transformed by enclosing them inside a new `omp.task`
+ operation. The `depend` clause related arguments are moved to the new `omp.task`
+ operation from the original 'target' operation.
+
+ Example:
+ Input:
+ ```mlir
+ omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+ ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+ "test.foobar"() : ()->()
+ omp.terminator
+ }
+ ```
+ Output:
+ ```mlir
+ omp.task depend(taskdependout -> %c : memref<?xi32>) {
+ omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) {
+ ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+ "test.foobar"() : ()->()
+ omp.terminator
+ }
+ }
+ ```
+ The intent is to make it easier to translate to LLVMIR by avoiding the
+ creation of such tasks in the OMPIRBuilder
+ }];
let dependentDialects = ["omp::OpenMPDialect"];
}
>From cee6f362ee78363bd6a3b6d4e876b4f53bf8c881 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 01:11:13 -0600
Subject: [PATCH 09/13] Fix the description of openmp-task-based-target pass
---
mlir/include/mlir/Dialect/OpenMP/Passes.td | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Passes.td
index d4ff138c3c93c4..4863999db24c56 100644
--- a/mlir/include/mlir/Dialect/OpenMP/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.td
@@ -42,10 +42,11 @@ def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> {
"test.foobar"() : ()->()
omp.terminator
}
+ omp.terminator
}
```
The intent is to make it easier to translate to LLVMIR by avoiding the
- creation of such tasks in the OMPIRBuilder
+ creation of such tasks in the OMPIRBuilder.
}];
let dependentDialects = ["omp::OpenMPDialect"];
>From 8cc7a8989d1d0cbe96c890ce677cd579a78b8acf Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 01:22:25 -0600
Subject: [PATCH 10/13] Roll back a now unnecessary change in
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 54 ++++++++-----------
1 file changed, 23 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 16baf3a3a47494..81f7563d1eb63b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -685,36 +685,7 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
return bodyGenStatus;
}
-template <typename T>
-static void buildDependData(T taskOrTargetop,
- SmallVector<llvm::OpenMPIRBuilder::DependData> &dds,
- LLVM::ModuleTranslation &moduleTranslation) {
- // std::optional<ArrayAttr> depends,
- // OperandRange &dependVars,
- if (taskOrTargetop.getDependVars().empty())
- return;
- std::optional<ArrayAttr> depends = taskOrTargetop.getDepends();
- const OperandRange &dependVars = taskOrTargetop.getDependVars();
- for (auto dep : llvm::zip(dependVars, depends->getValue())) {
- llvm::omp::RTLDependenceKindTy type;
- switch (
- cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
- case mlir::omp::ClauseTaskDepend::taskdependin:
- type = llvm::omp::RTLDependenceKindTy::DepIn;
- break;
- // The OpenMP runtime requires that the codegen for 'depend' clause for
- // 'out' dependency kind must be the same as codegen for 'depend' clause
- // with 'inout' dependency.
- case mlir::omp::ClauseTaskDepend::taskdependout:
- case mlir::omp::ClauseTaskDepend::taskdependinout:
- type = llvm::omp::RTLDependenceKindTy::DepInOut;
- break;
- };
- llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
- llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
- dds.emplace_back(dd);
- }
-}
+
// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
static LogicalResult
convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
@@ -777,7 +748,28 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
};
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
- buildDependData(taskOp, dds, moduleTranslation);
+ if (!taskOp.getDependVars().empty() && taskOp.getDepends()) {
+ for (auto dep :
+ llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) {
+ llvm::omp::RTLDependenceKindTy type;
+ switch (
+ cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
+ case mlir::omp::ClauseTaskDepend::taskdependin:
+ type = llvm::omp::RTLDependenceKindTy::DepIn;
+ break;
+ // The OpenMP runtime requires that the codegen for 'depend' clause for
+ // 'out' dependency kind must be the same as codegen for 'depend' clause
+ // with 'inout' dependency.
+ case mlir::omp::ClauseTaskDepend::taskdependout:
+ case mlir::omp::ClauseTaskDepend::taskdependinout:
+ type = llvm::omp::RTLDependenceKindTy::DepInOut;
+ break;
+ };
+ llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
+ llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+ dds.emplace_back(dd);
+ }
+ }
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
>From 0f71c22bc5357050aaa4ae39fd9378f588460d35 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 5 Mar 2024 01:23:42 -0600
Subject: [PATCH 11/13] Roll back one more uninted change in
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
---
.../LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 81f7563d1eb63b..9f87f89d8c636b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -757,9 +757,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
case mlir::omp::ClauseTaskDepend::taskdependin:
type = llvm::omp::RTLDependenceKindTy::DepIn;
break;
- // The OpenMP runtime requires that the codegen for 'depend' clause for
- // 'out' dependency kind must be the same as codegen for 'depend' clause
- // with 'inout' dependency.
+ // The OpenMP runtime requires that the codegen for 'depend' clause for
+ // 'out' dependency kind must be the same as codegen for 'depend' clause
+ // with 'inout' dependency.
case mlir::omp::ClauseTaskDepend::taskdependout:
case mlir::omp::ClauseTaskDepend::taskdependinout:
type = llvm::omp::RTLDependenceKindTy::DepInOut;
>From 65c75541864e9c8bab5790e99c8a838dd59ec877 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Wed, 1 May 2024 23:58:05 -0500
Subject: [PATCH 12/13] Update to fix build and tests after rebasing on main.
Also add the if clause to the newly generated omp.task op that
encloses the omp.target op.
---
.../Dialect/OpenMP/Transforms/CMakeLists.txt | 1 +
.../Transforms/OpenMPTaskBasedTarget.cpp | 15 +++++++---
.../Dialect/OpenMP/task-based-target.mlir | 30 +++++++++----------
3 files changed, 27 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
index 1a64b5268e0839..8bde30ba74b2a3 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIROpenMPTransforms
LINK_LIBS PUBLIC
MLIROpenMPDialect
+ MLIRArithDialect
MLIRFuncDialect
MLIRIR
MLIRPass
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index 1e3977eb33e754..9f5b3f6dc7bc91 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -33,6 +33,7 @@
#include "mlir/Dialect/OpenMP/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
@@ -68,8 +69,14 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
// Step 1: Create a new task op and tack on the dependency from the 'depend'
// clause on it.
+ Type i1Ty = rewriter.getI1Type();
+ // mlir::BoolAttr T = rewriter.getBoolAttr(true);
+ // mlir::BoolAttr F = rewriter.getBoolAttr(false);
omp::TaskOp taskOp = rewriter.create<omp::TaskOp>(
- op.getLoc(), /*if_expr*/ Value(),
+ op.getLoc(),
+ /*if_expr*/ op.getNowait() ?
+ rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1))
+ : rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)),
/*final_expr*/ Value(),
/*untied*/ UnitAttr(),
/*mergeable*/ UnitAttr(),
@@ -100,9 +107,9 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
static void
populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>,
- OmpTaskBasedTargetRewritePattern<omp::EnterDataOp>,
- OmpTaskBasedTargetRewritePattern<omp::UpdateDataOp>,
- OmpTaskBasedTargetRewritePattern<omp::ExitDataOp>>(
+ OmpTaskBasedTargetRewritePattern<omp::TargetEnterDataOp>,
+ OmpTaskBasedTargetRewritePattern<omp::TargetUpdateOp>,
+ OmpTaskBasedTargetRewritePattern<omp::TargetExitDataOp>>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/OpenMP/task-based-target.mlir b/mlir/test/Dialect/OpenMP/task-based-target.mlir
index 26cc493047e19b..bc6a12ce2ae74f 100644
--- a/mlir/test/Dialect/OpenMP/task-based-target.mlir
+++ b/mlir/test/Dialect/OpenMP/task-based-target.mlir
@@ -3,7 +3,7 @@
// CHECK-LABEL: @omp_target_depend
// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
- // CHECK: omp.task depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+ // CHECK: omp.task if(%false) depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
// CHECK: omp.target {
omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
// CHECK: omp.terminator
@@ -14,12 +14,12 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
-// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
-// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
-// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
- %map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
- %map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
- %map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
+// CHECK: [[MAP0:%.*]] = omp.map.info
+// CHECK-NEXT: [[MAP1:%.*]] = omp.map.info
+// CHECK-NEXT: [[MAP2:%.*]] = omp.map.info
+ %map_a = omp.map.info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+ %map_b = omp.map.info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ %map_c = omp.map.info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
// Do some work on the host that writes to 'a'
omp.task depend(taskdependout -> %a : memref<?xi32>) {
@@ -28,7 +28,7 @@ func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memre
}
// Then map that over to the target
- // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ // CHECK: omp.task if(%true) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
// CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>)
omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>)
@@ -46,21 +46,21 @@ func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memre
}
// Copy the updated 'a' onto the target
- // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
- // CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>)
- omp.target_update_data motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
+ // CHECK: omp.task if(%true) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ // CHECK: omp.target_update nowait motion_entries([[MAP0]] : memref<?xi32>)
+ omp.target_update motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
// Compute 'c' on the target and copy it back
- // CHECK:[[MAP3:%.*]] = omp.map_info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
- %map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
- // CHECK: omp.task depend(taskdependout -> [[ARG2]] : memref<?xi32>)
+ // CHECK:[[MAP3:%.*]] = omp.map.info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ %map_c_from = omp.map.info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ // CHECK: omp.task if(%false) depend(taskdependout -> [[ARG2]] : memref<?xi32>)
// CHECK: omp.target map_entries([[MAP0]] -> {{%.*}}, [[MAP3]] -> {{%.*}} : memref<?xi32>, memref<?xi32>) {
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
"test.foobar"() : ()->()
omp.terminator
}
- // CHECK: omp.task depend(taskdependin -> [[ARG2]] : memref<?xi32>) {
+ // CHECK: omp.task if(%false) depend(taskdependin -> [[ARG2]] : memref<?xi32>) {
// CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>)
omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
return
>From 47a09c8320ee9dedc9e3a818470ef95100eed2b7 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Thu, 2 May 2024 00:00:29 -0500
Subject: [PATCH 13/13] fix formatting issues
---
.../OpenMP/Transforms/OpenMPTaskBasedTarget.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
index 9f5b3f6dc7bc91..29310d6093b99d 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -32,8 +32,8 @@
#include "mlir/Dialect/OpenMP/Passes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
@@ -74,9 +74,11 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
// mlir::BoolAttr F = rewriter.getBoolAttr(false);
omp::TaskOp taskOp = rewriter.create<omp::TaskOp>(
op.getLoc(),
- /*if_expr*/ op.getNowait() ?
- rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1))
- : rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)),
+ /*if_expr*/ op.getNowait()
+ ? rewriter.create<mlir::arith::ConstantOp>(
+ op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1))
+ : rewriter.create<mlir::arith::ConstantOp>(
+ op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)),
/*final_expr*/ Value(),
/*untied*/ UnitAttr(),
/*mergeable*/ UnitAttr(),
More information about the flang-commits
mailing list