[Mlir-commits] [mlir] a24ed7d - [mlir][OpenMP] add attribute for privatization barrier (#140089)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 22 07:24:06 PDT 2025


Author: Tom Eccles
Date: 2025-05-22T15:24:02+01:00
New Revision: a24ed7d4775d119029bc8539c54fba03dba4366f

URL: https://github.com/llvm/llvm-project/commit/a24ed7d4775d119029bc8539c54fba03dba4366f
DIFF: https://github.com/llvm/llvm-project/commit/a24ed7d4775d119029bc8539c54fba03dba4366f.diff

LOG: [mlir][OpenMP] add attribute for privatization barrier (#140089)

A barrier is needed at the end of initialization/copying of private
variables if any of those variables is lastprivate. This ensures that
all firstprivate variables receive the original value of the variable
before the lastprivate clause overwrites it.

Previously this barrier was added by the flang fontend, but there is not
a reliable way to put the barrier in the correct place for delayed
privatization, and the OpenMP dialect could some day have other users.
It is important that there are safe ways to use the constructs available
in the dialect.

lastprivate is currently not modelled in the OpenMP dialect, and so
there is no way to reliably determine whether there were lastprivate
variables. Therefore the frontend will have to provide this information
through this new attribute.

Part of a series of patches to fix
https://github.com/llvm/llvm-project/issues/136357

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index f8e880ea43b75..16c14ef085d6d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1102,7 +1102,10 @@ class OpenMP_PrivateClauseSkip<
 
   let arguments = (ins
     Variadic<AnyType>:$private_vars,
-    OptionalAttr<SymbolRefArrayAttr>:$private_syms
+    OptionalAttr<SymbolRefArrayAttr>:$private_syms,
+    // Set this attribute if a barrier is needed after initialization and
+    // copying of lastprivate variables.
+    UnitAttr:$private_needs_barrier
   );
 
   // TODO: Add description.

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5a79fbf77a268..036c6a6e350a8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -213,8 +213,8 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -258,8 +258,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -317,8 +317,8 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -350,7 +350,7 @@ def SingleOp : OpenMP_Op<"single", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateRegion>($region, $private_vars, type($private_vars),
-        $private_syms) attr-dict
+        $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -505,8 +505,8 @@ def LoopOp : OpenMP_Op<"loop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let builders = [
@@ -557,8 +557,8 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -611,8 +611,8 @@ def SimdOp : OpenMP_Op<"simd", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
-        $reduction_syms) attr-dict
+        $private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
+        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -690,7 +690,7 @@ def DistributeOp : OpenMP_Op<"distribute", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateRegion>($region, $private_vars, type($private_vars),
-        $private_syms) attr-dict
+        $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -740,7 +740,7 @@ def TaskOp
     custom<InReductionPrivateRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms) attr-dict
+        type($private_vars), $private_syms, $private_needs_barrier) attr-dict
   }];
 
   let hasVerifier = 1;
@@ -816,8 +816,9 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     custom<InReductionPrivateReductionRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
-        type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
+        type($private_vars), $private_syms, $private_needs_barrier,
+        $reduction_mod, $reduction_vars, type($reduction_vars),
+        $reduction_byref, $reduction_syms) attr-dict
   }];
 
   let extraClassDeclaration = [{
@@ -1324,7 +1325,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
         $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
         type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
         $map_vars, type($map_vars), $private_vars, type($private_vars),
-        $private_syms, $private_maps) attr-dict
+        $private_syms, $private_needs_barrier, $private_maps) attr-dict
   }];
 
   let hasVerifier = 1;

diff  --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 233739e1d6d91..71786e856c6db 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -450,6 +450,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
+        /* private_needs_barrier = */ nullptr,
         /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
         /* reduction_mod = */ nullptr,
         /* reduction_vars = */ llvm::SmallVector<Value>{},

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index deff86d5c5ecb..e94d570b57122 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -581,11 +581,14 @@ struct PrivateParseArgs {
   llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   llvm::SmallVectorImpl<Type> &types;
   ArrayAttr &syms;
+  UnitAttr &needsBarrier;
   DenseI64ArrayAttr *mapIndices;
   PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
                    SmallVectorImpl<Type> &types, ArrayAttr &syms,
+                   UnitAttr &needsBarrier,
                    DenseI64ArrayAttr *mapIndices = nullptr)
-      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
+      : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
+        mapIndices(mapIndices) {}
 };
 
 struct ReductionParseArgs {
@@ -613,6 +616,10 @@ struct AllRegionParseArgs {
 };
 } // namespace
 
