[Mlir-commits] [flang] [mlir] [flang][OpenMP] - Add `MapInfoOp` instances for target private variables when needed (PR #109862)

Pranav Bhandarkar llvmlistbot at llvm.org
Tue Sep 24 13:44:08 PDT 2024


https://github.com/bhandarkar-pranav created https://github.com/llvm/llvm-project/pull/109862

This PR adds an OpenMP dialect related pass for FIR/HLFIR which creates `MapInfoOp` instances for certain privatized symbols. For example, if an allocatable variable is used in a private clause attached to a `omp.target` op, then the allocatable variable's descriptor will be needed on the device (e.g. GPU). This descriptor needs to be separately mapped onto the device. This pass creates the necessary `omp.map.info` ops for this.

>From 1d22d1d30cb82e0df4a40c3b29569660ff3a9e53 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Wed, 18 Sep 2024 23:50:02 -0500
Subject: [PATCH 1/5] Add MapsForPrivatizedSymbolsPass

---
 .../include/flang/Optimizer/OpenMP/Passes.td  |  13 ++
 flang/include/flang/Tools/CLOptions.inc       |   1 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 +
 .../OpenMP/MapsForPrivatizedSymbols.cpp       | 119 ++++++++++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   2 +-
 5 files changed, 135 insertions(+), 1 deletion(-)
 create mode 100644 flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 1c0ce08f5b4838..de2c561c2d3f37 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -22,6 +22,19 @@ def MapInfoFinalizationPass
   let dependentDialects = ["mlir::omp::OpenMPDialect"];
 }
 
