[llvm] Redesign Straight-Line Strength Reduction (SLSR) (PR #162930)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 10 17:52:11 PDT 2025


================
@@ -269,17 +590,284 @@ FunctionPass *llvm::createStraightLineStrengthReducePass() {
   return new StraightLineStrengthReduceLegacyPass();
 }
 
-bool StraightLineStrengthReduce::isBasisFor(const Candidate &Basis,
-                                            const Candidate &C) {
-  return (Basis.Ins != C.Ins && // skip the same instruction
-          // They must have the same type too. Basis.Base == C.Base
-          // doesn't guarantee their types are the same (PR23975).
-          Basis.Ins->getType() == C.Ins->getType() &&
-          // Basis must dominate C in order to rewrite C with respect to Basis.
-          DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) &&
-          // They share the same base, stride, and candidate kind.
-          Basis.Base == C.Base && Basis.Stride == C.Stride &&
-          Basis.CandidateKind == C.CandidateKind);
+// A helper function that unifies the bitwidth of A and B.
+static void unifyBitWidth(APInt &A, APInt &B) {
+  if (A.getBitWidth() < B.getBitWidth())
+    A = A.sext(B.getBitWidth());
+  else if (A.getBitWidth() > B.getBitWidth())
+    B = B.sext(A.getBitWidth());
+}
+
+Constant *StraightLineStrengthReduce::getIndexDelta(Candidate &C,
+                                                    Candidate &Basis) {
+  APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue();
+  unifyBitWidth(Idx, BasisIdx);
+  APInt IndexOffset = Idx - BasisIdx;
+  IntegerType *DeltaType =
+      IntegerType::get(C.Ins->getContext(), IndexOffset.getBitWidth());
+  return ConstantInt::get(DeltaType, IndexOffset);
+}
+
+bool StraightLineStrengthReduce::isSimilar(Candidate &C, Candidate &Basis,
+                                           Candidate::DKind K) {
+  bool SameType = false;
+  switch (K) {
+  case Candidate::StrideDelta:
+    SameType = C.StrideSCEV->getType() == Basis.StrideSCEV->getType();
+    break;
+  case Candidate::BaseDelta:
+    SameType = C.Base->getType() == Basis.Base->getType();
+    break;
+  case Candidate::IndexDelta:
+    SameType = true;
+    break;
+  default:;
+  }
+  return SameType && Basis.Ins != C.Ins &&
+         Basis.CandidateKind == C.CandidateKind;
+}
+
+void StraightLineStrengthReduce::setBasisAndDeltaFor(Candidate &C) {
+  auto SearchFrom = [this, &C](const CandidateDictTy::BBToCandsTy &BBToCands,
+                               auto IsTarget) -> bool {
+    // Search dominating candidates by walking the immediate-dominator chain
+    // from the candidate's defining block upward. Visiting blocks in this
+    // order ensures we prefer the closest dominating basis.
+    const BasicBlock *BB = C.Ins->getParent();
+    while (BB) {
+      auto It = BBToCands.find(BB);
+      if (It != BBToCands.end())
+        for (Candidate *Basis : reverse(It->second))
+          if (IsTarget(Basis))
+            return true;
+
+      const DomTreeNode *Node = DT->getNode(BB);
+      if (!Node)
+        break;
+      Node = Node->getIDom();
+      BB = Node ? Node->getBlock() : nullptr;
+    }
+    return false;
+  };
+
+  // Priority:
+  // Constant Delta from Index > Constant Delta from Base >
+  // Constant Delta from Stride > Variable Delta from Base or Stride
+  // TODO: Change the priority to align with the cost model.
+
+  // First, look for a constant index-diff basis
+  if (const auto *IndexDeltaCandidates =
+          CandidateDict.getCandidatesWithDeltaKind(C, Candidate::IndexDelta)) {
+    bool FoundConstDelta =
+        SearchFrom(*IndexDeltaCandidates, [&DT = DT, &C](Candidate *Basis) {
+          if (isSimilar(C, *Basis, Candidate::IndexDelta)) {
+            assert(DT->dominates(Basis->Ins, C.Ins));
+            auto *Delta = getIndexDelta(C, *Basis);
+            if (!C.isProfitableRewrite(Delta, Candidate::IndexDelta))
+              return false;
+            C.Basis = Basis;
+            C.DeltaKind = Candidate::IndexDelta;
+            C.Delta = Delta;
+            LLVM_DEBUG(dbgs() << "Found delta from Index " << *C.Delta << "\n");
+            return true;
+          }
+          return false;
+        });
+    if (FoundConstDelta)
+      return;
+  }
+
+  // No constant-index-diff basis found. look for the best possible base-diff
+  // or stride-diff basis
+  // Base/Stride diffs not supported for form (B + i) * S
+  if (C.CandidateKind == Candidate::Mul)
+    return;
+
+  auto For = [this, &C](Candidate::DKind K) {
+    // return true if find a Basis with constant delta and stop searching,
+    // return false if did not find a Basis or the delta is not a constant
+    // and continue searching for a Basis with constant delta
+    return [K, this, &C](Candidate *Basis) -> bool {
+      if (!isSimilar(C, *Basis, K))
+        return false;
+
+      assert(DT->dominates(Basis->Ins, C.Ins));
+      const SCEV *BasisPart =
+          (K == Candidate::BaseDelta) ? Basis->Base : Basis->StrideSCEV;
+      const SCEV *CandPart =
+          (K == Candidate::BaseDelta) ? C.Base : C.StrideSCEV;
+      const SCEV *Diff = SE->getMinusSCEV(CandPart, BasisPart);
+      Value *AvailableVal = getNearestValueOfSCEV(Diff, C.Ins);
+      if (!AvailableVal)
+        return false;
+
+      // Record delta if none has been found yet, or the new delta is
+      // a constant that is better than the existing delta.
+      if (!C.Delta || isa<ConstantInt>(AvailableVal)) {
+        C.Delta = AvailableVal;
+        C.Basis = Basis;
+        C.DeltaKind = K;
+      }
+      return isa<ConstantInt>(C.Delta);
+    };
+  };
+
+  if (const auto *BaseDeltaCandidates =
+          CandidateDict.getCandidatesWithDeltaKind(C, Candidate::BaseDelta)) {
+    if (SearchFrom(*BaseDeltaCandidates, For(Candidate::BaseDelta))) {
+      LLVM_DEBUG(dbgs() << "Found delta from Base: " << *C.Delta << "\n");
+      return;
+    }
+  }
+
+  if (const auto *StrideDeltaCandidates =
+          CandidateDict.getCandidatesWithDeltaKind(C, Candidate::StrideDelta)) {
+    if (SearchFrom(*StrideDeltaCandidates, For(Candidate::StrideDelta))) {
+      LLVM_DEBUG(dbgs() << "Found delta from Stride: " << *C.Delta << "\n");
+      return;
+    }
+  }
+
+  // If we did not find a constant delta, we might have found a variable delta
+  if (C.Delta) {
+    LLVM_DEBUG(dbgs() << "Found delta from ";
+               if (C.DeltaKind == Candidate::BaseDelta) dbgs() << "Base: ";
+               else dbgs() << "Stride: "; dbgs() << *C.Delta << "\n");
----------------
arsenm wrote:

```suggestion
    LLVM_DEBUG({dbgs() << "Found delta from ";
               if (C.DeltaKind == Candidate::BaseDelta) dbgs() << "Base: ";
               else dbgs() << "Stride: "; dbgs() << *C.Delta << '\n'});
```
Adding {} will clang-format it better 

https://github.com/llvm/llvm-project/pull/162930


More information about the llvm-commits mailing list