[llvm-branch-commits] [llvm] [OpenMPIRBuilder] Split calculation of canonical loop trip count, NFC (PR #127820)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Feb 21 02:08:24 PST 2025
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/127820
>From 082d8e12a622e2315dd4503ce460f9a0e6f29007 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 18 Feb 2025 14:19:30 +0000
Subject: [PATCH] [OpenMPIRBuilder] Split calculation of canonical loop trip
count, NFC
This patch splits off the calculation of canonical loop trip counts from the
creation of canonical loops. This makes it possible to reuse this logic to, for
instance, populate the `__tgt_target_kernel` runtime call for SPMD kernels.
This feature is used to simplify one of the existing OpenMPIRBuilder tests.
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 38 +++++++++++++++----
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 27 ++++++++-----
.../Frontend/OpenMPIRBuilderTest.cpp | 16 ++------
3 files changed, 52 insertions(+), 29 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 9ad85413acd34..207ca7fb05f62 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -728,13 +728,12 @@ class OpenMPIRBuilder {
LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
const Twine &Name = "loop");
- /// Generator for the control flow structure of an OpenMP canonical loop.
+ /// Calculate the trip count of a canonical loop.
///
- /// Instead of a logical iteration space, this allows specifying user-defined
- /// loop counter values using increment, upper- and lower bounds. To
- /// disambiguate the terminology when counting downwards, instead of lower
- /// bounds we use \p Start for the loop counter value in the first body
- /// iteration.
+ /// This allows specifying user-defined loop counter values using increment,
+ /// upper- and lower bounds. To disambiguate the terminology when counting
+ /// downwards, instead of lower bounds we use \p Start for the loop counter
+ /// value in the first body iteration.
///
/// Consider the following limitations:
///
@@ -758,7 +757,32 @@ class OpenMPIRBuilder {
///
/// for (int i = 0; i < 42; i -= 1u)
///
- //
+ /// \param Loc The insert and source location description.
+ /// \param Start Value of the loop counter for the first iterations.
+ /// \param Stop Loop counter values past this will stop the loop.
+ /// \param Step Loop counter increment after each iteration; negative
+ /// means counting down.
+ /// \param IsSigned Whether Start, Stop and Step are signed integers.
+ /// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
+ /// counter.
+ /// \param Name Base name used to derive instruction names.
+ ///
+ /// \returns The value holding the calculated trip count.
+ Value *calculateCanonicalLoopTripCount(const LocationDescription &Loc,
+ Value *Start, Value *Stop, Value *Step,
+ bool IsSigned, bool InclusiveStop,
+ const Twine &Name = "loop");
+
+ /// Generator for the control flow structure of an OpenMP canonical loop.
+ ///
+ /// Instead of a logical iteration space, this allows specifying user-defined
+ /// loop counter values using increment, upper- and lower bounds. To
+ /// disambiguate the terminology when counting downwards, instead of lower
+ /// bounds we use \p Start for the loop counter value in the first body
+ ///
+ /// It calls \see calculateCanonicalLoopTripCount for trip count calculations,
+ /// so limitations of that method apply here as well.
+ ///
/// \param Loc The insert and source location description.
/// \param BodyGenCB Callback that will generate the loop body code.
/// \param Start Value of the loop counter for the first iterations.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 7788897fc0795..eee6e3e54d615 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4059,10 +4059,9 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
return CL;
}
-Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
- const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
- Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
- InsertPointTy ComputeIP, const Twine &Name) {
+Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
+ const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
+ bool IsSigned, bool InclusiveStop, const Twine &Name) {
// Consider the following difficulties (assuming 8-bit signed integers):
// * Adding \p Step to the loop counter which passes \p Stop may overflow:
@@ -4075,9 +4074,7 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
assert(IndVarTy == Stop->getType() && "Stop type mismatch");
assert(IndVarTy == Step->getType() && "Step type mismatch");
- LocationDescription ComputeLoc =
- ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
- updateToLocation(ComputeLoc);
+ updateToLocation(Loc);
ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
ConstantInt *One = ConstantInt::get(IndVarTy, 1);
@@ -4117,8 +4114,20 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
}
- Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
- "omp_" + Name + ".tripcount");
+
+ return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
+ "omp_" + Name + ".tripcount");
+}
+
+Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
+ const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+ Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+ InsertPointTy ComputeIP, const Twine &Name) {
+ LocationDescription ComputeLoc =
+ ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
+
+ Value *TripCount = calculateCanonicalLoopTripCount(
+ ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
Builder.restoreIP(CodeGenIP);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 83c8f7e932b2b..9f4946a32d9b1 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -1441,8 +1441,7 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
}
-TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
- using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
IRBuilder<> Builder(BB);
@@ -1458,17 +1457,8 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
Value *StartVal = ConstantInt::get(LCTy, Start);
Value *StopVal = ConstantInt::get(LCTy, Stop);
Value *StepVal = ConstantInt::get(LCTy, Step);
- auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
- return Error::success();
- };
- ASSERT_EXPECTED_INIT_RETURN(
- CanonicalLoopInfo *, Loop,
- OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
- StepVal, IsSigned, InclusiveStop),
- -1);
- Loop->assertOK();
- Builder.restoreIP(Loop->getAfterIP());
- Value *TripCount = Loop->getTripCount();
+ Value *TripCount = OMPBuilder.calculateCanonicalLoopTripCount(
+ Loc, StartVal, StopVal, StepVal, IsSigned, InclusiveStop);
return cast<ConstantInt>(TripCount)->getValue().getZExtValue();
};
More information about the llvm-branch-commits
mailing list