[Mlir-commits] [mlir] [MLIR][OpenMP] Normalize lowering of omp.loop_nest (PR #127217)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 14 07:51:58 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
Author: Sergio Afonso (skatrak)
<details>
<summary>Changes</summary>
This patch refactors the translation of `omp.loop_nest` operations into LLVM IR so that it is handled similarly to other operations. Before this change, the responsibility of translating the loop nest fell into each loop wrapper, causing code duplication. This patch centralizes that handling of the loop. One consequence of this was fixing an issue lowering non-inclusive `omp.simd` loops.
As a result, it is now expected that the handling of composite constructs is performed collaboratively among translating functions for each operation involved. At the moment, only `do/for simd` is supported by ignoring SIMD information, and this behavior is preserved.
The translation of loop wrapper operations need access to the `llvm::CanonicalLoopInfo` loop information structure in order to apply transformations to it. This is now created in the nested call to `convertOmpLoopNest`, so it needs to be passed up to all associated loop wrapper translation functions. This is done via the creation of an `OpenMPLoopInfoStackFrame` within `convertOmpLoopNest` and its removal after its outermost associated loop wrapper has been translated.
---
Patch is 36.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127217.diff
3 Files Affected:
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+303-292)
- (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-simd-private.mlir (+6-3)
``````````diff
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 51a3cbdbb5e7f..a5ff3eff6439f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -75,6 +75,19 @@ class OpenMPAllocaStackFrame
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
};
+/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
+/// collapsed canonical loop information corresponding to an \c omp.loop_nest
+/// operation.
+class OpenMPLoopInfoStackFrame
+ : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
+
+ explicit OpenMPLoopInfoStackFrame(llvm::CanonicalLoopInfo *loopInfo)
+ : loopInfo(loopInfo) {}
+ llvm::CanonicalLoopInfo *loopInfo;
+};
+
/// Custom error class to signal translation errors that don't need reporting,
/// since encountering them will have already triggered relevant error messages.
///
@@ -372,6 +385,20 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
&funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
}
+/// Find the loop information structure for the loop nest being translated. It
+/// will not return a value unless called from the translation function for
+/// a loop wrapper operation after successfully translating its body.
+static std::optional<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
+ std::optional<llvm::CanonicalLoopInfo *> loopInfo;
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](const OpenMPLoopInfoStackFrame &frame) {
+ loopInfo = frame.loopInfo;
+ return WalkResult::interrupt();
+ });
+ return loopInfo;
+}
+
/// Converts the given region that appears within an OpenMP dialect operation to
/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
/// region, and a branch from any block with an successor-less OpenMP terminator
@@ -381,6 +408,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+ bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
llvm::BasicBlock *continuationBlock =
splitBB(builder, true, "omp.region.cont");
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -397,30 +426,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
// Terminators (namely YieldOp) may be forwarding values to the region that
// need to be available in the continuation block. Collect the types of these
- // operands in preparation of creating PHI nodes.
+ // operands in preparation of creating PHI nodes. This is skipped for loop
+ // wrapper operations, for which we know in advance they have no terminators.
SmallVector<llvm::Type *> continuationBlockPHITypes;
- bool operandsProcessed = false;
unsigned numYields = 0;
- for (Block &bb : region.getBlocks()) {
- if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
- if (!operandsProcessed) {
- for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
- continuationBlockPHITypes.push_back(
- moduleTranslation.convertType(yield->getOperand(i).getType()));
- }
- operandsProcessed = true;
- } else {
- assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
- "mismatching number of values yielded from the region");
- for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
- llvm::Type *operandType =
- moduleTranslation.convertType(yield->getOperand(i).getType());
- (void)operandType;
- assert(continuationBlockPHITypes[i] == operandType &&
- "values of mismatching types yielded from the region");
+
+ if (!isLoopWrapper) {
+ bool operandsProcessed = false;
+ for (Block &bb : region.getBlocks()) {
+ if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+ if (!operandsProcessed) {
+ for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+ continuationBlockPHITypes.push_back(
+ moduleTranslation.convertType(yield->getOperand(i).getType()));
+ }
+ operandsProcessed = true;
+ } else {
+ assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+ "mismatching number of values yielded from the region");
+ for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+ llvm::Type *operandType =
+ moduleTranslation.convertType(yield->getOperand(i).getType());
+ (void)operandType;
+ assert(continuationBlockPHITypes[i] == operandType &&
+ "values of mismatching types yielded from the region");
+ }
}
+ numYields++;
}
- numYields++;
}
}
@@ -458,6 +491,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
return llvm::make_error<PreviouslyReportedError>();
+ // Create a direct branch here for loop wrappers to prevent their lack of a
+ // terminator from causing a crash below.
+ if (isLoopWrapper) {
+ builder.CreateBr(continuationBlock);
+ continue;
+ }
+
// Special handling for `omp.yield` and `omp.terminator` (we may have more
// than one): they return the control to the parent OpenMP dialect operation
// so replace them with the branch to the continuation block. We handle this
@@ -509,7 +549,7 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
/// This must be called after block arguments of parent wrappers have already
/// been mapped to LLVM IR values.
static LogicalResult
-convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
+convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
LLVM::ModuleTranslation &moduleTranslation) {
// Map block arguments directly to the LLVM value associated to the
// corresponding operand. This is semantically equivalent to this wrapper not
@@ -531,34 +571,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
return success();
})
.Default([&](Operation *op) {
- return op->emitError() << "cannot ignore nested wrapper";
+ return op->emitError() << "cannot ignore wrapper";
});
}
-/// Helper function to call \c convertIgnoredWrapper() for all wrappers of the
-/// given \c loopOp nested inside of \c parentOp. This has the effect of mapping
-/// entry block arguments defined by these operations to outside values.
-///
-/// It must be called after block arguments of \c parentOp have already been
-/// mapped themselves.
-static LogicalResult
-convertIgnoredWrappers(omp::LoopNestOp loopOp,
- omp::LoopWrapperInterface parentOp,
- LLVM::ModuleTranslation &moduleTranslation) {
- SmallVector<omp::LoopWrapperInterface> wrappers;
- loopOp.gatherWrappers(wrappers);
-
- // Process wrappers nested inside of `parentOp` from outermost to innermost.
- for (auto it =
- std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
- it != wrappers.rend(); ++it) {
- if (failed(convertIgnoredWrapper(*it, moduleTranslation)))
- return failure();
- }
-
- return success();
-}
-
/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -1876,6 +1892,7 @@ convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
static LogicalResult
convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto wsloopOp = cast<omp::WsloopOp>(opInst);
if (failed(checkImplementationStatus(opInst)))
return failure();
@@ -1956,90 +1973,25 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
reductionVariableMap, isByRef, deferredStores)))
return failure();
- // TODO: Replace this with proper composite translation support.
- // Currently, all nested wrappers are ignored, so 'do/for simd' will be
- // treated the same as a standalone 'do/for'. This is allowed by the spec,
- // since it's equivalent to always using a SIMD length of 1.
- if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation)))
- return failure();
-
- // Set up the source location value for OpenMP runtime.
- llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
-
- // Generator of the canonical loop body.
- SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
- SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
- auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
- llvm::Value *iv) -> llvm::Error {
- // Make sure further conversions know about the induction variable.
- moduleTranslation.mapValue(
- loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
-
- // Capture the body insertion point for use in nested loops. BodyIP of the
- // CanonicalLoopInfo always points to the beginning of the entry block of
- // the body.
- bodyInsertPoints.push_back(ip);
-
- if (loopInfos.size() != loopOp.getNumLoops() - 1)
- return llvm::Error::success();
-
- // Convert the body of the loop.
- builder.restoreIP(ip);
- return convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
- moduleTranslation)
- .takeError();
- };
-
- // Delegate actual loop construction to the OpenMP IRBuilder.
- // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
- // loop, i.e. it has a positive step, uses signed integer semantics.
- // Reconsider this code when the nested loop operation clearly supports more
- // cases.
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
- llvm::Value *lowerBound =
- moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
- llvm::Value *upperBound =
- moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
- llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
-
- // Make sure loop trip count are emitted in the preheader of the outermost
- // loop at the latest so that they are all available for the new collapsed
- // loop will be created below.
- llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
- llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
- if (i != 0) {
- loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
- computeIP = loopInfos.front()->getPreheaderIP();
- }
-
- llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
- ompBuilder->createCanonicalLoop(
- loc, bodyGen, lowerBound, upperBound, step,
- /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
-
- if (failed(handleError(loopResult, *loopOp)))
- return failure();
-
- loopInfos.push_back(*loopResult);
- }
-
- // Collapse loops. Store the insertion point because LoopInfos may get
- // invalidated.
- llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
- llvm::CanonicalLoopInfo *loopInfo =
- ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
- allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
-
// TODO: Handle doacross loops when the ordered clause has a parameter.
bool isOrdered = wsloopOp.getOrdered().has_value();
std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
bool isSimd = wsloopOp.getScheduleSimd();
+ bool loopNeedsBarrier = !wsloopOp.getNowait();
+
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+ wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
+
+ if (failed(handleError(regionBlock, opInst)))
+ return failure();
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
convertToScheduleKind(schedule), chunk, isSimd,
scheduleMod == omp::ScheduleModifier::monotonic,
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered);
@@ -2047,12 +1999,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();
- // Continue building IR after the loop. Note that the LoopInfo returned by
- // `collapseLoops` points inside the outermost loop and is intended for
- // potential further loop transformations. Use the insertion point stored
- // before collapsing loops instead.
- builder.restoreIP(afterIP);
-
// Process the reductions if required.
if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
allocaIP, reductionDecls,
@@ -2261,8 +2207,20 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
static LogicalResult
convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto simdOp = cast<omp::SimdOp>(opInst);
- auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
+
+ // TODO: Replace this with proper composite translation support.
+ // Currently, simd information on composite constructs is ignored, so e.g.
+ // 'do/for simd' will be treated the same as a standalone 'do/for'. This is
+ // allowed by the spec, since it's equivalent to using a SIMD length of 1.
+ if (simdOp.isComposite()) {
+ if (failed(convertIgnoredWrapper(simdOp, moduleTranslation)))
+ return failure();
+
+ return inlineConvertOmpRegions(simdOp.getRegion(), "omp.simd.region",
+ builder, moduleTranslation);
+ }
if (failed(checkImplementationStatus(opInst)))
return failure();
@@ -2295,6 +2253,61 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
.failed())
return failure();
+ llvm::ConstantInt *simdlen = nullptr;
+ if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
+ simdlen = builder.getInt64(simdlenVar.value());
+
+ llvm::ConstantInt *safelen = nullptr;
+ if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
+ safelen = builder.getInt64(safelenVar.value());
+
+ llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
+ llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+ llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
+ std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
+ mlir::OperandRange operands = simdOp.getAlignedVars();
+ for (size_t i = 0; i < operands.size(); ++i) {
+ llvm::Value *alignment = nullptr;
+ llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
+ llvm::Type *ty = llvmVal->getType();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
+ alignment = builder.getInt64(intAttr.getInt());
+ assert(ty->isPointerTy() && "Invalid type for aligned variable");
+ assert(alignment && "Invalid alignment value");
+ auto curInsert = builder.saveIP();
+ builder.SetInsertPoint(sourceBlock);
+ llvmVal = builder.CreateLoad(ty, llvmVal);
+ builder.restoreIP(curInsert);
+ alignedVars[llvmVal] = alignment;
+ }
+ }
+
+ llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+ simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
+
+ if (failed(handleError(regionBlock, opInst)))
+ return failure();
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo(moduleTranslation);
+ ompBuilder->applySimd(loopInfo, alignedVars,
+ simdOp.getIfExpr()
+ ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+ : nullptr,
+ order, simdlen, safelen);
+
+ return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
+ llvmPrivateVars, privateDecls);
+}
+
+/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ auto loopOp = cast<omp::LoopNestOp>(opInst);
+
+ // Set up the source location value for OpenMP runtime.
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
// Generator of the canonical loop body.
@@ -2316,9 +2329,13 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
// Convert the body of the loop.
builder.restoreIP(ip);
- return convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
- moduleTranslation)
- .takeError();
+ llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+ loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
+ if (!regionBlock)
+ return regionBlock.takeError();
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ return llvm::Error::success();
};
// Delegate actual loop construction to the OpenMP IRBuilder.
@@ -2326,7 +2343,6 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
// loop, i.e. it has a positive step, uses signed integer semantics.
// Reconsider this code when the nested loop operation clearly supports more
// cases.
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
llvm::Value *lowerBound =
moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
@@ -2348,7 +2364,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
ompBuilder->createCanonicalLoop(
loc, bodyGen, lowerBound, upperBound, step,
- /*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
if (failed(handleError(loopResult, *loopOp)))
return failure();
@@ -2356,49 +2372,23 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
loopInfos.push_back(*loopResult);
}
- // Collapse loops.
- llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
- llvm::CanonicalLoopInfo *loopInfo =
- ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
- llvm::ConstantInt *simdlen = nullptr;
- if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
- simdlen = builder.getInt64(simdlenVar.value());
+ // Collapse loops. Store the insertion point because LoopInfos may get
+ // invalidated.
+ llvm::OpenMPIRBuilder::InsertPointTy afterIP =
+ loopInfos.front()->getAfterIP();
- llvm::ConstantInt *safelen = nullptr;
- if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
- safelen = builder.getInt64(safelenVar.value());
-
- llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
- llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
- llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
- std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
- mlir::OperandRange...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/127217
More information about the Mlir-commits
mailing list