[polly] 82fbc5d - [Polly] Partially refactoring of IslAstInfo and IslNodeBuilder to use isl++. NFC.

Michael Kruse via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 10 19:28:42 PDT 2021


Author: patacca
Date: 2021-04-10T21:28:02-05:00
New Revision: 82fbc5d45b0c2fc9050d1d5e335e35afb4ab2611

URL: https://github.com/llvm/llvm-project/commit/82fbc5d45b0c2fc9050d1d5e335e35afb4ab2611
DIFF: https://github.com/llvm/llvm-project/commit/82fbc5d45b0c2fc9050d1d5e335e35afb4ab2611.diff

LOG: [Polly] Partially refactoring of IslAstInfo and IslNodeBuilder to use isl++. NFC.

Polly use algorithms from the Integer Set Library (isl), which is a library written in C and which is incompatible with the rest of the LLVM as it is written in C++.

Changes made:
 - Refactoring the following methods of class IslAstInfo
   - isParallel() isExecutedInParallel() isReductionParallel() getSchedule() getMinimalDependenceDistance() getBrokenReductions()
 - Refactoring the following methods of class IslNodeBuilder
   - getReferencesInSubtree() getScheduleForAstNode()
 - Refactoring function getBrokenReductionsStr()
 - Fixed the mismatching function declaration for getScheduleForAstNode()

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D99971

Added: 
    

Modified: 
    polly/include/polly/CodeGen/IslAst.h
    polly/include/polly/CodeGen/IslNodeBuilder.h
    polly/lib/CodeGen/IslAst.cpp
    polly/lib/CodeGen/IslNodeBuilder.cpp

Removed: 
    


