[Mlir-commits] [mlir] [mlir][OpenMP] add attribute for privatization barrier (PR #140089)

Tom Eccles llvmlistbot at llvm.org
Thu May 15 09:05:59 PDT 2025


https://github.com/tblah created https://github.com/llvm/llvm-project/pull/140089

A barrier is needed at the end of initialization/copying of private variables if any of those variables is lastprivate. This ensures that all firstprivate variables receive the original value of the variable before the lastprivate clause overwrites it.

Previously this barrier was added by the flang fontend, but there is not a reliable way to put the barrier in the correct place for delayed privatization, and the OpenMP dialect could some day have other users. It is important that there are safe ways to use the constructs available in the dialect.

lastprivate is currently not modelled in the OpenMP dialect, and so there is no way to reliably determine whether there were lastprivate variables. Therefore the frontend will have to provide this information through this new attribute.

Part of a series of patches to fix
https://github.com/llvm/llvm-project/issues/136357

>From 5e2e1a76d93812fdb1ba6730c819b40aaeb99cb0 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 14 May 2025 16:49:32 +0000
Subject: [PATCH] [mlir][OpenMP] add attribute for privatization barrier

A barrier is needed at the end of initialization/copying of private
variables if any of those variables is lastprivate. This ensures that
all firstprivate variables receive the original value of the variable
before the lastprivate clause overwrites it.

Previously this barrier was added by the flang fontend, but there is not
a reliable way to put the barrier in the correct place for delayed
privatization, and the OpenMP dialect could some day have other users.
It is important that there are safe ways to use the constructs available
in the dialect.

lastprivate is currently not modelled in the OpenMP dialect, and so
there is no way to reliably determine whether there were lastprivate
variables. Therefore the frontend will have to provide this information
through this new attribute.

Part of a series of patches to fix
https://github.com/llvm/llvm-project/issues/136357
---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |   5 +-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  37 ++--
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |   1 +
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 196 +++++++++++-------
 mlir/test/Dialect/OpenMP/ops.mlir             |  17 ++
 5 files changed, 159 insertions(+), 97 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index f8e880ea43b75..16c14ef085d6d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1102,7 +1102,10 @@ class OpenMP_PrivateClauseSkip<
 
   let arguments = (ins
     Variadic<AnyType>:$private_vars,
-    OptionalAttr<SymbolRefArrayAttr>:$private_syms
+    OptionalAttr<SymbolRefArrayAttr>:$private_syms,
+    // Set this attribute if a barrier is needed after initialization and
+    // copying of lastprivate variables.
+    UnitAttr:$private_needs_barrier
   );
 
   // TODO: Add description.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5a79fbf77a268..036c6a6e350a8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -213,8 +213,8 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -258,8 +258,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -317,8 +317,8 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -350,7 +350,7 @@ def SingleOp : OpenMP_Op<"single", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateRegion>($region, $private_vars, type($private_vars),
-        $private_syms) attr-dict
+        $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -505,8 +505,8 @@ def LoopOp : OpenMP_Op<"loop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let builders = [
@@ -557,8 +557,8 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -611,8 +611,8 @@ def SimdOp : OpenMP_Op<"simd", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -690,7 +690,7 @@ def DistributeOp : OpenMP_Op<"distribute", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateRegion>($region, $private_vars, type($private_vars),
-        $private_syms) attr-dict
+        $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -740,7 +740,7 @@ def TaskOp
     custom<InReductionPrivateRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms) attr-dict
+        type($private_vars), $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -816,8 +816,9 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     custom<InReductionPrivateReductionRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
-        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
+        type($private_vars), $private_syms, $private_needs_barrier,
+        $reduction_mod, $reduction_vars, type($reduction_vars),
+        $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let extraClassDeclaration = [{
@@ -1324,7 +1325,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
         $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
         type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
         $map_vars, type($map_vars), $private_vars, type($private_vars),
-        $private_syms, $private_maps) attr-dict
+        $private_syms, $private_needs_barrier, $private_maps) attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 233739e1d6d91..71786e856c6db 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -450,6 +450,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
+        /* private_needs_barrier = */ nullptr,
         /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
         /* reduction_mod = */ nullptr,
         /* reduction_vars = */ llvm::SmallVector<Value>{},
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2bf7aaa46db11..57a54f21fe9de 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -581,11 +581,14 @@ struct PrivateParseArgs {
   llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   llvm::SmallVectorImpl<Type> &types;
   ArrayAttr &syms;
+  UnitAttr &needsBarrier;
   DenseI64ArrayAttr *mapIndices;
   PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
                    SmallVectorImpl<Type> &types, ArrayAttr &syms,
+                   UnitAttr &needsBarrier,
                    DenseI64ArrayAttr *mapIndices = nullptr)
