[clang] [clang-tools-extra] [llvm] [polly] Reduction series : Refactor reduction detection code (PR #72343)

Karthika Devi C via cfe-commits cfe-commits at lists.llvm.org
Tue Nov 14 20:03:55 PST 2023


https://github.com/kartcq created https://github.com/llvm/llvm-project/pull/72343

None

>From 3a8c2b0b7485ea6ee3d45b87132554a2b812aa50 Mon Sep 17 00:00:00 2001
From: kartcq <quic_kartc at quicinc.com>
Date: Mon, 6 Nov 2023 03:42:09 -0800
Subject: [PATCH] [polly][NFC] Refactor reduction detection code for modularity

This patch pulls out the memory checks from the base reduction detection
algorithm. This is the first one in the reduction patch series, to
reduce diff in future patches.
---
 polly/lib/Analysis/ScopBuilder.cpp | 78 ++++++++++++++++++------------
 1 file changed, 48 insertions(+), 30 deletions(-)

diff --git a/polly/lib/Analysis/ScopBuilder.cpp b/polly/lib/Analysis/ScopBuilder.cpp
index c34413812d9464e..0af0f6915b14585 100644
--- a/polly/lib/Analysis/ScopBuilder.cpp
+++ b/polly/lib/Analysis/ScopBuilder.cpp
@@ -2510,6 +2510,48 @@ static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
   }
 }
 
+///  True if @p AllAccs intersects with @p MemAccs execpt @p LoadMA and @p
+///  StoreMA
+bool hasIntersectingAccesses(isl::set AllAccs, MemoryAccess *LoadMA,
+                             MemoryAccess *StoreMA, isl::set Domain,
+                             SmallVector<MemoryAccess *, 8> &MemAccs) {
+  bool HasIntersectingAccs = false;
+  for (MemoryAccess *MA : MemAccs) {
+    if (MA == LoadMA || MA == StoreMA)
+      continue;
+
+    isl::map AccRel = MA->getAccessRelation().intersect_domain(Domain);
+    isl::set Accs = AccRel.range();
+
+    if (AllAccs.has_equal_space(Accs)) {
+      isl::set OverlapAccs = Accs.intersect(AllAccs);
+      bool DoesIntersect = !OverlapAccs.is_empty();
+      HasIntersectingAccs |= DoesIntersect;
+    }
+  }
+  return HasIntersectingAccs;
+}
+
+///  Test if the accesses of @p LoadMA and @p StoreMA can form a reduction
+bool checkCandidatePairAccesses(MemoryAccess *LoadMA, MemoryAccess *StoreMA,
+                                isl::set Domain,
+                                SmallVector<MemoryAccess *, 8> &MemAccs) {
+  isl::map LoadAccs = LoadMA->getAccessRelation();
+  isl::map StoreAccs = StoreMA->getAccessRelation();
+
+  // Skip those with obviously unequal base addresses.
+  bool Valid = LoadAccs.has_equal_space(StoreAccs);
+
+  // And check if the remaining for overlap with other memory accesses.
+  if (Valid) {
+    isl::map AllAccsRel = LoadAccs.unite(StoreAccs);
+    AllAccsRel = AllAccsRel.intersect_domain(Domain);
+    isl::set AllAccs = AllAccsRel.range();
+    Valid = !hasIntersectingAccesses(AllAccs, LoadMA, StoreMA, Domain, MemAccs);
+  }
+  return Valid;
+}
+
 void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
   SmallVector<MemoryAccess *, 2> Loads;
   SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4> Candidates;
@@ -2528,34 +2570,10 @@ void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
 
   // Then check each possible candidate pair.
   for (const auto &CandidatePair : Candidates) {
-    bool Valid = true;
-    isl::map LoadAccs = CandidatePair.first->getAccessRelation();
-    isl::map StoreAccs = CandidatePair.second->getAccessRelation();
-
-    // Skip those with obviously unequal base addresses.
-    if (!LoadAccs.has_equal_space(StoreAccs)) {
-      continue;
-    }
-
-    // And check if the remaining for overlap with other memory accesses.
-    isl::map AllAccsRel = LoadAccs.unite(StoreAccs);
-    AllAccsRel = AllAccsRel.intersect_domain(Stmt.getDomain());
-    isl::set AllAccs = AllAccsRel.range();
-
-    for (MemoryAccess *MA : Stmt) {
-      if (MA == CandidatePair.first || MA == CandidatePair.second)
-        continue;
-
-      isl::map AccRel =
-          MA->getAccessRelation().intersect_domain(Stmt.getDomain());
-      isl::set Accs = AccRel.range();
-
-      if (AllAccs.has_equal_space(Accs)) {
-        isl::set OverlapAccs = Accs.intersect(AllAccs);
-        Valid = Valid && OverlapAccs.is_empty();
-      }
-    }
-
+    MemoryAccess *LoadMA = CandidatePair.first;
+    MemoryAccess *StoreMA = CandidatePair.second;
+    bool Valid = checkCandidatePairAccesses(LoadMA, StoreMA, Stmt.getDomain(),
+                                            Stmt.MemAccs);
     if (!Valid)
       continue;
 
@@ -2566,8 +2584,8 @@ void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
 
     // If no overlapping access was found we mark the load and store as
     // reduction like.
-    CandidatePair.first->markAsReductionLike(RT);
-    CandidatePair.second->markAsReductionLike(RT);
+    LoadMA->markAsReductionLike(RT);
+    StoreMA->markAsReductionLike(RT);
   }
 }
 



More information about the cfe-commits mailing list