[Mlir-commits] [mlir] 1ce752b - [mlir] support reductions in SCF to OpenMP conversion

Alex Zinenko llvmlistbot at llvm.org
Thu Sep 9 04:04:37 PDT 2021


Author: Alex Zinenko
Date: 2021-09-09T13:04:27+02:00
New Revision: 1ce752b7415bef59200c32f64fcfac795813b11e

URL: https://github.com/llvm/llvm-project/commit/1ce752b7415bef59200c32f64fcfac795813b11e
DIFF: https://github.com/llvm/llvm-project/commit/1ce752b7415bef59200c32f64fcfac795813b11e.diff

LOG: [mlir] support reductions in SCF to OpenMP conversion

OpenMP reductions need a neutral element, so we match some known reduction
kinds (integer add/mul/or/and/xor, float add/mul, integer and float min/max) to
define the neutral element and the atomic version when possible to express
using atomicrmw (everything except float mul). The SCF-to-OpenMP pass becomes a
module pass because it now needs to introduce new symbols for reduction
declarations in the module.

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D107549

Added: 
    mlir/test/Conversion/SCFToOpenMP/reductions.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
    mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
    mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 24b2cf1dc422b..30940fcc891f9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -355,11 +355,11 @@ def ConvertPDLToPDLInterp : Pass<"convert-pdl-to-pdl-interp", "ModuleOp"> {
 // SCFToOpenMP
 //===----------------------------------------------------------------------===//
 
-def ConvertSCFToOpenMP : FunctionPass<"convert-scf-to-openmp"> {
+def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> {
   let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
                 "constructs.";
   let constructor = "mlir::createConvertSCFToOpenMPPass()";
-  let dependentDialects = ["omp::OpenMPDialect"];
+  let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect"];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
index 349c4e1efc833..4000bc1df46b2 100644
--- a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
+++ b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
@@ -12,11 +12,11 @@
 #include <memory>
 
 namespace mlir {
-class FuncOp;
+class ModuleOp;
 template <typename T>
 class OperationPass;
 
-std::unique_ptr<OperationPass<FuncOp>> createConvertSCFToOpenMPPass();
+std::unique_ptr<OperationPass<ModuleOp>> createConvertSCFToOpenMPPass();
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 08396a345c6e6..05d406d09fc67 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -317,12 +317,12 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
     The optional $device parameter specifies the device number for the target region.
 
     The optional $thread_limit specifies the limit on the number of threads
-    
+
     The optional $nowait elliminates the implicit barrier so the parent task can make progress
     even if the target task is not yet completed.
-    
+
     TODO:  private, map, is_device_ptr, firstprivate, depend, defaultmap, in_reduction
-    
+
   }];
 
   let arguments = (ins Optional<I1>:$if_expr,

diff  --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
index 1ef4b74da6dd7..1a75a3549f979 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
+++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
@@ -11,7 +11,9 @@ add_mlir_conversion_library(MLIRSCFToOpenMP
   Core
 
   LINK_LIBS PUBLIC
+  MLIRLLVMIR
   MLIROpenMP
   MLIRSCF
+  MLIRStandard
   MLIRTransforms
   )

diff  --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index a7d4a99c9d5b5..9c6fc6fed9139 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -13,26 +13,311 @@
 
 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
 #include "../PassDetail.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 
+/// Matches a block containing a "simple" reduction. The expected shape of the
+/// block is as follows.
+///
+///   ^bb(%arg0, %arg1):
+///     %0 = OpTy(%arg0, %arg1)
+///     scf.reduce.return %0
+template <typename... OpTy>
+static bool matchSimpleReduction(Block &block) {
+  if (block.empty() || llvm::hasSingleElement(block) ||
+      std::next(block.begin(), 2) != block.end())
+    return false;
+  return isa<OpTy...>(block.front()) &&
+         isa<scf::ReduceReturnOp>(block.back()) &&
+         block.front().getOperands() == block.getArguments() &&
+         block.back().getOperand(0) == block.front().getResult(0);
+}
+
+/// Matches a block containing a select-based min/max reduction. The types of
+/// select and compare operations are provided as template arguments. The
+/// comparison predicates suitable for min and max are provided as function
+/// arguments. If a reduction is matched, `ifMin` will be set if the reduction
+/// compute the minimum and unset if it computes the maximum, otherwise it
+/// remains unmodified. The expected shape of the block is as follows.
+///
+///   ^bb(%arg0, %arg1):
+///     %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
+///     %1 = SelectOpTy(%0, %arg0, %arg1)  // %arg0, %arg1 may be swapped here.
+///     scf.reduce.return %1
+template <
+    typename CompareOpTy, typename SelectOpTy,
+    typename Predicate = decltype(std::declval<CompareOpTy>().predicate())>
+static bool
+matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
+                     ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
+  static_assert(llvm::is_one_of<SelectOpTy, SelectOp, LLVM::SelectOp>::value,
+                "only std and llvm select ops are supported");
+
+  // Expect exactly three operations in the block.
+  if (block.empty() || llvm::hasSingleElement(block) ||
+      std::next(block.begin(), 2) == block.end() ||
+      std::next(block.begin(), 3) != block.end())
+    return false;
+
+  // Check op kinds.
+  auto compare = dyn_cast<CompareOpTy>(block.front());
+  auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
+  auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
+  if (!compare || !select || !terminator)
+    return false;
+
+  // Block arguments must be compared.
+  if (compare->getOperands() != block.getArguments())
+    return false;
+
+  // Detect whether the comparison is less-than or greater-than, otherwise bail.
+  bool isLess;
+  if (llvm::find(lessThanPredicates, compare.predicate()) !=
+      lessThanPredicates.end()) {
+    isLess = true;
+  } else if (llvm::find(greaterThanPredicates, compare.predicate()) !=
+             greaterThanPredicates.end()) {
+    isLess = false;
+  } else {
+    return false;
+  }
+
+  if (select.condition() != compare.getResult())
+    return false;
+
+  // Detect if the operands are swapped between cmpf and select. Match the
+  // comparison type with the requested type or with the opposite of the
+  // requested type if the operands are swapped. Use generic accessors because
+  // std and LLVM versions of select have 
diff erent operand names but identical
+  // positions.
+  constexpr unsigned kTrueValue = 1;
+  constexpr unsigned kFalseValue = 2;
+  bool sameOperands = select.getOperand(kTrueValue) == compare.lhs() &&
+                      select.getOperand(kFalseValue) == compare.rhs();
+  bool swappedOperands = select.getOperand(kTrueValue) == compare.rhs() &&
+                         select.getOperand(kFalseValue) == compare.lhs();
+  if (!sameOperands && !swappedOperands)
+    return false;
+
+  if (select.getResult() != terminator.result())
+    return false;
+
+  // The reduction is a min if it uses less-than predicates with same operands
+  // or greather-than predicates with swapped operands. Similarly for max.
+  isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
+  return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
+}
+
+/// Returns the float semantics for the given float type.
+static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
+  if (type.isF16())
+    return llvm::APFloat::IEEEhalf();
+  if (type.isF32())
+    return llvm::APFloat::IEEEsingle();
+  if (type.isF64())
+    return llvm::APFloat::IEEEdouble();
+  if (type.isF128())
+    return llvm::APFloat::IEEEquad();
+  if (type.isBF16())
+    return llvm::APFloat::BFloat();
+  if (type.isF80())
+    return llvm::APFloat::x87DoubleExtended();
+  llvm_unreachable("unknown float type");
+}
+
+/// Returns an attribute with the minimum (if `min` is set) or the maximum value
+/// (otherwise) for the given float type.
+static Attribute minMaxValueForFloat(Type type, bool min) {
+  auto fltType = type.cast<FloatType>();
+  return FloatAttr::get(
+      type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+}
+
+/// Returns an attribute with the signed integer minimum (if `min` is set) or
+/// the maximum value (otherwise) for the given integer type, regardless of its
+/// signedness semantics (only the width is considered).
+static Attribute minMaxValueForSignedInt(Type type, bool min) {
+  auto intType = type.cast<IntegerType>();
+  unsigned bitwidth = intType.getWidth();
+  return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
+                                    : llvm::APInt::getSignedMaxValue(bitwidth));
+}
+
+/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
+/// the maximum value (otherwise) for the given integer type, regardless of its
+/// signedness semantics (only the width is considered).
+static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
+  auto intType = type.cast<IntegerType>();
+  unsigned bitwidth = intType.getWidth();
+  return IntegerAttr::get(type, min ? llvm::APInt::getNullValue(bitwidth)
+                                    : llvm::APInt::getAllOnesValue(bitwidth));
+}
+
+/// Creates an OpenMP reduction declaration and inserts it into the provided
+/// symbol table. The declaration has a constant initializer with the neutral
+/// value `initValue`, and the reduction combiner carried over from `reduce`.
+static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
+                                          SymbolTable &symbolTable,
+                                          scf::ReduceOp reduce,
+                                          Attribute initValue) {
+  OpBuilder::InsertionGuard guard(builder);
+  auto decl = builder.create<omp::ReductionDeclareOp>(
+      reduce.getLoc(), "__scf_reduction", reduce.operand().getType());
+  symbolTable.insert(decl);
+
+  Type type = reduce.operand().getType();
+  builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(),
+                      {type});
+  builder.setInsertionPointToEnd(&decl.initializerRegion().back());
+  Value init =
+      builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
+  builder.create<omp::YieldOp>(reduce.getLoc(), init);
+
+  Operation *terminator = &reduce.getRegion().front().back();
+  assert(isa<scf::ReduceReturnOp>(terminator) &&
+         "expected reduce op to be terminated by redure return");
+  builder.setInsertionPoint(terminator);
+  builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
+                                           terminator->getOperands());
+  builder.inlineRegionBefore(reduce.getRegion(), decl.reductionRegion(),
+                             decl.reductionRegion().end());
+  return decl;
+}
+
+/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
+/// using llvm.atomicrmw of the given kind.
+static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
+                                            LLVM::AtomicBinOp atomicKind,
+                                            omp::ReductionDeclareOp decl,
+                                            scf::ReduceOp reduce) {
+  OpBuilder::InsertionGuard guard(builder);
+  Type type = reduce.operand().getType();
+  Type ptrType = LLVM::LLVMPointerType::get(type);
+  builder.createBlock(&decl.atomicReductionRegion(),
+                      decl.atomicReductionRegion().end(), {ptrType, ptrType});
+  Block *atomicBlock = &decl.atomicReductionRegion().back();
+  builder.setInsertionPointToEnd(atomicBlock);
+  Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(),
+                                              atomicBlock->getArgument(1));
+  builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), type, atomicKind,
+                                    atomicBlock->getArgument(0), loaded,
+                                    LLVM::AtomicOrdering::monotonic);
+  builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
+  return decl;
+}
+
+/// Creates an OpenMP reduction declaration that corresponds to the given SCF
+/// reduction and returns it. Recognizes common reductions in order to identify
+/// the neutral value, necessary for the OpenMP declaration. If the reduction
+/// cannot be recognized, returns null.
+static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
+                                                scf::ReduceOp reduce) {
+  Operation *container = SymbolTable::getNearestSymbolTable(reduce);
+  SymbolTable symbolTable(container);
+
+  // Insert reduction declarations in the symbol-table ancestor before the
+  // ancestor of the current insertion point.
+  Operation *insertionPoint = reduce;
+  while (insertionPoint->getParentOp() != container)
+    insertionPoint = insertionPoint->getParentOp();
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPoint(insertionPoint);
+
+  assert(llvm::hasSingleElement(reduce.getRegion()) &&
+         "expected reduction region to have a single element");
+
+  // Match simple binary reductions that can be expressed with atomicrmw.
+  Type type = reduce.operand().getType();
+  Block &reduction = reduce.getRegion().front();
+  if (matchSimpleReduction<AddFOp, LLVM::FAddOp>(reduction)) {
+    omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
+                                              builder.getFloatAttr(type, 0.0));
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce);
+  }
+  if (matchSimpleReduction<AddIOp, LLVM::AddOp>(reduction)) {
+    omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
+                                              builder.getIntegerAttr(type, 0));
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce);
+  }
+  if (matchSimpleReduction<OrOp, LLVM::OrOp>(reduction)) {
+    omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
+                                              builder.getIntegerAttr(type, 0));
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce);
+  }
+  if (matchSimpleReduction<XOrOp, LLVM::XOrOp>(reduction)) {
+    omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
+                                              builder.getIntegerAttr(type, 0));
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce);
+  }
+  if (matchSimpleReduction<AndOp, LLVM::AndOp>(reduction)) {
+    omp::ReductionDeclareOp decl = createDecl(
+        builder, symbolTable, reduce,
+        builder.getIntegerAttr(
+            type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth())));
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce);
+  }
+
+  // Match simple binary reductions that cannot be expressed with atomicrmw.
+  // TODO: add atomic region using cmpxchg (which needs atomic load to be
+  // available as an op).
+  if (matchSimpleReduction<MulFOp, LLVM::FMulOp>(reduction)) {
+    return createDecl(builder, symbolTable, reduce,
+                      builder.getFloatAttr(type, 1.0));
+  }
+
+  // Match select-based min/max reductions.
+  bool isMin;
+  if (matchSelectReduction<CmpFOp, SelectOp>(
+          reduction, {CmpFPredicate::OLT, CmpFPredicate::OLE},
+          {CmpFPredicate::OGT, CmpFPredicate::OGE}, isMin) ||
+      matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
+          reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
+          {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
+    return createDecl(builder, symbolTable, reduce,
+                      minMaxValueForFloat(type, !isMin));
+  }
+  if (matchSelectReduction<CmpIOp, SelectOp>(
+          reduction, {CmpIPredicate::slt, CmpIPredicate::sle},
+          {CmpIPredicate::sgt, CmpIPredicate::sge}, isMin) ||
+      matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
+          reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
+          {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
+    omp::ReductionDeclareOp decl = createDecl(
+        builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
+    return addAtomicRMW(builder,
+                        isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
+                        decl, reduce);
+  }
+  if (matchSelectReduction<CmpIOp, SelectOp>(
+          reduction, {CmpIPredicate::ult, CmpIPredicate::ule},
+          {CmpIPredicate::ugt, CmpIPredicate::uge}, isMin) ||
+      matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
+          reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
+          {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
+    omp::ReductionDeclareOp decl = createDecl(
+        builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
+    return addAtomicRMW(
+        builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
+        decl, reduce);
+  }
+
+  return nullptr;
+}
+
 namespace {
 
-/// Converts SCF parallel operation into an OpenMP workshare loop construct.
 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
   using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override {
-    // TODO: add support for reductions when OpenMP loops have them.
-    if (parallelOp.getNumResults() != 0)
-      return rewriter.notifyMatchFailure(
-          parallelOp,
-          "OpenMP dialect does not yet support loops with reductions");
-
     // Replace SCF yield with OpenMP yield.
     {
       OpBuilder::InsertionGuard guard(rewriter);
@@ -43,47 +328,118 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
           parallelOp.getBody()->getTerminator(), ValueRange());
     }
 
-    // Replace the loop.
-    auto omp = rewriter.create<omp::ParallelOp>(parallelOp.getLoc());
-    Block *block = rewriter.createBlock(&omp.getRegion());
-    rewriter.setInsertionPointToStart(block);
-    auto loop = rewriter.create<omp::WsLoopOp>(
-        parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
-        parallelOp.step());
-    rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
-                                loop.region().begin());
-    rewriter.create<omp::TerminatorOp>(parallelOp.getLoc());
-
-    rewriter.eraseOp(parallelOp);
+    // Declare reductions.
+    // TODO: consider checking it here is already a compatible reduction
+    // declaration and use it instead of redeclaring.
+    SmallVector<Attribute> reductionDeclSymbols;
+    for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
+      omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce);
+      if (!decl)
+        return failure();
+      reductionDeclSymbols.push_back(
+          SymbolRefAttr::get(rewriter.getContext(), decl.sym_name()));
+    }
+
+    // Allocate reduction variables. Make sure the we don't overflow the stack
+    // with local `alloca`s by saving and restoring the stack pointer.
+    Location loc = parallelOp.getLoc();
+    Value one = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
+    SmallVector<Value> reductionVariables;
+    reductionVariables.reserve(parallelOp.getNumReductions());
+    Value token = rewriter.create<LLVM::StackSaveOp>(
+        loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8)));
+    for (Value init : parallelOp.initVals()) {
+      assert((LLVM::isCompatibleType(init.getType()) ||
+              init.getType().isa<LLVM::PointerElementTypeInterface>()) &&
+             "cannot create a reduction variable if the type is not an LLVM "
+             "pointer element");
+      Value storage = rewriter.create<LLVM::AllocaOp>(
+          loc, LLVM::LLVMPointerType::get(init.getType()), one, 0);
+      rewriter.create<LLVM::StoreOp>(loc, init, storage);
+      reductionVariables.push_back(storage);
+    }
+
+    // Replace the reduction operations contained in this loop. Must be done
+    // here rather than in a separate pattern to have access to the list of
+    // reduction variables.
+    for (auto pair :
+         llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) {
+      OpBuilder::InsertionGuard guard(rewriter);
+      scf::ReduceOp reduceOp = std::get<0>(pair);
+      rewriter.setInsertionPoint(reduceOp);
+      rewriter.replaceOpWithNewOp<omp::ReductionOp>(
+          reduceOp, reduceOp.operand(), std::get<1>(pair));
+    }
+
+    // Create the parallel wrapper.
+    auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.createBlock(&ompParallel.region());
+
+      // Replace SCF yield with OpenMP yield.
+      {
+        OpBuilder::InsertionGuard innerGuard(rewriter);
+        rewriter.setInsertionPointToEnd(parallelOp.getBody());
+        assert(llvm::hasSingleElement(parallelOp.region()) &&
+               "expected scf.parallel to have one block");
+        rewriter.replaceOpWithNewOp<omp::YieldOp>(
+            parallelOp.getBody()->getTerminator(), ValueRange());
+      }
+
+      // Replace the loop.
+      auto loop = rewriter.create<omp::WsLoopOp>(
+          parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
+          parallelOp.step());
+      rewriter.create<omp::TerminatorOp>(loc);
+
+      rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
+                                  loop.region().begin());
+      if (!reductionVariables.empty()) {
+        loop.reductionsAttr(
+            ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
+        loop.reduction_varsMutable().append(reductionVariables);
+      }
+    }
+
+    // Load loop results.
+    SmallVector<Value> results;
+    results.reserve(reductionVariables.size());
+    for (Value variable : reductionVariables) {
+      Value res = rewriter.create<LLVM::LoadOp>(loc, variable);
+      results.push_back(res);
+    }
+    rewriter.replaceOp(parallelOp, results);
+
+    rewriter.create<LLVM::StackRestoreOp>(loc, token);
     return success();
   }
 };
 
 /// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(FuncOp func) {
