[llvm-branch-commits] [mlir] 10164a2 - [mlir] Refactor translation of OpenMP dialect ops to LLVM IR

Alex Zinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 7 04:38:21 PST 2021


Author: Alex Zinenko
Date: 2021-01-07T13:33:50+01:00
New Revision: 10164a2e50b4d7064bd02e7403aae6dd319cdd64

URL: https://github.com/llvm/llvm-project/commit/10164a2e50b4d7064bd02e7403aae6dd319cdd64
DIFF: https://github.com/llvm/llvm-project/commit/10164a2e50b4d7064bd02e7403aae6dd319cdd64.diff

LOG: [mlir] Refactor translation of OpenMP dialect ops to LLVM IR

The original implementation of the OpenMP dialect to LLVM IR translation has
been relying on a stack of insertion points for delayed insertion of branch
instructions that correspond to terminator ops. This is an intrusive into
ModuleTranslation and makes the translation non-local. A recent addition of the
WsLoop translation exercised another approach where the parent op is
responsible for converting terminators of all blocks in its regions. Use this
approach for other OpenMP dialect operations with regions, remove the stack and
deduplicate the code for converting such regions.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D94086

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 7691adfeef14..4a1871cac4dc 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -93,11 +93,11 @@ class ModuleTranslation {
                                            llvm::IRBuilder<> &builder);
   virtual LogicalResult convertOmpMaster(Operation &op,
                                          llvm::IRBuilder<> &builder);
-  void convertOmpOpRegions(Region &region,
+  void convertOmpOpRegions(Region &region, StringRef blockName,
                            DenseMap<Value, llvm::Value *> &valueMapping,
                            DenseMap<Block *, llvm::BasicBlock *> &blockMapping,
-                           llvm::Instruction *codeGenIPBBTI,
-                           llvm::BasicBlock &continuationIP,
+                           llvm::BasicBlock &sourceBlock,
+                           llvm::BasicBlock &continuationBlock,
                            llvm::IRBuilder<> &builder,
                            LogicalResult &bodyGenStatus);
   virtual LogicalResult convertOmpWsLoop(Operation &opInst,
@@ -121,7 +121,8 @@ class ModuleTranslation {
   LogicalResult convertFunctions();
   LogicalResult convertGlobals();
   LogicalResult convertOneFunction(LLVMFuncOp func);
-  LogicalResult convertBlock(Block &bb, bool ignoreArguments);
+  LogicalResult convertBlock(Block &bb, bool ignoreArguments,
+                             llvm::IRBuilder<> &builder);
 
   llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
                                   Location loc);
@@ -134,14 +135,11 @@ class ModuleTranslation {
 
   /// Builder for LLVM IR generation of OpenMP constructs.
   std::unique_ptr<llvm::OpenMPIRBuilder> ompBuilder;
+
   /// Precomputed pointer to OpenMP dialect. Note this can be nullptr if the
   /// OpenMP dialect hasn't been loaded (it is always loaded if there are OpenMP
   /// operations in the module though).
   const Dialect *ompDialect;
-  /// Stack which stores the target block to which a branch a must be added when
-  /// a terminator is seen. A stack is required to handle nested OpenMP parallel
-  /// regions.
-  SmallVector<llvm::BasicBlock *, 4> ompContinuationIPStack;
 
   /// Mappings between llvm.mlir.global definitions and corresponding globals.
   DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index da9c734fbd80..5ffb11e76a93 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -413,24 +413,11 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
 
   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
                        llvm::BasicBlock &continuationIP) {
-    llvm::LLVMContext &llvmContext = llvmModule->getContext();
-
-    llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock();
-    llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator();
-    ompContinuationIPStack.push_back(&continuationIP);
-
-    // ParallelOp has only `1` region associated with it.
+    // ParallelOp has only one region associated with it.
     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
-    for (auto &bb : region) {
-      auto *llvmBB = llvm::BasicBlock::Create(
-          llvmContext, "omp.par.region", codeGenIP.getBlock()->getParent());
-      blockMapping[&bb] = llvmBB;
-    }
-
-    convertOmpOpRegions(region, valueMapping, blockMapping, codeGenIPBBTI,
-                        continuationIP, builder, bodyGenStatus);
-    ompContinuationIPStack.pop_back();
-
+    convertOmpOpRegions(region, "omp.par.region", valueMapping, blockMapping,
+                        *codeGenIP.getBlock(), continuationIP, builder,
+                        bodyGenStatus);
   };
 
   // TODO: Perform appropriate actions according to the data-sharing
@@ -472,29 +459,50 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
 }
 
 void ModuleTranslation::convertOmpOpRegions(
-    Region &region, DenseMap<Value, llvm::Value *> &valueMapping,
+    Region &region, StringRef blockName,
+    DenseMap<Value, llvm::Value *> &valueMapping,
     DenseMap<Block *, llvm::BasicBlock *> &blockMapping,
-    llvm::Instruction *codeGenIPBBTI, llvm::BasicBlock &continuationIP,
+    llvm::BasicBlock &sourceBlock, llvm::BasicBlock &continuationBlock,
     llvm::IRBuilder<> &builder, LogicalResult &bodyGenStatus) {
+  llvm::LLVMContext &llvmContext = builder.getContext();
+  for (Block &bb : region) {
+    llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
+        llvmContext, blockName, builder.GetInsertBlock()->getParent());
+    blockMapping[&bb] = llvmBB;
+  }
+
+  llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
+
   // Convert blocks one by one in topological order to ensure
   // defs are converted before uses.
   llvm::SetVector<Block *> blocks = topologicalSort(region);
-  for (auto indexedBB : llvm::enumerate(blocks)) {
-    Block *bb = indexedBB.value();
-    llvm::BasicBlock *curLLVMBB = blockMapping[bb];
+  for (Block *bb : blocks) {
+    llvm::BasicBlock *llvmBB = blockMapping[bb];
+    // Retarget the branch of the entry block to the entry block of the
+    // converted region (regions are single-entry).
     if (bb->isEntryBlock()) {
-      assert(codeGenIPBBTI->getNumSuccessors() == 1 &&
-             "OpenMPIRBuilder provided entry block has multiple successors");
-      assert(codeGenIPBBTI->getSuccessor(0) == &continuationIP &&
-             "ContinuationIP is not the successor of OpenMPIRBuilder "
-             "provided entry block");
-      codeGenIPBBTI->setSuccessor(0, curLLVMBB);
+      assert(sourceTerminator->getNumSuccessors() == 1 &&
+             "provided entry block has multiple successors");
+      assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
+             "ContinuationBlock is not the successor of the entry block");
+      sourceTerminator->setSuccessor(0, llvmBB);
     }
 
-    if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))) {
+    llvm::IRBuilder<>::InsertPointGuard guard(builder);
+    if (failed(convertBlock(*bb, bb->isEntryBlock(), builder))) {
       bodyGenStatus = failure();
       return;
     }
+
+    // Special handling for `omp.yield` and `omp.terminator` (we may have more
+    // than one): they return the control to the parent OpenMP dialect operation
+    // so replace them with the branch to the continuation block. We handle this
+    // here to avoid relying inter-function communication through the
+    // ModuleTranslation class to set up the correct insertion point. This is
+    // also consistent with MLIR's idiom of handling special region terminators
+    // in the same code that handles the region-owning operation.
+    if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator()))
+      builder.CreateBr(&continuationBlock);
   }
   // Finally, after all blocks have been traversed and values mapped,
   // connect the PHI nodes to the results of preceding blocks.