+def MapsForPrivatizedSymbolsPass
+    : Pass<"omp-map-for-privatized-symbols", "mlir::func::FuncOp"> {
+  let summary = "Creates MapInfoOp instances for privatized symbols when needed";
+  let description = [{
+    Adds omp.map.info operations for privatized symbols on omp.target ops
+    In certain situations, such as when an allocatable is privatized, its
+    descriptor is needed in the alloc region of the privatizer. This results
+    in the use of the descriptor inside the target region. As such, the
+    descriptor then needs to be mapped. This pass adds such MapInfoOp operations.
+  }];
+  let dependentDialects = ["mlir::omp::OpenMPDialect"];
+}
+
 def MarkDeclareTargetPass
     : Pass<"omp-mark-declare-target", "mlir::ModuleOp"> {
   let summary = "Marks all functions called by an OpenMP declare target function as declare target";
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 04b7f0ba370b86..4b21dd6917d39f 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -368,6 +368,7 @@ inline void createHLFIRToFIRPassPipeline(
 inline void createOpenMPFIRPassPipeline(
     mlir::PassManager &pm, bool isTargetDevice) {
   pm.addPass(flangomp::createMapInfoFinalizationPass());
+  pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass());
   pm.addPass(flangomp::createMarkDeclareTargetPass());
   if (isTargetDevice)
     pm.addPass(flangomp::createFunctionFilteringPass());
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 92051634f0378b..035d0d5ca46c76 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
 
 add_flang_library(FlangOpenMPTransforms
   FunctionFiltering.cpp
+  MapsForPrivatizedSymbols.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
 
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
new file mode 100644
index 00000000000000..44c7a7a9302ee5
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -0,0 +1,119 @@
+//===- MapsForPrivatizedSymbols.cpp
+//-----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include <type_traits>
+
+namespace flangomp {
+#define GEN_PASS_DEF_MAPSFORPRIVATIZEDSYMBOLSPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+using namespace mlir;
+namespace {
+class MapsForPrivatizedSymbolsPass
+    : public flangomp::impl::MapsForPrivatizedSymbolsPassBase<
+          MapsForPrivatizedSymbolsPass> {
+
+  bool privatizerNeedsMap(omp::PrivateClauseOp &privatizer) {
+    Region &allocRegion = privatizer.getAllocRegion();
+    Value blockArg0 = allocRegion.getArgument(0);
+    if (blockArg0.use_empty())
+      return false;
+    return true;
+  }
+  void dumpPrivatizerInfo(omp::PrivateClauseOp &privatizer,
+                          mlir::Value privVar) {
+    llvm::errs() << "Found a privatizer:\n";
+    privatizer.dump();
+    llvm::errs() << "\n";
+
+    llvm::errs() << "$type = " << privatizer.getType() << "\n";
+    llvm::errs() << "privVar = ";
+    privVar.dump();
+    llvm::errs() << "\n";
+
+    llvm::errs() << "privVar.getDefiningOp() = ";
+    privVar.getDefiningOp()->dump();
+    llvm::errs() << "\n";
+    llvm::errs() << "\n";
+  }
+  omp::MapInfoOp createMapInfo(mlir::Location loc, mlir::Value var,
+                               OpBuilder &builder) {
+    //    llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
+    //    llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
+    uint64_t mapTypeTo = static_cast<
+        std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+    return builder.create<omp::MapInfoOp>(
+        loc, var.getType(), var,
+        mlir::TypeAttr::get(fir::unwrapRefType(var.getType())),
+        /*varPtrPtr=*/mlir::Value{},
+        /*members=*/mlir::SmallVector<mlir::Value>{},
+        /*member_index=*/mlir::DenseIntElementsAttr{},
+        /*bounds=*/mlir::ValueRange{},
+        builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
+                               mapTypeTo),
+        builder.getAttr<omp::VariableCaptureKindAttr>(
+            omp::VariableCaptureKind::ByRef),
+        mlir::StringAttr(), builder.getBoolAttr(false));
+  }
+  void addMapInfoOp(omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
+    mlir::Location loc = targetOp.getLoc();
+    targetOp.getMapVarsMutable().append(mlir::ValueRange{mapInfoOp});
+    size_t numMapVars = targetOp.getMapVars().size();
+    targetOp.getRegion().insertArgument(numMapVars - 1, mapInfoOp.getType(),
+                                        loc);
+  }
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    OpBuilder builder(context);
+    getOperation()->walk([&](omp::TargetOp targetOp) {
+      llvm::errs() << "MapsForPrivatizedSymbolsPass::TargetOp is \n";
+      targetOp.dump();
+      llvm::errs() << "\n";
+
+      if (targetOp.getPrivateVars().empty())
+        return;
+
+      OperandRange privVars = targetOp.getPrivateVars();
+      std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
+
+      for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
+
+        SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
+        omp::PrivateClauseOp privatizer =
+            SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+                targetOp, privatizerName);
+
+        assert(mlir::isa<fir::ReferenceType>(privVar.getType()) &&
+               "Privatized variable should be a reference.");
+        if (!privatizerNeedsMap(privatizer)) {
+          return;
+        }
+        llvm::errs() << "Privatizer NEEDS a map\n";
+        builder.setInsertionPoint(targetOp);
+        dumpPrivatizerInfo(privatizer, privVar);
+
+        mlir::Location loc = targetOp.getLoc();
+        omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
+        addMapInfoOp(targetOp, mapInfoOp);
+        llvm::errs() << __FUNCTION__ << "MapInfoOp is \n";
+        mapInfoOp.dump();
+        llvm::errs() << "\n";
+      }
+    });
+  }
+};
+} // namespace
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 9d2123a2e9bf52..edc3679a16e37d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -880,7 +880,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
        objects (e.g. derived types or classes), indicates the bounds to be copied
        of the variable. When it's an array slice it is in rank order where rank 0
        is the inner-most dimension.
-    - 'map_clauses': OpenMP map type for this map capture, for example: from, to and
+    - 'map_type': OpenMP map type for this map capture, for example: from, to and
        always. It's a bitfield composed of the OpenMP runtime flags stored in
        OpenMPOffloadMappingFlags.
     - 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla

>From 2b42abbc200d1d92dbd1f6ad2951db9fa83de1c3 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Sun, 22 Sep 2024 00:38:21 -0500
Subject: [PATCH 2/5] clean up and add one test

---
 .../OpenMP/MapsForPrivatizedSymbols.cpp       | 34 ++------
 .../omp-maps-for-privatized-symbols.fir       | 84 +++++++++++++++++++
 2 files changed, 90 insertions(+), 28 deletions(-)
 create mode 100644 flang/test/Transforms/omp-maps-for-privatized-symbols.fir

diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 44c7a7a9302ee5..494b4ebc6c44fa 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -14,8 +14,11 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Support/Debug.h"
 #include <type_traits>
 
+#define DEBUG_TYPE "omp-maps-for-privatized-symbols"
+
 namespace flangomp {
 #define GEN_PASS_DEF_MAPSFORPRIVATIZEDSYMBOLSPASS
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
@@ -33,29 +36,12 @@ class MapsForPrivatizedSymbolsPass
       return false;
     return true;
   }
