[flang-commits] [flang] [OpenMP][MLIR] Add "IsolatedFromAbove" and "OutlineableOpenMPOpInterface" trait to omp.target (PR #67318)
Akash Banerjee via flang-commits
flang-commits at lists.llvm.org
Wed Sep 27 07:20:08 PDT 2023
https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/67318
>From ec8ff64ca76380cc7db0160b9909f6cd6110d592 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 22 Sep 2023 18:12:06 +0100
Subject: [PATCH 1/2] [OpenMP][MLIR] Add "IsolatedFromAbove" and
"OutlineableOpenMPOpInterface" trait to omp.target
This patch adds the PFT lowering changes required for adding the IsolatedFromAbove and OutlineableOpenMPOpInterface traits to omp.target.
Key Changes:
- Add IsolatedFromAbove and OutlineableOpenMPOpInterface traits to target op in MLIR.
- Main reason for this change is to prevent CSE and other similar optimisations from crossing region boundaries for target operations. The link below has the discourse discussion surrounding this issue.
- Move implicit operand capturing to the PFT lowering stage.
- Update related tests.
Related discussion: https://discourse.llvm.org/t/rfc-prevent-cse-from-removing-expressions-inside-some-non-isolatedfromabove-operation-regions/73150
---
flang/lib/Lower/OpenMP.cpp | 127 +++++++++++++++---
.../Fir/convert-to-llvm-openmp-and-fir.fir | 6 +-
flang/test/Lower/OpenMP/FIR/array-bounds.f90 | 119 ++++++----------
flang/test/Lower/OpenMP/FIR/location.f90 | 2 +-
flang/test/Lower/OpenMP/FIR/target.f90 | 36 ++---
flang/test/Lower/OpenMP/location.f90 | 2 +-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 8 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 12 ++
8 files changed, 195 insertions(+), 117 deletions(-)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 5f5e968eaaa6414..dafd2ae7d13f9e1 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -538,7 +538,11 @@ class ClauseProcessor {
const llvm::omp::Directive &directive,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands) const;
+ llvm::SmallVectorImpl<mlir::Value> &mapOperands,
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
+ llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *mapSymbols = nullptr) const;
bool processReduction(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
@@ -1665,7 +1669,7 @@ static mlir::omp::MapInfoOp
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value baseAddr, std::stringstream &name,
mlir::SmallVector<mlir::Value> bounds, uint64_t mapType,
- mlir::omp::VariableCaptureKind mapCaptureType, bool implicit,
+ mlir::omp::VariableCaptureKind mapCaptureType,
mlir::Type retTy) {
mlir::Value varPtrPtr;
if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
@@ -1676,7 +1680,6 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::omp::MapInfoOp op =
builder.create<mlir::omp::MapInfoOp>(loc, retTy, baseAddr);
op.setNameAttr(builder.getStringAttr(name.str()));
- op.setImplicit(implicit);
op.setMapType(mapType);
op.setMapCaptureType(mapCaptureType);
@@ -1695,7 +1698,11 @@ bool ClauseProcessor::processMap(
mlir::Location currentLocation, const llvm::omp::Directive &directive,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands) const {
+ llvm::SmallVectorImpl<mlir::Value> &mapOperands,
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
+ llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
+ const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<ClauseTy::Map>(
[&](const ClauseTy::Map *mapClause,
@@ -1755,13 +1762,20 @@ bool ClauseProcessor::processMap(
// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
- mapOperands.push_back(createMapInfoOp(
+ mlir::Value mapOp = createMapInfoOp(
firOpBuilder, clauseLocation, baseAddr, asFortran, bounds,
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
- mlir::omp::VariableCaptureKind::ByRef, false,
- baseAddr.getType()));
+ mlir::omp::VariableCaptureKind::ByRef, baseAddr.getType());
+
+ mapOperands.push_back(mapOp);
+ if (mapSymTypes)
+ mapSymTypes->push_back(baseAddr.getType());
+ if (mapSymLocs)
+ mapSymLocs->push_back(baseAddr.getLoc());
+ if (mapSymbols)
+ mapSymbols->push_back(getOmpObjectSymbol(ompObject));
}
});
}
@@ -2142,7 +2156,7 @@ static void createBodyOfOp(
}
}
-static void createBodyOfTargetDataOp(
+static void genBodyOfTargetDataOp(
Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp,
const llvm::SmallVector<mlir::Type> &useDeviceTypes,
const llvm::SmallVector<mlir::Location> &useDeviceLocs,
@@ -2356,8 +2370,8 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>(
currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
deviceAddrOperands, mapOperands);
- createBodyOfTargetDataOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols, currentLocation);
+ genBodyOfTargetDataOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
+ useDeviceSymbols, currentLocation);
return dataOp;
}
@@ -2400,6 +2414,52 @@ genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
deviceOperand, nowaitAttr, mapOperands);
}
+static void genBodyOfTargetOp(
+ Fortran::lower::AbstractConverter &converter, mlir::omp::TargetOp &targetOp,
+ const llvm::SmallVector<mlir::Type> &mapSymTypes,
+ const llvm::SmallVector<mlir::Location> &mapSymLocs,
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &mapSymbols,
+ const mlir::Location ¤tLocation) {
+ assert(mapSymTypes.size() == mapSymLocs.size() &&
+ mapSymTypes.size() == mapSymbols.size());
+
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Region ®ion = targetOp.getRegion();
+
+ firOpBuilder.createBlock(®ion, {}, mapSymTypes, mapSymLocs);
+ firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
+ firOpBuilder.setInsertionPointToStart(®ion.front());
+
+ unsigned argIndex = 0;
+ for (const Fortran::semantics::Symbol *sym : mapSymbols) {
+ const mlir::BlockArgument &arg = region.front().getArgument(argIndex);
+ fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
+ mlir::Value val = fir::getBase(arg);
+ extVal.match(
+ [&](const fir::BoxValue &v) {
+ converter.bindSymbol(*sym, fir::BoxValue(val, v.getLBounds(),
+ v.getExplicitParameters(),
+ v.getExplicitExtents()));
+ },
+ [&](const fir::MutableBoxValue &v) {
+ converter.bindSymbol(*sym,
+ fir::MutableBoxValue(val, v.getLBounds(),
+ v.getMutableProperties()));
+ },
+ [&](const fir::ArrayBoxValue &v) {
+ converter.bindSymbol(*sym, fir::ArrayBoxValue(val, v.getExtents(),
+ v.getLBounds(),
+ v.getSourceBox()));
+ },
+ [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, val); },
+ [&](const auto &) {
+ TODO(converter.getCurrentLocation(),
+ "target map clause operand unsupported type");
+ });
+ argIndex++;
+ }
+}
+
static mlir::omp::TargetOp
genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
@@ -2411,6 +2471,9 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Value> mapOperands;
+ llvm::SmallVector<mlir::Type> mapSymTypes;
+ llvm::SmallVector<mlir::Location> mapSymLocs;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
ClauseProcessor cp(converter, clauseList);
cp.processIf(stmtCtx,
@@ -2420,7 +2483,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
cp.processThreadLimit(stmtCtx, threadLimitOperand);
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
- mapOperands);
+ mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols);
cp.processTODO<Fortran::parser::OmpClause::Private,
Fortran::parser::OmpClause::Depend,
Fortran::parser::OmpClause::Firstprivate,
@@ -2433,10 +2496,44 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::parser::OmpClause::Defaultmap>(
currentLocation, llvm::omp::Directive::OMPD_target);
- return genOpWithBody<mlir::omp::TargetOp>(
- converter, eval, currentLocation, outerCombined, &clauseList,
- ifClauseOperand, deviceOperand, threadLimitOperand, nowaitAttr,
- mapOperands);
+ auto captureImplicitMap = [&](const Fortran::semantics::Symbol &sym) {
+ if (llvm::find(mapSymbols, &sym) == mapSymbols.end()) {
+ mlir::Value baseOp = converter.getSymbolAddress(sym);
+ if (!baseOp)
+ if (const auto *details = sym.template detailsIf<
+ Fortran::semantics::HostAssocDetails>()) {
+ baseOp = converter.getSymbolAddress(details->symbol());
+ converter.copySymbolBinding(details->symbol(), sym);
+ }
+
+ if (baseOp) {
+ llvm::SmallVector<mlir::Value> bounds;
+ std::stringstream name;
+ name << sym.name().ToString();
+ mlir::Value mapOp = createMapInfoOp(
+ converter.getFirOpBuilder(), baseOp.getLoc(), baseOp, name, bounds,
+ static_cast<
+ std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
+ mlir::omp::VariableCaptureKind::ByCopy, baseOp.getType());
+ mapOperands.push_back(mapOp);
+ mapSymTypes.push_back(baseOp.getType());
+ mapSymLocs.push_back(baseOp.getLoc());
+ mapSymbols.push_back(&sym);
+ }
+ }
+ };
+ Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
+
+ auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
+ currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
+ nowaitAttr, mapOperands);
+
+ genBodyOfTargetOp(converter, targetOp, mapSymTypes, mapSymLocs, mapSymbols,
+ currentLocation);
+
+ return targetOp;
}
static mlir::omp::TeamsOp
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index ba1a15bf773d2ef..045230480916b84 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -434,11 +434,12 @@ func.func @_QPomp_target() {
%2 = omp.bounds lower_bound(%c0 : index) upper_bound(%1 : index) extent(%c512 : index) stride(%c1 : index) start_idx(%c1 : index)
%3 = omp.map_info var_ptr(%0 : !fir.ref<!fir.array<512xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%2) -> !fir.ref<!fir.array<512xi32>> {name = "a"}
omp.target thread_limit(%c64_i32 : i32) map_entries(%3 : !fir.ref<!fir.array<512xi32>>) {
+ ^bb0(%arg0: !fir.ref<!fir.array<512xi32>>):
%c10_i32 = arith.constant 10 : i32
%c1_i64 = arith.constant 1 : i64
%c1_i64_0 = arith.constant 1 : i64
%4 = arith.subi %c1_i64, %c1_i64_0 : i64
- %5 = fir.coordinate_of %0, %4 : (!fir.ref<!fir.array<512xi32>>, i64) -> !fir.ref<i32>
+ %5 = fir.coordinate_of %arg0, %4 : (!fir.ref<!fir.array<512xi32>>, i64) -> !fir.ref<i32>
fir.store %c10_i32 to %5 : !fir.ref<i32>
omp.terminator
}
@@ -456,11 +457,12 @@ func.func @_QPomp_target() {
// CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[LOWER]] : i64) upper_bound(%[[UPPER]] : i64) extent(%[[EXTENT]] : i64) stride(%[[STRIDE]] : i64) start_idx(%[[STRIDE]] : i64)
// CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%2 : !llvm.ptr<array<512 x i32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !llvm.ptr<array<512 x i32>> {name = "a"}
// CHECK: omp.target thread_limit(%[[VAL_2]] : i32) map_entries(%[[MAP]] : !llvm.ptr<array<512 x i32>>) {
+// CHECK: ^bb0(%[[ARG_0:.*]]: !llvm.ptr<array<512 x i32>>):
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(10 : i32) : i32
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_1]][0, %[[VAL_6]]] : (!llvm.ptr<array<512 x i32>>, i64) -> !llvm.ptr<i32>
+// CHECK: %[[VAL_7:.*]] = llvm.getelementptr %[[ARG_0]][0, %[[VAL_6]]] : (!llvm.ptr<array<512 x i32>>, i64) -> !llvm.ptr<i32>
// CHECK: llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr<i32>
// CHECK: omp.terminator
// CHECK: }
diff --git a/flang/test/Lower/OpenMP/FIR/array-bounds.f90 b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
index 697ff44176b3f10..01102d3b698dbe4 100644
--- a/flang/test/Lower/OpenMP/FIR/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
@@ -1,37 +1,22 @@
-!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes HOST
-!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes DEVICE
+!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes=HOST,ALL
+!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes=DEVICE,ALL
-!DEVICE-LABEL: func.func @_QPread_write_section_omp_outline_0(
-!DEVICE-SAME: %[[ARG0:.*]]: !fir.ref<i32>, %[[ARG1:.*]]: !fir.ref<!fir.array<10xi32>>, %[[ARG2:.*]]: !fir.ref<!fir.array<10xi32>>) attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>, omp.outline_parent_name = "_QPread_write_section"} {
-!DEVICE: %[[C1:.*]] = arith.constant 1 : index
-!DEVICE: %[[C2:.*]] = arith.constant 4 : index
-!DEVICE: %[[C3:.*]] = arith.constant 1 : index
-!DEVICE: %[[C4:.*]] = arith.constant 1 : index
-!DEVICE: %[[BOUNDS0:.*]] = omp.bounds lower_bound(%[[C1]] : index) upper_bound(%[[C2]] : index) stride(%[[C4]] : index) start_idx(%[[C4]] : index)
-!DEVICE: %[[MAP0:.*]] = omp.map_info var_ptr(%[[ARG1]] : !fir.ref<!fir.array<10xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS0]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_read(2:5)"}
-!DEVICE: %[[C5:.*]] = arith.constant 1 : index
-!DEVICE: %[[C6:.*]] = arith.constant 4 : index
-!DEVICE: %[[C7:.*]] = arith.constant 1 : index
-!DEVICE: %[[C8:.*]] = arith.constant 1 : index
-!DEVICE: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C5]] : index) upper_bound(%[[C6]] : index) stride(%[[C8]] : index) start_idx(%[[C8]] : index)
-!DEVICE: %[[MAP1:.*]] = omp.map_info var_ptr(%[[ARG2]] : !fir.ref<!fir.array<10xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
-!DEVICE: omp.target map_entries(%[[MAP0]], %[[MAP1]] : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>) {
-
-!HOST-LABEL: func.func @_QPread_write_section() {
-!HOST: %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFread_write_sectionEi"}
-!HOST: %[[READ:.*]] = fir.address_of(@_QFread_write_sectionEsp_read) : !fir.ref<!fir.array<10xi32>>
-!HOST: %[[WRITE:.*]] = fir.address_of(@_QFread_write_sectionEsp_write) : !fir.ref<!fir.array<10xi32>>
-!HOST: %[[C1:.*]] = arith.constant 1 : index
-!HOST: %[[C2:.*]] = arith.constant 1 : index
-!HOST: %[[C3:.*]] = arith.constant 4 : index
-!HOST: %[[BOUNDS0:.*]] = omp.bounds lower_bound(%[[C2]] : index) upper_bound(%[[C3]] : index) stride(%[[C1]] : index) start_idx(%[[C1]] : index)
-!HOST: %[[MAP0:.*]] = omp.map_info var_ptr(%[[READ]] : !fir.ref<!fir.array<10xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS0]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_read(2:5)"}
-!HOST: %[[C4:.*]] = arith.constant 1 : index
-!HOST: %[[C5:.*]] = arith.constant 1 : index
-!HOST: %[[C6:.*]] = arith.constant 4 : index
-!HOST: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C5]] : index) upper_bound(%[[C6]] : index) stride(%[[C4]] : index) start_idx(%[[C4]] : index)
-!HOST: %[[MAP1:.*]] = omp.map_info var_ptr(%[[WRITE]] : !fir.ref<!fir.array<10xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
-!HOST: omp.target map_entries(%[[MAP0]], %[[MAP1]] : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>) {
+!ALL-LABEL: func.func @_QPread_write_section(
+!ALL: %[[ITER:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFread_write_sectionEi"}
+!ALL: %[[READ:.*]] = fir.address_of(@_QFread_write_sectionEsp_read) : !fir.ref<!fir.array<10xi32>>
+!ALL: %[[WRITE:.*]] = fir.address_of(@_QFread_write_sectionEsp_write) : !fir.ref<!fir.array<10xi32>>
+!ALL: %[[C1:.*]] = arith.constant 1 : index
+!ALL: %[[C2:.*]] = arith.constant 1 : index
+!ALL: %[[C3:.*]] = arith.constant 4 : index
+!ALL: %[[BOUNDS0:.*]] = omp.bounds lower_bound(%[[C2]] : index) upper_bound(%[[C3]] : index) stride(%[[C1]] : index) start_idx(%[[C1]] : index)
+!ALL: %[[MAP0:.*]] = omp.map_info var_ptr(%[[READ]] : !fir.ref<!fir.array<10xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS0]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_read(2:5)"}
+!ALL: %[[C4:.*]] = arith.constant 1 : index
+!ALL: %[[C5:.*]] = arith.constant 1 : index
+!ALL: %[[C6:.*]] = arith.constant 4 : index
+!ALL: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C5]] : index) upper_bound(%[[C6]] : index) stride(%[[C4]] : index) start_idx(%[[C4]] : index)
+!ALL: %[[MAP1:.*]] = omp.map_info var_ptr(%[[WRITE]] : !fir.ref<!fir.array<10xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
+!ALL: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ITER]] : !fir.ref<i32>) map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
+!ALL: omp.target map_entries(%[[MAP0]], %[[MAP1]], %[[MAP2]] : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>) {
subroutine read_write_section()
integer :: sp_read(10) = (/1,2,3,4,5,6,7,8,9,10/)
@@ -44,33 +29,22 @@ subroutine read_write_section()
!$omp end target
end subroutine read_write_section
-
module assumed_array_routines
- contains
-!DEVICE-LABEL: func.func @_QMassumed_array_routinesPassumed_shape_array_omp_outline_0(
-!DEVICE-SAME: %[[ARG0:.*]]: !fir.ref<i32>, %[[ARG1:.*]]: !fir.box<!fir.array<?xi32>>) attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>, omp.outline_parent_name = "_QMassumed_array_routinesPassumed_shape_array"} {
-!DEVICE: %[[C0:.*]] = arith.constant 1 : index
-!DEVICE: %[[C1:.*]] = arith.constant 4 : index
-!DEVICE: %[[C2:.*]] = arith.constant 0 : index
-!DEVICE: %[[C3:.*]]:3 = fir.box_dims %[[ARG1]], %[[C2]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
-!DEVICE: %[[C4:.*]] = arith.constant 1 : index
-!DEVICE: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C0]] : index) upper_bound(%[[C1]] : index) stride(%[[C3]]#2 : index) start_idx(%[[C4]] : index) {stride_in_bytes = true}
-!DEVICE: %[[ARGADDR:.*]] = fir.box_addr %[[ARG1]] : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
-!DEVICE: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARGADDR]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
-!DEVICE: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<?xi32>>) {
+contains
+!ALL-LABEL: func.func @_QMassumed_array_routinesPassumed_shape_array(
+!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_shape_arrayEi"}
+!ALL: %[[C0:.*]] = arith.constant 1 : index
+!ALL: %[[C1:.*]] = arith.constant 0 : index
+!ALL: %[[C2:.*]]:3 = fir.box_dims %arg0, %[[C1]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+!ALL: %[[C3:.*]] = arith.constant 1 : index
+!ALL: %[[C4:.*]] = arith.constant 4 : index
+!ALL: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C3]] : index) upper_bound(%[[C4]] : index) stride(%[[C2]]#2 : index) start_idx(%[[C0]] : index) {stride_in_bytes = true}
+!ALL: %[[ADDROF:.*]] = fir.box_addr %arg0 : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
+!ALL: %[[MAP:.*]] = omp.map_info var_ptr(%[[ADDROF]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
+!ALL: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ALLOCA]] : !fir.ref<i32>) map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
+!ALL: omp.target map_entries(%[[MAP]], %[[MAP2]] : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
-!HOST-LABEL: func.func @_QMassumed_array_routinesPassumed_shape_array(
-!HOST-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
-!HOST: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_shape_arrayEi"}
-!HOST: %[[C0:.*]] = arith.constant 1 : index
-!HOST: %[[C1:.*]] = arith.constant 0 : index
-!HOST: %[[C2:.*]]:3 = fir.box_dims %arg0, %[[C1]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
-!HOST: %[[C3:.*]] = arith.constant 1 : index
-!HOST: %[[C4:.*]] = arith.constant 4 : index
-!HOST: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C3]] : index) upper_bound(%[[C4]] : index) stride(%[[C2]]#2 : index) start_idx(%[[C0]] : index) {stride_in_bytes = true}
-!HOST: %[[ADDROF:.*]] = fir.box_addr %arg0 : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
-!HOST: %[[MAP:.*]] = omp.map_info var_ptr(%[[ADDROF]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
-!HOST: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<?xi32>>) {
subroutine assumed_shape_array(arr_read_write)
integer, intent(inout) :: arr_read_write(:)
@@ -81,25 +55,17 @@ subroutine assumed_shape_array(arr_read_write)
!$omp end target
end subroutine assumed_shape_array
-!DEVICE-LABEL: func.func @_QMassumed_array_routinesPassumed_size_array_omp_outline_0(
-!DEVICE-SAME: %[[ARG0:.*]]: !fir.ref<i32>, %[[ARG1:.*]]: !fir.ref<!fir.array<?xi32>>) attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>, omp.outline_parent_name = "_QMassumed_array_routinesPassumed_size_array"} {
-!DEVICE: %[[C0:.*]] = arith.constant 1 : index
-!DEVICE: %[[C1:.*]] = arith.constant 4 : index
-!DEVICE: %[[C2:.*]] = arith.constant 1 : index
-!DEVICE: %[[C3:.*]] = arith.constant 1 : index
-!DEVICE: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C0]] : index) upper_bound(%[[C1]] : index) stride(%[[C3]] : index) start_idx(%[[C3]] : index)
-!DEVICE: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG1]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
-!DEVICE: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<?xi32>>) {
+!ALL-LABEL: func.func @_QMassumed_array_routinesPassumed_size_array(
+!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_size_arrayEi"}
+!ALL: %[[C0:.*]] = arith.constant 1 : index
+!ALL: %[[C1:.*]] = arith.constant 1 : index
+!ALL: %[[C2:.*]] = arith.constant 4 : index
+!ALL: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C1]] : index) upper_bound(%[[C2]] : index) stride(%[[C0]] : index) start_idx(%[[C0]] : index)
+!ALL: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG0]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
+!ALL: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ALLOCA]] : !fir.ref<i32>) map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
+!ALL: omp.target map_entries(%[[MAP]], %[[MAP2]] : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
-!HOST-LABEL: func.func @_QMassumed_array_routinesPassumed_size_array(
-!HOST-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
-!HOST: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_size_arrayEi"}
-!HOST: %[[C0:.*]] = arith.constant 1 : index
-!HOST: %[[C1:.*]] = arith.constant 1 : index
-!HOST: %[[C2:.*]] = arith.constant 4 : index
-!HOST: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C1]] : index) upper_bound(%[[C2]] : index) stride(%[[C0]] : index) start_idx(%[[C0]] : index)
-!HOST: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG0]] : !fir.ref<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
-!HOST: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<?xi32>>) {
subroutine assumed_size_array(arr_read_write)
integer, intent(inout) :: arr_read_write(*)
@@ -111,6 +77,7 @@ subroutine assumed_size_array(arr_read_write)
end subroutine assumed_size_array
end module assumed_array_routines
+!DEVICE-NOT:func.func @_QPcall_assumed_shape_and_size_array() {
!HOST-LABEL:func.func @_QPcall_assumed_shape_and_size_array() {
!HOST:%{{.*}} = arith.constant 20 : index
diff --git a/flang/test/Lower/OpenMP/FIR/location.f90 b/flang/test/Lower/OpenMP/FIR/location.f90
index 0e36e09b19e1942..84bbd1605179262 100644
--- a/flang/test/Lower/OpenMP/FIR/location.f90
+++ b/flang/test/Lower/OpenMP/FIR/location.f90
@@ -17,7 +17,7 @@ subroutine sub_parallel()
!CHECK-LABEL: sub_target
subroutine sub_target()
print *, x
-!CHECK: omp.target {
+!CHECK: omp.target
!$omp target
print *, x
!CHECK: omp.terminator loc(#[[TAR_LOC:.*]])
diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90
index 9b1fb5c15ac1d2d..749ddff523500c5 100644
--- a/flang/test/Lower/OpenMP/FIR/target.f90
+++ b/flang/test/Lower/OpenMP/FIR/target.f90
@@ -190,12 +190,13 @@ subroutine omp_target
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) {
+ !CHECK: ^bb0(%[[ARG_0:.*]]: !fir.ref<!fir.array<1024xi32>>):
!$omp target map(tofrom: a)
!CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32
!CHECK: %[[VAL_2:.*]] = arith.constant 1 : i64
!CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64
!CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_2]], %[[VAL_3]] : i64
- !CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
+ !CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG_0]], %[[VAL_4]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
!CHECK: fir.store %[[VAL_1]] to %[[VAL_5]] : !fir.ref<i32>
a(1) = 10
!CHECK: omp.terminator
@@ -213,6 +214,7 @@ subroutine omp_target_thread_limit
!CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "a"}
!CHECK: omp.target thread_limit(%[[VAL_1]] : i32) map_entries(%[[MAP]] : !fir.ref<i32>) {
+ !CHECK: ^bb0(%{{.*}}: !fir.ref<i32>):
!$omp target map(tofrom: a) thread_limit(64)
a = 10
!CHECK: omp.terminator
@@ -274,23 +276,25 @@ subroutine omp_target_parallel_do
!CHECK: %[[C0:.*]] = arith.constant 0 : index
!CHECK: %[[SUB:.*]] = arith.subi %[[C1024]], %[[C1]] : index
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C0]] : index) upper_bound(%[[SUB]] : index) extent(%[[C1024]] : index) stride(%[[C1]] : index) start_idx(%[[C1]] : index)
- !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
- !CHECK: omp.target map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) {
+ !CHECK: %[[MAP1:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+ !CHECK: %[[MAP2:.*]] = omp.map_info var_ptr(%[[VAL_1]] : !fir.ref<i32>) map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
+ !CHECK: omp.target map_entries(%[[MAP1]], %[[MAP2]] : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>) {
+ !CHECK: ^bb0(%[[VAL_2:.*]]: !fir.ref<!fir.array<1024xi32>>, %[[VAL_3:.*]]: !fir.ref<i32>):
!CHECK-NEXT: omp.parallel
!$omp target parallel do map(tofrom: a)
- !CHECK: %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
- !CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
- !CHECK: %[[VAL_4:.*]] = arith.constant 1024 : i32
+ !CHECK: %[[VAL_4:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
!CHECK: %[[VAL_5:.*]] = arith.constant 1 : i32
- !CHECK: omp.wsloop for (%[[VAL_6:.*]]) : i32 = (%[[VAL_3]]) to (%[[VAL_4]]) inclusive step (%[[VAL_5]]) {
- !CHECK: fir.store %[[VAL_6]] to %[[VAL_2]] : !fir.ref<i32>
- !CHECK: %[[VAL_7:.*]] = arith.constant 10 : i32
- !CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
- !CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i32) -> i64
- !CHECK: %[[VAL_10:.*]] = arith.constant 1 : i64
- !CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_9]], %[[VAL_10]] : i64
- !CHECK: %[[VAL_12:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_11]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
- !CHECK: fir.store %[[VAL_7]] to %[[VAL_12]] : !fir.ref<i32>
+ !CHECK: %[[VAL_6:.*]] = arith.constant 1024 : i32
+ !CHECK: %[[VAL_7:.*]] = arith.constant 1 : i32
+ !CHECK: omp.wsloop for (%[[VAL_8:.*]]) : i32 = (%[[VAL_5]]) to (%[[VAL_6]]) inclusive step (%[[VAL_7]]) {
+ !CHECK: fir.store %[[VAL_8]] to %[[VAL_4]] : !fir.ref<i32>
+ !CHECK: %[[VAL_9:.*]] = arith.constant 10 : i32
+ !CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
+ !CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (i32) -> i64
+ !CHECK: %[[VAL_12:.*]] = arith.constant 1 : i64
+ !CHECK: %[[VAL_13:.*]] = arith.subi %[[VAL_11]], %[[VAL_12]] : i64
+ !CHECK: %[[VAL_14:.*]] = fir.coordinate_of %[[VAL_2]], %[[VAL_13]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
+ !CHECK: fir.store %[[VAL_9]] to %[[VAL_14]] : !fir.ref<i32>
do i = 1, 1024
a(i) = 10
end do
@@ -301,4 +305,4 @@ subroutine omp_target_parallel_do
!CHECK: omp.terminator
!CHECK: }
!$omp end target parallel do
-end subroutine omp_target_parallel_do
+ end subroutine omp_target_parallel_do
diff --git a/flang/test/Lower/OpenMP/location.f90 b/flang/test/Lower/OpenMP/location.f90
index c87bf038e967215..953b7362f5586a9 100644
--- a/flang/test/Lower/OpenMP/location.f90
+++ b/flang/test/Lower/OpenMP/location.f90
@@ -17,7 +17,7 @@ subroutine sub_parallel()
!CHECK-LABEL: sub_target
subroutine sub_target()
print *, x
-!CHECK: omp.target {
+!CHECK: omp.target
!$omp target
print *, x
!CHECK: omp.terminator loc(#[[TAR_LOC:.*]])
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index bcd60d8046c8925..248b92a08649c88 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1149,7 +1149,6 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
Variadic<DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
OptionalAttr<UI64Attr>:$map_type,
OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
- DefaultValuedAttr<BoolAttr, "false">:$implicit,
OptionalAttr<StrAttr>:$name);
let results = (outs OpenMP_PointerLikeType:$omp_ptr);
@@ -1177,7 +1176,7 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
```
=>
```mlir
- omp.map_info var_ptr(%index_ssa) map_type(to) map_capture_type(ByRef) implicit(false)
+ omp.map_info var_ptr(%index_ssa) map_type(to) map_capture_type(ByRef)
name(index)
```
@@ -1189,9 +1188,6 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
objects (e.g. derived types or classes), indicates the bounds to be copied
of the variable. When it's an array slice it is in rank order where rank 0
is the inner-most dimension.
- - `implicit`: indicates where the map item has been specified explicitly in a
- map clause or captured implicitly by being used in a target region with no
- map or other data mapping construct.
- 'map_clauses': OpenMP map type for this map capture, for example: from, to and
always. It's a bitfield composed of the OpenMP runtime flags stored in
OpenMPOffloadMappingFlags.
@@ -1375,7 +1371,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
// 2.14.5 target construct
//===----------------------------------------------------------------------===//
-def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
+def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterface, AttrSizedOperandSegments]> {
let summary = "target construct";
let description = [{
The target construct includes a region of code which is to be executed
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2bf9355ed62676b..bb4fc191432957c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -693,6 +693,12 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
if (mapTypeMod == "always")
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
+ if (mapTypeMod == "literal")
+ mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
+
+ if (mapTypeMod == "implicit")
+ mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+
if (mapTypeMod == "close")
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
@@ -740,6 +746,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
if (mapTypeToBitFlag(mapTypeBits,
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
mapTypeStrs.push_back("always");
+ if (mapTypeToBitFlag(mapTypeBits,
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL))
+ mapTypeStrs.push_back("literal");
+ if (mapTypeToBitFlag(mapTypeBits,
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
+ mapTypeStrs.push_back("implicit");
if (mapTypeToBitFlag(mapTypeBits,
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
mapTypeStrs.push_back("close");
>From b0772b76e77263bcee2d2ee75c323239e0799299 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 22 Sep 2023 18:13:46 +0100
Subject: [PATCH 2/2] [OpenMP][MLIR] Add "IsolatedFromAbove" and
"OutlineableOpenMPOpInterface" trait to omp.target
This patch adds the MLIR translation changes required for add the IsolatedFromAbove and OutlineableOpenMPOpInterface traits to omp.target. It links the newly added block arguments to their corresponding llvm values.
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 27 ++++++++++++++-----
.../OpenMPToLLVM/convert-to-llvmir.mlir | 14 ++++++----
mlir/test/Dialect/OpenMP/canonicalize.mlir | 5 ++--
.../omptarget-declare-target-llvm-device.mlir | 11 ++++----
.../LLVMIR/omptarget-region-device-llvm.mlir | 7 ++---
.../omptarget-region-llvm-target-device.mlir | 7 ++---
.../Target/LLVMIR/omptarget-region-llvm.mlir | 7 ++---
.../omptarget-region-parallel-llvm.mlir | 7 ++---
8 files changed, 54 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8f7f1963b3e5a4f..f796ace1ab29b9e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2011,6 +2011,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
auto targetOp = cast<omp::TargetOp>(opInst);
auto &targetRegion = targetOp.getRegion();
+ DataLayout DL = DataLayout(opInst.getParentOfType<ModuleOp>());
+ SmallVector<Value> mapOperands = targetOp.getMapOperands();
// This function filters out kernel data that will not show up as kernel
// input arguments to the generated kernel function but will still need
@@ -2018,7 +2020,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
// (declare target). It also prepares some data used for generating the
// kernel and populating the associated OpenMP runtime data structures.
auto getKernelArguments =
- [&](const llvm::SetVector<Value> &operandSet,
+ [&](const llvm::SmallVectorImpl<Value> &operandSet,
llvm::SmallVectorImpl<llvm::Value *> &llvmInputs) {
for (Value operand : operandSet) {
if (!getRefPtrIfDeclareTarget(operand, moduleTranslation))
@@ -2026,11 +2028,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
}
};
- llvm::SetVector<Value> operandSet;
- getUsedValuesDefinedAbove(targetRegion, operandSet);
+ llvm::SmallVector<Value> mapVals;
+ for (Value mapOp : mapOperands) {
+ auto mapInfoOp =
+ mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
+ mapVals.push_back(mapInfoOp.getVarPtr());
+ }
llvm::SmallVector<llvm::Value *> inputs;
- getKernelArguments(operandSet, inputs);
+ getKernelArguments(mapVals, inputs);
LogicalResult bodyGenStatus = success();
@@ -2038,6 +2044,16 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
auto bodyCB = [&](InsertPointTy allocaIP,
InsertPointTy codeGenIP) -> InsertPointTy {
builder.restoreIP(codeGenIP);
+ unsigned argIndex = 0;
+ for (auto &mapOp : mapOperands) {
+ auto mapInfoOp =
+ mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
+ llvm::Value *mapOpValue =
+ moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
+ const auto &arg = targetRegion.front().getArgument(argIndex);
+ moduleTranslation.mapValue(arg, mapOpValue);
+ argIndex++;
+ }
llvm::BasicBlock *exitBlock = convertOmpOpRegions(
targetRegion, "omp.target", builder, moduleTranslation, bodyGenStatus);
builder.SetInsertPoint(exitBlock);
@@ -2065,9 +2081,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
- DataLayout DL = DataLayout(opInst.getParentOfType<ModuleOp>());
- SmallVector<Value> mapOperands = targetOp.getMapOperands();
-
auto getMapTypes = [](mlir::OperandRange mapOperands,
mlir::MLIRContext *ctx) {
SmallVector<mlir::Attribute> mapTypes;
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 1df27dd9957e594..b290c69a14cfba2 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -244,10 +244,12 @@ llvm.func @_QPomp_target_data_region(%a : !llvm.ptr<array<1024 x i32>>, %i : !ll
// CHECK: %[[ARG_0:.*]]: !llvm.ptr<array<1024 x i32>>,
// CHECK: %[[ARG_1:.*]]: !llvm.ptr<i32>) {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(64 : i32) : i32
-// CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG_0]] : !llvm.ptr<array<1024 x i32>>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<array<1024 x i32>> {name = ""}
-// CHECK: omp.target thread_limit(%[[VAL_0]] : i32) map_entries(%[[MAP]] : !llvm.ptr<array<1024 x i32>>) {
+// CHECK: %[[MAP1:.*]] = omp.map_info var_ptr(%[[ARG_0]] : !llvm.ptr<array<1024 x i32>>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<array<1024 x i32>> {name = ""}
+// CHECK: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ARG_1]] : !llvm.ptr<i32>) map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr<i32> {name = ""}
+// CHECK: omp.target thread_limit(%[[VAL_0]] : i32) map_entries(%[[MAP1]], %[[MAP2]] : !llvm.ptr<array<1024 x i32>>, !llvm.ptr<i32>) {
+// CHECK: ^bb0(%[[BB_ARG0:.*]]: !llvm.ptr<array<1024 x i32>>, %[[BB_ARG1:.*]]: !llvm.ptr<i32>):
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(10 : i32) : i32
-// CHECK: llvm.store %[[VAL_1]], %[[ARG_1]] : !llvm.ptr<i32>
+// CHECK: llvm.store %[[VAL_1]], %[[BB_ARG1]] : !llvm.ptr<i32>
// CHECK: omp.terminator
// CHECK: }
// CHECK: llvm.return
@@ -256,9 +258,11 @@ llvm.func @_QPomp_target_data_region(%a : !llvm.ptr<array<1024 x i32>>, %i : !ll
llvm.func @_QPomp_target(%a : !llvm.ptr<array<1024 x i32>>, %i : !llvm.ptr<i32>) {
%0 = llvm.mlir.constant(64 : i32) : i32
%1 = omp.map_info var_ptr(%a : !llvm.ptr<array<1024 x i32>>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<array<1024 x i32>> {name = ""}
- omp.target thread_limit(%0 : i32) map_entries(%1 : !llvm.ptr<array<1024 x i32>>) {
+ %3 = omp.map_info var_ptr(%i : !llvm.ptr<i32>) map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr<i32> {name = ""}
+ omp.target thread_limit(%0 : i32) map_entries(%1, %3 : !llvm.ptr<array<1024 x i32>>, !llvm.ptr<i32>) {
+ ^bb0(%arg0: !llvm.ptr<array<1024 x i32>>, %arg1: !llvm.ptr<i32>):
%2 = llvm.mlir.constant(10 : i32) : i32
- llvm.store %2, %i : !llvm.ptr<i32>
+ llvm.store %2, %arg1 : !llvm.ptr<i32>
omp.terminator
}
llvm.return
diff --git a/mlir/test/Dialect/OpenMP/canonicalize.mlir b/mlir/test/Dialect/OpenMP/canonicalize.mlir
index 68f5bacb1def178..4ecbb027ba47f24 100644
--- a/mlir/test/Dialect/OpenMP/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenMP/canonicalize.mlir
@@ -131,8 +131,9 @@ func.func private @foo() -> ()
func.func @constant_hoisting_target(%x : !llvm.ptr<i32>) {
omp.target {
+ ^bb0(%arg0: !llvm.ptr<i32>):
%c1 = arith.constant 10 : i32
- llvm.store %c1, %x : i32, !llvm.ptr<i32>
+ llvm.store %c1, %arg0 : i32, !llvm.ptr<i32>
omp.terminator
}
return
@@ -141,4 +142,4 @@ func.func @constant_hoisting_target(%x : !llvm.ptr<i32>) {
// CHECK-LABEL: func.func @constant_hoisting_target
// CHECK-NOT: arith.constant
// CHECK: omp.target
-// CHECK-NEXT: arith.constant
+// CHECK: arith.constant
diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir
index f279772bd6c0969..4bf5af8a0cc0d6a 100644
--- a/mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// This test the generation of additional load operations for declare target link variables
-// inside of target op regions when lowering to IR for device. Unfortunately as the host file is not
+// inside of target op regions when lowering to IR for device. Unfortunately as the host file is not
// passed as a module attribute, we miss out on the metadata and entryinfo.
//
-// Unfortunately, only so much can be tested as the device side is dependent on a *.bc
+// Unfortunately, only so much can be tested as the device side is dependent on a *.bc
// file created by the host and appended as an attribute to the module.
module attributes {omp.is_target_device = true} {
@@ -13,18 +13,19 @@ module attributes {omp.is_target_device = true} {
%0 = llvm.mlir.constant(0 : i32) : i32
llvm.return %0 : i32
}
-
+
llvm.func @_QQmain() attributes {} {
%0 = llvm.mlir.addressof @_QMtest_0Esp : !llvm.ptr<i32>
-
+
// CHECK-DAG: omp.target: ; preds = %user_code.entry
// CHECK-DAG: %1 = load ptr, ptr @_QMtest_0Esp_decl_tgt_ref_ptr, align 8
// CHECK-DAG: store i32 1, ptr %1, align 4
// CHECK-DAG: br label %omp.region.cont
%map = omp.map_info var_ptr(%0 : !llvm.ptr<i32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
omp.target map_entries(%map : !llvm.ptr<i32>) {
+ ^bb0(%arg0: !llvm.ptr<i32>):
%1 = llvm.mlir.constant(1 : i32) : i32
- llvm.store %1, %0 : !llvm.ptr<i32>
+ llvm.store %1, %arg0 : !llvm.ptr<i32>
omp.terminator
}
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
index cf70469e7484f64..dbc5f2d3475d01c 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
@@ -16,10 +16,11 @@ module attributes {omp.is_target_device = true} {
%map2 = omp.map_info var_ptr(%5 : !llvm.ptr<i32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
%map3 = omp.map_info var_ptr(%7 : !llvm.ptr<i32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
omp.target map_entries(%map1, %map2, %map3 : !llvm.ptr<i32>, !llvm.ptr<i32>, !llvm.ptr<i32>) {
- %8 = llvm.load %3 : !llvm.ptr<i32>
- %9 = llvm.load %5 : !llvm.ptr<i32>
+ ^bb0(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>, %arg2: !llvm.ptr<i32>):
+ %8 = llvm.load %arg0 : !llvm.ptr<i32>
+ %9 = llvm.load %arg1 : !llvm.ptr<i32>
%10 = llvm.add %8, %9 : i32
- llvm.store %10, %7 : !llvm.ptr<i32>
+ llvm.store %10, %arg2 : !llvm.ptr<i32>
omp.terminator
}
llvm.return
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
index 01a3bd556294f3e..bea3e481cae3496 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
@@ -3,10 +3,11 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
module attributes {omp.is_target_device = true} {
- llvm.func @writeindex_omp_outline_0_(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>) attributes {omp.outline_parent_name = "writeindex_"} {
- %0 = omp.map_info var_ptr(%arg0 : !llvm.ptr<i32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
- %1 = omp.map_info var_ptr(%arg1 : !llvm.ptr<i32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
+ llvm.func @writeindex_omp_outline_0_(%val0: !llvm.ptr<i32>, %val1: !llvm.ptr<i32>) attributes {omp.outline_parent_name = "writeindex_"} {
+ %0 = omp.map_info var_ptr(%val0 : !llvm.ptr<i32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
+ %1 = omp.map_info var_ptr(%val1 : !llvm.ptr<i32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
omp.target map_entries(%0, %1 : !llvm.ptr<i32>, !llvm.ptr<i32>) {
+ ^bb0(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>):
%2 = llvm.mlir.constant(20 : i32) : i32
%3 = llvm.mlir.constant(10 : i32) : i32
llvm.store %3, %arg0 : !llvm.ptr<i32>
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir
index 51506386d83782d..8d11a684f7139c2 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir
@@ -16,10 +16,11 @@ module attributes {omp.is_target_device = false} {
%map2 = omp.map_info var_ptr(%5 : !llvm.ptr<i32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
%map3 = omp.map_info var_ptr(%7 : !llvm.ptr<i32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
omp.target map_entries(%map1, %map2, %map3 : !llvm.ptr<i32>, !llvm.ptr<i32>, !llvm.ptr<i32>) {
- %8 = llvm.load %3 : !llvm.ptr<i32>
- %9 = llvm.load %5 : !llvm.ptr<i32>
+ ^bb0(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>, %arg2: !llvm.ptr<i32>):
+ %8 = llvm.load %arg0 : !llvm.ptr<i32>
+ %9 = llvm.load %arg1 : !llvm.ptr<i32>
%10 = llvm.add %8, %9 : i32
- llvm.store %10, %7 : !llvm.ptr<i32>
+ llvm.store %10, %arg2 : !llvm.ptr<i32>
omp.terminator
}
llvm.return
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir
index f0bd37ca36e93b4..a0c57ddc8148aed 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir
@@ -16,11 +16,12 @@ module attributes {omp.is_target_device = false} {
%map2 = omp.map_info var_ptr(%5 : !llvm.ptr<i32>) map_clauses(tofrommap_info) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
%map3 = omp.map_info var_ptr(%7 : !llvm.ptr<i32>) map_clauses(tofrommap_info) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
omp.target map_entries( %map1, %map2, %map3 : !llvm.ptr<i32>, !llvm.ptr<i32>, !llvm.ptr<i32>) {
+ ^bb0(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>, %arg2: !llvm.ptr<i32>):
omp.parallel {
- %8 = llvm.load %3 : !llvm.ptr<i32>
- %9 = llvm.load %5 : !llvm.ptr<i32>
+ %8 = llvm.load %arg0 : !llvm.ptr<i32>
+ %9 = llvm.load %arg1 : !llvm.ptr<i32>
%10 = llvm.add %8, %9 : i32
- llvm.store %10, %7 : !llvm.ptr<i32>
+ llvm.store %10, %arg2 : !llvm.ptr<i32>
omp.terminator
}
omp.terminator
More information about the flang-commits
mailing list