[Mlir-commits] [mlir] 72d013d - [mlir] OpenMP-to-LLVM: properly set outer alloca insertion point
    Alex Zinenko 
    llvmlistbot at llvm.org
       
    Mon May 10 01:05:00 PDT 2021
    
    
  
Author: Alex Zinenko
Date: 2021-05-10T10:04:52+02:00
New Revision: 72d013dd73f4b59eb421d7dbbfd0b2bccbb6fc7b
URL: https://github.com/llvm/llvm-project/commit/72d013dd73f4b59eb421d7dbbfd0b2bccbb6fc7b
DIFF: https://github.com/llvm/llvm-project/commit/72d013dd73f4b59eb421d7dbbfd0b2bccbb6fc7b.diff
LOG: [mlir] OpenMP-to-LLVM: properly set outer alloca insertion point
Previously, the OpenMP to LLVM IR conversion was setting the alloca insertion
point to the same position as the main compuation when converting OpenMP
`parallel` operations. This is problematic if, for example, the `parallel`
operation is placed inside a loop and would keep allocating on stack on each
iteration leading to stack overflow.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D101307
Added: 
    
Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Target/LLVMIR/openmp-llvm.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 137961443af33..bd765879c4ac3 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -176,6 +176,82 @@ class ModuleTranslation {
   /// it if it does not exist.
   llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
 
+  /// Common CRTP base class for ModuleTranslation stack frames.
+  class StackFrame {
+  public:
+    virtual ~StackFrame() {}
+    TypeID getTypeID() const { return typeID; }
+
+  protected:
+    explicit StackFrame(TypeID typeID) : typeID(typeID) {}
+
+  private:
+    const TypeID typeID;
+    virtual void anchor();
+  };
+
+  /// Concrete CRTP base class for ModuleTranslation stack frames. When
+  /// translating operations with regions, users of ModuleTranslation can store
+  /// state on ModuleTranslation stack before entering the region and inspect
+  /// it when converting operations nested within that region. Users are
+  /// expected to derive this class and put any relevant information into fields
+  /// of the derived class. The usual isa/dyn_cast functionality is available
+  /// for instances of derived classes.
+  template <typename Derived>
+  class StackFrameBase : public StackFrame {
+  public:
+    explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
+  };
+
+  /// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
+  /// be derived from `StackFrameBase<T>` and constructible from the provided
+  /// arguments. Doing this before entering the region of the op being
+  /// translated makes the frame available when translating ops within that
+  /// region.
+  template <typename T, typename... Args>
+  void stackPush(Args &&... args) {
+    static_assert(
+        std::is_base_of<StackFrame, T>::value,
+        "can only push instances of StackFrame on ModuleTranslation stack");
+    stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
+  }
+
+  /// Pops the last element from the ModuleTranslation stack.
+  void stackPop() { stack.pop_back(); }
+
+  /// Calls `callback` for every ModuleTranslation stack frame of type `T`
+  /// starting from the top of the stack.
+  template <typename T>
+  WalkResult
+  stackWalk(llvm::function_ref<WalkResult(const T &)> callback) const {
+    static_assert(std::is_base_of<StackFrame, T>::value,
+                  "expected T derived from StackFrame");
+    if (!callback)
+      return WalkResult::skip();
+    for (const std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
+      if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
+        WalkResult result = callback(*ptr);
+        if (result.wasInterrupted())
+          return result;
+      }
+    }
+    return WalkResult::advance();
+  }
+
+  /// RAII object calling stackPush/stackPop on construction/destruction.
+  template <typename T>
+  struct SaveStack {
+    template <typename... Args>
+    explicit SaveStack(ModuleTranslation &m, Args &&...args)
+        : moduleTranslation(m) {
+      moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
+    }
+    ~SaveStack() { moduleTranslation.stackPop(); }
+
+  private:
+    ModuleTranslation &moduleTranslation;
+  };
+
 private:
   ModuleTranslation(Operation *module,
                     std::unique_ptr<llvm::Module> llvmModule);
@@ -233,6 +309,10 @@ class ModuleTranslation {
   /// metadata. The metadata is attached to Latch block branches with this
   /// attribute.
   DenseMap<Attribute, llvm::MDNode *> loopOptionsMetadataMapping;
+
+  /// Stack of user-specified state elements, useful when translating operations
+  /// with regions.
+  SmallVector<std::unique_ptr<StackFrame>> stack;
 };
 
 namespace detail {
@@ -270,4 +350,14 @@ llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
 } // namespace LLVM
 } // namespace mlir
 
+namespace llvm {
+template <typename T>
+struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
+  static inline bool
+  doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
+    return frame.getTypeID() == ::mlir::TypeID::get<T>();
+  }
+};
+} // namespace llvm
+
 #endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3cd201ad08a5c..6236092c0dae3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -23,6 +23,42 @@
 
 using namespace mlir;
 
+namespace {
+/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
+/// insertion points for allocas.
+class OpenMPAllocaStackFrame
+    : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
+public:
+  explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
+      : allocaInsertPoint(allocaIP) {}
+  llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+};
+} // namespace
+
+/// Find the insertion point for allocas given the current insertion point for
+/// normal operations in the builder.
+static llvm::OpenMPIRBuilder::InsertPointTy
+findAllocaInsertPoint(llvm::IRBuilderBase &builder,
+                      const LLVM::ModuleTranslation &moduleTranslation) {
+  // If there is an alloca insertion point on stack, i.e. we are in a nested
+  // operation and a specific point was provided by some surrounding operation,
+  // use it.
+  llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+  WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
+      [&](const OpenMPAllocaStackFrame &frame) {
+        allocaInsertPoint = frame.allocaInsertPoint;
+        return WalkResult::interrupt();
+      });
+  if (walkResult.wasInterrupted())
+    return allocaInsertPoint;
+
+  // Otherwise, insert to the entry block of the surrounding function.
+  llvm::BasicBlock &funcEntryBlock =
+      builder.GetInsertBlock()->getParent()->getEntryBlock();
+  return llvm::OpenMPIRBuilder::InsertPointTy(
+      &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
+}
+
 /// Converts the given region that appears within an OpenMP dialect operation to
 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
 /// region, and a branch from any block with an successor-less OpenMP terminator
