[polly] r214489 - Annotate the IslAst with broken reductions
Johannes Doerfert
jdoerfert at codeaurora.org
Fri Aug 1 01:17:19 PDT 2014
Author: jdoerfert
Date: Fri Aug 1 03:17:19 2014
New Revision: 214489
URL: http://llvm.org/viewvc/llvm-project?rev=214489&view=rev
Log:
Annotate the IslAst with broken reductions
+ Split all reduction dependences and map them to the causing memory accesses.
+ Print the types & base addresses of broken reductions for each "reduction
parallel" marked loop (OpenMP style).
+ 3 test cases to show how reductions are now represented in the isl ast.
The mapping "(ast) loops -> broken reductions" is also needed to find the
memory accesses we need to privatize in a loop.
Modified:
polly/trunk/include/polly/CodeGen/IslAst.h
polly/trunk/include/polly/Dependences.h
polly/trunk/lib/Analysis/Dependences.cpp
polly/trunk/lib/CodeGen/IslAst.cpp
Modified: polly/trunk/include/polly/CodeGen/IslAst.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/CodeGen/IslAst.h?rev=214489&r1=214488&r2=214489&view=diff
==============================================================================
--- polly/trunk/include/polly/CodeGen/IslAst.h (original)
+++ polly/trunk/include/polly/CodeGen/IslAst.h Fri Aug 1 03:17:19 2014
@@ -34,15 +34,19 @@ class raw_ostream;
struct isl_ast_node;
struct isl_ast_expr;
struct isl_ast_build;
+struct isl_union_map;
struct isl_pw_multi_aff;
namespace polly {
class Scop;
class IslAst;
+class MemoryAccess;
class IslAstInfo : public ScopPass {
public:
- /// @brief Payload information used to annoate an ast node.
+ using MemoryAccessSet = SmallPtrSet<MemoryAccess *, 4>;
+
+ /// @brief Payload information used to annotate an AST node.
struct IslAstUserPayload {
/// @brief Construct and initialize the payload.
IslAstUserPayload()
@@ -67,6 +71,9 @@ public:
/// @brief The build environment at the time this node was constructed.
isl_ast_build *Build;
+
+ /// @brief Set of accesses which break reduction dependences.
+ MemoryAccessSet BrokenReductions;
};
private:
@@ -119,6 +126,9 @@ public:
/// @brief Get the nodes schedule or a nullptr if not available.
static __isl_give isl_union_map *getSchedule(__isl_keep isl_ast_node *Node);
+ /// @brief Get the nodes broken reductions or a nullptr if not available.
+ static MemoryAccessSet *getBrokenReductions(__isl_keep isl_ast_node *Node);
+
///}
virtual void getAnalysisUsage(AnalysisUsage &AU) const;
Modified: polly/trunk/include/polly/Dependences.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/Dependences.h?rev=214489&r1=214488&r2=214489&view=diff
==============================================================================
--- polly/trunk/include/polly/Dependences.h (original)
+++ polly/trunk/include/polly/Dependences.h Fri Aug 1 03:17:19 2014
@@ -40,6 +40,7 @@ namespace polly {
class Scop;
class ScopStmt;
+class MemoryAccess;
class Dependences : public ScopPass {
public:
@@ -105,6 +106,16 @@ public:
/// @brief Report if valid dependences are available.
bool hasValidDependences();
+ /// @brief Return the reduction dependences caused by @p MA.
+ ///
+ /// @return The reduction dependences caused by @p MA or nullptr if None.
+ __isl_give isl_map *getReductionDependences(MemoryAccess *MA);
+
+ /// @brief Return the reduction dependences mapped by the causing @p MA.
+ const DenseMap<MemoryAccess *, isl_map *> &getReductionDependences() const {
+ return ReductionDependences;
+ }
+
bool runOnScop(Scop &S);
void printScop(raw_ostream &OS) const;
virtual void releaseMemory();
@@ -122,6 +133,9 @@ private:
/// @brief The (reverse) transitive closure of reduction dependences
isl_union_map *TC_RED = nullptr;
+ /// @brief Map from memory accesses to their reduction dependences.
+ DenseMap<MemoryAccess *, isl_map *> ReductionDependences;
+
/// @brief Collect information about the SCoP.
void collectInfo(Scop &S, isl_union_map **Read, isl_union_map **Write,
isl_union_map **MayWrite, isl_union_map **AccessSchedule,
@@ -132,6 +146,10 @@ private:
/// @brief Calculate the dependences for a certain SCoP.
void calculateDependences(Scop &S);
+
+ /// @brief Set the reduction dependences for @p MA to @p Deps.
+ void setReductionDependences(MemoryAccess *MA, __isl_take isl_map *Deps);
+
};
} // End polly namespace.
Modified: polly/trunk/lib/Analysis/Dependences.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Analysis/Dependences.cpp?rev=214489&r1=214488&r2=214489&view=diff
==============================================================================
--- polly/trunk/lib/Analysis/Dependences.cpp (original)
+++ polly/trunk/lib/Analysis/Dependences.cpp Fri Aug 1 03:17:19 2014
@@ -355,6 +355,42 @@ void Dependences::calculateDependences(S
DEBUG(dbgs() << "Final Wrapped Dependences:\n"; printScop(dbgs());
dbgs() << "\n");
+ // RED_SIN is used to collect all reduction dependences again after we
+ // split them according to the causing memory accesses. The current assumption
+ // is that our method of splitting will not have any leftovers. In the end
+ // we validate this assumption until we have more confidence in this method.
+ isl_union_map *RED_SIN = isl_union_map_empty(isl_union_map_get_space(RAW));
+
+ // For each reduction like memory access, check if there are reduction
+ // dependences with the access relation of the memory access as a domain
+ // (wrapped space!). If so these dependences are caused by this memory access.
+ // We then move this portion of reduction dependences back to the statement ->
+ // statement space and add a mapping from the memory access to these
+ // dependences.
+ for (ScopStmt *Stmt : S) {
+ for (MemoryAccess *MA : *Stmt) {
+ if (!MA->isReductionLike())
+ continue;
+
+ isl_set *AccDomW = isl_map_wrap(MA->getAccessRelation());
+ isl_union_map *AccRedDepU = isl_union_map_intersect_domain(
+ isl_union_map_copy(TC_RED), isl_union_set_from_set(AccDomW));
+ if (isl_union_map_is_empty(AccRedDepU) && !isl_union_map_free(AccRedDepU))
+ continue;
+
+ isl_map *AccRedDep = isl_map_from_union_map(AccRedDepU);
+ RED_SIN = isl_union_map_add_map(RED_SIN, isl_map_copy(AccRedDep));
+ AccRedDep = isl_map_zip(AccRedDep);
+ AccRedDep = isl_set_unwrap(isl_map_domain(AccRedDep));
+ setReductionDependences(MA, AccRedDep);
+ }
+ }
+
+ assert(isl_union_map_is_equal(RED_SIN, TC_RED) &&
+ "Intersecting the reduction dependence domain with the wrapped access "
+ "relation is not enough, we need to loosen the access relation also");
+ isl_union_map_free(RED_SIN);
+
RAW = isl_union_map_zip(RAW);
WAW = isl_union_map_zip(WAW);
WAR = isl_union_map_zip(WAR);
@@ -506,6 +542,10 @@ void Dependences::releaseMemory() {
isl_union_map_free(TC_RED);
RED = RAW = WAR = WAW = TC_RED = nullptr;
+
+ for (auto &ReductionDeps : ReductionDependences)
+ isl_map_free(ReductionDeps.second);
+ ReductionDependences.clear();
}
isl_union_map *Dependences::getDependences(int Kinds) {
@@ -537,6 +577,16 @@ bool Dependences::hasValidDependences()
return (RAW != nullptr) && (WAR != nullptr) && (WAW != nullptr);
}
+isl_map *Dependences::getReductionDependences(MemoryAccess *MA) {
+ return isl_map_copy(ReductionDependences[MA]);
+}
+
+void Dependences::setReductionDependences(MemoryAccess *MA, isl_map *D) {
+ assert(ReductionDependences.count(MA) == 0 &&
+ "Reduction dependences set twice!");
+ ReductionDependences[MA] = D;
+}
+
void Dependences::getAnalysisUsage(AnalysisUsage &AU) const {
ScopPass::getAnalysisUsage(AU);
}
Modified: polly/trunk/lib/CodeGen/IslAst.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/CodeGen/IslAst.cpp?rev=214489&r1=214488&r2=214489&view=diff
==============================================================================
--- polly/trunk/lib/CodeGen/IslAst.cpp (original)
+++ polly/trunk/lib/CodeGen/IslAst.cpp Fri Aug 1 03:17:19 2014
@@ -108,23 +108,48 @@ static isl_printer *printLine(__isl_take
return isl_printer_end_line(Printer);
}
+/// @brief Return all broken reductions as a string of clauses (OpenMP style).
+static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) {
+ IslAstInfo::MemoryAccessSet *BrokenReductions;
+ std::string str;
+
+ BrokenReductions = IslAstInfo::getBrokenReductions(Node);
+ if (!BrokenReductions || BrokenReductions->empty())
+ return "";
+
+ // Map each type of reduction to a comma separated list of the base addresses.
+ std::map<MemoryAccess::ReductionType, std::string> Clauses;
+ for (MemoryAccess *MA : *BrokenReductions)
+ if (MA->isWrite())
+ Clauses[MA->getReductionType()] +=
+ ", " + MA->getBaseAddr()->getName().str();
+
+ // Now print the reductions sorted by type. Each type will cause a clause
+ // like: reduction (+ : sum0, sum1, sum2)
+ for (const auto &ReductionClause : Clauses) {
+ str += " reduction (";
+ str += MemoryAccess::getReductionOperatorStr(ReductionClause.first);
+ // Remove the first two symbols (", ") to make the output look pretty.
+ str += " : " + ReductionClause.second.substr(2) + ")";
+ }
+
+ return str;
+}
+
/// @brief Callback executed for each for node in the ast in order to print it.
static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
__isl_take isl_ast_print_options *Options,
__isl_keep isl_ast_node *Node, void *) {
- if (IslAstInfo::isInnermostParallel(Node) &&
- !IslAstInfo::isReductionParallel(Node))
- Printer = printLine(Printer, "#pragma simd");
-
- if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
- Printer = printLine(Printer, "#pragma simd reduction");
-
- if (IslAstInfo::isOutermostParallel(Node) &&
- !IslAstInfo::isReductionParallel(Node))
- Printer = printLine(Printer, "#pragma omp parallel for");
- if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
- Printer = printLine(Printer, "#pragma omp parallel for reduction");
+ const std::string BrokenReductionsStr = getBrokenReductionsStr(Node);
+ const std::string SimdPragmaStr = "#pragma simd";
+ const std::string OmpPragmaStr = "#pragma omp parallel for";
+
+ if (IslAstInfo::isInnermostParallel(Node))
+ Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr);
+
+ if (IslAstInfo::isOutermostParallel(Node))
+ Printer = printLine(Printer, OmpPragmaStr + BrokenReductionsStr);
return isl_ast_node_for_print(Node, Printer, Options);
}
@@ -141,7 +166,7 @@ static isl_printer *cbPrintFor(__isl_tak
/// (or non-zero) dependence distance on the dimension in question.
static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build,
Dependences *D,
- bool &IsReductionParallel) {
+ IslAstUserPayload *NodeInfo) {
if (!D->hasValidDependences())
return false;
@@ -153,7 +178,20 @@ static bool astScheduleDimIsParallel(__i
isl_union_map *RedDeps = D->getDependences(Dependences::TYPE_TC_RED);
if (!D->isParallel(Schedule, RedDeps))
- IsReductionParallel = true;
+ NodeInfo->IsReductionParallel = true;
+
+ if (!NodeInfo->IsReductionParallel && !isl_union_map_free(Schedule))
+ return true;
+
+ // Annotate reduction parallel nodes with the memory accesses which caused the
+ // reduction dependences parallel execution of the node conflicts with.
+ for (const auto &MaRedPair : D->getReductionDependences()) {
+ if (!MaRedPair.second)
+ continue;
+ RedDeps = isl_union_map_from_map(isl_map_copy(MaRedPair.second));
+ if (!D->isParallel(Schedule, RedDeps))
+ NodeInfo->BrokenReductions.insert(MaRedPair.first);
+ }
isl_union_map_free(Schedule);
return true;
@@ -177,8 +215,7 @@ static __isl_give isl_id *astBuildBefore
// Test for parallelism only if we are not already inside a parallel loop
if (!BuildInfo->InParallelFor)
BuildInfo->InParallelFor = Payload->IsOutermostParallel =
- astScheduleDimIsParallel(Build, BuildInfo->Deps,
- Payload->IsReductionParallel);
+ astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload);
return Id;
}
@@ -206,13 +243,13 @@ astBuildAfterFor(__isl_take isl_ast_node
// Innermost loops that are surrounded by parallel loops have not yet been
// tested for parallelism. Test them here to ensure we check all innermost
// loops for parallelism.
- if (Payload->IsInnermost && BuildInfo->InParallelFor)
+ if (Payload->IsInnermost && BuildInfo->InParallelFor) {
if (Payload->IsOutermostParallel)
Payload->IsInnermostParallel = true;
else
- Payload->IsInnermostParallel = astScheduleDimIsParallel(
- Build, BuildInfo->Deps, Payload->IsReductionParallel);
- else if (Payload->IsOutermostParallel)
+ Payload->IsInnermostParallel =
+ astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload);
+ } else if (Payload->IsOutermostParallel)
BuildInfo->InParallelFor = false;
isl_id_free(Id);
@@ -370,6 +407,12 @@ isl_union_map *IslAstInfo::getSchedule(_
return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr;
}
+IslAstInfo::MemoryAccessSet *
+IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) {
+ IslAstUserPayload *Payload = getNodePayload(Node);
+ return Payload ? &Payload->BrokenReductions : nullptr;
+}
+
void IslAstInfo::printScop(raw_ostream &OS) const {
isl_ast_print_options *Options;
isl_ast_node *RootNode = getAst();
More information about the llvm-commits
mailing list