[flang-commits] [flang] 179db7e - [MLIR][OpenMP] Add support for depend clause

Prabhdeep Singh Soni via flang-commits flang-commits at lists.llvm.org
Tue Feb 14 11:18:47 PST 2023


Author: Prabhdeep Singh Soni
Date: 2023-02-14T14:18:16-05:00
New Revision: 179db7efe567ed76e36b6c4d69605b426d8f70ca

URL: https://github.com/llvm/llvm-project/commit/179db7efe567ed76e36b6c4d69605b426d8f70ca
DIFF: https://github.com/llvm/llvm-project/commit/179db7efe567ed76e36b6c4d69605b426d8f70ca.diff

LOG: [MLIR][OpenMP] Add support for depend clause

This patch adds support for the OpenMP 4.0 depend clause (in, out,
inout) of the task construct to the definition of the OpenMP MLIR
dialect and translation from MLIR to LLVM IR using OMPIRBuilder.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D142730

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP.cpp
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir
    mlir/test/Target/LLVMIR/openmp-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 38af6204ec9e7..5899ddbdcae6f 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -974,8 +974,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
         currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
         mergeableAttr, /*in_reduction_vars=*/ValueRange(),
-        /*in_reductions=*/nullptr, priorityClauseOperand, allocateOperands,
-        allocatorOperands);
+        /*in_reductions=*/nullptr, priorityClauseOperand, /*depends=*/nullptr,
+        /*depend_vars=*/ValueRange(), allocateOperands, allocatorOperands);
     createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
   } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
     // TODO: Add task_reduction support

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 238bcfdc12d62..d494e89b1274a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -478,6 +478,26 @@ def YieldOp : OpenMP_Op<"yield",
 // 2.10.1 task Construct
 //===----------------------------------------------------------------------===//
 
+def ClauseTaskDependIn    : I32EnumAttrCase<"taskdependin",    0>;
+def ClauseTaskDependOut   : I32EnumAttrCase<"taskdependout",   1>;
+def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
+
+def ClauseTaskDepend : I32EnumAttr<
+    "ClauseTaskDepend",
+    "task depend clause",
+    [ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::omp";
+}
+def ClauseTaskDependAttr :
+  EnumAttr<OpenMP_Dialect, ClauseTaskDepend, "clause_task_depend"> {
+  let assemblyFormat = "`(` $value `)`";
+}
+def TaskDependArrayAttr :
+  TypedArrayAttrBase<ClauseTaskDependAttr, "clause_task_depend array attr"> {
+    let constBuilderCall = ?;
+  }
+
 def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
                        OutlineableOpenMPOpInterface, AutomaticAllocationScope,
                        ReductionClauseInterface]> {
@@ -518,6 +538,10 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
     default priority-value when no priority clause is specified should be
     assumed to be zero (the lowest priority).
 
+    The `depends` and `depend_vars` arguments are variadic lists of values
+    that specify the dependencies of this particular task in relation to
+    other tasks.
+
     The `allocators_vars` and `allocate_vars` arguments are a variadic list of
     values that specify the memory allocator to be used to obtain storage for
     private values.
@@ -532,6 +556,8 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
                        Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
                        OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
                        Optional<I32>:$priority,
+                       OptionalAttr<TaskDependArrayAttr>:$depends,
+                       Variadic<OpenMP_PointerLikeType>:$depend_vars,
                        Variadic<AnyType>:$allocate_vars,
                        Variadic<AnyType>:$allocators_vars);
   let regions = (region AnyRegion:$region);
@@ -550,6 +576,10 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
                 $allocate_vars, type($allocate_vars),
                 $allocators_vars, type($allocators_vars)
               ) `)`