-  void dumpPrivatizerInfo(omp::PrivateClauseOp &privatizer,
-                          mlir::Value privVar) {
-    llvm::errs() << "Found a privatizer:\n";
-    privatizer.dump();
-    llvm::errs() << "\n";
-
-    llvm::errs() << "$type = " << privatizer.getType() << "\n";
-    llvm::errs() << "privVar = ";
-    privVar.dump();
-    llvm::errs() << "\n";
-
-    llvm::errs() << "privVar.getDefiningOp() = ";
-    privVar.getDefiningOp()->dump();
-    llvm::errs() << "\n";
-    llvm::errs() << "\n";
-  }
   omp::MapInfoOp createMapInfo(mlir::Location loc, mlir::Value var,
                                OpBuilder &builder) {
-    //    llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
-    //    llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
     uint64_t mapTypeTo = static_cast<
         std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
         llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+
     return builder.create<omp::MapInfoOp>(
         loc, var.getType(), var,
         mlir::TypeAttr::get(fir::unwrapRefType(var.getType())),
@@ -80,10 +66,6 @@ class MapsForPrivatizedSymbolsPass
     MLIRContext *context = &getContext();
     OpBuilder builder(context);
     getOperation()->walk([&](omp::TargetOp targetOp) {
-      llvm::errs() << "MapsForPrivatizedSymbolsPass::TargetOp is \n";
-      targetOp.dump();
-      llvm::errs() << "\n";
-
       if (targetOp.getPrivateVars().empty())
         return;
 
@@ -102,16 +84,12 @@ class MapsForPrivatizedSymbolsPass
         if (!privatizerNeedsMap(privatizer)) {
           return;
         }
-        llvm::errs() << "Privatizer NEEDS a map\n";
         builder.setInsertionPoint(targetOp);
-        dumpPrivatizerInfo(privatizer, privVar);
-
         mlir::Location loc = targetOp.getLoc();
         omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
         addMapInfoOp(targetOp, mapInfoOp);
-        llvm::errs() << __FUNCTION__ << "MapInfoOp is \n";
-        mapInfoOp.dump();
-        llvm::errs() << "\n";
+        LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
+        LLVM_DEBUG(mapInfoOp.dump());
       }
     });
   }
