[llvm] Redesign Straight-Line Strength Reduction (SLSR) (PR #162930)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 10 17:52:11 PDT 2025


================
@@ -561,118 +1067,164 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
             DL->getIndexSizeInBits(GEP->getAddressSpace())) {
       // Skip factoring if TruncatedArrayIdx is wider than the pointer size,
       // because TruncatedArrayIdx is implicitly truncated to the pointer size.
-      factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP);
+      allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx,
+                                     TruncatedArrayIdx, GEP);
     }
 
     IndexExprs[I - 1] = OrigIndexExpr;
   }
 }
 
-// A helper function that unifies the bitwidth of A and B.
-static void unifyBitWidth(APInt &A, APInt &B) {
-  if (A.getBitWidth() < B.getBitWidth())
-    A = A.sext(B.getBitWidth());
-  else if (A.getBitWidth() > B.getBitWidth())
-    B = B.sext(A.getBitWidth());
-}
-
 Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
                                             const Candidate &C,
                                             IRBuilder<> &Builder,
                                             const DataLayout *DL) {
-  APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue();
-  unifyBitWidth(Idx, BasisIdx);
-  APInt IndexOffset = Idx - BasisIdx;
+  auto CreateMul = [&](Value *LHS, Value *RHS) {
+    if (isa<ConstantInt>(RHS)) {
+      APInt ConstRHS = cast<ConstantInt>(RHS)->getValue();
+      IntegerType *DeltaType =
+          IntegerType::get(C.Ins->getContext(), ConstRHS.getBitWidth());
+      if (ConstRHS.isPowerOf2()) {
+        ConstantInt *Exponent =
+            ConstantInt::get(DeltaType, ConstRHS.logBase2());
+        return Builder.CreateShl(LHS, Exponent);
+      }
+      if (ConstRHS.isNegatedPowerOf2()) {
+        ConstantInt *Exponent =
+            ConstantInt::get(DeltaType, (-ConstRHS).logBase2());
+        return Builder.CreateNeg(Builder.CreateShl(LHS, Exponent));
+      }
+    }
 
-  // Compute Bump = C - Basis = (i' - i) * S.
-  // Common case 1: if (i' - i) is 1, Bump = S.
-  if (IndexOffset == 1)
-    return C.Stride;
-  // Common case 2: if (i' - i) is -1, Bump = -S.
-  if (IndexOffset.isAllOnes())
-    return Builder.CreateNeg(C.Stride);
+    return Builder.CreateMul(LHS, RHS);
+  };
 
-  // Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may
-  // have different bit widths.
-  IntegerType *DeltaType =
-      IntegerType::get(Basis.Ins->getContext(), IndexOffset.getBitWidth());
-  Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType);
-  if (IndexOffset.isPowerOf2()) {
-    // If (i' - i) is a power of 2, Bump = sext/trunc(S) << log(i' - i).
-    ConstantInt *Exponent = ConstantInt::get(DeltaType, IndexOffset.logBase2());
-    return Builder.CreateShl(ExtendedStride, Exponent);
-  }
-  if (IndexOffset.isNegatedPowerOf2()) {
-    // If (i - i') is a power of 2, Bump = -sext/trunc(S) << log(i' - i).
-    ConstantInt *Exponent =
-        ConstantInt::get(DeltaType, (-IndexOffset).logBase2());
-    return Builder.CreateNeg(Builder.CreateShl(ExtendedStride, Exponent));
+  if (C.DeltaKind == Candidate::IndexDelta) {
+    APInt IndexOffset = cast<ConstantInt>(C.Delta)->getValue();
+    // IndexDelta
+    // X = B + i * S
+    // Y = B + i` * S
+    //   = B + (i' - i) * S
+    //   = X + Delta * S
+    // Bump = (i' - i) * S
+
+    // If Delta is 0, C is a fully redundant of C.Basis,
+    // just replace C.Ins with Basis.Ins
+    if (IndexOffset.isZero())
+      return nullptr;
+
+    // Compute Bump = C - Basis = (i' - i) * S.
+    // Common case 1: if (i' - i) is 1, Bump = S.
+    if (IndexOffset == 1)
+      return C.Stride;
+    // Common case 2: if (i' - i) is -1, Bump = -S.
+    if (IndexOffset.isAllOnes())
+      return Builder.CreateNeg(C.Stride);
+
+    IntegerType *DeltaType =
+        IntegerType::get(Basis.Ins->getContext(), IndexOffset.getBitWidth());
+    Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType);
+
+    return CreateMul(ExtendedStride, C.Delta);
+  } else {
+    assert(C.DeltaKind == Candidate::StrideDelta ||
+           C.DeltaKind == Candidate::BaseDelta);
+    assert(C.CandidateKind != Candidate::Mul);
+    // StrideDelta
+    // X = B + i * S
+    // Y = B + i * S'
+    //   = B + i * (S + Delta)
+    //   = B + i * S + i * Delta
+    //   = X + i * StrideDelta
+    // Bump = i * (S' - S)
+    //
+    // BaseDelta
+    // X = B  + i * S
+    // Y = B' + i * S
+    //   = (B + Delta) + i * S
+    //   = X + BaseDelta
+    // Bump = (B' - B).
+    Value *Bump = C.Delta;
+    if (C.DeltaKind == Candidate::StrideDelta) {
+      // If this value is consumed by a GEP, promote StrideDelta before doing
+      // StrideDelta * Index to ensure the same semantics as the original GEP.
+      if (C.CandidateKind == Candidate::GEP) {
+        auto *GEP = cast<GetElementPtrInst>(C.Ins);
+        Type *NewScalarIndexTy =
+            DL->getIndexType(GEP->getPointerOperandType()->getScalarType());
+        Bump = Builder.CreateSExtOrTrunc(Bump, NewScalarIndexTy);
+      }
+      if (!C.Index->isOne()) {
+        Value *ExtendedIndex =
+            Builder.CreateSExtOrTrunc(C.Index, Bump->getType());
+        Bump = CreateMul(Bump, ExtendedIndex);
+      }
+    }
+    return Bump;
   }
