[Mlir-commits] [mlir] [mlir][acc] Refactor and expand OpenACC utils (PR #115119)
Razvan Lupusoru
llvmlistbot at llvm.org
Tue Nov 5 21:45:00 PST 2024
https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/115119
This MR accomplishes the following:
- OpenACC utilities are moved from the dialect to their own library to be consistent with the way Utils are done in other dialects.
- Adds matching setters for several of the utilities that already had getters.
- Ensure that OpenACCUtils depends on OpenACCDialect (and not vice versa) by inlining some of the used utilities in the dialect.
- Adds unit testing for the utilities
>From 1d816056c03b3d0d533a6c7e7c4bac75b80b4c4c Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Tue, 5 Nov 2024 21:36:05 -0800
Subject: [PATCH] [mlir][acc] Refactor and expand OpenACC utils
This MR accomplishes the following:
- OpenACC utilities are moved from the dialect to their own library
to be consistent with the way Utils are done in other dialects.
- Adds matching setters for several of the utilities that already had
getters.
- Ensure that OpenACCUtils depends on OpenACCDialect (and not vice
versa) by inlining some of the used utilities in the dialect.
- Adds unit testing for the utilities
---
mlir/include/mlir/Dialect/OpenACC/OpenACC.h | 53 --
.../mlir/Dialect/OpenACC/Utils/OpenACCUtils.h | 117 +++++
mlir/lib/Dialect/OpenACC/CMakeLists.txt | 1 +
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 145 +-----
.../Dialect/OpenACC/Transforms/CMakeLists.txt | 1 +
.../OpenACC/Transforms/LegalizeDataValues.cpp | 5 +-
mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt | 20 +
.../Dialect/OpenACC/Utils/OpenACCUtils.cpp | 240 +++++++++
mlir/unittests/Dialect/OpenACC/CMakeLists.txt | 2 +
.../Dialect/OpenACC/OpenACCUtilsTest.cpp | 457 ++++++++++++++++++
10 files changed, 858 insertions(+), 183 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/OpenACC/Utils/OpenACCUtils.h
create mode 100644 mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
create mode 100644 mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index cda07d6a913649..4d399f2e0ed19c 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -83,59 +83,6 @@ namespace acc {
/// combined and the final mapping value would be 5 (4 | 1).
enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4 };
-/// Used to obtain the `varPtr` from a data clause operation.
-/// Returns empty value if not a data clause operation or is a data exit
-/// operation with no `varPtr`.
-mlir::Value getVarPtr(mlir::Operation *accDataClauseOp);
-
-/// Used to obtain the `accPtr` from a data clause operation.
-/// When a data entry operation, it obtains its result `accPtr` value.
-/// If a data exit operation, it obtains its operand `accPtr` value.
-/// Returns empty value if not a data clause operation.
-mlir::Value getAccPtr(mlir::Operation *accDataClauseOp);
-
-/// Used to obtain the `varPtrPtr` from a data clause operation.
-/// Returns empty value if not a data clause operation.
-mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp);
-
-/// Used to obtain `bounds` from an acc data clause operation.
-/// Returns an empty vector if there are no bounds.
-mlir::SmallVector<mlir::Value> getBounds(mlir::Operation *accDataClauseOp);
-
-/// Used to obtain `async` operands from an acc data clause operation.
-/// Returns an empty vector if there are no such operands.
-mlir::SmallVector<mlir::Value>
-getAsyncOperands(mlir::Operation *accDataClauseOp);
-
-/// Returns an array of acc:DeviceTypeAttr attributes attached to
-/// an acc data clause operation, that correspond to the device types
-/// associated with the async clauses with an async-value.
-mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp);
-
-/// Returns an array of acc:DeviceTypeAttr attributes attached to
-/// an acc data clause operation, that correspond to the device types
-/// associated with the async clauses without an async-value.
-mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp);
-
-/// Used to obtain the `name` from an acc operation.
-std::optional<llvm::StringRef> getVarName(mlir::Operation *accOp);
-
-/// Used to obtain the `dataClause` from a data entry operation.
-/// Returns empty optional if not a data entry operation.
-std::optional<mlir::acc::DataClause>
-getDataClause(mlir::Operation *accDataEntryOp);
-
-/// Used to find out whether data operation is implicit.
-/// Returns false if not a data operation or if it is a data operation without
-/// implicit flag.
-bool getImplicitFlag(mlir::Operation *accDataEntryOp);
-
-/// Used to get an immutable range iterating over the data operands.
-mlir::ValueRange getDataOperands(mlir::Operation *accOp);
-
-/// Used to get a mutable range iterating over the data operands.
-mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp);
-
/// Used to obtain the attribute name for declare.
static constexpr StringLiteral getDeclareAttrName() {
return StringLiteral("acc.declare");
diff --git a/mlir/include/mlir/Dialect/OpenACC/Utils/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/Utils/OpenACCUtils.h
new file mode 100644
index 00000000000000..e1d01aa2512c0c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Utils/OpenACCUtils.h
@@ -0,0 +1,117 @@
+//===- OpenACCUtils.h - OpenACC Utilities -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENACC_UTILS_OPENACCUTILS_H_
+#define MLIR_DIALECT_OPENACC_UTILS_OPENACCUTILS_H_
+
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+
+namespace mlir {
+namespace acc {
+/// Used to obtain the `varPtr` from a data clause operation.
+/// Returns empty value if not a data clause operation or is a data exit
+/// operation with no `varPtr`.
+mlir::Value getVarPtr(mlir::Operation *accDataClauseOp);
+
+/// Used to set the `varPtr` of a data clause operation.
+/// Returns true if it was set successfully and false if this is not a data
+/// clause operation.
+bool setVarPtr(mlir::Operation *accDataClauseOp, mlir::Value varPtr);
+
+/// Used to obtain the `accPtr` from a data clause operation.
+/// When a data entry operation, it obtains its result `accPtr` value.
+/// If a data exit operation, it obtains its operand `accPtr` value.
+/// Returns empty value if not a data clause operation.
+mlir::Value getAccPtr(mlir::Operation *accDataClauseOp);
+
+/// Used to set the `accPtr` for a data exit operation.
+/// Returns true if it was set successfully and false if is not a data exit
+/// operation (data entry operations have their result as `accPtr` which
+/// cannot be changed).
+bool setAccPtr(mlir::Operation *accDataClauseOp, mlir::Value accPtr);
+
+/// Used to obtain the `varPtrPtr` from a data clause operation.
+/// Returns empty value if not a data clause operation.
+mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp);
+
+/// Used to set the `varPtrPtr` for a data clause operation.
+/// Returns false if the operation does not have varPtrPtr or is not a data
+/// clause op.
+bool setVarPtrPtr(mlir::Operation *accDataClauseOp, mlir::Value varPtrPtr);
+
+/// Used to obtain `bounds` from an acc data clause operation.
+/// Returns an empty vector if there are no bounds.
+mlir::SmallVector<mlir::Value> getBounds(mlir::Operation *accDataClauseOp);
+
+/// Used to set `bounds` for an acc data clause operation. It completely
+/// replaces all bounds operands with the new list.
+/// Returns false if new bounds were not set (such as when argument is not
+/// an acc dat aclause operation).
+bool setBounds(mlir::Operation *accDataClauseOp,
+ mlir::SmallVector<mlir::Value> &bounds);
+bool setBounds(mlir::Operation *accDataClauseOp,
+ mlir::Value bound);
+
+/// Used to obtain the `dataClause` from a data clause operation.
+/// Returns empty optional if not a data operation.
+std::optional<mlir::acc::DataClause>
+getDataClause(mlir::Operation *accDataClauseOp);
+
+/// Used to set the `dataClause` on a data clause operation.
+/// Returns true if successfully set and false otherwise.
+bool setDataClause(mlir::Operation *accDataClauseOp,
+ mlir::acc::DataClause dataClause);
+
+/// Used to find out whether this data operation uses structured runtime
+/// counters. Returns false if not a data operation or if it is a data operation
+/// without the structured flag set.
+bool getStructuredFlag(mlir::Operation *accDataClauseOp);
+
+/// Used to update the data clause operation whether it represents structured
+/// or dynamic (value of `structured` is passed as false).
+/// Returns true if successfully set and false otherwise.
+bool setStructuredFlag(mlir::Operation *accDataClauseOp, bool structured);
+
+/// Used to find out whether data operation is implicit.
+/// Returns false if not a data operation or if it is a data operation without
+/// implicit flag.
+bool getImplicitFlag(mlir::Operation *accDataClauseOp);
+
+/// Used to update the data clause operation whether this operation is
+/// implicit or explicit (`implicit` set as false).
+/// Returns true if successfully set and false otherwise.
+bool setImplicitFlag(mlir::Operation *accDataClauseOp, bool implicit);
+
+/// Used to obtain the `name` from an acc operation.
+std::optional<llvm::StringRef> getVarName(mlir::Operation *accDataClauseOp);
+
+/// Used to obtain `async` operands from an acc data clause operation.
+/// Returns an empty vector if there are no such operands.
+mlir::SmallVector<mlir::Value>
+getAsyncOperands(mlir::Operation *accDataClauseOp);
+
+/// Returns an array of acc:DeviceTypeAttr attributes attached to
+/// an acc data clause operation, that correspond to the device types
+/// associated with the async clauses with an async-value.
+mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp);
+
+/// Returns an array of acc:DeviceTypeAttr attributes attached to
+/// an acc data clause operation, that correspond to the device types
+/// associated with the async clauses without an async-value.
+mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp);
+
+/// Used to get an immutable range iterating over the data operands.
+mlir::ValueRange getDataOperands(mlir::Operation *accOp);
+
+/// Used to get a mutable range iterating over the data operands.
+mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp);
+
+} // namespace acc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_OPENACC_UTILS_OPENACCUTILS_H_
diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
index 9f57627c321fb0..31167e6af908b9 100644
--- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 280260e0485bb5..1cc67629f97412 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2379,11 +2379,21 @@ checkDeclareOperands(Op &op, const mlir::ValueRange &operands,
"expect valid declare data entry operation or acc.getdeviceptr "
"as defining op");
- mlir::Value varPtr{getVarPtr(operand.getDefiningOp())};
+ mlir::Value varPtr{
+ llvm::TypeSwitch<mlir::Operation *, mlir::Value>(
+ operand.getDefiningOp())
+ .Case<ACC_DATA_ENTRY_OPS>(
+ [&](auto entry) { return entry.getVarPtr(); })
+ .Default([&](mlir::Operation *) { return mlir::Value(); })};
assert(varPtr && "declare operands can only be data entry operations which "
"must have varPtr");
std::optional<mlir::acc::DataClause> dataClauseOptional{
- getDataClause(operand.getDefiningOp())};
+ llvm::TypeSwitch<mlir::Operation *,
+ std::optional<mlir::acc::DataClause>>(
+ operand.getDefiningOp())
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
+ [&](auto entry) { return entry.getDataClause(); })
+ .Default([&](mlir::Operation *) { return std::nullopt; })};
assert(dataClauseOptional.has_value() &&
"declare operands can only be data entry operations which must have "
"dataClause");
@@ -2409,8 +2419,13 @@ checkDeclareOperands(Op &op, const mlir::ValueRange &operands,
// since implicit data action may be inserted to do actions like updating
// device copy, in which case the variable is not necessarily implicitly
// declare'd.
+ bool operandOpImplicitFlag{
+ llvm::TypeSwitch<mlir::Operation *, bool>(operand.getDefiningOp())
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
+ [&](auto entry) { return entry.getImplicit(); })
+ .Default([&](mlir::Operation *) { return false; })};
if (declAttr.getImplicit() &&
- declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
+ declAttr.getImplicit() != operandOpImplicitFlag)
return op.emitError(
"implicitness must match between declare op and flag on variable");
}
@@ -2868,127 +2883,3 @@ LogicalResult acc::WaitOp::verify() {
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
-
-//===----------------------------------------------------------------------===//
-// acc dialect utilities
-//===----------------------------------------------------------------------===//
-
-mlir::Value mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) {
- auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
- .Case<ACC_DATA_ENTRY_OPS>(
- [&](auto entry) { return entry.getVarPtr(); })
- .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
- [&](auto exit) { return exit.getVarPtr(); })
- .Default([&](mlir::Operation *) { return mlir::Value(); })};
- return varPtr;
-}
-
-mlir::Value mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) {
- auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
- .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
- [&](auto dataClause) { return dataClause.getAccPtr(); })
- .Default([&](mlir::Operation *) { return mlir::Value(); })};
- return accPtr;
-}
-
-mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) {
- auto varPtrPtr{
- llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
- .Case<ACC_DATA_ENTRY_OPS>(
- [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
- .Default([&](mlir::Operation *) { return mlir::Value(); })};
- return varPtrPtr;
-}
-
-mlir::SmallVector<mlir::Value>
-mlir::acc::getBounds(mlir::Operation *accDataClauseOp) {
- mlir::SmallVector<mlir::Value> bounds{
- llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
- accDataClauseOp)
- .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
- return mlir::SmallVector<mlir::Value>(
- dataClause.getBounds().begin(), dataClause.getBounds().end());
- })
- .Default([&](mlir::Operation *) {
- return mlir::SmallVector<mlir::Value, 0>();
- })};
- return bounds;
-}
-
-mlir::SmallVector<mlir::Value>
-mlir::acc::getAsyncOperands(mlir::Operation *accDataClauseOp) {
- return llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
- accDataClauseOp)
- .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
- return mlir::SmallVector<mlir::Value>(
- dataClause.getAsyncOperands().begin(),
- dataClause.getAsyncOperands().end());
- })
- .Default([&](mlir::Operation *) {
- return mlir::SmallVector<mlir::Value, 0>();
- });
-}
-
-mlir::ArrayAttr
-mlir::acc::getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp) {
- return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
- .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
- return dataClause.getAsyncOperandsDeviceTypeAttr();
- })
- .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
-}
-
-mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
- return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
- .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
- [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
- .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
-}
-
-std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
- auto name{
- llvm::TypeSwitch<mlir::Operation *, std::optional<llvm::StringRef>>(accOp)
- .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
- .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
- return {};
- })};
- return name;
-}
-
-std::optional<mlir::acc::DataClause>
-mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) {
- auto dataClause{
- llvm::TypeSwitch<mlir::Operation *, std::optional<mlir::acc::DataClause>>(
- accDataEntryOp)
- .Case<ACC_DATA_ENTRY_OPS>(
- [&](auto entry) { return entry.getDataClause(); })
- .Default([&](mlir::Operation *) { return std::nullopt; })};
- return dataClause;
-}
-
-bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) {
- auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
- .Case<ACC_DATA_ENTRY_OPS>(
- [&](auto entry) { return entry.getImplicit(); })
- .Default([&](mlir::Operation *) { return false; })};
- return implicit;
-}
-
-mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) {
- auto dataOperands{
- llvm::TypeSwitch<mlir::Operation *, mlir::ValueRange>(accOp)
- .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
- [&](auto entry) { return entry.getDataClauseOperands(); })
- .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
- return dataOperands;
-}
-
-mlir::MutableOperandRange
-mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
- auto dataOperands{
- llvm::TypeSwitch<mlir::Operation *, mlir::MutableOperandRange>(accOp)
- .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
- [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
- .Default([&](mlir::Operation *) { return nullptr; })};
- return dataOperands;
-}
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 7d934956089a5a..42fcc138c5ffb9 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIROpenACCTransforms
LINK_LIBS PUBLIC
MLIROpenACCDialect
+ MLIROpenACCUtils
MLIRFuncDialect
MLIRIR
MLIRPass
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 026b309ce4969d..028f53d05b1725 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -6,12 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
-
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenACC/Utils/OpenACCUtils.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/ErrorHandling.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt
new file mode 100644
index 00000000000000..83a843b65168ba
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_dialect_library(MLIROpenACCUtils
+ OpenACCUtils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+ DEPENDS
+ MLIROpenACCPassIncGen
+ MLIROpenACCOpsIncGen
+ MLIROpenACCEnumsIncGen
+ MLIROpenACCAttributesIncGen
+ MLIROpenACCMPOpsInterfacesIncGen
+ MLIROpenACCOpsInterfacesIncGen
+ MLIROpenACCTypeInterfacesIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIROpenACCDialect
+ MLIRSupport
+)
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
new file mode 100644
index 00000000000000..2fa4d70fd1f3f7
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -0,0 +1,240 @@
+//===- OpenACCUtils.cpp ---------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Utils/OpenACCUtils.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+mlir::Value mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) {
+ auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, mlir::acc::CopyoutOp,
+ mlir::acc::UpdateHostOp>([&](auto dataClauseOp) {
+ return dataClauseOp.getVarPtr();
+ })
+ .Default([&](mlir::Operation *) { return mlir::Value(); })};
+ return varPtr;
+}
+
+bool mlir::acc::setVarPtr(mlir::Operation *accDataClauseOp,
+ mlir::Value varPtr) {
+ bool res{llvm::TypeSwitch<mlir::Operation *, bool>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, mlir::acc::CopyoutOp,
+ mlir::acc::UpdateHostOp>([&](auto dataClauseOp) {
+ dataClauseOp.getVarPtrMutable().assign(varPtr);
+ return true;
+ })
+ .Default([&](mlir::Operation *) { return false; })};
+ return res;
+}
+
+mlir::Value mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) {
+ auto accPtr{
+ llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
+ [&](auto dataClauseOp) { return dataClauseOp.getAccPtr(); })
+ .Default([&](mlir::Operation *) { return mlir::Value(); })};
+ return accPtr;
+}
+
+bool mlir::acc::setAccPtr(mlir::Operation *accDataClauseOp,
+ mlir::Value accPtr) {
+ bool res{llvm::TypeSwitch<mlir::Operation *, bool>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS>([&](auto dataClauseOp) {
+ // Cannot set the result of an existing operation and
+ // data entry ops produce `accPtr` as a result.
+ return false;
+ })
+ .Case<ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ dataClauseOp.getAccPtrMutable().assign(accPtr);
+ return true;
+ })
+ .Default([&](mlir::Operation *) { return false; })};
+ return res;
+}
+
+mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) {
+ auto varPtrPtr{
+ llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS>(
+ [&](auto dataClauseOp) { return dataClauseOp.getVarPtrPtr(); })
+ .Default([&](mlir::Operation *) { return mlir::Value(); })};
+ return varPtrPtr;
+}
+
+bool mlir::acc::setVarPtrPtr(mlir::Operation *accDataClauseOp,
+ mlir::Value varPtrPtr) {
+ bool res{llvm::TypeSwitch<mlir::Operation *, bool>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS>([&](auto dataClauseOp) {
+ dataClauseOp.getVarPtrPtrMutable().assign(varPtrPtr);
+ return true;
+ })
+ .Default([&](mlir::Operation *) { return false; })};
+ return res;
+}
+
+mlir::SmallVector<mlir::Value>
+mlir::acc::getBounds(mlir::Operation *accDataClauseOp) {
+ mlir::SmallVector<mlir::Value> bounds{
+ llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
+ accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ return mlir::SmallVector<mlir::Value>(
+ dataClauseOp.getBounds().begin(),
+ dataClauseOp.getBounds().end());
+ })
+ .Default([&](mlir::Operation *) {
+ return mlir::SmallVector<mlir::Value, 0>();
+ })};
+ return bounds;
+}
+
+bool mlir::acc::setBounds(mlir::Operation *accDataClauseOp,
+ mlir::SmallVector<mlir::Value> &bounds) {
+ bool res{
+ llvm::TypeSwitch<mlir::Operation *, bool>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ dataClauseOp.getBoundsMutable().assign(bounds);
+ return true;
+ })
+ .Default([&](mlir::Operation *) { return false; })};
+ return res;
+}
+
+bool mlir::acc::setBounds(mlir::Operation *accDataClauseOp, mlir::Value bound) {
+ mlir::SmallVector<mlir::Value> bounds({bound});
+ return setBounds(accDataClauseOp, bounds);
+}
+
+std::optional<llvm::StringRef>
+mlir::acc::getVarName(mlir::Operation *accDataClauseOp) {
+ auto name{
+ llvm::TypeSwitch<mlir::Operation *, std::optional<llvm::StringRef>>(
+ accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
+ [&](auto dataClauseOp) { return dataClauseOp.getName(); })
+ .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
+ return {};
+ })};
+ return name;
+}
+
+std::optional<mlir::acc::DataClause>
+mlir::acc::getDataClause(mlir::Operation *accDataClauseOp) {
+ auto dataClause{
+ llvm::TypeSwitch<mlir::Operation *, std::optional<mlir::acc::DataClause>>(
+ accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
+ [&](auto dataClauseOp) { return dataClauseOp.getDataClause(); })
+ .Default([&](mlir::Operation *) { return std::nullopt; })};
+ return dataClause;
+}
+
+bool mlir::acc::setDataClause(mlir::Operation *accDataClauseOp,
+ mlir::acc::DataClause dataClause) {
+ bool res{
+ llvm::TypeSwitch<mlir::Operation *, bool>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ dataClauseOp.setDataClause(dataClause);
+ return true;
+ })
+ .Default([&](mlir::Operation *) { return false; })};
+ return res;
+}
+
+bool mlir::acc::getStructuredFlag(mlir::Operation *accDataClauseOp) {
+ auto structured{
+ llvm::TypeSwitch<mlir::Operation *, bool>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
+ [&](auto dataClauseOp) { return dataClauseOp.getStructured(); })
+ .Default([&](mlir::Operation *) { return false; })};
+ return structured;
+}
+
+bool mlir::acc::setStructuredFlag(mlir::Operation *accDataClauseOp,
+ bool structured) {
+ auto res{
+ llvm::TypeSwitch<mlir::Operation *, bool>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ dataClauseOp.setStructured(structured);
+ return true;
+ })
+ .Default([&](mlir::Operation *) { return false; })};
+ return res;
+}
+
+bool mlir::acc::getImplicitFlag(mlir::Operation *accDataClauseOp) {
+ auto implicit{
+ llvm::TypeSwitch<mlir::Operation *, bool>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
+ [&](auto dataClauseOp) { return dataClauseOp.getImplicit(); })
+ .Default([&](mlir::Operation *) { return false; })};
+ return implicit;
+}
+
+bool mlir::acc::setImplicitFlag(mlir::Operation *accDataClauseOp,
+ bool implicit) {
+ auto res{
+ llvm::TypeSwitch<mlir::Operation *, bool>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ dataClauseOp.setImplicit(implicit);
+ return true;
+ })
+ .Default([&](mlir::Operation *) { return false; })};
+ return res;
+}
+
+mlir::SmallVector<mlir::Value>
+mlir::acc::getAsyncOperands(mlir::Operation *accDataClauseOp) {
+ return llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
+ accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ return mlir::SmallVector<mlir::Value>(
+ dataClauseOp.getAsyncOperands().begin(),
+ dataClauseOp.getAsyncOperands().end());
+ })
+ .Default([&](mlir::Operation *) {
+ return mlir::SmallVector<mlir::Value, 0>();
+ });
+}
+
+mlir::ArrayAttr
+mlir::acc::getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp) {
+ return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ return dataClauseOp.getAsyncOperandsDeviceTypeAttr();
+ })
+ .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
+}
+
+mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
+ return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
+ [&](auto dataClauseOp) { return dataClauseOp.getAsyncOnlyAttr(); })
+ .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
+}
+
+mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) {
+ auto dataOperands{
+ llvm::TypeSwitch<mlir::Operation *, mlir::ValueRange>(accOp)
+ .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>([&](auto accConstructOp) {
+ return accConstructOp.getDataClauseOperands();
+ })
+ .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
+ return dataOperands;
+}
+
+mlir::MutableOperandRange
+mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
+ auto dataOperands{
+ llvm::TypeSwitch<mlir::Operation *, mlir::MutableOperandRange>(accOp)
+ .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>([&](auto accConstructOp) {
+ return accConstructOp.getDataClauseOperandsMutable();
+ })
+ .Default([&](mlir::Operation *) { return nullptr; })};
+ return dataOperands;
+}
diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
index 5133d7fc38296c..c1ba546f7d0698 100644
--- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
@@ -1,8 +1,10 @@
add_mlir_unittest(MLIROpenACCTests
OpenACCOpsTest.cpp
+ OpenACCUtilsTest.cpp
)
target_link_libraries(MLIROpenACCTests
PRIVATE
MLIRIR
MLIROpenACCDialect
+ MLIROpenACCUtils
)
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
new file mode 100644
index 00000000000000..2f3be7c9106a79
--- /dev/null
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
@@ -0,0 +1,457 @@
+//===- OpenACCUtilsTest.cpp - Unit tests for OpenACC utils ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Utils/OpenACCUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+class OpenACCUtilsTest : public ::testing::Test {
+protected:
+ OpenACCUtilsTest() : b(&context), loc(UnknownLoc::get(&context)) {
+ context.loadDialect<acc::OpenACCDialect, arith::ArithDialect,
+ memref::MemRefDialect>();
+ }
+
+ MLIRContext context;
+ OpBuilder b;
+ Location loc;
+};
+
+template <typename Op>
+void testDataOpVarPtr(OpBuilder &b, MLIRContext &context, Location loc) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+ OwningOpRef<GetDevicePtrOp> accPtrOp =
+ b.create<GetDevicePtrOp>(loc, varPtrOp->getResult(), true, true);
+ auto memrefTy2 = MemRefType::get({}, b.getF64Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp2 =
+ b.create<memref::AllocaOp>(loc, memrefTy2);
+
+ OwningOpRef<Op> op;
+ if constexpr (std::is_same<Op, CopyoutOp>() ||
+ std::is_same<Op, UpdateHostOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ } else {
+ op = b.create<Op>(loc, varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ }
+ EXPECT_EQ(varPtrOp->getResult(), getVarPtr(op.get()));
+ EXPECT_EQ(op->getVarPtr(), getVarPtr(op.get()));
+ setVarPtr(op.get(), varPtrOp2->getResult());
+ EXPECT_EQ(varPtrOp2->getResult(), getVarPtr(op.get()));
+ EXPECT_EQ(op->getVarPtr(), getVarPtr(op.get()));
+}
+
+TEST_F(OpenACCUtilsTest, dataOpVarPtr) {
+ testDataOpVarPtr<PrivateOp>(b, context, loc);
+ testDataOpVarPtr<FirstprivateOp>(b, context, loc);
+ testDataOpVarPtr<ReductionOp>(b, context, loc);
+ testDataOpVarPtr<DevicePtrOp>(b, context, loc);
+ testDataOpVarPtr<PresentOp>(b, context, loc);
+ testDataOpVarPtr<CopyinOp>(b, context, loc);
+ testDataOpVarPtr<CreateOp>(b, context, loc);
+ testDataOpVarPtr<NoCreateOp>(b, context, loc);
+ testDataOpVarPtr<AttachOp>(b, context, loc);
+ testDataOpVarPtr<GetDevicePtrOp>(b, context, loc);
+ testDataOpVarPtr<UpdateDeviceOp>(b, context, loc);
+ testDataOpVarPtr<UseDeviceOp>(b, context, loc);
+ testDataOpVarPtr<DeclareDeviceResidentOp>(b, context, loc);
+ testDataOpVarPtr<DeclareLinkOp>(b, context, loc);
+ testDataOpVarPtr<CacheOp>(b, context, loc);
+ testDataOpVarPtr<CopyoutOp>(b, context, loc);
+ testDataOpVarPtr<UpdateHostOp>(b, context, loc);
+}
+
+template <typename Op>
+void testDataOpAccPtr(OpBuilder &b, MLIRContext &context, Location loc) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+ OwningOpRef<GetDevicePtrOp> accPtrOp =
+ b.create<GetDevicePtrOp>(loc, varPtrOp->getResult(), true, true);
+ auto memrefTy2 = MemRefType::get({}, b.getF64Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp2 =
+ b.create<memref::AllocaOp>(loc, memrefTy2);
+ OwningOpRef<GetDevicePtrOp> accPtrOp2 =
+ b.create<GetDevicePtrOp>(loc, varPtrOp2->getResult(), true, true);
+
+ OwningOpRef<Op> op;
+ if constexpr (std::is_same<Op, CopyoutOp>() ||
+ std::is_same<Op, UpdateHostOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get()));
+ EXPECT_EQ(op->getAccPtr(), accPtrOp->getResult());
+ setAccPtr(op.get(), accPtrOp2->getResult());
+ EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get()));
+ EXPECT_EQ(op->getAccPtr(), accPtrOp2->getResult());
+ } else if constexpr (std::is_same<Op, DeleteOp>() ||
+ std::is_same<Op, DetachOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get()));
+ EXPECT_EQ(op->getAccPtr(), accPtrOp->getResult());
+ setAccPtr(op.get(), accPtrOp2->getResult());
+ EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get()));
+ EXPECT_EQ(op->getAccPtr(), accPtrOp2->getResult());
+ } else {
+ op = b.create<Op>(loc, varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ EXPECT_EQ(op->getAccPtr(), op->getResult());
+ EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get()));
+ }
+}
+
+TEST_F(OpenACCUtilsTest, dataOpAccPtr) {
+ testDataOpAccPtr<PrivateOp>(b, context, loc);
+ testDataOpAccPtr<FirstprivateOp>(b, context, loc);
+ testDataOpAccPtr<ReductionOp>(b, context, loc);
+ testDataOpAccPtr<DevicePtrOp>(b, context, loc);
+ testDataOpAccPtr<PresentOp>(b, context, loc);
+ testDataOpAccPtr<CopyinOp>(b, context, loc);
+ testDataOpAccPtr<CreateOp>(b, context, loc);
+ testDataOpAccPtr<NoCreateOp>(b, context, loc);
+ testDataOpAccPtr<AttachOp>(b, context, loc);
+ testDataOpAccPtr<GetDevicePtrOp>(b, context, loc);
+ testDataOpAccPtr<UpdateDeviceOp>(b, context, loc);
+ testDataOpAccPtr<UseDeviceOp>(b, context, loc);
+ testDataOpAccPtr<DeclareDeviceResidentOp>(b, context, loc);
+ testDataOpAccPtr<DeclareLinkOp>(b, context, loc);
+ testDataOpAccPtr<CacheOp>(b, context, loc);
+ testDataOpAccPtr<CopyoutOp>(b, context, loc);
+ testDataOpAccPtr<UpdateHostOp>(b, context, loc);
+ testDataOpAccPtr<DeleteOp>(b, context, loc);
+ testDataOpAccPtr<DetachOp>(b, context, loc);
+}
+
+template <typename Op>
+void testDataOpVarPtrPtr(OpBuilder &b, MLIRContext &context, Location loc) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+
+ auto memrefTy2 = MemRefType::get({}, memrefTy);
+ OwningOpRef<memref::AllocaOp> varPtrPtr =
+ b.create<memref::AllocaOp>(loc, memrefTy2);
+
+ OwningOpRef<Op> op = b.create<Op>(loc, varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+
+ EXPECT_EQ(op->getVarPtrPtr(), getVarPtrPtr(op.get()));
+ EXPECT_EQ(op->getVarPtrPtr(), Value());
+ setVarPtrPtr(op.get(), varPtrPtr->getResult());
+ EXPECT_EQ(op->getVarPtrPtr(), getVarPtrPtr(op.get()));
+ EXPECT_EQ(op->getVarPtrPtr(), varPtrPtr->getResult());
+}
+
+TEST_F(OpenACCUtilsTest, dataOpVarPtrPtr) {
+ testDataOpVarPtr<PrivateOp>(b, context, loc);
+ testDataOpVarPtr<FirstprivateOp>(b, context, loc);
+ testDataOpVarPtr<ReductionOp>(b, context, loc);
+ testDataOpVarPtr<DevicePtrOp>(b, context, loc);
+ testDataOpVarPtr<PresentOp>(b, context, loc);
+ testDataOpVarPtr<CopyinOp>(b, context, loc);
+ testDataOpVarPtr<CreateOp>(b, context, loc);
+ testDataOpVarPtr<NoCreateOp>(b, context, loc);
+ testDataOpVarPtr<AttachOp>(b, context, loc);
+ testDataOpVarPtr<GetDevicePtrOp>(b, context, loc);
+ testDataOpVarPtr<UpdateDeviceOp>(b, context, loc);
+ testDataOpVarPtr<UseDeviceOp>(b, context, loc);
+ testDataOpVarPtr<DeclareDeviceResidentOp>(b, context, loc);
+ testDataOpVarPtr<DeclareLinkOp>(b, context, loc);
+ testDataOpVarPtr<CacheOp>(b, context, loc);
+}
+
+template <typename Op>
+void testDataOpBounds(OpBuilder &b, MLIRContext &context, Location loc) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+ OwningOpRef<GetDevicePtrOp> accPtrOp =
+ b.create<GetDevicePtrOp>(loc, varPtrOp->getResult(), true, true);
+ OwningOpRef<arith::ConstantIndexOp> extent =
+ b.create<arith::ConstantIndexOp>(loc, 1);
+ OwningOpRef<DataBoundsOp> bounds =
+ b.create<DataBoundsOp>(loc, extent->getResult());
+
+ OwningOpRef<Op> op;
+ if constexpr (std::is_same<Op, CopyoutOp>() ||
+ std::is_same<Op, UpdateHostOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ } else if constexpr (std::is_same<Op, DeleteOp>() ||
+ std::is_same<Op, DetachOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ } else {
+ op = b.create<Op>(loc, varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ }
+
+ EXPECT_EQ(op->getBounds().size(), getBounds(op.get()).size());
+ for (auto [bound1, bound2] :
+ llvm::zip(op->getBounds(), getBounds(op.get()))) {
+ EXPECT_EQ(bound1, bound2);
+ }
+ setBounds(op.get(), bounds->getResult());
+ EXPECT_EQ(op->getBounds().size(), getBounds(op.get()).size());
+ for (auto [bound1, bound2] :
+ llvm::zip(op->getBounds(), getBounds(op.get()))) {
+ EXPECT_EQ(bound1, bound2);
+ EXPECT_EQ(bound1, bounds->getResult());
+ }
+}
+
+TEST_F(OpenACCUtilsTest, dataOpBounds) {
+ testDataOpBounds<PrivateOp>(b, context, loc);
+ testDataOpBounds<FirstprivateOp>(b, context, loc);
+ testDataOpBounds<ReductionOp>(b, context, loc);
+ testDataOpBounds<DevicePtrOp>(b, context, loc);
+ testDataOpBounds<PresentOp>(b, context, loc);
+ testDataOpBounds<CopyinOp>(b, context, loc);
+ testDataOpBounds<CreateOp>(b, context, loc);
+ testDataOpBounds<NoCreateOp>(b, context, loc);
+ testDataOpBounds<AttachOp>(b, context, loc);
+ testDataOpBounds<GetDevicePtrOp>(b, context, loc);
+ testDataOpBounds<UpdateDeviceOp>(b, context, loc);
+ testDataOpBounds<UseDeviceOp>(b, context, loc);
+ testDataOpBounds<DeclareDeviceResidentOp>(b, context, loc);
+ testDataOpBounds<DeclareLinkOp>(b, context, loc);
+ testDataOpBounds<CacheOp>(b, context, loc);
+ testDataOpBounds<CopyoutOp>(b, context, loc);
+ testDataOpBounds<UpdateHostOp>(b, context, loc);
+ testDataOpBounds<DeleteOp>(b, context, loc);
+ testDataOpBounds<DetachOp>(b, context, loc);
+}
+
+template <typename Op>
+void testDataOpName(OpBuilder &b, MLIRContext &context, Location loc) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+ OwningOpRef<GetDevicePtrOp> accPtrOp =
+ b.create<GetDevicePtrOp>(loc, varPtrOp->getResult(), true, true);
+
+ OwningOpRef<Op> op;
+ if constexpr (std::is_same<Op, CopyoutOp>() ||
+ std::is_same<Op, UpdateHostOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true, "varName");
+ } else if constexpr (std::is_same<Op, DeleteOp>() ||
+ std::is_same<Op, DetachOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true, "varName");
+ } else {
+ op = b.create<Op>(loc, varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true, "varName");
+ }
+
+ EXPECT_EQ(op->getNameAttr().str(), "varName");
+ EXPECT_EQ(getVarName(op.get()), "varName");
+}
+
+TEST_F(OpenACCUtilsTest, dataOpName) {
+ testDataOpName<PrivateOp>(b, context, loc);
+ testDataOpName<FirstprivateOp>(b, context, loc);
+ testDataOpName<ReductionOp>(b, context, loc);
+ testDataOpName<DevicePtrOp>(b, context, loc);
+ testDataOpName<PresentOp>(b, context, loc);
+ testDataOpName<CopyinOp>(b, context, loc);
+ testDataOpName<CreateOp>(b, context, loc);
+ testDataOpName<NoCreateOp>(b, context, loc);
+ testDataOpName<AttachOp>(b, context, loc);
+ testDataOpName<GetDevicePtrOp>(b, context, loc);
+ testDataOpName<UpdateDeviceOp>(b, context, loc);
+ testDataOpName<UseDeviceOp>(b, context, loc);
+ testDataOpName<DeclareDeviceResidentOp>(b, context, loc);
+ testDataOpName<DeclareLinkOp>(b, context, loc);
+ testDataOpName<CacheOp>(b, context, loc);
+ testDataOpName<CopyoutOp>(b, context, loc);
+ testDataOpName<UpdateHostOp>(b, context, loc);
+ testDataOpName<DeleteOp>(b, context, loc);
+ testDataOpName<DetachOp>(b, context, loc);
+}
+
+template <typename Op>
+void testDataOpStructured(OpBuilder &b, MLIRContext &context, Location loc) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+ OwningOpRef<GetDevicePtrOp> accPtrOp =
+ b.create<GetDevicePtrOp>(loc, varPtrOp->getResult(), true, true);
+
+ OwningOpRef<Op> op;
+ if constexpr (std::is_same<Op, CopyoutOp>() ||
+ std::is_same<Op, UpdateHostOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ } else if constexpr (std::is_same<Op, DeleteOp>() ||
+ std::is_same<Op, DetachOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ } else {
+ op = b.create<Op>(loc, varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ }
+
+ EXPECT_EQ(op->getStructured(), getStructuredFlag(op.get()));
+ EXPECT_EQ(op->getStructured(), true);
+ setStructuredFlag(op.get(), false);
+ EXPECT_EQ(op->getStructured(), getStructuredFlag(op.get()));
+ EXPECT_EQ(op->getStructured(), false);
+}
+
+TEST_F(OpenACCUtilsTest, dataOpStructured) {
+ testDataOpStructured<PrivateOp>(b, context, loc);
+ testDataOpStructured<FirstprivateOp>(b, context, loc);
+ testDataOpStructured<ReductionOp>(b, context, loc);
+ testDataOpStructured<DevicePtrOp>(b, context, loc);
+ testDataOpStructured<PresentOp>(b, context, loc);
+ testDataOpStructured<CopyinOp>(b, context, loc);
+ testDataOpStructured<CreateOp>(b, context, loc);
+ testDataOpStructured<NoCreateOp>(b, context, loc);
+ testDataOpStructured<AttachOp>(b, context, loc);
+ testDataOpStructured<GetDevicePtrOp>(b, context, loc);
+ testDataOpStructured<UpdateDeviceOp>(b, context, loc);
+ testDataOpStructured<UseDeviceOp>(b, context, loc);
+ testDataOpStructured<DeclareDeviceResidentOp>(b, context, loc);
+ testDataOpStructured<DeclareLinkOp>(b, context, loc);
+ testDataOpStructured<CacheOp>(b, context, loc);
+ testDataOpStructured<CopyoutOp>(b, context, loc);
+ testDataOpStructured<UpdateHostOp>(b, context, loc);
+ testDataOpStructured<DeleteOp>(b, context, loc);
+ testDataOpStructured<DetachOp>(b, context, loc);
+}
+
+template <typename Op>
+void testDataOpImplicit(OpBuilder &b, MLIRContext &context, Location loc) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+ OwningOpRef<GetDevicePtrOp> accPtrOp =
+ b.create<GetDevicePtrOp>(loc, varPtrOp->getResult(), true, true);
+
+ OwningOpRef<Op> op;
+ if constexpr (std::is_same<Op, CopyoutOp>() ||
+ std::is_same<Op, UpdateHostOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ } else if constexpr (std::is_same<Op, DeleteOp>() ||
+ std::is_same<Op, DetachOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ } else {
+ op = b.create<Op>(loc, varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ }
+
+ EXPECT_EQ(op->getImplicit(), getImplicitFlag(op.get()));
+ EXPECT_EQ(op->getImplicit(), true);
+ setImplicitFlag(op.get(), false);
+ EXPECT_EQ(op->getImplicit(), getImplicitFlag(op.get()));
+ EXPECT_EQ(op->getImplicit(), false);
+}
+
+TEST_F(OpenACCUtilsTest, dataOpImplicit) {
+ testDataOpImplicit<PrivateOp>(b, context, loc);
+ testDataOpImplicit<FirstprivateOp>(b, context, loc);
+ testDataOpImplicit<ReductionOp>(b, context, loc);
+ testDataOpImplicit<DevicePtrOp>(b, context, loc);
+ testDataOpImplicit<PresentOp>(b, context, loc);
+ testDataOpImplicit<CopyinOp>(b, context, loc);
+ testDataOpImplicit<CreateOp>(b, context, loc);
+ testDataOpImplicit<NoCreateOp>(b, context, loc);
+ testDataOpImplicit<AttachOp>(b, context, loc);
+ testDataOpImplicit<GetDevicePtrOp>(b, context, loc);
+ testDataOpImplicit<UpdateDeviceOp>(b, context, loc);
+ testDataOpImplicit<UseDeviceOp>(b, context, loc);
+ testDataOpImplicit<DeclareDeviceResidentOp>(b, context, loc);
+ testDataOpImplicit<DeclareLinkOp>(b, context, loc);
+ testDataOpImplicit<CacheOp>(b, context, loc);
+ testDataOpImplicit<CopyoutOp>(b, context, loc);
+ testDataOpImplicit<UpdateHostOp>(b, context, loc);
+ testDataOpImplicit<DeleteOp>(b, context, loc);
+ testDataOpImplicit<DetachOp>(b, context, loc);
+}
+
+template <typename Op>
+void testDataOpDataClause(OpBuilder &b, MLIRContext &context, Location loc,
+ DataClause dataClause) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+ OwningOpRef<GetDevicePtrOp> accPtrOp =
+ b.create<GetDevicePtrOp>(loc, varPtrOp->getResult(), true, true);
+
+ OwningOpRef<Op> op;
+ if constexpr (std::is_same<Op, CopyoutOp>() ||
+ std::is_same<Op, UpdateHostOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ } else if constexpr (std::is_same<Op, DeleteOp>() ||
+ std::is_same<Op, DetachOp>()) {
+ op = b.create<Op>(loc, /*accPtr=*/accPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ } else {
+ op = b.create<Op>(loc, varPtrOp->getResult(),
+ /*structured=*/true, /*implicit=*/true);
+ }
+
+ EXPECT_EQ(op->getDataClause(), getDataClause(op.get()).value());
+ EXPECT_EQ(op->getDataClause(), dataClause);
+ setDataClause(op.get(), DataClause::acc_getdeviceptr);
+ EXPECT_EQ(op->getDataClause(), getDataClause(op.get()).value());
+ EXPECT_EQ(op->getDataClause(), DataClause::acc_getdeviceptr);
+}
+
+TEST_F(OpenACCUtilsTest, dataOpDataClause) {
+ testDataOpDataClause<PrivateOp>(b, context, loc, DataClause::acc_private);
+ testDataOpDataClause<FirstprivateOp>(b, context, loc,
+ DataClause::acc_firstprivate);
+ testDataOpDataClause<ReductionOp>(b, context, loc, DataClause::acc_reduction);
+ testDataOpDataClause<DevicePtrOp>(b, context, loc, DataClause::acc_deviceptr);
+ testDataOpDataClause<PresentOp>(b, context, loc, DataClause::acc_present);
+ testDataOpDataClause<CopyinOp>(b, context, loc, DataClause::acc_copyin);
+ testDataOpDataClause<CreateOp>(b, context, loc, DataClause::acc_create);
+ testDataOpDataClause<NoCreateOp>(b, context, loc, DataClause::acc_no_create);
+ testDataOpDataClause<AttachOp>(b, context, loc, DataClause::acc_attach);
+ testDataOpDataClause<GetDevicePtrOp>(b, context, loc,
+ DataClause::acc_getdeviceptr);
+ testDataOpDataClause<UpdateDeviceOp>(b, context, loc,
+ DataClause::acc_update_device);
+ testDataOpDataClause<UseDeviceOp>(b, context, loc,
+ DataClause::acc_use_device);
+ testDataOpDataClause<DeclareDeviceResidentOp>(
+ b, context, loc, DataClause::acc_declare_device_resident);
+ testDataOpDataClause<DeclareLinkOp>(b, context, loc,
+ DataClause::acc_declare_link);
+ testDataOpDataClause<CacheOp>(b, context, loc, DataClause::acc_cache);
+ testDataOpDataClause<CopyoutOp>(b, context, loc, DataClause::acc_copyout);
+ testDataOpDataClause<UpdateHostOp>(b, context, loc,
+ DataClause::acc_update_host);
+ testDataOpDataClause<DeleteOp>(b, context, loc, DataClause::acc_delete);
+ testDataOpDataClause<DetachOp>(b, context, loc, DataClause::acc_detach);
+}
More information about the Mlir-commits
mailing list