[flang-commits] [flang] [mlir] [Flang][MLIR][OpenMP] Improve use_device_* handling (PR #137198)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Thu May 15 03:23:40 PDT 2025


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/137198

>From 59bc1c921519967d71837a0833023a6dbccf9045 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 11 Apr 2025 13:30:38 +0100
Subject: [PATCH] [Flang][MLIR][OpenMP] Improve use_device_* handling

This patch updates MLIR op verifiers for operations taking arguments that must
always be defined by an `omp.map.info` operation to check this requirement.

It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as the
custom MLIR printer and parser for these clauses, to support initializing it to
`OMP_MAP_RETURN_PARAM` and represent this in the MLIR representation as
`return_param`. This internal mapping flag is what eventually is used for
variables passed via these clauses into the target region when translating to
LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in
the current representation.
---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |  6 +--
 flang/lib/Lower/OpenMP/Utils.cpp              |  8 ++--
 .../Fir/convert-to-llvm-openmp-and-fir.fir    |  5 +-
 flang/test/Lower/OpenMP/target.f90            |  2 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 47 +++++++++++++++----
 mlir/test/Dialect/OpenMP/ops.mlir             | 10 ++--
 6 files changed, 57 insertions(+), 21 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index f4876256a378f..02454543d0a60 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1407,8 +1407,7 @@ bool ClauseProcessor::processUseDeviceAddr(
           const parser::CharBlock &source) {
         mlir::Location location = converter.genLocation(source);
         llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
         processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
                           parentMemberIndices, result.useDeviceAddrVars,
                           useDeviceSyms);
@@ -1429,8 +1428,7 @@ bool ClauseProcessor::processUseDevicePtr(
           const parser::CharBlock &source) {
         mlir::Location location = converter.genLocation(source);
         llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
         processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
                           parentMemberIndices, result.useDevicePtrVars,
                           useDeviceSyms);
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 3f4cfb8c11a9d..173dceb07b193 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -398,14 +398,16 @@ mlir::Value createParentSymAndGenIntermediateMaps(
               interimBounds, treatIndexAsSection);
         }
 
-        // Remove all map TO, FROM and TOFROM bits, from the intermediate
-        // allocatable maps, we simply wish to alloc or release them. It may be
-        // safer to just pass OMP_MAP_NONE as the map type, but we may still
+        // Remove all map-type bits (e.g. TO, FROM, etc.) from the intermediate
+        // allocatable maps, as we simply wish to alloc or release them. It may
+        // be safer to just pass OMP_MAP_NONE as the map type, but we may still
         // need some of the other map types the mapped member utilises, so for
         // now it's good to keep an eye on this.
         llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits;
         interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
         interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+        interimMapType &=
+            ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
 
         // Create a map for the intermediate member and insert it and it's
         // indices into the parentMemberIndices list to track it.
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 8019ecf7f6a05..b13921f822b4d 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -423,14 +423,15 @@ func.func @_QPopenmp_target_data_region() {
 
 func.func @_QPomp_target_data_empty() {
   %0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
-  omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
+  %1 = omp.map.info var_ptr(%0 : !fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>) map_clauses(return_param) capture(ByRef) -> !fir.ref<!fir.array<1024xi32>> {name = ""}
+  omp.target_data use_device_addr(%1 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
     omp.terminator
   }
   return
 }
 
 // CHECK-LABEL:   llvm.func @_QPomp_target_data_empty
-// CHECK: omp.target_data   use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
+// CHECK: omp.target_data   use_device_addr(%{{.*}} -> %{{.*}} : !llvm.ptr) {
 // CHECK: }
 
 // -----
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 4815e6564fc7e..f04aacc63fc2b 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -544,7 +544,7 @@ subroutine omp_target_device_addr
    !CHECK: %[[VAL_0_DECL:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
    !CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
    !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
-   !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
+   !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(return_param) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
    !CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
    !CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
    !$omp target data map(tofrom: a) use_device_addr(a)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2bf7aaa46db11..deff86d5c5ecb 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1521,6 +1521,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
     if (mapTypeMod == "delete")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
 
+    if (mapTypeMod == "return_param")
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+
     return success();
   };
 
@@ -1583,6 +1586,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
     emitAllocRelease = false;
     mapTypeStrs.push_back("delete");
   }
+  if (mapTypeToBitFlag(
+          mapTypeBits,
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
+    emitAllocRelease = false;
+    mapTypeStrs.push_back("return_param");
+  }
   if (emitAllocRelease)
     mapTypeStrs.push_back("exit_release_or_enter_alloc");
 
@@ -1777,6 +1786,17 @@ static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
 // MapInfoOp
 //===----------------------------------------------------------------------===//
 
+static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
+                                              StringRef clauseName,
+                                              OperandRange vars) {
+  for (Value var : vars)
+    if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
+      return op->emitOpError()
+             << "'" << clauseName
+             << "' arguments must be defined by 'omp.map.info' ops";
+  return success();
+}
+
 LogicalResult MapInfoOp::verify() {
   if (getMapperId() &&
       !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
@@ -1784,6 +1804,9 @@ LogicalResult MapInfoOp::verify() {
     return emitError("invalid mapper id");
   }
 
+  if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
+    return failure();
+
   return success();
 }
 
@@ -1805,6 +1828,15 @@ LogicalResult TargetDataOp::verify() {
                        "At least one of map, use_device_ptr_vars, or "
                        "use_device_addr_vars operand must be present");
   }
+
+  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
+                                      getUseDevicePtrVars())))
+    return failure();
+
+  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
+                                      getUseDeviceAddrVars())))
+    return failure();
+
   return verifyMapClause(*this, getMapVars());
 }
 
@@ -1889,16 +1921,15 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult TargetOp::verify() {
-  LogicalResult verifyDependVars =
-      verifyDependVarList(*this, getDependKinds(), getDependVars());
-
-  if (failed(verifyDependVars))
-    return verifyDependVars;
+  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
+    return failure();
 
-  LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
+  if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
+                                      getHasDeviceAddrVars())))
+    return failure();
 
-  if (failed(verifyMapVars))
-    return verifyMapVars;
+  if (failed(verifyMapClause(*this, getMapVars())))
+    return failure();
 
   return verifyPrivateVarsMapping(*this);
 }
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index b7e16b7ec35e2..a9e4af035dbd7 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -802,10 +802,14 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref<i
     %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
     omp.target_data if(%if_cond) device(%device : si32) map_entries(%mapv1 : memref<?xi32>){}
 
-    // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
-    // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref<i32>)
+    // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>)   map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
+    // CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>)   map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
+    // CHECK: %[[DEV_PTR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<i32>, tensor<i32>)   map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
+    // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[DEV_ADDR]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[DEV_PTR]] -> %{{.*}} : memref<i32>)
     %mapv2 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
-    omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addr -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptr -> %arg1 : memref<i32>) {
+    %device_addrv1 = omp.map.info var_ptr(%device_addr : memref<?xi32>, tensor<?xi32>) map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
+    %device_ptrv1 = omp.map.info var_ptr(%device_ptr : memref<i32>, tensor<i32>) map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
+    omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addrv1 -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptrv1 -> %arg1 : memref<i32>) {
       omp.terminator
     }
 



More information about the flang-commits mailing list