-  ConversionTarget target(*func.getContext());
-  target.addIllegalOp<scf::ParallelOp>();
-  target.addDynamicallyLegalOp<scf::YieldOp>(
-      [](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); });
-  target.addLegalDialect<omp::OpenMPDialect>();
-
-  RewritePatternSet patterns(func.getContext());
-  patterns.add<ParallelOpLowering>(func.getContext());
+static LogicalResult applyPatterns(ModuleOp module) {
+  ConversionTarget target(*module.getContext());
+  target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
+  target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect>();
+
+  RewritePatternSet patterns(module.getContext());
+  patterns.add<ParallelOpLowering>(module.getContext());
   FrozenRewritePatternSet frozen(std::move(patterns));
-  return applyPartialConversion(func, target, frozen);
+  return applyPartialConversion(module, target, frozen);
 }
 
 /// A pass converting SCF operations to OpenMP operations.
 struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
   /// Pass entry point.
-  void runOnFunction() override {
-    if (failed(applyPatterns(getFunction())))
+  void runOnOperation() override {
+    if (failed(applyPatterns(getOperation())))
       signalPassFailure();
   }
 };
 
 } // end namespace
 
-std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertSCFToOpenMPPass() {
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() {
   return std::make_unique<SCFToOpenMPPass>();
 }

diff  --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
new file mode 100644
index 0000000000000..bbc7d61a33a4b
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -convert-scf-to-openmp -split-input-file %s | FileCheck %s
+
+// CHECK: omp.reduction.declare @[[$REDF:.*]] : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(0.000000e+00 : f32)
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[RES:.*]] = addf %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK: atomic
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<f32>, %[[ARG1:.*]]: !llvm.ptr<f32>):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: llvm.atomicrmw fadd %[[ARG0]], %[[RHS]] monotonic
+
+// CHECK-LABEL: @reduction1
+func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
+                 %arg3 : index, %arg4 : index) {
+  // CHECK: %[[CST:.*]] = constant 0.0
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1
+  // CHECK: llvm.intr.stacksave
+  // CHECK: %[[BUF:.*]] = llvm.alloca %[[ONE]] x f32
+  // CHECK: llvm.store %[[CST]], %[[BUF]]
+  %step = constant 1 : index
+  %zero = constant 0.0 : f32
+  // CHECK: omp.parallel
+  // CHECK: omp.wsloop
+  // CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]]
+  scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+                            step (%arg4, %step) init (%zero) -> (f32) {
+    // CHECK: %[[CST_INNER:.*]] = constant 1.0
+    %one = constant 1.0 : f32
+    // CHECK: omp.reduction %[[CST_INNER]], %[[BUF]]
+    scf.reduce(%one) : f32 {
+    ^bb0(%lhs : f32, %rhs: f32):
+      %res = addf %lhs, %rhs : f32
+      scf.reduce.return %res : f32
+    }
+    // CHECK: omp.yield
+  }
+  // CHECK: omp.terminator
+  // CHECK: llvm.load %[[BUF]]
+  // CHECK: llvm.intr.stackrestore
+  return
+}
+
+// -----
+
+// Only check the declaration here, the rest is same as above.
+// CHECK: omp.reduction.declare @{{.*}} : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(1.000000e+00 : f32)
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[RES:.*]] = mulf %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK-NOT: atomic
+
+// CHECK-LABEL: @reduction2
+func @reduction2(%arg0 : index, %arg1 : index, %arg2 : index,
+                 %arg3 : index, %arg4 : index) {
+  %step = constant 1 : index
+  %zero = constant 0.0 : f32
+  scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+                            step (%arg4, %step) init (%zero) -> (f32) {
+    %one = constant 1.0 : f32
+    scf.reduce(%one) : f32 {
+    ^bb0(%lhs : f32, %rhs: f32):
+      %res = mulf %lhs, %rhs : f32
+      scf.reduce.return %res : f32
+    }
+  }
+  return
+}
+
+// -----
+
+// Only check the declaration here, the rest is same as above.
+// CHECK: omp.reduction.declare @{{.*}} : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[CMP:.*]] = cmpf oge, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK-NOT: atomic
+
+// CHECK-LABEL: @reduction3
+func @reduction3(%arg0 : index, %arg1 : index, %arg2 : index,
+                 %arg3 : index, %arg4 : index) {
+  %step = constant 1 : index
+  %zero = constant 0.0 : f32
+  scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+                            step (%arg4, %step) init (%zero) -> (f32) {
+    %one = constant 1.0 : f32
+    scf.reduce(%one) : f32 {
+    ^bb0(%lhs : f32, %rhs: f32):
+      %cmp = cmpf oge, %lhs, %rhs : f32
+      %res = select %cmp, %lhs, %rhs : f32
+      scf.reduce.return %res : f32
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: omp.reduction.declare @[[$REDF1:.*]] : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[CMP:.*]] = cmpf oge, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK-NOT: atomic
+
+// CHECK: omp.reduction.declare @[[$REDF2:.*]] : i64
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant
+// CHECK: omp.yield(%[[INIT]] : i64)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64)
+// CHECK: %[[CMP:.*]] = cmpi slt, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG1]], %[[ARG0]]
+// CHECK: omp.yield(%[[RES]] : i64)
+
+// CHECK: atomic
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<i64>, %[[ARG1:.*]]: !llvm.ptr<i64>):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic
+
+// CHECK-LABEL: @reduction4
+func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
+                 %arg3 : index, %arg4 : index) -> (f32, i64) {
+  %step = constant 1 : index
+  // CHECK: %[[ZERO:.*]] = constant 0.0
+  %zero = constant 0.0 : f32
+  // CHECK: %[[IONE:.*]] = constant 1
+  %ione = constant 1 : i64
+  // CHECK: %[[BUF1:.*]] = llvm.alloca %{{.*}} x f32
+  // CHECK: llvm.store %[[ZERO]], %[[BUF1]]
+  // CHECK: %[[BUF2:.*]] = llvm.alloca %{{.*}} x i64
+  // CHECK: llvm.store %[[IONE]], %[[BUF2]]
+
+  // CHECK: omp.parallel
+  // CHECK: omp.wsloop
+  // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]]
+  // CHECK-SAME:           @[[$REDF2]] -> %[[BUF2]]
+  %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+                        step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
+    %one = constant 1.0 : f32
+    // CHECK: omp.reduction %{{.*}}, %[[BUF1]]
+    scf.reduce(%one) : f32 {
+    ^bb0(%lhs : f32, %rhs: f32):
+      %cmp = cmpf oge, %lhs, %rhs : f32
+      %res = select %cmp, %lhs, %rhs : f32
+      scf.reduce.return %res : f32
+    }
+    // CHECK: fptosi
+    %1 = fptosi %one : f32 to i64
+    // CHECK: omp.reduction %{{.*}}, %[[BUF2]]
+    scf.reduce(%1) : i64 {
+    ^bb1(%lhs: i64, %rhs: i64):
+      %cmp = cmpi slt, %lhs, %rhs : i64
+      %res = select %cmp, %rhs, %lhs : i64
+      scf.reduce.return %res : i64
+    }
+    // CHECK: omp.yield
+  }
+  // CHECK: omp.terminator
+  // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]]
+  // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]]
+  // CHECK: return %[[RES1]], %[[RES2]]
+  return %res#0, %res#1 : f32, i64
+}

diff  --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index 44059a27b3295..1507f927b9f00 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -27,7 +27,6 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
     scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
       // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()
       "test.payload"(%i, %j) : (index, index) -> ()
-      // CHECK:   omp.yield
       // CHECK: }
     }
     // CHECK:   omp.yield
@@ -38,6 +37,7 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
   return
 }
 
+// CHECK-LABEL: @adjacent_loops
 func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
                      %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: omp.parallel {

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 225703d0f3a3e..f486701c1dfcd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4316,9 +4316,11 @@ cc_library(
     deps = [
         ":ConversionPassIncGen",
         ":IR",
+        ":LLVMDialect",
         ":OpenMPDialect",
         ":Pass",
         ":SCFDialect",
+        ":StandardOps",
         ":Support",
         ":Transforms",
     ],


        


More information about the Mlir-commits mailing list