[Mlir-commits] [mlir] [MLIR][OpenMP] Prevent loop wrapper translation crashes (PR #115475)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 8 05:00:15 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch updates the `convertOmpOpRegions` translation function to prevent calling it for a loop wrapper region from causing a compiler crash due to a lack of terminator operations.

This problem is currently not triggered because there are no cases for which the region of a loop wrapper is passed to that function. This might have to change in order to support composite construct translation to LLVM IR.

---
Full diff: https://github.com/llvm/llvm-project/pull/115475.diff


1 Files Affected:

- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+33-20) 


``````````diff
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da11ee9960e1f9..b507fa656d601f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -391,6 +391,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
     LLVM::ModuleTranslation &moduleTranslation,
     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+  bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
   llvm::BasicBlock *continuationBlock =
       splitBB(builder, true, "omp.region.cont");
   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -407,30 +409,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
 
   // Terminators (namely YieldOp) may be forwarding values to the region that
   // need to be available in the continuation block. Collect the types of these
-  // operands in preparation of creating PHI nodes.
+  // operands in preparation of creating PHI nodes. This is skipped for loop
+  // wrapper operations, for which we know in advance they have no terminators.
   SmallVector<llvm::Type *> continuationBlockPHITypes;
-  bool operandsProcessed = false;
   unsigned numYields = 0;
-  for (Block &bb : region.getBlocks()) {
-    if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
-      if (!operandsProcessed) {
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          continuationBlockPHITypes.push_back(
-              moduleTranslation.convertType(yield->getOperand(i).getType()));
-        }
-        operandsProcessed = true;
-      } else {
-        assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
-               "mismatching number of values yielded from the region");
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          llvm::Type *operandType =
-              moduleTranslation.convertType(yield->getOperand(i).getType());
-          (void)operandType;
-          assert(continuationBlockPHITypes[i] == operandType &&
-                 "values of mismatching types yielded from the region");
+
+  if (!isLoopWrapper) {
+    bool operandsProcessed = false;
+    for (Block &bb : region.getBlocks()) {
+      if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+        if (!operandsProcessed) {
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            continuationBlockPHITypes.push_back(
+                moduleTranslation.convertType(yield->getOperand(i).getType()));
+          }
+          operandsProcessed = true;
+        } else {
+          assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+                 "mismatching number of values yielded from the region");
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            llvm::Type *operandType =
+                moduleTranslation.convertType(yield->getOperand(i).getType());
+            (void)operandType;
+            assert(continuationBlockPHITypes[i] == operandType &&
+                   "values of mismatching types yielded from the region");
+          }
         }
+        numYields++;
       }
-      numYields++;
     }
   }
 
@@ -468,6 +474,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    // Create a direct branch here for loop wrappers to prevent their lack of a
+    // terminator from causing a crash below.
+    if (isLoopWrapper) {
+      builder.CreateBr(continuationBlock);
+      continue;
+    }
+
     // Special handling for `omp.yield` and `omp.terminator` (we may have more
     // than one): they return the control to the parent OpenMP dialect operation
     // so replace them with the branch to the continuation block. We handle this

``````````

</details>


https://github.com/llvm/llvm-project/pull/115475


More information about the Mlir-commits mailing list