[llvm-branch-commits] [llvm] [OpenMPIRBuilder] Split calculation of canonical loop trip count, NFC (PR #127820)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Feb 19 07:51:19 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch splits off the calculation of canonical loop trip counts from the creation of canonical loops. This makes it possible to reuse this logic to, for instance, populate the `__tgt_target_kernel` runtime call for SPMD kernels.

This feature is used to simplify one of the existing OpenMPIRBuilder tests.

---
Full diff: https://github.com/llvm/llvm-project/pull/127820.diff


3 Files Affected:

- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+31-7) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+18-9) 
- (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+3-13) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 9ad85413acd34..207ca7fb05f62 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -728,13 +728,12 @@ class OpenMPIRBuilder {
                       LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
                       const Twine &Name = "loop");
 
-  /// Generator for the control flow structure of an OpenMP canonical loop.
+  /// Calculate the trip count of a canonical loop.
   ///
-  /// Instead of a logical iteration space, this allows specifying user-defined
-  /// loop counter values using increment, upper- and lower bounds. To
-  /// disambiguate the terminology when counting downwards, instead of lower
-  /// bounds we use \p Start for the loop counter value in the first body
-  /// iteration.
+  /// This allows specifying user-defined loop counter values using increment,
+  /// upper- and lower bounds. To disambiguate the terminology when counting
+  /// downwards, instead of lower bounds we use \p Start for the loop counter
+  /// value in the first body iteration.
   ///
   /// Consider the following limitations:
   ///
@@ -758,7 +757,32 @@ class OpenMPIRBuilder {
   ///
   ///      for (int i = 0; i < 42; i -= 1u)
   ///
-  //
+  /// \param Loc       The insert and source location description.
+  /// \param Start     Value of the loop counter for the first iterations.
+  /// \param Stop      Loop counter values past this will stop the loop.
+  /// \param Step      Loop counter increment after each iteration; negative
+  ///                  means counting down.
+  /// \param IsSigned  Whether Start, Stop and Step are signed integers.
+  /// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
+  ///                      counter.
+  /// \param Name      Base name used to derive instruction names.
+  ///
+  /// \returns The value holding the calculated trip count.
+  Value *calculateCanonicalLoopTripCount(const LocationDescription &Loc,
+                                         Value *Start, Value *Stop, Value *Step,
+                                         bool IsSigned, bool InclusiveStop,
+                                         const Twine &Name = "loop");
+
+  /// Generator for the control flow structure of an OpenMP canonical loop.
+  ///
+  /// Instead of a logical iteration space, this allows specifying user-defined
+  /// loop counter values using increment, upper- and lower bounds. To
+  /// disambiguate the terminology when counting downwards, instead of lower
+  /// bounds we use \p Start for the loop counter value in the first body
+  ///
+  /// It calls \see calculateCanonicalLoopTripCount for trip count calculations,
+  /// so limitations of that method apply here as well.
+  ///
   /// \param Loc       The insert and source location description.
   /// \param BodyGenCB Callback that will generate the loop body code.
   /// \param Start     Value of the loop counter for the first iterations.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 7788897fc0795..eee6e3e54d615 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4059,10 +4059,9 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
   return CL;
 }
 
-Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
-    const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
-    Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
-    InsertPointTy ComputeIP, const Twine &Name) {
+Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
+    const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
+    bool IsSigned, bool InclusiveStop, const Twine &Name) {
 
   // Consider the following difficulties (assuming 8-bit signed integers):
   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
@@ -4075,9 +4074,7 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
   assert(IndVarTy == Step->getType() && "Step type mismatch");
 
-  LocationDescription ComputeLoc =
-      ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
-  updateToLocation(ComputeLoc);
+  updateToLocation(Loc);
 
   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
@@ -4117,8 +4114,20 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
     Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
   }
-  Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
-                                          "omp_" + Name + ".tripcount");
+
+  return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
+                              "omp_" + Name + ".tripcount");
+}
+
+Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
+    const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+    Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+    InsertPointTy ComputeIP, const Twine &Name) {
+  LocationDescription ComputeLoc =
+      ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
+
+  Value *TripCount = calculateCanonicalLoopTripCount(
+      ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
 
   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
     Builder.restoreIP(CodeGenIP);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 83c8f7e932b2b..9f4946a32d9b1 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -1441,8 +1441,7 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
   EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
 }
 
-TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
-  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
   OpenMPIRBuilder OMPBuilder(*M);
   OMPBuilder.initialize();
   IRBuilder<> Builder(BB);
@@ -1458,17 +1457,8 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
     Value *StartVal = ConstantInt::get(LCTy, Start);
     Value *StopVal = ConstantInt::get(LCTy, Stop);
     Value *StepVal = ConstantInt::get(LCTy, Step);
-    auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
-      return Error::success();
-    };
-    ASSERT_EXPECTED_INIT_RETURN(
-        CanonicalLoopInfo *, Loop,
-        OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
-                                       StepVal, IsSigned, InclusiveStop),
-        -1);
-    Loop->assertOK();
-    Builder.restoreIP(Loop->getAfterIP());
-    Value *TripCount = Loop->getTripCount();
+    Value *TripCount = OMPBuilder.calculateCanonicalLoopTripCount(
+        Loc, StartVal, StopVal, StepVal, IsSigned, InclusiveStop);
     return cast<ConstantInt>(TripCount)->getValue().getZExtValue();
   };
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/127820


More information about the llvm-branch-commits mailing list