[Mlir-commits] [mlir] c282d55 - [mlir] add support for reductions in OpenMP WsLoopOp

Alex Zinenko llvmlistbot at llvm.org
Fri Jul 9 08:54:27 PDT 2021


Author: Alex Zinenko
Date: 2021-07-09T17:54:20+02:00
New Revision: c282d55a38577e076b48cd7a8113e5eb0a2039cd

URL: https://github.com/llvm/llvm-project/commit/c282d55a38577e076b48cd7a8113e5eb0a2039cd
DIFF: https://github.com/llvm/llvm-project/commit/c282d55a38577e076b48cd7a8113e5eb0a2039cd.diff

LOG: [mlir] add support for reductions in OpenMP WsLoopOp

Use a modeling similar to SCF ParallelOp to support arbitrary parallel
reductions. The two main differences are: (1) reductions are named and declared
beforehand similarly to functions using a special op that provides the neutral
element, the reduction code and optionally the atomic reduction code; (2)
reductions go through memory instead because this is closer to the OpenMP
semantics.

See https://llvm.discourse.group/t/rfc-openmp-reduction-support/3367.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D105358

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
    mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/CMakeLists.txt
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir
    mlir/test/Target/LLVMIR/openmp-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index 90614993cacf6..a0aa8ba8ab030 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -9,6 +9,8 @@ mlir_tablegen(OpenMPOps.h.inc -gen-op-decls)
 mlir_tablegen(OpenMPOps.cpp.inc -gen-op-defs)
 mlir_tablegen(OpenMPOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(OpenMPOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(OpenMPTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs)
 add_mlir_doc(OpenMPOps OpenMPDialect Dialects/ -gen-dialect-doc)
 add_public_tablegen_target(MLIROpenMPOpsIncGen)
 add_dependencies(OpenMPDialectDocGen omp_common_td)

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
index 8f79c4af1ad80..7e916f7fec7c0 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
@@ -13,14 +13,16 @@
 #ifndef MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_
 #define MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_
 
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc"
 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
+#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8eaaf971f3fdb..6cab36ddc570a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -17,12 +17,14 @@
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
-include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Dialect/OpenMP/OmpCommon.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 
 def OpenMP_Dialect : Dialect {
   let name = "omp";
   let cppNamespace = "::mlir::omp";
+  let dependentDialects = ["::mlir::LLVM::LLVMDialect"];
 }
 
 class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
@@ -31,6 +33,27 @@ class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
 // Type which can be constraint accepting standard integers and indices.
 def IntLikeType : AnyTypeOf<[AnyInteger, Index]>;
 
+def OpenMP_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
+  let cppNamespace = "::mlir::omp";
+
+  let description = [{
+    An interface for pointer-like types suitable to contain a value that OpenMP
+    specification refers to as variable.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*description=*/"Returns the pointee type.",
+      /*retTy=*/"::mlir::Type",
+      /*methodName=*/"getElementType"
+    >,
+  ];
+}
+
+def OpenMP_PointerLikeType : Type<
+  CPred<"$_self.isa<::mlir::omp::PointerLikeType>()">,
+  "OpenMP-compatible variable type", "::mlir::omp::PointerLikeType">;
+
 //===----------------------------------------------------------------------===//
 // 2.6 parallel Construct
 //===----------------------------------------------------------------------===//
@@ -146,6 +169,18 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
     that the `linear_vars` and `linear_step_vars` variadic lists should contain
     the same number of elements.
 
+    Reductions can be performed in a workshare loop by specifying reduction
+    accumulator variables in `reduction_vars` and symbols referring to reduction
+    declarations in the `reductions` attribute. Each reduction is identified
+    by the accumulator it uses and accumulators must not be repeated in the same
+    reduction. The `omp.reduction` operation accepts the accumulator and a
+    partial value which is considered to be produced by the current loop
+    iteration for the given reduction. If multiple values are produced for the
+    same accumulator, i.e. there are multiple `omp.reduction`s, the last value
+    is taken. The reduction declaration specifies how to combine the values from
+    each iteration into the final value, which is available in the accumulator
+    after the loop completes.
+
     The optional `schedule_val` attribute specifies the loop schedule for this
     loop, determining how the loop is distributed across the parallel threads.
     The optional `schedule_chunk_var` associated with this determines further
@@ -173,6 +208,9 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
              Variadic<AnyType>:$lastprivate_vars,
              Variadic<AnyType>:$linear_vars,
              Variadic<AnyType>:$linear_step_vars,
+             Variadic<OpenMP_PointerLikeType>:$reduction_vars,
+             OptionalAttr<TypedArrayAttrBase<SymbolRefAttr,
+                            "array of symbol references">>:$reductions,
              OptionalAttr<ScheduleKind>:$schedule_val,
              Optional<AnyType>:$schedule_chunk_var,
              Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$collapse_val,
@@ -191,11 +229,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
                "ValueRange":$upperBound, "ValueRange":$step,
                "ValueRange":$privateVars, "ValueRange":$firstprivateVars,
                "ValueRange":$lastprivate_vars, "ValueRange":$linear_vars,
-               "ValueRange":$linear_step_vars, "StringAttr":$schedule_val,
-               "Value":$schedule_chunk_var, "IntegerAttr":$collapse_val,
-               "UnitAttr":$nowait, "IntegerAttr":$ordered_val,
-               "StringAttr":$order_val, "UnitAttr":$inclusive, CArg<"bool",
-               "true">:$buildBody)>,
+               "ValueRange":$linear_step_vars, "ValueRange":$reduction_vars,
+               "StringAttr":$schedule_val, "Value":$schedule_chunk_var,
+               "IntegerAttr":$collapse_val, "UnitAttr":$nowait,
+               "IntegerAttr":$ordered_val, "StringAttr":$order_val,
+               "UnitAttr":$inclusive, CArg<"bool", "true">:$buildBody)>,
     OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
                CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
   ];
