[flang-commits] [flang] fcbf00f - [mlir][OpenMP] Added ReductionClauseInterface

Shraiysh Vaishay via flang-commits flang-commits at lists.llvm.org
Mon Mar 28 01:54:39 PDT 2022


Author: Shraiysh Vaishay
Date: 2022-03-28T14:24:28+05:30
New Revision: fcbf00f098b234c205b1ee22b982e7e575d75f14

URL: https://github.com/llvm/llvm-project/commit/fcbf00f098b234c205b1ee22b982e7e575d75f14
DIFF: https://github.com/llvm/llvm-project/commit/fcbf00f098b234c205b1ee22b982e7e575d75f14.diff

LOG: [mlir][OpenMP] Added ReductionClauseInterface

This patch adds the ReductionClauseInterface and also adds reduction
support for `omp.parallel` operation.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D122402

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP.cpp
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index bdecfeee968bd..b3b8fc3211956 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -205,7 +205,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     // Create and insert the operation.
     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
         currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
-        ValueRange(), ValueRange(),
+        /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
+        /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
         procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>());
     // Handle attribute based clauses.
     for (const auto &clause : parallelOpClauseList.v) {

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index efc471833d1c7..0bc267c2124ca 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -66,7 +66,7 @@ def OpenMP_PointerLikeType : Type<
 def ParallelOp : OpenMP_Op<"parallel", [
                  AutomaticAllocationScope, AttrSizedOperandSegments,
                  DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
-                 RecursiveSideEffects]> {
+                 RecursiveSideEffects, ReductionClauseInterface]> {
   let summary = "parallel construct";
   let description = [{
     The parallel construct includes a region of code which is to be executed
@@ -83,6 +83,18 @@ def ParallelOp : OpenMP_Op<"parallel", [
     The $allocators_vars and $allocate_vars parameters are a variadic list of values
     that specify the memory allocator to be used to obtain storage for private values.
 
+    Reductions can be performed in a parallel construct by specifying reduction
+    accumulator variables in `reduction_vars` and symbols referring to reduction
+    declarations in the `reductions` attribute. Each reduction is identified
+    by the accumulator it uses and accumulators must not be repeated in the same
+    reduction. The `omp.reduction` operation accepts the accumulator and a
+    partial value which is considered to be produced by the thread for the
+    given reduction. If multiple values are produced for the same accumulator,
+    i.e. there are multiple `omp.reduction`s, the last value is taken. The
+    reduction declaration specifies how to combine the values from each thread
+    into the final value, which is available in the accumulator after all the
+    threads complete.
+
     The optional $proc_bind_val attribute controls the thread affinity for the execution
     of the parallel region.
   }];
@@ -91,6 +103,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
              Optional<AnyType>:$num_threads_var,
              Variadic<AnyType>:$allocate_vars,
              Variadic<AnyType>:$allocators_vars,
+             Variadic<OpenMP_PointerLikeType>:$reduction_vars,
+             OptionalAttr<SymbolRefArrayAttr>:$reductions,
              OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
 
   let regions = (region AnyRegion:$region);
@@ -99,7 +113,11 @@ def ParallelOp : OpenMP_Op<"parallel", [
     OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
   ];
   let assemblyFormat = [{
-    oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)`
+    oilist( `reduction` `(`
+              custom<ReductionVarList>(
+                $reduction_vars, type($reduction_vars), $reductions
+              ) `)`
+          | `if` `(` $if_expr_var `:` type($if_expr_var) `)`
           | `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
           | `allocate` `(`
               custom<AllocateAndAllocator>(
@@ -110,6 +128,12 @@ def ParallelOp : OpenMP_Op<"parallel", [
     ) $region attr-dict
   }];
   let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    // TODO: remove this once emitAccessorPrefix is set to
+    // kEmitAccessorPrefix_Prefixed for the dialect.
+    /// Returns the reduction variables
+    operand_range getReductionVars() { return reduction_vars(); }
+  }];
 }
 
 def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> {
@@ -156,7 +180,8 @@ def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">]> {
   let assemblyFormat = "$region attr-dict";
 }
 
