[Mlir-commits] [mlir] 66900b3 - [mlir] Use dialect interfaces to translate OpenMP dialect to LLVM IR

Alex Zinenko llvmlistbot at llvm.org
Fri Feb 12 09:37:54 PST 2021


Author: Alex Zinenko
Date: 2021-02-12T18:37:47+01:00
New Revision: 66900b3eae96b295cc7eb9468680085028f35daa

URL: https://github.com/llvm/llvm-project/commit/66900b3eae96b295cc7eb9468680085028f35daa
DIFF: https://github.com/llvm/llvm-project/commit/66900b3eae96b295cc7eb9468680085028f35daa.diff

LOG: [mlir] Use dialect interfaces to translate OpenMP dialect to LLVM IR

Migrate the translation of the OpenMP dialect operations to LLVM IR to the new
dialect-based mechanism.

Depends On D96503

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96504

Added: 
    mlir/include/mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Target/CMakeLists.txt
    mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
    mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h
new file mode 100644
index 000000000000..7d9eeea9462e
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h
@@ -0,0 +1,37 @@
+//===- OpenMPToLLVMIRTranslation.h - OpenMP Dialect to LLVM IR --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the dialect interface for translating the OpenMP dialect
+// to LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_OPENMP_OPENMPTOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_OPENMP_OPENMPTOLLVMIRTRANSLATION_H
+
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
+
+namespace mlir {
+
+/// Implementation of the dialect interface that converts operations beloning to
+/// the OpenMP dialect to LLVM IR.
+class OpenMPDialectLLVMIRTranslationInterface
+    : public LLVMTranslationDialectInterface {
+public:
+  using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+  /// Translates the given operation to LLVM IR using the provided IR builder
+  /// and saving the state in `moduleTranslation`.
+  LogicalResult
+  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const final;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_OPENMP_OPENMPTOLLVMIRTRANSLATION_H

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index b15fcc304a14..03b7f5336461 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -22,6 +22,7 @@
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeTranslation.h"
 
+#include "llvm/ADT/SetVector.h"
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
@@ -160,6 +161,29 @@ class ModuleTranslation {
     return globalsMapping.lookup(op);
   }
 
+  /// Returns the OpenMP IR builder associated with the LLVM IR module being
+  /// constructed.
+  llvm::OpenMPIRBuilder *getOpenMPBuilder() {
+    if (!ompBuilder) {
+      ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
+      ompBuilder->initialize();
+    }
+    return ompBuilder.get();
+  }
+
+  /// Translates the given location.
+  const llvm::DILocation *translateLoc(Location loc, llvm::DILocalScope *scope);
+
+  /// Translates the contents of the given block to LLVM IR using this
+  /// translator. The LLVM IR basic block corresponding to the given block is
+  /// expected to exist in the mapping of this translator. Uses `builder` to
+  /// translate the IR, leaving it at the end of the block. If `ignoreArguments`
+  /// is set, does not produce PHI nodes for the block arguments. Otherwise, the
+  /// PHI nodes are constructed for block arguments but are _not_ connected to
+  /// the predecessors that may not exist yet.
+  LogicalResult convertBlock(Block &bb, bool ignoreArguments,
+                             llvm::IRBuilder<> &builder);
+
 protected:
   /// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
   /// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
@@ -170,19 +194,6 @@ class ModuleTranslation {
 
   virtual LogicalResult convertOperation(Operation &op,
                                          llvm::IRBuilder<> &builder);
-  virtual LogicalResult convertOmpOperation(Operation &op,
-                                            llvm::IRBuilder<> &builder);
-  virtual LogicalResult convertOmpParallel(Operation &op,
-                                           llvm::IRBuilder<> &builder);
-  virtual LogicalResult convertOmpMaster(Operation &op,
-                                         llvm::IRBuilder<> &builder);
-  void convertOmpOpRegions(Region &region, StringRef blockName,
-                           llvm::BasicBlock &sourceBlock,
-                           llvm::BasicBlock &continuationBlock,
-                           llvm::IRBuilder<> &builder,
-                           LogicalResult &bodyGenStatus);
-  virtual LogicalResult convertOmpWsLoop(Operation &opInst,
-                                         llvm::IRBuilder<> &builder);
 
   static std::unique_ptr<llvm::Module>
   prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
@@ -196,8 +207,6 @@ class ModuleTranslation {
   LogicalResult convertFunctions();
   LogicalResult convertGlobals();
   LogicalResult convertOneFunction(LLVMFuncOp func);
-  LogicalResult convertBlock(Block &bb, bool ignoreArguments,
-                             llvm::IRBuilder<> &builder);
 
   /// Original and translated module.
   Operation *mlirModule;
@@ -232,6 +241,16 @@ class ModuleTranslation {
   DenseMap<Operation *, llvm::Instruction *> branchMapping;
 };
 
+namespace detail {
+/// For all blocks in the region that were converted to LLVM IR using the given
+/// ModuleTranslation, connect the PHI nodes of the corresponding LLVM IR blocks
+/// to the results of preceding blocks.
+void connectPHINodes(Region &region, const ModuleTranslation &state);
+
+/// Get a topologically sorted list of blocks of the given region.
+llvm::SetVector<Block *> getTopologicallySortedBlocks(Region &region);
+} // namespace detail
+
 } // namespace LLVM
 } // namespace mlir
 

diff  --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index 72555ac7876b..e951ffade6aa 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -57,6 +57,7 @@ add_mlir_translation_library(MLIRTargetLLVMIR
 
   LINK_LIBS PUBLIC
   MLIRLLVMToLLVMIRTranslation
+  MLIROpenMPToLLVMIRTranslation
   MLIRTargetLLVMIRModuleTranslation
   )
 

diff  --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
index bf8e248804dd..6b30748cc79b 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 
@@ -70,6 +71,8 @@ void registerToLLVMIRTranslation() {
       },
       [](DialectRegistry &registry) {
         registry.insert<omp::OpenMPDialect>();
+        registry.addDialectInterface<omp::OpenMPDialect,
+                                     OpenMPDialectLLVMIRTranslationInterface>();
         registerLLVMDialectTranslation(registry);
       });
 }

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index 39d31dc9b5e9..9ab260874b7a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(LLVMIR)
+add_subdirectory(OpenMP)

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
new file mode 100644
index 000000000000..6cc36f47f075
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_translation_library(MLIROpenMPToLLVMIRTranslation
+  OpenMPToLLVMIRTranslation.cpp
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMIR
+  MLIROpenMP
+  MLIRSupport
+  MLIRTargetLLVMIRModuleTranslation
+  )

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
new file mode 100644
index 000000000000..361bdae7c97f
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -0,0 +1,309 @@
+//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR OpenMP dialect and LLVM
+// IR.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
+#include "llvm/IR/IRBuilder.h"
+
+using namespace mlir;
+
+/// Converts the given region that appears within an OpenMP dialect operation to
+/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
+/// region, and a branch from any block with an successor-less OpenMP terminator
+/// to `continuationBlock`.
+static void convertOmpOpRegions(Region &region, StringRef blockName,
+                                llvm::BasicBlock &sourceBlock,
+                                llvm::BasicBlock &continuationBlock,
+                                llvm::IRBuilderBase &builder,
+                                LLVM::ModuleTranslation &moduleTranslation,
+                                LogicalResult &bodyGenStatus) {
+  llvm::LLVMContext &llvmContext = builder.getContext();
+  for (Block &bb : region) {
+    llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
+        llvmContext, blockName, builder.GetInsertBlock()->getParent());
+    moduleTranslation.mapBlock(&bb, llvmBB);
+  }
+
+  llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
+
+  // Convert blocks one by one in topological order to ensure
+  // defs are converted before uses.
+  llvm::SetVector<Block *> blocks =
+      LLVM::detail::getTopologicallySortedBlocks(region);
+  for (Block *bb : blocks) {
+    llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
+    // Retarget the branch of the entry block to the entry block of the
+    // converted region (regions are single-entry).
+    if (bb->isEntryBlock()) {
+      assert(sourceTerminator->getNumSuccessors() == 1 &&
+             "provided entry block has multiple successors");
+      assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
+             "ContinuationBlock is not the successor of the entry block");
+      sourceTerminator->setSuccessor(0, llvmBB);
+    }
+
+    llvm::IRBuilder<>::InsertPointGuard guard(builder);
+    if (failed(moduleTranslation.convertBlock(
+            *bb, bb->isEntryBlock(),
+            // TODO: this downcast should be removed after all of
+            // ModuleTranslation migrated to using IRBuilderBase &; the cast is
+            // safe in practice because the builder always comes from
+            // ModuleTranslation itself that only uses this subclass.
+            static_cast<llvm::IRBuilder<> &>(builder)))) {
+      bodyGenStatus = failure();
+      return;
+    }
+
+    // Special handling for `omp.yield` and `omp.terminator` (we may have more
+    // than one): they return the control to the parent OpenMP dialect operation
+    // so replace them with the branch to the continuation block. We handle this
+    // here to avoid relying inter-function communication through the
+    // ModuleTranslation class to set up the correct insertion point. This is
+    // also consistent with MLIR's idiom of handling special region terminators
+    // in the same code that handles the region-owning operation.
+    if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator()))
+      builder.CreateBr(&continuationBlock);
+  }
+  // Finally, after all blocks have been traversed and values mapped,
+  // connect the PHI nodes to the results of preceding blocks.
+  LLVM::detail::connectPHINodes(region, moduleTranslation);
+}
+
+/// Converts the OpenMP parallel operation to LLVM IR.
+static LogicalResult
+convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) {
+  using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+  // TODO: support error propagation in OpenMPIRBuilder and use it instead of
+  // relying on captured variables.
+  LogicalResult bodyGenStatus = success();
+
+  auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
+                       llvm::BasicBlock &continuationBlock) {
+    // ParallelOp has only one region associated with it.
+    auto &region = cast<omp::ParallelOp>(opInst).getRegion();
+    convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
+                        continuationBlock, builder, moduleTranslation,
+                        bodyGenStatus);
+  };
+
+  // TODO: Perform appropriate actions according to the data-sharing
+  // attribute (shared, private, firstprivate, ...) of variables.
+  // Currently defaults to shared.
+  auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
+                    llvm::Value &, llvm::Value &vPtr,
+                    llvm::Value *&replacementValue) -> InsertPointTy {
+    replacementValue = &vPtr;
+
+    return codeGenIP;
+  };
+
+  // TODO: Perform finalization actions for variables. This has to be
+  // called for variables which have destructors/finalizers.
+  auto finiCB = [&](InsertPointTy codeGenIP) {};
+
+  llvm::Value *ifCond = nullptr;
+  if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
+    ifCond = moduleTranslation.lookupValue(ifExprVar);
+  llvm::Value *numThreads = nullptr;
+  if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
+    numThreads = moduleTranslation.lookupValue(numThreadsVar);
+  llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
+  if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
+    pbKind = llvm::omp::getProcBindKind(bind.getValue());
+  // TODO: Is the Parallel construct cancellable?
+  bool isCancellable = false;
+  // TODO: Determine the actual alloca insertion point, e.g., the function
+  // entry or the alloca insertion point as provided by the body callback
+  // above.
+  llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
+  if (failed(bodyGenStatus))
+    return failure();
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(
+      builder.saveIP(), builder.getCurrentDebugLocation());
+  builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
+      ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
+      isCancellable));
+  return success();
+}
+
+/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
+                 LLVM::ModuleTranslation &moduleTranslation) {
+  using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+  // TODO: support error propagation in OpenMPIRBuilder and use it instead of
+  // relying on captured variables.
+  LogicalResult bodyGenStatus = success();
+
+  auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
+                       llvm::BasicBlock &continuationBlock) {
+    // MasterOp has only one region associated with it.
+    auto &region = cast<omp::MasterOp>(opInst).getRegion();
+    convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
+                        continuationBlock, builder, moduleTranslation,
+                        bodyGenStatus);
+  };
+
+  // TODO: Perform finalization actions for variables. This has to be
+  // called for variables which have destructors/finalizers.
+  auto finiCB = [&](InsertPointTy codeGenIP) {};
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(
+      builder.saveIP(), builder.getCurrentDebugLocation());
+  builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
+      ompLoc, bodyGenCB, finiCB));
+  return success();
+}
+
+/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
+LogicalResult convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
+                               LLVM::ModuleTranslation &moduleTranslation) {
+  auto loop = cast<omp::WsLoopOp>(opInst);
+  // TODO: this should be in the op verifier instead.
+  if (loop.lowerBound().empty())
+    return failure();
+
+  if (loop.getNumLoops() != 1)
+    return opInst.emitOpError("collapsed loops not yet supported");
+
+  if (loop.schedule_val().hasValue() &&
+      omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()) !=
+          omp::ClauseScheduleKind::Static)
+    return opInst.emitOpError(
+        "only static (default) loop schedule is currently supported");
+
+  // Find the loop configuration.
+  llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]);
+  llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]);
+  llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
+  llvm::Type *ivType = step->getType();
+  llvm::Value *chunk =
+      loop.schedule_chunk_var()
+          ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
+          : llvm::ConstantInt::get(ivType, 1);
+
+  // Set up the source location value for OpenMP runtime.
+  llvm::DISubprogram *subprogram =
+      builder.GetInsertBlock()->getParent()->getSubprogram();
+  const llvm::DILocation *diLoc =
+      moduleTranslation.translateLoc(opInst.getLoc(), subprogram);
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
+                                                    llvm::DebugLoc(diLoc));
+
+  // Generator of the canonical loop body. Produces an SESE region of basic
+  // blocks.
+  // TODO: support error propagation in OpenMPIRBuilder and use it instead of
+  // relying on captured variables.
+  LogicalResult bodyGenStatus = success();
+  auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
+    llvm::IRBuilder<>::InsertPointGuard guard(builder);
+
+    // Make sure further conversions know about the induction variable.
+    moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv);
+
+    llvm::BasicBlock *entryBlock = ip.getBlock();
+    llvm::BasicBlock *exitBlock =
+        entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
+
+    // Convert the body of the loop.
+    convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
+                        *exitBlock, builder, moduleTranslation, bodyGenStatus);
+  };
+
+  // Delegate actual loop construction to the OpenMP IRBuilder.
+  // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
+  // i.e. it has a positive step, uses signed integer semantics. Reconsider
+  // this code when WsLoop clearly supports more cases.
+  llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
+  llvm::CanonicalLoopInfo *loopInfo =
+      moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
+          ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
+          /*InclusiveStop=*/loop.inclusive());
+  if (failed(bodyGenStatus))
+    return failure();
+
+  // TODO: get the alloca insertion point from the parallel operation builder.
+  // If we insert the at the top of the current function, they will be passed as
+  // extra arguments into the function the parallel operation builder outlines.
+  // Put them at the start of the current block for now.
+  llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
+      insertBlock, insertBlock->getFirstInsertionPt());
+  loopInfo = moduleTranslation.getOpenMPBuilder()->createStaticWorkshareLoop(
+      ompLoc, loopInfo, allocaIP, !loop.nowait(), chunk);
+
+  // Continue building IR after the loop.
+  builder.restoreIP(loopInfo->getAfterIP());
+  return success();
+}
+
+/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
+/// (including OpenMP runtime calls).
+LogicalResult mlir::OpenMPDialectLLVMIRTranslationInterface::convertOperation(
+    Operation *op, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation) const {
+
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+
+  return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+      .Case([&](omp::BarrierOp) {
+        ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
+        return success();
+      })
+      .Case([&](omp::TaskwaitOp) {
+        ompBuilder->createTaskwait(builder.saveIP());
+        return success();
+      })
+      .Case([&](omp::TaskyieldOp) {
+        ompBuilder->createTaskyield(builder.saveIP());
+        return success();
+      })
+      .Case([&](omp::FlushOp) {
+        // No support in Openmp runtime function (__kmpc_flush) to accept
+        // the argument list.
+        // OpenMP standard states the following:
+        //  "An implementation may implement a flush with a list by ignoring
+        //   the list, and treating it the same as a flush without a list."
+        //
+        // The argument list is discarded so that, flush with a list is treated
+        // same as a flush without a list.
+        ompBuilder->createFlush(builder.saveIP());
+        return success();
+      })
+      .Case([&](omp::ParallelOp) {
+        return convertOmpParallel(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::MasterOp) {
+        return convertOmpMaster(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::WsLoopOp) {
+        return convertOmpWsLoop(*op, builder, moduleTranslation);
+      })
+      .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) {
+        // `yield` and `terminator` can be just omitted. The block structure was
+        // created in the function that handles their parent operation.
+        assert(op->getNumOperands() == 0 &&
+               "unexpected OpenMP terminator with operands");
+        return success();
+      })
+      .Default([&](Operation *inst) {
+        return inst->emitError("unsupported OpenMP operation: ")
+               << inst->getName();
+      });
+}

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 6728511084dd..f563c67270e6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -239,11 +239,12 @@ static Value getPHISourceValue(Block *current, Block *pred,
 }
 
 /// Connect the PHI nodes to the results of preceding blocks.