@@ -205,13 +243,18 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
   let extraClassDeclaration = [{
     /// Returns the number of loops in the workshape loop nest.
     unsigned getNumLoops() { return lowerBound().size(); }
+
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return reduction_vars().size(); }
   }];
   let parser = [{ return parseWsLoopOp(parser, result); }];
   let printer = [{ return printWsLoopOp(p, *this); }];
+  let verifier = [{ return ::verifyWsLoopOp(*this); }];
 }
 
-def YieldOp : OpenMP_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
-                                  HasParent<"WsLoopOp">]> {
+def YieldOp : OpenMP_Op<"yield",
+    [NoSideEffect, ReturnLike, Terminator,
+     ParentOneOf<["WsLoopOp", "ReductionDeclareOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "omp.yield" yields SSA values from the OpenMP dialect op region and
@@ -334,4 +377,78 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
   let assemblyFormat = "attr-dict";
 }
 
+//===----------------------------------------------------------------------===//
+// 2.19.5.7 declare reduction Directive
+//===----------------------------------------------------------------------===//
+
+def ReductionDeclareOp : OpenMP_Op<"reduction.declare", [Symbol]> {
+  let summary = "declares a reduction kind";
+
+  let description = [{
+    Declares an OpenMP reduction kind. This requires two mandatory and one
+    optional region.
+
+      1. The initializer region specifies how to initialize the thread-local
+         reduction value. This is usually the neutral element of the reduction.
+         For convenience, the region has an argument that contains the value
+         of the reduction accumulator at the start of the reduction. It is
+         expected to `omp.yield` the new value on all control flow paths.
+      2. The reduction region specifies how to combine two values into one, i.e.
+         the reduction operator. It accepts the two values as arguments and is
+         expected to `omp.yield` the combined value on all control flow paths.
+      3. The atomic reduction region is optional and specifies how two values
+         can be combined atomically given local accumulator variables. It is
+         expected to store the combined value in the first accumulator variable.
+
+    Note that the MLIR type system does not allow for type-polymorphic
+    reductions. Separate reduction declarations should be created for 
diff erent
+    element and accumulator types.
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttr:$type);
+
+  let regions = (region AnyRegion:$initializerRegion,
+                        AnyRegion:$reductionRegion,
+                        AnyRegion:$atomicReductionRegion);
+  let verifier = "return ::verifyReductionDeclareOp(*this);";
+
+  let assemblyFormat = "$sym_name `:` $type attr-dict-with-keyword "
+                       "`init` $initializerRegion "
+                       "`combiner` $reductionRegion "
+                       "custom<AtomicReductionRegion>($atomicReductionRegion)";
+
+  let extraClassDeclaration = [{
+    PointerLikeType getAccumulatorType() {
+      if (atomicReductionRegion().empty())
+        return {};
+
+      return atomicReductionRegion().front().getArgument(0).getType();
+    }
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// 2.19.5.4 reduction clause
+//===----------------------------------------------------------------------===//
+
+def ReductionOp : OpenMP_Op<"reduction", [
+    TypesMatchWith<"value types matches accumulator element type",
+                   "accumulator", "operand",
+                 "$_self.cast<::mlir::omp::PointerLikeType>().getElementType()">
+  ]> {
+  let summary = "reduction construct";
+  let description = [{
+    Indicates the value that is produced by the current reduction-participating
+    entity for a reduction requested in some ancestor. The reduction is
+    identified by the accumulator, but the value of the accumulator may not be
+    updated immediately.
+  }];
+
+  let arguments= (ins AnyType:$operand, OpenMP_PointerLikeType:$accumulator);
+  let assemblyFormat =
+    "$operand `,` $accumulator attr-dict `:` type($accumulator)";
+  let verifier = "return ::verifyReductionOp(*this);";
+}
+
 #endif // OPENMP_OPS

diff  --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 23f8cd4058f7b..6fd75c611542f 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -9,4 +9,5 @@ add_mlir_dialect_library(MLIROpenMP
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRLLVMIR
   )

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 30a138e6a5a27..b5abdc7426ac5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/OpImplementation.h"
@@ -24,15 +25,31 @@
 
 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
+#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
 
 using namespace mlir;
 using namespace mlir::omp;
 
+namespace {
+/// Model for pointer-like types that already provide a `getElementType` method.
+template <typename T>
+struct PointerLikeModel
+    : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
+  Type getElementType(Type pointer) const {
+    return pointer.cast<T>().getElementType();
+  }
+};
+} // end namespace
+
 void OpenMPDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
       >();
+
+  LLVM::LLVMPointerType::attachInterface<
+      PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
+  MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
 }
 
 //===----------------------------------------------------------------------===//
@@ -439,6 +456,27 @@ parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
   return success();
 }
 
+/// reduction-init ::= `reduction` `(` reduction-entry-list `)`
+/// reduction-entry-list ::= reduction-entry
+///                        | reduction-entry-list `,` reduction-entry
+/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
+static ParseResult
+parseReductionVarList(OpAsmParser &parser,
+                      SmallVectorImpl<SymbolRefAttr> &symbols,
+                      SmallVectorImpl<OpAsmParser::OperandType> &operands,
+                      SmallVectorImpl<Type> &types) {
+  if (failed(parser.parseLParen()))
+    return failure();
+
+  do {
+    if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
+        parser.parseOperand(operands.emplace_back()) ||
+        parser.parseColonType(types.emplace_back()))
+      return failure();
+  } while (succeeded(parser.parseOptionalComma()));
+  return parser.parseRParen();
+}
+
 /// Parses an OpenMP Workshare Loop operation
 ///
 /// operation ::= `omp.wsloop` loop-control clause-list
@@ -503,9 +541,11 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::OperandType> linears;
   SmallVector<Type> linearTypes;
   SmallVector<OpAsmParser::OperandType> linearSteps;
+  SmallVector<SymbolRefAttr> reductionSymbols;
+  SmallVector<OpAsmParser::OperandType> reductionVars;
+  SmallVector<Type> reductionVarTypes;
   SmallString<8> schedule;
   Optional<OpAsmParser::OperandType> scheduleChunkSize;
-  std::array<int, 9> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0};
 
   const StringRef opName = result.name.getStringRef();
   StringRef keyword;
@@ -519,8 +559,10 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
     lastprivateClausePos,
     linearClausePos,
     linearStepPos,
+    reductionVarPos,
     scheduleClausePos,
   };
