[llvm-branch-commits] [llvm] [OpenMPIRBuilder] Split calculation of canonical loop trip count, NFC (PR #127820)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 24 04:55:53 PST 2025


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/127820

>From aa6182cc887136d44c8fd180a702f62b381f9b5c Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 18 Feb 2025 14:19:30 +0000
Subject: [PATCH] [OpenMPIRBuilder] Split calculation of canonical loop trip
 count, NFC

This patch splits off the calculation of canonical loop trip counts from the
creation of canonical loops. This makes it possible to reuse this logic to, for
instance, populate the `__tgt_target_kernel` runtime call for SPMD kernels.

This feature is used to simplify one of the existing OpenMPIRBuilder tests.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 38 +++++++++++++++----
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 27 ++++++++-----
 .../Frontend/OpenMPIRBuilderTest.cpp          | 16 ++------
 3 files changed, 52 insertions(+), 29 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 5fc2bd6b78fc5..80b4aa2bd2855 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -728,13 +728,12 @@ class OpenMPIRBuilder {
                       LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
                       const Twine &Name = "loop");
 
-  /// Generator for the control flow structure of an OpenMP canonical loop.
+  /// Calculate the trip count of a canonical loop.
   ///
-  /// Instead of a logical iteration space, this allows specifying user-defined
-  /// loop counter values using increment, upper- and lower bounds. To
-  /// disambiguate the terminology when counting downwards, instead of lower
-  /// bounds we use \p Start for the loop counter value in the first body
-  /// iteration.
+  /// This allows specifying user-defined loop counter values using increment,
+  /// upper- and lower bounds. To disambiguate the terminology when counting
+  /// downwards, instead of lower bounds we use \p Start for the loop counter
+  /// value in the first body iteration.
   ///
   /// Consider the following limitations:
   ///
@@ -758,7 +757,32 @@ class OpenMPIRBuilder {
   ///
   ///      for (int i = 0; i < 42; i -= 1u)
   ///
-  //
+  /// \param Loc       The insert and source location description.
+  /// \param Start     Value of the loop counter for the first iterations.
+  /// \param Stop      Loop counter values past this will stop the loop.
+  /// \param Step      Loop counter increment after each iteration; negative
+  ///                  means counting down.
+  /// \param IsSigned  Whether Start, Stop and Step are signed integers.
+  /// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
+  ///                      counter.
+  /// \param Name      Base name used to derive instruction names.
+  ///
+  /// \returns The value holding the calculated trip count.
+  Value *calculateCanonicalLoopTripCount(const LocationDescription &Loc,
+                                         Value *Start, Value *Stop, Value *Step,
+                                         bool IsSigned, bool InclusiveStop,
+                                         const Twine &Name = "loop");
+
+  /// Generator for the control flow structure of an OpenMP canonical loop.
+  ///
+  /// Instead of a logical iteration space, this allows specifying user-defined
+  /// loop counter values using increment, upper- and lower bounds. To
+  /// disambiguate the terminology when counting downwards, instead of lower
+  /// bounds we use \p Start for the loop counter value in the first body
+  ///
+  /// It calls \see calculateCanonicalLoopTripCount for trip count calculations,
+  /// so limitations of that method apply here as well.
+  ///
   /// \param Loc       The insert and source location description.
   /// \param BodyGenCB Callback that will generate the loop body code.
   /// \param Start     Value of the loop counter for the first iterations.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 9619dc71f3012..dd27c4c870e27 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4059,10 +4059,9 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
   return CL;
 }
 
-Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
-    const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
-    Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
-    InsertPointTy ComputeIP, const Twine &Name) {
+Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
+    const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
+    bool IsSigned, bool InclusiveStop, const Twine &Name) {
 
   // Consider the following difficulties (assuming 8-bit signed integers):
   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
@@ -4075,9 +4074,7 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
   assert(IndVarTy == Step->getType() && "Step type mismatch");
 
-  LocationDescription ComputeLoc =
-      ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
-  updateToLocation(ComputeLoc);
+  updateToLocation(Loc);
 
   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
@@ -4117,8 +4114,20 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
     Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
   }
-  Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
-                                          "omp_" + Name + ".tripcount");
+
+  return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
+                              "omp_" + Name + ".tripcount");
+}
+
+Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
+    const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+    Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+    InsertPointTy ComputeIP, const Twine &Name) {
+  LocationDescription ComputeLoc =
+      ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
+
+  Value *TripCount = calculateCanonicalLoopTripCount(
+      ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
 
   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
     Builder.restoreIP(CodeGenIP);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index a1ea7849d7c0c..27c0e0bf80255 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -1441,8 +1441,7 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
   EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
 }
 
-TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
-  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
   OpenMPIRBuilder OMPBuilder(*M);
   OMPBuilder.initialize();
   IRBuilder<> Builder(BB);
@@ -1458,17 +1457,8 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
     Value *StartVal = ConstantInt::get(LCTy, Start);
     Value *StopVal = ConstantInt::get(LCTy, Stop);
     Value *StepVal = ConstantInt::get(LCTy, Step);
-    auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
-      return Error::success();
-    };
-    ASSERT_EXPECTED_INIT_RETURN(
-        CanonicalLoopInfo *, Loop,
-        OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
-                                       StepVal, IsSigned, InclusiveStop),
-        -1);
-    Loop->assertOK();
-    Builder.restoreIP(Loop->getAfterIP());
-    Value *TripCount = Loop->getTripCount();
+    Value *TripCount = OMPBuilder.calculateCanonicalLoopTripCount(
+        Loc, StartVal, StopVal, StepVal, IsSigned, InclusiveStop);
     return cast<ConstantInt>(TripCount)->getValue().getZExtValue();
   };
 



More information about the llvm-branch-commits mailing list