[Mlir-commits] [mlir] [MLIR][OpenMP] Prevent loop wrapper translation crashes (PR #115475)
Sergio Afonso
llvmlistbot at llvm.org
Fri Nov 8 04:59:40 PST 2024
https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/115475
This patch updates the `convertOmpOpRegions` translation function to prevent calling it for a loop wrapper region from causing a compiler crash due to a lack of terminator operations.
This problem is currently not triggered because there are no cases for which the region of a loop wrapper is passed to that function. This might have to change in order to support composite construct translation to LLVM IR.
>From a18ac6599fb602c8b0abb17b9887a39c51d84e8e Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 8 Nov 2024 12:46:34 +0000
Subject: [PATCH] [MLIR][OpenMP] Prevent loop wrapper translation crashes
This patch updates the `convertOmpOpRegions` translation function to prevent
calling it for a loop wrapper region from causing a compiler crash due to a
lack of terminator operations.
This problem is currently not triggered because there are no cases for which
the region of a loop wrapper is passed to that function. This might have to
change in order to support composite construct translation to LLVM IR.
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 53 ++++++++++++-------
1 file changed, 33 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da11ee9960e1f9..b507fa656d601f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -391,6 +391,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+ bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
llvm::BasicBlock *continuationBlock =
splitBB(builder, true, "omp.region.cont");
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -407,30 +409,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
// Terminators (namely YieldOp) may be forwarding values to the region that
// need to be available in the continuation block. Collect the types of these
- // operands in preparation of creating PHI nodes.
+ // operands in preparation of creating PHI nodes. This is skipped for loop
+ // wrapper operations, for which we know in advance they have no terminators.
SmallVector<llvm::Type *> continuationBlockPHITypes;
- bool operandsProcessed = false;
unsigned numYields = 0;
- for (Block &bb : region.getBlocks()) {
- if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
- if (!operandsProcessed) {
- for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
- continuationBlockPHITypes.push_back(
- moduleTranslation.convertType(yield->getOperand(i).getType()));
- }
- operandsProcessed = true;
- } else {
- assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
- "mismatching number of values yielded from the region");
- for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
- llvm::Type *operandType =
- moduleTranslation.convertType(yield->getOperand(i).getType());
- (void)operandType;
- assert(continuationBlockPHITypes[i] == operandType &&
- "values of mismatching types yielded from the region");
+
+ if (!isLoopWrapper) {
+ bool operandsProcessed = false;
+ for (Block &bb : region.getBlocks()) {
+ if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+ if (!operandsProcessed) {
+ for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+ continuationBlockPHITypes.push_back(
+ moduleTranslation.convertType(yield->getOperand(i).getType()));
+ }
+ operandsProcessed = true;
+ } else {
+ assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+ "mismatching number of values yielded from the region");
+ for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+ llvm::Type *operandType =
+ moduleTranslation.convertType(yield->getOperand(i).getType());
+ (void)operandType;
+ assert(continuationBlockPHITypes[i] == operandType &&
+ "values of mismatching types yielded from the region");
+ }
}
+ numYields++;
}
- numYields++;
}
}
@@ -468,6 +474,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
return llvm::make_error<PreviouslyReportedError>();
+ // Create a direct branch here for loop wrappers to prevent their lack of a
+ // terminator from causing a crash below.
+ if (isLoopWrapper) {
+ builder.CreateBr(continuationBlock);
+ continue;
+ }
+
// Special handling for `omp.yield` and `omp.terminator` (we may have more
// than one): they return the control to the parent OpenMP dialect operation
// so replace them with the branch to the continuation block. We handle this
More information about the Mlir-commits
mailing list