[Mlir-commits] [flang] [mlir] [mlir][OpenMP] Annotate `private` vars with `map_idx` when needed (PR #116770)

Kareem Ergawy llvmlistbot at llvm.org
Tue Nov 19 00:50:52 PST 2024


https://github.com/ergawy created https://github.com/llvm/llvm-project/pull/116770

This PR extends the MLIR representation for `omp.target` ops by adding a `map_idx` to `private` vars. This annotation stores the index of the map info operand corresponding to the private var. If the variable does not have a map operand, the `map_idx` attribute is either not present at all or its value is `-1`.

>From e218871c9b66ed588887a647cb0efcc6b0b40997 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 18 Nov 2024 23:12:31 -0600
Subject: [PATCH] [mlir][OpenMP] Annotate `private` vars with `map_idx` when
 needed

This PR extends the MLIR representation for `omp.target` ops by adding a
`map_idx` to `private` vars. This annotation stores the index of the map
info operand corresponding to the private var. If the variable does not
have a map operand, the `map_idx` attribute is either not present at all
or its value is `-1`.
---
 .../OpenMP/MapsForPrivatizedSymbols.cpp       |  15 ++-
 .../target-private-multiple-variables.f90     |  12 +--
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   6 +-
 mlir/include/mlir/Dialect/OpenMP/Utils.h      |  19 ++++
 mlir/lib/Dialect/OpenMP/CMakeLists.txt        |   1 +
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 100 ++++++++++++------
 mlir/lib/Dialect/OpenMP/IR/Utils.cpp          |  22 ++++
 mlir/test/Dialect/OpenMP/ops.mlir             |  24 +++++
 8 files changed, 158 insertions(+), 41 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/OpenMP/Utils.h
 create mode 100644 mlir/lib/Dialect/OpenMP/IR/Utils.cpp

diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 289e648eed8546..6e537300dfb7f1 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -1,5 +1,4 @@
-//===- MapsForPrivatizedSymbols.cpp
-//-----------------------------------------===//
+//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -28,8 +27,10 @@
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
@@ -124,6 +125,8 @@ class MapsForPrivatizedSymbolsPass
       if (targetOp.getPrivateVars().empty())
         return;
       OperandRange privVars = targetOp.getPrivateVars();
+      llvm::SmallVector<int64_t> privVarMapIdx;
+
       std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
       SmallVector<omp::MapInfoOp, 4> mapInfoOps;
       for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
@@ -133,17 +136,25 @@ class MapsForPrivatizedSymbolsPass
             SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
                 targetOp, privatizerName);
         if (!privatizerNeedsMap(privatizer)) {
+          privVarMapIdx.push_back(-1);
           continue;
         }
+
+        privVarMapIdx.push_back(targetOp.getMapVars().size() +
+                                mapInfoOps.size());
+
         builder.setInsertionPoint(targetOp);
         Location loc = targetOp.getLoc();
         omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
         mapInfoOps.push_back(mapInfoOp);
+
         LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
         LLVM_DEBUG(mapInfoOp.dump());
       }
       if (!mapInfoOps.empty()) {
         mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
+        targetOp.setPrivateMapsAttr(mlir::omp::utils::makeI64ArrayAttr(
+            privVarMapIdx, targetOp.getContext()));
       }
     });
     if (!mapInfoOpsForTarget.empty()) {
diff --git a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
index b0c76ff3845f83..602e98975e9dc5 100644
--- a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
+++ b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
@@ -171,12 +171,12 @@ end subroutine target_allocatable
 ! CHECK_SAME        %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] :
 ! CHECK-SAME        !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.ref<!fir.boxchar<1>>)
 ! CHECK-SAME:     private(
-! CHECK-SAME:       @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]],
-! CHECK-SAME:       @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]],
-! CHECK-SAME:       @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]],
-! CHECK-SAME:       @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]],
-! CHECK-SAME:       @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]],
-! CHECK-SAME:       @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] :
+! CHECK-SAME:       @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1],
+! CHECK-SAME:       @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2],
+! CHECK-SAME:       @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] :
 ! CHECK-SAME:       !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<f32>, !fir.ref<i64>, !fir.box<!fir.array<?xf32>>, !fir.ref<complex<f32>>, !fir.boxchar<1>) {
 ! CHECK-NOT:      fir.alloca
 ! CHECK:          hlfir.declare %[[ALLOC_ARG]]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 156e6eb371b85d..31ecbea8e0c211 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1227,6 +1227,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
     a device, if it is 0 then the target region is executed on the host device.
   }] # clausesDescription;
 