@@ -510,22 +518,11 @@ LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst,
 
   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
                        llvm::BasicBlock &continuationIP) {
-    llvm::LLVMContext &llvmContext = llvmModule->getContext();
-
-    llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock();
-    llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator();
-    ompContinuationIPStack.push_back(&continuationIP);
-
-    // MasterOp has only `1` region associated with it.
+    // MasterOp has only one region associated with it.
     auto &region = cast<omp::MasterOp>(opInst).getRegion();
-    for (auto &bb : region) {
-      auto *llvmBB = llvm::BasicBlock::Create(
-          llvmContext, "omp.master.region", codeGenIP.getBlock()->getParent());
-      blockMapping[&bb] = llvmBB;
-    }
-    convertOmpOpRegions(region, valueMapping, blockMapping, codeGenIPBBTI,
-                        continuationIP, builder, bodyGenStatus);
-    ompContinuationIPStack.pop_back();
+    convertOmpOpRegions(region, "omp.master.region", valueMapping, blockMapping,
+                        *codeGenIP.getBlock(), continuationIP, builder,
+                        bodyGenStatus);
   };
 
   // TODO: Perform finalization actions for variables. This has to be
@@ -553,9 +550,6 @@ LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst,
     return opInst.emitOpError(
         "only static (default) loop schedule is currently supported");
 