-      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
+      : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
+        mapIndices(mapIndices) {}
 };
 
 struct ReductionParseArgs {
@@ -613,6 +616,10 @@ struct AllRegionParseArgs {
 };
 } // namespace
 
+static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
+  return "private_barrier";
+}
+
 static ParseResult parseClauseWithRegionArgs(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
@@ -620,7 +627,8 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
     ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
     DenseBoolArrayAttr *byref = nullptr,
-    ReductionModifierAttr *modifier = nullptr) {
+    ReductionModifierAttr *modifier = nullptr,
+    UnitAttr *needsBarrier = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
   SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
@@ -688,6 +696,12 @@ static ParseResult parseClauseWithRegionArgs(
   if (parser.parseRParen())
     return failure();
 
+  if (needsBarrier) {
+    if (parser.parseOptionalKeyword(getPrivateNeedsBarrierSpelling())
+            .succeeded())
+      *needsBarrier = mlir::UnitAttr::get(parser.getContext());
+  }
+
   auto *argsBegin = regionPrivateArgs.begin();
   MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
                                argsBegin + regionArgOffset + types.size());
@@ -735,7 +749,8 @@ static ParseResult parseBlockArgClause(
 
     if (failed(parseClauseWithRegionArgs(
             parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
-            &privateArgs->syms, privateArgs->mapIndices)))
+            &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
+            /*modifier=*/nullptr, &privateArgs->needsBarrier)))
       return failure();
   }
   return success();
@@ -824,7 +839,7 @@ static ParseResult parseTargetOpRegion(
     SmallVectorImpl<Type> &mapTypes,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    DenseI64ArrayAttr &privateMaps) {
+    UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
@@ -832,7 +847,7 @@ static ParseResult parseTargetOpRegion(
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
-                           &privateMaps);
+                           privateNeedsBarrier, &privateMaps);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -842,11 +857,13 @@ static ParseResult parseInReductionPrivateRegion(
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    UnitAttr &privateNeedsBarrier) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -857,14 +874,15 @@ static ParseResult parseInReductionPrivateReductionRegion(
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    ReductionModifierAttr &reductionMod,
+    UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
@@ -873,9 +891,11 @@ static ParseResult parseInReductionPrivateReductionRegion(
 static ParseResult parsePrivateRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    UnitAttr &privateNeedsBarrier) {
   AllRegionParseArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -883,12 +903,13 @@ static ParseResult parsePrivateReductionRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    ReductionModifierAttr &reductionMod,
+    UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
@@ -931,10 +952,12 @@ struct PrivatePrintArgs {
   ValueRange vars;
   TypeRange types;
   ArrayAttr syms;
+  UnitAttr needsBarrier;
   DenseI64ArrayAttr mapIndices;
   PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
-                   DenseI64ArrayAttr mapIndices)
-      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
+                   UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
+      : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
+        mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
   ValueRange vars;
@@ -964,7 +987,7 @@ static void printClauseWithRegionArgs(
     ValueRange argsSubrange, ValueRange operands, TypeRange types,
     ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
     DenseBoolArrayAttr byref = nullptr,
-    ReductionModifierAttr modifier = nullptr) {
+    ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
   if (argsSubrange.empty())
     return;
 
@@ -1006,6 +1029,9 @@ static void printClauseWithRegionArgs(
   p << " : ";
   llvm::interleaveComma(types, p);
   p << ") ";
+
+  if (needsBarrier)
+    p << getPrivateNeedsBarrierSpelling() << " ";
 }
 
 static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
@@ -1020,9 +1046,10 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
                                 StringRef clauseName, ValueRange argsSubrange,
                                 std::optional<PrivatePrintArgs> privateArgs) {
   if (privateArgs)
-    printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
-                              privateArgs->vars, privateArgs->types,
-                              privateArgs->syms, privateArgs->mapIndices);
+    printClauseWithRegionArgs(
+        p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
+        privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
+        /*modifier=*/nullptr, privateArgs->needsBarrier);
 }
 
 static void
@@ -1068,23 +1095,23 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
 
 // These parseXyz functions correspond to the custom<Xyz> definitions
 // in the .td file(s).
-static void
-printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region,
-                    ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
-                    ValueRange hostEvalVars, TypeRange hostEvalTypes,
-                    ValueRange inReductionVars, TypeRange inReductionTypes,
-                    DenseBoolArrayAttr inReductionByref,
-                    ArrayAttr inReductionSyms, ValueRange mapVars,
-                    TypeRange mapTypes, ValueRange privateVars,
-                    TypeRange privateTypes, ArrayAttr privateSyms,
-                    DenseI64ArrayAttr privateMaps) {
+static void printTargetOpRegion(
+    OpAsmPrinter &p, Operation *op, Region &region,
+    ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
+    ValueRange hostEvalVars, TypeRange hostEvalTypes,
+    ValueRange inReductionVars, TypeRange inReductionTypes,
+    DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
+    ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
+    TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
+    DenseI64ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -1092,11 +1119,12 @@ static void printInReductionPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms) {
+    ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   printBlockArgRegion(p, op, region, args);
 }
@@ -1105,13 +1133,15 @@ static void printInReductionPrivateReductionRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
-    ValueRange reductionVars, TypeRange reductionTypes,
-    DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
+    ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
+    ReductionModifierAttr reductionMod, ValueRange reductionVars,
+    TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
+    ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, reductionMod);
@@ -1120,21 +1150,24 @@ static void printInReductionPrivateReductionRegion(
 
 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                ValueRange privateVars, TypeRange privateTypes,
-                               ArrayAttr privateSyms) {
+                               ArrayAttr privateSyms,
+                               UnitAttr privateNeedsBarrier) {
   AllRegionPrintArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
 static void printPrivateReductionRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
-    TypeRange privateTypes, ArrayAttr privateSyms,
+    TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
     ReductionModifierAttr reductionMod, ValueRange reductionVars,
     TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
     ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, reductionMod);