+  std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0};
 
   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
     if (keyword == "private") {
@@ -592,6 +634,13 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
     } else if (keyword == "inclusive") {
       auto attr = UnitAttr::get(parser.getBuilder().getContext());
       result.addAttribute("inclusive", attr);
+    } else if (keyword == "reduction") {
+      if (segments[reductionVarPos])
+        return allowedOnce(parser, "reduction", opName);
+      if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars,
+                                       reductionVarTypes)))
+        return failure();
+      segments[reductionVarPos] = reductionVars.size();
     }
   }
 
@@ -619,6 +668,17 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
                            linearSteps[0].location, result.operands);
   }
 
+  if (segments[reductionVarPos]) {
+    if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
+                                      parser.getNameLoc(), result.operands))) {
+      return failure();
+    }
+    SmallVector<Attribute> reductions(reductionSymbols.begin(),
+                                      reductionSymbols.end());
+    result.addAttribute("reductions",
+                        parser.getBuilder().getArrayAttr(reductions));
+  }
+
   if (!schedule.empty()) {
     schedule[0] = llvm::toUpper(schedule[0]);
     auto attr = parser.getBuilder().getStringAttr(schedule);
@@ -635,7 +695,8 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
   // Now parse the body.
   Region *body = result.addRegion();
   SmallVector<Type> ivTypes(numIVs, loopVarType);
-  if (parser.parseRegion(*body, ivs, ivTypes))
+  SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
+  if (parser.parseRegion(*body, blockArgs, ivTypes))
     return failure();
   return success();
 }
