[flang-commits] [flang] 2507314 - [OpenMP][mlir] Add DynGroupPrivateClause in omp dialect (#153562)

via flang-commits flang-commits at lists.llvm.org
Mon May 4 03:11:44 PDT 2026


Author: Chaitanya
Date: 2026-05-04T15:41:37+05:30
New Revision: 2507314720b881cc5822b3659e07e036d562099b

URL: https://github.com/llvm/llvm-project/commit/2507314720b881cc5822b3659e07e036d562099b
DIFF: https://github.com/llvm/llvm-project/commit/2507314720b881cc5822b3659e07e036d562099b.diff

LOG: [OpenMP][mlir] Add DynGroupPrivateClause in omp dialect (#153562)

- The `dyn_groupprivate` clause allows to dynamically allocate
group-private memory in OpenMP parallel regions, specifically for
`target` and `teams` directives.
- This clause enables runtime-sized private memory allocation and
applicable to target and teams ops.

This PR enables dyn_groupprivate clause in openmp mlir dialect and adds
it to Teams and Target ops. Also includes parser, printer and
verification for clause.

Added: 
    mlir/test/Target/LLVMIR/openmp-target-dyn-groupprivate.mlir

Modified: 
    flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
    flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir
    mlir/test/Target/LLVMIR/openmp-todo.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index 475ed35cac9fa..b78b33f31b142 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -313,6 +313,7 @@ class FunctionFilteringPass
       targetOp.setDependIteratedKindsAttr(nullptr);
       targetOp.getDeviceMutable().clear();
       targetOp.getIfExprMutable().clear();
+      targetOp.getDynGroupprivateSizeMutable().clear();
 
       // TODO: Clear some of these operands rather than rewriting them,
       // depending on whether they are needed by device codegen once support for

diff  --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 8a08f67006c0a..1d98018ffbd05 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -761,7 +761,9 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
       targetOp.getAllocatorVars(), targetOp.getBareAttr(),
       targetOp.getDependKindsAttr(), targetOp.getDependVars(),
       targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
-      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
+      targetOp.getDynGroupprivateFallbackAttr(),
+      targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
       targetOp.getHostEvalVars(), targetOp.getIfExpr(),
       targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
       targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
@@ -1482,8 +1484,10 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
       targetOp.getAllocatorVars(), targetOp.getBareAttr(),
       targetOp.getDependKindsAttr(), targetOp.getDependVars(),
       targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
-      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars,
-      targetOp.getIfExpr(), targetOp.getInReductionVars(),
+      targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
+      targetOp.getDynGroupprivateFallbackAttr(),
+      targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
+      preHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
       targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
@@ -1573,7 +1577,9 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
       targetOp.getAllocatorVars(), targetOp.getBareAttr(),
       targetOp.getDependKindsAttr(), targetOp.getDependVars(),
       targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
-      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
+      targetOp.getDynGroupprivateFallbackAttr(),
+      targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
       isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
       targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
@@ -1654,8 +1660,10 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
       targetOp.getAllocatorVars(), targetOp.getBareAttr(),
       targetOp.getDependKindsAttr(), targetOp.getDependVars(),
       targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
-      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars,
-      targetOp.getIfExpr(), targetOp.getInReductionVars(),
+      targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
+      targetOp.getDynGroupprivateFallbackAttr(),
+      targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
+      postHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
       targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 6270e05b77780..2a2b695e77c0e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1845,4 +1845,39 @@ class OpenMP_UniformClauseSkip<
 
 def OpenMP_UniformClause : OpenMP_UniformClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V6.1 `dyn_groupprivate` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_DynGroupprivateClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+
+  let arguments = (ins
+    OptionalAttr<AccessGroupModifierAttr>:$dyn_groupprivate_access_group,
+    OptionalAttr<FallbackModifierAttr>:$dyn_groupprivate_fallback,
+    Optional<AnyInteger>:$dyn_groupprivate_size
+  );
+
+  let description = [{
+    The `dyn_groupprivate_access_group` attribute specifies the access group
+    modifier for the dynamically allocated group-private memory. The
+    `dyn_groupprivate_fallback` attribute specifies the fallback behavior when
+    allocation fails. The `dyn_groupprivate_size` operand specifies the size in
+    bytes to allocate.
+  }];
+
+  let optAssemblyFormat = [{
+    `dyn_groupprivate` `(`
+      custom<DynGroupprivateClause>($dyn_groupprivate_access_group,
+      $dyn_groupprivate_fallback,
+      $dyn_groupprivate_size, type($dyn_groupprivate_size))
+    `)`
+  }];
+}
+
+def OpenMP_DynGroupprivateClause : OpenMP_DynGroupprivateClauseSkip<>;
+
 #endif // OPENMP_CLAUSES

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 6c649f4243f13..adea931c5fa60 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -337,4 +337,42 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
   let assemblyFormat = "`(` $value `)`";
 }
 
