[polly] [Polly] Data flow reduction detection to cover more cases (PR #84901)

Michael Kruse via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 25 08:21:41 PDT 2024


================
@@ -2571,47 +2581,208 @@ bool checkCandidatePairAccesses(MemoryAccess *LoadMA, MemoryAccess *StoreMA,
     AllAccsRel = AllAccsRel.intersect_domain(Domain);
     isl::set AllAccs = AllAccsRel.range();
     Valid = !hasIntersectingAccesses(AllAccs, LoadMA, StoreMA, Domain, MemAccs);
-
     POLLY_DEBUG(dbgs() << " == The accessed memory is " << (Valid ? "not " : "")
                        << "accessed by other instructions!\n");
   }
+
   return Valid;
 }
 
+// Perform a data flow analysis on the current scop statement to propagate the
+// uses of loaded values. Then check and mark the memory accesses which are
+// part of reduction like chains.
+//
+// NOTE: This assumes independent scop statements and breaks otherwise.
 void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
-  SmallVector<MemoryAccess *, 2> Loads;
-  SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4> Candidates;
+  // During the data flow analysis we use the State variable to keep track of
+  // the used "load-instructions" for each instruction in the scop statement.
+  // This includes the LLVM-IR of the load and the "number of uses" (or the
+  // number of paths in the operand tree which end in this load).
+  using StatePairTy = std::pair<unsigned, MemoryAccess::ReductionType>;
+  using FlowInSetTy = MapVector<const LoadInst *, StatePairTy>;
+  using StateTy = MapVector<const Instruction *, FlowInSetTy>;
+  StateTy State;
+
+  // Invalid loads are loads which have uses we can't track properly in the
+  // state map. This includes loads which:
+  //   o do not form a reduction when they flow into a memory location:
+  //     (e.g., A[i] = B[i] * 3 and  A[i] = A[i] * A[i] + A[i])
+  //   o are used by a non binary operator or one which is not commutative
+  //     and associative (e.g., A[i] = A[i] % 3)
+  //   o might change the control flow            (e.g., if (A[i]))
+  //   o are used in indirect memory accesses     (e.g., A[B[i]])
+  //   o are used outside the current scop statement
+  SmallPtrSet<const Instruction *, 8> InvalidLoads;
+  SmallVector<BasicBlock *, 8> ScopBlocks;
+  BasicBlock *BB = Stmt.getBasicBlock();
+  if (BB)
+    ScopBlocks.push_back(BB);
+  else
+    for (BasicBlock *Block : Stmt.getRegion()->blocks())
+      ScopBlocks.push_back(Block);
+  // Run the data flow analysis for all values in the scop statement
+  for (BasicBlock *Block : ScopBlocks) {
+    for (Instruction &Inst : *Block) {
+      if ((Stmt.getParent())->getStmtFor(&Inst) != &Stmt)
+        continue;
+      bool UsedOutsideStmt = any_of(Inst.users(), [&Stmt](User *U) {
+        return (Stmt.getParent())->getStmtFor(cast<Instruction>(U)) != &Stmt;
+      });
+      //  Treat loads and stores special
+      if (auto *Load = dyn_cast<LoadInst>(&Inst)) {
+        // Invalidate all loads used which feed into the address of this load.
+        if (auto *Ptr = dyn_cast<Instruction>(Load->getPointerOperand())) {
+          const auto &It = State.find(Ptr);
+          if (It != State.end())
+            for (const auto &FlowInSetElem : It->second)
+              InvalidLoads.insert(FlowInSetElem.first);
+        }
 
-  // First collect candidate load-store reduction chains by iterating over all
-  // stores and collecting possible reduction loads.
-  for (MemoryAccess *StoreMA : Stmt) {
-    if (StoreMA->isRead())
-      continue;
+        // If this load is used outside this stmt, invalidate it.
+        if (UsedOutsideStmt)
+          InvalidLoads.insert(Load);
+
+        // And indicate that this load uses itself once but without specifying
+        // any reduction operator.
+        State[Load].insert(
+            std::make_pair(Load, std::make_pair(1, MemoryAccess::RT_BOTTOM)));
+        continue;
+      }
+
+      if (auto *Store = dyn_cast<StoreInst>(&Inst)) {
+        // Invalidate all loads which feed into the address of this store.
+        if (const Instruction *Ptr =
+                dyn_cast<Instruction>(Store->getPointerOperand())) {
+          const auto &It = State.find(Ptr);
+          if (It != State.end())
+            for (const auto &FlowInSetElem : It->second)
+              InvalidLoads.insert(FlowInSetElem.first);
+        }
 
-    Loads.clear();
-    collectCandidateReductionLoads(StoreMA, Loads);
-    for (MemoryAccess *LoadMA : Loads)
-      Candidates.push_back(std::make_pair(LoadMA, StoreMA));
+        // Propagate the uses of the value operand to the store
+        if (auto *ValueInst = dyn_cast<Instruction>(Store->getValueOperand()))
+          State.insert(std::make_pair(Store, State[ValueInst]));
+        continue;
+      }
+
+      // Non load and store instructions are either binary operators or they
+      // will invalidate all used loads.
+      auto *BinOp = dyn_cast<BinaryOperator>(&Inst);
+      auto CurRedType = getReductionType(BinOp);
----------------
Meinersbur wrote:

```suggestion
      ReductionType CurRedType = getReductionType(BinOp);
```

https://github.com/llvm/llvm-project/pull/84901


More information about the llvm-commits mailing list