[polly] r214489 - Annotate the IslAst with broken reductions

Johannes Doerfert jdoerfert at codeaurora.org
Fri Aug 1 01:17:19 PDT 2014


Author: jdoerfert
Date: Fri Aug  1 03:17:19 2014
New Revision: 214489

URL: http://llvm.org/viewvc/llvm-project?rev=214489&view=rev
Log:
Annotate the IslAst with broken reductions

  + Split all reduction dependences and map them to the causing memory accesses.
  + Print the types & base addresses of broken reductions for each "reduction
    parallel" marked loop (OpenMP style).
  + 3 test cases to show how reductions are now represented in the isl ast.

  The mapping "(ast) loops -> broken reductions" is also needed to find the
  memory accesses we need to privatize in a loop.


Modified:
    polly/trunk/include/polly/CodeGen/IslAst.h
    polly/trunk/include/polly/Dependences.h
    polly/trunk/lib/Analysis/Dependences.cpp
    polly/trunk/lib/CodeGen/IslAst.cpp

Modified: polly/trunk/include/polly/CodeGen/IslAst.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/CodeGen/IslAst.h?rev=214489&r1=214488&r2=214489&view=diff
==============================================================================
--- polly/trunk/include/polly/CodeGen/IslAst.h (original)
+++ polly/trunk/include/polly/CodeGen/IslAst.h Fri Aug  1 03:17:19 2014
@@ -34,15 +34,19 @@ class raw_ostream;
 struct isl_ast_node;
 struct isl_ast_expr;
 struct isl_ast_build;
