[Mlir-commits] [mlir] 72d013d - [mlir] OpenMP-to-LLVM: properly set outer alloca insertion point
Alex Zinenko
llvmlistbot at llvm.org
Mon May 10 01:05:00 PDT 2021
Author: Alex Zinenko
Date: 2021-05-10T10:04:52+02:00
New Revision: 72d013dd73f4b59eb421d7dbbfd0b2bccbb6fc7b
URL: https://github.com/llvm/llvm-project/commit/72d013dd73f4b59eb421d7dbbfd0b2bccbb6fc7b
DIFF: https://github.com/llvm/llvm-project/commit/72d013dd73f4b59eb421d7dbbfd0b2bccbb6fc7b.diff
LOG: [mlir] OpenMP-to-LLVM: properly set outer alloca insertion point
Previously, the OpenMP to LLVM IR conversion was setting the alloca insertion
point to the same position as the main compuation when converting OpenMP
`parallel` operations. This is problematic if, for example, the `parallel`
operation is placed inside a loop and would keep allocating on stack on each
iteration leading to stack overflow.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D101307
Added:
Modified:
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/LLVMIR/openmp-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 137961443af33..bd765879c4ac3 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -176,6 +176,82 @@ class ModuleTranslation {
/// it if it does not exist.
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
+ /// Common CRTP base class for ModuleTranslation stack frames.
+ class StackFrame {
+ public:
+ virtual ~StackFrame() {}
+ TypeID getTypeID() const { return typeID; }
+
+ protected:
+ explicit StackFrame(TypeID typeID) : typeID(typeID) {}
+
+ private:
+ const TypeID typeID;
+ virtual void anchor();
+ };
+
+ /// Concrete CRTP base class for ModuleTranslation stack frames. When
+ /// translating operations with regions, users of ModuleTranslation can store
+ /// state on ModuleTranslation stack before entering the region and inspect
+ /// it when converting operations nested within that region. Users are
+ /// expected to derive this class and put any relevant information into fields
+ /// of the derived class. The usual isa/dyn_cast functionality is available
+ /// for instances of derived classes.
+ template <typename Derived>
+ class StackFrameBase : public StackFrame {
+ public:
+ explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
+ };
+
+ /// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
+ /// be derived from `StackFrameBase<T>` and constructible from the provided
+ /// arguments. Doing this before entering the region of the op being
+ /// translated makes the frame available when translating ops within that
+ /// region.
+ template <typename T, typename... Args>
+ void stackPush(Args &&... args) {
+ static_assert(
+ std::is_base_of<StackFrame, T>::value,
+ "can only push instances of StackFrame on ModuleTranslation stack");
+ stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
+ }
+
+ /// Pops the last element from the ModuleTranslation stack.
+ void stackPop() { stack.pop_back(); }
+
+ /// Calls `callback` for every ModuleTranslation stack frame of type `T`
+ /// starting from the top of the stack.
+ template <typename T>
+ WalkResult
+ stackWalk(llvm::function_ref<WalkResult(const T &)> callback) const {
+ static_assert(std::is_base_of<StackFrame, T>::value,
+ "expected T derived from StackFrame");
+ if (!callback)
+ return WalkResult::skip();
+ for (const std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
+ if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
+ WalkResult result = callback(*ptr);
+ if (result.wasInterrupted())
+ return result;
+ }
+ }
+ return WalkResult::advance();
+ }
+
+ /// RAII object calling stackPush/stackPop on construction/destruction.
+ template <typename T>
+ struct SaveStack {
+ template <typename... Args>
+ explicit SaveStack(ModuleTranslation &m, Args &&...args)
+ : moduleTranslation(m) {
+ moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
+ }
+ ~SaveStack() { moduleTranslation.stackPop(); }
+
+ private:
+ ModuleTranslation &moduleTranslation;
+ };
+
private:
ModuleTranslation(Operation *module,
std::unique_ptr<llvm::Module> llvmModule);
@@ -233,6 +309,10 @@ class ModuleTranslation {
/// metadata. The metadata is attached to Latch block branches with this
/// attribute.
DenseMap<Attribute, llvm::MDNode *> loopOptionsMetadataMapping;
+
+ /// Stack of user-specified state elements, useful when translating operations
+ /// with regions.
+ SmallVector<std::unique_ptr<StackFrame>> stack;
};
namespace detail {
@@ -270,4 +350,14 @@ llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
} // namespace LLVM
} // namespace mlir
+namespace llvm {
+template <typename T>
+struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
+ static inline bool
+ doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
+ return frame.getTypeID() == ::mlir::TypeID::get<T>();
+ }
+};
+} // namespace llvm
+
#endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3cd201ad08a5c..6236092c0dae3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -23,6 +23,42 @@
using namespace mlir;
+namespace {
+/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
+/// insertion points for allocas.
+class OpenMPAllocaStackFrame
+ : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
+public:
+ explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
+ : allocaInsertPoint(allocaIP) {}
+ llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+};
+} // namespace
+
+/// Find the insertion point for allocas given the current insertion point for
+/// normal operations in the builder.
+static llvm::OpenMPIRBuilder::InsertPointTy
+findAllocaInsertPoint(llvm::IRBuilderBase &builder,
+ const LLVM::ModuleTranslation &moduleTranslation) {
+ // If there is an alloca insertion point on stack, i.e. we are in a nested
+ // operation and a specific point was provided by some surrounding operation,
+ // use it.
+ llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+ WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
+ [&](const OpenMPAllocaStackFrame &frame) {
+ allocaInsertPoint = frame.allocaInsertPoint;
+ return WalkResult::interrupt();
+ });
+ if (walkResult.wasInterrupted())
+ return allocaInsertPoint;
+
+ // Otherwise, insert to the entry block of the surrounding function.
+ llvm::BasicBlock &funcEntryBlock =
+ builder.GetInsertBlock()->getParent()->getEntryBlock();
+ return llvm::OpenMPIRBuilder::InsertPointTy(
+ &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
+}
+
/// Converts the given region that appears within an OpenMP dialect operation to
/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
/// region, and a branch from any block with an successor-less OpenMP terminator
@@ -91,6 +127,11 @@ convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
llvm::BasicBlock &continuationBlock) {
+ // Save the alloca insertion point on ModuleTranslation stack for use in
+ // nested regions.
+ LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
+ moduleTranslation, allocaIP);
+
// ParallelOp has only one region associated with it.
auto ®ion = cast<omp::ParallelOp>(opInst).getRegion();
convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
@@ -124,18 +165,14 @@ convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
pbKind = llvm::omp::getProcBindKind(bind.getValue());
// TODO: Is the Parallel construct cancellable?
bool isCancellable = false;
- // TODO: Determine the actual alloca insertion point, e.g., the function
- // entry or the alloca insertion point as provided by the body callback
- // above.
- llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
- if (failed(bodyGenStatus))
- return failure();
+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(
builder.saveIP(), builder.getCurrentDebugLocation());
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
- ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
- isCancellable));
- return success();
+ ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB,
+ privCB, finiCB, ifCond, numThreads, pbKind, isCancellable));
+
+ return bodyGenStatus;
}
/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
@@ -233,7 +270,6 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
// i.e. it has a positive step, uses signed integer semantics. Reconsider
// this code when WsLoop clearly supports more cases.
- llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
llvm::CanonicalLoopInfo *loopInfo =
moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
@@ -241,12 +277,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(bodyGenStatus))
return failure();
- // TODO: get the alloca insertion point from the parallel operation builder.
- // If we insert the at the top of the current function, they will be passed as
- // extra arguments into the function the parallel operation builder outlines.
- // Put them at the start of the current block for now.
- llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
- insertBlock, insertBlock->getFirstInsertionPt());
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
+ findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointTy afterIP;
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
if (isStatic) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index fb5319546adc3..7aa2ffda248e2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -755,6 +755,8 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
return llvmModule->getOrInsertNamedMetadata(name);
}
+void ModuleTranslation::StackFrame::anchor() {}
+
static std::unique_ptr<llvm::Module>
prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
StringRef name) {
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 71f0277744f37..383ced4529534 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -151,6 +151,13 @@ llvm.func @test_omp_parallel_num_threads_3() -> () {
// CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]])
llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
+// Check that the allocas are emitted by the OpenMPIRBuilder at the top of the
+// function, before the condition. Allocas are only emitted by the builder when
+// the `if` clause is present. We match specific SSA value names since LLVM
+// actually produces those names.
+// CHECK: %tid.addr{{.*}} = alloca i32
+// CHECK: %zero.addr{{.*}} = alloca i32
+
// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
%0 = llvm.mlir.constant(0 : index) : i32
%1 = llvm.icmp "slt" %arg0, %0 : i32
@@ -184,6 +191,60 @@ llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
// CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]]
// CHECK: call void @__kmpc_barrier
+// -----
+
+// CHECK-LABEL: @test_nested_alloca_ip
+llvm.func @test_nested_alloca_ip(%arg0: i32) -> () {
+
+ // Check that the allocas are emitted by the OpenMPIRBuilder at the top of
+ // the function, before the condition. Allocas are only emitted by the
+ // builder when the `if` clause is present. We match specific SSA value names
+ // since LLVM actually produces those names and ensure they come before the
+ // "icmp" that is the first operation we emit.
+ // CHECK: %tid.addr{{.*}} = alloca i32
+ // CHECK: %zero.addr{{.*}} = alloca i32
+ // CHECK: icmp slt i32 %{{.*}}, 0
+ %0 = llvm.mlir.constant(0 : index) : i32
+ %1 = llvm.icmp "slt" %arg0, %0 : i32
+
+ omp.parallel if(%1 : i1) {
+ // The "parallel" operation will be outlined, check the the function is
+ // produced. Inside that function, further allocas should be placed before
+ // another "icmp".
+ // CHECK: define
+ // CHECK: %tid.addr{{.*}} = alloca i32
+ // CHECK: %zero.addr{{.*}} = alloca i32
+ // CHECK: icmp slt i32 %{{.*}}, 1
+ %2 = llvm.mlir.constant(1 : index) : i32
+ %3 = llvm.icmp "slt" %arg0, %2 : i32
+
+ omp.parallel if(%3 : i1) {
+ // One more nesting level.
+ // CHECK: define
+ // CHECK: %tid.addr{{.*}} = alloca i32
+ // CHECK: %zero.addr{{.*}} = alloca i32
+ // CHECK: icmp slt i32 %{{.*}}, 2
+
+ %4 = llvm.mlir.constant(2 : index) : i32
+ %5 = llvm.icmp "slt" %arg0, %4 : i32
+
+ omp.parallel if(%5 : i1) {
+ omp.barrier
+ omp.terminator
+ }
+
+ omp.barrier
+ omp.terminator
+ }
+ omp.barrier
+ omp.terminator
+ }
+
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: define void @test_omp_parallel_3()
llvm.func @test_omp_parallel_3() -> () {
// CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})
More information about the Mlir-commits
mailing list