[llvm] [SCEV] Add initial support for ptrtoaddr. (PR #158032)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 14 14:52:25 PDT 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/158032

>From 8f742069375cf4ef87dde3662d8a82c7cd2c60ef Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 14 Oct 2025 15:06:55 +0100
Subject: [PATCH] [SCEV] Add initial support for ptrtoaddr.

Add support for PtrToAddr to SCEV and use it to compute pointer difference when computing backedge taken counts.

This patch retains SCEVPtrToInt, which is mostly used to expressions comping directly from IR.

PtrToAddr more closely matches the semantics of most uses in SCEV. See #156978 for some related discussion.

The patch still retains SCEVPtrToInt, as it is used for cases where a SCEV expression is used in IR by a ptrtoint, e.g. https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/IndVarSimplify/pr59633.ll. Some (or perhaps all, still need to check) should probably just be SCEVUnknown.
---
 llvm/include/llvm/Analysis/ScalarEvolution.h  |   9 +-
 .../llvm/Analysis/ScalarEvolutionDivision.h   |   1 +
 .../Analysis/ScalarEvolutionExpressions.h     |  28 +++-
 .../Utils/ScalarEvolutionExpander.h           |   2 +
 llvm/lib/Analysis/ScalarEvolution.cpp         | 121 +++++++++++++-----
 .../Utils/ScalarEvolutionExpander.cpp         |  11 ++
 .../Analysis/ScalarEvolution/ptrtoaddr.ll     |  42 +++---
 .../LoopVectorize/expand-ptrtoaddr.ll         |  47 +++++++
 8 files changed, 207 insertions(+), 54 deletions(-)
 create mode 100644 llvm/test/Transforms/LoopVectorize/expand-ptrtoaddr.ll

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index e5a6c8cc0a6aa..5d7e6417ebd15 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -563,8 +563,13 @@ class ScalarEvolution {
   LLVM_ABI const SCEV *getConstant(ConstantInt *V);
   LLVM_ABI const SCEV *getConstant(const APInt &Val);
   LLVM_ABI const SCEV *getConstant(Type *Ty, uint64_t V, bool isSigned = false);
-  LLVM_ABI const SCEV *getLosslessPtrToIntExpr(const SCEV *Op,
-                                               unsigned Depth = 0);
+  LLVM_ABI const SCEV *getLosslessPtrToIntOrAddrExpr(SCEVTypes Kind,
+                                                     const SCEV *Op,
+                                                     unsigned Depth = 0);
+  LLVM_ABI const SCEV *getLosslessPtrToIntExpr(const SCEV *Op);
+  LLVM_ABI const SCEV *getLosslessPtrToAddrExpr(const SCEV *Op);
+
+  LLVM_ABI const SCEV *getPtrToAddrExpr(const SCEV *Op, Type *Ty);
   LLVM_ABI const SCEV *getPtrToIntExpr(const SCEV *Op, Type *Ty);
   LLVM_ABI const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty,
                                        unsigned Depth = 0);
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h
index 7c78487fea8c0..d3df7e346b4a5 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h
@@ -46,6 +46,7 @@ struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
 
   // Except in the trivial case described above, we do not know how to divide
   // Expr by Denominator for the following functions with empty implementation.
+  void visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Numerator) {}
   void visitPtrToIntExpr(const SCEVPtrToIntExpr *Numerator) {}
   void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {}
   void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {}
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 13b9e1b812942..6ae7b1b08773a 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -53,6 +53,7 @@ enum SCEVTypes : unsigned short {
   scSMinExpr,
   scSequentialUMinExpr,
   scPtrToInt,
+  scPtrToAddr,
   scUnknown,
   scCouldNotCompute
 };
@@ -121,8 +122,9 @@ class SCEVCastExpr : public SCEV {
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const SCEV *S) {
-    return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
-           S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
+    return S->getSCEVType() == scPtrToAddr || S->getSCEVType() == scPtrToInt ||
+           S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
+           S->getSCEVType() == scSignExtend;
   }
 };
 