+struct isl_union_map;
 struct isl_pw_multi_aff;
 
 namespace polly {
 class Scop;
 class IslAst;
+class MemoryAccess;
 
 class IslAstInfo : public ScopPass {
 public:
-  /// @brief Payload information used to annoate an ast node.
+  using MemoryAccessSet = SmallPtrSet<MemoryAccess *, 4>;
+
+  /// @brief Payload information used to annotate an AST node.
   struct IslAstUserPayload {
     /// @brief Construct and initialize the payload.
     IslAstUserPayload()
@@ -67,6 +71,9 @@ public:
 
     /// @brief The build environment at the time this node was constructed.
     isl_ast_build *Build;
+
+    /// @brief Set of accesses which break reduction dependences.
+    MemoryAccessSet BrokenReductions;
   };
 
 private:
@@ -119,6 +126,9 @@ public:
   /// @brief Get the nodes schedule or a nullptr if not available.
   static __isl_give isl_union_map *getSchedule(__isl_keep isl_ast_node *Node);
 
+  /// @brief Get the nodes broken reductions or a nullptr if not available.
+  static MemoryAccessSet *getBrokenReductions(__isl_keep isl_ast_node *Node);
+
   ///}
 
   virtual void getAnalysisUsage(AnalysisUsage &AU) const;

Modified: polly/trunk/include/polly/Dependences.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/Dependences.h?rev=214489&r1=214488&r2=214489&view=diff
==============================================================================
--- polly/trunk/include/polly/Dependences.h (original)
+++ polly/trunk/include/polly/Dependences.h Fri Aug  1 03:17:19 2014
@@ -40,6 +40,7 @@ namespace polly {
 
 class Scop;
 class ScopStmt;
+class MemoryAccess;
 
 class Dependences : public ScopPass {
 public:
@@ -105,6 +106,16 @@ public:
   /// @brief Report if valid dependences are available.
   bool hasValidDependences();
 
+  /// @brief Return the reduction dependences caused by @p MA.
+  ///
+  /// @return The reduction dependences caused by @p MA or nullptr if None.
+  __isl_give isl_map *getReductionDependences(MemoryAccess *MA);
+
+  /// @brief Return the reduction dependences mapped by the causing @p MA.
+  const DenseMap<MemoryAccess *, isl_map *> &getReductionDependences() const {
+    return ReductionDependences;
+  }
+
   bool runOnScop(Scop &S);
   void printScop(raw_ostream &OS) const;
   virtual void releaseMemory();
@@ -122,6 +133,9 @@ private:
   /// @brief The (reverse) transitive closure of reduction dependences
   isl_union_map *TC_RED = nullptr;
 
+  /// @brief Map from memory accesses to their reduction dependences.
+  DenseMap<MemoryAccess *, isl_map *> ReductionDependences;
+
   /// @brief Collect information about the SCoP.
   void collectInfo(Scop &S, isl_union_map **Read, isl_union_map **Write,
                    isl_union_map **MayWrite, isl_union_map **AccessSchedule,
@@ -132,6 +146,10 @@ private:
 
   /// @brief Calculate the dependences for a certain SCoP.
   void calculateDependences(Scop &S);
+
+  /// @brief Set the reduction dependences for @p MA to @p Deps.
+  void setReductionDependences(MemoryAccess *MA, __isl_take isl_map *Deps);
+
 };
 
 } // End polly namespace.

Modified: polly/trunk/lib/Analysis/Dependences.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Analysis/Dependences.cpp?rev=214489&r1=214488&r2=214489&view=diff
==============================================================================
--- polly/trunk/lib/Analysis/Dependences.cpp (original)
+++ polly/trunk/lib/Analysis/Dependences.cpp Fri Aug  1 03:17:19 2014
@@ -355,6 +355,42 @@ void Dependences::calculateDependences(S
   DEBUG(dbgs() << "Final Wrapped Dependences:\n"; printScop(dbgs());
         dbgs() << "\n");
 
+  // RED_SIN is used to collect all reduction dependences again after we
+  // split them according to the causing memory accesses. The current assumption
+  // is that our method of splitting will not have any leftovers. In the end
+  // we validate this assumption until we have more confidence in this method.
+  isl_union_map *RED_SIN = isl_union_map_empty(isl_union_map_get_space(RAW));
+
+  // For each reduction like memory access, check if there are reduction
+  // dependences with the access relation of the memory access as a domain
+  // (wrapped space!). If so these dependences are caused by this memory access.
+  // We then move this portion of reduction dependences back to the statement ->
+  // statement space and add a mapping from the memory access to these
+  // dependences.
+  for (ScopStmt *Stmt : S) {
+    for (MemoryAccess *MA : *Stmt) {
+      if (!MA->isReductionLike())
+        continue;
+
+      isl_set *AccDomW = isl_map_wrap(MA->getAccessRelation());
+      isl_union_map *AccRedDepU = isl_union_map_intersect_domain(
+          isl_union_map_copy(TC_RED), isl_union_set_from_set(AccDomW));
+      if (isl_union_map_is_empty(AccRedDepU) && !isl_union_map_free(AccRedDepU))
+        continue;
+
+      isl_map *AccRedDep = isl_map_from_union_map(AccRedDepU);
+      RED_SIN = isl_union_map_add_map(RED_SIN, isl_map_copy(AccRedDep));
+      AccRedDep = isl_map_zip(AccRedDep);
+      AccRedDep = isl_set_unwrap(isl_map_domain(AccRedDep));
+      setReductionDependences(MA, AccRedDep);
+    }
+  }
+
+  assert(isl_union_map_is_equal(RED_SIN, TC_RED) &&
+         "Intersecting the reduction dependence domain with the wrapped access "
+         "relation is not enough, we need to loosen the access relation also");
+  isl_union_map_free(RED_SIN);
+
   RAW = isl_union_map_zip(RAW);
   WAW = isl_union_map_zip(WAW);
   WAR = isl_union_map_zip(WAR);
@@ -506,6 +542,10 @@ void Dependences::releaseMemory() {
   isl_union_map_free(TC_RED);
 
   RED = RAW = WAR = WAW = TC_RED = nullptr;
+
+  for (auto &ReductionDeps : ReductionDependences)
+    isl_map_free(ReductionDeps.second);
+  ReductionDependences.clear();
 }
 
 isl_union_map *Dependences::getDependences(int Kinds) {
@@ -537,6 +577,16 @@ bool Dependences::hasValidDependences()
   return (RAW != nullptr) && (WAR != nullptr) && (WAW != nullptr);
 }
 
+isl_map *Dependences::getReductionDependences(MemoryAccess *MA) {
+  return isl_map_copy(ReductionDependences[MA]);
+}
+
+void Dependences::setReductionDependences(MemoryAccess *MA, isl_map *D) {
+  assert(ReductionDependences.count(MA) == 0 &&
+         "Reduction dependences set twice!");
+  ReductionDependences[MA] = D;
+}
+
 void Dependences::getAnalysisUsage(AnalysisUsage &AU) const {
   ScopPass::getAnalysisUsage(AU);
 }

Modified: polly/trunk/lib/CodeGen/IslAst.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/CodeGen/IslAst.cpp?rev=214489&r1=214488&r2=214489&view=diff
==============================================================================
--- polly/trunk/lib/CodeGen/IslAst.cpp (original)
+++ polly/trunk/lib/CodeGen/IslAst.cpp Fri Aug  1 03:17:19 2014
@@ -108,23 +108,48 @@ static isl_printer *printLine(__isl_take
   return isl_printer_end_line(Printer);
 }
 
+/// @brief Return all broken reductions as a string of clauses (OpenMP style).
+static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) {
+  IslAstInfo::MemoryAccessSet *BrokenReductions;
+  std::string str;
+
+  BrokenReductions = IslAstInfo::getBrokenReductions(Node);
+  if (!BrokenReductions || BrokenReductions->empty())
+    return "";
+
+  // Map each type of reduction to a comma separated list of the base addresses.
+  std::map<MemoryAccess::ReductionType, std::string> Clauses;
+  for (MemoryAccess *MA : *BrokenReductions)
+    if (MA->isWrite())
+      Clauses[MA->getReductionType()] +=
+          ", " + MA->getBaseAddr()->getName().str();
+
+  // Now print the reductions sorted by type. Each type will cause a clause
+  // like:  reduction (+ : sum0, sum1, sum2)
+  for (const auto &ReductionClause : Clauses) {
+    str += " reduction (";
+    str += MemoryAccess::getReductionOperatorStr(ReductionClause.first);
+    // Remove the first two symbols (", ") to make the output look pretty.
+    str += " : " + ReductionClause.second.substr(2) + ")";
+  }
+
+  return str;
+}
+
 /// @brief Callback executed for each for node in the ast in order to print it.
 static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
                                __isl_take isl_ast_print_options *Options,
                                __isl_keep isl_ast_node *Node, void *) {
-  if (IslAstInfo::isInnermostParallel(Node) &&
-      !IslAstInfo::isReductionParallel(Node))
-    Printer = printLine(Printer, "#pragma simd");
-
-  if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
-    Printer = printLine(Printer, "#pragma simd reduction");
-
-  if (IslAstInfo::isOutermostParallel(Node) &&
-      !IslAstInfo::isReductionParallel(Node))
-    Printer = printLine(Printer, "#pragma omp parallel for");
 
-  if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
-    Printer = printLine(Printer, "#pragma omp parallel for reduction");
+  const std::string BrokenReductionsStr = getBrokenReductionsStr(Node);
+  const std::string SimdPragmaStr = "#pragma simd";
+  const std::string OmpPragmaStr = "#pragma omp parallel for";
+
+  if (IslAstInfo::isInnermostParallel(Node))
+    Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr);
+
+  if (IslAstInfo::isOutermostParallel(Node))
+    Printer = printLine(Printer, OmpPragmaStr + BrokenReductionsStr);
 
   return isl_ast_node_for_print(Node, Printer, Options);
 }
@@ -141,7 +166,7 @@ static isl_printer *cbPrintFor(__isl_tak
 /// (or non-zero) dependence distance on the dimension in question.
 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build,
                                      Dependences *D,
-                                     bool &IsReductionParallel) {
+                                     IslAstUserPayload *NodeInfo) {
   if (!D->hasValidDependences())
     return false;
 
@@ -153,7 +178,20 @@ static bool astScheduleDimIsParallel(__i
 
   isl_union_map *RedDeps = D->getDependences(Dependences::TYPE_TC_RED);
   if (!D->isParallel(Schedule, RedDeps))
-    IsReductionParallel = true;
+    NodeInfo->IsReductionParallel = true;
+
+  if (!NodeInfo->IsReductionParallel && !isl_union_map_free(Schedule))
+    return true;
+
+  // Annotate reduction parallel nodes with the memory accesses which caused the
+  // reduction dependences parallel execution of the node conflicts with.
+  for (const auto &MaRedPair : D->getReductionDependences()) {
+    if (!MaRedPair.second)
+      continue;
+    RedDeps = isl_union_map_from_map(isl_map_copy(MaRedPair.second));
+    if (!D->isParallel(Schedule, RedDeps))
+      NodeInfo->BrokenReductions.insert(MaRedPair.first);
+  }
 
   isl_union_map_free(Schedule);
   return true;
@@ -177,8 +215,7 @@ static __isl_give isl_id *astBuildBefore
   // Test for parallelism only if we are not already inside a parallel loop
   if (!BuildInfo->InParallelFor)
     BuildInfo->InParallelFor = Payload->IsOutermostParallel =
-        astScheduleDimIsParallel(Build, BuildInfo->Deps,
-                                 Payload->IsReductionParallel);
+        astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload);
 
   return Id;
 }
@@ -206,13 +243,13 @@ astBuildAfterFor(__isl_take isl_ast_node
   // Innermost loops that are surrounded by parallel loops have not yet been
   // tested for parallelism. Test them here to ensure we check all innermost
   // loops for parallelism.
-  if (Payload->IsInnermost && BuildInfo->InParallelFor)
+  if (Payload->IsInnermost && BuildInfo->InParallelFor) {
     if (Payload->IsOutermostParallel)
       Payload->IsInnermostParallel = true;
     else
-      Payload->IsInnermostParallel = astScheduleDimIsParallel(
-          Build, BuildInfo->Deps, Payload->IsReductionParallel);
-  else if (Payload->IsOutermostParallel)
+      Payload->IsInnermostParallel =
+          astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload);
+  } else if (Payload->IsOutermostParallel)
     BuildInfo->InParallelFor = false;
 
   isl_id_free(Id);
@@ -370,6 +407,12 @@ isl_union_map *IslAstInfo::getSchedule(_
   return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr;
 }
 
+IslAstInfo::MemoryAccessSet *
+IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) {
+  IslAstUserPayload *Payload = getNodePayload(Node);
+  return Payload ? &Payload->BrokenReductions : nullptr;
+}
+
 void IslAstInfo::printScop(raw_ostream &OS) const {
   isl_ast_print_options *Options;
   isl_ast_node *RootNode = getAst();





More information about the llvm-commits mailing list