+static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
+  return "private_barrier";
+}
+
 static ParseResult parseClauseWithRegionArgs(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
@@ -620,7 +627,8 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
     ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
     DenseBoolArrayAttr *byref = nullptr,
-    ReductionModifierAttr *modifier = nullptr) {
+    ReductionModifierAttr *modifier = nullptr,
+    UnitAttr *needsBarrier = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
   SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
@@ -688,6 +696,12 @@ static ParseResult parseClauseWithRegionArgs(
   if (parser.parseRParen())
     return failure();
 
+  if (needsBarrier) {
+    if (parser.parseOptionalKeyword(getPrivateNeedsBarrierSpelling())
+            .succeeded())
+      *needsBarrier = mlir::UnitAttr::get(parser.getContext());
+  }
+
   auto *argsBegin = regionPrivateArgs.begin();
   MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
                                argsBegin + regionArgOffset + types.size());
@@ -735,7 +749,8 @@ static ParseResult parseBlockArgClause(
 
     if (failed(parseClauseWithRegionArgs(
             parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
-            &privateArgs->syms, privateArgs->mapIndices)))
+            &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
+            /*modifier=*/nullptr, &privateArgs->needsBarrier)))
       return failure();
   }
   return success();
@@ -824,7 +839,7 @@ static ParseResult parseTargetOpRegion(
     SmallVectorImpl<Type> &mapTypes,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    DenseI64ArrayAttr &privateMaps) {
+    UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
@@ -832,7 +847,7 @@ static ParseResult parseTargetOpRegion(
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
-                           &privateMaps);
+                           privateNeedsBarrier, &privateMaps);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -842,11 +857,13 @@ static ParseResult parseInReductionPrivateRegion(
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    UnitAttr &privateNeedsBarrier) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -857,14 +874,15 @@ static ParseResult parseInReductionPrivateReductionRegion(
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    ReductionModifierAttr &reductionMod,
+    UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
@@ -873,9 +891,11 @@ static ParseResult parseInReductionPrivateReductionRegion(
 static ParseResult parsePrivateRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    UnitAttr &privateNeedsBarrier) {
   AllRegionParseArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -883,12 +903,13 @@ static ParseResult parsePrivateReductionRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
-    ReductionModifierAttr &reductionMod,
+    UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
@@ -931,10 +952,12 @@ struct PrivatePrintArgs {
   ValueRange vars;
   TypeRange types;
   ArrayAttr syms;
+  UnitAttr needsBarrier;
   DenseI64ArrayAttr mapIndices;
   PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
-                   DenseI64ArrayAttr mapIndices)
-      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
+                   UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
+      : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
+        mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
   ValueRange vars;
@@ -964,7 +987,7 @@ static void printClauseWithRegionArgs(
     ValueRange argsSubrange, ValueRange operands, TypeRange types,
     ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
     DenseBoolArrayAttr byref = nullptr,
-    ReductionModifierAttr modifier = nullptr) {
+    ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
   if (argsSubrange.empty())
     return;
 
@@ -1006,6 +1029,9 @@ static void printClauseWithRegionArgs(
   p << " : ";
   llvm::interleaveComma(types, p);
   p << ") ";
+
+  if (needsBarrier)
+    p << getPrivateNeedsBarrierSpelling() << " ";
 }
 
 static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
@@ -1020,9 +1046,10 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
                                 StringRef clauseName, ValueRange argsSubrange,
                                 std::optional<PrivatePrintArgs> privateArgs) {
   if (privateArgs)
-    printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
-                              privateArgs->vars, privateArgs->types,
-                              privateArgs->syms, privateArgs->mapIndices);
+    printClauseWithRegionArgs(
+        p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
+        privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
+        /*modifier=*/nullptr, privateArgs->needsBarrier);
 }
 
 static void
@@ -1068,23 +1095,23 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
 
 // These parseXyz functions correspond to the custom<Xyz> definitions
 // in the .td file(s).
-static void
-printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region,
-                    ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
-                    ValueRange hostEvalVars, TypeRange hostEvalTypes,
-                    ValueRange inReductionVars, TypeRange inReductionTypes,
-                    DenseBoolArrayAttr inReductionByref,
-                    ArrayAttr inReductionSyms, ValueRange mapVars,
-                    TypeRange mapTypes, ValueRange privateVars,
-                    TypeRange privateTypes, ArrayAttr privateSyms,
-                    DenseI64ArrayAttr privateMaps) {
+static void printTargetOpRegion(
+    OpAsmPrinter &p, Operation *op, Region &region,
+    ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
+    ValueRange hostEvalVars, TypeRange hostEvalTypes,
+    ValueRange inReductionVars, TypeRange inReductionTypes,
+    DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
+    ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
+    TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
+    DenseI64ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -1092,11 +1119,12 @@ static void printInReductionPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms) {
+    ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   printBlockArgRegion(p, op, region, args);
 }
@@ -1105,13 +1133,15 @@ static void printInReductionPrivateReductionRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
-    ValueRange reductionVars, TypeRange reductionTypes,
-    DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
+    ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
+    ReductionModifierAttr reductionMod, ValueRange reductionVars,
+    TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
+    ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, reductionMod);
@@ -1120,21 +1150,24 @@ static void printInReductionPrivateReductionRegion(
 
 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                ValueRange privateVars, TypeRange privateTypes,
-                               ArrayAttr privateSyms) {
+                               ArrayAttr privateSyms,
+                               UnitAttr privateNeedsBarrier) {
   AllRegionPrintArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
 static void printPrivateReductionRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
-    TypeRange privateTypes, ArrayAttr privateSyms,
+    TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
     ReductionModifierAttr reductionMod, ValueRange reductionVars,
     TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
     ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
+                           privateNeedsBarrier,
                            /*mapIndices=*/nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms, reductionMod);