@@ -138,6 +140,18 @@ class SCEVPtrToIntExpr : public SCEVCastExpr {
   static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
 };
 
+/// This class represents a cast from a pointer to a pointer-sized integer
+/// value, without capturing the provenance of the pointer.
+class SCEVPtrToAddrExpr : public SCEVCastExpr {
+  friend class ScalarEvolution;
+
+  SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
+
+public:
+  /// Methods for support type inquiry through isa, cast, and dyn_cast:
+  static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToAddr; }
+};
+
 /// This is the base class for unary integral cast operator classes.
 class SCEVIntegralCastExpr : public SCEVCastExpr {
 protected:
@@ -615,6 +629,8 @@ template <typename SC, typename RetVal = void> struct SCEVVisitor {
       return ((SC *)this)->visitConstant((const SCEVConstant *)S);
     case scVScale:
       return ((SC *)this)->visitVScale((const SCEVVScale *)S);
+    case scPtrToAddr:
+      return ((SC *)this)->visitPtrToAddrExpr((const SCEVPtrToAddrExpr *)S);
     case scPtrToInt:
       return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
     case scTruncate:
@@ -685,6 +701,7 @@ template <typename SV> class SCEVTraversal {
       case scVScale:
       case scUnknown:
         continue;
+      case scPtrToAddr:
       case scPtrToInt:
       case scTruncate:
       case scZeroExtend:
@@ -774,6 +791,13 @@ class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
 
   const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
 
+  const SCEV *visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) {
+    const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
+    return Operand == Expr->getOperand()
+               ? Expr
+               : SE.getPtrToAddrExpr(Operand, Expr->getType());
+  }
+
   const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
     return Operand == Expr->getOperand()
diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 310118078695c..c986f78db5c19 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -498,6 +498,8 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
 
   Value *visitVScale(const SCEVVScale *S);
 
+  Value *visitPtrToAddrExpr(const SCEVPtrToAddrExpr *S);
+
   Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S);
 
   Value *visitTruncateExpr(const SCEVTruncateExpr *S);
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a64b93d541943..5e789323a04be 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -277,11 +277,13 @@ void SCEV::print(raw_ostream &OS) const {
   case scVScale:
     OS << "vscale";
     return;
+  case scPtrToAddr:
   case scPtrToInt: {
-    const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
-    const SCEV *Op = PtrToInt->getOperand();
-    OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
-       << *PtrToInt->getType() << ")";
+    const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
+    const SCEV *Op = PtrCast->getOperand();
+    StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
+    OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
+       << *PtrCast->getType() << ")";
     return;
   }
   case scTruncate: {
@@ -386,6 +388,7 @@ Type *SCEV::getType() const {
     return cast<SCEVConstant>(this)->getType();
   case scVScale:
     return cast<SCEVVScale>(this)->getType();
+  case scPtrToAddr:
   case scPtrToInt:
   case scTruncate:
   case scZeroExtend:
@@ -420,6 +423,7 @@ ArrayRef<const SCEV *> SCEV::operands() const {
   case scVScale:
   case scUnknown:
     return {};
+  case scPtrToAddr:
   case scPtrToInt:
   case scTruncate:
   case scZeroExtend:
@@ -512,6 +516,13 @@ SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                            const SCEV *op, Type *ty)
     : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
 
+SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
+                                     const SCEV *Op, Type *ITy)
+    : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
+  assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
+         "Must be a non-bit-width-changing pointer-to-integer cast!");
+}
+
 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
                                    Type *ITy)
     : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
@@ -724,6 +735,7 @@ CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
   case scTruncate:
   case scZeroExtend:
   case scSignExtend:
+  case scPtrToAddr:
   case scPtrToInt:
   case scAddExpr:
   case scMulExpr:
@@ -1004,10 +1016,11 @@ SCEVAddRecExpr::evaluateAtIteration(ArrayRef<const SCEV *> Operands,
 //                    SCEV Expression folder implementations
 //===----------------------------------------------------------------------===//
 
-const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op,
-                                                     unsigned Depth) {
+const SCEV *ScalarEvolution::getLosslessPtrToIntOrAddrExpr(SCEVTypes Kind,
+                                                           const SCEV *Op,
+                                                           unsigned Depth) {
   assert(Depth <= 1 &&
-         "getLosslessPtrToIntExpr() should self-recurse at most once.");
+         "getLosslessPtrToIntOrAddrExpr() should self-recurse at most once.");
 
   // We could be called with an integer-typed operands during SCEV rewrites.
   // Since the operand is an integer already, just perform zext/trunc/self cast.
@@ -1052,35 +1065,45 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op,
     // Create an explicit cast node.
     // We can reuse the existing insert position since if we get here,
     // we won't have made any changes which would invalidate it.
-    SCEV *S = new (SCEVAllocator)
-        SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
+    SCEV *S;
+    if (Kind == scPtrToInt) {
+      S = new (SCEVAllocator)
+          SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
+    } else {
+      S = new (SCEVAllocator)
+          SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
+    }
     UniqueSCEVs.InsertNode(S, IP);
     registerUser(S, Op);
     return S;
   }
 
-  assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
-                       "non-SCEVUnknown's.");
+  assert(Depth == 0 &&
+         "getLosslessPtrToIntOrAddrExpr() should not self-recurse for "
+         "non-SCEVUnknown's.");
 
   // Otherwise, we've got some expression that is more complex than just a
-  // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
-  // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
-  // only, and the expressions must otherwise be integer-typed.
-  // So sink the cast down to the SCEVUnknown's.
+  // single SCEVUnknown. But we don't want to have a SCEVPtrTo(Int|Addr)Expr of
+  // an arbitrary expression, we want to have SCEVPtrTo(Int|Addr)Expr of an
+  // SCEVUnknown only, and the expressions must otherwise be integer-typed. So
+  // sink the cast down to the SCEVUnknown's.
 
-  /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
+  /// The SCEVPtrToIntOrAddrSinkingRewriter takes a scalar evolution expression,
   /// which computes a pointer-typed value, and rewrites the whole expression
   /// tree so that *all* the computations are done on integers, and the only
   /// pointer-typed operands in the expression are SCEVUnknown.
-  class SCEVPtrToIntSinkingRewriter
-      : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
-    using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>;
+  class SCEVPtrToIntOrAddrSinkingRewriter
+      : public SCEVRewriteVisitor<SCEVPtrToIntOrAddrSinkingRewriter> {
+    using Base = SCEVRewriteVisitor<SCEVPtrToIntOrAddrSinkingRewriter>;
+    const SCEVTypes Kind;
 
   public:
-    SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
+    SCEVPtrToIntOrAddrSinkingRewriter(SCEVTypes Kind, ScalarEvolution &SE)
+        : SCEVRewriteVisitor(SE), Kind(Kind) {}
 
-    static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
-      SCEVPtrToIntSinkingRewriter Rewriter(SE);
+    static const SCEV *rewrite(const SCEV *Scev, SCEVTypes Kind,
+                               ScalarEvolution &SE) {
+      SCEVPtrToIntOrAddrSinkingRewriter Rewriter(Kind, SE);
       return Rewriter.visit(Scev);
     }
 
@@ -1116,18 +1139,37 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op,
     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
       assert(Expr->getType()->isPointerTy() &&
              "Should only reach pointer-typed SCEVUnknown's.");
-      return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
+      return SE.getLosslessPtrToIntOrAddrExpr(Kind, Expr, /*Depth=*/1);
     }
   };
 
   // And actually perform the cast sinking.
