[flang-commits] [flang] 7ff0dc4 - [mlir][OpenMP] Add iterator support to depend clause (#189090)

via flang-commits flang-commits at lists.llvm.org
Tue Mar 31 09:11:17 PDT 2026


Author: Chi-Chun, Chen
Date: 2026-03-31T11:11:08-05:00
New Revision: 7ff0dc4b9fdee840de0901e969ea11880a28d433

URL: https://github.com/llvm/llvm-project/commit/7ff0dc4b9fdee840de0901e969ea11880a28d433
DIFF: https://github.com/llvm/llvm-project/commit/7ff0dc4b9fdee840de0901e969ea11880a28d433.diff

LOG: [mlir][OpenMP] Add iterator support to depend clause (#189090)

Extend the depend clause to support `!omp.iterated<Ty>` handles
alongside plain depend vars, so the IR can represent both forms.

Assisted with copilot

This is part of feature work for
https://github.com/llvm/llvm-project/issues/188061

Added: 
    

Modified: 
    flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
    flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir
    mlir/test/Target/LLVMIR/openmp-todo.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index 472d6a9f08a6e..475ed35cac9fa 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -309,6 +309,8 @@ class FunctionFilteringPass
       // Variables unused by the device.
       targetOp.getDependVarsMutable().clear();
       targetOp.setDependKindsAttr(nullptr);
+      targetOp.getDependIteratedMutable().clear();
+      targetOp.setDependIteratedKindsAttr(nullptr);
       targetOp.getDeviceMutable().clear();
       targetOp.getIfExprMutable().clear();
 

diff  --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 2c7980064500f..8a08f67006c0a 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -760,6 +760,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
       rewriter, targetOp.getLoc(), targetOp.getAllocateVars(),
       targetOp.getAllocatorVars(), targetOp.getBareAttr(),
       targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
       targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
       targetOp.getHostEvalVars(), targetOp.getIfExpr(),
       targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
@@ -1480,6 +1481,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
       rewriter, targetOp.getLoc(), targetOp.getAllocateVars(),
       targetOp.getAllocatorVars(), targetOp.getBareAttr(),
       targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
       targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars,
       targetOp.getIfExpr(), targetOp.getInReductionVars(),
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
@@ -1570,6 +1572,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
       rewriter, targetOp.getLoc(), targetOp.getAllocateVars(),
       targetOp.getAllocatorVars(), targetOp.getBareAttr(),
       targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
       targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
       isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
@@ -1650,6 +1653,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
       rewriter, targetOp.getLoc(), targetOp.getAllocateVars(),
       targetOp.getAllocatorVars(), targetOp.getBareAttr(),
       targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
       targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars,
       targetOp.getIfExpr(), targetOp.getInReductionVars(),
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index b5a047dc613ff..f24efd0d4fc42 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -316,20 +316,26 @@ class OpenMP_DependClauseSkip<
     bit description = false, bit extraClassDeclaration = false
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
-  let arguments = (ins
-    OptionalAttr<TaskDependArrayAttr>:$depend_kinds,
-    Variadic<OpenMP_PointerLikeType>:$depend_vars
-  );
+  let arguments = (ins OptionalAttr<TaskDependArrayAttr>:$depend_kinds,
+      Variadic<OpenMP_PointerLikeType>:$depend_vars,
+      OptionalAttr<TaskDependArrayAttr>:$depend_iterated_kinds,
+      Variadic<OpenMP_IteratedType>:$depend_iterated);
 
   let optAssemblyFormat = [{
     `depend` `(`
-      custom<DependVarList>($depend_vars, type($depend_vars), $depend_kinds) `)`
+      custom<DependVarList>($depend_vars, type($depend_vars), $depend_kinds,
+                            $depend_iterated, type($depend_iterated),
+                            $depend_iterated_kinds) `)`
   }];
 
   let description = [{
     The `depend_kinds` and `depend_vars` arguments are variadic lists of values
     that specify the dependencies of this particular task in relation to other
     tasks.
+
+    The `depend_iterated_kinds` and `depend_iterated` arguments are variadic
+    lists of iterator-produced handles (from `omp.iterator`) that specify
+    dependencies expanded at runtime via an iterator modifier.
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 88c8ab4f6f949..6a2fd44841572 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1665,9 +1665,9 @@ def OrderedRegionOp : OpenMP_Op<"ordered.region", clauses = [
 // 2.17.5 taskwait Construct
 //===----------------------------------------------------------------------===//
 
-def TaskwaitOp : OpenMP_Op<"taskwait", clauses = [
-    OpenMP_DependClause, OpenMP_NowaitClause
-  ]> {
+def TaskwaitOp
+    : OpenMP_Op<"taskwait", traits = [AttrSizedOperandSegments],
+                clauses = [OpenMP_DependClause, OpenMP_NowaitClause]> {
   let summary = "taskwait construct";
   let description = [{
     The taskwait construct specifies a wait on the completion of child tasks

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 04418ee39be54..82fbc909f5275 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1712,50 +1712,77 @@ verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars,
 /// depend-entry-list ::= depend-entry
 ///                     | depend-entry-list `,` depend-entry
 /// depend-entry ::= depend-kind `->` ssa-id `:` type
-static ParseResult
-parseDependVarList(OpAsmParser &parser,
-                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dependVars,
-                   SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
+///                | depend-kind `->` ssa-id `:` iterated-type
+static ParseResult parseDependVarList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dependVars,
+    SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &iteratedVars,
+    SmallVectorImpl<Type> &iteratedTypes, ArrayAttr &iteratedKinds) {
   SmallVector<ClauseTaskDependAttr> kindsVec;
+  SmallVector<ClauseTaskDependAttr> iterKindsVec;
   if (failed(parser.parseCommaSeparatedList([&]() {
         StringRef keyword;
+        OpAsmParser::UnresolvedOperand operand;
+        Type ty;
         if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
-            parser.parseOperand(dependVars.emplace_back()) ||
-            parser.parseColonType(dependTypes.emplace_back()))
+            parser.parseOperand(operand) || parser.parseColonType(ty))
           return failure();
-        if (std::optional<ClauseTaskDepend> keywordDepend =
-                (symbolizeClauseTaskDepend(keyword)))
-          kindsVec.emplace_back(
-              ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
-        else
+        std::optional<ClauseTaskDepend> keywordDepend =
+            symbolizeClauseTaskDepend(keyword);
+        if (!keywordDepend)
           return failure();
+        auto kindAttr =
+            ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
+        if (llvm::isa<mlir::omp::IteratedType>(ty)) {
+          iteratedVars.push_back(operand);
+          iteratedTypes.push_back(ty);
+          iterKindsVec.push_back(kindAttr);
+        } else {
+          dependVars.push_back(operand);
+          dependTypes.push_back(ty);
+          kindsVec.push_back(kindAttr);
+        }
         return success();
       })))
     return failure();
   SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
   dependKinds = ArrayAttr::get(parser.getContext(), kinds);
+  SmallVector<Attribute> iterKinds(iterKindsVec.begin(), iterKindsVec.end());
+  iteratedKinds = ArrayAttr::get(parser.getContext(), iterKinds);
   return success();
 }
 
 /// Print Depend clause
 static void printDependVarList(OpAsmPrinter &p, Operation *op,
                                OperandRange dependVars, TypeRange dependTypes,
-                               std::optional<ArrayAttr> dependKinds) {
-
-  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
-    if (i != 0)
-      p << ", ";
-    p << stringifyClauseTaskDepend(
-             llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
-                 .getValue())
-      << " -> " << dependVars[i] << " : " << dependTypes[i];
-  }
+                               std::optional<ArrayAttr> dependKinds,
+                               OperandRange iteratedVars,
+                               TypeRange iteratedTypes,
+                               std::optional<ArrayAttr> iteratedKinds) {
+  bool first = true;
+  auto printEntries = [&](OperandRange vars, TypeRange types,
+                          std::optional<ArrayAttr> kinds) {
+    for (unsigned i = 0, e = vars.size(); i < e; ++i) {
+      if (!first)
+        p << ", ";
+      p << stringifyClauseTaskDepend(
+               llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
+                   .getValue())
+        << " -> " << vars[i] << " : " << types[i];
+      first = false;
+    }
+  };
+  printEntries(dependVars, dependTypes, dependKinds);
+  printEntries(iteratedVars, iteratedTypes, iteratedKinds);
 }
 
 /// Verifies Depend clause
 static LogicalResult verifyDependVarList(Operation *op,
                                          std::optional<ArrayAttr> dependKinds,
-                                         OperandRange dependVars) {
+                                         OperandRange dependVars,
+                                         std::optional<ArrayAttr> iteratedKinds,
+                                         OperandRange iteratedVars) {
   if (!dependVars.empty()) {
     if (!dependKinds || dependKinds->size() != dependVars.size())
       return op->emitOpError() << "expected as many depend values"
@@ -1763,7 +1790,15 @@ static LogicalResult verifyDependVarList(Operation *op,
   } else {
     if (dependKinds && !dependKinds->empty())
       return op->emitOpError() << "unexpected depend values";
-    return success();
+  }
+
+  if (!iteratedVars.empty()) {
+    if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
+      return op->emitOpError() << "expected as many depend iterated values"
+                                  " as depend iterated variables";
+  } else {
+    if (iteratedKinds && !iteratedKinds->empty())
+      return op->emitOpError() << "unexpected depend iterated values";
   }
 
   return success();
@@ -2266,15 +2301,17 @@ void TargetEnterDataOp::build(
     OpBuilder &builder, OperationState &state,
     const TargetEnterExitUpdateDataOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  TargetEnterDataOp::build(builder, state,
-                           makeArrayAttr(ctx, clauses.dependKinds),
-                           clauses.dependVars, clauses.device, clauses.ifExpr,
-                           clauses.mapVars, clauses.nowait);
+  TargetEnterDataOp::build(
+      builder, state, makeArrayAttr(ctx, clauses.dependKinds),
+      clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
+      clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
+      clauses.nowait);
 }
 
 LogicalResult TargetEnterDataOp::verify() {
   LogicalResult verifyDependVars =
-      verifyDependVarList(*this, getDependKinds(), getDependVars());
+      verifyDependVarList(*this, getDependKinds(), getDependVars(),
+                          getDependIteratedKinds(), getDependIterated());
   return failed(verifyDependVars) ? verifyDependVars
                                   : verifyMapClause(*this, getMapVars());
 }
@@ -2286,15 +2323,17 @@ LogicalResult TargetEnterDataOp::verify() {
 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
                              const TargetEnterExitUpdateDataOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  TargetExitDataOp::build(builder, state,
-                          makeArrayAttr(ctx, clauses.dependKinds),
-                          clauses.dependVars, clauses.device, clauses.ifExpr,
-                          clauses.mapVars, clauses.nowait);
+  TargetExitDataOp::build(
+      builder, state, makeArrayAttr(ctx, clauses.dependKinds),
+      clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
+      clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
+      clauses.nowait);
 }
 
 LogicalResult TargetExitDataOp::verify() {
   LogicalResult verifyDependVars =
-      verifyDependVarList(*this, getDependKinds(), getDependVars());
+      verifyDependVarList(*this, getDependKinds(), getDependVars(),
+                          getDependIteratedKinds(), getDependIterated());
   return failed(verifyDependVars) ? verifyDependVars
                                   : verifyMapClause(*this, getMapVars());
 }
@@ -2307,13 +2346,16 @@ void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
                            const TargetEnterExitUpdateDataOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
-                        clauses.dependVars, clauses.device, clauses.ifExpr,
+                        clauses.dependVars,
+                        makeArrayAttr(ctx, clauses.dependIteratedKinds),
+                        clauses.dependIterated, clauses.device, clauses.ifExpr,
                         clauses.mapVars, clauses.nowait);
 }
 
 LogicalResult TargetUpdateOp::verify() {
   LogicalResult verifyDependVars =
-      verifyDependVarList(*this, getDependKinds(), getDependVars());
+      verifyDependVarList(*this, getDependKinds(), getDependVars(),
+                          getDependIteratedKinds(), getDependIterated());
   return failed(verifyDependVars) ? verifyDependVars
                                   : verifyMapClause(*this, getMapVars());
 }
@@ -2327,20 +2369,24 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
   // inReductionByref, inReductionSyms.
-  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
-                  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
-                  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
-                  clauses.hostEvalVars, clauses.ifExpr,
-                  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
-                  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
-                  clauses.mapVars, clauses.nowait, clauses.privateVars,
-                  makeArrayAttr(ctx, clauses.privateSyms),
-                  clauses.privateNeedsBarrier, clauses.threadLimitVars,
-                  /*private_maps=*/nullptr);
+  TargetOp::build(
+      builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
+      makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
+      makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
+      clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars,
+      clauses.ifExpr,
+      /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
+      /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
+      clauses.nowait, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.threadLimitVars,
+      /*private_maps=*/nullptr);
 }
 
 LogicalResult TargetOp::verify() {
-  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
+  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars(),
+                                 getDependIteratedKinds(),
+                                 getDependIterated())))
     return failure();
 
   if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
@@ -3270,21 +3316,23 @@ LogicalResult DeclareReductionOp::verifyRegions() {
 void TaskOp::build(OpBuilder &builder, OperationState &state,
                    const TaskOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  TaskOp::build(builder, state, clauses.iterated, clauses.affinityVars,
-                clauses.allocateVars, clauses.allocatorVars,
-                makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                clauses.final, clauses.ifExpr, clauses.inReductionVars,
-                makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
-                makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
-                clauses.priority, /*private_vars=*/clauses.privateVars,
-                /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
-                clauses.privateNeedsBarrier, clauses.untied,
-                clauses.eventHandle);
+  TaskOp::build(
+      builder, state, clauses.iterated, clauses.affinityVars,
+      clauses.allocateVars, clauses.allocatorVars,
+      makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
+      makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
+      clauses.final, clauses.ifExpr, clauses.inReductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
+      makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
+      clauses.priority, /*private_vars=*/clauses.privateVars,
+      /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
+      clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
 }
 
 LogicalResult TaskOp::verify() {
   LogicalResult verifyDependVars =
-      verifyDependVarList(*this, getDependKinds(), getDependVars());
+      verifyDependVarList(*this, getDependKinds(), getDependVars(),
+                          getDependIteratedKinds(), getDependIterated());
   return failed(verifyDependVars)
              ? verifyDependVars
              : verifyReductionVarList(*this, getInReductionSyms(),
@@ -4142,7 +4190,8 @@ void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
                        const TaskwaitOperands &clauses) {
   // TODO Store clauses in op: dependKinds, dependVars, nowait.
   TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
-                    /*depend_vars=*/{}, /*nowait=*/nullptr);
+                    /*depend_vars=*/{}, /*depend_iterated_kinds=*/nullptr,
+                    /*depend_iterated=*/{}, /*nowait=*/nullptr);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index aa8a9b9b1f7f6..281235a0af461 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -357,6 +357,11 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (!op.getDependVars().empty() || op.getDependKinds())
       result = todo("depend");
   };
+  auto checkDependIteratorModifier = [&todo](auto op, LogicalResult &result) {
+    if (!op.getDependIterated().empty() ||
+        (op.getDependIteratedKinds() && !op.getDependIteratedKinds()->empty()))
+      result = todo("depend with iterator modifier");
+  };
   auto checkHint = [](auto op, LogicalResult &) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
@@ -429,6 +434,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::TaskOp op) {
         checkAllocate(op, result);
+        checkDependIteratorModifier(op, result);
         checkInReduction(op, result);
       })
       .Case([&](omp::TaskgroupOp op) {
@@ -463,6 +469,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::TargetOp op) {
         checkAllocate(op, result);
         checkBare(op, result);
+        checkDependIteratorModifier(op, result);
         checkInReduction(op, result);
         checkThreadLimit(op, result);
       })

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 4879ea754bf75..db5d1b60c5697 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1716,6 +1716,26 @@ func.func @omp_task_depend(%data_var: memref<i32>) {
 
 // -----
 
+func.func @omp_task_depend_iterated_no_vars(%data_var: memref<i32>) {
+  // expected-error @below {{op unexpected depend iterated values}}
+    "omp.task"() ({
+      "omp.terminator"() : () -> ()
+    }) {depend_iterated_kinds = [#omp<clause_task_depend(taskdependin)>], operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : () -> ()
+   "func.return"() : () -> ()
+}
+
+// -----
+
+func.func @omp_task_depend_iterated_mismatch(%it: !omp.iterated<!llvm.ptr>) {
+  // expected-error @below {{op expected as many depend iterated values as depend iterated variables}}
+    "omp.task"(%it) ({
+      "omp.terminator"() : () -> ()
+    }) {depend_iterated_kinds = [], operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>} : (!omp.iterated<!llvm.ptr>) -> ()
+   "func.return"() : () -> ()
+}
+
+// -----
+
 func.func @omp_task(%ptr: !llvm.ptr) {
   // expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}}
   omp.task in_reduction(@add_f32 %ptr -> %arg0 : !llvm.ptr) {
@@ -2274,7 +2294,7 @@ func.func @omp_target_enter_data(%map1: memref<?xi32>) {
 func.func @omp_target_enter_data_depend(%a: memref<?xi32>) {
   %0 = omp.map.info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
   // expected-error @below {{op expected as many depend values as depend variables}}
-  omp.target_enter_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 1, 0, 0, 0>}
+  omp.target_enter_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 1, 0, 0, 0, 0>}
   return
 }
 
@@ -2292,7 +2312,7 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
 func.func @omp_target_exit_data_depend(%a: memref<?xi32>) {
   %0 = omp.map.info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
   // expected-error @below {{op expected as many depend values as depend variables}}
-  omp.target_exit_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 1, 0, 0, 0>}
+  omp.target_exit_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 1, 0, 0, 0, 0>}
   return
 }
 
@@ -2373,7 +2393,7 @@ llvm.mlir.global internal @_QFsubEx() : i32
 func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
   %0 = omp.map.info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
   // expected-error @below {{op expected as many depend values as depend variables}}
-  omp.target_update map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 1, 0, 0, 0>}
+  omp.target_update map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 1, 0, 0, 0, 0>}
   return
 }
 
@@ -2475,7 +2495,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d924d479eba90..b0554eba459f8 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -871,7 +871,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%device, %if_cond, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2269,6 +2269,39 @@ func.func @omp_task_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
   return
 }
 