+          |`depend` `(`
+              custom<DependVarList>(
+                $depend_vars, type($depend_vars), $depends
+              ) `)`
     ) $region attr-dict
   }];
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index c12a2e9bd3348..822a1abd0b282 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -113,10 +113,10 @@ struct LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
 
 void mlir::configureOpenMPToLLVMConversionLegality(
     ConversionTarget &target, LLVMTypeConverter &typeConverter) {
-  target.addDynamicallyLegalOp<mlir::omp::CriticalOp, mlir::omp::ParallelOp,
-                               mlir::omp::WsLoopOp, mlir::omp::SimdLoopOp,
-                               mlir::omp::MasterOp, mlir::omp::SectionsOp,
-                               mlir::omp::SingleOp>([&](Operation *op) {
+  target.addDynamicallyLegalOp<
+      mlir::omp::CriticalOp, mlir::omp::ParallelOp, mlir::omp::WsLoopOp,
+      mlir::omp::SimdLoopOp, mlir::omp::MasterOp, mlir::omp::SectionsOp,
+      mlir::omp::SingleOp, mlir::omp::TaskOp>([&](Operation *op) {
     return typeConverter.isLegal(&op->getRegion(0)) &&
            typeConverter.isLegal(op->getOperandTypes()) &&
            typeConverter.isLegal(op->getResultTypes());
@@ -142,6 +142,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
       RegionOpConversion<omp::MasterOp>, RegionOpConversion<omp::ParallelOp>,
       RegionOpConversion<omp::WsLoopOp>, RegionOpConversion<omp::SectionsOp>,
       RegionOpConversion<omp::SimdLoopOp>, RegionOpConversion<omp::SingleOp>,
+      RegionOpConversion<omp::TaskOp>,
       RegionLessOpWithVarOperandsConversion<omp::AtomicReadOp>,
       RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
       RegionLessOpWithVarOperandsConversion<omp::FlushOp>,

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index aec35691a258c..79f339d26ad1e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -464,6 +464,69 @@ static LogicalResult verifyReductionVarList(Operation *op,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Parser, printer and verifier for DependVarList
+//===----------------------------------------------------------------------===//
+
+/// depend-entry-list ::= depend-entry
+///                     | depend-entry-list `,` depend-entry
+/// depend-entry ::= depend-kind `->` ssa-id `:` type
+static ParseResult
+parseDependVarList(OpAsmParser &parser,
+                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+                   SmallVectorImpl<Type> &types, ArrayAttr &dependsArray) {
+  SmallVector<ClauseTaskDependAttr> dependVec;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        StringRef keyword;
+        if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
+            parser.parseOperand(operands.emplace_back()) ||
+            parser.parseColonType(types.emplace_back()))
+          return failure();
+        if (std::optional<ClauseTaskDepend> keywordDepend =
+                (symbolizeClauseTaskDepend(keyword)))
+          dependVec.emplace_back(
+              ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
+        else
+          return failure();
+        return success();
+      })))
+    return failure();
+  SmallVector<Attribute> depends(dependVec.begin(), dependVec.end());
+  dependsArray = ArrayAttr::get(parser.getContext(), depends);
+  return success();
+}
+
+/// Print Depend clause
+static void printDependVarList(OpAsmPrinter &p, Operation *op,
+                               OperandRange dependVars, TypeRange dependTypes,
+                               std::optional<ArrayAttr> depends) {
+
+  for (unsigned i = 0, e = depends->size(); i < e; ++i) {
+    if (i != 0)
+      p << ", ";
+    p << stringifyClauseTaskDepend(
+             (*depends)[i].cast<mlir::omp::ClauseTaskDependAttr>().getValue())
+      << " -> " << dependVars[i] << " : " << dependTypes[i];
+  }
+}
+
+/// Verifies Depend clause
+static LogicalResult verifyDependVarList(Operation *op,
+                                         Optional<ArrayAttr> depends,
+                                         OperandRange dependVars) {
+  if (!dependVars.empty()) {
+    if (!depends || depends->size() != dependVars.size())
+      return op->emitOpError() << "expected as many depend values"
+                                  " as depend variables";
+  } else {
+    if (depends)
+      return op->emitOpError() << "unexpected depend values";
+    return success();
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Parser, printer and verifier for Synchronization Hint (2.17.12)
 //===----------------------------------------------------------------------===//
@@ -958,7 +1021,12 @@ LogicalResult ReductionOp::verify() {
 // TaskOp
 //===----------------------------------------------------------------------===//
 LogicalResult TaskOp::verify() {
-  return verifyReductionVarList(*this, getInReductions(), getInReductionVars());
+  LogicalResult verifyDependVars =
+      verifyDependVarList(*this, getDepends(), getDependVars());
+  return failed(verifyDependVars)
+             ? verifyDependVars
+             : verifyReductionVarList(*this, getInReductions(),
+                                      getInReductionVars());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6cf4a3c4e9c74..2cfdaa3f8730a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -693,11 +693,37 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
     convertOmpOpRegions(taskOp.getRegion(), "omp.task.region", builder,
                         moduleTranslation, bodyGenStatus);
   };
+
+  SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
+  if (!taskOp.getDependVars().empty() && taskOp.getDepends()) {
+    for (auto dep :
+         llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) {
+      llvm::omp::RTLDependenceKindTy type;
+      switch (
+          std::get<1>(dep).cast<mlir::omp::ClauseTaskDependAttr>().getValue()) {
+      case mlir::omp::ClauseTaskDepend::taskdependin:
+        type = llvm::omp::RTLDependenceKindTy::DepIn;
+        break;
+      // The OpenMP runtime requires that the codegen for 'depend' clause for
+      // 'out' dependency kind must be the same as codegen for 'depend' clause
+      // with 'inout' dependency.
+      case mlir::omp::ClauseTaskDepend::taskdependout:
+      case mlir::omp::ClauseTaskDepend::taskdependinout:
+        type = llvm::omp::RTLDependenceKindTy::DepInOut;
+        break;
+      };
+      llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
+      llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+      dds.emplace_back(dd);
+    }
+  }
+
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTask(
-      ompLoc, allocaIP, bodyCB, !taskOp.getUntied()));
+      ompLoc, allocaIP, bodyCB, !taskOp.getUntied(), /*Final*/ nullptr,
+      /*IfCondition*/ nullptr, dds));
   return bodyGenStatus;
 }
 

diff  --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index b8e4ff36801a1..354c67912377b 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -148,6 +148,23 @@ func.func @simdloop_block_arg(%val : i32, %ub : i32, %i : index) {
 
 // -----
 
+// CHECK-LABEL: @task_depend
+// CHECK:  (%[[ARG0:.*]]: !llvm.ptr<i32>) {
+// CHECK:  omp.task depend(taskdependin -> %[[ARG0]] : !llvm.ptr<i32>) {
+// CHECK:    omp.terminator
+// CHECK:  }
+// CHECK:   llvm.return
+// CHECK: }
+
+func.func @task_depend(%arg0: !llvm.ptr<i32>) {
+  omp.task depend(taskdependin -> %arg0 : !llvm.ptr<i32>) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @_QPomp_target_data
 // CHECK: (%[[ARG0:.*]]: !llvm.ptr<i32>, %[[ARG1:.*]]: !llvm.ptr<i32>, %[[ARG2:.*]]: !llvm.ptr<i32>, %[[ARG3:.*]]: !llvm.ptr<i32>)
 // CHECK:         omp.target_enter_data   map((to -> %[[ARG0]] : !llvm.ptr<i32>), (to -> %[[ARG1]] : !llvm.ptr<i32>), (always, alloc -> %[[ARG2]] : !llvm.ptr<i32>))

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 5619da6c9ec25..bdc935dd229a9 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1238,6 +1238,16 @@ func.func @omp_single(%data_var : memref<i32>) -> () {
 
 // -----
 
+func.func @omp_task_depend(%data_var: memref<i32>) {
+  // expected-error @below {{op expected as many depend values as depend variables}}
+    "omp.task"(%data_var) ({
+      "omp.terminator"() : () -> ()
+    }) {depends = [], operand_segment_sizes = array<i32: 0, 0, 0, 0, 1, 0, 0>} : (memref<i32>) -> ()
+   "func.return"() : () -> ()
+}
+
+// -----
+
 func.func @omp_task(%ptr: !llvm.ptr<f32>) {
   // expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}}
   omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr<f32>) {

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 358a63947a5ef..411a4c722b6f0 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1566,6 +1566,19 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
   return
 }
 
+// CHECK-LABEL: @omp_task_depend
+// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
+func.func @omp_task_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
+  // CHECK:  omp.task   depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+  omp.task   depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+    // CHECK: "test.foo"() : () -> ()
+    "test.foo"() : () -> ()
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
 func.func @omp_threadprivate() {
   %0 = arith.constant 1 : i32
   %1 = arith.constant 2 : i32

diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index f96a5ff2768c6..f9bf47f388af0 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2237,6 +2237,55 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 // CHECK:   ret void
 
 
+// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
+// CHECK:   call void @[[outlined_fn]]()
+// CHECK:   ret i32 0
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: define void @omp_task_with_deps
+// CHECK-SAME: (ptr %[[zaddr:.+]])
+// CHECK:  %[[dep_arr_addr:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
+// CHECK:  %[[dep_arr_addr_0:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[dep_arr_addr]], i64 0, i64 0
+// CHECK:  %[[dep_arr_addr_0_val:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 0
+// CHECK:  %[[dep_arr_addr_0_val_int:.+]] = ptrtoint ptr %0 to i64
+// CHECK:  store i64 %[[dep_arr_addr_0_val_int]], ptr %[[dep_arr_addr_0_val]], align 4
+// CHECK:  %[[dep_arr_addr_0_size:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 1
+// CHECK:  store i64 8, ptr %[[dep_arr_addr_0_size]], align 4
+// CHECK:  %[[dep_arr_addr_0_kind:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 2
+// CHECK: store i8 1, ptr %[[dep_arr_addr_0_kind]], align 1
+llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
+  // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
+  // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
+  // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0,
+  // CHECK-SAME:  i64 0, ptr @[[wrapper_fn:.+]])
+  // CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}})
+  omp.task depend(taskdependin -> %zaddr : !llvm.ptr<i32>) {
+    %n = llvm.mlir.constant(1 : i64) : i64
+    %valaddr = llvm.alloca %n x i32 : (i64) -> !llvm.ptr<i32>
+    %val = llvm.load %valaddr : !llvm.ptr<i32>
+    %double = llvm.add %val, %val : i32
+    llvm.store %double, %valaddr : !llvm.ptr<i32>
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK: define internal void @[[outlined_fn:.+]]()
+// CHECK: task.alloca{{.*}}:
+// CHECK:   br label %[[task_body:[^, ]+]]
+// CHECK: [[task_body]]:
+// CHECK:   br label %[[task_region:[^, ]+]]
+// CHECK: [[task_region]]:
+// CHECK:   %[[alloca:.+]] = alloca i32, i64 1
+// CHECK:   %[[val:.+]] = load i32, ptr %[[alloca]]
+// CHECK:   %[[newval:.+]] = add i32 %[[val]], %[[val]]
+// CHECK:   store i32 %[[newval]], ptr %{{[^, ]+}}
+// CHECK:   br label %[[exit_stub:[^, ]+]]
+// CHECK: [[exit_stub]]:
+// CHECK:   ret void
+
 // CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
 // CHECK:   call void @[[outlined_fn]]()
 // CHECK:   ret i32 0


        


More information about the flang-commits mailing list