################################################################################
diff  --git a/polly/include/polly/CodeGen/IslAst.h b/polly/include/polly/CodeGen/IslAst.h
index 8b5b847d08213..4aa80e91ce684 100644
--- a/polly/include/polly/CodeGen/IslAst.h
+++ b/polly/include/polly/CodeGen/IslAst.h
@@ -142,7 +142,7 @@ class IslAstInfo {
   static bool isInnermost(const isl::ast_node &Node);
 
   /// Is this loop a parallel loop?
-  static bool isParallel(__isl_keep isl_ast_node *Node);
+  static bool isParallel(const isl::ast_node &Node);
 
   /// Is this loop an outermost parallel loop?
   static bool isOutermostParallel(const isl::ast_node &Node);
@@ -151,20 +151,19 @@ class IslAstInfo {
   static bool isInnermostParallel(const isl::ast_node &Node);
 
   /// Is this loop a reduction parallel loop?
-  static bool isReductionParallel(__isl_keep isl_ast_node *Node);
+  static bool isReductionParallel(const isl::ast_node &Node);
 
   /// Will the loop be run as thread parallel?
-  static bool isExecutedInParallel(__isl_keep isl_ast_node *Node);
+  static bool isExecutedInParallel(const isl::ast_node &Node);
 
   /// Get the nodes schedule or a nullptr if not available.
-  static __isl_give isl_union_map *getSchedule(__isl_keep isl_ast_node *Node);
+  static isl::union_map getSchedule(const isl::ast_node &Node);
 
   /// Get minimal dependence distance or nullptr if not available.
-  static __isl_give isl_pw_aff *
-  getMinimalDependenceDistance(__isl_keep isl_ast_node *Node);
+  static isl::pw_aff getMinimalDependenceDistance(const isl::ast_node &Node);
 
   /// Get the nodes broken reductions or a nullptr if not available.
-  static MemoryAccessSet *getBrokenReductions(__isl_keep isl_ast_node *Node);
+  static MemoryAccessSet *getBrokenReductions(const isl::ast_node &Node);
 
   /// Get the nodes build context or a nullptr if not available.
   static __isl_give isl_ast_build *getBuild(__isl_keep isl_ast_node *Node);

diff  --git a/polly/include/polly/CodeGen/IslNodeBuilder.h b/polly/include/polly/CodeGen/IslNodeBuilder.h
index 3177762c88021..bb729b8611473 100644
--- a/polly/include/polly/CodeGen/IslNodeBuilder.h
+++ b/polly/include/polly/CodeGen/IslNodeBuilder.h
@@ -248,7 +248,7 @@ class IslNodeBuilder {
   ///               this subtree.
   /// @param Loops  A vector that will be filled with the Loops referenced in
   ///               this subtree.
-  void getReferencesInSubtree(__isl_keep isl_ast_node *For,
+  void getReferencesInSubtree(const isl::ast_node &For,
                               SetVector<Value *> &Values,
                               SetVector<const Loop *> &Loops);
 
@@ -398,8 +398,7 @@ class IslNodeBuilder {
   ///         below this ast node to the scheduling vectors used to enumerate
   ///         them.
   ///
-  virtual __isl_give isl_union_map *
-  getScheduleForAstNode(__isl_take isl_ast_node *Node);
+  virtual isl::union_map getScheduleForAstNode(const isl::ast_node &Node);
 
 private:
   /// Create code for a copy statement.

diff  --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp
index 013e9b31a10df..9a70754205d85 100644
--- a/polly/lib/CodeGen/IslAst.cpp
+++ b/polly/lib/CodeGen/IslAst.cpp
@@ -140,7 +140,7 @@ static isl_printer *printLine(__isl_take isl_printer *Printer,
 }
 
 /// Return all broken reductions as a string of clauses (OpenMP style).
-static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) {
+static const std::string getBrokenReductionsStr(const isl::ast_node &Node) {
   IslAstInfo::MemoryAccessSet *BrokenReductions;
   std::string str;
 
@@ -171,25 +171,26 @@ static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) {
 static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
                                __isl_take isl_ast_print_options *Options,
                                __isl_keep isl_ast_node *Node, void *) {
-  isl_pw_aff *DD = IslAstInfo::getMinimalDependenceDistance(Node);
-  const std::string BrokenReductionsStr = getBrokenReductionsStr(Node);
+  isl::pw_aff DD =
+      IslAstInfo::getMinimalDependenceDistance(isl::manage_copy(Node));
+  const std::string BrokenReductionsStr =
+      getBrokenReductionsStr(isl::manage_copy(Node));
   const std::string KnownParallelStr = "#pragma known-parallel";
   const std::string DepDisPragmaStr = "#pragma minimal dependence distance: ";
   const std::string SimdPragmaStr = "#pragma simd";
   const std::string OmpPragmaStr = "#pragma omp parallel for";
 
-  if (DD)
-    Printer = printLine(Printer, DepDisPragmaStr, DD);
+  if (!DD.is_null())
+    Printer = printLine(Printer, DepDisPragmaStr, DD.get());
 
   if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node)))
     Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr);
 
-  if (IslAstInfo::isExecutedInParallel(Node))
+  if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node)))
     Printer = printLine(Printer, OmpPragmaStr);
   else if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node)))
     Printer = printLine(Printer, KnownParallelStr + BrokenReductionsStr);
 
-  isl_pw_aff_free(DD);
   return isl_ast_node_for_print(Node, Printer, Options);
 }
 
@@ -472,15 +473,15 @@ static void walkAstForStatistics(__isl_keep isl_ast_node *Ast) {
         switch (isl_ast_node_get_type(Node)) {
         case isl_ast_node_for:
           NumForLoops++;
-          if (IslAstInfo::isParallel(Node))
+          if (IslAstInfo::isParallel(isl::manage_copy(Node)))
             NumParallel++;
           if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node)))
             NumInnermostParallel++;
           if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node)))
             NumOutermostParallel++;
-          if (IslAstInfo::isReductionParallel(Node))
+          if (IslAstInfo::isReductionParallel(isl::manage_copy(Node)))
             NumReductionParallel++;
-          if (IslAstInfo::isExecutedInParallel(Node))
+          if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node)))
             NumExecutedInParallel++;
           break;
 
@@ -593,9 +594,9 @@ bool IslAstInfo::isInnermost(const isl::ast_node &Node) {
   return Payload && Payload->IsInnermost;
 }
 
-bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) {
-  return IslAstInfo::isInnermostParallel(isl::manage_copy(Node)) ||
-         IslAstInfo::isOutermostParallel(isl::manage_copy(Node));
+bool IslAstInfo::isParallel(const isl::ast_node &Node) {
+  return IslAstInfo::isInnermostParallel(Node) ||
+         IslAstInfo::isOutermostParallel(Node);
 }
 
 bool IslAstInfo::isInnermostParallel(const isl::ast_node &Node) {
@@ -608,12 +609,12 @@ bool IslAstInfo::isOutermostParallel(const isl::ast_node &Node) {
   return Payload && Payload->IsOutermostParallel;
 }
 
-bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) {
-  IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node));
+bool IslAstInfo::isReductionParallel(const isl::ast_node &Node) {
+  IslAstUserPayload *Payload = getNodePayload(Node);
   return Payload && Payload->IsReductionParallel;
 }
 
-bool IslAstInfo::isExecutedInParallel(__isl_keep isl_ast_node *Node) {
+bool IslAstInfo::isExecutedInParallel(const isl::ast_node &Node) {
   if (!PollyParallel)
     return false;
 
@@ -626,28 +627,30 @@ bool IslAstInfo::isExecutedInParallel(__isl_keep isl_ast_node *Node) {
   //       executed. This can possibly require run-time checks, which again
   //       raises the question of both run-time check overhead and code size
   //       costs.
-  if (!PollyParallelForce && isInnermost(isl::manage_copy(Node)))
+  if (!PollyParallelForce && isInnermost(Node))
     return false;
 
-  return isOutermostParallel(isl::manage_copy(Node)) &&
-         !isReductionParallel(Node);
+  return isOutermostParallel(Node) && !isReductionParallel(Node);
 }
 
-__isl_give isl_union_map *
-IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) {
-  IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node));
-  return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr;
+isl::union_map IslAstInfo::getSchedule(const isl::ast_node &Node) {
+  IslAstUserPayload *Payload = getNodePayload(Node);
+  if (!Payload)
+    return nullptr;
+
+  isl::ast_build Build = isl::manage_copy(Payload->Build);
+  return Build.get_schedule();
 }
 
-__isl_give isl_pw_aff *
-IslAstInfo::getMinimalDependenceDistance(__isl_keep isl_ast_node *Node) {
-  IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node));
-  return Payload ? Payload->MinimalDependenceDistance.copy() : nullptr;
+isl::pw_aff
+IslAstInfo::getMinimalDependenceDistance(const isl::ast_node &Node) {
+  IslAstUserPayload *Payload = getNodePayload(Node);
+  return Payload ? Payload->MinimalDependenceDistance : nullptr;
 }
 
 IslAstInfo::MemoryAccessSet *
-IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) {
-  IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node));
+IslAstInfo::getBrokenReductions(const isl::ast_node &Node) {
+  IslAstUserPayload *Payload = getNodePayload(Node);
   return Payload ? &Payload->BrokenReductions : nullptr;
 }
 

diff  --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp
index fd0d50ac22335..e9675c9a56831 100644
--- a/polly/lib/CodeGen/IslNodeBuilder.cpp
+++ b/polly/lib/CodeGen/IslNodeBuilder.cpp
@@ -300,12 +300,12 @@ addReferencesFromStmtUnionSet(isl::union_set USet,
     addReferencesFromStmtSet(Set, &References);
 }
 
