[Mlir-commits] [mlir] d9067dc - Lowering of OpenMP Parallel operation to LLVM IR 1/n
Kiran Chandramohan
llvmlistbot at llvm.org
Mon Jul 13 15:59:53 PDT 2020
Author: Kiran Chandramohan
Date: 2020-07-13T23:55:45+01:00
New Revision: d9067dca7ba7cda97a86ec22106e06ffc700ecbf
URL: https://github.com/llvm/llvm-project/commit/d9067dca7ba7cda97a86ec22106e06ffc700ecbf
DIFF: https://github.com/llvm/llvm-project/commit/d9067dca7ba7cda97a86ec22106e06ffc700ecbf.diff
LOG: Lowering of OpenMP Parallel operation to LLVM IR 1/n
This patch introduces lowering of the OpenMP parallel operation to LLVM
IR using the OpenMPIRBuilder.
Functions topologicalSort and connectPhiNodes are generalised so that
they work with operations also. connectPhiNodes is also made static.
Lowering works for a parallel region with multiple blocks. Clauses and
arguments of the OpenMP operation are not handled.
Reviewed By: rriddle, anchu-rajendran
Differential Revision: https://reviews.llvm.org/D81660
Added:
Modified:
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/openmp-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 3be6c97322b5..642282f8af18 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -24,7 +24,6 @@ def OpenMP_Dialect : Dialect {
class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
Op<OpenMP_Dialect, mnemonic, traits>;
-
//===----------------------------------------------------------------------===//
// 2.6 parallel Construct
//===----------------------------------------------------------------------===//
@@ -81,8 +80,8 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
of the parallel region.
}];
- let arguments = (ins Optional<I1>:$if_expr_var,
- Optional<AnyInteger>:$num_threads_var,
+ let arguments = (ins Optional<AnyType>:$if_expr_var,
+ Optional<AnyType>:$num_threads_var,
OptionalAttr<ClauseDefault>:$default_val,
Variadic<AnyType>:$private_vars,
Variadic<AnyType>:$firstprivate_vars,
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 3a701018beb5..e44ae976e0dd 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -87,6 +87,8 @@ class ModuleTranslation {
llvm::IRBuilder<> &builder);
virtual LogicalResult convertOmpOperation(Operation &op,
llvm::IRBuilder<> &builder);
+ virtual LogicalResult convertOmpParallel(Operation &op,
+ llvm::IRBuilder<> &builder);
static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
/// A helper to look up remapped operands in the value remapping table.
@@ -100,7 +102,6 @@ class ModuleTranslation {
LogicalResult convertFunctions();
LogicalResult convertGlobals();
LogicalResult convertOneFunction(LLVMFuncOp func);
- void connectPHINodes(LLVMFuncOp func);
LogicalResult convertBlock(Block &bb, bool ignoreArguments);
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 657aa84afe1c..0defea6bbbb9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -25,11 +25,13 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
using namespace mlir;
@@ -304,7 +306,160 @@ ModuleTranslation::ModuleTranslation(Operation *module,
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
}
-ModuleTranslation::~ModuleTranslation() {}
+ModuleTranslation::~ModuleTranslation() {
+ if (ompBuilder)
+ ompBuilder->finalize();
+}
+
+/// Get the SSA value passed to the current block from the terminator operation
+/// of its predecessor.
+static Value getPHISourceValue(Block *current, Block *pred,
+ unsigned numArguments, unsigned index) {
+ Operation &terminator = *pred->getTerminator();
+ if (isa<LLVM::BrOp>(terminator))
+ return terminator.getOperand(index);
+
+ // For conditional branches, we need to check if the current block is reached
+ // through the "true" or the "false" branch and take the relevant operands.
+ auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
+ assert(condBranchOp &&
+ "only branch operations can be terminators of a block that "
+ "has successors");
+ assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
+ "successors with arguments in LLVM conditional branches must be "
+ "
diff erent blocks");
+
+ return condBranchOp.getSuccessor(0) == current
+ ? condBranchOp.trueDestOperands()[index]
+ : condBranchOp.falseDestOperands()[index];
+}
+
+/// Connect the PHI nodes to the results of preceding blocks.
+template <typename T>
+static void
+connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
+ const DenseMap<Block *, llvm::BasicBlock *> &blockMapping) {
+ // Skip the first block, it cannot be branched to and its arguments correspond
+ // to the arguments of the LLVM function.
+ for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
+ Block *bb = &*it;
+ llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
+ auto phis = llvmBB->phis();
+ auto numArguments = bb->getNumArguments();
+ assert(numArguments == std::distance(phis.begin(), phis.end()));
+ for (auto &numberedPhiNode : llvm::enumerate(phis)) {
+ auto &phiNode = numberedPhiNode.value();
+ unsigned index = numberedPhiNode.index();
+ for (auto *pred : bb->getPredecessors()) {
+ phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
+ bb, pred, numArguments, index)),
+ blockMapping.lookup(pred));
+ }
+ }
+ }
+}
+
+// TODO: implement an iterative version
+static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
+ blocks.insert(b);
+ for (Block *bb : b->getSuccessors()) {
+ if (blocks.count(bb) == 0)
+ topologicalSortImpl(blocks, bb);
+ }
+}
+
+/// Sort function blocks topologically.
+template <typename T>
+static llvm::SetVector<Block *> topologicalSort(T &f) {
+ // For each blocks that has not been visited yet (i.e. that has no
+ // predecessors), add it to the list and traverse its successors in DFS
+ // preorder.
+ llvm::SetVector<Block *> blocks;
+ for (Block &b : f) {
+ if (blocks.count(&b) == 0)
+ topologicalSortImpl(blocks, &b);
+ }
+ assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
+
+ return blocks;
+}
+
+/// Convert the OpenMP parallel Operation to LLVM IR.
+LogicalResult
+ModuleTranslation::convertOmpParallel(Operation &opInst,
+ llvm::IRBuilder<> &builder) {
+ using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+
+ auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
+ llvm::BasicBlock &continuationIP) {
+ llvm::LLVMContext &llvmContext = llvmModule->getContext();
+
+ llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock();
+ llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator();
+
+ builder.SetInsertPoint(codeGenIPBB);
+
+ for (auto ®ion : opInst.getRegions()) {
+ for (auto &bb : region) {
+ auto *llvmBB = llvm::BasicBlock::Create(
+ llvmContext, "omp.par.region", codeGenIP.getBlock()->getParent());
+ blockMapping[&bb] = llvmBB;
+ }
+
+ // Then, convert blocks one by one in topological order to ensure
+ // defs are converted before uses.
+ llvm::SetVector<Block *> blocks = topologicalSort(region);
+ for (auto indexedBB : llvm::enumerate(blocks)) {
+ Block *bb = indexedBB.value();
+ llvm::BasicBlock *curLLVMBB = blockMapping[bb];
+ if (bb->isEntryBlock())
+ codeGenIPBBTI->setSuccessor(0, curLLVMBB);
+
+ // TODO: Error not returned up the hierarchy
+ if (failed(
+ convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
+ return;
+
+ // If this block has the terminator then add a jump to
+ // continuation bb
+ for (auto &op : *bb) {
+ if (isa<omp::TerminatorOp>(op)) {
+ builder.SetInsertPoint(curLLVMBB);
+ builder.CreateBr(&continuationIP);
+ }
+ }
+ }
+ // Finally, after all blocks have been traversed and values mapped,
+ // connect the PHI nodes to the results of preceding blocks.
+ connectPHINodes(region, valueMapping, blockMapping);
+ }
+ };
+
+ // TODO: Perform appropriate actions according to the data-sharing
+ // attribute (shared, private, firstprivate, ...) of variables.
+ // Currently defaults to shared.
+ auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
+ llvm::Value &vPtr,
+ llvm::Value *&replacementValue) -> InsertPointTy {
+ replacementValue = &vPtr;
+
+ return codeGenIP;
+ };
+
+ // TODO: Perform finalization actions for variables. This has to be
+ // called for variables which have destructors/finalizers.
+ auto finiCB = [&](InsertPointTy codeGenIP) {};
+
+ // TODO: The various operands of parallel operation are not handled.
+ // Parallel operation is created with some default options for now.
+ llvm::Value *ifCond = nullptr;
+ llvm::Value *numThreads = nullptr;
+ bool isCancellable = false;
+ builder.restoreIP(ompBuilder->CreateParallel(
+ builder, bodyGenCB, privCB, finiCB, ifCond, numThreads,
+ llvm::omp::OMP_PROC_BIND_default, isCancellable));
+ return success();
+}
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
/// (including OpenMP runtime calls).
@@ -340,6 +495,9 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
ompBuilder->CreateFlush(builder.saveIP());
return success();
})
+ .Case([&](omp::TerminatorOp) { return success(); })
+ .Case(
+ [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
.Default([&](Operation *inst) {
return inst->emitError("unsupported OpenMP operation: ")
<< inst->getName();
@@ -556,75 +714,6 @@ LogicalResult ModuleTranslation::convertGlobals() {
return success();
}
-/// Get the SSA value passed to the current block from the terminator operation
-/// of its predecessor.
-static Value getPHISourceValue(Block *current, Block *pred,
- unsigned numArguments, unsigned index) {
- auto &terminator = *pred->getTerminator();
- if (isa<LLVM::BrOp>(terminator)) {
- return terminator.getOperand(index);
- }
-
- // For conditional branches, we need to check if the current block is reached
- // through the "true" or the "false" branch and take the relevant operands.
- auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
- assert(condBranchOp &&
- "only branch operations can be terminators of a block that "
- "has successors");
- assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
- "successors with arguments in LLVM conditional branches must be "
- "
diff erent blocks");
-
- return condBranchOp.getSuccessor(0) == current
- ? condBranchOp.trueDestOperands()[index]
- : condBranchOp.falseDestOperands()[index];
-}
-
-void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
- // Skip the first block, it cannot be branched to and its arguments correspond
- // to the arguments of the LLVM function.
- for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
- Block *bb = &*it;
- llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
- auto phis = llvmBB->phis();
- auto numArguments = bb->getNumArguments();
- assert(numArguments == std::distance(phis.begin(), phis.end()));
- for (auto &numberedPhiNode : llvm::enumerate(phis)) {
- auto &phiNode = numberedPhiNode.value();
- unsigned index = numberedPhiNode.index();
- for (auto *pred : bb->getPredecessors()) {
- phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
- bb, pred, numArguments, index)),
- blockMapping.lookup(pred));
- }
- }
- }
-}
-
-// TODO: implement an iterative version
-static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
- blocks.insert(b);
- for (Block *bb : b->getSuccessors()) {
- if (blocks.count(bb) == 0)
- topologicalSortImpl(blocks, bb);
- }
-}
-
-/// Sort function blocks topologically.
-static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) {
- // For each blocks that has not been visited yet (i.e. that has no
- // predecessors), add it to the list and traverse its successors in DFS
- // preorder.
- llvm::SetVector<Block *> blocks;
- for (Block &b : f) {
- if (blocks.count(&b) == 0)
- topologicalSortImpl(blocks, &b);
- }
- assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
-
- return blocks;
-}
-
/// Attempts to add an attribute identified by `key`, optionally with the given
/// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
/// attribute has a kind known to LLVM IR, create the attribute of this kind,
@@ -772,7 +861,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// Finally, after all blocks have been traversed and values mapped, connect
// the PHI nodes to the results of preceding blocks.
- connectPHINodes(func);
+ connectPHINodes(func, valueMapping, blockMapping);
return success();
}
diff --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir
index ddfc2a4cf786..c8acd8022b2b 100644
--- a/mlir/test/Target/openmp-llvm.mlir
+++ b/mlir/test/Target/openmp-llvm.mlir
@@ -32,3 +32,49 @@ llvm.func @test_flush_construct(%arg0: !llvm.i32) {
// CHECK-NEXT: ret void
llvm.return
}
+
+// CHECK-LABEL: define void @test_omp_parallel_1()
+llvm.func @test_omp_parallel_1() -> () {
+ // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_1:.*]] to {{.*}}
+ omp.parallel {
+ omp.barrier
+ omp.terminator
+ }
+
+ llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_1]]
+ // CHECK: call void @__kmpc_barrier
+
+llvm.func @body(!llvm.i64)
+
+// CHECK-LABEL: define void @test_omp_parallel_2()
+llvm.func @test_omp_parallel_2() -> () {
+ // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_2:.*]] to {{.*}}
+ omp.parallel {
+ ^bb0:
+ %0 = llvm.mlir.constant(1 : index) : !llvm.i64
+ %1 = llvm.mlir.constant(42 : index) : !llvm.i64
+ llvm.call @body(%0) : (!llvm.i64) -> ()
+ llvm.call @body(%1) : (!llvm.i64) -> ()
+ llvm.br ^bb1
+
+ ^bb1:
+ %2 = llvm.add %0, %1 : !llvm.i64
+ llvm.call @body(%2) : (!llvm.i64) -> ()
+ omp.terminator
+ }
+ llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_2]]
+ // CHECK-LABEL: omp.par.region:
+ // CHECK: br label %omp.par.region1
+ // CHECK-LABEL: omp.par.region1:
+ // CHECK: call void @body(i64 1)
+ // CHECK: call void @body(i64 42)
+ // CHECK: br label %omp.par.region2
+ // CHECK-LABEL: omp.par.region2:
+ // CHECK: call void @body(i64 43)
+ // CHECK: br label %omp.par.pre_finalize
More information about the Mlir-commits
mailing list