[flang-commits] [flang] [mlir] [Flang][MLIR][OpenMP] Improve use_device_* handling (PR #137198)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Thu Apr 24 08:42:10 PDT 2025


https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/137198

This patch updates MLIR op verifiers for operations taking arguments that must always be defined by an `omp.map.info` operation to check this requirement.

It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as the custom MLIR printer and parser for these clauses, to support initializing it to `OMP_MAP_RETURN_PARAM` and represent this in the MLIR representation as `return_param`. This internal mapping flag is what eventually is used for variables passed via these clauses into the target region when translating to LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in the current representation.

>From de00b779219bad0a0164fecfc62e8d110df84e97 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 11 Apr 2025 13:30:38 +0100
Subject: [PATCH] [Flang][MLIR][OpenMP] Improve use_device_* handling

This patch updates MLIR op verifiers for operations taking arguments that must
always be defined by an `omp.map.info` operation to check this requirement.

It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as the
custom MLIR printer and parser for these clauses, to support initializing it to
`OMP_MAP_RETURN_PARAM` and represent this in the MLIR representation as
`return_param`. This internal mapping flag is what eventually is used for
variables passed via these clauses into the target region when translating to
LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in
the current representation.
---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |  6 +--
 flang/lib/Lower/OpenMP/Utils.cpp              |  4 +-
 .../Fir/convert-to-llvm-openmp-and-fir.fir    |  5 +-
 flang/test/Lower/OpenMP/target.f90            |  2 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 47 +++++++++++++++----
 mlir/test/Dialect/OpenMP/ops.mlir             | 10 ++--
 6 files changed, 55 insertions(+), 19 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index cce6dc32bad4b..d03020707ef99 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1340,8 +1340,7 @@ bool ClauseProcessor::processUseDeviceAddr(
           const parser::CharBlock &source) {
         mlir::Location location = converter.genLocation(source);
         llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
         processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
                           parentMemberIndices, result.useDeviceAddrVars,
                           useDeviceSyms);
@@ -1362,8 +1361,7 @@ bool ClauseProcessor::processUseDevicePtr(
           const parser::CharBlock &source) {
         mlir::Location location = converter.genLocation(source);
         llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
         processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
                           parentMemberIndices, result.useDevicePtrVars,
                           useDeviceSyms);
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 3f4cfb8c11a9d..bdeed8f6d213e 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -398,7 +398,7 @@ mlir::Value createParentSymAndGenIntermediateMaps(
               interimBounds, treatIndexAsSection);
         }
 
-        // Remove all map TO, FROM and TOFROM bits, from the intermediate
+        // Remove all map TO, FROM and RETURN_PARAM bits, from the intermediate
         // allocatable maps, we simply wish to alloc or release them. It may be
         // safer to just pass OMP_MAP_NONE as the map type, but we may still
         // need some of the other map types the mapped member utilises, so for
@@ -406,6 +406,8 @@ mlir::Value createParentSymAndGenIntermediateMaps(
         llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits;
         interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
         interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+        interimMapType &=
+            ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
 
         // Create a map for the intermediate member and insert it and it's
         // indices into the parentMemberIndices list to track it.
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 8019ecf7f6a05..b13921f822b4d 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -423,14 +423,15 @@ func.func @_QPopenmp_target_data_region() {
 
 func.func @_QPomp_target_data_empty() {
   %0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
-  omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
+  %1 = omp.map.info var_ptr(%0 : !fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>) map_clauses(return_param) capture(ByRef) -> !fir.ref<!fir.array<1024xi32>> {name = ""}
+  omp.target_data use_device_addr(%1 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
     omp.terminator
   }
   return
 }
 
 // CHECK-LABEL:   llvm.func @_QPomp_target_data_empty
-// CHECK: omp.target_data   use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
+// CHECK: omp.target_data   use_device_addr(%{{.*}} -> %{{.*}} : !llvm.ptr) {
 // CHECK: }
 
 // -----
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 4815e6564fc7e..f04aacc63fc2b 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -544,7 +544,7 @@ subroutine omp_target_device_addr
    !CHECK: %[[VAL_0_DECL:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
    !CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
    !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
-   !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
+   !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(return_param) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
    !CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
    !CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
    !$omp target data map(tofrom: a) use_device_addr(a)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index dd701da507fc6..a81f9a63a8ebb 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1520,6 +1520,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
     if (mapTypeMod == "delete")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
 
+    if (mapTypeMod == "return_param")
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+
     return success();
   };
 
@@ -1582,6 +1585,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
     emitAllocRelease = false;
     mapTypeStrs.push_back("delete");
   }
+  if (mapTypeToBitFlag(
+          mapTypeBits,
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
+    emitAllocRelease = false;
+    mapTypeStrs.push_back("return_param");
+  }
   if (emitAllocRelease)
     mapTypeStrs.push_back("exit_release_or_enter_alloc");
 
@@ -1776,6 +1785,17 @@ static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
 // MapInfoOp
 //===----------------------------------------------------------------------===//
 
+static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
+                                              StringRef clauseName,
+                                              OperandRange vars) {
+  for (Value var : vars)
+    if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
+      return op->emitOpError()
+             << "'" << clauseName
+             << "' arguments must be defined by 'omp.map.info' ops";
+  return success();
+}
+
 LogicalResult MapInfoOp::verify() {
   if (getMapperId() &&
       !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
@@ -1783,6 +1803,9 @@ LogicalResult MapInfoOp::verify() {
     return emitError("invalid mapper id");
   }
 
+  if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
+    return failure();
+
   return success();
 }
 
@@ -1804,6 +1827,15 @@ LogicalResult TargetDataOp::verify() {
                        "At least one of map, use_device_ptr_vars, or "
                        "use_device_addr_vars operand must be present");
   }
+
+  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
+                                      getUseDevicePtrVars())))
+    return failure();
+
+  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
+                                      getUseDeviceAddrVars())))
+    return failure();
+
   return verifyMapClause(*this, getMapVars());
 }
 
@@ -1888,16 +1920,15 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult TargetOp::verify() {
-  LogicalResult verifyDependVars =
-      verifyDependVarList(*this, getDependKinds(), getDependVars());
-
-  if (failed(verifyDependVars))
-    return verifyDependVars;
+  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
+    return failure();
 
-  LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
+  if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
+                                      getHasDeviceAddrVars())))
+    return failure();
 
-  if (failed(verifyMapVars))
-    return verifyMapVars;
+  if (failed(verifyMapClause(*this, getMapVars())))
+    return failure();
 
   return verifyPrivateVarsMapping(*this);
 }
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d5e2bfa5d3949..3cecc2188aabd 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -802,10 +802,14 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref<i
     %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
     omp.target_data if(%if_cond) device(%device : si32) map_entries(%mapv1 : memref<?xi32>){}
 
-    // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
-    // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref<i32>)
+    // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>)   map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
+    // CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>)   map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
+    // CHECK: %[[DEV_PTR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<i32>, tensor<i32>)   map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
+    // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[DEV_ADDR]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[DEV_PTR]] -> %{{.*}} : memref<i32>)
     %mapv2 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
-    omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addr -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptr -> %arg1 : memref<i32>) {
+    %device_addrv1 = omp.map.info var_ptr(%device_addr : memref<?xi32>, tensor<?xi32>) map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
+    %device_ptrv1 = omp.map.info var_ptr(%device_ptr : memref<i32>, tensor<i32>) map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
+    omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addrv1 -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptrv1 -> %arg1 : memref<i32>) {
       omp.terminator
     }
 



More information about the flang-commits mailing list