+  let arguments = !con(clausesArgs,
+                       (ins OptionalAttr<I64ArrayAttr>:$private_maps));
+
   let builders = [
     OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
   ];
@@ -1239,7 +1242,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
     custom<InReductionMapPrivateRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
-        $private_vars, type($private_vars), $private_syms) attr-dict
+        $private_vars, type($private_vars), $private_syms, $private_maps)
+        attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/Utils.h b/mlir/include/mlir/Dialect/OpenMP/Utils.h
new file mode 100644
index 00000000000000..f79e10b1e5ab38
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Utils.h
@@ -0,0 +1,19 @@
+//===- Utils.h - Utils for the OpenMP MLIR Dialect --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_UTILS_H_
+#define MLIR_DIALECT_OPENMP_UTILS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+                                 mlir::MLIRContext *context);
+} // namespace mlir::omp::utils
+
+#endif // MLIR_DIALECT_OPENMP_UTILS_H_
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d3445c151c..809bd1306563bd 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIROpenMPDialect
   IR/OpenMPDialect.cpp
+  IR/Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 94e71e089d4b18..4a13272b8f4a83 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
   llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   llvm::SmallVectorImpl<Type> &types;
   ArrayAttr &syms;
+  ArrayAttr *mapIndices;
   PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
-                   SmallVectorImpl<Type> &types, ArrayAttr &syms)
-      : vars(vars), types(types), syms(syms) {}
+                   SmallVectorImpl<Type> &types, ArrayAttr &syms,
+                   ArrayAttr *mapIndices=nullptr)
+      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionParseArgs {
   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
     SmallVectorImpl<Type> &types,
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
-    ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+    ArrayAttr *symbols = nullptr, ArrayAttr *mapIndices = nullptr,
+    DenseBoolArrayAttr *byref = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
+  SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
   unsigned regionArgOffset = regionPrivateArgs.size();
 
@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
             parser.parseArgument(regionPrivateArgs.emplace_back()))
           return failure();
 
+        if (mapIndices) {
+          if (parser.parseOptionalLSquare().succeeded()) {
+            if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
+                parser.parseInteger(mapIndicesVec.emplace_back()) ||
+                parser.parseRSquare())
+              return failure();
+          } else
+            mapIndicesVec.push_back(-1);
+        }
+
         return success();
       }))
     return failure();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
     *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
   }
 
+  if (!mapIndicesVec.empty())
+    *mapIndices = utils::makeI64ArrayAttr(mapIndicesVec, parser.getContext());
+
   if (byref)
     *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
 
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
 static ParseResult parseBlockArgClause(
     OpAsmParser &parser,
     llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
-    StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
+    StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
-    if (!reductionArgs)
+    if (!privateArgs)
       return failure();
 
-    if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
-                                         reductionArgs->types, entryBlockArgs,
-                                         &reductionArgs->syms)))
+    if (failed(parseClauseWithRegionArgs(
+            parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
+            &privateArgs->syms, privateArgs->mapIndices)))
       return failure();
   }
   return success();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
 
     if (failed(parseClauseWithRegionArgs(
             parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
-            &reductionArgs->syms, &reductionArgs->byref)))
+            &reductionArgs->syms, /*mapIndices=*/nullptr,
+            &reductionArgs->byref)))
       return failure();
   }
   return success();
@@ -674,12 +693,13 @@ static ParseResult parseInReductionMapPrivateRegion(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
     SmallVectorImpl<Type> &mapTypes,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, &privateMaps);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -776,8 +796,10 @@ struct PrivatePrintArgs {
   ValueRange vars;
   TypeRange types;
   ArrayAttr syms;
-  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
-      : vars(vars), types(types), syms(syms) {}
+  ArrayAttr mapIndices;
+  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
+                   ArrayAttr mapIndices)
+      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
   ValueRange vars;
@@ -804,6 +826,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
                                       ValueRange argsSubrange,
                                       ValueRange operands, TypeRange types,
                                       ArrayAttr symbols = nullptr,
+                                      ArrayAttr mapIndices = nullptr,
                                       DenseBoolArrayAttr byref = nullptr) {
   if (argsSubrange.empty())
     return;
@@ -815,21 +838,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
     symbols = ArrayAttr::get(ctx, values);
   }
 