@@ -694,6 +755,17 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
     p << " ordered(" << ordered << ")";
   }
 
+  if (!op.reduction_vars().empty()) {
+    p << " reduction(";
+    for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) {
+      if (i != 0)
+        p << ", ";
+      p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : "
+        << op.reduction_vars()[i].getType();
+    }
+    p << ")";
+  }
+
   if (op.inclusive()) {
     p << " inclusive";
   }
@@ -701,6 +773,86 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
 }
 
+//===----------------------------------------------------------------------===//
+// ReductionOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
+                                              Region &region) {
+  if (parser.parseOptionalKeyword("atomic"))
+    return success();
+  return parser.parseRegion(region);
+}
+
+static void printAtomicReductionRegion(OpAsmPrinter &printer,
+                                       ReductionDeclareOp op, Region &region) {
+  if (region.empty())
+    return;
+  printer << "atomic ";
+  printer.printRegion(region);
+}
+
+static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
+  if (op.initializerRegion().empty())
+    return op.emitOpError() << "expects non-empty initializer region";
+  Block &initializerEntryBlock = op.initializerRegion().front();
+  if (initializerEntryBlock.getNumArguments() != 1 ||
+      initializerEntryBlock.getArgument(0).getType() != op.type()) {
+    return op.emitOpError() << "expects initializer region with one argument "
+                               "of the reduction type";
+  }
+
+  for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
+    if (yieldOp.results().size() != 1 ||
+        yieldOp.results().getTypes()[0] != op.type())
+      return op.emitOpError() << "expects initializer region to yield a value "
+                                 "of the reduction type";
+  }
+
+  if (op.reductionRegion().empty())
+    return op.emitOpError() << "expects non-empty reduction region";
+  Block &reductionEntryBlock = op.reductionRegion().front();
+  if (reductionEntryBlock.getNumArguments() != 2 ||
+      reductionEntryBlock.getArgumentTypes()[0] !=
+          reductionEntryBlock.getArgumentTypes()[1] ||
+      reductionEntryBlock.getArgumentTypes()[0] != op.type())
+    return op.emitOpError() << "expects reduction region with two arguments of "
+                               "the reduction type";
+  for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
+    if (yieldOp.results().size() != 1 ||
+        yieldOp.results().getTypes()[0] != op.type())
+      return op.emitOpError() << "expects reduction region to yield a value "
+                                 "of the reduction type";
+  }
+
+  if (op.atomicReductionRegion().empty())
+    return success();
+
+  Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
+  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
+      atomicReductionEntryBlock.getArgumentTypes()[0] !=
+          atomicReductionEntryBlock.getArgumentTypes()[1])
+    return op.emitOpError() << "expects atomic reduction region with two "
+                               "arguments of the same type";
+  auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
+                     .dyn_cast<PointerLikeType>();
+  if (!ptrType || ptrType.getElementType() != op.type())
+    return op.emitOpError() << "expects atomic reduction region arguments to "
+                               "be accumulators containing the reduction type";
+  return success();
+}
+
+static LogicalResult verifyReductionOp(ReductionOp op) {
+  // TODO: generalize this to an op interface when there is more than one op
+  // that supports reductions.
+  auto container = op->getParentOfType<WsLoopOp>();
+  for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
+    if (container.reduction_vars()[i] == op.accumulator())
+      return success();
+
+  return op.emitOpError() << "the accumulator is not used by the parent";
+}
+
 //===----------------------------------------------------------------------===//
 // WsLoopOp
 //===----------------------------------------------------------------------===//