-  Constant *Delta = ConstantInt::get(DeltaType, IndexOffset);
-  return Builder.CreateMul(ExtendedStride, Delta);
 }
 
 void StraightLineStrengthReduce::rewriteCandidateWithBasis(
     const Candidate &C, const Candidate &Basis) {
   if (!DebugCounter::shouldExecute(StraightLineStrengthReduceCounter))
     return;
 
-  assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base &&
-         C.Stride == Basis.Stride);
-  // We run rewriteCandidateWithBasis on all candidates in a post-order, so the
-  // basis of a candidate cannot be unlinked before the candidate.
-  assert(Basis.Ins->getParent() != nullptr && "the basis is unlinked");
-
-  // An instruction can correspond to multiple candidates. Therefore, instead of
-  // simply deleting an instruction when we rewrite it, we mark its parent as
-  // nullptr (i.e. unlink it) so that we can skip the candidates whose
-  // instruction is already rewritten.
-  if (!C.Ins->getParent())
-    return;
+  // If one of Base, Index, and Stride are different,
+  // other parts must be the same
+  assert(C.Delta && C.CandidateKind == Basis.CandidateKind &&
+         ((C.Base == Basis.Base && C.StrideSCEV == Basis.StrideSCEV &&
+           C.DeltaKind == Candidate::IndexDelta) ||
+          (C.Base == Basis.Base && C.Index == Basis.Index &&
+           C.DeltaKind == Candidate::StrideDelta) ||
+          (C.StrideSCEV == Basis.StrideSCEV && C.Index == Basis.Index &&
+           C.DeltaKind == Candidate::BaseDelta)));
----------------
arsenm wrote:

Can you move this to operator== 

https://github.com/llvm/llvm-project/pull/162930


More information about the llvm-commits mailing list