@@ -91,6 +127,11 @@ convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
 
   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
                        llvm::BasicBlock &continuationBlock) {
+    // Save the alloca insertion point on ModuleTranslation stack for use in
+    // nested regions.
+    LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
+        moduleTranslation, allocaIP);
+
     // ParallelOp has only one region associated with it.
     auto ®ion = cast<omp::ParallelOp>(opInst).getRegion();
     convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
@@ -124,18 +165,14 @@ convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
     pbKind = llvm::omp::getProcBindKind(bind.getValue());
   // TODO: Is the Parallel construct cancellable?
   bool isCancellable = false;
-  // TODO: Determine the actual alloca insertion point, e.g., the function
-  // entry or the alloca insertion point as provided by the body callback
-  // above.
-  llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
-  if (failed(bodyGenStatus))
-    return failure();
+
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
       builder.saveIP(), builder.getCurrentDebugLocation());
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
-      ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
-      isCancellable));
-  return success();
+      ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB,
+      privCB, finiCB, ifCond, numThreads, pbKind, isCancellable));
+
+  return bodyGenStatus;
 }
 
 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
@@ -233,7 +270,6 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
   // i.e. it has a positive step, uses signed integer semantics. Reconsider
   // this code when WsLoop clearly supports more cases.
-  llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
   llvm::CanonicalLoopInfo *loopInfo =
       moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
           ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
@@ -241,12 +277,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(bodyGenStatus))
     return failure();
 
-  // TODO: get the alloca insertion point from the parallel operation builder.
-  // If we insert the at the top of the current function, they will be passed as
-  // extra arguments into the function the parallel operation builder outlines.
-  // Put them at the start of the current block for now.
-  llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
-      insertBlock, insertBlock->getFirstInsertionPt());
+  llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
+      findAllocaInsertPoint(builder, moduleTranslation);
   llvm::OpenMPIRBuilder::InsertPointTy afterIP;
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   if (isStatic) {
diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index fb5319546adc3..7aa2ffda248e2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -755,6 +755,8 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
   return llvmModule->getOrInsertNamedMetadata(name);
 }
 
+void ModuleTranslation::StackFrame::anchor() {}
+
 static std::unique_ptr<llvm::Module>
 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
                   StringRef name) {
diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 71f0277744f37..383ced4529534 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -151,6 +151,13 @@ llvm.func @test_omp_parallel_num_threads_3() -> () {
 // CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]])
 llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
 
+// Check that the allocas are emitted by the OpenMPIRBuilder at the top of the
+// function, before the condition. Allocas are only emitted by the builder when
+// the `if` clause is present. We match specific SSA value names since LLVM
+// actually produces those names.
+// CHECK: %tid.addr{{.*}} = alloca i32
+// CHECK: %zero.addr{{.*}} = alloca i32
+
 // CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
   %0 = llvm.mlir.constant(0 : index) : i32
   %1 = llvm.icmp "slt" %arg0, %0 : i32
@@ -184,6 +191,60 @@ llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
 // CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]]
   // CHECK: call void @__kmpc_barrier
 
+// -----
+
+// CHECK-LABEL: @test_nested_alloca_ip
+llvm.func @test_nested_alloca_ip(%arg0: i32) -> () {
+
+  // Check that the allocas are emitted by the OpenMPIRBuilder at the top of
+  // the function, before the condition. Allocas are only emitted by the
+  // builder when the `if` clause is present. We match specific SSA value names
+  // since LLVM actually produces those names and ensure they come before the
+  // "icmp" that is the first operation we emit.
+  // CHECK: %tid.addr{{.*}} = alloca i32
+  // CHECK: %zero.addr{{.*}} = alloca i32
+  // CHECK: icmp slt i32 %{{.*}}, 0
+  %0 = llvm.mlir.constant(0 : index) : i32
+  %1 = llvm.icmp "slt" %arg0, %0 : i32
+
+  omp.parallel if(%1 : i1) {
+    // The "parallel" operation will be outlined, check the the function is
+    // produced. Inside that function, further allocas should be placed before
+    // another "icmp".
+    // CHECK: define
+    // CHECK: %tid.addr{{.*}} = alloca i32
+    // CHECK: %zero.addr{{.*}} = alloca i32
+    // CHECK: icmp slt i32 %{{.*}}, 1
+    %2 = llvm.mlir.constant(1 : index) : i32
+    %3 = llvm.icmp "slt" %arg0, %2 : i32
+
+    omp.parallel if(%3 : i1) {
+      // One more nesting level.
+      // CHECK: define
+      // CHECK: %tid.addr{{.*}} = alloca i32
+      // CHECK: %zero.addr{{.*}} = alloca i32
+      // CHECK: icmp slt i32 %{{.*}}, 2
+
+      %4 = llvm.mlir.constant(2 : index) : i32
+      %5 = llvm.icmp "slt" %arg0, %4 : i32
+
+      omp.parallel if(%5 : i1) {
+        omp.barrier
+        omp.terminator
+      }
+
+      omp.barrier
+      omp.terminator
+    }
+    omp.barrier
+    omp.terminator
+  }
+
+  llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: define void @test_omp_parallel_3()
 llvm.func @test_omp_parallel_3() -> () {
   // CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})
        
    
    
More information about the Mlir-commits
mailing list