+// CHECK-LABEL: func.func @omp_task_depend_iterated
+func.func @omp_task_depend_iterated(%lb : index, %ub : index, %step : index,
+                                     %addr : !llvm.ptr) -> () {
+  // CHECK: %[[IT:.*]] = omp.iterator(%[[IV:.*]]: index) = (%{{.*}} to %{{.*}} step %{{.*}}) {
+  // CHECK:   omp.yield(%{{.*}} : !llvm.ptr)
+  // CHECK: } -> !omp.iterated<!llvm.ptr>
+  // CHECK: omp.task depend(taskdependin -> %[[IT]] : !omp.iterated<!llvm.ptr>) {
+  %it = omp.iterator(%iv: index) = (%lb to %ub step %step) {
+    omp.yield(%addr : !llvm.ptr)
+  } -> !omp.iterated<!llvm.ptr>
+
+  omp.task depend(taskdependin -> %it : !omp.iterated<!llvm.ptr>) {
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @omp_task_depend_iterated_mixed
+func.func @omp_task_depend_iterated_mixed(%lb : index, %ub : index, %step : index,
+                                           %addr : !llvm.ptr,
+                                           %plain : memref<i32>) -> () {
+  // CHECK: %[[IT:.*]] = omp.iterator
+  // CHECK: omp.task depend(taskdependout -> %{{.*}} : memref<i32>, taskdependin -> %[[IT]] : !omp.iterated<!llvm.ptr>) {
+  %it = omp.iterator(%iv: index) = (%lb to %ub step %step) {
+    omp.yield(%addr : !llvm.ptr)
+  } -> !omp.iterated<!llvm.ptr>
+
+  omp.task depend(taskdependout -> %plain : memref<i32>, taskdependin -> %it : !omp.iterated<!llvm.ptr>) {
+    omp.terminator
+  }
+  return
+}
+
 
 // CHECK-LABEL: @omp_target_depend
 // CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
@@ -2277,7 +2310,7 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
   omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
     // CHECK: omp.terminator
     omp.terminator
-  } {operandSegmentSizes = array<i32: 0,0,0,3,0,0,0,0>}
+  } {operandSegmentSizes = array<i32: 0,0,3,0,0,0,0,0,0,0,0,0,0>}
   return
 }
 

