[llvm] [VPlan] Extend getSCEVForVPV, use to compute VPReplicateRecipe cost. (PR #161276)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 3 06:45:17 PDT 2025


================
@@ -86,6 +87,85 @@ const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) {
   return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe())
       .Case<VPExpandSCEVRecipe>(
           [](const VPExpandSCEVRecipe *R) { return R->getSCEV(); })
+      .Case<VPCanonicalIVPHIRecipe>([&SE, L](const VPCanonicalIVPHIRecipe *R) {
+        if (!L)
+          return SE.getCouldNotCompute();
+        const SCEV *Start = getSCEVExprForVPValue(R->getOperand(0), SE, L);
+        return SE.getAddRecExpr(Start, SE.getOne(Start->getType()), L,
+                                SCEV::FlagAnyWrap);
+      })
+      .Case<VPDerivedIVRecipe>([&SE, L](const VPDerivedIVRecipe *R) {
+        const SCEV *Start = getSCEVExprForVPValue(R->getOperand(0), SE, L);
+        const SCEV *IV = getSCEVExprForVPValue(R->getOperand(1), SE, L);
+        const SCEV *Scale = getSCEVExprForVPValue(R->getOperand(2), SE, L);
+        if (any_of(ArrayRef({Start, IV, Scale}), IsaPred<SCEVCouldNotCompute>))
+          return SE.getCouldNotCompute();
+
+        return SE.getAddExpr(SE.getTruncateOrSignExtend(Start, IV->getType()),
+                             SE.getMulExpr(IV, SE.getTruncateOrSignExtend(
+                                                   Scale, IV->getType())));
+      })
+      .Case<VPScalarIVStepsRecipe>([&SE, L](const VPScalarIVStepsRecipe *R) {
+        return getSCEVExprForVPValue(R->getOperand(0), SE, L);
+      })
+      .Case<VPReplicateRecipe>([&SE, L](const VPReplicateRecipe *R) {
+        if (R->getOpcode() != Instruction::GetElementPtr)
+          return SE.getCouldNotCompute();
+
+        const SCEV *Base = getSCEVExprForVPValue(R->getOperand(0), SE, L);
+        if (isa<SCEVCouldNotCompute>(Base))
+          return SE.getCouldNotCompute();
+
+        Type *IntIdxTy = SE.getEffectiveSCEVType(Base->getType());
+        Type *CurTy = IntIdxTy;
+        bool FirstIter = true;
+        SmallVector<const SCEV *, 4> Offsets;
+        for (VPValue *Index : drop_begin(R->operands())) {
+          const SCEV *IndexExpr = getSCEVExprForVPValue(Index, SE, L);
+          if (isa<SCEVCouldNotCompute>(IndexExpr))
+            return SE.getCouldNotCompute();
+          // Compute the (potentially symbolic) offset in bytes for this index.
+          if (StructType *STy = dyn_cast<StructType>(CurTy)) {
+            // For a struct, add the member offset.
+            ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
+            unsigned FieldNo = Index->getZExtValue();
+            const SCEV *FieldOffset =
+                SE.getOffsetOfExpr(IntIdxTy, STy, FieldNo);
+            Offsets.push_back(FieldOffset);
+
+            // Update CurTy to the type of the field at Index.
+            CurTy = STy->getTypeAtIndex(Index);
+          } else {
+            // Update CurTy to its element type.
+            if (FirstIter) {
+              CurTy = cast<GetElementPtrInst>(R->getUnderlyingInstr())
+                          ->getSourceElementType();
+              FirstIter = false;
+            } else {
+              CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
+            }
+            // For an array, add the element offset, explicitly scaled.
+            const SCEV *ElementSize = SE.getSizeOfExpr(IntIdxTy, CurTy);
+            // Getelementptr indices are signed.
+            IndexExpr = SE.getTruncateOrSignExtend(IndexExpr, IntIdxTy);
+
+            // Multiply the index by the element size to compute the element
+            // offset.
+            const SCEV *LocalOffset = SE.getMulExpr(IndexExpr, ElementSize);
+            Offsets.push_back(LocalOffset);
+          }
+        }
+        // Handle degenerate case of GEP without offsets.
+        if (Offsets.empty())
+          return Base;
+
+        // Add the offsets together, assuming nsw if inbounds.
+        const SCEV *Offset = SE.getAddExpr(Offsets);
+        // Add the base address and the offset. We cannot use the nsw flag, as
+        // the base address is unsigned. However, if we know that the offset is
+        // non-negative, we can use nuw.
+        return SE.getAddExpr(Base, Offset);
+      })
----------------
artagnon wrote:

Confused that this looks like a copy of ScalarEvolution::getGEPExpr?

https://github.com/llvm/llvm-project/pull/161276


More information about the llvm-commits mailing list