@@ -1916,7 +1949,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
-                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
+                  makeArrayAttr(ctx, clauses.privateSyms),
+                  clauses.privateNeedsBarrier, clauses.threadLimit,
                   /*private_maps=*/nullptr);
 }
 
@@ -2180,7 +2214,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
-                    /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
+                    /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
+                    /*proc_bind_kind=*/nullptr,
                     /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
                     /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
   state.addAttributes(attributes);
@@ -2192,8 +2227,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                     clauses.ifExpr, clauses.numThreads, clauses.privateVars,
                     makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.procBindKind, clauses.reductionMod,
-                    clauses.reductionVars,
+                    clauses.privateNeedsBarrier, clauses.procBindKind,
+                    clauses.reductionMod, clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                     makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -2297,11 +2332,12 @@ static bool opInGlobalImplicitParallelRegion(Operation *op) {
 void TeamsOp::build(OpBuilder &builder, OperationState &state,
                     const TeamsOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: privateVars, privateSyms.
+  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
                  /*private_vars=*/{}, /*private_syms=*/nullptr,
-                 clauses.reductionMod, clauses.reductionVars,
+                 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
+                 clauses.reductionVars,
                  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                  makeArrayAttr(ctx, clauses.reductionSyms),
                  clauses.threadLimit);
@@ -2358,11 +2394,11 @@ OperandRange SectionOp::getReductionVars() {
 void SectionsOp::build(OpBuilder &builder, OperationState &state,
                        const SectionsOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: privateVars, privateSyms.
+  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                     clauses.nowait, /*private_vars=*/{},
-                    /*private_syms=*/nullptr, clauses.reductionMod,
-                    clauses.reductionVars,
+                    /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
+                    clauses.reductionMod, clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                     makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -2394,11 +2430,12 @@ LogicalResult SectionsOp::verifyRegions() {
 void SingleOp::build(OpBuilder &builder, OperationState &state,
                      const SingleOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: privateVars, privateSyms.
+  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                   clauses.copyprivateVars,
                   makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
-                  /*private_vars=*/{}, /*private_syms=*/nullptr);
+                  /*private_vars=*/{}, /*private_syms=*/nullptr,
+                  /*private_needs_barrier=*/nullptr);
 }
 
 LogicalResult SingleOp::verify() {
@@ -2474,8 +2511,9 @@ void LoopOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
 
   LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
-                makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
-                clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
+                makeArrayAttr(ctx, clauses.privateSyms),
+                clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
+                clauses.reductionMod, clauses.reductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                 makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -2503,6 +2541,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
         /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
         /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
+        /*private_needs_barrier=*/false,
         /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
         /*reduction_byref=*/nullptr,
         /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
@@ -2514,18 +2553,17 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
 void WsloopOp::build(OpBuilder &builder, OperationState &state,
                      const WsloopOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
-  // privateSyms.
-  WsloopOp::build(builder, state,
-                  /*allocate_vars=*/{}, /*allocator_vars=*/{},
-                  clauses.linearVars, clauses.linearStepVars, clauses.nowait,
-                  clauses.order, clauses.orderMod, clauses.ordered,
-                  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
-                  clauses.reductionMod, clauses.reductionVars,
-                  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                  makeArrayAttr(ctx, clauses.reductionSyms),
-                  clauses.scheduleKind, clauses.scheduleChunk,
-                  clauses.scheduleMod, clauses.scheduleSimd);
+  // TODO: Store clauses in op: allocateVars, allocatorVars
+  WsloopOp::build(
+      builder, state,
+      /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
+      clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
+      clauses.ordered, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
+      clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
 }
 
 LogicalResult WsloopOp::verify() {
@@ -2565,14 +2603,14 @@ LogicalResult WsloopOp::verifyRegions() {
 void SimdOp::build(OpBuilder &builder, OperationState &state,
                    const SimdOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
-  // privateSyms.
+  // TODO Store clauses in op: linearVars, linearStepVars
   SimdOp::build(builder, state, clauses.alignedVars,
                 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
                 /*linear_vars=*/{}, /*linear_step_vars=*/{},
                 clauses.nontemporalVars, clauses.order, clauses.orderMod,
                 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
-                clauses.reductionMod, clauses.reductionVars,
+                clauses.privateNeedsBarrier, clauses.reductionMod,
+                clauses.reductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
                 clauses.simdlen);
@@ -2622,7 +2660,8 @@ void DistributeOp::build(OpBuilder &builder, OperationState &state,
                       clauses.allocatorVars, clauses.distScheduleStatic,
                       clauses.distScheduleChunkSize, clauses.order,
                       clauses.orderMod, clauses.privateVars,
-                      makeArrayAttr(builder.getContext(), clauses.privateSyms));
+                      makeArrayAttr(builder.getContext(), clauses.privateSyms),
+                      clauses.privateNeedsBarrier);
 }
 
 LogicalResult DistributeOp::verify() {
@@ -2778,7 +2817,8 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
                 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
                 clauses.priority, /*private_vars=*/clauses.privateVars,
                 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
-                clauses.untied, clauses.eventHandle);
+                clauses.privateNeedsBarrier, clauses.untied,
+                clauses.eventHandle);
 }
 
 LogicalResult TaskOp::verify() {
@@ -2817,18 +2857,18 @@ LogicalResult TaskgroupOp::verify() {
 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
                        const TaskloopOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.final, clauses.grainsizeMod, clauses.grainsize,
-                    clauses.ifExpr, clauses.inReductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
-                    makeArrayAttr(ctx, clauses.inReductionSyms),
-                    clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
-                    clauses.numTasks, clauses.priority,
-                    /*private_vars=*/clauses.privateVars,
-                    /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.reductionMod, clauses.reductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                    makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
+  TaskloopOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
+      clauses.inReductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
+      makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
+      clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
+      /*private_vars=*/clauses.privateVars,
+      /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
+      clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
 }
 
 LogicalResult TaskloopOp::verify() {

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index a9e4af035dbd7..47cfc5278a5d0 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2876,6 +2876,23 @@ func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
   return
 }
 
+// CHECK-LABEL: parallel_op_privatizers_barrier
+// CHECK-SAME: (%[[ARG0:[^[:space:]]+]]: !llvm.ptr, %[[ARG1:[^[:space:]]+]]: !llvm.ptr)
+func.func @parallel_op_privatizers_barrier(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  // CHECK: omp.parallel private(
+  // CHECK-SAME: @x.privatizer %[[ARG0]] -> %[[ARG0_PRIV:[^[:space:]]+]],
+  // CHECK-SAME: @y.privatizer %[[ARG1]] -> %[[ARG1_PRIV:[^[:space:]]+]] : !llvm.ptr, !llvm.ptr)
+  // CHECK-SAME: private_barrier
+  omp.parallel private(@x.privatizer %arg0 -> %arg2, @y.privatizer %arg1 -> %arg3 : !llvm.ptr, !llvm.ptr) private_barrier {
+    // CHECK: llvm.load %[[ARG0_PRIV]]
+    %0 = llvm.load %arg2 : !llvm.ptr -> i32
+    // CHECK: llvm.load %[[ARG1_PRIV]]
+    %1 = llvm.load %arg3 : !llvm.ptr -> i32
+    omp.terminator
+  }
+  return
+}
+
 // CHECK-LABEL: omp.private {type = private} @a.privatizer : !llvm.ptr init {
 omp.private {type = private} @a.privatizer : !llvm.ptr init {
 // CHECK: ^bb0(%{{.*}}: {{.*}}, %{{.*}}: {{.*}}):


        


More information about the Mlir-commits mailing list