[Mlir-commits] [mlir] a71a0d6 - [OpenMP][MLIR] Fix for nested parallel regions

Kiran Chandramohan llvmlistbot at llvm.org
Mon Oct 19 00:49:52 PDT 2020


Author: Kiran Chandramohan
Date: 2020-10-19T08:45:50+01:00
New Revision: a71a0d6d219a8d460e6aadadc47dbb04955f1375

URL: https://github.com/llvm/llvm-project/commit/a71a0d6d219a8d460e6aadadc47dbb04955f1375
DIFF: https://github.com/llvm/llvm-project/commit/a71a0d6d219a8d460e6aadadc47dbb04955f1375.diff

LOG: [OpenMP][MLIR] Fix for nested parallel regions

Usage of nested parallel regions were not working correctly and leading
to assertion failures. Fix contains the following changes,
1) Don't set the insertion point in the body callback.
2) Save the continuation IP in a stack and set the branch to
continuationIP at the terminator.

Reviewed By: SouraVX, jdoerfert, ftynse

Differential Revision: https://reviews.llvm.org/D88720

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Target/openmp-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index d28c7968c5ae..7fc66f18d53c 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -127,6 +127,10 @@ class ModuleTranslation {
   /// OpenMP dialect hasn't been loaded (it is always loaded if there are OpenMP
   /// operations in the module though).
   const Dialect *ompDialect;
+  /// Stack which stores the target block to which a branch a must be added when
+  /// a terminator is seen. A stack is required to handle nested OpenMP parallel
+  /// regions.
+  SmallVector<llvm::BasicBlock *, 4> ompContinuationIPStack;
 
   /// Mappings between llvm.mlir.global definitions and corresponding globals.
   DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 23f5698b80a2..3d5e091e58e6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -391,8 +391,8 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
 
     llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock();
     llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator();
+    ompContinuationIPStack.push_back(&continuationIP);
 
-    builder.SetInsertPoint(codeGenIPBB);
     // ParallelOp has only `1` region associated with it.
     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
     for (auto &bb : region) {
@@ -407,22 +407,22 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
     for (auto indexedBB : llvm::enumerate(blocks)) {
       Block *bb = indexedBB.value();
       llvm::BasicBlock *curLLVMBB = blockMapping[bb];
-      if (bb->isEntryBlock())
+      if (bb->isEntryBlock()) {
+        assert(codeGenIPBBTI->getNumSuccessors() == 1 &&
+               "OpenMPIRBuilder provided entry block has multiple successors");
+        assert(codeGenIPBBTI->getSuccessor(0) == &continuationIP &&
+               "ContinuationIP is not the successor of OpenMPIRBuilder "
+               "provided entry block");
         codeGenIPBBTI->setSuccessor(0, curLLVMBB);
+      }
 
       // TODO: Error not returned up the hierarchy
       if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
         return;
-
-      // If this block has the terminator then add a jump to
-      // continuation bb
-      for (auto &op : *bb) {
-        if (isa<omp::TerminatorOp>(op)) {
-          builder.SetInsertPoint(curLLVMBB);
-          builder.CreateBr(&continuationIP);
-        }
-      }
     }
+
+    ompContinuationIPStack.pop_back();
+
     // Finally, after all blocks have been traversed and values mapped,
     // connect the PHI nodes to the results of preceding blocks.
     connectPHINodes(region, valueMapping, blockMapping);
@@ -498,7 +498,10 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
         ompBuilder->CreateFlush(builder.saveIP());
         return success();
       })
-      .Case([&](omp::TerminatorOp) { return success(); })
+      .Case([&](omp::TerminatorOp) {
+        builder.CreateBr(ompContinuationIPStack.back());
+        return success();
+      })
       .Case(
           [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
       .Default([&](Operation *inst) {

diff  --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir
index 518fa6fce2d4..4b3de0b052e5 100644
--- a/mlir/test/Target/openmp-llvm.mlir
+++ b/mlir/test/Target/openmp-llvm.mlir
@@ -214,3 +214,53 @@ llvm.func @test_omp_parallel_3() -> () {
 // CHECK: define internal void @[[OMP_OUTLINED_FN_3_3]]
 // CHECK: define internal void @[[OMP_OUTLINED_FN_3_2]]
 // CHECK: define internal void @[[OMP_OUTLINED_FN_3_1]]
+
+// CHECK-LABEL: define void @test_omp_parallel_4()
+llvm.func @test_omp_parallel_4() -> () {
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_4_1:.*]] to
+// CHECK: define internal void @[[OMP_OUTLINED_FN_4_1]]
+// CHECK: call void @__kmpc_barrier
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_4_1_1:.*]] to
+// CHECK: call void @__kmpc_barrier
+  omp.parallel {
+    omp.barrier
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_4_1_1]]
+// CHECK: call void @__kmpc_barrier
+    omp.parallel {
+      omp.barrier
+      omp.terminator
+    }
+
+    omp.barrier
+    omp.terminator
+  }
+  llvm.return
+}
+
+llvm.func @test_omp_parallel_5() -> () {
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_5_1:.*]] to
+// CHECK: define internal void @[[OMP_OUTLINED_FN_5_1]]
+// CHECK: call void @__kmpc_barrier
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_5_1_1:.*]] to
+// CHECK: call void @__kmpc_barrier
+  omp.parallel {
+    omp.barrier
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_5_1_1]]
+    omp.parallel {
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_5_1_1_1:.*]] to
+// CHECK: define internal void @[[OMP_OUTLINED_FN_5_1_1_1]]
+// CHECK: call void @__kmpc_barrier
+      omp.parallel {
+        omp.barrier
+        omp.terminator
+      }
+      omp.terminator
+    }
+
+    omp.barrier
+    omp.terminator
+  }
+  llvm.return
+}


        


More information about the Mlir-commits mailing list