[Mlir-commits] [mlir] c282d55 - [mlir] add support for reductions in OpenMP WsLoopOp
Alex Zinenko
llvmlistbot at llvm.org
Fri Jul 9 08:54:27 PDT 2021
Author: Alex Zinenko
Date: 2021-07-09T17:54:20+02:00
New Revision: c282d55a38577e076b48cd7a8113e5eb0a2039cd
URL: https://github.com/llvm/llvm-project/commit/c282d55a38577e076b48cd7a8113e5eb0a2039cd
DIFF: https://github.com/llvm/llvm-project/commit/c282d55a38577e076b48cd7a8113e5eb0a2039cd.diff
LOG: [mlir] add support for reductions in OpenMP WsLoopOp
Use a modeling similar to SCF ParallelOp to support arbitrary parallel
reductions. The two main differences are: (1) reductions are named and declared
beforehand similarly to functions using a special op that provides the neutral
element, the reduction code and optionally the atomic reduction code; (2)
reductions go through memory instead because this is closer to the OpenMP
semantics.
See https://llvm.discourse.group/t/rfc-openmp-reduction-support/3367.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D105358
Added:
Modified:
mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Dialect/OpenMP/CMakeLists.txt
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
mlir/test/Dialect/OpenMP/invalid.mlir
mlir/test/Dialect/OpenMP/ops.mlir
mlir/test/Target/LLVMIR/openmp-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index 90614993cacf6..a0aa8ba8ab030 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -9,6 +9,8 @@ mlir_tablegen(OpenMPOps.h.inc -gen-op-decls)
mlir_tablegen(OpenMPOps.cpp.inc -gen-op-defs)
mlir_tablegen(OpenMPOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpenMPOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(OpenMPTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_mlir_doc(OpenMPOps OpenMPDialect Dialects/ -gen-dialect-doc)
add_public_tablegen_target(MLIROpenMPOpsIncGen)
add_dependencies(OpenMPDialectDocGen omp_common_td)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
index 8f79c4af1ad80..7e916f7fec7c0 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
@@ -13,14 +13,16 @@
#ifndef MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_
#define MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
+#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8eaaf971f3fdb..6cab36ddc570a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -17,12 +17,14 @@
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
-include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/Dialect/OpenMP/OmpCommon.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
def OpenMP_Dialect : Dialect {
let name = "omp";
let cppNamespace = "::mlir::omp";
+ let dependentDialects = ["::mlir::LLVM::LLVMDialect"];
}
class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
@@ -31,6 +33,27 @@ class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
// Type which can be constraint accepting standard integers and indices.
def IntLikeType : AnyTypeOf<[AnyInteger, Index]>;
+def OpenMP_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
+ let cppNamespace = "::mlir::omp";
+
+ let description = [{
+ An interface for pointer-like types suitable to contain a value that OpenMP
+ specification refers to as variable.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*description=*/"Returns the pointee type.",
+ /*retTy=*/"::mlir::Type",
+ /*methodName=*/"getElementType"
+ >,
+ ];
+}
+
+def OpenMP_PointerLikeType : Type<
+ CPred<"$_self.isa<::mlir::omp::PointerLikeType>()">,
+ "OpenMP-compatible variable type", "::mlir::omp::PointerLikeType">;
+
//===----------------------------------------------------------------------===//
// 2.6 parallel Construct
//===----------------------------------------------------------------------===//
@@ -146,6 +169,18 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
that the `linear_vars` and `linear_step_vars` variadic lists should contain
the same number of elements.
+ Reductions can be performed in a workshare loop by specifying reduction
+ accumulator variables in `reduction_vars` and symbols referring to reduction
+ declarations in the `reductions` attribute. Each reduction is identified
+ by the accumulator it uses and accumulators must not be repeated in the same
+ reduction. The `omp.reduction` operation accepts the accumulator and a
+ partial value which is considered to be produced by the current loop
+ iteration for the given reduction. If multiple values are produced for the
+ same accumulator, i.e. there are multiple `omp.reduction`s, the last value
+ is taken. The reduction declaration specifies how to combine the values from
+ each iteration into the final value, which is available in the accumulator
+ after the loop completes.
+
The optional `schedule_val` attribute specifies the loop schedule for this
loop, determining how the loop is distributed across the parallel threads.
The optional `schedule_chunk_var` associated with this determines further
@@ -173,6 +208,9 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
Variadic<AnyType>:$lastprivate_vars,
Variadic<AnyType>:$linear_vars,
Variadic<AnyType>:$linear_step_vars,
+ Variadic<OpenMP_PointerLikeType>:$reduction_vars,
+ OptionalAttr<TypedArrayAttrBase<SymbolRefAttr,
+ "array of symbol references">>:$reductions,
OptionalAttr<ScheduleKind>:$schedule_val,
Optional<AnyType>:$schedule_chunk_var,
Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$collapse_val,
@@ -191,11 +229,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
"ValueRange":$upperBound, "ValueRange":$step,
"ValueRange":$privateVars, "ValueRange":$firstprivateVars,
"ValueRange":$lastprivate_vars, "ValueRange":$linear_vars,
- "ValueRange":$linear_step_vars, "StringAttr":$schedule_val,
- "Value":$schedule_chunk_var, "IntegerAttr":$collapse_val,
- "UnitAttr":$nowait, "IntegerAttr":$ordered_val,
- "StringAttr":$order_val, "UnitAttr":$inclusive, CArg<"bool",
- "true">:$buildBody)>,
+ "ValueRange":$linear_step_vars, "ValueRange":$reduction_vars,
+ "StringAttr":$schedule_val, "Value":$schedule_chunk_var,
+ "IntegerAttr":$collapse_val, "UnitAttr":$nowait,
+ "IntegerAttr":$ordered_val, "StringAttr":$order_val,
+ "UnitAttr":$inclusive, CArg<"bool", "true">:$buildBody)>,
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
@@ -205,13 +243,18 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
let extraClassDeclaration = [{
/// Returns the number of loops in the workshape loop nest.
unsigned getNumLoops() { return lowerBound().size(); }
+
+ /// Returns the number of reduction variables.
+ unsigned getNumReductionVars() { return reduction_vars().size(); }
}];
let parser = [{ return parseWsLoopOp(parser, result); }];
let printer = [{ return printWsLoopOp(p, *this); }];
+ let verifier = [{ return ::verifyWsLoopOp(*this); }];
}
-def YieldOp : OpenMP_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
- HasParent<"WsLoopOp">]> {
+def YieldOp : OpenMP_Op<"yield",
+ [NoSideEffect, ReturnLike, Terminator,
+ ParentOneOf<["WsLoopOp", "ReductionDeclareOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"omp.yield" yields SSA values from the OpenMP dialect op region and
@@ -334,4 +377,78 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
let assemblyFormat = "attr-dict";
}
+//===----------------------------------------------------------------------===//
+// 2.19.5.7 declare reduction Directive
+//===----------------------------------------------------------------------===//
+
+def ReductionDeclareOp : OpenMP_Op<"reduction.declare", [Symbol]> {
+ let summary = "declares a reduction kind";
+
+ let description = [{
+ Declares an OpenMP reduction kind. This requires two mandatory and one
+ optional region.
+
+ 1. The initializer region specifies how to initialize the thread-local
+ reduction value. This is usually the neutral element of the reduction.
+ For convenience, the region has an argument that contains the value
+ of the reduction accumulator at the start of the reduction. It is
+ expected to `omp.yield` the new value on all control flow paths.
+ 2. The reduction region specifies how to combine two values into one, i.e.
+ the reduction operator. It accepts the two values as arguments and is
+ expected to `omp.yield` the combined value on all control flow paths.
+ 3. The atomic reduction region is optional and specifies how two values
+ can be combined atomically given local accumulator variables. It is
+ expected to store the combined value in the first accumulator variable.
+
+ Note that the MLIR type system does not allow for type-polymorphic
+ reductions. Separate reduction declarations should be created for
diff erent
+ element and accumulator types.
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttr:$type);
+
+ let regions = (region AnyRegion:$initializerRegion,
+ AnyRegion:$reductionRegion,
+ AnyRegion:$atomicReductionRegion);
+ let verifier = "return ::verifyReductionDeclareOp(*this);";
+
+ let assemblyFormat = "$sym_name `:` $type attr-dict-with-keyword "
+ "`init` $initializerRegion "
+ "`combiner` $reductionRegion "
+ "custom<AtomicReductionRegion>($atomicReductionRegion)";
+
+ let extraClassDeclaration = [{
+ PointerLikeType getAccumulatorType() {
+ if (atomicReductionRegion().empty())
+ return {};
+
+ return atomicReductionRegion().front().getArgument(0).getType();
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// 2.19.5.4 reduction clause
+//===----------------------------------------------------------------------===//
+
+def ReductionOp : OpenMP_Op<"reduction", [
+ TypesMatchWith<"value types matches accumulator element type",
+ "accumulator", "operand",
+ "$_self.cast<::mlir::omp::PointerLikeType>().getElementType()">
+ ]> {
+ let summary = "reduction construct";
+ let description = [{
+ Indicates the value that is produced by the current reduction-participating
+ entity for a reduction requested in some ancestor. The reduction is
+ identified by the accumulator, but the value of the accumulator may not be
+ updated immediately.
+ }];
+
+ let arguments= (ins AnyType:$operand, OpenMP_PointerLikeType:$accumulator);
+ let assemblyFormat =
+ "$operand `,` $accumulator attr-dict `:` type($accumulator)";
+ let verifier = "return ::verifyReductionOp(*this);";
+}
+
#endif // OPENMP_OPS
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 23f8cd4058f7b..6fd75c611542f 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -9,4 +9,5 @@ add_mlir_dialect_library(MLIROpenMP
LINK_LIBS PUBLIC
MLIRIR
+ MLIRLLVMIR
)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 30a138e6a5a27..b5abdc7426ac5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/OpImplementation.h"
@@ -24,15 +25,31 @@
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
+#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
using namespace mlir;
using namespace mlir::omp;
+namespace {
+/// Model for pointer-like types that already provide a `getElementType` method.
+template <typename T>
+struct PointerLikeModel
+ : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
+ Type getElementType(Type pointer) const {
+ return pointer.cast<T>().getElementType();
+ }
+};
+} // end namespace
+
void OpenMPDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
>();
+
+ LLVM::LLVMPointerType::attachInterface<
+ PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
+ MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
}
//===----------------------------------------------------------------------===//
@@ -439,6 +456,27 @@ parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
return success();
}
+/// reduction-init ::= `reduction` `(` reduction-entry-list `)`
+/// reduction-entry-list ::= reduction-entry
+/// | reduction-entry-list `,` reduction-entry
+/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
+static ParseResult
+parseReductionVarList(OpAsmParser &parser,
+ SmallVectorImpl<SymbolRefAttr> &symbols,
+ SmallVectorImpl<OpAsmParser::OperandType> &operands,
+ SmallVectorImpl<Type> &types) {
+ if (failed(parser.parseLParen()))
+ return failure();
+
+ do {
+ if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
+ parser.parseOperand(operands.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ } while (succeeded(parser.parseOptionalComma()));
+ return parser.parseRParen();
+}
+
/// Parses an OpenMP Workshare Loop operation
///
/// operation ::= `omp.wsloop` loop-control clause-list
@@ -503,9 +541,11 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType> linears;
SmallVector<Type> linearTypes;
SmallVector<OpAsmParser::OperandType> linearSteps;
+ SmallVector<SymbolRefAttr> reductionSymbols;
+ SmallVector<OpAsmParser::OperandType> reductionVars;
+ SmallVector<Type> reductionVarTypes;
SmallString<8> schedule;
Optional<OpAsmParser::OperandType> scheduleChunkSize;
- std::array<int, 9> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0};
const StringRef opName = result.name.getStringRef();
StringRef keyword;
@@ -519,8 +559,10 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
lastprivateClausePos,
linearClausePos,
linearStepPos,
+ reductionVarPos,
scheduleClausePos,
};
+ std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0};
while (succeeded(parser.parseOptionalKeyword(&keyword))) {
if (keyword == "private") {
@@ -592,6 +634,13 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
} else if (keyword == "inclusive") {
auto attr = UnitAttr::get(parser.getBuilder().getContext());
result.addAttribute("inclusive", attr);
+ } else if (keyword == "reduction") {
+ if (segments[reductionVarPos])
+ return allowedOnce(parser, "reduction", opName);
+ if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars,
+ reductionVarTypes)))
+ return failure();
+ segments[reductionVarPos] = reductionVars.size();
}
}
@@ -619,6 +668,17 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
linearSteps[0].location, result.operands);
}
+ if (segments[reductionVarPos]) {
+ if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
+ parser.getNameLoc(), result.operands))) {
+ return failure();
+ }
+ SmallVector<Attribute> reductions(reductionSymbols.begin(),
+ reductionSymbols.end());
+ result.addAttribute("reductions",
+ parser.getBuilder().getArrayAttr(reductions));
+ }
+
if (!schedule.empty()) {
schedule[0] = llvm::toUpper(schedule[0]);
auto attr = parser.getBuilder().getStringAttr(schedule);
@@ -635,7 +695,8 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
// Now parse the body.
Region *body = result.addRegion();
SmallVector<Type> ivTypes(numIVs, loopVarType);
- if (parser.parseRegion(*body, ivs, ivTypes))
+ SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
+ if (parser.parseRegion(*body, blockArgs, ivTypes))
return failure();
return success();
}
@@ -694,6 +755,17 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
p << " ordered(" << ordered << ")";
}
+ if (!op.reduction_vars().empty()) {
+ p << " reduction(";
+ for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) {
+ if (i != 0)
+ p << ", ";
+ p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : "
+ << op.reduction_vars()[i].getType();
+ }
+ p << ")";
+ }
+
if (op.inclusive()) {
p << " inclusive";
}
@@ -701,6 +773,86 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
+//===----------------------------------------------------------------------===//
+// ReductionOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
+ Region ®ion) {
+ if (parser.parseOptionalKeyword("atomic"))
+ return success();
+ return parser.parseRegion(region);
+}
+
+static void printAtomicReductionRegion(OpAsmPrinter &printer,
+ ReductionDeclareOp op, Region ®ion) {
+ if (region.empty())
+ return;
+ printer << "atomic ";
+ printer.printRegion(region);
+}
+
+static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
+ if (op.initializerRegion().empty())
+ return op.emitOpError() << "expects non-empty initializer region";
+ Block &initializerEntryBlock = op.initializerRegion().front();
+ if (initializerEntryBlock.getNumArguments() != 1 ||
+ initializerEntryBlock.getArgument(0).getType() != op.type()) {
+ return op.emitOpError() << "expects initializer region with one argument "
+ "of the reduction type";
+ }
+
+ for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
+ if (yieldOp.results().size() != 1 ||
+ yieldOp.results().getTypes()[0] != op.type())
+ return op.emitOpError() << "expects initializer region to yield a value "
+ "of the reduction type";
+ }
+
+ if (op.reductionRegion().empty())
+ return op.emitOpError() << "expects non-empty reduction region";
+ Block &reductionEntryBlock = op.reductionRegion().front();
+ if (reductionEntryBlock.getNumArguments() != 2 ||
+ reductionEntryBlock.getArgumentTypes()[0] !=
+ reductionEntryBlock.getArgumentTypes()[1] ||
+ reductionEntryBlock.getArgumentTypes()[0] != op.type())
+ return op.emitOpError() << "expects reduction region with two arguments of "
+ "the reduction type";
+ for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
+ if (yieldOp.results().size() != 1 ||
+ yieldOp.results().getTypes()[0] != op.type())
+ return op.emitOpError() << "expects reduction region to yield a value "
+ "of the reduction type";
+ }
+
+ if (op.atomicReductionRegion().empty())
+ return success();
+
+ Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
+ if (atomicReductionEntryBlock.getNumArguments() != 2 ||
+ atomicReductionEntryBlock.getArgumentTypes()[0] !=
+ atomicReductionEntryBlock.getArgumentTypes()[1])
+ return op.emitOpError() << "expects atomic reduction region with two "
+ "arguments of the same type";
+ auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
+ .dyn_cast<PointerLikeType>();
+ if (!ptrType || ptrType.getElementType() != op.type())
+ return op.emitOpError() << "expects atomic reduction region arguments to "
+ "be accumulators containing the reduction type";
+ return success();
+}
+
+static LogicalResult verifyReductionOp(ReductionOp op) {
+ // TODO: generalize this to an op interface when there is more than one op
+ // that supports reductions.
+ auto container = op->getParentOfType<WsLoopOp>();
+ for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
+ if (container.reduction_vars()[i] == op.accumulator())
+ return success();
+
+ return op.emitOpError() << "the accumulator is not used by the parent";
+}
+
//===----------------------------------------------------------------------===//
// WsLoopOp
//===----------------------------------------------------------------------===//
@@ -712,8 +864,8 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &state,
/*private_vars=*/ValueRange(),
/*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
/*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
- /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr,
- /*collapse_val=*/nullptr,
+ /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
+ /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
/*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
/*inclusive=*/nullptr, /*buildBody=*/false);
state.addAttributes(attributes);
@@ -724,7 +876,7 @@ void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
state.addOperands(operands);
state.addAttributes(attributes);
(void)state.addRegion();
- assert(resultTypes.size() == 0u && "mismatched number of return types");
+ assert(resultTypes.empty() && "mismatched number of return types");
state.addTypes(resultTypes);
}
@@ -733,10 +885,11 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &result,
ValueRange upperBounds, ValueRange steps,
ValueRange privateVars, ValueRange firstprivateVars,
ValueRange lastprivateVars, ValueRange linearVars,
- ValueRange linearStepVars, StringAttr scheduleVal,
- Value scheduleChunkVar, IntegerAttr collapseVal,
- UnitAttr nowait, IntegerAttr orderedVal,
- StringAttr orderVal, UnitAttr inclusive, bool buildBody) {
+ ValueRange linearStepVars, ValueRange reductionVars,
+ StringAttr scheduleVal, Value scheduleChunkVar,
+ IntegerAttr collapseVal, UnitAttr nowait,
+ IntegerAttr orderedVal, StringAttr orderVal,
+ UnitAttr inclusive, bool buildBody) {
result.addOperands(lowerBounds);
result.addOperands(upperBounds);
result.addOperands(steps);
@@ -770,6 +923,7 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &result,
static_cast<int32_t>(lastprivateVars.size()),
static_cast<int32_t>(linearVars.size()),
static_cast<int32_t>(linearStepVars.size()),
+ static_cast<int32_t>(reductionVars.size()),
static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
Region *bodyRegion = result.addRegion();
@@ -781,5 +935,44 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &result,
}
}
+static LogicalResult verifyWsLoopOp(WsLoopOp op) {
+ if (op.getNumReductionVars() != 0) {
+ if (!op.reductions() ||
+ op.reductions()->size() != op.getNumReductionVars()) {
+ return op.emitOpError() << "expected as many reduction symbol references "
+ "as reduction variables";
+ }
+ } else {
+ if (op.reductions())
+ return op.emitOpError() << "unexpected reduction symbol references";
+ return success();
+ }
+
+ DenseSet<Value> accumulators;
+ for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) {
+ Value accum = std::get<0>(args);
+ if (!accumulators.insert(accum).second) {
+ return op.emitOpError() << "accumulator variable used more than once";
+ }
+ Type varType = accum.getType().cast<PointerLikeType>();
+ auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
+ auto decl =
+ SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
+ if (!decl) {
+ return op.emitOpError() << "expected symbol reference " << symbolRef
+ << " to point to a reduction declaration";
+ }
+
+ if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) {
+ return op.emitOpError()
+ << "expected accumulator (" << varType
+ << ") to be the same type as reduction declaration ("
+ << decl.getAccumulatorType() << ")";
+ }
+ }
+
+ return success();
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index e0bb0134a14af..e9a46cb6dff10 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -40,7 +40,7 @@ func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: inde
// CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (i64, i64) -> ()
"test.payload"(%arg6, %arg7) : (index, index) -> ()
omp.yield
- }) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (index, index, index, index, index, index) -> ()
+ }) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (index, index, index, index, index, index) -> ()
omp.terminator
}
return
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 88f61e7f79169..4c85025c65a9d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -87,3 +87,209 @@ func @proc_bind_once() {
return
}
+
+// -----
+
+// expected-error @below {{op expects initializer region with one argument of the reduction type}}
+omp.reduction.declare @add_f32 : f64
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+
+// -----
+
+// expected-error @below {{expects initializer region to yield a value of the reduction type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f64
+ omp.yield (%0 : f64)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+
+// -----
+
+// expected-error @below {{expects reduction region with two arguments of the reduction type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f64, %arg1: f64):
+ %1 = addf %arg0, %arg1 : f64
+ omp.yield (%1 : f64)
+}
+
+// -----
+
+// expected-error @below {{expects reduction region to yield a value of the reduction type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ %2 = fpext %1 : f32 to f64
+ omp.yield (%2 : f64)
+}
+
+// -----
+
+// expected-error @below {{expects atomic reduction region with two arguments of the same type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg0: memref<f32>, %arg1: memref<f64>):
+ omp.yield
+}
+
+// -----
+
+// expected-error @below {{expects atomic reduction region arguments to be accumulators containing the reduction type}}
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg0: memref<f64>, %arg1: memref<f64>):
+ omp.yield
+}
+
+// -----
+
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+
+func @foo(%lb : index, %ub : index, %step : index) {
+ %c1 = constant 1 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+ %1 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+
+ omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+ reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+ %2 = constant 2.0 : f32
+ // expected-error @below {{accumulator is not used by the parent}}
+ omp.reduction %2, %1 : !llvm.ptr<f32>
+ omp.yield
+ }
+ return
+}
+
+// -----
+
+func @foo(%lb : index, %ub : index, %step : index) {
+ %c1 = constant 1 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+ %1 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+
+ // expected-error @below {{expected symbol reference @foo to point to a reduction declaration}}
+ omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+ reduction(@foo -> %0 : !llvm.ptr<f32>) {
+ %2 = constant 2.0 : f32
+ omp.reduction %2, %1 : !llvm.ptr<f32>
+ omp.yield
+ }
+ return
+}
+
+// -----
+
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+
+func @foo(%lb : index, %ub : index, %step : index) {
+ %c1 = constant 1 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+
+ // expected-error @below {{accumulator variable used more than once}}
+ omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+ reduction(@add_f32 -> %0 : !llvm.ptr<f32>, @add_f32 -> %0 : !llvm.ptr<f32>) {
+ %2 = constant 2.0 : f32
+ omp.reduction %2, %0 : !llvm.ptr<f32>
+ omp.yield
+ }
+ return
+}
+
+// -----
+
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr<f32>, %arg3: !llvm.ptr<f32>):
+ %2 = llvm.load %arg3 : !llvm.ptr<f32>
+ llvm.atomicrmw fadd %arg2, %2 monotonic : f32
+ omp.yield
+}
+
+func @foo(%lb : index, %ub : index, %step : index, %mem : memref<1xf32>) {
+ %c1 = constant 1 : i32
+
+ // expected-error @below {{expected accumulator ('memref<1xf32>') to be the same type as reduction declaration ('!llvm.ptr<f32>')}}
+ omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+ reduction(@add_f32 -> %mem : memref<1xf32>) {
+ %2 = constant 2.0 : f32
+ omp.reduction %2, %mem : memref<1xf32>
+ omp.yield
+ }
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index f23ddf9df0f73..35ac6b30593b2 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -133,35 +133,35 @@ func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref<i32
"omp.wsloop" (%lb, %ub, %step, %data_var, %data_var) ({
^bb0(%iv: index):
omp.yield
- }) {operand_segment_sizes = dense<[1,1,1,2,0,0,0,0,0]> : vector<9xi32>, collapse_val = 2, ordered_val = 1} :
+ }) {operand_segment_sizes = dense<[1,1,1,2,0,0,0,0,0,0]> : vector<10xi32>, collapse_val = 2, ordered_val = 1} :
(index, index, index, memref<i32>, memref<i32>) -> ()
// CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(static)
"omp.wsloop" (%lb, %ub, %step, %data_var, %linear_var) ({
^bb0(%iv: index):
omp.yield
- }) {operand_segment_sizes = dense<[1,1,1,0,0,0,1,1,0]> : vector<9xi32>, schedule_val = "Static"} :
+ }) {operand_segment_sizes = dense<[1,1,1,0,0,0,1,1,0,0]> : vector<10xi32>, schedule_val = "Static"} :
(index, index, index, memref<i32>, i32) -> ()
// CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) linear(%{{.*}} = %{{.*}} : memref<i32>, %{{.*}} = %{{.*}} : memref<i32>) schedule(static)
"omp.wsloop" (%lb, %ub, %step, %data_var, %data_var, %linear_var, %linear_var) ({
^bb0(%iv: index):
omp.yield
- }) {operand_segment_sizes = dense<[1,1,1,0,0,0,2,2,0]> : vector<9xi32>, schedule_val = "Static"} :
+ }) {operand_segment_sizes = dense<[1,1,1,0,0,0,2,2,0,0]> : vector<10xi32>, schedule_val = "Static"} :
(index, index, index, memref<i32>, memref<i32>, i32, i32) -> ()
// CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref<i32>) firstprivate(%{{.*}} : memref<i32>) lastprivate(%{{.*}} : memref<i32>) linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(dynamic = %{{.*}}) collapse(3) ordered(2)
"omp.wsloop" (%lb, %ub, %step, %data_var, %data_var, %data_var, %data_var, %linear_var, %chunk_var) ({
^bb0(%iv: index):
omp.yield
- }) {operand_segment_sizes = dense<[1,1,1,1,1,1,1,1,1]> : vector<9xi32>, schedule_val = "Dynamic", collapse_val = 3, ordered_val = 2} :
+ }) {operand_segment_sizes = dense<[1,1,1,1,1,1,1,1,0,1]> : vector<10xi32>, schedule_val = "Dynamic", collapse_val = 3, ordered_val = 2} :
(index, index, index, memref<i32>, memref<i32>, memref<i32>, memref<i32>, i32, i32) -> ()
// CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref<i32>) schedule(auto) nowait
"omp.wsloop" (%lb, %ub, %step, %data_var) ({
^bb0(%iv: index):
omp.yield
- }) {operand_segment_sizes = dense<[1,1,1,1,0,0,0,0,0]> : vector<9xi32>, nowait, schedule_val = "Auto"} :
+ }) {operand_segment_sizes = dense<[1,1,1,1,0,0,0,0,0,0]> : vector<10xi32>, nowait, schedule_val = "Auto"} :
(index, index, index, memref<i32>) -> ()
return
@@ -294,3 +294,78 @@ func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
return
}
+
+// CHECK: omp.reduction.declare
+// CHECK-LABEL: @add_f32
+// CHECK: : f32
+// CHECK: init
+// CHECK: ^{{.+}}(%{{.+}}: f32):
+// CHECK: omp.yield
+// CHECK: combiner
+// CHECK: ^{{.+}}(%{{.+}}: f32, %{{.+}}: f32):
+// CHECK: omp.yield
+// CHECK: atomic
+// CHECK: ^{{.+}}(%{{.+}}: !llvm.ptr<f32>, %{{.+}}: !llvm.ptr<f32>):
+// CHECK: omp.yield
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr<f32>, %arg3: !llvm.ptr<f32>):
+ %2 = llvm.load %arg3 : !llvm.ptr<f32>
+ llvm.atomicrmw fadd %arg2, %2 monotonic : f32
+ omp.yield
+}
+
+func @reduction(%lb : index, %ub : index, %step : index) {
+ %c1 = constant 1 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr<f32>
+ // CHECK: reduction(@add_f32 -> %{{.+}} : !llvm.ptr<f32>)
+ omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+ reduction(@add_f32 -> %0 : !llvm.ptr<f32>) {
+ %1 = constant 2.0 : f32
+ // CHECK: omp.reduction %{{.+}}, %{{.+}}
+ omp.reduction %1, %0 : !llvm.ptr<f32>
+ omp.yield
+ }
+ return
+}
+
+// CHECK: omp.reduction.declare
+// CHECK-LABEL: @add2_f32
+omp.reduction.declare @add2_f32 : f32
+// CHECK: init
+init {
+^bb0(%arg: f32):
+ %0 = constant 0.0 : f32
+ omp.yield (%0 : f32)
+}
+// CHECK: combiner
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+// CHECK-NOT: atomic
+
+func @reduction2(%lb : index, %ub : index, %step : index) {
+ %0 = memref.alloca() : memref<1xf32>
+ // CHECK: reduction
+ omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step)
+ reduction(@add2_f32 -> %0 : memref<1xf32>) {
+ %1 = constant 2.0 : f32
+ // CHECK: omp.reduction
+ omp.reduction %1, %0 : memref<1xf32>
+ omp.yield
+ }
+ return
+}
+
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index e9d472d2e602e..e53ec47370eb8 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -379,7 +379,7 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr<f32>) {
llvm.store %3, %4 : !llvm.ptr<f32>
omp.yield
// CHECK: call void @__kmpc_for_static_fini(%struct.ident_t* @[[$wsloop_loc_struct]],
- }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (i64, i64, i64) -> ()
+ }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (i64, i64, i64) -> ()
omp.terminator
}
llvm.return
@@ -397,7 +397,7 @@ llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr<f32>) {
%4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
llvm.store %3, %4 : !llvm.ptr<f32>
omp.yield
- }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (i64, i64, i64) -> ()
+ }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (i64, i64, i64) -> ()
llvm.return
}
@@ -413,7 +413,7 @@ llvm.func @wsloop_inclusive_2(%arg0: !llvm.ptr<f32>) {
%4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
llvm.store %3, %4 : !llvm.ptr<f32>
omp.yield
- }) {inclusive, operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (i64, i64, i64) -> ()
+ }) {inclusive, operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (i64, i64, i64) -> ()
llvm.return
}
More information about the Mlir-commits
mailing list