+  if (!mapIndices) {
+    llvm::SmallVector<Attribute> values(operands.size(), nullptr);
+    mapIndices = ArrayAttr::get(ctx, values);
+  }
+
   if (!byref) {
     mlir::SmallVector<bool> values(operands.size(), false);
     byref = DenseBoolArrayAttr::get(ctx, values);
   }
 
-  llvm::interleaveComma(
-      llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
-      [&p](auto t) {
-        auto [op, arg, sym, isByRef] = t;
-        if (isByRef)
-          p << "byref ";
-        if (sym)
-          p << sym << " ";
-        p << op << " -> " << arg;
-      });
+  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
+                                        mapIndices, byref.asArrayRef()),
+                        p, [&p](auto t) {
+                          auto [op, arg, sym, map, isByRef] = t;
+                          if (isByRef)
+                            p << "byref ";
+                          if (sym)
+                            p << sym << " ";
+
+                          p << op << " -> " << arg;
+
+                          if (map)
+                            p << " [map_idx="
+                              << llvm::cast<IntegerAttr>(map).getInt() << "]";
+                        });
   p << " : ";
   llvm::interleaveComma(types, p);
   p << ") ";
@@ -849,7 +882,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
   if (privateArgs)
     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
                               privateArgs->vars, privateArgs->types,
-                              privateArgs->syms);
+                              privateArgs->syms, privateArgs->mapIndices);
 }
 
 static void
@@ -859,7 +892,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
   if (reductionArgs)
     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
                               reductionArgs->vars, reductionArgs->types,
-                              reductionArgs->syms, reductionArgs->byref);
+                              reductionArgs->syms, nullptr,
+                              reductionArgs->byref);
 }
 
 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
@@ -891,12 +925,13 @@ static void printInReductionMapPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
-    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
+    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
+    ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -908,7 +943,7 @@ static void printInReductionPrivateRegion(
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -921,7 +956,7 @@ static void printInReductionPrivateReductionRegion(
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms);
   printBlockArgRegion(p, op, region, args);
@@ -931,7 +966,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                ValueRange privateVars, TypeRange privateTypes,
                                ArrayAttr privateSyms) {
   AllRegionPrintArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -941,7 +976,7 @@ static void printPrivateReductionRegion(
     TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
     ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms);
   printBlockArgRegion(p, op, region, args);
@@ -1656,7 +1691,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
-                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
+                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
+                  /*private_maps=*/nullptr);
 }
 
 LogicalResult TargetOp::verify() {
diff --git a/mlir/lib/Dialect/OpenMP/IR/Utils.cpp b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
new file mode 100644
index 00000000000000..64e84817c818b2
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
@@ -0,0 +1,22 @@
+//===- Utils.cpp - Utils for the OpenMP MLIR Dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Utils.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+                                 mlir::MLIRContext *context) {
+  llvm::SmallVector<mlir::Attribute, 4> attrs;
+  attrs.reserve(values.size());
+  for (auto &v : values)
+    attrs.push_back(mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64),
+                                           mlir::APInt(64, v)));
+  return mlir::ArrayAttr::get(context, attrs);
+}
+} // namespace mlir::omp::utils
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c25a6ef4b4849b..94c63dd8e9aa0e 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2750,6 +2750,30 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
   return
 }
 
+// CHECK-LABEL: omp_target_private_with_map_idx
+func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+  %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+  %mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // CHECK: omp.target
+
+  // CHECK-SAME: map_entries(
+  // CHECK-SAME:   %{{[^[:space:]]+}} -> %[[MAP1_ARG:[^[:space:]]+]],
+  // CHECK-SAME:   %{{[^[:space:]]+}} -> %[[MAP2_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : memref<?xi32>, memref<?xi32>
+  // CHECK-SAME: )
+
+  // CHECK-SAME: private(
+  // CHECK-SAME:   @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]] [map_idx=1]
+  // CHECK-SAME:   : !llvm.ptr
+  // CHECK-SAME: )
+  omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg [map_idx=1] : !llvm.ptr) {
+    omp.terminator
+  }
+
+  return
+}
+
 // CHECK-LABEL: omp_loop
 func.func @omp_loop(%lb : index, %ub : index, %step : index) {
   // CHECK: omp.loop {



More information about the Mlir-commits mailing list