[Mlir-commits] [mlir] [MLIR][OpenMP] Prevent loop wrapper translation crashes (PR #115475)

Sergio Afonso llvmlistbot at llvm.org
Fri Nov 8 04:59:40 PST 2024


https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/115475

This patch updates the `convertOmpOpRegions` translation function to prevent calling it for a loop wrapper region from causing a compiler crash due to a lack of terminator operations.

This problem is currently not triggered because there are no cases for which the region of a loop wrapper is passed to that function. This might have to change in order to support composite construct translation to LLVM IR.

>From a18ac6599fb602c8b0abb17b9887a39c51d84e8e Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 8 Nov 2024 12:46:34 +0000
Subject: [PATCH] [MLIR][OpenMP] Prevent loop wrapper translation crashes

This patch updates the `convertOmpOpRegions` translation function to prevent
calling it for a loop wrapper region from causing a compiler crash due to a
lack of terminator operations.

This problem is currently not triggered because there are no cases for which
the region of a loop wrapper is passed to that function. This might have to
change in order to support composite construct translation to LLVM IR.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 53 ++++++++++++-------
 1 file changed, 33 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da11ee9960e1f9..b507fa656d601f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -391,6 +391,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
     LLVM::ModuleTranslation &moduleTranslation,
     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+  bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
   llvm::BasicBlock *continuationBlock =
       splitBB(builder, true, "omp.region.cont");
   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -407,30 +409,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
 
   // Terminators (namely YieldOp) may be forwarding values to the region that
   // need to be available in the continuation block. Collect the types of these
-  // operands in preparation of creating PHI nodes.
+  // operands in preparation of creating PHI nodes. This is skipped for loop
+  // wrapper operations, for which we know in advance they have no terminators.
   SmallVector<llvm::Type *> continuationBlockPHITypes;
-  bool operandsProcessed = false;
   unsigned numYields = 0;
-  for (Block &bb : region.getBlocks()) {
-    if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
-      if (!operandsProcessed) {
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          continuationBlockPHITypes.push_back(
-              moduleTranslation.convertType(yield->getOperand(i).getType()));
-        }
-        operandsProcessed = true;
-      } else {
-        assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
-               "mismatching number of values yielded from the region");
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          llvm::Type *operandType =
-              moduleTranslation.convertType(yield->getOperand(i).getType());
-          (void)operandType;
-          assert(continuationBlockPHITypes[i] == operandType &&
-                 "values of mismatching types yielded from the region");
+
+  if (!isLoopWrapper) {
+    bool operandsProcessed = false;
+    for (Block &bb : region.getBlocks()) {
+      if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+        if (!operandsProcessed) {
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            continuationBlockPHITypes.push_back(
+                moduleTranslation.convertType(yield->getOperand(i).getType()));
+          }
+          operandsProcessed = true;
+        } else {
+          assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+                 "mismatching number of values yielded from the region");
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            llvm::Type *operandType =
+                moduleTranslation.convertType(yield->getOperand(i).getType());
+            (void)operandType;
+            assert(continuationBlockPHITypes[i] == operandType &&
+                   "values of mismatching types yielded from the region");
+          }
         }
+        numYields++;
       }
-      numYields++;
     }
   }
 
@@ -468,6 +474,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    // Create a direct branch here for loop wrappers to prevent their lack of a
+    // terminator from causing a crash below.
+    if (isLoopWrapper) {
+      builder.CreateBr(continuationBlock);
+      continue;
+    }
+
     // Special handling for `omp.yield` and `omp.terminator` (we may have more
     // than one): they return the control to the parent OpenMP dialect operation
     // so replace them with the branch to the continuation block. We handle this



More information about the Mlir-commits mailing list