-def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
+def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
+                           ReductionClauseInterface]> {
   let summary = "sections construct";
   let description = [{
     The sections construct is a non-iterative worksharing construct that
@@ -207,6 +232,13 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
 
   let hasVerifier = 1;
   let hasRegionVerifier = 1;
+
+  let extraClassDeclaration = [{
+    // TODO: remove this once emitAccessorPrefix is set to
+    // kEmitAccessorPrefix_Prefixed for the dialect.
+    /// Returns the reduction variables
+    operand_range getReductionVars() { return reduction_vars(); }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -247,7 +279,7 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
 
 def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
                          AllTypesMatch<["lowerBound", "upperBound", "step"]>,
-                         RecursiveSideEffects]> {
+                         RecursiveSideEffects, ReductionClauseInterface]> {
   let summary = "workshare loop construct";
   let description = [{
     The workshare loop construct specifies that the iterations of the loop(s)
@@ -338,6 +370,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
 
     /// Returns the number of reduction variables.
     unsigned getNumReductionVars() { return reduction_vars().size(); }
+
+    // TODO: remove this once emitAccessorPrefix is set to
+    // kEmitAccessorPrefix_Prefixed for the dialect.
+    /// Returns the reduction variables
+    operand_range getReductionVars() { return reduction_vars(); }
   }];
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index a99e496e11b07..83180e03455e2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -31,4 +31,18 @@ def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
   ];
 }
 
+def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
+  let description = [{
+    OpenMP operations that support reduction clause have this interface.
+  }];
+
+  let cppNamespace = "::mlir::omp";
+
+  let methods = [
+    InterfaceMethod<
+      "Get reduction vars", "::mlir::Operation::operand_range",
+      "getReductionVars">,
+  ];
+}
+
 #endif // OpenMP_OPS_INTERFACES

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c43dea42aebc0..176c27b132164 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -27,6 +27,7 @@
 
 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
+#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
 
 using namespace mlir;
@@ -58,19 +59,6 @@ void OpenMPDialect::initialize() {
   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
 }
 
-//===----------------------------------------------------------------------===//
-// ParallelOp
-//===----------------------------------------------------------------------===//
-
-void ParallelOp::build(OpBuilder &builder, OperationState &state,
-                       ArrayRef<NamedAttribute> attributes) {
-  ParallelOp::build(
-      builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
-      /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
-      /*proc_bind_val=*/nullptr);
-  state.addAttributes(attributes);
-}
-
 //===----------------------------------------------------------------------===//
 // Parser and printer for Allocate Clause
 //===----------------------------------------------------------------------===//
@@ -142,13 +130,6 @@ void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
   p << stringifyEnum(attr.getValue());
 }
 
-LogicalResult ParallelOp::verify() {
-  if (allocate_vars().size() != allocators_vars().size())
-    return emitError(
-        "expected equal sizes for allocate and allocator variables");
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Parser and printer for Linear Clause
 //===----------------------------------------------------------------------===//
@@ -469,6 +450,27 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ParallelOp
+//===----------------------------------------------------------------------===//
+
+void ParallelOp::build(OpBuilder &builder, OperationState &state,
+                       ArrayRef<NamedAttribute> attributes) {
+  ParallelOp::build(
+      builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
+      /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
+      /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
+      /*proc_bind_val=*/nullptr);
+  state.addAttributes(attributes);
+}
+
+LogicalResult ParallelOp::verify() {
+  if (allocate_vars().size() != allocators_vars().size())
+    return emitError(
+        "expected equal sizes for allocate and allocator variables");
+  return verifyReductionVarList(*this, reductions(), reduction_vars());
+}
+
 //===----------------------------------------------------------------------===//
 // Verifier for SectionsOp
 //===----------------------------------------------------------------------===//
@@ -709,13 +711,17 @@ LogicalResult ReductionDeclareOp::verifyRegions() {
 }
 
 LogicalResult ReductionOp::verify() {
-  // TODO: generalize this to an op interface when there is more than one op
-  // that supports reductions.
-  auto container = (*this)->getParentOfType<WsLoopOp>();
-  for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
-    if (container.reduction_vars()[i] == accumulator())
-      return success();
-
+  auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
+  if (!op)
+    return emitOpError() << "must be used within an operation supporting "
+                            "reduction clause interface";
+  while (op) {
+    for (const auto &var :
+         cast<ReductionClauseInterface>(op).getReductionVars())
+      if (var == accumulator())
+        return success();
+    op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
+  }
   return emitOpError() << "the accumulator is not used by the parent";
 }
 

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60ef9667b36ea..9149feb7a92eb 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -59,7 +59,7 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
   // CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
     "omp.parallel"(%num_threads, %data_var, %data_var) ({
       omp.terminator
-    }) {operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
+    }) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (si32, memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -68,22 +68,22 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
   // CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
     "omp.parallel"(%if_cond, %data_var, %data_var) ({
       omp.terminator
-    }) {operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
+    }) {operand_segment_sizes = dense<[1,0,1,1,0]> : vector<5xi32>} : (i1, memref<i32>, memref<i32>) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
+    }) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, si32) -> ()
 
     omp.terminator
-  }) {operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
+  }) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
+  }) {operand_segment_sizes = dense<[0,0,1,1,0]> : vector<5xi32>} : (memref<i32>, memref<i32>) -> ()
 
   return
 }