@@ -712,8 +864,8 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &state,
         /*private_vars=*/ValueRange(),
         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
-        /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr,
-        /*collapse_val=*/nullptr,
+        /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
+        /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
         /*inclusive=*/nullptr, /*buildBody=*/false);
   state.addAttributes(attributes);
@@ -724,7 +876,7 @@ void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
   state.addOperands(operands);
   state.addAttributes(attributes);
   (void)state.addRegion();
-  assert(resultTypes.size() == 0u && "mismatched number of return types");
+  assert(resultTypes.empty() && "mismatched number of return types");
   state.addTypes(resultTypes);
 }
 
@@ -733,10 +885,11 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &result,
                      ValueRange upperBounds, ValueRange steps,
                      ValueRange privateVars, ValueRange firstprivateVars,
                      ValueRange lastprivateVars, ValueRange linearVars,
-                     ValueRange linearStepVars, StringAttr scheduleVal,
-                     Value scheduleChunkVar, IntegerAttr collapseVal,
-                     UnitAttr nowait, IntegerAttr orderedVal,
-                     StringAttr orderVal, UnitAttr inclusive, bool buildBody) {
+                     ValueRange linearStepVars, ValueRange reductionVars,
+                     StringAttr scheduleVal, Value scheduleChunkVar,
+                     IntegerAttr collapseVal, UnitAttr nowait,
+                     IntegerAttr orderedVal, StringAttr orderVal,
+                     UnitAttr inclusive, bool buildBody) {
   result.addOperands(lowerBounds);
   result.addOperands(upperBounds);
   result.addOperands(steps);
@@ -770,6 +923,7 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &result,
            static_cast<int32_t>(lastprivateVars.size()),
            static_cast<int32_t>(linearVars.size()),
            static_cast<int32_t>(linearStepVars.size()),
+           static_cast<int32_t>(reductionVars.size()),
            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
 
   Region *bodyRegion = result.addRegion();
@@ -781,5 +935,44 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &result,
   }
 }
 
+static LogicalResult verifyWsLoopOp(WsLoopOp op) {
+  if (op.getNumReductionVars() != 0) {
+    if (!op.reductions() ||
+        op.reductions()->size() != op.getNumReductionVars()) {
+      return op.emitOpError() << "expected as many reduction symbol references "
+                                 "as reduction variables";
+    }
+  } else {
+    if (op.reductions())
+      return op.emitOpError() << "unexpected reduction symbol references";
+    return success();
+  }
+
+  DenseSet<Value> accumulators;
+  for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) {
+    Value accum = std::get<0>(args);
+    if (!accumulators.insert(accum).second) {
+      return op.emitOpError() << "accumulator variable used more than once";
+    }
+    Type varType = accum.getType().cast<PointerLikeType>();
+    auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
+    auto decl =
+        SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
+    if (!decl) {
+      return op.emitOpError() << "expected symbol reference " << symbolRef
+                              << " to point to a reduction declaration";
+    }
+
+    if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) {
+      return op.emitOpError()
+             << "expected accumulator (" << varType
+             << ") to be the same type as reduction declaration ("
+             << decl.getAccumulatorType() << ")";
+    }
+  }
+
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"

diff  --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index e0bb0134a14af..e9a46cb6dff10 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -40,7 +40,7 @@ func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: inde
       // CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (i64, i64) -> ()
       "test.payload"(%arg6, %arg7) : (index, index) -> ()
       omp.yield
-    }) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (index, index, index, index, index, index) -> ()
+    }) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (index, index, index, index, index, index) -> ()
     omp.terminator
   }
   return

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 88f61e7f79169..4c85025c65a9d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -87,3 +87,209 @@ func @proc_bind_once() {
 
   return
 }