-  llvm::Function *func = builder.GetInsertBlock()->getParent();
-  llvm::LLVMContext &llvmContext = llvmModule->getContext();
-
   // Find the loop configuration.
   llvm::Value *lowerBound = valueMapping.lookup(loop.lowerBound()[0]);
   llvm::Value *upperBound = valueMapping.lookup(loop.upperBound()[0]);
@@ -589,44 +583,9 @@ LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst,
         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
 
     // Convert the body of the loop.
-    Region &region = loop.region();
-    for (Block &bb : region) {
-      llvm::BasicBlock *llvmBB =
-          llvm::BasicBlock::Create(llvmContext, "omp.wsloop.region", func);
-      blockMapping[&bb] = llvmBB;
-
-      // Retarget the branch of the entry block to the entry block of the
-      // converted region (regions are single-entry).
-      if (bb.isEntryBlock()) {
-        auto *branch = cast<llvm::BranchInst>(entryBlock->getTerminator());
-        branch->setSuccessor(0, llvmBB);
-      }
-    }
-
-    // Block conversion creates a new IRBuilder every time so need not bother
-    // about maintaining the insertion point.
-    llvm::SetVector<Block *> blocks = topologicalSort(region);
-    for (Block *bb : blocks) {
-      if (failed(convertBlock(*bb, bb->isEntryBlock()))) {
-        bodyGenStatus = failure();
-        return;
-      }
-
-      // Special handling for `omp.yield` terminators (we may have more than
-      // one): they return the control to the parent WsLoop operation so replace
-      // them with the branch to the exit block. We handle this here to avoid
-      // relying inter-function communication through the ModuleTranslation
-      // class to set up the correct insertion point. This is also consistent
-      // with MLIR's idiom of handling special region terminators in the same
-      // code that handles the region-owning operation.
-      if (isa<omp::YieldOp>(bb->getTerminator())) {
-        llvm::BasicBlock *llvmBB = blockMapping[bb];
-        builder.SetInsertPoint(llvmBB, llvmBB->end());
-        builder.CreateBr(exitBlock);
-      }
-    }
-
-    connectPHINodes(region, valueMapping, blockMapping, branchMapping);
+    convertOmpOpRegions(loop.region(), "omp.wsloop.region", valueMapping,
+                        blockMapping, *entryBlock, *exitBlock, builder,
+                        bodyGenStatus);
   };
 
   // Delegate actual loop construction to the OpenMP IRBuilder.
@@ -690,18 +649,15 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
         ompBuilder->createFlush(builder.saveIP());
         return success();
       })
-      .Case([&](omp::TerminatorOp) {
-        builder.CreateBr(ompContinuationIPStack.back());
-        return success();
-      })
       .Case(
           [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
       .Case([&](omp::MasterOp) { return convertOmpMaster(opInst, builder); })
       .Case([&](omp::WsLoopOp) { return convertOmpWsLoop(opInst, builder); })
-      .Case([&](omp::YieldOp op) {
-        // Yields are loop terminators that can be just omitted. The loop
-        // structure was created in the function that handles WsLoopOp.
-        assert(op.getNumOperands() == 0 && "unexpected yield with operands");
+      .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) {
+        // `yield` and `terminator` can be just omitted. The block structure was
+        // created in the function that handles their parent operation.
+        assert(op->getNumOperands() == 0 &&
+               "unexpected OpenMP terminator with operands");
         return success();
       })
       .Default([&](Operation *inst) {
@@ -911,9 +867,14 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
 
 /// Convert block to LLVM IR.  Unless `ignoreArguments` is set, emit PHI nodes
 /// to define values corresponding to the MLIR block arguments.  These nodes
-/// are not connected to the source basic blocks, which may not exist yet.
-LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
-  llvm::IRBuilder<> builder(blockMapping[&bb]);
+/// are not connected to the source basic blocks, which may not exist yet.  Uses
+/// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
+/// been created for `bb` and included in the block mapping.  Inserts new
+/// instructions at the end of the block and leaves `builder` in a state
+/// suitable for further insertion into the end of the block.
+LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
+                                              llvm::IRBuilder<> &builder) {
+  builder.SetInsertPoint(blockMapping[&bb]);
   auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
 
   // Before traversing operations, make block arguments available through
@@ -1137,9 +1098,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   // Then, convert blocks one by one in topological order to ensure defs are
   // converted before uses.
   auto blocks = topologicalSort(func);
-  for (auto indexedBB : llvm::enumerate(blocks)) {
-    auto *bb = indexedBB.value();
-    if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
+  for (Block *bb : blocks) {
+    llvm::IRBuilder<> builder(llvmContext);
+    if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
       return failure();
   }
 


        


More information about the llvm-branch-commits mailing list