-template <typename T>
-static void connectPHINodes(T &func, const ModuleTranslation &state) {
+void mlir::LLVM::detail::connectPHINodes(Region &region,
+                                         const ModuleTranslation &state) {
   // Skip the first block, it cannot be branched to and its arguments correspond
   // to the arguments of the LLVM function.
-  for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
+  for (auto it = std::next(region.begin()), eit = region.end(); it != eit;
+       ++it) {
     Block *bb = &*it;
     llvm::BasicBlock *llvmBB = state.lookupBlock(bb);
     auto phis = llvmBB->phis();
@@ -270,294 +271,32 @@ static void connectPHINodes(T &func, const ModuleTranslation &state) {
 }
 
 /// Sort function blocks topologically.
-template <typename T>
-static llvm::SetVector<Block *> topologicalSort(T &f) {
+llvm::SetVector<Block *>
+mlir::LLVM::detail::getTopologicallySortedBlocks(Region &region) {
   // For each block that has not been visited yet (i.e. that has no
   // predecessors), add it to the list as well as its successors.
   llvm::SetVector<Block *> blocks;
-  for (Block &b : f) {
+  for (Block &b : region) {
     if (blocks.count(&b) == 0) {
       llvm::ReversePostOrderTraversal<Block *> traversal(&b);
       blocks.insert(traversal.begin(), traversal.end());
     }
   }
-  assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
+  assert(blocks.size() == region.getBlocks().size() &&
+         "some blocks are not sorted");
 
   return blocks;
 }
 
-/// Convert the OpenMP parallel Operation to LLVM IR.
-LogicalResult
-ModuleTranslation::convertOmpParallel(Operation &opInst,
-                                      llvm::IRBuilder<> &builder) {
-  using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
-  // TODO: support error propagation in OpenMPIRBuilder and use it instead of
-  // relying on captured variables.
-  LogicalResult bodyGenStatus = success();
-
-  auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
-                       llvm::BasicBlock &continuationBlock) {
-    // ParallelOp has only one region associated with it.
-    auto &region = cast<omp::ParallelOp>(opInst).getRegion();
-    convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
-                        continuationBlock, builder, bodyGenStatus);
-  };
-
-  // TODO: Perform appropriate actions according to the data-sharing
-  // attribute (shared, private, firstprivate, ...) of variables.
-  // Currently defaults to shared.
-  auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
-                    llvm::Value &, llvm::Value &vPtr,
-                    llvm::Value *&replacementValue) -> InsertPointTy {
-    replacementValue = &vPtr;
-
-    return codeGenIP;
-  };
-
-  // TODO: Perform finalization actions for variables. This has to be
-  // called for variables which have destructors/finalizers.
-  auto finiCB = [&](InsertPointTy codeGenIP) {};
-
-  llvm::Value *ifCond = nullptr;
-  if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
-    ifCond = lookupValue(ifExprVar);
-  llvm::Value *numThreads = nullptr;
-  if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
-    numThreads = lookupValue(numThreadsVar);
-  llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
-  if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
-    pbKind = llvm::omp::getProcBindKind(bind.getValue());
-  // TODO: Is the Parallel construct cancellable?
-  bool isCancellable = false;
-  // TODO: Determine the actual alloca insertion point, e.g., the function
-  // entry or the alloca insertion point as provided by the body callback
-  // above.
-  llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
-  if (failed(bodyGenStatus))
-    return failure();
-  builder.restoreIP(
-      ompBuilder->createParallel(builder, allocaIP, bodyGenCB, privCB, finiCB,
-                                 ifCond, numThreads, pbKind, isCancellable));
-  return success();
-}
-
-void ModuleTranslation::convertOmpOpRegions(
-    Region &region, StringRef blockName,
-    llvm::BasicBlock &sourceBlock, llvm::BasicBlock &continuationBlock,
-    llvm::IRBuilder<> &builder, LogicalResult &bodyGenStatus) {
-  llvm::LLVMContext &llvmContext = builder.getContext();
-  for (Block &bb : region) {
-    llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
-        llvmContext, blockName, builder.GetInsertBlock()->getParent());
-    mapBlock(&bb, llvmBB);
-  }
-
-  llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
-
-  // Convert blocks one by one in topological order to ensure
-  // defs are converted before uses.
-  llvm::SetVector<Block *> blocks = topologicalSort(region);
-  for (Block *bb : blocks) {
-    llvm::BasicBlock *llvmBB = lookupBlock(bb);
-    // Retarget the branch of the entry block to the entry block of the
-    // converted region (regions are single-entry).
-    if (bb->isEntryBlock()) {
-      assert(sourceTerminator->getNumSuccessors() == 1 &&
-             "provided entry block has multiple successors");
-      assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
-             "ContinuationBlock is not the successor of the entry block");
-      sourceTerminator->setSuccessor(0, llvmBB);
-    }
-
-    llvm::IRBuilder<>::InsertPointGuard guard(builder);
-    if (failed(convertBlock(*bb, bb->isEntryBlock(), builder))) {
-      bodyGenStatus = failure();
-      return;
-    }
-
-    // Special handling for `omp.yield` and `omp.terminator` (we may have more
-    // than one): they return the control to the parent OpenMP dialect operation
-    // so replace them with the branch to the continuation block. We handle this
-    // here to avoid relying inter-function communication through the
-    // ModuleTranslation class to set up the correct insertion point. This is
-    // also consistent with MLIR's idiom of handling special region terminators
-    // in the same code that handles the region-owning operation.
-    if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator()))
-      builder.CreateBr(&continuationBlock);
-  }
-  // Finally, after all blocks have been traversed and values mapped,
-  // connect the PHI nodes to the results of preceding blocks.
-  connectPHINodes(region, *this);
-}
-
-LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst,
-                                                  llvm::IRBuilder<> &builder) {
-  using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
-  // TODO: support error propagation in OpenMPIRBuilder and use it instead of
-  // relying on captured variables.
-  LogicalResult bodyGenStatus = success();
-
-  auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
-                       llvm::BasicBlock &continuationBlock) {
-    // MasterOp has only one region associated with it.
-    auto &region = cast<omp::MasterOp>(opInst).getRegion();
-    convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
-                        continuationBlock, builder, bodyGenStatus);
-  };
-
-  // TODO: Perform finalization actions for variables. This has to be
-  // called for variables which have destructors/finalizers.
-  auto finiCB = [&](InsertPointTy codeGenIP) {};
-
-  builder.restoreIP(ompBuilder->createMaster(builder, bodyGenCB, finiCB));
-  return success();
-}
-
-/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
-LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst,
-                                                  llvm::IRBuilder<> &builder) {
-  auto loop = cast<omp::WsLoopOp>(opInst);
-  // TODO: this should be in the op verifier instead.
-  if (loop.lowerBound().empty())
-    return failure();
-
-  if (loop.getNumLoops() != 1)
-    return opInst.emitOpError("collapsed loops not yet supported");
-
-  if (loop.schedule_val().hasValue() &&
-      omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()) !=
-          omp::ClauseScheduleKind::Static)
-    return opInst.emitOpError(
-        "only static (default) loop schedule is currently supported");
-
-  // Find the loop configuration.
-  llvm::Value *lowerBound = lookupValue(loop.lowerBound()[0]);
-  llvm::Value *upperBound = lookupValue(loop.upperBound()[0]);
-  llvm::Value *step = lookupValue(loop.step()[0]);
-  llvm::Type *ivType = step->getType();
-  llvm::Value *chunk = loop.schedule_chunk_var()
-                           ? lookupValue(loop.schedule_chunk_var())
-                           : llvm::ConstantInt::get(ivType, 1);
-
-  // Set up the source location value for OpenMP runtime.
-  llvm::DISubprogram *subprogram =
-      builder.GetInsertBlock()->getParent()->getSubprogram();
-  const llvm::DILocation *diLoc =
-      debugTranslation->translateLoc(opInst.getLoc(), subprogram);
-  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
-                                                    llvm::DebugLoc(diLoc));
-
-  // Generator of the canonical loop body. Produces an SESE region of basic
-  // blocks.
-  // TODO: support error propagation in OpenMPIRBuilder and use it instead of
-  // relying on captured variables.
-  LogicalResult bodyGenStatus = success();
-  auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
-    llvm::IRBuilder<>::InsertPointGuard guard(builder);
-
-    // Make sure further conversions know about the induction variable.
-    mapValue(loop.getRegion().front().getArgument(0), iv);
-
-    llvm::BasicBlock *entryBlock = ip.getBlock();
-    llvm::BasicBlock *exitBlock =
-        entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
-
-    // Convert the body of the loop.
-    convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
-                        *exitBlock, builder, bodyGenStatus);
-  };
-
-  // Delegate actual loop construction to the OpenMP IRBuilder.
-  // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
-  // i.e. it has a positive step, uses signed integer semantics. Reconsider
-  // this code when WsLoop clearly supports more cases.
-  llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
-  llvm::CanonicalLoopInfo *loopInfo = ompBuilder->createCanonicalLoop(
-      ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
-      /*InclusiveStop=*/loop.inclusive());
-  if (failed(bodyGenStatus))
-    return failure();
-
-  // TODO: get the alloca insertion point from the parallel operation builder.
-  // If we insert the at the top of the current function, they will be passed as
-  // extra arguments into the function the parallel operation builder outlines.
-  // Put them at the start of the current block for now.
-  llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
-      insertBlock, insertBlock->getFirstInsertionPt());
-  loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
-                                                   !loop.nowait(), chunk);
-
-  // Continue building IR after the loop.
-  builder.restoreIP(loopInfo->getAfterIP());
-  return success();
-}
-
-/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
-/// (including OpenMP runtime calls).
-LogicalResult
-ModuleTranslation::convertOmpOperation(Operation &opInst,
-                                       llvm::IRBuilder<> &builder) {
-  if (!ompBuilder) {
-    ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
-    ompBuilder->initialize();
-  }
-  return llvm::TypeSwitch<Operation *, LogicalResult>(&opInst)
-      .Case([&](omp::BarrierOp) {
-        ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
-        return success();
-      })
-      .Case([&](omp::TaskwaitOp) {
-        ompBuilder->createTaskwait(builder.saveIP());
-        return success();
-      })
-      .Case([&](omp::TaskyieldOp) {
-        ompBuilder->createTaskyield(builder.saveIP());
-        return success();
-      })
-      .Case([&](omp::FlushOp) {
-        // No support in Openmp runtime function (__kmpc_flush) to accept
-        // the argument list.
-        // OpenMP standard states the following:
-        //  "An implementation may implement a flush with a list by ignoring
-        //   the list, and treating it the same as a flush without a list."
-        //
-        // The argument list is discarded so that, flush with a list is treated
-        // same as a flush without a list.
-        ompBuilder->createFlush(builder.saveIP());
-        return success();
-      })
-      .Case(
-          [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
-      .Case([&](omp::MasterOp) { return convertOmpMaster(opInst, builder); })
-      .Case([&](omp::WsLoopOp) { return convertOmpWsLoop(opInst, builder); })
-      .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) {
-        // `yield` and `terminator` can be just omitted. The block structure was
-        // created in the function that handles their parent operation.
-        assert(op->getNumOperands() == 0 &&
-               "unexpected OpenMP terminator with operands");
-        return success();
-      })
-      .Default([&](Operation *inst) {
-        return inst->emitError("unsupported OpenMP operation: ")
-               << inst->getName();
-      });
-}
-
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.  LLVM IR Builder does not have a generic interface so
 /// this has to be a long chain of `if`s calling 