-__isl_give isl_union_map *
-IslNodeBuilder::getScheduleForAstNode(__isl_keep isl_ast_node *For) {
-  return IslAstInfo::getSchedule(For);
+isl::union_map
+IslNodeBuilder::getScheduleForAstNode(const isl::ast_node &Node) {
+  return IslAstInfo::getSchedule(Node);
 }
 
-void IslNodeBuilder::getReferencesInSubtree(__isl_keep isl_ast_node *For,
+void IslNodeBuilder::getReferencesInSubtree(const isl::ast_node &For,
                                             SetVector<Value *> &Values,
                                             SetVector<const Loop *> &Loops) {
   SetVector<const SCEV *> SCEVs;
@@ -319,8 +319,7 @@ void IslNodeBuilder::getReferencesInSubtree(__isl_keep isl_ast_node *For,
   for (const auto &I : OutsideLoopIterations)
     Values.insert(cast<SCEVUnknown>(I.second)->getValue());
 
-  isl::union_set Schedule =
-      isl::manage(isl_union_map_domain(getScheduleForAstNode(For)));
+  isl::union_set Schedule = getScheduleForAstNode(For).domain();
   addReferencesFromStmtUnionSet(Schedule, References);
 
   for (const SCEV *Expr : SCEVs) {
@@ -476,22 +475,22 @@ void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For,
   for (int i = 1; i < VectorWidth; i++)
     IVS[i] = Builder.CreateAdd(IVS[i - 1], ValueInc, "p_vector_iv");
 
-  isl_union_map *Schedule = getScheduleForAstNode(For);
-  assert(Schedule && "For statement annotation does not contain its schedule");
+  isl::union_map Schedule = getScheduleForAstNode(isl::manage_copy(For));
+  assert(!Schedule.is_null() &&
+         "For statement annotation does not contain its schedule");
 
   IDToValue[IteratorID] = ValueLB;
 
   switch (isl_ast_node_get_type(Body)) {
   case isl_ast_node_user:
-    createUserVector(Body, IVS, isl_id_copy(IteratorID),
-                     isl_union_map_copy(Schedule));
+    createUserVector(Body, IVS, isl_id_copy(IteratorID), Schedule.copy());
     break;
   case isl_ast_node_block: {
     isl_ast_node_list *List = isl_ast_node_block_get_children(Body);
 
     for (int i = 0; i < isl_ast_node_list_n_ast_node(List); ++i)
       createUserVector(isl_ast_node_list_get_ast_node(List, i), IVS,
-                       isl_id_copy(IteratorID), isl_union_map_copy(Schedule));
+                       isl_id_copy(IteratorID), Schedule.copy());
 
     isl_ast_node_free(Body);
     isl_ast_node_list_free(List);
@@ -504,7 +503,6 @@ void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For,
 
   IDToValue.erase(IDToValue.find(IteratorID));
   isl_id_free(IteratorID);
-  isl_union_map_free(Schedule);
 
   isl_ast_node_free(For);
   isl_ast_expr_free(Iterator);
@@ -685,7 +683,7 @@ void IslNodeBuilder::createForParallel(__isl_take isl_ast_node *For) {
   SetVector<Value *> SubtreeValues;
   SetVector<const Loop *> Loops;
 
-  getReferencesInSubtree(For, SubtreeValues, Loops);
+  getReferencesInSubtree(isl::manage_copy(For), SubtreeValues, Loops);
 
   // Create for all loops we depend on values that contain the current loop
   // iteration. These values are necessary to generate code for SCEVs that
@@ -783,7 +781,7 @@ void IslNodeBuilder::createFor(__isl_take isl_ast_node *For) {
   bool Vector = PollyVectorizerChoice == VECTORIZER_POLLY;
 
   if (Vector && IslAstInfo::isInnermostParallel(isl::manage_copy(For)) &&
-      !IslAstInfo::isReductionParallel(For)) {
+      !IslAstInfo::isReductionParallel(isl::manage_copy(For))) {
     int VectorWidth = getNumberOfIterations(isl::manage_copy(For));
     if (1 < VectorWidth && VectorWidth <= 16 && !hasPartialAccesses(For)) {
       createForVector(For, VectorWidth);
@@ -791,12 +789,12 @@ void IslNodeBuilder::createFor(__isl_take isl_ast_node *For) {
     }
   }
 
-  if (IslAstInfo::isExecutedInParallel(For)) {
+  if (IslAstInfo::isExecutedInParallel(isl::manage_copy(For))) {
     createForParallel(For);
     return;
   }
-  bool Parallel =
-      (IslAstInfo::isParallel(For) && !IslAstInfo::isReductionParallel(For));
+  bool Parallel = (IslAstInfo::isParallel(isl::manage_copy(For)) &&
+                   !IslAstInfo::isReductionParallel(isl::manage_copy(For)));
   createForSequential(isl::manage(For), Parallel);
 }
 


        


More information about the llvm-commits mailing list