-  const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
+  const SCEV *IntOp =
+      SCEVPtrToIntOrAddrSinkingRewriter::rewrite(Op, Kind, *this);
   assert(IntOp->getType()->isIntegerTy() &&
          "We must have succeeded in sinking the cast, "
          "and ending up with an integer-typed expression!");
   return IntOp;
 }
 
+const SCEV *ScalarEvolution::getLosslessPtrToAddrExpr(const SCEV *Op) {
+  return getLosslessPtrToIntOrAddrExpr(scPtrToAddr, Op);
+}
+
+const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op) {
+  return getLosslessPtrToIntOrAddrExpr(scPtrToInt, Op);
+}
+
+const SCEV *ScalarEvolution::getPtrToAddrExpr(const SCEV *Op, Type *Ty) {
+  assert(Ty->isIntegerTy() && "Target type must be an integer type!");
+
+  const SCEV *IntOp = getLosslessPtrToAddrExpr(Op);
+  if (isa<SCEVCouldNotCompute>(IntOp))
+    return IntOp;
+
+  return getTruncateOrZeroExtend(IntOp, Ty);
+}
+
 const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) {
   assert(Ty->isIntegerTy() && "Target type must be an integer type!");
 
@@ -4076,6 +4118,8 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final
 
   RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
 
+  RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
+
   RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
 
   RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
@@ -4126,6 +4170,7 @@ static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) {
   case scTruncate:
   case scZeroExtend:
   case scSignExtend:
+  case scPtrToAddr:
   case scPtrToInt:
   case scAddExpr:
   case scMulExpr:
@@ -6361,8 +6406,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
   switch (S->getSCEVType()) {
   case scConstant:
     return cast<SCEVConstant>(S)->getAPInt();
+  case scPtrToAddr:
   case scPtrToInt:
-    return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
+    return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
   case scUDivExpr:
   case scVScale:
     return APInt(BitWidth, 1);
@@ -6639,6 +6685,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S,
     case scTruncate:
     case scZeroExtend:
     case scSignExtend:
+    case scPtrToAddr:
     case scPtrToInt:
     case scAddExpr:
     case scMulExpr:
@@ -6766,10 +6813,11 @@ const ConstantRange &ScalarEvolution::getRangeRef(
         SExt, SignHint,
         ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
   }
+  case scPtrToAddr:
   case scPtrToInt: {
-    const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
-    ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
-    return setRange(PtrToInt, SignHint, X);
+    const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
+    ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
+    return setRange(Cast, SignHint, X);
   }
   case scAddExpr: {
     const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
@@ -7662,6 +7710,7 @@ ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
   case Instruction::Trunc:
   case Instruction::ZExt:
   case Instruction::SExt:
+  case Instruction::PtrToAddr:
   case Instruction::PtrToInt:
     Ops.push_back(U->getOperand(0));
     return nullptr;
@@ -8134,13 +8183,16 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       return getSCEV(U->getOperand(0));
     break;
 
+  case Instruction::PtrToAddr:
   case Instruction::PtrToInt: {
     // Pointer to integer cast is straight-forward, so do model it.
     const SCEV *Op = getSCEV(U->getOperand(0));
     Type *DstIntTy = U->getType();
     // But only if effective SCEV (integer) type is wide enough to represent
     // all possible pointer values.
-    const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
+    const SCEV *IntOp = U->getOpcode() == Instruction::PtrToInt
+                            ? getPtrToIntExpr(Op, DstIntTy)
+                            : getPtrToAddrExpr(Op, DstIntTy);
     if (isa<SCEVCouldNotCompute>(IntOp))
       return getUnknown(V);
     return IntOp;
@@ -9942,6 +9994,13 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
     return cast<SCEVConstant>(V)->getValue();
   case scUnknown:
     return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
+  case scPtrToAddr: {
+    const SCEVPtrToAddrExpr *P2I = cast<SCEVPtrToAddrExpr>(V);
+    if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
+      return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
+
+    return nullptr;
+  }
   case scPtrToInt: {
     const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
     if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
@@ -10000,6 +10059,7 @@ ScalarEvolution::getWithOperands(const SCEV *S,
   case scTruncate:
   case scZeroExtend:
   case scSignExtend:
+  case scPtrToAddr:
   case scPtrToInt:
     return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
   case scAddRecExpr: {
@@ -10084,6 +10144,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
   case scTruncate:
   case scZeroExtend:
   case scSignExtend:
+  case scPtrToAddr:
   case scPtrToInt:
   case scAddExpr:
   case scMulExpr:
@@ -14187,6 +14248,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
   case scTruncate:
   case scZeroExtend:
   case scSignExtend:
+  case scPtrToAddr:
   case scPtrToInt:
   case scAddExpr:
   case scMulExpr:
@@ -14268,6 +14330,7 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
   case scTruncate:
   case scZeroExtend:
   case scSignExtend:
+  case scPtrToAddr:
   case scPtrToInt:
   case scAddExpr:
   case scMulExpr:
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 9035e58a707c4..6f08eabdd90de 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -457,6 +457,7 @@ const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) {
   case scTruncate:
   case scZeroExtend:
   case scSignExtend:
+  case scPtrToAddr:
   case scPtrToInt:
   case scAddExpr:
   case scMulExpr:
@@ -1433,6 +1434,12 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
   return expand(T);
 }
 
+Value *SCEVExpander::visitPtrToAddrExpr(const SCEVPtrToAddrExpr *S) {
+  Value *V = expand(S->getOperand());
+  return ReuseOrCreateCast(V, S->getType(), CastInst::PtrToAddr,
+                           GetOptimalInsertionPointForCastOf(V));
+}
+
 Value *SCEVExpander::visitPtrToIntExpr(const SCEVPtrToIntExpr *S) {
   Value *V = expand(S->getOperand());
   return ReuseOrCreateCast(V, S->getType(), CastInst::PtrToInt,
@@ -1957,6 +1964,9 @@ template<typename T> static InstructionCost costAndCollectOperands(
   case scConstant:
   case scVScale:
     return 0;
+  case scPtrToAddr:
+    Cost = CastCost(Instruction::PtrToAddr);
+    break;
   case scPtrToInt:
     Cost = CastCost(Instruction::PtrToInt);
     break;
@@ -2080,6 +2090,7 @@ bool SCEVExpander::isHighCostExpansionHelper(
     return Cost > Budget;
   }
   case scTruncate:
+  case scPtrToAddr:
   case scPtrToInt:
   case scZeroExtend:
   case scSignExtend: {
diff --git a/llvm/test/Analysis/ScalarEvolution/ptrtoaddr.ll b/llvm/test/Analysis/ScalarEvolution/ptrtoaddr.ll
index ebab9f0f20323..e553e8efe6784 100644
--- a/llvm/test/Analysis/ScalarEvolution/ptrtoaddr.ll
+++ b/llvm/test/Analysis/ScalarEvolution/ptrtoaddr.ll
@@ -8,25 +8,25 @@ define void @ptrtoaddr(ptr %in, ptr %out0, ptr %out1, ptr %out2, ptr %out3) {
 ; X64-LABEL: 'ptrtoaddr'
 ; X64-NEXT:  Classifying expressions for: @ptrtoaddr
 ; X64-NEXT:    %p0 = ptrtoaddr ptr %in to i64
-; X64-NEXT:    --> %p0 U: full-set S: full-set
+; X64-NEXT:    --> (ptrtoaddr ptr %in to i64) U: full-set S: full-set
 ; X64-NEXT:    %p1 = ptrtoaddr ptr %in to i32
-; X64-NEXT:    --> %p1 U: full-set S: full-set
+; X64-NEXT:    --> (trunc i64 (ptrtoaddr ptr %in to i64) to i32) U: full-set S: full-set
 ; X64-NEXT:    %p2 = ptrtoaddr ptr %in to i16
-; X64-NEXT:    --> %p2 U: full-set S: full-set
+; X64-NEXT:    --> (trunc i64 (ptrtoaddr ptr %in to i64) to i16) U: full-set S: full-set
 ; X64-NEXT:    %p3 = ptrtoaddr ptr %in to i128
-; X64-NEXT:    --> %p3 U: full-set S: full-set
+; X64-NEXT:    --> (zext i64 (ptrtoaddr ptr %in to i64) to i128) U: [0,18446744073709551616) S: [0,18446744073709551616)
 ; X64-NEXT:  Determining loop execution counts for: @ptrtoaddr
 ;
 ; X32-LABEL: 'ptrtoaddr'
 ; X32-NEXT:  Classifying expressions for: @ptrtoaddr
 ; X32-NEXT:    %p0 = ptrtoaddr ptr %in to i64
-; X32-NEXT:    --> %p0 U: full-set S: full-set
+; X32-NEXT:    --> (zext i32 (ptrtoaddr ptr %in to i32) to i64) U: [0,4294967296) S: [0,4294967296)
 ; X32-NEXT:    %p1 = ptrtoaddr ptr %in to i32
-; X32-NEXT:    --> %p1 U: full-set S: full-set
+; X32-NEXT:    --> (ptrtoaddr ptr %in to i32) U: full-set S: full-set
 ; X32-NEXT:    %p2 = ptrtoaddr ptr %in to i16
-; X32-NEXT:    --> %p2 U: full-set S: full-set
+; X32-NEXT:    --> (trunc i32 (ptrtoaddr ptr %in to i32) to i16) U: full-set S: full-set
 ; X32-NEXT:    %p3 = ptrtoaddr ptr %in to i128
-; X32-NEXT:    --> %p3 U: full-set S: full-set
+; X32-NEXT:    --> (zext i32 (ptrtoaddr ptr %in to i32) to i128) U: [0,4294967296) S: [0,4294967296)
 ; X32-NEXT:  Determining loop execution counts for: @ptrtoaddr
 ;
   %p0 = ptrtoaddr ptr %in to i64
@@ -44,25 +44,25 @@ define void @ptrtoaddr_as1(ptr addrspace(1) %in, ptr %out0, ptr %out1, ptr %out2
 ; X64-LABEL: 'ptrtoaddr_as1'
 ; X64-NEXT:  Classifying expressions for: @ptrtoaddr_as1
 ; X64-NEXT:    %p0 = ptrtoaddr ptr addrspace(1) %in to i64
-; X64-NEXT:    --> %p0 U: full-set S: full-set
+; X64-NEXT:    --> (ptrtoaddr ptr addrspace(1) %in to i64) U: full-set S: full-set
 ; X64-NEXT:    %p1 = ptrtoaddr ptr addrspace(1) %in to i32
-; X64-NEXT:    --> %p1 U: full-set S: full-set
+; X64-NEXT:    --> (trunc i64 (ptrtoaddr ptr addrspace(1) %in to i64) to i32) U: full-set S: full-set
 ; X64-NEXT:    %p2 = ptrtoaddr ptr addrspace(1) %in to i16
-; X64-NEXT:    --> %p2 U: full-set S: full-set
+; X64-NEXT:    --> (trunc i64 (ptrtoaddr ptr addrspace(1) %in to i64) to i16) U: full-set S: full-set
 ; X64-NEXT:    %p3 = ptrtoaddr ptr addrspace(1) %in to i128
-; X64-NEXT:    --> %p3 U: full-set S: full-set
+; X64-NEXT:    --> (zext i64 (ptrtoaddr ptr addrspace(1) %in to i64) to i128) U: [0,18446744073709551616) S: [0,18446744073709551616)
 ; X64-NEXT:  Determining loop execution counts for: @ptrtoaddr_as1
 ;
 ; X32-LABEL: 'ptrtoaddr_as1'
 ; X32-NEXT:  Classifying expressions for: @ptrtoaddr_as1
 ; X32-NEXT:    %p0 = ptrtoaddr ptr addrspace(1) %in to i64
-; X32-NEXT:    --> %p0 U: full-set S: full-set
+; X32-NEXT:    --> (zext i32 (ptrtoaddr ptr addrspace(1) %in to i32) to i64) U: [0,4294967296) S: [0,4294967296)
 ; X32-NEXT:    %p1 = ptrtoaddr ptr addrspace(1) %in to i32
-; X32-NEXT:    --> %p1 U: full-set S: full-set
+; X32-NEXT:    --> (ptrtoaddr ptr addrspace(1) %in to i32) U: full-set S: full-set
 ; X32-NEXT:    %p2 = ptrtoaddr ptr addrspace(1) %in to i16
-; X32-NEXT:    --> %p2 U: full-set S: full-set
+; X32-NEXT:    --> (trunc i32 (ptrtoaddr ptr addrspace(1) %in to i32) to i16) U: full-set S: full-set
 ; X32-NEXT:    %p3 = ptrtoaddr ptr addrspace(1) %in to i128
-; X32-NEXT:    --> %p3 U: full-set S: full-set
+; X32-NEXT:    --> (zext i32 (ptrtoaddr ptr addrspace(1) %in to i32) to i128) U: [0,4294967296) S: [0,4294967296)
 ; X32-NEXT:  Determining loop execution counts for: @ptrtoaddr_as1
 ;
   %p0 = ptrtoaddr ptr addrspace(1) %in to i64
@@ -82,7 +82,7 @@ define void @ptrtoaddr_of_bitcast(ptr %in, ptr %out0) {
 ; X64-NEXT:    %in_casted = bitcast ptr %in to ptr
 ; X64-NEXT:    --> %in U: full-set S: full-set
 ; X64-NEXT:    %p0 = ptrtoaddr ptr %in_casted to i64
-; X64-NEXT:    --> %p0 U: full-set S: full-set
+; X64-NEXT:    --> (ptrtoaddr ptr %in to i64) U: full-set S: full-set
 ; X64-NEXT:  Determining loop execution counts for: @ptrtoaddr_of_bitcast
 ;
 ; X32-LABEL: 'ptrtoaddr_of_bitcast'
@@ -90,7 +90,7 @@ define void @ptrtoaddr_of_bitcast(ptr %in, ptr %out0) {
 ; X32-NEXT:    %in_casted = bitcast ptr %in to ptr
 ; X32-NEXT:    --> %in U: full-set S: full-set
 ; X32-NEXT:    %p0 = ptrtoaddr ptr %in_casted to i64
-; X32-NEXT:    --> %p0 U: full-set S: full-set
+; X32-NEXT:    --> (zext i32 (ptrtoaddr ptr %in to i32) to i64) U: [0,4294967296) S: [0,4294967296)
 ; X32-NEXT:  Determining loop execution counts for: @ptrtoaddr_of_bitcast
 ;
   %in_casted = bitcast ptr %in to ptr
@@ -103,7 +103,7 @@ define void @ptrtoaddr_of_nullptr(ptr %out0) {
 ; ALL-LABEL: 'ptrtoaddr_of_nullptr'
 ; ALL-NEXT:  Classifying expressions for: @ptrtoaddr_of_nullptr
 ; ALL-NEXT:    %p0 = ptrtoaddr ptr null to i64
-; ALL-NEXT:    --> %p0 U: full-set S: full-set
+; ALL-NEXT:    --> 0 U: [0,1) S: [0,1)
 ; ALL-NEXT:  Determining loop execution counts for: @ptrtoaddr_of_nullptr
 ;
   %p0 = ptrtoaddr ptr null to i64
@@ -117,7 +117,7 @@ define void @ptrtoaddr_of_gep(ptr %in, ptr %out0) {
 ; X64-NEXT:    %in_adj = getelementptr inbounds i8, ptr %in, i64 42
 ; X64-NEXT:    --> (42 + %in) U: full-set S: full-set
 ; X64-NEXT:    %p0 = ptrtoaddr ptr %in_adj to i64
-; X64-NEXT:    --> %p0 U: full-set S: full-set
+; X64-NEXT:    --> (42 + (ptrtoaddr ptr %in to i64)) U: full-set S: full-set
 ; X64-NEXT:  Determining loop execution counts for: @ptrtoaddr_of_gep
 ;
 ; X32-LABEL: 'ptrtoaddr_of_gep'
@@ -125,7 +125,7 @@ define void @ptrtoaddr_of_gep(ptr %in, ptr %out0) {
 ; X32-NEXT:    %in_adj = getelementptr inbounds i8, ptr %in, i64 42
 ; X32-NEXT:    --> (42 + %in) U: full-set S: full-set
 ; X32-NEXT:    %p0 = ptrtoaddr ptr %in_adj to i64
-; X32-NEXT:    --> %p0 U: full-set S: full-set
+; X32-NEXT:    --> (zext i32 (42 + (ptrtoaddr ptr %in to i32)) to i64) U: [0,4294967296) S: [0,4294967296)
 ; X32-NEXT:  Determining loop execution counts for: @ptrtoaddr_of_gep
 ;
   %in_adj = getelementptr inbounds i8, ptr %in, i64 42
diff --git a/llvm/test/Transforms/LoopVectorize/expand-ptrtoaddr.ll b/llvm/test/Transforms/LoopVectorize/expand-ptrtoaddr.ll
new file mode 100644
index 0000000000000..e925e8422c3fd
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/expand-ptrtoaddr.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals none --filter-out-after "scalar.ph:" --version 6
+; RUN: opt -p loop-vectorize -force-vector-width=4 -S %s | FileCheck %s
+
+define void @test_ptrtoaddr_tripcount(ptr %start, ptr %end) {
+; CHECK-LABEL: define void @test_ptrtoaddr_tripcount(
+; CHECK-SAME: ptr [[START:%.*]], ptr [[END:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[START_ADDR:%.*]] = ptrtoaddr ptr [[START]] to i64
+; CHECK-NEXT:    [[END_ADDR:%.*]] = ptrtoaddr ptr [[END]] to i64
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[END_ADDR]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sub i64 [[TMP0]], [[START_ADDR]]
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP1]], 4
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
+; CHECK:       [[VECTOR_PH]]:
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], 4
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP1]], [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[START_ADDR]], [[N_VEC]]
+; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
+; CHECK:       [[VECTOR_BODY]]:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = add i64 [[START_ADDR]], [[INDEX]]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, ptr [[START]], i64 [[OFFSET_IDX]]
+; CHECK-NEXT:    store <4 x i8> zeroinitializer, ptr [[TMP3]], align 1
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP4]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       [[MIDDLE_BLOCK]]:
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], [[EXIT:label %.*]], label %[[SCALAR_PH]]
+; CHECK:       [[SCALAR_PH]]:
+;
+entry:
+  %start.addr = ptrtoaddr ptr %start to i64
+  %end.addr = ptrtoaddr ptr %end to i64
+  br label %loop
+
+loop:
+  %iv = phi i64 [ %start.addr, %entry ], [ %iv.next, %loop ]
+  %gep = getelementptr inbounds i8, ptr %start, i64 %iv
+  store i8 0, ptr %gep
+  %iv.next = add i64 %iv, 1
+  %cmp = icmp ne i64 %iv, %end.addr
+  br i1 %cmp, label %loop, label %exit
+
+exit:
+  ret void
+}



More information about the llvm-commits mailing list