diff --git a/flang/test/Transforms/omp-maps-for-privatized-symbols.fir b/flang/test/Transforms/omp-maps-for-privatized-symbols.fir
new file mode 100644
index 00000000000000..cd7ee2463238ee
--- /dev/null
+++ b/flang/test/Transforms/omp-maps-for-privatized-symbols.fir
@@ -0,0 +1,84 @@
+// RUN: fir-opt --split-input-file --omp-maps-for-privatized-symbols %s | FileCheck %s
+module attributes {omp.is_target_device = false} {
+  omp.private {type = private} @_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 : !fir.ref<!fir.box<!fir.heap<i32>>> alloc {
+  ^bb0(%arg0: !fir.ref<!fir.box<!fir.heap<i32>>>):
+    %0 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", pinned, uniq_name = "_QFtarget_simpleEsimple_var"}
+    %1 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    %2 = fir.box_addr %1 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+    %3 = fir.convert %2 : (!fir.heap<i32>) -> i64
+    %c0_i64 = arith.constant 0 : i64
+    %4 = arith.cmpi ne, %3, %c0_i64 : i64
+    fir.if %4 {
+      %6 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+      %7 = fir.box_addr %6 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+      %8 = fir.allocmem i32 {fir.must_be_heap = true, uniq_name = "_QFtarget_simpleEsimple_var.alloc"}
+      %9 = fir.embox %8 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+      fir.store %9 to %0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    } else {
+      %6 = fir.zero_bits !fir.heap<i32>
+      %7 = fir.embox %6 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+      fir.store %7 to %0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    }
+    %5:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+    omp.yield(%5#0 : !fir.ref<!fir.box<!fir.heap<i32>>>)
+  } dealloc {
+  ^bb0(%arg0: !fir.ref<!fir.box<!fir.heap<i32>>>):
+    %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    %1 = fir.box_addr %0 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+    %2 = fir.convert %1 : (!fir.heap<i32>) -> i64
+    %c0_i64 = arith.constant 0 : i64
+    %3 = arith.cmpi ne, %2, %c0_i64 : i64
+    fir.if %3 {
+      %false = arith.constant false
+      %4 = fir.absent !fir.box<none>
+      %c70 = arith.constant 70 : index
+      %c10_i32 = arith.constant 10 : i32
+      %6 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+      %7 = fir.box_addr %6 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+      fir.freemem %7 : !fir.heap<i32>
+      %8 = fir.zero_bits !fir.heap<i32>
+      %9 = fir.embox %8 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+      fir.store %9 to %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    }
+    omp.yield
+  }
+  func.func @_QPtarget_simple() {
+    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
+    %1:2 = hlfir.declare %0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %2 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
+    %3 = fir.zero_bits !fir.heap<i32>
+    %4 = fir.embox %3 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+    fir.store %4 to %2 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    %5:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+    %c2_i32 = arith.constant 2 : i32
+    hlfir.assign %c2_i32 to %1#0 : i32, !fir.ref<i32>
+    %6 = omp.map.info var_ptr(%1#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+    omp.target map_entries(%6 -> %arg0 : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %5#0 -> %arg1 : !fir.ref<!fir.box<!fir.heap<i32>>>) {
+    ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<!fir.box<!fir.heap<i32>>>):
+      %11:2 = hlfir.declare %arg0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+      %12:2 = hlfir.declare %arg1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+      %c10_i32 = arith.constant 10 : i32
+      %13 = fir.load %11#0 : !fir.ref<i32>
+      %14 = arith.addi %c10_i32, %13 : i32
+      hlfir.assign %14 to %12#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
+      omp.terminator
+    }
+    %7 = fir.load %5#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    %8 = fir.box_addr %7 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+    %9 = fir.convert %8 : (!fir.heap<i32>) -> i64
+    %c0_i64 = arith.constant 0 : i64
+    %10 = arith.cmpi ne, %9, %c0_i64 : i64
+    fir.if %10 {
+      %11 = fir.load %5#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
+      %12 = fir.box_addr %11 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
+      fir.freemem %12 : !fir.heap<i32>
+      %13 = fir.zero_bits !fir.heap<i32>
+      %14 = fir.embox %13 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+      fir.store %14 to %5#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    }
+    return
+  }
+}
+// CHECK: %[[MAP0:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+// CHECK: %[[MAP1:.*]] =  omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.box<!fir.heap<i32>>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.box<!fir.heap<i32>>>    
+// CHECK: omp.target map_entries(%[[MAP0]] -> %arg0, %[[MAP1]] -> %arg1 : !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>)

>From 06af4004027bee9ad474d14ef6c4b2754aef5465 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Mon, 23 Sep 2024 16:59:12 -0500
Subject: [PATCH 3/5] fix failures seen in
 target-private-multiple-variables.f90

---
 .../OpenMP/MapsForPrivatizedSymbols.cpp       | 64 ++++++++++++++++---
 1 file changed, 54 insertions(+), 10 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 494b4ebc6c44fa..d236b3d23d0cb0 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -7,6 +7,7 @@
 //
 //===----------------------------------------------------------------------===//
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -41,10 +42,27 @@ class MapsForPrivatizedSymbolsPass
     uint64_t mapTypeTo = static_cast<
         std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
         llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
-
+    Operation *definingOp = var.getDefiningOp();
+    auto declOp = llvm::dyn_cast_or_null<hlfir::DeclareOp>(definingOp);
+    assert(declOp &&
+           "Expected defining Op of privatized var to be hlfir.declare");
+    Value varPtr = declOp.getOriginalBase();
+    Value varBase = declOp.getBase();
+    llvm::errs() << "varPtr = ";
+    varPtr.dump();
+    llvm::errs() << " type -> ";
+    varPtr.getType().dump();
+    llvm::errs() << "\n";
+    llvm::errs() << "varBase = ";
+    varBase.dump();
+    llvm::errs() << " type -> ";
+    varBase.getType().dump();
+    llvm::errs() << "\n";
     return builder.create<omp::MapInfoOp>(
-        loc, var.getType(), var,
-        mlir::TypeAttr::get(fir::unwrapRefType(var.getType())),
+        loc, varPtr.getType(), varPtr,
+        mlir::TypeAttr::get(
+            llvm::cast<mlir::omp::PointerLikeType>(varPtr.getType())
+                .getElementType()),
         /*varPtrPtr=*/mlir::Value{},
         /*members=*/mlir::SmallVector<mlir::Value>{},
         /*member_index=*/mlir::DenseIntElementsAttr{},
@@ -57,41 +75,67 @@ class MapsForPrivatizedSymbolsPass
   }
   void addMapInfoOp(omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
     mlir::Location loc = targetOp.getLoc();
+    llvm::errs() << "Adding mapInfoOp -> ";
+    mapInfoOp.dump();
+    llvm::errs() << "\n";
+
     targetOp.getMapVarsMutable().append(mlir::ValueRange{mapInfoOp});
     size_t numMapVars = targetOp.getMapVars().size();
     targetOp.getRegion().insertArgument(numMapVars - 1, mapInfoOp.getType(),
                                         loc);
   }
+  void addMapInfoOps(omp::TargetOp targetOp,
+                     llvm::SmallVectorImpl<omp::MapInfoOp> &mapInfoOps) {
+    for (auto mapInfoOp : mapInfoOps)
+      addMapInfoOp(targetOp, mapInfoOp);
+  }
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     OpBuilder builder(context);
+    llvm::DenseMap<Operation *, llvm::SmallVector<omp::MapInfoOp, 4>>
+        mapInfoOpsForTarget;
     getOperation()->walk([&](omp::TargetOp targetOp) {
       if (targetOp.getPrivateVars().empty())
         return;
-
+      llvm::errs() << "Func is \n";
+      targetOp.getOperation()->getParentOp()->getParentOp()->dump();
+      llvm::errs() << "\n";
       OperandRange privVars = targetOp.getPrivateVars();
       std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
-
+      SmallVector<omp::MapInfoOp, 4> mapInfoOps;
       for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
 
         SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
         omp::PrivateClauseOp privatizer =
             SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
                 targetOp, privatizerName);
-
-        assert(mlir::isa<fir::ReferenceType>(privVar.getType()) &&
-               "Privatized variable should be a reference.");
+        llvm::errs() << "privVar = ";
+        privVar.dump();
+        llvm::errs() << "\n";
+        llvm::errs() << "privVar.getType() = ";
+        privVar.getType().dump();
+        llvm::errs() << "\n";
+        // assert(mlir::isa<fir::ReferenceType>(privVar.getType()) &&
+        //        "Privatized variable should be a reference.");
         if (!privatizerNeedsMap(privatizer)) {
-          return;
+          continue;
         }
         builder.setInsertionPoint(targetOp);
         mlir::Location loc = targetOp.getLoc();
         omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
-        addMapInfoOp(targetOp, mapInfoOp);
+        mapInfoOps.push_back(mapInfoOp);
         LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
         LLVM_DEBUG(mapInfoOp.dump());
       }
+      if (!mapInfoOps.empty()) {
+        mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
+      }
     });