+
+// -----
+
+// expected-error @below {{op expects initializer region with one argument of the reduction type}}
+omp.reduction.declare @add_f32 : f64
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+// -----
+
+// expected-error @below {{expects initializer region to yield a value of the reduction type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f64
+  omp.yield (%0 : f64)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+// -----
+
+// expected-error @below {{expects reduction region with two arguments of the reduction type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f64, %arg1: f64):
+  %1 = addf %arg0, %arg1 : f64
+  omp.yield (%1 : f64)
+}
+
+// -----
+
+// expected-error @below {{expects reduction region to yield a value of the reduction type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  %2 = fpext %1 : f32 to f64
+  omp.yield (%2 : f64)
+}
+
+// -----
+
+// expected-error @below {{expects atomic reduction region with two arguments of the same type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg0: memref<f32>, %arg1: memref<f64>):
+  omp.yield
+}
+
+// -----
+
+// expected-error @below {{expects atomic reduction region arguments to be accumulators containing the reduction type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg0: memref<f64>, %arg1: memref<f64>):
+  omp.yield
+}
+
+// -----
+
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+func @foo(%lb : index, %ub : index, %step : index) {
+  %c1 = constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+  %1 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+  reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+    %2 = constant 2.0 : f32
+    // expected-error @below {{accumulator is not used by the parent}}
+    omp.reduction %2, %1 : !llvm.ptr<f32>
+    omp.yield
+  }
+  return
+}
+
+// -----
+
+func @foo(%lb : index, %ub : index, %step : index) {
+  %c1 = constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+  %1 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+
+  // expected-error @below {{expected symbol reference @foo to point to a reduction declaration}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+  reduction(@foo -> %0 : !llvm.ptr<f32>) {
+    %2 = constant 2.0 : f32
+    omp.reduction %2, %1 : !llvm.ptr<f32>
+    omp.yield
+  }
+  return
+}
+
+// -----
+
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+func @foo(%lb : index, %ub : index, %step : index) {
+  %c1 = constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+
+  // expected-error @below {{accumulator variable used more than once}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+  reduction(@add_f32 -> %0 : !llvm.ptr<f32>, @add_f32 -> %0 : !llvm.ptr<f32>) {
+    %2 = constant 2.0 : f32
+    omp.reduction %2, %0 : !llvm.ptr<f32>
+    omp.yield
+  }
+  return
+}
+
+// -----
+
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr<f32>, %arg3: !llvm.ptr<f32>):
+  %2 = llvm.load %arg3 : !llvm.ptr<f32>
+  llvm.atomicrmw fadd %arg2, %2 monotonic : f32
+  omp.yield
+}
+
+func @foo(%lb : index, %ub : index, %step : index, %mem : memref<1xf32>) {
+  %c1 = constant 1 : i32
+
+  // expected-error @below {{expected accumulator ('memref<1xf32>') to be the same type as reduction declaration ('!llvm.ptr<f32>')}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+  reduction(@add_f32 -> %mem : memref<1xf32>) {
+    %2 = constant 2.0 : f32
+    omp.reduction %2, %mem : memref<1xf32>
+    omp.yield
+  }
+  return
+}

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index f23ddf9df0f73..35ac6b30593b2 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -133,35 +133,35 @@ func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref<i32
   "omp.wsloop" (%lb, %ub, %step, %data_var, %data_var) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1,2,0,0,0,0,0]> : vector<9xi32>, collapse_val = 2, ordered_val = 1} :
+  }) {operand_segment_sizes = dense<[1,1,1,2,0,0,0,0,0,0]> : vector<10xi32>, collapse_val = 2, ordered_val = 1} :
     (index, index, index, memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(static)
   "omp.wsloop" (%lb, %ub, %step, %data_var, %linear_var) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1,0,0,0,1,1,0]> : vector<9xi32>, schedule_val = "Static"} :
+  }) {operand_segment_sizes = dense<[1,1,1,0,0,0,1,1,0,0]> : vector<10xi32>, schedule_val = "Static"} :
     (index, index, index, memref<i32>, i32) -> ()
 
   // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) linear(%{{.*}} = %{{.*}} : memref<i32>, %{{.*}} = %{{.*}} : memref<i32>) schedule(static)
   "omp.wsloop" (%lb, %ub, %step, %data_var, %data_var, %linear_var, %linear_var) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1,0,0,0,2,2,0]> : vector<9xi32>, schedule_val = "Static"} :
+  }) {operand_segment_sizes = dense<[1,1,1,0,0,0,2,2,0,0]> : vector<10xi32>, schedule_val = "Static"} :
     (index, index, index, memref<i32>, memref<i32>, i32, i32) -> ()
 
   // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref<i32>) firstprivate(%{{.*}} : memref<i32>) lastprivate(%{{.*}} : memref<i32>) linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(dynamic = %{{.*}}) collapse(3) ordered(2)
   "omp.wsloop" (%lb, %ub, %step, %data_var, %data_var, %data_var, %data_var, %linear_var, %chunk_var) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1,1,1,1,1,1,1]> : vector<9xi32>, schedule_val = "Dynamic", collapse_val = 3, ordered_val = 2} :
