[flang-commits] [flang] d4e9ba5 - [mlir][OpenMP] Standardise representation of reduction clause (#96215)

via flang-commits flang-commits at lists.llvm.org
Thu Jun 27 04:06:27 PDT 2024


Author: Tom Eccles
Date: 2024-06-27T12:06:22+01:00
New Revision: d4e9ba59d6a2e334c983fa79f43b167d0583772b

URL: https://github.com/llvm/llvm-project/commit/d4e9ba59d6a2e334c983fa79f43b167d0583772b
DIFF: https://github.com/llvm/llvm-project/commit/d4e9ba59d6a2e334c983fa79f43b167d0583772b.diff

LOG: [mlir][OpenMP] Standardise representation of reduction clause (#96215)

Now all operations with a reduction clause have an array of bools
controlling whether each reduction variable should be passed by
reference or value.

This was already supported for Wsloop and Parallel. The new operations
modified here currently have no flang lowering or translation to LLVMIR
and so further changes are not needed.

It isn't possible to check the verifier in
mlir/test/Dialect/OpenMP/invalid.mlir because there is no way of parsing
an operation to have an incorrect number of byref attributes. The
verifier exists to pick up buggy operation builders or in-place
operation modification.

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP/ClauseProcessor.cpp
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0b4ecaac9d73c..f78cd0f9df1a1 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1027,7 +1027,8 @@ bool ClauseProcessor::processReduction(
 
         // Copy local lists into the output.
         llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
-        llvm::copy(reduceVarByRef, std::back_inserter(result.reduceVarByRef));
+        llvm::copy(reduceVarByRef,
+                   std::back_inserter(result.reductionVarsByRef));
         llvm::copy(reductionDeclSymbols,
                    std::back_inserter(result.reductionDeclSymbols));
 

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index c604531fd8bcc..386f9f3dcb689 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -94,6 +94,7 @@ struct IfClauseOps {
 
 struct InReductionClauseOps {
   llvm::SmallVector<Value> inReductionVars;
+  llvm::SmallVector<bool> inReductionVarsByRef;
   llvm::SmallVector<Attribute> inReductionDeclSymbols;
 };
 
@@ -178,7 +179,7 @@ struct ProcBindClauseOps {
 
 struct ReductionClauseOps {
   llvm::SmallVector<Value> reductionVars;
-  llvm::SmallVector<bool> reduceVarByRef;
+  llvm::SmallVector<bool> reductionVarsByRef;
   llvm::SmallVector<Attribute> reductionDeclSymbols;
 };
 
@@ -199,6 +200,7 @@ struct SimdlenClauseOps {
 
 struct TaskReductionClauseOps {
   llvm::SmallVector<Value> taskReductionVars;
+  llvm::SmallVector<bool> taskReductionVarsByRef;
   llvm::SmallVector<Attribute> taskReductionDeclSymbols;
 };
 

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 437159383e504..bba8a29a5599a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -250,6 +250,7 @@ def TeamsOp : OpenMP_Op<"teams", [
                        Variadic<AnyType>:$allocate_vars,
                        Variadic<AnyType>:$allocators_vars,
                        Variadic<OpenMP_PointerLikeType>:$reduction_vars,
+                       OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
                        OptionalAttr<SymbolRefArrayAttr>:$reductions);
 
   let regions = (region AnyRegion:$region);
@@ -266,8 +267,8 @@ def TeamsOp : OpenMP_Op<"teams", [
     | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
     | `reduction` `(`
         custom<ReductionVarList>(
-          $reduction_vars, type($reduction_vars), $reductions
-        ) `)`
+          $reduction_vars, type($reduction_vars), $reduction_vars_byref,
+          $reductions ) `)`
     | `allocate` `(`
         custom<AllocateAndAllocator>(
           $allocate_vars, type($allocate_vars),
@@ -310,7 +311,9 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
     by the accumulator it uses and accumulators must not be repeated in the same
     reduction. The reduction declaration specifies how to combine the values
     from each section into the final value, which is available in the
-    accumulator after all the sections complete.
+    accumulator after all the sections complete. True values in
+    reduction_vars_byref indicate that the reduction variable should be passed
+    by reference.
 
     The $allocators_vars and $allocate_vars parameters are a variadic list of values
     that specify the memory allocator to be used to obtain storage for private values.
@@ -319,6 +322,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
     implicit barrier at the end of the construct.
   }];
   let arguments = (ins Variadic<OpenMP_PointerLikeType>:$reduction_vars,
+                       OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
                        OptionalAttr<SymbolRefArrayAttr>:$reductions,
                        Variadic<AnyType>:$allocate_vars,
                        Variadic<AnyType>:$allocators_vars,
@@ -333,7 +337,8 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
   let assemblyFormat = [{
     oilist( `reduction` `(`
               custom<ReductionVarList>(
-                $reduction_vars, type($reduction_vars), $reductions
+                $reduction_vars, type($reduction_vars), $reduction_vars_byref,
+                $reductions
               ) `)`
           | `allocate` `(`
               custom<AllocateAndAllocator>(
@@ -793,6 +798,8 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
 
     The `in_reduction` clause specifies that this particular task (among all the
     tasks in current taskgroup, if any) participates in a reduction.
+    `in_reduction_vars_byref` indicates whether each reduction variable should
+    be passed by value or by reference.
 
     The `priority` clause is a hint for the priority of the generated task.
     The `priority` is a non-negative integer expression that provides a hint for
@@ -818,6 +825,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
                        UnitAttr:$untied,
                        UnitAttr:$mergeable,
                        Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
+                       OptionalAttr<DenseBoolArrayAttr>:$in_reduction_vars_byref,
                        OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
                        Optional<I32>:$priority,
                        OptionalAttr<TaskDependArrayAttr>:$depends,
@@ -835,7 +843,8 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
           |`mergeable` $mergeable
           |`in_reduction` `(`
               custom<ReductionVarList>(
-                $in_reduction_vars, type($in_reduction_vars), $in_reductions
+                $in_reduction_vars, type($in_reduction_vars),
+                $in_reduction_vars_byref, $in_reductions
               ) `)`
           |`priority` `(` $priority `)`
           |`allocate` `(`
@@ -962,8 +971,10 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
                        UnitAttr:$untied,
                        UnitAttr:$mergeable,
                        Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
+                       OptionalAttr<DenseBoolArrayAttr>:$in_reduction_vars_byref,
                        OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
                        Variadic<OpenMP_PointerLikeType>:$reduction_vars,
+                       OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
                        OptionalAttr<SymbolRefArrayAttr>:$reductions,
                        Optional<IntLikeType>:$priority,
                        Variadic<AnyType>:$allocate_vars,
@@ -985,11 +996,13 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
           |`mergeable` $mergeable
           |`in_reduction` `(`
               custom<ReductionVarList>(
-                $in_reduction_vars, type($in_reduction_vars), $in_reductions
+                $in_reduction_vars, type($in_reduction_vars),
+                $in_reduction_vars_byref, $in_reductions
               ) `)`
           |`reduction` `(`
               custom<ReductionVarList>(
-                $reduction_vars, type($reduction_vars), $reductions
+                $reduction_vars, type($reduction_vars), $reduction_vars_byref,
+                $reductions
               ) `)`
           |`priority` `(` $priority `:` type($priority) `)`
           |`allocate` `(`
@@ -1040,6 +1053,7 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
   }];
 
   let arguments = (ins Variadic<OpenMP_PointerLikeType>:$task_reduction_vars,
+                       OptionalAttr<DenseBoolArrayAttr>:$task_reduction_vars_byref,
                        OptionalAttr<SymbolRefArrayAttr>:$task_reductions,
                        Variadic<AnyType>:$allocate_vars,
                        Variadic<AnyType>:$allocators_vars);
@@ -1053,7 +1067,8 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
   let assemblyFormat = [{
     oilist(`task_reduction` `(`
               custom<ReductionVarList>(
-                $task_reduction_vars, type($task_reduction_vars), $task_reductions
+                $task_reduction_vars, type($task_reduction_vars),
+                $task_reduction_vars_byref, $task_reductions
               ) `)`
           |`allocate` `(`
               custom<AllocateAndAllocator>(

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index fbad80a2480bf..c0be9e919d2fd 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -48,6 +48,11 @@ static ArrayAttr makeArrayAttr(MLIRContext *context,
   return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
 }
 
+static DenseBoolArrayAttr
+makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
+  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
+}
+
 namespace {
 struct MemRefPointerLikeModel
     : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -499,7 +504,7 @@ static ParseResult parseClauseWithRegionArgs(
             return success();
           })))
     return failure();
-  isByRef = DenseBoolArrayAttr::get(parser.getContext(), isByRefVec);
+  isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
 
   auto *argsBegin = regionPrivateArgs.begin();
   MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
@@ -591,7 +596,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
     mlir::SmallVector<bool> isByRefVec;
     isByRefVec.resize(privateVarTypes.size(), false);
     DenseBoolArrayAttr isByRef =
-        DenseBoolArrayAttr::get(op->getContext(), isByRefVec);
+        makeDenseBoolArrayAttr(op->getContext(), isByRefVec);
 
     printClauseWithRegionArgs(p, op, argsSubrange, "private",
                               privateVarOperands, privateVarTypes, isByRef,
@@ -607,18 +612,22 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
 static ParseResult
 parseReductionVarList(OpAsmParser &parser,
                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
-                      SmallVectorImpl<Type> &types,
+                      SmallVectorImpl<Type> &types, DenseBoolArrayAttr &isByRef,
                       ArrayAttr &redcuctionSymbols) {
   SmallVector<SymbolRefAttr> reductionVec;
+  SmallVector<bool> isByRefVec;
   if (failed(parser.parseCommaSeparatedList([&]() {
+        ParseResult optionalByref = parser.parseOptionalKeyword("byref");
         if (parser.parseAttribute(reductionVec.emplace_back()) ||
             parser.parseArrow() ||
             parser.parseOperand(operands.emplace_back()) ||
             parser.parseColonType(types.emplace_back()))
           return failure();
+        isByRefVec.push_back(optionalByref.succeeded());
         return success();
       })))
     return failure();
+  isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
   SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
   redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
   return success();
@@ -628,11 +637,21 @@ parseReductionVarList(OpAsmParser &parser,
 static void printReductionVarList(OpAsmPrinter &p, Operation *op,
                                   OperandRange reductionVars,
                                   TypeRange reductionTypes,
+                                  std::optional<DenseBoolArrayAttr> isByRef,
                                   std::optional<ArrayAttr> reductions) {
-  for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
+  auto getByRef = [&](unsigned i) -> const char * {
+    if (!isByRef || !*isByRef)
+      return "";
+    assert(isByRef->empty() || i < isByRef->size());
+    if (!isByRef->empty() && (*isByRef)[i])
+      return "byref ";
+    return "";
+  };
+
+  for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
     if (i != 0)
       p << ", ";
-    p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
+    p << getByRef(i) << (*reductions)[i] << " -> " << reductionVars[i] << " : "
       << reductionVars[i].getType();
   }
 }
@@ -641,16 +660,12 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op,
 static LogicalResult
 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductions,
                        OperandRange reductionVars,
-                       std::optional<ArrayRef<bool>> byRef = std::nullopt) {
+                       std::optional<ArrayRef<bool>> byRef) {
   if (!reductionVars.empty()) {
     if (!reductions || reductions->size() != reductionVars.size())
       return op->emitOpError()
              << "expected as many reduction symbol references "
                 "as reduction variables";
-    if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
-      assert(byRef);
-    else
-      assert(!byRef); // TODO: support byref reductions on other operations
     if (byRef && byRef->size() != reductionVars.size())
       return op->emitError() << "expected as many reduction variable by "
                                 "reference attributes as reduction variables";
@@ -1492,7 +1507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(builder, state, clauses.ifVar, clauses.numThreadsVar,
                     clauses.allocateVars, clauses.allocatorVars,
                     clauses.reductionVars,
-                    DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
+                    makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
                     makeArrayAttr(ctx, clauses.reductionDeclSymbols),
                     clauses.procBindKindAttr, clauses.privateVars,
                     makeArrayAttr(ctx, clauses.privatizers));
@@ -1590,6 +1605,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
                  clauses.numTeamsUpperVar, clauses.ifVar,
                  clauses.threadLimitVar, clauses.allocateVars,
                  clauses.allocatorVars, clauses.reductionVars,
+                 makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
                  makeArrayAttr(ctx, clauses.reductionDeclSymbols));
 }
 
@@ -1621,7 +1637,8 @@ LogicalResult TeamsOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
-  return verifyReductionVarList(*this, getReductions(), getReductionVars());
+  return verifyReductionVarList(*this, getReductions(), getReductionVars(),
+                                getReductionVarsByref());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1633,6 +1650,7 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
   SectionsOp::build(builder, state, clauses.reductionVars,
+                    makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
                     makeArrayAttr(ctx, clauses.reductionDeclSymbols),
                     clauses.allocateVars, clauses.allocatorVars,
                     clauses.nowaitAttr);
@@ -1643,7 +1661,8 @@ LogicalResult SectionsOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
-  return verifyReductionVarList(*this, getReductions(), getReductionVars());
+  return verifyReductionVarList(*this, getReductions(), getReductionVars(),
+                                getReductionVarsByref());
 }
 
 LogicalResult SectionsOp::verifyRegions() {
@@ -1733,7 +1752,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
   // privatizers.
   WsloopOp::build(builder, state, clauses.linearVars, clauses.linearStepVars,
                   clauses.reductionVars,
-                  DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
+                  makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
                   makeArrayAttr(ctx, clauses.reductionDeclSymbols),
                   clauses.scheduleValAttr, clauses.scheduleChunkVar,
                   clauses.scheduleModAttr, clauses.scheduleSimdAttr,
@@ -1934,6 +1953,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
   TaskOp::build(
       builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
       clauses.mergeableAttr, clauses.inReductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
       makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
       makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
       clauses.allocateVars, clauses.allocatorVars);
@@ -1945,7 +1965,8 @@ LogicalResult TaskOp::verify() {
   return failed(verifyDependVars)
              ? verifyDependVars
              : verifyReductionVarList(*this, getInReductions(),
-                                      getInReductionVars());
+                                      getInReductionVars(),
+                                      getInReductionVarsByref());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1955,14 +1976,17 @@ LogicalResult TaskOp::verify() {
 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
                         const TaskgroupClauseOps &clauses) {
   MLIRContext *ctx = builder.getContext();
-  TaskgroupOp::build(builder, state, clauses.taskReductionVars,
-                     makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
-                     clauses.allocateVars, clauses.allocatorVars);
+  TaskgroupOp::build(
+      builder, state, clauses.taskReductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.taskReductionVarsByRef),
+      makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
+      clauses.allocateVars, clauses.allocatorVars);
 }
 
 LogicalResult TaskgroupOp::verify() {
   return verifyReductionVarList(*this, getTaskReductions(),
-                                getTaskReductionVars());
+                                getTaskReductionVars(),
+                                getTaskReductionVarsByref());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1976,7 +2000,9 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
   TaskloopOp::build(
       builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
       clauses.mergeableAttr, clauses.inReductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
       makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
       makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
       clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
       clauses.numTasksVar, clauses.nogroupAttr);
@@ -1994,10 +2020,11 @@ LogicalResult TaskloopOp::verify() {
   if (getAllocateVars().size() != getAllocatorsVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
-  if (failed(
-          verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
+  if (failed(verifyReductionVarList(*this, getReductions(), getReductionVars(),
+                                    getReductionVarsByref())) ||
       failed(verifyReductionVarList(*this, getInReductions(),
-                                    getInReductionVars())))
+                                    getInReductionVars(),
+                                    getInReductionVarsByref())))
     return failure();
 
   if (!getReductionVars().empty() && getNogroup())

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 2edcbefa3df02..56e65213d147b 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1047,6 +1047,14 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
     omp.terminator
   }
 
+  // Test reduction byref
+  // CHECK: omp.teams reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr) {
+  omp.teams reduction(byref @add_f32 -> %0 : !llvm.ptr) {
+    %1 = arith.constant 2.0 : f32
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+
   // Test allocate.
   // CHECK: omp.teams allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
   omp.teams allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
@@ -1078,6 +1086,27 @@ func.func @sections_reduction() {
   return
 }
 
+// CHECK-LABEL: func @sections_reduction_byref
+func.func @sections_reduction_byref() {
+  %c1 = arith.constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
+  // CHECK: omp.sections reduction(byref @add_f32 -> {{.+}} : !llvm.ptr)
+  omp.sections reduction(byref @add_f32 -> %0 : !llvm.ptr) {
+    // CHECK: omp.section
+    omp.section {
+      %1 = arith.constant 2.0 : f32
+      omp.terminator
+    }
+    // CHECK: omp.section
+    omp.section {
+      %1 = arith.constant 3.0 : f32
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
 // CHECK: omp.declare_reduction
 // CHECK-LABEL: @add2_f32
 omp.declare_reduction @add2_f32 : f32
@@ -2010,6 +2039,15 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
     omp.terminator
   }
 
+  // Checking `in_reduction` clause (mixed) byref
+  // CHECK: omp.task in_reduction(byref @add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) {
+  omp.task in_reduction(byref @add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) {
+    // CHECK: "test.foo"() : () -> ()
+    "test.foo"() : () -> ()
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+
   // Checking priority clause
   // CHECK: omp.task priority(%[[i32_var]]) {
   omp.task priority(%i32_var) {
@@ -2031,8 +2069,8 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
   // Checking multiple clauses
   // CHECK: omp.task if(%[[bool_var]]) final(%[[bool_var]]) untied
   omp.task if(%bool_var) final(%bool_var) untied
-      // CHECK-SAME: in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr)
-      in_reduction(@add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr)
+      // CHECK-SAME: in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr, byref @add_f32 -> %[[redn_var2]] : !llvm.ptr)
+      in_reduction(@add_f32 -> %0 : !llvm.ptr, byref @add_f32 -> %1 : !llvm.ptr)
       // CHECK-SAME: priority(%[[i32_var]])
       priority(%i32_var)
       // CHECK-SAME: allocate(%[[data_var]] : memref<i32> -> %[[data_var]] : memref<i32>)
@@ -2287,8 +2325,26 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
     }
   }
 
-  // CHECK: omp.taskloop reduction(@add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
-  omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) {
+  // Checking byref attribute for in_reduction
+  // CHECK: omp.taskloop in_reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
+  omp.taskloop in_reduction(byref @add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+
+  // CHECK: omp.taskloop reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
+  omp.taskloop reduction(byref @add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+
+  // check byref attrbute for reduction
+  // CHECK: omp.taskloop reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, byref @add_f32 -> %{{.+}} : !llvm.ptr) {
+  omp.taskloop reduction(byref @add_f32 -> %testf32 : !llvm.ptr, byref @add_f32 -> %testf32_2 : !llvm.ptr) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
       // CHECK: omp.yield
       omp.yield


        


More information about the flang-commits mailing list