diff  --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8fb66cb4dd0eb..ea7ec3cfc3bdb 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -190,6 +190,36 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) {
 
 // -----
 
+llvm.func @task_depend_iterator_modifier(%lb : i64, %ub : i64, %step : i64,
+                                 %addr : !llvm.ptr) {
+  %it = omp.iterator(%iv: i64) = (%lb to %ub step %step) {
+    omp.yield(%addr : !llvm.ptr)
+  } -> !omp.iterated<!llvm.ptr>
+  // expected-error at below {{not yet implemented: Unhandled clause depend with iterator modifier in omp.task operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.task}}
+  omp.task depend(taskdependin -> %it : !omp.iterated<!llvm.ptr>) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+llvm.func @target_depend_iterator_modifier(%lb : i64, %ub : i64, %step : i64,
+                                   %addr : !llvm.ptr) {
+  %it = omp.iterator(%iv: i64) = (%lb to %ub step %step) {
+    omp.yield(%addr : !llvm.ptr)
+  } -> !omp.iterated<!llvm.ptr>
+  // expected-error at below {{not yet implemented: Unhandled clause depend with iterator modifier in omp.target operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
+  omp.target depend(taskdependin -> %it : !omp.iterated<!llvm.ptr>) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @target_enter_data_depend(%x: !llvm.ptr) {
   // expected-error at below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.target_enter_data}}


        


More information about the flang-commits mailing list