diff erent functions with a
 /// 
diff erent number of arguments.
 LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
                                                   llvm::IRBuilder<> &builder) {
-
-  // TODO(zinenko): this should be the "main" conversion here, remove the
-  // dispatch below.
   if (succeeded(iface.convertOperation(&opInst, builder, *this)))
     return success();
 
-  if (ompDialect && opInst.getDialect() == ompDialect)
-    return convertOmpOperation(opInst, builder);
-
   return opInst.emitError("unsupported or non-LLVM operation: ")
          << opInst.getName();
 }
@@ -812,7 +551,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
 
   // Then, convert blocks one by one in topological order to ensure defs are
   // converted before uses.
-  auto blocks = topologicalSort(func);
+  auto blocks = detail::getTopologicallySortedBlocks(func.getBody());
   for (Block *bb : blocks) {
     llvm::IRBuilder<> builder(llvmContext);
     if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
@@ -821,7 +560,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
 
   // Finally, after all blocks have been traversed and values mapped, connect
   // the PHI nodes to the results of preceding blocks.
-  connectPHINodes(func, *this);
+  detail::connectPHINodes(func.getBody(), *this);
   return success();
 }
 
@@ -881,6 +620,11 @@ ModuleTranslation::lookupValues(ValueRange values) {
   return remapped;
 }
 
+const llvm::DILocation *
+ModuleTranslation::translateLoc(Location loc, llvm::DILocalScope *scope) {
+  return debugTranslation->translateLoc(loc, scope);
+}
+
 std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(
     Operation *m, llvm::LLVMContext &llvmContext, StringRef name) {
   m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();

diff  --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
index 4389039b95de..4ce48782672f 100644
--- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
+++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
@@ -18,6 +18,7 @@
 #include "mlir/ExecutionEngine/OptUtils.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/Target/LLVMIR.h"
+#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
 
 #include "llvm/Support/InitLLVM.h"
 #include "llvm/Support/TargetSelect.h"
@@ -32,6 +33,8 @@ int main(int argc, char **argv) {
   mlir::DialectRegistry registry;
   registry.insert<mlir::LLVM::LLVMDialect, mlir::omp::OpenMPDialect>();
   mlir::registerLLVMDialectTranslation(registry);
+  registry.addDialectInterface<mlir::omp::OpenMPDialect,
+                               mlir::OpenMPDialectLLVMIRTranslationInterface>();
 
   return mlir::JitRunnerMain(argc, argv, registry);
 }


        


More information about the Mlir-commits mailing list