[flang-commits] [flang] d21580c - [MLIR][OpenMP]Add Flang lowering support for device_ptr and device_addr clauses
Akash Banerjee via flang-commits
flang-commits at lists.llvm.org
Thu Jun 22 07:52:42 PDT 2023
Author: Akash Banerjee
Date: 2023-06-22T15:52:33+01:00
New Revision: d21580c3065740ac32f3713d737f302ed7f0ac94
URL: https://github.com/llvm/llvm-project/commit/d21580c3065740ac32f3713d737f302ed7f0ac94
DIFF: https://github.com/llvm/llvm-project/commit/d21580c3065740ac32f3713d737f302ed7f0ac94.diff
LOG: [MLIR][OpenMP]Add Flang lowering support for device_ptr and device_addr clauses
Add lowering support for the use_device_ptr and use_Device_addr clauses for the Target Data directive.
Depends on D152822
Differential Revision: https://reviews.llvm.org/D152824
Added:
Modified:
flang/lib/Lower/OpenMP.cpp
flang/test/Lower/OpenMP/target.f90
Removed:
################################################################################
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 141912a18e29e..42fad0d483d19 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -723,6 +723,48 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
}
}
+static void createBodyOfTargetOp(
+ Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp,
+ const llvm::SmallVector<mlir::Type> &useDeviceTypes,
+ const llvm::SmallVector<mlir::Location> &useDeviceLocs,
+ const SmallVector<const Fortran::semantics::Symbol *> &useDeviceSymbols,
+ const mlir::Location ¤tLocation) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Region ®ion = dataOp.getRegion();
+
+ firOpBuilder.createBlock(®ion, {}, useDeviceTypes, useDeviceLocs);
+ firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
+ firOpBuilder.setInsertionPointToStart(®ion.front());
+
+ unsigned argIndex = 0;
+ for (auto *sym : useDeviceSymbols) {
+ const mlir::BlockArgument &arg = region.front().getArgument(argIndex);
+ mlir::Value val = fir::getBase(arg);
+ fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
+ if (auto refType = val.getType().dyn_cast<fir::ReferenceType>()) {
+ if (fir::isa_builtin_cptr_type(refType.getElementType())) {
+ converter.bindSymbol(*sym, val);
+ } else {
+ extVal.match(
+ [&](const fir::MutableBoxValue &mbv) {
+ converter.bindSymbol(
+ *sym,
+ fir::MutableBoxValue(
+ val, fir::factory::getNonDeferredLenParams(extVal), {}));
+ },
+ [&](const auto &) {
+ TODO(converter.getCurrentLocation(),
+ "use_device clause operand unsupported type");
+ });
+ }
+ } else {
+ TODO(converter.getCurrentLocation(),
+ "use_device clause operand unsupported type");
+ }
+ argIndex++;
+ }
+}
+
static void createTargetOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &opClauseList,
const llvm::omp::Directive &directive,
@@ -732,13 +774,24 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand;
mlir::UnitAttr nowaitAttr;
- llvm::SmallVector<mlir::Value> useDevicePtrOperand, useDeviceAddrOperand,
- mapOperands;
+ llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
+ deviceAddrOperands;
llvm::SmallVector<mlir::IntegerAttr> mapTypes;
+ llvm::SmallVector<mlir::Type> useDeviceTypes;
+ llvm::SmallVector<mlir::Location> useDeviceLocs;
+ SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
+
+ /// Check for unsupported map operand types.
+ auto checkType = [](auto currentLocation, mlir::Type type) {
+ if (auto refType = type.dyn_cast<fir::ReferenceType>())
+ type = refType.getElementType();
+ if (auto boxType = type.dyn_cast_or_null<fir::BoxType>())
+ if (!boxType.getElementType().isa<fir::PointerType>())
+ TODO(currentLocation, "OMPD_target_data MapOperand BoxType");
+ };
- auto addMapClause = [&firOpBuilder, &converter, &mapOperands,
- &mapTypes](const auto &mapClause,
- mlir::Location ¤tLocation) {
+ auto addMapClause = [&](const auto &mapClause,
+ mlir::Location ¤tLocation) {
auto mapType = std::get<Fortran::parser::OmpMapType::Type>(
std::get<std::optional<Fortran::parser::OmpMapType>>(mapClause->v.t)
->t);
@@ -793,18 +846,25 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
converter, mapOperand);
for (mlir::Value mapOp : mapOperand) {
- /// Check for unsupported map operand types.
- mlir::Type checkType = mapOp.getType();
- if (auto refType = checkType.dyn_cast<fir::ReferenceType>())
- checkType = refType.getElementType();
- if (checkType.isa<fir::BoxType>())
- TODO(currentLocation, "OMPD_target_data MapOperand BoxType");
-
+ checkType(mapOp.getLoc(), mapOp.getType());
mapOperands.push_back(mapOp);
mapTypes.push_back(mapTypeAttr);
}
};
+ auto addUseDeviceClause = [&](const auto &useDeviceClause, auto &operands) {
+ genObjectList(useDeviceClause, converter, operands);
+ for (auto &operand : operands) {
+ checkType(operand.getLoc(), operand.getType());
+ useDeviceTypes.push_back(operand.getType());
+ useDeviceLocs.push_back(operand.getLoc());
+ }
+ for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
+ Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+ useDeviceSymbols.push_back(sym);
+ }
+ };
+
for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
mlir::Location currentLocation = converter.genLocation(clause.source);
if (const auto &ifClause =
@@ -825,12 +885,6 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
deviceOperand =
fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
}
- } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
- &clause.u)) {
- TODO(currentLocation, "OMPD_target Use Device Ptr");
- } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
- &clause.u)) {
- TODO(currentLocation, "OMPD_target Use Device Addr");
} else if (const auto &threadLmtClause =
std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
&clause.u)) {
@@ -838,6 +892,14 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
*Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx));
} else if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
nowaitAttr = firOpBuilder.getUnitAttr();
+ } else if (const auto &devPtrClause =
+ std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
+ &clause.u)) {
+ addUseDeviceClause(devPtrClause->v, devicePtrOperands);
+ } else if (const auto &devAddrClause =
+ std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
+ &clause.u)) {
+ addUseDeviceClause(devAddrClause->v, deviceAddrOperands);
} else if (const auto &mapClause =
std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
addMapClause(mapClause, currentLocation);
@@ -859,9 +921,10 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList);
} else if (directive == llvm::omp::Directive::OMPD_target_data) {
auto dataOp = firOpBuilder.create<omp::DataOp>(
- currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand,
- useDeviceAddrOperand, mapOperands, mapTypesArrayAttr);
- createBodyOfOp(dataOp, converter, currentLocation, *eval, &opClauseList);
+ currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
+ deviceAddrOperands, mapOperands, mapTypesArrayAttr);
+ createBodyOfTargetOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
+ useDeviceSymbols, currentLocation);
} else if (directive == llvm::omp::Directive::OMPD_target_enter_data) {
firOpBuilder.create<omp::EnterDataOp>(currentLocation, ifClauseOperand,
deviceOperand, nowaitAttr,
@@ -1157,7 +1220,17 @@ genOMP(Fortran::lower::AbstractConverter &converter,
continue;
} else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
// Map clause is exclusive to Target Data directives. It is handled
- // as part of the DataOp creation.
+ // as part of the TargetOp creation.
+ continue;
+ } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
+ &clause.u)) {
+ // UseDevicePtr clause is exclusive to Target Data directives. It is
+ // handled as part of the TargetOp creation.
+ continue;
+ } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
+ &clause.u)) {
+ // UseDeviceAddr clause is exclusive to Target Data directives. It is
+ // handled as part of the TargetOp creation.
continue;
} else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
&clause.u)) {
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 0e8574a821940..44b2b7b53531e 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -162,3 +162,39 @@ subroutine omp_target_thread_limit
!$omp end target
!CHECK: }
end subroutine omp_target_thread_limit
+
+!===============================================================================
+! Target `use_device_ptr` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device_ptr() {
+subroutine omp_target_device_ptr
+ use iso_c_binding, only : c_ptr, c_loc
+ type(c_ptr) :: a
+ integer, target :: b
+ !CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)) use_device_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)
+ !$omp target data map(tofrom: a) use_device_ptr(a)
+ !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
+ !CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+ a = c_loc(b)
+ !CHECK: omp.terminator
+ !$omp end target data
+ !CHECK: }
+end subroutine omp_target_device_ptr
+
+ !===============================================================================
+ ! Target `use_device_addr` clause
+ !===============================================================================
+
+ !CHECK-LABEL: func.func @_QPomp_target_device_addr() {
+ subroutine omp_target_device_addr
+ integer, pointer :: a
+ !CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>)) use_device_addr(%[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<i32>>>)
+ !$omp target data map(tofrom: a) use_device_addr(a)
+ !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
+ !CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
+ a = 10
+ !CHECK: omp.terminator
+ !$omp end target data
+ !CHECK: }
+end subroutine omp_target_device_addr
More information about the flang-commits
mailing list