+    if (!mapInfoOpsForTarget.empty()) {
+      for (auto &[targetOp, mapInfoOps] : mapInfoOpsForTarget) {
+        addMapInfoOps(static_cast<omp::TargetOp>(targetOp), mapInfoOps);
+      }
+    }
   }
 };
 } // namespace

>From 23cf6fc9b71d51091467082946f7d710e0eadfe3 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 24 Sep 2024 10:20:51 -0500
Subject: [PATCH 4/5] fix flag omp-maps-for-privatized-symbols

---
 flang/include/flang/Optimizer/OpenMP/Passes.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index de2c561c2d3f37..c070bc22ff20cc 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -23,7 +23,7 @@ def MapInfoFinalizationPass
 }
 
 def MapsForPrivatizedSymbolsPass
-    : Pass<"omp-map-for-privatized-symbols", "mlir::func::FuncOp"> {
+    : Pass<"omp-maps-for-privatized-symbols", "mlir::func::FuncOp"> {
   let summary = "Creates MapInfoOp instances for privatized symbols when needed";
   let description = [{
     Adds omp.map.info operations for privatized symbols on omp.target ops

>From 9b08164558dd50f88c2ae589a6181d84cb377b37 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Tue, 24 Sep 2024 15:38:57 -0500
Subject: [PATCH 5/5] clean up MapsForPrivatizedSymbols.cpp

---
 .../OpenMP/MapsForPrivatizedSymbols.cpp       | 68 ++++++++-----------
 1 file changed, 29 insertions(+), 39 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index d236b3d23d0cb0..4943181e2d3acd 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -6,6 +6,23 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+/// \file
+/// An OpenMP dialect related pass for FIR/HLFIR which creates MapInfoOp
+/// instances for certain privatized symbols.
+/// For example, if an allocatable variable is used in a private clause attached
+/// to a omp.target op, then the allocatable variable's descriptor will be
+/// needed on the device (e.g. GPU). This descriptor needs to be separately
+/// mapped onto the device. This pass creates the necessary omp.map.info ops for
+/// this.
+//===----------------------------------------------------------------------===//
+// TODO:
+// 1. Before adding omp.map.info, check if in case we already have an
+//    omp.map.info for the variable in question.
+// 2. Generalize this for more than just omp.target ops.
+//===----------------------------------------------------------------------===//
+
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
@@ -37,8 +54,7 @@ class MapsForPrivatizedSymbolsPass
       return false;
     return true;
   }
-  omp::MapInfoOp createMapInfo(mlir::Location loc, mlir::Value var,
-                               OpBuilder &builder) {
+  omp::MapInfoOp createMapInfo(Location loc, Value var, OpBuilder &builder) {
     uint64_t mapTypeTo = static_cast<
         std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
         llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
@@ -47,39 +63,24 @@ class MapsForPrivatizedSymbolsPass
     assert(declOp &&
            "Expected defining Op of privatized var to be hlfir.declare");
     Value varPtr = declOp.getOriginalBase();
-    Value varBase = declOp.getBase();
-    llvm::errs() << "varPtr = ";
-    varPtr.dump();
-    llvm::errs() << " type -> ";
-    varPtr.getType().dump();
-    llvm::errs() << "\n";
-    llvm::errs() << "varBase = ";
-    varBase.dump();
-    llvm::errs() << " type -> ";
-    varBase.getType().dump();
-    llvm::errs() << "\n";
+
     return builder.create<omp::MapInfoOp>(
         loc, varPtr.getType(), varPtr,
-        mlir::TypeAttr::get(
-            llvm::cast<mlir::omp::PointerLikeType>(varPtr.getType())
-                .getElementType()),
-        /*varPtrPtr=*/mlir::Value{},
-        /*members=*/mlir::SmallVector<mlir::Value>{},
-        /*member_index=*/mlir::DenseIntElementsAttr{},
-        /*bounds=*/mlir::ValueRange{},
+        TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType())
+                          .getElementType()),
+        /*varPtrPtr=*/Value{},
+        /*members=*/SmallVector<Value>{},
+        /*member_index=*/DenseIntElementsAttr{},
+        /*bounds=*/ValueRange{},
         builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
                                mapTypeTo),
         builder.getAttr<omp::VariableCaptureKindAttr>(
             omp::VariableCaptureKind::ByRef),
-        mlir::StringAttr(), builder.getBoolAttr(false));
+        StringAttr(), builder.getBoolAttr(false));
   }
   void addMapInfoOp(omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
-    mlir::Location loc = targetOp.getLoc();
-    llvm::errs() << "Adding mapInfoOp -> ";
-    mapInfoOp.dump();
-    llvm::errs() << "\n";
-
-    targetOp.getMapVarsMutable().append(mlir::ValueRange{mapInfoOp});
+    Location loc = targetOp.getLoc();
+    targetOp.getMapVarsMutable().append(ValueRange{mapInfoOp});
     size_t numMapVars = targetOp.getMapVars().size();
     targetOp.getRegion().insertArgument(numMapVars - 1, mapInfoOp.getType(),
                                         loc);
@@ -97,9 +98,6 @@ class MapsForPrivatizedSymbolsPass
     getOperation()->walk([&](omp::TargetOp targetOp) {
       if (targetOp.getPrivateVars().empty())
         return;
-      llvm::errs() << "Func is \n";
-      targetOp.getOperation()->getParentOp()->getParentOp()->dump();
-      llvm::errs() << "\n";
       OperandRange privVars = targetOp.getPrivateVars();
       std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
       SmallVector<omp::MapInfoOp, 4> mapInfoOps;
@@ -109,19 +107,11 @@ class MapsForPrivatizedSymbolsPass
         omp::PrivateClauseOp privatizer =
             SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
                 targetOp, privatizerName);
-        llvm::errs() << "privVar = ";
-        privVar.dump();
-        llvm::errs() << "\n";
-        llvm::errs() << "privVar.getType() = ";
-        privVar.getType().dump();
-        llvm::errs() << "\n";
-        // assert(mlir::isa<fir::ReferenceType>(privVar.getType()) &&
-        //        "Privatized variable should be a reference.");
         if (!privatizerNeedsMap(privatizer)) {
           continue;
         }
         builder.setInsertionPoint(targetOp);
-        mlir::Location loc = targetOp.getLoc();
+        Location loc = targetOp.getLoc();
         omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
         mapInfoOps.push_back(mapInfoOp);
         LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");



More information about the Mlir-commits mailing list