[Mlir-commits] [mlir] 1ce752b - [mlir] support reductions in SCF to OpenMP conversion
Alex Zinenko
llvmlistbot at llvm.org
Thu Sep 9 04:04:37 PDT 2021
Author: Alex Zinenko
Date: 2021-09-09T13:04:27+02:00
New Revision: 1ce752b7415bef59200c32f64fcfac795813b11e
URL: https://github.com/llvm/llvm-project/commit/1ce752b7415bef59200c32f64fcfac795813b11e
DIFF: https://github.com/llvm/llvm-project/commit/1ce752b7415bef59200c32f64fcfac795813b11e.diff
LOG: [mlir] support reductions in SCF to OpenMP conversion
OpenMP reductions need a neutral element, so we match some known reduction
kinds (integer add/mul/or/and/xor, float add/mul, integer and float min/max) to
define the neutral element and the atomic version when possible to express
using atomicrmw (everything except float mul). The SCF-to-OpenMP pass becomes a
module pass because it now needs to introduce new symbols for reduction
declarations in the module.
Reviewed By: chelini
Differential Revision: https://reviews.llvm.org/D107549
Added:
mlir/test/Conversion/SCFToOpenMP/reductions.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 24b2cf1dc422b..30940fcc891f9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -355,11 +355,11 @@ def ConvertPDLToPDLInterp : Pass<"convert-pdl-to-pdl-interp", "ModuleOp"> {
// SCFToOpenMP
//===----------------------------------------------------------------------===//
-def ConvertSCFToOpenMP : FunctionPass<"convert-scf-to-openmp"> {
+def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> {
let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
"constructs.";
let constructor = "mlir::createConvertSCFToOpenMPPass()";
- let dependentDialects = ["omp::OpenMPDialect"];
+ let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
index 349c4e1efc833..4000bc1df46b2 100644
--- a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
+++ b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
@@ -12,11 +12,11 @@
#include <memory>
namespace mlir {
-class FuncOp;
+class ModuleOp;
template <typename T>
class OperationPass;
-std::unique_ptr<OperationPass<FuncOp>> createConvertSCFToOpenMPPass();
+std::unique_ptr<OperationPass<ModuleOp>> createConvertSCFToOpenMPPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 08396a345c6e6..05d406d09fc67 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -317,12 +317,12 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
The optional $device parameter specifies the device number for the target region.
The optional $thread_limit specifies the limit on the number of threads
-
+
The optional $nowait elliminates the implicit barrier so the parent task can make progress
even if the target task is not yet completed.
-
+
TODO: private, map, is_device_ptr, firstprivate, depend, defaultmap, in_reduction
-
+
}];
let arguments = (ins Optional<I1>:$if_expr,
diff --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
index 1ef4b74da6dd7..1a75a3549f979 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
+++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
@@ -11,7 +11,9 @@ add_mlir_conversion_library(MLIRSCFToOpenMP
Core
LINK_LIBS PUBLIC
+ MLIRLLVMIR
MLIROpenMP
MLIRSCF
+ MLIRStandard
MLIRTransforms
)
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index a7d4a99c9d5b5..9c6fc6fed9139 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -13,26 +13,311 @@
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
#include "../PassDetail.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
+/// Matches a block containing a "simple" reduction. The expected shape of the
+/// block is as follows.
+///
+/// ^bb(%arg0, %arg1):
+/// %0 = OpTy(%arg0, %arg1)
+/// scf.reduce.return %0
+template <typename... OpTy>
+static bool matchSimpleReduction(Block &block) {
+ if (block.empty() || llvm::hasSingleElement(block) ||
+ std::next(block.begin(), 2) != block.end())
+ return false;
+ return isa<OpTy...>(block.front()) &&
+ isa<scf::ReduceReturnOp>(block.back()) &&
+ block.front().getOperands() == block.getArguments() &&
+ block.back().getOperand(0) == block.front().getResult(0);
+}
+
+/// Matches a block containing a select-based min/max reduction. The types of
+/// select and compare operations are provided as template arguments. The
+/// comparison predicates suitable for min and max are provided as function
+/// arguments. If a reduction is matched, `ifMin` will be set if the reduction
+/// compute the minimum and unset if it computes the maximum, otherwise it
+/// remains unmodified. The expected shape of the block is as follows.
+///
+/// ^bb(%arg0, %arg1):
+/// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
+/// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
+/// scf.reduce.return %1
+template <
+ typename CompareOpTy, typename SelectOpTy,
+ typename Predicate = decltype(std::declval<CompareOpTy>().predicate())>
+static bool
+matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
+ ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
+ static_assert(llvm::is_one_of<SelectOpTy, SelectOp, LLVM::SelectOp>::value,
+ "only std and llvm select ops are supported");
+
+ // Expect exactly three operations in the block.
+ if (block.empty() || llvm::hasSingleElement(block) ||
+ std::next(block.begin(), 2) == block.end() ||
+ std::next(block.begin(), 3) != block.end())
+ return false;
+
+ // Check op kinds.
+ auto compare = dyn_cast<CompareOpTy>(block.front());
+ auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
+ auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
+ if (!compare || !select || !terminator)
+ return false;
+
+ // Block arguments must be compared.
+ if (compare->getOperands() != block.getArguments())
+ return false;
+
+ // Detect whether the comparison is less-than or greater-than, otherwise bail.
+ bool isLess;
+ if (llvm::find(lessThanPredicates, compare.predicate()) !=
+ lessThanPredicates.end()) {
+ isLess = true;
+ } else if (llvm::find(greaterThanPredicates, compare.predicate()) !=
+ greaterThanPredicates.end()) {
+ isLess = false;
+ } else {
+ return false;
+ }
+
+ if (select.condition() != compare.getResult())
+ return false;
+
+ // Detect if the operands are swapped between cmpf and select. Match the
+ // comparison type with the requested type or with the opposite of the
+ // requested type if the operands are swapped. Use generic accessors because
+ // std and LLVM versions of select have
diff erent operand names but identical
+ // positions.
+ constexpr unsigned kTrueValue = 1;
+ constexpr unsigned kFalseValue = 2;
+ bool sameOperands = select.getOperand(kTrueValue) == compare.lhs() &&
+ select.getOperand(kFalseValue) == compare.rhs();
+ bool swappedOperands = select.getOperand(kTrueValue) == compare.rhs() &&
+ select.getOperand(kFalseValue) == compare.lhs();
+ if (!sameOperands && !swappedOperands)
+ return false;
+
+ if (select.getResult() != terminator.result())
+ return false;
+
+ // The reduction is a min if it uses less-than predicates with same operands
+ // or greather-than predicates with swapped operands. Similarly for max.
+ isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
+ return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
+}
+
+/// Returns the float semantics for the given float type.
+static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
+ if (type.isF16())
+ return llvm::APFloat::IEEEhalf();
+ if (type.isF32())
+ return llvm::APFloat::IEEEsingle();
+ if (type.isF64())
+ return llvm::APFloat::IEEEdouble();
+ if (type.isF128())
+ return llvm::APFloat::IEEEquad();
+ if (type.isBF16())
+ return llvm::APFloat::BFloat();
+ if (type.isF80())
+ return llvm::APFloat::x87DoubleExtended();
+ llvm_unreachable("unknown float type");
+}
+
+/// Returns an attribute with the minimum (if `min` is set) or the maximum value
+/// (otherwise) for the given float type.
+static Attribute minMaxValueForFloat(Type type, bool min) {
+ auto fltType = type.cast<FloatType>();
+ return FloatAttr::get(
+ type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+}
+
+/// Returns an attribute with the signed integer minimum (if `min` is set) or
+/// the maximum value (otherwise) for the given integer type, regardless of its
+/// signedness semantics (only the width is considered).
+static Attribute minMaxValueForSignedInt(Type type, bool min) {
+ auto intType = type.cast<IntegerType>();
+ unsigned bitwidth = intType.getWidth();
+ return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
+ : llvm::APInt::getSignedMaxValue(bitwidth));
+}
+
+/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
+/// the maximum value (otherwise) for the given integer type, regardless of its
+/// signedness semantics (only the width is considered).
+static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
+ auto intType = type.cast<IntegerType>();
+ unsigned bitwidth = intType.getWidth();
+ return IntegerAttr::get(type, min ? llvm::APInt::getNullValue(bitwidth)
+ : llvm::APInt::getAllOnesValue(bitwidth));
+}
+
+/// Creates an OpenMP reduction declaration and inserts it into the provided
+/// symbol table. The declaration has a constant initializer with the neutral
+/// value `initValue`, and the reduction combiner carried over from `reduce`.
+static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
+ SymbolTable &symbolTable,
+ scf::ReduceOp reduce,
+ Attribute initValue) {
+ OpBuilder::InsertionGuard guard(builder);
+ auto decl = builder.create<omp::ReductionDeclareOp>(
+ reduce.getLoc(), "__scf_reduction", reduce.operand().getType());
+ symbolTable.insert(decl);
+
+ Type type = reduce.operand().getType();
+ builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(),
+ {type});
+ builder.setInsertionPointToEnd(&decl.initializerRegion().back());
+ Value init =
+ builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
+ builder.create<omp::YieldOp>(reduce.getLoc(), init);
+
+ Operation *terminator = &reduce.getRegion().front().back();
+ assert(isa<scf::ReduceReturnOp>(terminator) &&
+ "expected reduce op to be terminated by redure return");
+ builder.setInsertionPoint(terminator);
+ builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
+ terminator->getOperands());
+ builder.inlineRegionBefore(reduce.getRegion(), decl.reductionRegion(),
+ decl.reductionRegion().end());
+ return decl;
+}
+
+/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
+/// using llvm.atomicrmw of the given kind.
+static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
+ LLVM::AtomicBinOp atomicKind,
+ omp::ReductionDeclareOp decl,
+ scf::ReduceOp reduce) {
+ OpBuilder::InsertionGuard guard(builder);
+ Type type = reduce.operand().getType();
+ Type ptrType = LLVM::LLVMPointerType::get(type);
+ builder.createBlock(&decl.atomicReductionRegion(),
+ decl.atomicReductionRegion().end(), {ptrType, ptrType});
+ Block *atomicBlock = &decl.atomicReductionRegion().back();
+ builder.setInsertionPointToEnd(atomicBlock);
+ Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(),
+ atomicBlock->getArgument(1));
+ builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), type, atomicKind,
+ atomicBlock->getArgument(0), loaded,
+ LLVM::AtomicOrdering::monotonic);
+ builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
+ return decl;
+}
+
+/// Creates an OpenMP reduction declaration that corresponds to the given SCF
+/// reduction and returns it. Recognizes common reductions in order to identify
+/// the neutral value, necessary for the OpenMP declaration. If the reduction
+/// cannot be recognized, returns null.
+static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
+ scf::ReduceOp reduce) {
+ Operation *container = SymbolTable::getNearestSymbolTable(reduce);
+ SymbolTable symbolTable(container);
+
+ // Insert reduction declarations in the symbol-table ancestor before the
+ // ancestor of the current insertion point.
+ Operation *insertionPoint = reduce;
+ while (insertionPoint->getParentOp() != container)
+ insertionPoint = insertionPoint->getParentOp();
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPoint(insertionPoint);
+
+ assert(llvm::hasSingleElement(reduce.getRegion()) &&
+ "expected reduction region to have a single element");
+
+ // Match simple binary reductions that can be expressed with atomicrmw.
+ Type type = reduce.operand().getType();
+ Block &reduction = reduce.getRegion().front();
+ if (matchSimpleReduction<AddFOp, LLVM::FAddOp>(reduction)) {
+ omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
+ builder.getFloatAttr(type, 0.0));
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce);
+ }
+ if (matchSimpleReduction<AddIOp, LLVM::AddOp>(reduction)) {
+ omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
+ builder.getIntegerAttr(type, 0));
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce);
+ }
+ if (matchSimpleReduction<OrOp, LLVM::OrOp>(reduction)) {
+ omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
+ builder.getIntegerAttr(type, 0));
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce);
+ }
+ if (matchSimpleReduction<XOrOp, LLVM::XOrOp>(reduction)) {
+ omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
+ builder.getIntegerAttr(type, 0));
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce);
+ }
+ if (matchSimpleReduction<AndOp, LLVM::AndOp>(reduction)) {
+ omp::ReductionDeclareOp decl = createDecl(
+ builder, symbolTable, reduce,
+ builder.getIntegerAttr(
+ type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth())));
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce);
+ }
+
+ // Match simple binary reductions that cannot be expressed with atomicrmw.
+ // TODO: add atomic region using cmpxchg (which needs atomic load to be
+ // available as an op).
+ if (matchSimpleReduction<MulFOp, LLVM::FMulOp>(reduction)) {
+ return createDecl(builder, symbolTable, reduce,
+ builder.getFloatAttr(type, 1.0));
+ }
+
+ // Match select-based min/max reductions.
+ bool isMin;
+ if (matchSelectReduction<CmpFOp, SelectOp>(
+ reduction, {CmpFPredicate::OLT, CmpFPredicate::OLE},
+ {CmpFPredicate::OGT, CmpFPredicate::OGE}, isMin) ||
+ matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
+ reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
+ {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
+ return createDecl(builder, symbolTable, reduce,
+ minMaxValueForFloat(type, !isMin));
+ }
+ if (matchSelectReduction<CmpIOp, SelectOp>(
+ reduction, {CmpIPredicate::slt, CmpIPredicate::sle},
+ {CmpIPredicate::sgt, CmpIPredicate::sge}, isMin) ||
+ matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
+ reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
+ {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
+ omp::ReductionDeclareOp decl = createDecl(
+ builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
+ return addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
+ decl, reduce);
+ }
+ if (matchSelectReduction<CmpIOp, SelectOp>(
+ reduction, {CmpIPredicate::ult, CmpIPredicate::ule},
+ {CmpIPredicate::ugt, CmpIPredicate::uge}, isMin) ||
+ matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
+ reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
+ {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
+ omp::ReductionDeclareOp decl = createDecl(
+ builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
+ return addAtomicRMW(
+ builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
+ decl, reduce);
+ }
+
+ return nullptr;
+}
+
namespace {
-/// Converts SCF parallel operation into an OpenMP workshare loop construct.
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override {
- // TODO: add support for reductions when OpenMP loops have them.
- if (parallelOp.getNumResults() != 0)
- return rewriter.notifyMatchFailure(
- parallelOp,
- "OpenMP dialect does not yet support loops with reductions");
-
// Replace SCF yield with OpenMP yield.
{
OpBuilder::InsertionGuard guard(rewriter);
@@ -43,47 +328,118 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
parallelOp.getBody()->getTerminator(), ValueRange());
}
- // Replace the loop.
- auto omp = rewriter.create<omp::ParallelOp>(parallelOp.getLoc());
- Block *block = rewriter.createBlock(&omp.getRegion());
- rewriter.setInsertionPointToStart(block);
- auto loop = rewriter.create<omp::WsLoopOp>(
- parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
- parallelOp.step());
- rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
- loop.region().begin());
- rewriter.create<omp::TerminatorOp>(parallelOp.getLoc());
-
- rewriter.eraseOp(parallelOp);
+ // Declare reductions.
+ // TODO: consider checking it here is already a compatible reduction
+ // declaration and use it instead of redeclaring.
+ SmallVector<Attribute> reductionDeclSymbols;
+ for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
+ omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce);
+ if (!decl)
+ return failure();
+ reductionDeclSymbols.push_back(
+ SymbolRefAttr::get(rewriter.getContext(), decl.sym_name()));
+ }
+
+ // Allocate reduction variables. Make sure the we don't overflow the stack
+ // with local `alloca`s by saving and restoring the stack pointer.
+ Location loc = parallelOp.getLoc();
+ Value one = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
+ SmallVector<Value> reductionVariables;
+ reductionVariables.reserve(parallelOp.getNumReductions());
+ Value token = rewriter.create<LLVM::StackSaveOp>(
+ loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8)));
+ for (Value init : parallelOp.initVals()) {
+ assert((LLVM::isCompatibleType(init.getType()) ||
+ init.getType().isa<LLVM::PointerElementTypeInterface>()) &&
+ "cannot create a reduction variable if the type is not an LLVM "
+ "pointer element");
+ Value storage = rewriter.create<LLVM::AllocaOp>(
+ loc, LLVM::LLVMPointerType::get(init.getType()), one, 0);
+ rewriter.create<LLVM::StoreOp>(loc, init, storage);
+ reductionVariables.push_back(storage);
+ }
+
+ // Replace the reduction operations contained in this loop. Must be done
+ // here rather than in a separate pattern to have access to the list of
+ // reduction variables.
+ for (auto pair :
+ llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ scf::ReduceOp reduceOp = std::get<0>(pair);
+ rewriter.setInsertionPoint(reduceOp);
+ rewriter.replaceOpWithNewOp<omp::ReductionOp>(
+ reduceOp, reduceOp.operand(), std::get<1>(pair));
+ }
+
+ // Create the parallel wrapper.
+ auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.createBlock(&ompParallel.region());
+
+ // Replace SCF yield with OpenMP yield.
+ {
+ OpBuilder::InsertionGuard innerGuard(rewriter);
+ rewriter.setInsertionPointToEnd(parallelOp.getBody());
+ assert(llvm::hasSingleElement(parallelOp.region()) &&
+ "expected scf.parallel to have one block");
+ rewriter.replaceOpWithNewOp<omp::YieldOp>(
+ parallelOp.getBody()->getTerminator(), ValueRange());
+ }
+
+ // Replace the loop.
+ auto loop = rewriter.create<omp::WsLoopOp>(
+ parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
+ parallelOp.step());
+ rewriter.create<omp::TerminatorOp>(loc);
+
+ rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
+ loop.region().begin());
+ if (!reductionVariables.empty()) {
+ loop.reductionsAttr(
+ ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
+ loop.reduction_varsMutable().append(reductionVariables);
+ }
+ }
+
+ // Load loop results.
+ SmallVector<Value> results;
+ results.reserve(reductionVariables.size());
+ for (Value variable : reductionVariables) {
+ Value res = rewriter.create<LLVM::LoadOp>(loc, variable);
+ results.push_back(res);
+ }
+ rewriter.replaceOp(parallelOp, results);
+
+ rewriter.create<LLVM::StackRestoreOp>(loc, token);
return success();
}
};
/// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(FuncOp func) {
- ConversionTarget target(*func.getContext());
- target.addIllegalOp<scf::ParallelOp>();
- target.addDynamicallyLegalOp<scf::YieldOp>(
- [](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); });
- target.addLegalDialect<omp::OpenMPDialect>();
-
- RewritePatternSet patterns(func.getContext());
- patterns.add<ParallelOpLowering>(func.getContext());
+static LogicalResult applyPatterns(ModuleOp module) {
+ ConversionTarget target(*module.getContext());
+ target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
+ target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect>();
+
+ RewritePatternSet patterns(module.getContext());
+ patterns.add<ParallelOpLowering>(module.getContext());
FrozenRewritePatternSet frozen(std::move(patterns));
- return applyPartialConversion(func, target, frozen);
+ return applyPartialConversion(module, target, frozen);
}
/// A pass converting SCF operations to OpenMP operations.
struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
/// Pass entry point.
- void runOnFunction() override {
- if (failed(applyPatterns(getFunction())))
+ void runOnOperation() override {
+ if (failed(applyPatterns(getOperation())))
signalPassFailure();
}
};
} // end namespace
-std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertSCFToOpenMPPass() {
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() {
return std::make_unique<SCFToOpenMPPass>();
}
diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
new file mode 100644
index 0000000000000..bbc7d61a33a4b
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -convert-scf-to-openmp -split-input-file %s | FileCheck %s
+
+// CHECK: omp.reduction.declare @[[$REDF:.*]] : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(0.000000e+00 : f32)
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[RES:.*]] = addf %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK: atomic
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<f32>, %[[ARG1:.*]]: !llvm.ptr<f32>):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: llvm.atomicrmw fadd %[[ARG0]], %[[RHS]] monotonic
+
+// CHECK-LABEL: @reduction1
+func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index) {
+ // CHECK: %[[CST:.*]] = constant 0.0
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1
+ // CHECK: llvm.intr.stacksave
+ // CHECK: %[[BUF:.*]] = llvm.alloca %[[ONE]] x f32
+ // CHECK: llvm.store %[[CST]], %[[BUF]]
+ %step = constant 1 : index
+ %zero = constant 0.0 : f32
+ // CHECK: omp.parallel
+ // CHECK: omp.wsloop
+ // CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]]
+ scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+ step (%arg4, %step) init (%zero) -> (f32) {
+ // CHECK: %[[CST_INNER:.*]] = constant 1.0
+ %one = constant 1.0 : f32
+ // CHECK: omp.reduction %[[CST_INNER]], %[[BUF]]
+ scf.reduce(%one) : f32 {
+ ^bb0(%lhs : f32, %rhs: f32):
+ %res = addf %lhs, %rhs : f32
+ scf.reduce.return %res : f32
+ }
+ // CHECK: omp.yield
+ }
+ // CHECK: omp.terminator
+ // CHECK: llvm.load %[[BUF]]
+ // CHECK: llvm.intr.stackrestore
+ return
+}
+
+// -----
+
+// Only check the declaration here, the rest is same as above.
+// CHECK: omp.reduction.declare @{{.*}} : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(1.000000e+00 : f32)
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[RES:.*]] = mulf %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK-NOT: atomic
+
+// CHECK-LABEL: @reduction2
+func @reduction2(%arg0 : index, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index) {
+ %step = constant 1 : index
+ %zero = constant 0.0 : f32
+ scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+ step (%arg4, %step) init (%zero) -> (f32) {
+ %one = constant 1.0 : f32
+ scf.reduce(%one) : f32 {
+ ^bb0(%lhs : f32, %rhs: f32):
+ %res = mulf %lhs, %rhs : f32
+ scf.reduce.return %res : f32
+ }
+ }
+ return
+}
+
+// -----
+
+// Only check the declaration here, the rest is same as above.
+// CHECK: omp.reduction.declare @{{.*}} : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[CMP:.*]] = cmpf oge, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK-NOT: atomic
+
+// CHECK-LABEL: @reduction3
+func @reduction3(%arg0 : index, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index) {
+ %step = constant 1 : index
+ %zero = constant 0.0 : f32
+ scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+ step (%arg4, %step) init (%zero) -> (f32) {
+ %one = constant 1.0 : f32
+ scf.reduce(%one) : f32 {
+ ^bb0(%lhs : f32, %rhs: f32):
+ %cmp = cmpf oge, %lhs, %rhs : f32
+ %res = select %cmp, %lhs, %rhs : f32
+ scf.reduce.return %res : f32
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: omp.reduction.declare @[[$REDF1:.*]] : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[CMP:.*]] = cmpf oge, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK-NOT: atomic
+
+// CHECK: omp.reduction.declare @[[$REDF2:.*]] : i64
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant
+// CHECK: omp.yield(%[[INIT]] : i64)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64)
+// CHECK: %[[CMP:.*]] = cmpi slt, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG1]], %[[ARG0]]
+// CHECK: omp.yield(%[[RES]] : i64)
+
+// CHECK: atomic
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<i64>, %[[ARG1:.*]]: !llvm.ptr<i64>):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic
+
+// CHECK-LABEL: @reduction4
+func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index) -> (f32, i64) {
+ %step = constant 1 : index
+ // CHECK: %[[ZERO:.*]] = constant 0.0
+ %zero = constant 0.0 : f32
+ // CHECK: %[[IONE:.*]] = constant 1
+ %ione = constant 1 : i64
+ // CHECK: %[[BUF1:.*]] = llvm.alloca %{{.*}} x f32
+ // CHECK: llvm.store %[[ZERO]], %[[BUF1]]
+ // CHECK: %[[BUF2:.*]] = llvm.alloca %{{.*}} x i64
+ // CHECK: llvm.store %[[IONE]], %[[BUF2]]
+
+ // CHECK: omp.parallel
+ // CHECK: omp.wsloop
+ // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]]
+ // CHECK-SAME: @[[$REDF2]] -> %[[BUF2]]
+ %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+ step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
+ %one = constant 1.0 : f32
+ // CHECK: omp.reduction %{{.*}}, %[[BUF1]]
+ scf.reduce(%one) : f32 {
+ ^bb0(%lhs : f32, %rhs: f32):
+ %cmp = cmpf oge, %lhs, %rhs : f32
+ %res = select %cmp, %lhs, %rhs : f32
+ scf.reduce.return %res : f32
+ }
+ // CHECK: fptosi
+ %1 = fptosi %one : f32 to i64
+ // CHECK: omp.reduction %{{.*}}, %[[BUF2]]
+ scf.reduce(%1) : i64 {
+ ^bb1(%lhs: i64, %rhs: i64):
+ %cmp = cmpi slt, %lhs, %rhs : i64
+ %res = select %cmp, %rhs, %lhs : i64
+ scf.reduce.return %res : i64
+ }
+ // CHECK: omp.yield
+ }
+ // CHECK: omp.terminator
+ // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]]
+ // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]]
+ // CHECK: return %[[RES1]], %[[RES2]]
+ return %res#0, %res#1 : f32, i64
+}
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index 44059a27b3295..1507f927b9f00 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -27,7 +27,6 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
// CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()
"test.payload"(%i, %j) : (index, index) -> ()
- // CHECK: omp.yield
// CHECK: }
}
// CHECK: omp.yield
@@ -38,6 +37,7 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
return
}
+// CHECK-LABEL: @adjacent_loops
func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
// CHECK: omp.parallel {
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 225703d0f3a3e..f486701c1dfcd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4316,9 +4316,11 @@ cc_library(
deps = [
":ConversionPassIncGen",
":IR",
+ ":LLVMDialect",
":OpenMPDialect",
":Pass",
":SCFDialect",
+ ":StandardOps",
":Support",
":Transforms",
],
More information about the Mlir-commits
mailing list