[flang-commits] [flang] fcbf00f - [mlir][OpenMP] Added ReductionClauseInterface
Shraiysh Vaishay via flang-commits
flang-commits at lists.llvm.org
Mon Mar 28 01:54:39 PDT 2022
Author: Shraiysh Vaishay
Date: 2022-03-28T14:24:28+05:30
New Revision: fcbf00f098b234c205b1ee22b982e7e575d75f14
URL: https://github.com/llvm/llvm-project/commit/fcbf00f098b234c205b1ee22b982e7e575d75f14
DIFF: https://github.com/llvm/llvm-project/commit/fcbf00f098b234c205b1ee22b982e7e575d75f14.diff
LOG: [mlir][OpenMP] Added ReductionClauseInterface
This patch adds the ReductionClauseInterface and also adds reduction
support for `omp.parallel` operation.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D122402
Added:
Modified:
flang/lib/Lower/OpenMP.cpp
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/test/Dialect/OpenMP/ops.mlir
Removed:
################################################################################
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index bdecfeee968bd..b3b8fc3211956 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -205,7 +205,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
// Create and insert the operation.
auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
- ValueRange(), ValueRange(),
+ /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
+ /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>());
// Handle attribute based clauses.
for (const auto &clause : parallelOpClauseList.v) {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index efc471833d1c7..0bc267c2124ca 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -66,7 +66,7 @@ def OpenMP_PointerLikeType : Type<
def ParallelOp : OpenMP_Op<"parallel", [
AutomaticAllocationScope, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
- RecursiveSideEffects]> {
+ RecursiveSideEffects, ReductionClauseInterface]> {
let summary = "parallel construct";
let description = [{
The parallel construct includes a region of code which is to be executed
@@ -83,6 +83,18 @@ def ParallelOp : OpenMP_Op<"parallel", [
The $allocators_vars and $allocate_vars parameters are a variadic list of values
that specify the memory allocator to be used to obtain storage for private values.
+ Reductions can be performed in a parallel construct by specifying reduction
+ accumulator variables in `reduction_vars` and symbols referring to reduction
+ declarations in the `reductions` attribute. Each reduction is identified
+ by the accumulator it uses and accumulators must not be repeated in the same
+ reduction. The `omp.reduction` operation accepts the accumulator and a
+ partial value which is considered to be produced by the thread for the
+ given reduction. If multiple values are produced for the same accumulator,
+ i.e. there are multiple `omp.reduction`s, the last value is taken. The
+ reduction declaration specifies how to combine the values from each thread
+ into the final value, which is available in the accumulator after all the
+ threads complete.
+
The optional $proc_bind_val attribute controls the thread affinity for the execution
of the parallel region.
}];
@@ -91,6 +103,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
Optional<AnyType>:$num_threads_var,
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars,
+ Variadic<OpenMP_PointerLikeType>:$reduction_vars,
+ OptionalAttr<SymbolRefArrayAttr>:$reductions,
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
let regions = (region AnyRegion:$region);
@@ -99,7 +113,11 @@ def ParallelOp : OpenMP_Op<"parallel", [
OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
let assemblyFormat = [{
- oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)`
+ oilist( `reduction` `(`
+ custom<ReductionVarList>(
+ $reduction_vars, type($reduction_vars), $reductions
+ ) `)`
+ | `if` `(` $if_expr_var `:` type($if_expr_var) `)`
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
@@ -110,6 +128,12 @@ def ParallelOp : OpenMP_Op<"parallel", [
) $region attr-dict
}];
let hasVerifier = 1;
+ let extraClassDeclaration = [{
+ // TODO: remove this once emitAccessorPrefix is set to
+ // kEmitAccessorPrefix_Prefixed for the dialect.
+ /// Returns the reduction variables
+ operand_range getReductionVars() { return reduction_vars(); }
+ }];
}
def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> {
@@ -156,7 +180,8 @@ def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">]> {
let assemblyFormat = "$region attr-dict";
}
-def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
+def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
+ ReductionClauseInterface]> {
let summary = "sections construct";
let description = [{
The sections construct is a non-iterative worksharing construct that
@@ -207,6 +232,13 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
let hasVerifier = 1;
let hasRegionVerifier = 1;
+
+ let extraClassDeclaration = [{
+ // TODO: remove this once emitAccessorPrefix is set to
+ // kEmitAccessorPrefix_Prefixed for the dialect.
+ /// Returns the reduction variables
+ operand_range getReductionVars() { return reduction_vars(); }
+ }];
}
//===----------------------------------------------------------------------===//
@@ -247,7 +279,7 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
- RecursiveSideEffects]> {
+ RecursiveSideEffects, ReductionClauseInterface]> {
let summary = "workshare loop construct";
let description = [{
The workshare loop construct specifies that the iterations of the loop(s)
@@ -338,6 +370,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
/// Returns the number of reduction variables.
unsigned getNumReductionVars() { return reduction_vars().size(); }
+
+ // TODO: remove this once emitAccessorPrefix is set to
+ // kEmitAccessorPrefix_Prefixed for the dialect.
+ /// Returns the reduction variables
+ operand_range getReductionVars() { return reduction_vars(); }
}];
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index a99e496e11b07..83180e03455e2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -31,4 +31,18 @@ def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
];
}
+def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
+ let description = [{
+ OpenMP operations that support reduction clause have this interface.
+ }];
+
+ let cppNamespace = "::mlir::omp";
+
+ let methods = [
+ InterfaceMethod<
+ "Get reduction vars", "::mlir::Operation::operand_range",
+ "getReductionVars">,
+ ];
+}
+
#endif // OpenMP_OPS_INTERFACES
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c43dea42aebc0..176c27b132164 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -27,6 +27,7 @@
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
+#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
using namespace mlir;
@@ -58,19 +59,6 @@ void OpenMPDialect::initialize() {
MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
}
-//===----------------------------------------------------------------------===//
-// ParallelOp
-//===----------------------------------------------------------------------===//
-
-void ParallelOp::build(OpBuilder &builder, OperationState &state,
- ArrayRef<NamedAttribute> attributes) {
- ParallelOp::build(
- builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
- /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
- /*proc_bind_val=*/nullptr);
- state.addAttributes(attributes);
-}
-
//===----------------------------------------------------------------------===//
// Parser and printer for Allocate Clause
//===----------------------------------------------------------------------===//
@@ -142,13 +130,6 @@ void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
p << stringifyEnum(attr.getValue());
}
-LogicalResult ParallelOp::verify() {
- if (allocate_vars().size() != allocators_vars().size())
- return emitError(
- "expected equal sizes for allocate and allocator variables");
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Parser and printer for Linear Clause
//===----------------------------------------------------------------------===//
@@ -469,6 +450,27 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
return success();
}
+//===----------------------------------------------------------------------===//
+// ParallelOp
+//===----------------------------------------------------------------------===//
+
+void ParallelOp::build(OpBuilder &builder, OperationState &state,
+ ArrayRef<NamedAttribute> attributes) {
+ ParallelOp::build(
+ builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
+ /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
+ /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
+ /*proc_bind_val=*/nullptr);
+ state.addAttributes(attributes);
+}
+
+LogicalResult ParallelOp::verify() {
+ if (allocate_vars().size() != allocators_vars().size())
+ return emitError(
+ "expected equal sizes for allocate and allocator variables");
+ return verifyReductionVarList(*this, reductions(), reduction_vars());
+}
+
//===----------------------------------------------------------------------===//
// Verifier for SectionsOp
//===----------------------------------------------------------------------===//
@@ -709,13 +711,17 @@ LogicalResult ReductionDeclareOp::verifyRegions() {
}
LogicalResult ReductionOp::verify() {
- // TODO: generalize this to an op interface when there is more than one op
- // that supports reductions.
- auto container = (*this)->getParentOfType<WsLoopOp>();
- for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
- if (container.reduction_vars()[i] == accumulator())
- return success();
-
+ auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
+ if (!op)
+ return emitOpError() << "must be used within an operation supporting "
+ "reduction clause interface";
+ while (op) {
+ for (const auto &var :
+ cast<ReductionClauseInterface>(op).getReductionVars())
+ if (var == accumulator())
+ return success();
+ op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
+ }
return emitOpError() << "the accumulator is not used by the parent";
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60ef9667b36ea..9149feb7a92eb 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -59,7 +59,7 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
// CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%num_threads, %data_var, %data_var) ({
omp.terminator
- }) {operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
+ }) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (si32, memref<i32>, memref<i32>) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -68,22 +68,22 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
// CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%if_cond, %data_var, %data_var) ({
omp.terminator
- }) {operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
+ }) {operand_segment_sizes = dense<[1,0,1,1,0]> : vector<5xi32>} : (i1, memref<i32>, memref<i32>) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
+ }) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, si32) -> ()
omp.terminator
- }) {operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
+ }) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
+ }) {operand_segment_sizes = dense<[0,0,1,1,0]> : vector<5xi32>} : (memref<i32>, memref<i32>) -> ()
return
}
@@ -407,7 +407,8 @@ atomic {
omp.yield
}
-func @reduction(%lb : index, %ub : index, %step : index) {
+// CHECK-LABEL: func @wsloop_reduction
+func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
// CHECK: reduction(@add_f32 -> %{{.+}} : !llvm.ptr<f32>)
@@ -421,6 +422,65 @@ func @reduction(%lb : index, %ub : index, %step : index) {
return
}
+// CHECK-LABEL: func @parallel_reduction
+func @parallel_reduction() {
+ %c1 = arith.constant 1 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+ // CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr<f32>)
+ omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+ %1 = arith.constant 2.0 : f32
+ // CHECK: omp.reduction %{{.+}}, %{{.+}}
+ omp.reduction %1, %0 : !llvm.ptr<f32>
+ omp.terminator
+ }
+ return
+}
+
+// CHECK: func @parallel_wsloop_reduction
+func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
+ %c1 = arith.constant 1 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+ // CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr<f32>) {
+ omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+ // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
+ omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+ %1 = arith.constant 2.0 : f32
+ // CHECK: omp.reduction %{{.+}}, %{{.+}} : !llvm.ptr<f32>
+ omp.reduction %1, %0 : !llvm.ptr<f32>
+ // CHECK: omp.yield
+ omp.yield
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func @sections_reduction
+func @sections_reduction() {
+ %c1 = arith.constant 1 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+ // CHECK: omp.sections reduction(@add_f32 -> {{.+}} : !llvm.ptr<f32>)
+ omp.sections reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+ // CHECK: omp.section
+ omp.section {
+ %1 = arith.constant 2.0 : f32
+ // CHECK: omp.reduction %{{.+}}, %{{.+}}
+ omp.reduction %1, %0 : !llvm.ptr<f32>
+ omp.terminator
+ }
+ // CHECK: omp.section
+ omp.section {
+ %1 = arith.constant 3.0 : f32
+ // CHECK: omp.reduction %{{.+}}, %{{.+}}
+ omp.reduction %1, %0 : !llvm.ptr<f32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
// CHECK: omp.reduction.declare
// CHECK-LABEL: @add2_f32
omp.reduction.declare @add2_f32 : f32
@@ -438,9 +498,10 @@ combiner {
}
// CHECK-NOT: atomic
-func @reduction2(%lb : index, %ub : index, %step : index) {
+// CHECK-LABEL: func @wsloop_reduction2
+func @wsloop_reduction2(%lb : index, %ub : index, %step : index) {
%0 = memref.alloca() : memref<1xf32>
- // CHECK: reduction
+ // CHECK: omp.wsloop reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
omp.wsloop reduction(@add2_f32 -> %0 : memref<1xf32>)
for (%iv) : index = (%lb) to (%ub) step (%step) {
%1 = arith.constant 2.0 : f32
@@ -451,6 +512,61 @@ func @reduction2(%lb : index, %ub : index, %step : index) {
return
}
+// CHECK-LABEL: func @parallel_reduction2
+func @parallel_reduction2() {
+ %0 = memref.alloca() : memref<1xf32>
+ // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
+ omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) {
+ %1 = arith.constant 2.0 : f32
+ // CHECK: omp.reduction
+ omp.reduction %1, %0 : memref<1xf32>
+ omp.terminator
+ }
+ return
+}
+
+// CHECK: func @parallel_wsloop_reduction2
+func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) {
+ %c1 = arith.constant 1 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+ // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr<f32>) {
+ omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr<f32>) {
+ // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
+ omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+ %1 = arith.constant 2.0 : f32
+ // CHECK: omp.reduction %{{.+}}, %{{.+}} : !llvm.ptr<f32>
+ omp.reduction %1, %0 : !llvm.ptr<f32>
+ // CHECK: omp.yield
+ omp.yield
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func @sections_reduction2
+func @sections_reduction2() {
+ %0 = memref.alloca() : memref<1xf32>
+ // CHECK: omp.sections reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
+ omp.sections reduction(@add2_f32 -> %0 : memref<1xf32>) {
+ omp.section {
+ %1 = arith.constant 2.0 : f32
+ // CHECK: omp.reduction
+ omp.reduction %1, %0 : memref<1xf32>
+ omp.terminator
+ }
+ omp.section {
+ %1 = arith.constant 2.0 : f32
+ // CHECK: omp.reduction
+ omp.reduction %1, %0 : memref<1xf32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
// CHECK: omp.critical.declare @mutex1 hint(uncontended)
omp.critical.declare @mutex1 hint(uncontended)
// CHECK: omp.critical.declare @mutex2 hint(contended)
More information about the flang-commits
mailing list