+  }) {operand_segment_sizes = dense<[1,1,1,1,1,1,1,1,0,1]> : vector<10xi32>, schedule_val = "Dynamic", collapse_val = 3, ordered_val = 2} :
     (index, index, index, memref<i32>, memref<i32>, memref<i32>, memref<i32>, i32, i32) -> ()
 
   // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref<i32>) schedule(auto) nowait
   "omp.wsloop" (%lb, %ub, %step, %data_var) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1,1,0,0,0,0,0]> : vector<9xi32>, nowait, schedule_val = "Auto"} :
+  }) {operand_segment_sizes = dense<[1,1,1,1,0,0,0,0,0,0]> : vector<10xi32>, nowait, schedule_val = "Auto"} :
     (index, index, index, memref<i32>) -> ()
 
   return
@@ -294,3 +294,78 @@ func @omp_target(%if_cond : i1, %device : si32,  %num_threads : si32) -> () {
 
     return
 }
+
+// CHECK: omp.reduction.declare
+// CHECK-LABEL: @add_f32
+// CHECK: : f32
+// CHECK: init
+// CHECK: ^{{.+}}(%{{.+}}: f32):
+// CHECK:   omp.yield
+// CHECK: combiner
+// CHECK: ^{{.+}}(%{{.+}}: f32, %{{.+}}: f32):
+// CHECK:   omp.yield
+// CHECK: atomic
+// CHECK: ^{{.+}}(%{{.+}}: !llvm.ptr<f32>, %{{.+}}: !llvm.ptr<f32>):
+// CHECK:  omp.yield
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr<f32>, %arg3: !llvm.ptr<f32>):
+  %2 = llvm.load %arg3 : !llvm.ptr<f32>
+  llvm.atomicrmw fadd %arg2, %2 monotonic : f32
+  omp.yield
+}
+
+func @reduction(%lb : index, %ub : index, %step : index) {
+  %c1 = constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+  // CHECK: reduction(@add_f32 -> %{{.+}} : !llvm.ptr<f32>)
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+  reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+    %1 = constant 2.0 : f32
+    // CHECK: omp.reduction %{{.+}}, %{{.+}}
+    omp.reduction %1, %0 : !llvm.ptr<f32>
+    omp.yield
+  }
+  return
+}
+
+// CHECK: omp.reduction.declare
+// CHECK-LABEL: @add2_f32
+omp.reduction.declare @add2_f32 : f32
+// CHECK: init
+init {
+^bb0(%arg: f32):
+  %0 = constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+// CHECK: combiner
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+// CHECK-NOT: atomic
+
+func @reduction2(%lb : index, %ub : index, %step : index) {
+  %0 = memref.alloca() : memref<1xf32>
+  // CHECK: reduction
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+  reduction(@add2_f32 -> %0 : memref<1xf32>) {
+    %1 = constant 2.0 : f32
+    // CHECK: omp.reduction
+    omp.reduction %1, %0 : memref<1xf32>
+    omp.yield
+  }
+  return
+}
+

diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index e9d472d2e602e..e53ec47370eb8 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -379,7 +379,7 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr<f32>) {
       llvm.store %3, %4 : !llvm.ptr<f32>
       omp.yield
       // CHECK: call void @__kmpc_for_static_fini(%struct.ident_t* @[[$wsloop_loc_struct]],
-    }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (i64, i64, i64) -> ()
+    }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (i64, i64, i64) -> ()
     omp.terminator
   }
   llvm.return
@@ -397,7 +397,7 @@ llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr<f32>) {
     %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
     llvm.store %3, %4 : !llvm.ptr<f32>
     omp.yield
-  }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (i64, i64, i64) -> ()
+  }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (i64, i64, i64) -> ()
   llvm.return
 }
 
@@ -413,7 +413,7 @@ llvm.func @wsloop_inclusive_2(%arg0: !llvm.ptr<f32>) {
     %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
     llvm.store %3, %4 : !llvm.ptr<f32>
     omp.yield
-  }) {inclusive, operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (i64, i64, i64) -> ()
+  }) {inclusive, operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (i64, i64, i64) -> ()
   llvm.return
 }
 


        


More information about the Mlir-commits mailing list