+//===----------------------------------------------------------------------===//
+// access_group_modifier enum.
+//===----------------------------------------------------------------------===//
+
+def AccessGroupCGroup : I32EnumAttrCase<"cgroup", 0>;
+
+def AccessGroupModifier : OpenMP_I32EnumAttr<
+    "AccessGroupModifier",
+    "access group modifier", [
+      AccessGroupCGroup
+    ]>;
+
+def AccessGroupModifierAttr : OpenMP_EnumAttr<AccessGroupModifier,
+                                            "access_group_modifier"> {
+  let assemblyFormat = "`(` $value `)`";
+}
+
+//===----------------------------------------------------------------------===//
+// fallback_modifier enum.
+//===----------------------------------------------------------------------===//
+
+def FallbackAbort : I32EnumAttrCase<"abort", 0>;
+def FallbackNull : I32EnumAttrCase<"null", 1>;
+def FallbackDefaultMem : I32EnumAttrCase<"default_mem", 2>;
+
+def FallbackModifier : OpenMP_I32EnumAttr<
+    "FallbackModifier",
+    "fallback modifier", [
+      FallbackAbort,
+      FallbackNull,
+      FallbackDefaultMem
+    ]>;
+
+def FallbackModifierAttr : OpenMP_EnumAttr<FallbackModifier,
+                                            "fallback_modifier"> {
+  let assemblyFormat = "`(` $value `)`";
+}
+
 #endif // OPENMP_ENUMS

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 7741542d3329e..c8c233477c174 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -240,8 +240,9 @@ def TerminatorOp : OpenMP_Op<"terminator", [Terminator, Pure]> {
 def TeamsOp : OpenMP_Op<"teams", traits = [
     AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
   ], clauses = [
-    OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
-    OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
+    OpenMP_AllocateClause, OpenMP_DynGroupprivateClause, OpenMP_IfClause,
+    OpenMP_NumTeamsClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
+    OpenMP_ThreadLimitClause
   ], singleRegion = true> {
   let summary = "teams construct";
   let description = [{
@@ -1579,7 +1580,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
     OpenMP_AllocateClause, OpenMP_BareClause, OpenMP_DependClause,
-    OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
+    OpenMP_DeviceClause, OpenMP_DynGroupprivateClause,
+    OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
     OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
     OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
     OpenMP_PrivateClause, OpenMP_ThreadLimitClause

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c6683d1c23e09..40ccff7405799 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -926,6 +926,118 @@ static void printHeapAllocClause(OpAsmPrinter &p, Operation *op,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Parser, printer and verify for dyn_groupprivate Clause
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup,
+                            FallbackModifierAttr fallback,
+                            Value dynGroupprivateSize) {
+  if (!dynGroupprivateSize && (accessGroup || fallback))
+    return op->emitOpError("dyn_groupprivate modifiers require a size operand");
+
+  return success();
+}
+
+static ParseResult parseDynGroupprivateClause(
+    OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr,
+    FallbackModifierAttr &fallbackAttr,
+    std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
+    Type &sizeType) {
+
+  bool parsedAccessGroup = false;
+  bool parsedFallback = false;
+  bool parsedSize = false;
+
+  return parser.parseCommaSeparatedList([&]() -> ParseResult {
+    // Parse AccessGroupModifier.
+    if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
+      if (parsedAccessGroup)
+        return parser.emitError(parser.getCurrentLocation(),
+                                "duplicate access group modifier");
+      accessGroupAttr = AccessGroupModifierAttr::get(
+          parser.getContext(), AccessGroupModifier::cgroup);
+      parsedAccessGroup = true;
+      return success();
+    }
+    // Parse FallbackModifier.
+    if (succeeded(parser.parseOptionalKeyword("fallback"))) {
+      if (parsedFallback)
+        return parser.emitError(parser.getCurrentLocation(),
+                                "duplicate fallback modifier");
+      if (parser.parseLParen())
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected '(' after 'fallback'");
+      llvm::StringRef fbKind;
+      if (parser.parseKeyword(&fbKind))
+        return parser.emitError(
+            parser.getCurrentLocation(),
+            "expected fallback modifier (abort/null/default_mem)");
+      std::optional<FallbackModifier> fbEnum;
+      if (fbKind == "abort")
+        fbEnum = FallbackModifier::abort;
+      else if (fbKind == "null")
+        fbEnum = FallbackModifier::null;
+      else if (fbKind == "default_mem")
+        fbEnum = FallbackModifier::default_mem;
+      else
+        return parser.emitError(parser.getCurrentLocation(),
+                                "invalid fallback modifier '" + fbKind + "'");
+      fallbackAttr = FallbackModifierAttr::get(parser.getContext(), *fbEnum);
+      if (parser.parseRParen())
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected ')' after fallback modifier");
+      parsedFallback = true;
+      return success();
+    }
+    // Parse size operand.
+    OpAsmParser::UnresolvedOperand operand;
+    if (succeeded(parser.parseOperand(operand))) {
+      if (parsedSize)
+        return parser.emitError(parser.getCurrentLocation(),
+                                "duplicate size operand");
+      dynGroupprivateSize = operand;
+      parsedSize = true;
+      if (failed(parser.parseColon()) || failed(parser.parseType(sizeType)))
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected ':' and type after size operand");
+      return success();
+    }
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected dyn_groupprivate_size operand");
+  });
+}
+
+static void printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op,
+                                       AccessGroupModifierAttr modifierFirst,
+                                       FallbackModifierAttr modifierSecond,
+                                       Value dynGroupprivateSize,
+                                       Type sizeType) {
+
+  bool needsComma = false;
+
+  if (modifierFirst) {
+    printer << modifierFirst.getValue();
+    needsComma = true;
+  }
+
+  if (modifierSecond) {
+    if (needsComma)
+      printer << ", ";
+    printer << "fallback(";
+    printer << modifierSecond.getValue();
+    printer << ")";
+    needsComma = true;
+  }
+
+  if (dynGroupprivateSize) {
+    if (needsComma)
+      printer << ", ";
+    printer << dynGroupprivateSize << " : " << sizeType;
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Parsers for operations including clauses that define entry block arguments.
 //===----------------------------------------------------------------------===//
@@ -2426,8 +2538,9 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
       builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
       makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
       makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
-      clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars,
-      clauses.ifExpr,
+      clauses.device, clauses.dynGroupprivateAccessGroup,
+      clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
+      clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
       /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
       /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
       clauses.nowait, clauses.privateVars,
@@ -2449,6 +2562,11 @@ LogicalResult TargetOp::verify() {
   if (failed(verifyMapClause(*this, getMapVars())))
     return failure();
 
+  if (failed(verifyDynGroupprivateClause(
+          *this, getDynGroupprivateAccessGroupAttr(),
+          getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
+    return failure();
+
   return verifyPrivateVarsMapping(*this);
 }
 
@@ -2834,8 +2952,9 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   TeamsOp::build(
       builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
-      /*private_vars=*/{}, /*private_syms=*/nullptr,
+      clauses.dynGroupprivateAccessGroup, clauses.dynGroupprivateFallback,
+      clauses.dynGroupprivateSize, clauses.ifExpr, clauses.numTeamsLower,
+      clauses.numTeamsUpperVars, /*private_vars=*/{}, /*private_syms=*/nullptr,
       /*private_needs_barrier=*/nullptr, clauses.reductionMod,
       clauses.reductionVars,
       makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
@@ -2882,6 +3001,11 @@ LogicalResult TeamsOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  if (failed(verifyDynGroupprivateClause(
+          op, getDynGroupprivateAccessGroupAttr(),
+          getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
+    return failure();
+
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 846cc584c8843..bce0e66a4ea36 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -389,6 +389,11 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       result = todo("thread_limit with multi-dimensional values");
   };
 
+  auto checkDynGroupprivate = [&todo](auto op, LogicalResult &result) {
+    if (op.getDynGroupprivateSize())
+      result = todo("dyn_groupprivate");
+  };
+
   LogicalResult result = success();
   llvm::TypeSwitch<Operation &>(op)
       .Case([&](omp::DistributeOp op) {
@@ -413,6 +418,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkPrivate(op, result);
         checkNumTeams(op, result);
         checkThreadLimit(op, result);
+        checkDynGroupprivate(op, result);
       })
       .Case([&](omp::TaskOp op) {
         checkAllocate(op, result);
@@ -7263,6 +7269,22 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   }
 }
 
+static llvm::omp::OMPDynGroupprivateFallbackType
+getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr) {
+  omp::FallbackModifier fb = fallbackAttr ? fallbackAttr.getValue()
+                                          : omp::FallbackModifier::default_mem;
+  switch (fb) {
+  case omp::FallbackModifier::abort:
+    return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
+  case omp::FallbackModifier::null:
+    return llvm::omp::OMPDynGroupprivateFallbackType::Null;
+  case omp::FallbackModifier::default_mem:
+    return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
+  }
+
+  llvm_unreachable("unexpected dyn_groupprivate fallback type");
+}
+
 static LogicalResult
 convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
@@ -7575,12 +7597,23 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   if (Value targetIfCond = targetOp.getIfExpr())
     ifCond = moduleTranslation.lookupValue(targetIfCond);
 
+  Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
+  llvm::Value *dynSizeVal = nullptr;
+  if (dynGroupPrivateSize) {
+    dynSizeVal = moduleTranslation.lookupValue(dynGroupPrivateSize);
+    dynSizeVal = builder.CreateIntCast(dynSizeVal, builder.getInt32Ty(),
+                                       /*isSigned=*/false);
+  }
+
+  llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
+      getDynGroupprivateFallbackType(targetOp.getDynGroupprivateFallbackAttr());
+
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), deallocBlocks,
           info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput,
           genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
-          targetOp.getNowait());
+          targetOp.getNowait(), dynSizeVal, fallbackType);
 
   if (failed(handleError(afterIP, opInst)))
     return failure();

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 1a3bd678621b4..04725a69c8559 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
     // expected-error @below {{expected equal sizes for allocate and allocator variables}}
     "omp.teams" (%data_var) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
     omp.terminator
   }
   return
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
     // expected-error @below {{expected exactly one num_teams upper bound when lower bound is specified}}
     "omp.teams" (%lb) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
     omp.terminator
   }
   return
@@ -1468,7 +1468,7 @@ func.func @omp_teams_num_teams_multidim_with_bounds() {
     // expected-error @below {{expected exactly one num_teams upper bound when lower bound is specified}}
     "omp.teams" (%lb, %v0, %v1) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,0,1,2,0,0,0>} : (i32, i32, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,0,0,1,2,0,0,0>} : (i32, i32, i32) -> ()
     omp.terminator
   }
   return
@@ -1489,6 +1489,67 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) {
 
 // -----
 
+func.func @test_teams_dyn_groupprivate_errors_1(%dyn_size: i32) {
+  // expected-error @below {{duplicate access group modifier}}
+  omp.teams dyn_groupprivate(cgroup, cgroup, %dyn_size : i32) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_2(%dyn_size: i32) {
+  // expected-error @below {{duplicate fallback modifier}}
+  omp.teams dyn_groupprivate(fallback(null), fallback(abort), %dyn_size : i32) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_3(%dyn_size: i32) {
+  // expected-error @below {{invalid fallback modifier 'no'}}
+  omp.teams dyn_groupprivate(fallback(no), %dyn_size : i32) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_4(%dyn_size: i32) {
+  // expected-error @below {{'omp.teams' op dyn_groupprivate modifiers require a size operand}}
+  omp.teams dyn_groupprivate(fallback(null)) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_5() {
+  // expected-error @below {{expected dyn_groupprivate_size operand}}
+  // expected-error @below {{expected SSA operand}}
+  omp.teams dyn_groupprivate() {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_6(%s1: i32, %s2: i32) {
+  // expected-error @below {{duplicate size operand}}
+  omp.teams dyn_groupprivate(%s1 : i32, %s2 : i32) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @omp_sections(%data_var : memref<i32>) -> () {
   // expected-error @below {{expected equal sizes for allocate and allocator variables}}
   "omp.sections" (%data_var) ({
@@ -2534,7 +2595,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 6ece0e95cdce0..826e36e3f7b19 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -871,7 +871,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%device, %if_cond, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1124,7 +1124,8 @@ func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
 
 // CHECK-LABEL: omp_teams
 func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
-                     %data_var : memref<i32>, %ub64 : i64, %ub16 : i16) -> () {
+                     %data_var : memref<i32>, %ub64 : i64, %ub16 : i16,
+                     %dyn_size : i32) -> () {
   // Test nesting inside of omp.target
   omp.target {
     // CHECK: omp.teams
@@ -1221,6 +1222,13 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
     omp.terminator
   }
 
+  // Test dyn_groupprivate
+  // CHECK: omp.teams dyn_groupprivate(cgroup, fallback(null), %{{.+}} : i32)
+  omp.teams dyn_groupprivate(cgroup, fallback(null), %dyn_size : i32) {
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+
   return
 }
 
@@ -2310,7 +2318,33 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
   omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
     // CHECK: omp.terminator
     omp.terminator
-  } {operandSegmentSizes = array<i32: 0,0,3,0,0,0,0,0,0,0,0,0,0>}
+  } {operandSegmentSizes = array<i32: 0,0,3,0,0,0,0,0,0,0,0,0,0,0>}
+  return
+}
+
+// CHECK-LABEL: @omp_target_dyn_groupprivate
+func.func @omp_target_dyn_groupprivate(%dyn_size: i32, %large_size: i64) {
+  // CHECK: omp.target dyn_groupprivate(%{{.*}} : i32)
+  omp.target dyn_groupprivate(%dyn_size : i32) {
+    omp.terminator
+  }
+  // CHECK: omp.target dyn_groupprivate(cgroup, %{{.*}} : i64)
+  omp.target dyn_groupprivate(cgroup, %large_size : i64) {
+    omp.terminator
+  }
+  // CHECK: omp.target dyn_groupprivate(cgroup, fallback(abort), %{{.*}} : i32)
+  omp.target dyn_groupprivate(cgroup, fallback(abort), %dyn_size : i32) {
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  // CHECK: omp.target dyn_groupprivate(cgroup, fallback(null), %{{.*}} : i32)
+  omp.target dyn_groupprivate(fallback(null), cgroup, %dyn_size : i32) {
+    omp.terminator
+  }
+  // CHECK: omp.target dyn_groupprivate(%{{.*}} : i64)
+  omp.target dyn_groupprivate(%large_size : i64) {
+    omp.terminator
+  }
   return
 }
 

diff  --git a/mlir/test/Target/LLVMIR/openmp-target-dyn-groupprivate.mlir b/mlir/test/Target/LLVMIR/openmp-target-dyn-groupprivate.mlir
new file mode 100644
index 0000000000000..e08ecfbb13959
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-dyn-groupprivate.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// Test that dyn_groupprivate size is correctly cast to i32 when a
+// 
diff erent integer type is used, matching the uint32_t DynCGroupMem
+// field in __tgt_kernel_arguments.
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @target_dyn_groupprivate_i64(%size : i64) {
+    omp.target dyn_groupprivate(%size : i64) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  llvm.func @target_dyn_groupprivate_i16(%size : i16) {
+    omp.target dyn_groupprivate(%size : i16) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  llvm.func @target_dyn_groupprivate_i32(%size : i32) {
+    omp.target dyn_groupprivate(%size : i32) {
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @target_dyn_groupprivate_i64
+// CHECK: %[[TRUNC:.*]] = trunc i64 %{{.*}} to i32
+// CHECK: %[[GEP:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %{{.*}}, i32 0, i32 12
+// CHECK: store i32 %[[TRUNC]], ptr %[[GEP]], align 4
+
+// CHECK-LABEL: define void @target_dyn_groupprivate_i16
+// CHECK: %[[ZEXT:.*]] = zext i16 %{{.*}} to i32
+// CHECK: %[[GEP2:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %{{.*}}, i32 0, i32 12
+// CHECK: store i32 %[[ZEXT]], ptr %[[GEP2]], align 4
+
+// CHECK-LABEL: define void @target_dyn_groupprivate_i32
+// CHECK-NOT: trunc
+// CHECK-NOT: zext
+// CHECK: %[[GEP3:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %{{.*}}, i32 0, i32 12
+// CHECK: store i32 %{{.*}}, ptr %[[GEP3]], align 4

diff  --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 3e521fb4f9263..295ba54dbfb38 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -460,6 +460,17 @@ llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) {
 
 // -----
 
+llvm.func @teams_dyn_groupprivate(%dyn_size : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause dyn_groupprivate in omp.teams operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
+  omp.teams dyn_groupprivate(%dyn_size : i32) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
   // expected-error at below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}


        


More information about the flang-commits mailing list