[polly] [Polly] Data flow reduction detection to cover more cases (PR #84901)

Michael Kruse via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 22 08:22:17 PDT 2024


================
@@ -2568,49 +2579,196 @@ bool checkCandidatePairAccesses(MemoryAccess *LoadMA, MemoryAccess *StoreMA,
     // Finally, check if they are no other instructions accessing this memory
     isl::map AllAccsRel = LoadAccs.unite(StoreAccs);
     AllAccsRel = AllAccsRel.intersect_domain(Domain);
+
     isl::set AllAccs = AllAccsRel.range();
+
     Valid = !hasIntersectingAccesses(AllAccs, LoadMA, StoreMA, Domain, MemAccs);
 
     LLVM_DEBUG(dbgs() << " == The accessed memory is " << (Valid ? "not " : "")
                       << "accessed by other instructions!\n");
   }
+
   return Valid;
 }
 
-void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
-  SmallVector<MemoryAccess *, 2> Loads;
-  SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4> Candidates;
+/// Perform a data flow analysis on the current basic block to propagate the
+/// uses of loaded values. Then check and mark the memory accesses which are
+/// part of reduction like chains.
+///
+/// NOTE: This assumes independent blocks and breaks otherwise.
+void ScopBuilder::checkForReductions(ScopStmt &Stmt, BasicBlock *Block) {
+  // During the data flow anaylis we use the State variable to keep track of
+  // the used "load-instructions" for each instruction in the basic block.
+  // This includes the LLVM-IR of the load and the "number of uses" (or the
+  // number of paths in the operand tree which end in this load).
+  using StatePairTy = std::pair<unsigned, MemoryAccess::ReductionType>;
+  using FlowInSetTy = MapVector<const LoadInst *, StatePairTy>;
+  using StateTy = MapVector<const Instruction *, FlowInSetTy>;
+  StateTy State;
+
+  // Invalid loads are loads which have uses we can't track properly in the
+  // state map. This includes loads which:
+  //   o do not form a reduction when they flow into a memory location:
+  //     (e.g., A[i] = B[i] * 3 and  A[i] = A[i] * A[i] + A[i])
+  //   o are used by a non binary operator or one which is not commutative
+  //     and associative (e.g., A[i] = A[i] % 3)
+  //   o might change the control flow            (e.g., if (A[i]))
+  //   o are used in indirect memory accesses     (e.g., A[B[i]])
+  //   o are used outside the current basic block
+  SmallPtrSet<const Instruction *, 8> InvalidLoads;
+
+  // Run the data flow analysis for all values in the basic block
+  for (Instruction &Inst : *Block) {
+    bool UsedOutsideBlock = any_of(Inst.users(), [Block](User *U) {
+      return cast<Instruction>(U)->getParent() != Block;
+    });
+
+    // Treat loads and stores special
+    if (auto *Load = dyn_cast<LoadInst>(&Inst)) {
+      // Invalidate all loads used which feed into the address of this load.
+      if (auto *Ptr = dyn_cast<Instruction>(Load->getPointerOperand())) {
+        const auto &It = State.find(Ptr);
+        if (It != State.end())
+          for (const auto &FlowInSetElem : It->second)
+            InvalidLoads.insert(FlowInSetElem.first);
+      }
+
+      // If this load is used outside this block, invalidate it.
+      if (UsedOutsideBlock)
+        InvalidLoads.insert(Load);
 
-  // First collect candidate load-store reduction chains by iterating over all
-  // stores and collecting possible reduction loads.
-  for (MemoryAccess *StoreMA : Stmt) {
-    if (StoreMA->isRead())
+      // And indicate that this load uses itself once but without specifying
+      // any reduction operator.
+      State[Load].insert(
+          std::make_pair(Load, std::make_pair(1, MemoryAccess::RT_BOTTOM)));
       continue;
+    }
 
-    Loads.clear();
-    collectCandidateReductionLoads(StoreMA, Loads);
-    for (MemoryAccess *LoadMA : Loads)
-      Candidates.push_back(std::make_pair(LoadMA, StoreMA));
-  }
+    if (auto *Store = dyn_cast<StoreInst>(&Inst)) {
+      // Invalidate all loads which feed into the address of this store.
+      if (const Instruction *Ptr =
+              dyn_cast<Instruction>(Store->getPointerOperand())) {
+        const auto &It = State.find(Ptr);
+        if (It != State.end())
+          for (const auto &FlowInSetElem : It->second)
+            InvalidLoads.insert(FlowInSetElem.first);
+      }
 
-  // Then check each possible candidate pair.
-  for (const auto &CandidatePair : Candidates) {
-    MemoryAccess *LoadMA = CandidatePair.first;
-    MemoryAccess *StoreMA = CandidatePair.second;
-    bool Valid = checkCandidatePairAccesses(LoadMA, StoreMA, Stmt.getDomain(),
-                                            Stmt.MemAccs);
-    if (!Valid)
+      // Propagate the uses of the value operand to the store
+      if (auto *ValueInst = dyn_cast<Instruction>(Store->getValueOperand()))
+        State.insert(std::make_pair(Store, State[ValueInst]));
       continue;
+    }
+
+    // Non load and store instructions are either binary operators or they will
+    // invalidate all used loads.
+    auto *BinOp = dyn_cast<BinaryOperator>(&Inst);
+    auto CurRedType = getReductionType(BinOp);
+    LLVM_DEBUG(dbgs() << "CurInst: " << Inst << " RT: " << CurRedType << "\n");
+
+    // Iterate over all operands and propagate their input loads to instruction.
+    FlowInSetTy &InstInFlowSet = State[&Inst];
+    for (Use &Op : Inst.operands()) {
+      auto *OpInst = dyn_cast<Instruction>(Op);
+      if (!OpInst)
+        continue;
+
+      LLVM_DEBUG(dbgs().indent(4) << "Op Inst: " << *OpInst << "\n");
+      const StateTy::iterator &OpInFlowSetIt = State.find(OpInst);
+      if (OpInFlowSetIt == State.end())
+        continue;
+
+      // Iterate over all the input loads of the operand and combine them
+      // with the input loads of current instruction.
+      FlowInSetTy &OpInFlowSet = OpInFlowSetIt->second;
+      for (auto &OpInFlowPair : OpInFlowSet) {
+        unsigned OpFlowIn = OpInFlowPair.second.first;
+        unsigned InstFlowIn = InstInFlowSet[OpInFlowPair.first].first;
+
+        auto OpRedType = OpInFlowPair.second.second;
+        auto InstRedType = InstInFlowSet[OpInFlowPair.first].second;
+
+        auto NewRedType = combineReductionType(OpRedType, CurRedType);
+        if (InstFlowIn)
+          NewRedType = combineReductionType(NewRedType, InstRedType);
+
+        LLVM_DEBUG(dbgs().indent(8) << "OpRedType: " << OpRedType << "\n");
+        LLVM_DEBUG(dbgs().indent(8) << "NewRedType: " << NewRedType << "\n");
+        InstInFlowSet[OpInFlowPair.first] =
+            std::make_pair(OpFlowIn + InstFlowIn, NewRedType);
+      }
+    }
+
+    // If this operation is used outside the block, invalidate all the loads
+    // which feed into it.
+    if (UsedOutsideBlock)
+      for (const auto &FlowInSetElem : InstInFlowSet)
+        InvalidLoads.insert(FlowInSetElem.first);
+  }
+
+  // All used loads are propagated through the whole basic block; now try to
+  // find valid reduction like candidate pairs. These load-store pairs fulfill
+  // all reduction like properties with regards to only this load-store chain.
+  // We later have to check if the loaded value was invalidated by an
+  // instruction not in that chain.
+  using MemAccPair = std::pair<MemoryAccess *, MemoryAccess *>;
+  DenseMap<MemAccPair, MemoryAccess::ReductionType> ValidCandidates;
+  DominatorTree *DT = Stmt.getParent()->getDT();
+
+  // Iterate over all write memory accesses and check the loads flowing into
+  // it for reduction candidate pairs.
+  for (MemoryAccess *WriteMA : Stmt.MemAccs) {
+    if (WriteMA->isRead())
+      continue;
+    StoreInst *St = dyn_cast<StoreInst>(WriteMA->getAccessInstruction());
+    if (!St || St->isVolatile())
+      continue;
+
+    FlowInSetTy &MaInFlowSet = State[WriteMA->getAccessInstruction()];
+    bool Valid = false;
+
+    for (auto &MaInFlowSetElem : MaInFlowSet) {
+      MemoryAccess *ReadMA = &Stmt.getArrayAccessFor(MaInFlowSetElem.first);
+      assert(ReadMA && "Couldn't find memory access for incoming load!");
 
-    const LoadInst *Load =
-        dyn_cast<const LoadInst>(CandidatePair.first->getAccessInstruction());
-    MemoryAccess::ReductionType RT =
-        getReductionType(dyn_cast<BinaryOperator>(Load->user_back()), Load);
+      LLVM_DEBUG(dbgs() << "'" << *ReadMA->getAccessInstruction()
+                        << "'\n\tflows into\n'"
+                        << *WriteMA->getAccessInstruction() << "'\n\t #"
+                        << MaInFlowSetElem.second.first << " times & RT: "
+                        << MaInFlowSetElem.second.second << "\n");
 
-    // If no overlapping access was found we mark the load and store as
-    // reduction like.
-    LoadMA->markAsReductionLike(RT);
-    StoreMA->markAsReductionLike(RT);
+      MemoryAccess::ReductionType RT = MaInFlowSetElem.second.second;
+      unsigned NumAllowableInFlow = 1;
+
+      // We allow the load to flow in exactly once for binary reductions
+      Valid = (MaInFlowSetElem.second.first == NumAllowableInFlow);
----------------
Meinersbur wrote:

```suggestion
      bool Valid = (MaInFlowSetElem.second.first == NumAllowableInFlow);
```

https://github.com/llvm/llvm-project/pull/84901


More information about the llvm-commits mailing list