@@ -407,7 +407,8 @@ atomic {
   omp.yield
 }
 
-func @reduction(%lb : index, %ub : index, %step : index) {
+// CHECK-LABEL: func @wsloop_reduction
+func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
   %c1 = arith.constant 1 : i32
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
   // CHECK: reduction(@add_f32 -> %{{.+}} : !llvm.ptr<f32>)
@@ -421,6 +422,65 @@ func @reduction(%lb : index, %ub : index, %step : index) {
   return
 }
 
+// CHECK-LABEL: func @parallel_reduction
+func @parallel_reduction() {
+  %c1 = arith.constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+  // CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr<f32>)
+  omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+    %1 = arith.constant 2.0 : f32
+    // CHECK: omp.reduction %{{.+}}, %{{.+}}
+    omp.reduction %1, %0 : !llvm.ptr<f32>
+    omp.terminator
+  }
+  return
+}
+
+// CHECK: func @parallel_wsloop_reduction
+func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
+  %c1 = arith.constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+  // CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr<f32>) {
+  omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+    // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
+    omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+      %1 = arith.constant 2.0 : f32
+      // CHECK: omp.reduction %{{.+}}, %{{.+}} : !llvm.ptr<f32>
+      omp.reduction %1, %0 : !llvm.ptr<f32>
+      // CHECK: omp.yield
+      omp.yield
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @sections_reduction
+func @sections_reduction() {
+  %c1 = arith.constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+  // CHECK: omp.sections reduction(@add_f32 -> {{.+}} : !llvm.ptr<f32>)
+  omp.sections reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+    // CHECK: omp.section
+    omp.section {
+      %1 = arith.constant 2.0 : f32
+      // CHECK: omp.reduction %{{.+}}, %{{.+}}
+      omp.reduction %1, %0 : !llvm.ptr<f32>
+      omp.terminator
+    }
+    // CHECK: omp.section
+    omp.section {
+      %1 = arith.constant 3.0 : f32
+      // CHECK: omp.reduction %{{.+}}, %{{.+}}
+      omp.reduction %1, %0 : !llvm.ptr<f32>
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
 // CHECK: omp.reduction.declare
 // CHECK-LABEL: @add2_f32
 omp.reduction.declare @add2_f32 : f32
@@ -438,9 +498,10 @@ combiner {
 }
 // CHECK-NOT: atomic
 
-func @reduction2(%lb : index, %ub : index, %step : index) {
+// CHECK-LABEL: func @wsloop_reduction2
+func @wsloop_reduction2(%lb : index, %ub : index, %step : index) {
   %0 = memref.alloca() : memref<1xf32>
-  // CHECK: reduction
+  // CHECK: omp.wsloop reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
   omp.wsloop reduction(@add2_f32 -> %0 : memref<1xf32>)
   for (%iv) : index = (%lb) to (%ub) step (%step) {
     %1 = arith.constant 2.0 : f32
@@ -451,6 +512,61 @@ func @reduction2(%lb : index, %ub : index, %step : index) {
   return
 }
 
+// CHECK-LABEL: func @parallel_reduction2
+func @parallel_reduction2() {
+  %0 = memref.alloca() : memref<1xf32>
+  // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
+  omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) {
+    %1 = arith.constant 2.0 : f32
+    // CHECK: omp.reduction
+    omp.reduction %1, %0 : memref<1xf32>
+    omp.terminator
+  }
+  return
+}
+
+// CHECK: func @parallel_wsloop_reduction2
+func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) {
+  %c1 = arith.constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+  // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr<f32>) {
+  omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr<f32>) {
+    // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
+    omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+      %1 = arith.constant 2.0 : f32
+      // CHECK: omp.reduction %{{.+}}, %{{.+}} : !llvm.ptr<f32>
+      omp.reduction %1, %0 : !llvm.ptr<f32>
+      // CHECK: omp.yield
+      omp.yield
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @sections_reduction2
+func @sections_reduction2() {
+  %0 = memref.alloca() : memref<1xf32>
+  // CHECK: omp.sections reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
+  omp.sections reduction(@add2_f32 -> %0 : memref<1xf32>) {
+    omp.section {
+      %1 = arith.constant 2.0 : f32
+      // CHECK: omp.reduction
+      omp.reduction %1, %0 : memref<1xf32>
+      omp.terminator
+    }
+    omp.section {
+      %1 = arith.constant 2.0 : f32
+      // CHECK: omp.reduction
+      omp.reduction %1, %0 : memref<1xf32>
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
 // CHECK: omp.critical.declare @mutex1 hint(uncontended)
 omp.critical.declare @mutex1 hint(uncontended)
 // CHECK: omp.critical.declare @mutex2 hint(contended)


        


More information about the flang-commits mailing list