[llvm] Redesign Straight-Line Strength Reduction (SLSR) (PR #162930)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 10 17:52:11 PDT 2025
================
@@ -561,118 +1067,164 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
DL->getIndexSizeInBits(GEP->getAddressSpace())) {
// Skip factoring if TruncatedArrayIdx is wider than the pointer size,
// because TruncatedArrayIdx is implicitly truncated to the pointer size.
- factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP);
+ allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx,
+ TruncatedArrayIdx, GEP);
}
IndexExprs[I - 1] = OrigIndexExpr;
}
}
-// A helper function that unifies the bitwidth of A and B.
-static void unifyBitWidth(APInt &A, APInt &B) {
- if (A.getBitWidth() < B.getBitWidth())
- A = A.sext(B.getBitWidth());
- else if (A.getBitWidth() > B.getBitWidth())
- B = B.sext(A.getBitWidth());
-}
-
Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
const Candidate &C,
IRBuilder<> &Builder,
const DataLayout *DL) {
- APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue();
- unifyBitWidth(Idx, BasisIdx);
- APInt IndexOffset = Idx - BasisIdx;
+ auto CreateMul = [&](Value *LHS, Value *RHS) {
+ if (isa<ConstantInt>(RHS)) {
+ APInt ConstRHS = cast<ConstantInt>(RHS)->getValue();
+ IntegerType *DeltaType =
+ IntegerType::get(C.Ins->getContext(), ConstRHS.getBitWidth());
+ if (ConstRHS.isPowerOf2()) {
+ ConstantInt *Exponent =
+ ConstantInt::get(DeltaType, ConstRHS.logBase2());
+ return Builder.CreateShl(LHS, Exponent);
+ }
+ if (ConstRHS.isNegatedPowerOf2()) {
+ ConstantInt *Exponent =
+ ConstantInt::get(DeltaType, (-ConstRHS).logBase2());
+ return Builder.CreateNeg(Builder.CreateShl(LHS, Exponent));
+ }
+ }
- // Compute Bump = C - Basis = (i' - i) * S.
- // Common case 1: if (i' - i) is 1, Bump = S.
- if (IndexOffset == 1)
- return C.Stride;
- // Common case 2: if (i' - i) is -1, Bump = -S.
- if (IndexOffset.isAllOnes())
- return Builder.CreateNeg(C.Stride);
+ return Builder.CreateMul(LHS, RHS);
+ };
- // Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may
- // have different bit widths.
- IntegerType *DeltaType =
- IntegerType::get(Basis.Ins->getContext(), IndexOffset.getBitWidth());
- Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType);
- if (IndexOffset.isPowerOf2()) {
- // If (i' - i) is a power of 2, Bump = sext/trunc(S) << log(i' - i).
- ConstantInt *Exponent = ConstantInt::get(DeltaType, IndexOffset.logBase2());
- return Builder.CreateShl(ExtendedStride, Exponent);
- }
- if (IndexOffset.isNegatedPowerOf2()) {
- // If (i - i') is a power of 2, Bump = -sext/trunc(S) << log(i' - i).
- ConstantInt *Exponent =
- ConstantInt::get(DeltaType, (-IndexOffset).logBase2());
- return Builder.CreateNeg(Builder.CreateShl(ExtendedStride, Exponent));
+ if (C.DeltaKind == Candidate::IndexDelta) {
+ APInt IndexOffset = cast<ConstantInt>(C.Delta)->getValue();
+ // IndexDelta
+ // X = B + i * S
+ // Y = B + i` * S
+ // = B + (i' - i) * S
+ // = X + Delta * S
+ // Bump = (i' - i) * S
+
+ // If Delta is 0, C is a fully redundant of C.Basis,
+ // just replace C.Ins with Basis.Ins
+ if (IndexOffset.isZero())
+ return nullptr;
+
+ // Compute Bump = C - Basis = (i' - i) * S.
+ // Common case 1: if (i' - i) is 1, Bump = S.
+ if (IndexOffset == 1)
+ return C.Stride;
+ // Common case 2: if (i' - i) is -1, Bump = -S.
+ if (IndexOffset.isAllOnes())
+ return Builder.CreateNeg(C.Stride);
+
+ IntegerType *DeltaType =
+ IntegerType::get(Basis.Ins->getContext(), IndexOffset.getBitWidth());
+ Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType);
+
+ return CreateMul(ExtendedStride, C.Delta);
+ } else {
+ assert(C.DeltaKind == Candidate::StrideDelta ||
+ C.DeltaKind == Candidate::BaseDelta);
+ assert(C.CandidateKind != Candidate::Mul);
+ // StrideDelta
+ // X = B + i * S
+ // Y = B + i * S'
+ // = B + i * (S + Delta)
+ // = B + i * S + i * Delta
+ // = X + i * StrideDelta
+ // Bump = i * (S' - S)
+ //
+ // BaseDelta
+ // X = B + i * S
+ // Y = B' + i * S
+ // = (B + Delta) + i * S
+ // = X + BaseDelta
+ // Bump = (B' - B).
+ Value *Bump = C.Delta;
+ if (C.DeltaKind == Candidate::StrideDelta) {
+ // If this value is consumed by a GEP, promote StrideDelta before doing
+ // StrideDelta * Index to ensure the same semantics as the original GEP.
+ if (C.CandidateKind == Candidate::GEP) {
+ auto *GEP = cast<GetElementPtrInst>(C.Ins);
+ Type *NewScalarIndexTy =
+ DL->getIndexType(GEP->getPointerOperandType()->getScalarType());
+ Bump = Builder.CreateSExtOrTrunc(Bump, NewScalarIndexTy);
+ }
+ if (!C.Index->isOne()) {
+ Value *ExtendedIndex =
+ Builder.CreateSExtOrTrunc(C.Index, Bump->getType());
+ Bump = CreateMul(Bump, ExtendedIndex);
+ }
+ }
+ return Bump;
}
- Constant *Delta = ConstantInt::get(DeltaType, IndexOffset);
- return Builder.CreateMul(ExtendedStride, Delta);
}
void StraightLineStrengthReduce::rewriteCandidateWithBasis(
const Candidate &C, const Candidate &Basis) {
if (!DebugCounter::shouldExecute(StraightLineStrengthReduceCounter))
return;
- assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base &&
- C.Stride == Basis.Stride);
- // We run rewriteCandidateWithBasis on all candidates in a post-order, so the
- // basis of a candidate cannot be unlinked before the candidate.
- assert(Basis.Ins->getParent() != nullptr && "the basis is unlinked");
-
- // An instruction can correspond to multiple candidates. Therefore, instead of
- // simply deleting an instruction when we rewrite it, we mark its parent as
- // nullptr (i.e. unlink it) so that we can skip the candidates whose
- // instruction is already rewritten.
- if (!C.Ins->getParent())
- return;
+ // If one of Base, Index, and Stride are different,
+ // other parts must be the same
+ assert(C.Delta && C.CandidateKind == Basis.CandidateKind &&
+ ((C.Base == Basis.Base && C.StrideSCEV == Basis.StrideSCEV &&
+ C.DeltaKind == Candidate::IndexDelta) ||
+ (C.Base == Basis.Base && C.Index == Basis.Index &&
+ C.DeltaKind == Candidate::StrideDelta) ||
+ (C.StrideSCEV == Basis.StrideSCEV && C.Index == Basis.Index &&
+ C.DeltaKind == Candidate::BaseDelta)));
----------------
arsenm wrote:
Can you move this to operator==
https://github.com/llvm/llvm-project/pull/162930
More information about the llvm-commits
mailing list