[flang-commits] [flang] [mlir] [flang][OpenMP] - Add `MapInfoOp` instances for target private variables when needed (PR #109862)
via flang-commits
flang-commits at lists.llvm.org
Tue Sep 24 13:44:44 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Pranav Bhandarkar (bhandarkar-pranav)
<details>
<summary>Changes</summary>
This PR adds an OpenMP dialect related pass for FIR/HLFIR which creates `MapInfoOp` instances for certain privatized symbols. For example, if an allocatable variable is used in a private clause attached to a `omp.target` op, then the allocatable variable's descriptor will be needed on the device (e.g. GPU). This descriptor needs to be separately mapped onto the device. This pass creates the necessary `omp.map.info` ops for this.
---
Full diff: https://github.com/llvm/llvm-project/pull/109862.diff
6 Files Affected:
- (modified) flang/include/flang/Optimizer/OpenMP/Passes.td (+13)
- (modified) flang/include/flang/Tools/CLOptions.inc (+1)
- (modified) flang/lib/Optimizer/OpenMP/CMakeLists.txt (+1)
- (added) flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp (+131)
- (added) flang/test/Transforms/omp-maps-for-privatized-symbols.fir (+84)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+1-1)
``````````diff
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 1c0ce08f5b4838..c070bc22ff20cc 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -22,6 +22,19 @@ def MapInfoFinalizationPass
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}
+def MapsForPrivatizedSymbolsPass
+ : Pass<"omp-maps-for-privatized-symbols", "mlir::func::FuncOp"> {
+ let summary = "Creates MapInfoOp instances for privatized symbols when needed";
+ let description = [{
+ Adds omp.map.info operations for privatized symbols on omp.target ops
+ In certain situations, such as when an allocatable is privatized, its
+ descriptor is needed in the alloc region of the privatizer. This results
+ in the use of the descriptor inside the target region. As such, the
+ descriptor then needs to be mapped. This pass adds such MapInfoOp operations.
+ }];
+ let dependentDialects = ["mlir::omp::OpenMPDialect"];
+}
+
def MarkDeclareTargetPass
: Pass<"omp-mark-declare-target", "mlir::ModuleOp"> {
let summary = "Marks all functions called by an OpenMP declare target function as declare target";
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 04b7f0ba370b86..4b21dd6917d39f 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -368,6 +368,7 @@ inline void createHLFIRToFIRPassPipeline(
inline void createOpenMPFIRPassPipeline(
mlir::PassManager &pm, bool isTargetDevice) {
pm.addPass(flangomp::createMapInfoFinalizationPass());
+ pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass());
pm.addPass(flangomp::createMarkDeclareTargetPass());
if (isTargetDevice)
pm.addPass(flangomp::createFunctionFilteringPass());
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 92051634f0378b..035d0d5ca46c76 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(FlangOpenMPTransforms
FunctionFiltering.cpp
+ MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
new file mode 100644
index 00000000000000..4943181e2d3acd
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -0,0 +1,131 @@
+//===- MapsForPrivatizedSymbols.cpp
+//-----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+/// \file
+/// An OpenMP dialect related pass for FIR/HLFIR which creates MapInfoOp
+/// instances for certain privatized symbols.
+/// For example, if an allocatable variable is used in a private clause attached
+/// to a omp.target op, then the allocatable variable's descriptor will be
+/// needed on the device (e.g. GPU). This descriptor needs to be separately
+/// mapped onto the device. This pass creates the necessary omp.map.info ops for
+/// this.
+//===----------------------------------------------------------------------===//
+// TODO:
+// 1. Before adding omp.map.info, check if in case we already have an
+// omp.map.info for the variable in question.
+// 2. Generalize this for more than just omp.target ops.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Support/Debug.h"
+#include <type_traits>
+
+#define DEBUG_TYPE "omp-maps-for-privatized-symbols"
+
+namespace flangomp {
+#define GEN_PASS_DEF_MAPSFORPRIVATIZEDSYMBOLSPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+using namespace mlir;
+namespace {
+class MapsForPrivatizedSymbolsPass
+ : public flangomp::impl::MapsForPrivatizedSymbolsPassBase<
+ MapsForPrivatizedSymbolsPass> {
+
+ bool privatizerNeedsMap(omp::PrivateClauseOp &privatizer) {
+ Region &allocRegion = privatizer.getAllocRegion();
+ Value blockArg0 = allocRegion.getArgument(0);
+ if (blockArg0.use_empty())
+ return false;
+ return true;
+ }
+ omp::MapInfoOp createMapInfo(Location loc, Value var, OpBuilder &builder) {
+ uint64_t mapTypeTo = static_cast<
+ std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+ Operation *definingOp = var.getDefiningOp();
+ auto declOp = llvm::dyn_cast_or_null<hlfir::DeclareOp>(definingOp);
+ assert(declOp &&
+ "Expected defining Op of privatized var to be hlfir.declare");
+ Value varPtr = declOp.getOriginalBase();
+
+ return builder.create<omp::MapInfoOp>(
+ loc, varPtr.getType(), varPtr,
+ TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType())
+ .getElementType()),
+ /*varPtrPtr=*/Value{},
+ /*members=*/SmallVector<Value>{},
+ /*member_index=*/DenseIntElementsAttr{},
+ /*bounds=*/ValueRange{},
+ builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
+ mapTypeTo),
+ builder.getAttr<omp::VariableCaptureKindAttr>(
+ omp::VariableCaptureKind::ByRef),
+ StringAttr(), builder.getBoolAttr(false));
+ }
+ void addMapInfoOp(omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
+ Location loc = targetOp.getLoc();
+ targetOp.getMapVarsMutable().append(ValueRange{mapInfoOp});
+ size_t numMapVars = targetOp.getMapVars().size();
+ targetOp.getRegion().insertArgument(numMapVars - 1, mapInfoOp.getType(),
+ loc);
+ }
+ void addMapInfoOps(omp::TargetOp targetOp,
+ llvm::SmallVectorImpl<omp::MapInfoOp> &mapInfoOps) {
+ for (auto mapInfoOp : mapInfoOps)
+ addMapInfoOp(targetOp, mapInfoOp);
+ }
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ OpBuilder builder(context);
+ llvm::DenseMap<Operation *, llvm::SmallVector<omp::MapInfoOp, 4>>
+ mapInfoOpsForTarget;
+ getOperation()->walk([&](omp::TargetOp targetOp) {
+ if (targetOp.getPrivateVars().empty())
+ return;
+ OperandRange privVars = targetOp.getPrivateVars();
+ std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
+ SmallVector<omp::MapInfoOp, 4> mapInfoOps;
+ for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
+
+ SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
+ omp::PrivateClauseOp privatizer =
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+ targetOp, privatizerName);
+ if (!privatizerNeedsMap(privatizer)) {
+ continue;
+ }
+ builder.setInsertionPoint(targetOp);
+ Location loc = targetOp.getLoc();
+ omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
+ mapInfoOps.push_back(mapInfoOp);
+ LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
+ LLVM_DEBUG(mapInfoOp.dump());
+ }
+ if (!mapInfoOps.empty()) {
+ mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
+ }
+ });
+ if (!mapInfoOpsForTarget.empty()) {
+ for (auto &[targetOp, mapInfoOps] : mapInfoOpsForTarget) {
+ addMapInfoOps(static_cast<omp::TargetOp>(targetOp), mapInfoOps);
+ }
+ }
+ }
+};
+} // namespace
diff --git a/flang/test/Transforms/omp-maps-for-privatized-symbols.fir b/flang/test/Transforms/omp-maps-for-privatized-symbols.fir
new file mode 100644
index 00000000000000..cd7ee2463238ee
--- /dev/null
+++ b/flang/test/Transforms/omp-maps-for-privatized-symbols.fir
@@ -0,0 +1,84 @@
+// RUN: fir-opt --split-input-file --omp-maps-for-privatized-symbols %s | FileCheck %s
+module attributes {omp.is_target_device = false} {
+ omp.private {type = private} @_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 : !fir.ref<!fir.box<!fir.heap<i32>>> alloc {
+ ^bb0(%arg0: !fir.ref<!fir.box<!fir.heap<i32>>>):
+ %0 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", pinned, uniq_name = "_QFtarget_simpleEsimple_var"}
+ %1 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ %2 = fir.box_addr %1 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+ %3 = fir.convert %2 : (!fir.heap<i32>) -> i64
+ %c0_i64 = arith.constant 0 : i64
+ %4 = arith.cmpi ne, %3, %c0_i64 : i64
+ fir.if %4 {
+ %6 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ %7 = fir.box_addr %6 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+ %8 = fir.allocmem i32 {fir.must_be_heap = true, uniq_name = "_QFtarget_simpleEsimple_var.alloc"}
+ %9 = fir.embox %8 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+ fir.store %9 to %0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ } else {
+ %6 = fir.zero_bits !fir.heap<i32>
+ %7 = fir.embox %6 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+ fir.store %7 to %0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ }
+ %5:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+ omp.yield(%5#0 : !fir.ref<!fir.box<!fir.heap<i32>>>)
+ } dealloc {
+ ^bb0(%arg0: !fir.ref<!fir.box<!fir.heap<i32>>>):
+ %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ %1 = fir.box_addr %0 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+ %2 = fir.convert %1 : (!fir.heap<i32>) -> i64
+ %c0_i64 = arith.constant 0 : i64
+ %3 = arith.cmpi ne, %2, %c0_i64 : i64
+ fir.if %3 {
+ %false = arith.constant false
+ %4 = fir.absent !fir.box<none>
+ %c70 = arith.constant 70 : index
+ %c10_i32 = arith.constant 10 : i32
+ %6 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ %7 = fir.box_addr %6 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+ fir.freemem %7 : !fir.heap<i32>
+ %8 = fir.zero_bits !fir.heap<i32>
+ %9 = fir.embox %8 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+ fir.store %9 to %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ }
+ omp.yield
+ }
+ func.func @_QPtarget_simple() {
+ %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
+ %1:2 = hlfir.declare %0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %2 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
+ %3 = fir.zero_bits !fir.heap<i32>
+ %4 = fir.embox %3 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+ fir.store %4 to %2 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ %5:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+ %c2_i32 = arith.constant 2 : i32
+ hlfir.assign %c2_i32 to %1#0 : i32, !fir.ref<i32>
+ %6 = omp.map.info var_ptr(%1#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+ omp.target map_entries(%6 -> %arg0 : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %5#0 -> %arg1 : !fir.ref<!fir.box<!fir.heap<i32>>>) {
+ ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<!fir.box<!fir.heap<i32>>>):
+ %11:2 = hlfir.declare %arg0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %12:2 = hlfir.declare %arg1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+ %c10_i32 = arith.constant 10 : i32
+ %13 = fir.load %11#0 : !fir.ref<i32>
+ %14 = arith.addi %c10_i32, %13 : i32
+ hlfir.assign %14 to %12#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
+ omp.terminator
+ }
+ %7 = fir.load %5#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ %8 = fir.box_addr %7 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+ %9 = fir.convert %8 : (!fir.heap<i32>) -> i64
+ %c0_i64 = arith.constant 0 : i64
+ %10 = arith.cmpi ne, %9, %c0_i64 : i64
+ fir.if %10 {
+ %11 = fir.load %5#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ %12 = fir.box_addr %11 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+ fir.freemem %12 : !fir.heap<i32>
+ %13 = fir.zero_bits !fir.heap<i32>
+ %14 = fir.embox %13 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+ fir.store %14 to %5#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
+ }
+ return
+ }
+}
+// CHECK: %[[MAP0:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.box<!fir.heap<i32>>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.box<!fir.heap<i32>>>
+// CHECK: omp.target map_entries(%[[MAP0]] -> %arg0, %[[MAP1]] -> %arg1 : !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 9d2123a2e9bf52..edc3679a16e37d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -880,7 +880,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
objects (e.g. derived types or classes), indicates the bounds to be copied
of the variable. When it's an array slice it is in rank order where rank 0
is the inner-most dimension.
- - 'map_clauses': OpenMP map type for this map capture, for example: from, to and
+ - 'map_type': OpenMP map type for this map capture, for example: from, to and
always. It's a bitfield composed of the OpenMP runtime flags stored in
OpenMPOffloadMappingFlags.
- 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla
``````````
</details>
https://github.com/llvm/llvm-project/pull/109862
More information about the flang-commits
mailing list