@@ -1884,7 +1917,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
-                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
+                  makeArrayAttr(ctx, clauses.privateSyms),
+                  clauses.privateNeedsBarrier, clauses.threadLimit,
                   /*private_maps=*/nullptr);
 }
 
@@ -2149,7 +2183,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
-                    /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
+                    /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
+                    /*proc_bind_kind=*/nullptr,
                     /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
                     /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
   state.addAttributes(attributes);
@@ -2161,8 +2196,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                     clauses.ifExpr, clauses.numThreads, clauses.privateVars,
                     makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.procBindKind, clauses.reductionMod,
-                    clauses.reductionVars,
+                    clauses.privateNeedsBarrier, clauses.procBindKind,
+                    clauses.reductionMod, clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                     makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -2266,11 +2301,12 @@ static bool opInGlobalImplicitParallelRegion(Operation *op) {
 void TeamsOp::build(OpBuilder &builder, OperationState &state,
                     const TeamsOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: privateVars, privateSyms.
+  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
                  /*private_vars=*/{}, /*private_syms=*/nullptr,
-                 clauses.reductionMod, clauses.reductionVars,
+                 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
+                 clauses.reductionVars,
                  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                  makeArrayAttr(ctx, clauses.reductionSyms),
                  clauses.threadLimit);
@@ -2327,11 +2363,11 @@ OperandRange SectionOp::getReductionVars() {
 void SectionsOp::build(OpBuilder &builder, OperationState &state,
                        const SectionsOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: privateVars, privateSyms.
+  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                     clauses.nowait, /*private_vars=*/{},
-                    /*private_syms=*/nullptr, clauses.reductionMod,
-                    clauses.reductionVars,
+                    /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
+                    clauses.reductionMod, clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                     makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -2363,11 +2399,12 @@ LogicalResult SectionsOp::verifyRegions() {
 void SingleOp::build(OpBuilder &builder, OperationState &state,
                      const SingleOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: privateVars, privateSyms.
+  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                   clauses.copyprivateVars,
                   makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
-                  /*private_vars=*/{}, /*private_syms=*/nullptr);
+                  /*private_vars=*/{}, /*private_syms=*/nullptr,
+                  /*private_needs_barrier=*/nullptr);
 }
 
 LogicalResult SingleOp::verify() {
@@ -2443,8 +2480,9 @@ void LoopOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
 
   LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
-                makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
-                clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
+                makeArrayAttr(ctx, clauses.privateSyms),
+                clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
+                clauses.reductionMod, clauses.reductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                 makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -2472,6 +2510,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
         /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
         /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
+        /*private_needs_barrier=*/false,
         /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
         /*reduction_byref=*/nullptr,
         /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
@@ -2483,18 +2522,17 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
 void WsloopOp::build(OpBuilder &builder, OperationState &state,
                      const WsloopOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
-  // privateSyms.
-  WsloopOp::build(builder, state,
-                  /*allocate_vars=*/{}, /*allocator_vars=*/{},
-                  clauses.linearVars, clauses.linearStepVars, clauses.nowait,
-                  clauses.order, clauses.orderMod, clauses.ordered,
-                  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
-                  clauses.reductionMod, clauses.reductionVars,
-                  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                  makeArrayAttr(ctx, clauses.reductionSyms),
-                  clauses.scheduleKind, clauses.scheduleChunk,
-                  clauses.scheduleMod, clauses.scheduleSimd);
+  // TODO: Store clauses in op: allocateVars, allocatorVars
+  WsloopOp::build(
+      builder, state,
+      /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
+      clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
+      clauses.ordered, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
+      clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
 }
 
 LogicalResult WsloopOp::verify() {
@@ -2534,14 +2572,14 @@ LogicalResult WsloopOp::verifyRegions() {
 void SimdOp::build(OpBuilder &builder, OperationState &state,
                    const SimdOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
-  // privateSyms.
+  // TODO Store clauses in op: linearVars, linearStepVars
   SimdOp::build(builder, state, clauses.alignedVars,
                 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
                 /*linear_vars=*/{}, /*linear_step_vars=*/{},
                 clauses.nontemporalVars, clauses.order, clauses.orderMod,
                 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
-                clauses.reductionMod, clauses.reductionVars,
+                clauses.privateNeedsBarrier, clauses.reductionMod,
+                clauses.reductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
                 clauses.simdlen);
@@ -2591,7 +2629,8 @@ void DistributeOp::build(OpBuilder &builder, OperationState &state,
                       clauses.allocatorVars, clauses.distScheduleStatic,
                       clauses.distScheduleChunkSize, clauses.order,
                       clauses.orderMod, clauses.privateVars,
-                      makeArrayAttr(builder.getContext(), clauses.privateSyms));
+                      makeArrayAttr(builder.getContext(), clauses.privateSyms),
+                      clauses.privateNeedsBarrier);
 }
 
 LogicalResult DistributeOp::verify() {
@@ -2747,7 +2786,8 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
                 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
                 clauses.priority, /*private_vars=*/clauses.privateVars,
                 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
-                clauses.untied, clauses.eventHandle);
+                clauses.privateNeedsBarrier, clauses.untied,
+                clauses.eventHandle);
 }
 
 LogicalResult TaskOp::verify() {
@@ -2786,18 +2826,18 @@ LogicalResult TaskgroupOp::verify() {
 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
                        const TaskloopOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.final, clauses.grainsizeMod, clauses.grainsize,
-                    clauses.ifExpr, clauses.inReductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
-                    makeArrayAttr(ctx, clauses.inReductionSyms),
-                    clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
-                    clauses.numTasks, clauses.priority,
-                    /*private_vars=*/clauses.privateVars,
-                    /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.reductionMod, clauses.reductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                    makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
+  TaskloopOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
+      clauses.inReductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
+      makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
+      clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
+      /*private_vars=*/clauses.privateVars,
+      /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
+      clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
 }
 
 LogicalResult TaskloopOp::verify() {
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index b7e16b7ec35e2..3eef3799c4b45 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2872,6 +2872,23 @@ func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
   return
 }
 
+// CHECK-LABEL: parallel_op_privatizers_barrier
+// CHECK-SAME: (%[[ARG0:[^[:space:]]+]]: !llvm.ptr, %[[ARG1:[^[:space:]]+]]: !llvm.ptr)
+func.func @parallel_op_privatizers_barrier(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  // CHECK: omp.parallel private(
+  // CHECK-SAME: @x.privatizer %[[ARG0]] -> %[[ARG0_PRIV:[^[:space:]]+]],
+  // CHECK-SAME: @y.privatizer %[[ARG1]] -> %[[ARG1_PRIV:[^[:space:]]+]] : !llvm.ptr, !llvm.ptr)
+  // CHECK-SAME: private_barrier
+  omp.parallel private(@x.privatizer %arg0 -> %arg2, @y.privatizer %arg1 -> %arg3 : !llvm.ptr, !llvm.ptr) private_barrier {
+    // CHECK: llvm.load %[[ARG0_PRIV]]
+    %0 = llvm.load %arg2 : !llvm.ptr -> i32
+    // CHECK: llvm.load %[[ARG1_PRIV]]
+    %1 = llvm.load %arg3 : !llvm.ptr -> i32
+    omp.terminator
+  }
+  return
+}
+
 // CHECK-LABEL: omp.private {type = private} @a.privatizer : !llvm.ptr init {
 omp.private {type = private} @a.privatizer : !llvm.ptr init {
 // CHECK: ^bb0(%{{.*}}: {{.*}}, %{{.*}}: {{.*}}):



More information about the Mlir-commits mailing list