[polly] [Polly] Data flow reduction detection to cover more cases (PR #84901)
Michael Kruse via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 22 08:22:17 PDT 2024
================
@@ -2568,49 +2579,196 @@ bool checkCandidatePairAccesses(MemoryAccess *LoadMA, MemoryAccess *StoreMA,
// Finally, check if they are no other instructions accessing this memory
isl::map AllAccsRel = LoadAccs.unite(StoreAccs);
AllAccsRel = AllAccsRel.intersect_domain(Domain);
+
isl::set AllAccs = AllAccsRel.range();
+
Valid = !hasIntersectingAccesses(AllAccs, LoadMA, StoreMA, Domain, MemAccs);
LLVM_DEBUG(dbgs() << " == The accessed memory is " << (Valid ? "not " : "")
<< "accessed by other instructions!\n");
}
+
return Valid;
}
-void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
- SmallVector<MemoryAccess *, 2> Loads;
- SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4> Candidates;
+/// Perform a data flow analysis on the current basic block to propagate the
+/// uses of loaded values. Then check and mark the memory accesses which are
+/// part of reduction like chains.
+///
+/// NOTE: This assumes independent blocks and breaks otherwise.
+void ScopBuilder::checkForReductions(ScopStmt &Stmt, BasicBlock *Block) {
+ // During the data flow anaylis we use the State variable to keep track of
+ // the used "load-instructions" for each instruction in the basic block.
+ // This includes the LLVM-IR of the load and the "number of uses" (or the
+ // number of paths in the operand tree which end in this load).
+ using StatePairTy = std::pair<unsigned, MemoryAccess::ReductionType>;
+ using FlowInSetTy = MapVector<const LoadInst *, StatePairTy>;
+ using StateTy = MapVector<const Instruction *, FlowInSetTy>;
+ StateTy State;
+
+ // Invalid loads are loads which have uses we can't track properly in the
+ // state map. This includes loads which:
+ // o do not form a reduction when they flow into a memory location:
+ // (e.g., A[i] = B[i] * 3 and A[i] = A[i] * A[i] + A[i])
+ // o are used by a non binary operator or one which is not commutative
+ // and associative (e.g., A[i] = A[i] % 3)
+ // o might change the control flow (e.g., if (A[i]))
+ // o are used in indirect memory accesses (e.g., A[B[i]])
+ // o are used outside the current basic block
+ SmallPtrSet<const Instruction *, 8> InvalidLoads;
+
+ // Run the data flow analysis for all values in the basic block
+ for (Instruction &Inst : *Block) {
+ bool UsedOutsideBlock = any_of(Inst.users(), [Block](User *U) {
+ return cast<Instruction>(U)->getParent() != Block;
+ });
+
+ // Treat loads and stores special
+ if (auto *Load = dyn_cast<LoadInst>(&Inst)) {
+ // Invalidate all loads used which feed into the address of this load.
+ if (auto *Ptr = dyn_cast<Instruction>(Load->getPointerOperand())) {
+ const auto &It = State.find(Ptr);
+ if (It != State.end())
+ for (const auto &FlowInSetElem : It->second)
+ InvalidLoads.insert(FlowInSetElem.first);
+ }
+
+ // If this load is used outside this block, invalidate it.
+ if (UsedOutsideBlock)
+ InvalidLoads.insert(Load);
- // First collect candidate load-store reduction chains by iterating over all
- // stores and collecting possible reduction loads.
- for (MemoryAccess *StoreMA : Stmt) {
- if (StoreMA->isRead())
+ // And indicate that this load uses itself once but without specifying
+ // any reduction operator.
+ State[Load].insert(
+ std::make_pair(Load, std::make_pair(1, MemoryAccess::RT_BOTTOM)));
continue;
+ }
- Loads.clear();
- collectCandidateReductionLoads(StoreMA, Loads);
- for (MemoryAccess *LoadMA : Loads)
- Candidates.push_back(std::make_pair(LoadMA, StoreMA));
- }
+ if (auto *Store = dyn_cast<StoreInst>(&Inst)) {
+ // Invalidate all loads which feed into the address of this store.
+ if (const Instruction *Ptr =
+ dyn_cast<Instruction>(Store->getPointerOperand())) {
+ const auto &It = State.find(Ptr);
+ if (It != State.end())
+ for (const auto &FlowInSetElem : It->second)
+ InvalidLoads.insert(FlowInSetElem.first);
+ }
- // Then check each possible candidate pair.
- for (const auto &CandidatePair : Candidates) {
- MemoryAccess *LoadMA = CandidatePair.first;
- MemoryAccess *StoreMA = CandidatePair.second;
- bool Valid = checkCandidatePairAccesses(LoadMA, StoreMA, Stmt.getDomain(),
- Stmt.MemAccs);
- if (!Valid)
+ // Propagate the uses of the value operand to the store
+ if (auto *ValueInst = dyn_cast<Instruction>(Store->getValueOperand()))
+ State.insert(std::make_pair(Store, State[ValueInst]));
continue;
+ }
+
+ // Non load and store instructions are either binary operators or they will
+ // invalidate all used loads.
+ auto *BinOp = dyn_cast<BinaryOperator>(&Inst);
+ auto CurRedType = getReductionType(BinOp);
+ LLVM_DEBUG(dbgs() << "CurInst: " << Inst << " RT: " << CurRedType << "\n");
+
+ // Iterate over all operands and propagate their input loads to instruction.
+ FlowInSetTy &InstInFlowSet = State[&Inst];
+ for (Use &Op : Inst.operands()) {
+ auto *OpInst = dyn_cast<Instruction>(Op);
+ if (!OpInst)
+ continue;
+
+ LLVM_DEBUG(dbgs().indent(4) << "Op Inst: " << *OpInst << "\n");
+ const StateTy::iterator &OpInFlowSetIt = State.find(OpInst);
+ if (OpInFlowSetIt == State.end())
+ continue;
+
+ // Iterate over all the input loads of the operand and combine them
+ // with the input loads of current instruction.
+ FlowInSetTy &OpInFlowSet = OpInFlowSetIt->second;
+ for (auto &OpInFlowPair : OpInFlowSet) {
+ unsigned OpFlowIn = OpInFlowPair.second.first;
+ unsigned InstFlowIn = InstInFlowSet[OpInFlowPair.first].first;
+
+ auto OpRedType = OpInFlowPair.second.second;
+ auto InstRedType = InstInFlowSet[OpInFlowPair.first].second;
+
+ auto NewRedType = combineReductionType(OpRedType, CurRedType);
+ if (InstFlowIn)
+ NewRedType = combineReductionType(NewRedType, InstRedType);
+
+ LLVM_DEBUG(dbgs().indent(8) << "OpRedType: " << OpRedType << "\n");
+ LLVM_DEBUG(dbgs().indent(8) << "NewRedType: " << NewRedType << "\n");
+ InstInFlowSet[OpInFlowPair.first] =
+ std::make_pair(OpFlowIn + InstFlowIn, NewRedType);
+ }
+ }
+
+ // If this operation is used outside the block, invalidate all the loads
+ // which feed into it.
+ if (UsedOutsideBlock)
+ for (const auto &FlowInSetElem : InstInFlowSet)
+ InvalidLoads.insert(FlowInSetElem.first);
+ }
+
+ // All used loads are propagated through the whole basic block; now try to
+ // find valid reduction like candidate pairs. These load-store pairs fulfill
+ // all reduction like properties with regards to only this load-store chain.
+ // We later have to check if the loaded value was invalidated by an
+ // instruction not in that chain.
+ using MemAccPair = std::pair<MemoryAccess *, MemoryAccess *>;
+ DenseMap<MemAccPair, MemoryAccess::ReductionType> ValidCandidates;
+ DominatorTree *DT = Stmt.getParent()->getDT();
+
+ // Iterate over all write memory accesses and check the loads flowing into
+ // it for reduction candidate pairs.
+ for (MemoryAccess *WriteMA : Stmt.MemAccs) {
+ if (WriteMA->isRead())
+ continue;
+ StoreInst *St = dyn_cast<StoreInst>(WriteMA->getAccessInstruction());
+ if (!St || St->isVolatile())
+ continue;
+
+ FlowInSetTy &MaInFlowSet = State[WriteMA->getAccessInstruction()];
+ bool Valid = false;
----------------
Meinersbur wrote:
```suggestion
```
Only used inside the loop
https://github.com/llvm/llvm-project/pull/84901
More information about the llvm-commits
mailing list