[Mlir-commits] [mlir] ee17955 - [MLIR][OpenMP] Add OMP Mapper field to MapInfoOp (#120994)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 18 09:27:51 PST 2025


Author: Akash Banerjee
Date: 2025-02-18T17:27:48Z
New Revision: ee17955dfe454f78bef41e47ef7ce255fa2563e9

URL: https://github.com/llvm/llvm-project/commit/ee17955dfe454f78bef41e47ef7ce255fa2563e9
DIFF: https://github.com/llvm/llvm-project/commit/ee17955dfe454f78bef41e47ef7ce255fa2563e9.diff

LOG: [MLIR][OpenMP] Add OMP Mapper field to MapInfoOp (#120994)

This patch adds the mapper field to the omp.map.info op.

Depends on #117046.

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP/Utils.cpp
    flang/lib/Lower/OpenMP/Utils.h
    flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
    flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 35722fa7d1b12..fa1975dac789b 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -125,7 +125,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
                 llvm::ArrayRef<mlir::Value> members,
                 mlir::ArrayAttr membersIndex, uint64_t mapType,
                 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
-                bool partialMap) {
+                bool partialMap, mlir::FlatSymbolRefAttr mapperId) {
   if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
     baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
     retTy = baseAddr.getType();
@@ -144,6 +144,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
   mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
       loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
       builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
+      mapperId,
       builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
       builder.getStringAttr(name), builder.getBoolAttr(partialMap));
   return op;

diff  --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index f2e378443e5f2..3943eb633b04e 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -116,7 +116,8 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
                 llvm::ArrayRef<mlir::Value> members,
                 mlir::ArrayAttr membersIndex, uint64_t mapType,
                 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
-                bool partialMap = false);
+                bool partialMap = false,
+                mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr());
 
 void insertChildMapInfoIntoParent(
     Fortran::lower::AbstractConverter &converter,

diff  --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index e7c1d1d9d560f..beea7543e54b3 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -184,6 +184,7 @@ class MapInfoFinalizationPass
         /*members=*/mlir::SmallVector<mlir::Value>{},
         /*membersIndex=*/mlir::ArrayAttr{}, bounds,
         builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
+        /*mapperId*/ mlir::FlatSymbolRefAttr(),
         builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
             mlir::omp::VariableCaptureKind::ByRef),
         /*name=*/builder.getStringAttr(""),
@@ -329,7 +330,8 @@ class MapInfoFinalizationPass
             builder.getIntegerAttr(
                 builder.getIntegerType(64, false),
                 getDescriptorMapType(op.getMapType().value_or(0), target)),
-            op.getMapCaptureTypeAttr(), op.getNameAttr(),
+            /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getMapCaptureTypeAttr(),
+            op.getNameAttr(),
             /*partial_map=*/builder.getBoolAttr(false));
     op.replaceAllUsesWith(newDescParentMapOp.getResult());
     op->erase();
@@ -623,6 +625,7 @@ class MapInfoFinalizationPass
                   /*members=*/mlir::ValueRange{},
                   /*members_index=*/mlir::ArrayAttr{},
                   /*bounds=*/bounds, op.getMapTypeAttr(),
+                  /*mapperId*/ mlir::FlatSymbolRefAttr(),
                   builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
                       mlir::omp::VariableCaptureKind::ByRef),
                   builder.getStringAttr(op.getNameAttr().strref() + "." +

diff  --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 963ae863c1fc5..97ea463a3c495 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -91,6 +91,7 @@ class MapsForPrivatizedSymbolsPass
         /*bounds=*/ValueRange{},
         builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
                                mapTypeTo),
+        /*mapperId*/ mlir::FlatSymbolRefAttr(),
         builder.getAttr<omp::VariableCaptureKindAttr>(
             omp::VariableCaptureKind::ByRef),
         StringAttr(), builder.getBoolAttr(false));

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index e1eef30c0dd06..2d8e022190f62 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1023,6 +1023,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
                        OptionalAttr<IndexListArrayAttr>:$members_index,
                        Variadic<OpenMP_MapBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
                        OptionalAttr<UI64Attr>:$map_type,
+                       OptionalAttr<FlatSymbolRefAttr>:$mapper_id,
                        OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
                        OptionalAttr<StrAttr>:$name,
                        DefaultValuedAttr<BoolAttr, "false">:$partial_map);
@@ -1076,6 +1077,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
     - 'map_type': OpenMP map type for this map capture, for example: from, to and
        always. It's a bitfield composed of the OpenMP runtime flags stored in
        OpenMPOffloadMappingFlags.
+    - 'mapper_id': OpenMP mapper map type modifier for this map capture. It's used to
+       specify a user defined mapper to be used for mapping.
     - 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla
        this can affect how the variable is lowered.
     - `name`: Holds the name of variable as specified in user clause (including bounds).
@@ -1087,6 +1090,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
     `var_ptr` `(` $var_ptr `:` type($var_ptr) `,` $var_type `)`
     oilist(
         `var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)`
+      | `mapper` `(` $mapper_id `)`
       | `map_clauses` `(` custom<MapClause>($map_type) `)`
       | `capture` `(` custom<CaptureType>($map_capture_type) `)`
       | `members` `(` $members `:` custom<MembersIndex>($members_index) `:` type($members) `)`

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 60cf4045344a6..62e1c4c3ed3b1 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1632,7 +1632,13 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
 
         to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
       }
-    } else {
+
+      if (mapInfoOp.getMapperId() &&
+          !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
+              mapInfoOp, mapInfoOp.getMapperIdAttr())) {
+        return emitError(op->getLoc(), "invalid mapper id");
+      }
+    } else if (!isa<DeclareMapperInfoOp>(op)) {
       emitError(op->getLoc(), "map argument is not a map entry operation");
     }
   }

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index a2d34a88d2ebe..d7f468bed3d3d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2849,3 +2849,13 @@ func.func @missing_workshare(%idx : index) {
   ^bb0(%arg0: !llvm.ptr):
     omp.terminator
   }
+
+// -----
+llvm.func @invalid_mapper(%0 : !llvm.ptr) {
+  %1 = omp.map.info var_ptr(%0 : !llvm.ptr, !llvm.struct<"my_type", (i32)>) mapper(@my_mapper) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
+  // expected-error @below {{invalid mapper id}}
+  omp.target_data map_entries(%1 : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index f00d3d3426631..e318afbebbf0c 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2546,13 +2546,13 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
   // CHECK: %[[C_12:.*]] = llvm.mlir.constant(2 : index) : i64
   // CHECK: %[[C_13:.*]] = llvm.mlir.constant(2 : index) : i64
   // CHECK: %[[BOUNDS1:.*]] = omp.map.bounds   lower_bound(%[[C_11]] : i64) upper_bound(%[[C_10]] : i64) stride(%[[C_12]] : i64) start_idx(%[[C_13]] : i64)
-  // CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>)   map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""}
+  // CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>)   mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""}
     %6 = llvm.mlir.constant(9 : index) : i64
     %7 = llvm.mlir.constant(1 : index) : i64
     %8 = llvm.mlir.constant(2 : index) : i64
     %9 = llvm.mlir.constant(2 : index) : i64
     %10 = omp.map.bounds   lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%8 : i64) start_idx(%9 : i64)
-    %mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>)   map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""}
+    %mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>)   mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""}
 
     // CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr)
     omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) {


        


More information about the Mlir-commits mailing list