[llvm] [SCEVAA] Allowing to subtract two inttoptrs with different pointer bases (PR #91453)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 03:27:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: None (csstormq)

<details>
<summary>Changes</summary>

In our out-of-tree target, all memory access are allocated at compile time using `inttoptr`. We depend on `scev-aa` to break dependencies among these memory locations. However, `ScalarEvolution::getMinusSCEV()` forbid subracting pointers with different pointer bases. That is, that function always returns `SCEVCouldNotCompute`, which does not meet our expectation.

To solve this problem we occurs, we allow to subtract two inttoptrs, regardless of whether the pointer bases are the same or not. For this example showed in pull request, `SCEVAAResult::alias()` will return `NoAlias` rather than `MayAlias`.

---
Full diff: https://github.com/llvm/llvm-project/pull/91453.diff


4 Files Affected:

- (modified) llvm/include/llvm/Analysis/ScalarEvolution.h (+25) 
- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+93-5) 
- (modified) llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp (+14-8) 
- (modified) llvm/test/Analysis/ScalarEvolution/scev-aa.ll (+26) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 5828cc156cc78..694d2dee9e875 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -721,6 +721,11 @@ class ScalarEvolution {
   const SCEV *getTruncateOrSignExtend(const SCEV *V, Type *Ty,
                                       unsigned Depth = 0);
 
+  /// Return a SCEV corresponding to a conversion of the input value to the
+  /// specified type.  If the type must be extended, it is any extended.
+  const SCEV *getTruncateOrAnyExtend(const SCEV *V, Type *Ty,
+                                     unsigned Depth = 0);
+
   /// Return a SCEV corresponding to a conversion of the input value to the
   /// specified type.  If the type must be extended, it is zero extended.  The
   /// conversion must not be narrowing.
@@ -754,6 +759,26 @@ class ScalarEvolution {
   const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
                                          bool Sequential = false);
 
+  /// Promote the operands to the wider of the types using any-extension, and
+  /// then perform a addrec operation with them.
+  const SCEV *
+  getAddRecExprFromMismatchedTypes(const SmallVectorImpl<const SCEV *> &Ops,
+                                   const Loop *L, SCEV::NoWrapFlags Flags);
+
+  /// Promote the operands to the wider of the types using any-extension, and
+  /// then perform a add operation with them.
+  const SCEV *
+  getAddExprFromMismatchedTypes(const SmallVectorImpl<const SCEV *> &Ops,
+                                SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap,
+                                unsigned Depth = 0);
+  const SCEV *
+  getAddExprFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS,
+                                SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap,
+                                unsigned Depth = 0) {
+    SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
+    return getAddExprFromMismatchedTypes(Ops, Flags, Depth);
+  }
+
   /// Transitively follow the chain of pointer-type operands until reaching a
   /// SCEV that does not have a single pointer operand. This returns a
   /// SCEVUnknown pointer for well-formed pointer-type expressions, but corner
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 93f885c5d5ad8..60bba380f66d6 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -4650,7 +4650,8 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
     Ops[0] = removePointerBase(Ops[0]);
     // Don't try to transfer nowrap flags for now. We could in some cases
     // (for example, if pointer operand of the AddRec is a SCEVUnknown).
-    return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
+    return getAddRecExprFromMismatchedTypes(Ops, AddRec->getLoop(),
+                                            SCEV::FlagAnyWrap);
   }
   if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
     // The base of an Add is the pointer operand.
@@ -4665,12 +4666,31 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
     *PtrOp = removePointerBase(*PtrOp);
     // Don't try to transfer nowrap flags for now. We could in some cases
     // (for example, if the pointer operand of the Add is a SCEVUnknown).
-    return getAddExpr(Ops);
+    return getAddExprFromMismatchedTypes(Ops);
   }
+
+  if (auto *Unknown = dyn_cast<SCEVUnknown>(P)) {
+    if (auto *O = dyn_cast<Operator>(Unknown->getValue())) {
+      if (O->getOpcode() == Instruction::IntToPtr) {
+        Value *Op0 = O->getOperand(0);
+        if (isa<ConstantInt>(Op0))
+          return getConstant(dyn_cast<ConstantInt>(Op0));
+        return getSCEV(Op0);
+      }
+    }
+  }
+
   // Any other expression must be a pointer base.
   return getZero(P->getType());
 }
 
+static bool isIntToPtr(const SCEV *V) {
+  if (auto *Unknown = dyn_cast<SCEVUnknown>(V))
+    if (auto *Op = dyn_cast<Operator>(Unknown->getValue()))
+      return Op->getOpcode() == Instruction::IntToPtr;
+  return false;
+}
+
 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
                                           SCEV::NoWrapFlags Flags,
                                           unsigned Depth) {
@@ -4678,12 +4698,15 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
   if (LHS == RHS)
     return getZero(LHS->getType());
 
-  // If we subtract two pointers with different pointer bases, bail.
+  // If we subtract two pointers except inttoptrs with different pointer bases,
+  // bail.
   // Eventually, we're going to add an assertion to getMulExpr that we
   // can't multiply by a pointer.
   if (RHS->getType()->isPointerTy()) {
+    const SCEV *LBase = getPointerBase(LHS);
+    const SCEV *RBase = getPointerBase(RHS);
     if (!LHS->getType()->isPointerTy() ||
-        getPointerBase(LHS) != getPointerBase(RHS))
+        (LBase != RBase && (!isIntToPtr(LBase) || !isIntToPtr(RBase))))
       return getCouldNotCompute();
     LHS = removePointerBase(LHS);
     RHS = removePointerBase(RHS);
@@ -4718,7 +4741,8 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
   // larger scope than intended.
   auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
 
-  return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
+  return getAddExprFromMismatchedTypes(LHS, getNegativeSCEV(RHS, NegFlags),
+                                       AddFlags, Depth);
 }
 
 const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
@@ -4745,6 +4769,18 @@ const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
   return getSignExtendExpr(V, Ty, Depth);
 }
 
+const SCEV *ScalarEvolution::getTruncateOrAnyExtend(const SCEV *V, Type *Ty,
+                                                    unsigned Depth) {
+  Type *SrcTy = V->getType();
+  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+         "Cannot truncate or any extend with non-integer arguments!");
+  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
+    return V; // No conversion
+  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
+    return getTruncateExpr(V, Ty, Depth);
+  return getAnyExtendExpr(V, Ty);
+}
+
 const SCEV *
 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
   Type *SrcTy = V->getType();
@@ -4839,6 +4875,58 @@ ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
   return getUMinExpr(PromotedOps, Sequential);
 }
 
+const SCEV *ScalarEvolution::getAddRecExprFromMismatchedTypes(
+    const SmallVectorImpl<const SCEV *> &Ops, const Loop *L,
+    SCEV::NoWrapFlags Flags) {
+  assert(!Ops.empty() && "At least one operand must be!");
+  // Trivial case.
+  if (Ops.size() == 1)
+    return Ops[0];
+
+  // Find the max type first.
+  Type *MaxType = nullptr;
+  for (const auto *S : Ops)
+    if (MaxType)
+      MaxType = getWiderType(MaxType, S->getType());
+    else
+      MaxType = S->getType();
+  assert(MaxType && "Failed to find maximum type!");
+
+  // Extend all ops to max type.
+  SmallVector<const SCEV *, 2> PromotedOps;
+  PromotedOps.reserve(Ops.size());
+  for (const auto *S : Ops)
+    PromotedOps.push_back(getNoopOrAnyExtend(S, MaxType));
+
+  return getAddRecExpr(PromotedOps, L, Flags);
+}
+
+const SCEV *ScalarEvolution::getAddExprFromMismatchedTypes(
+    const SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags,
+    unsigned Depth) {
+  assert(!Ops.empty() && "At least one operand must be!");
+  // Trivial case.
+  if (Ops.size() == 1)
+    return Ops[0];
+
+  // Find the max type first.
+  Type *MaxType = nullptr;
+  for (const auto *S : Ops)
+    if (MaxType)
+      MaxType = getWiderType(MaxType, S->getType());
+    else
+      MaxType = S->getType();
+  assert(MaxType && "Failed to find maximum type!");
+
+  // Extend all ops to max type.
+  SmallVector<const SCEV *, 2> PromotedOps;
+  PromotedOps.reserve(Ops.size());
+  for (const auto *S : Ops)
+    PromotedOps.push_back(getNoopOrAnyExtend(S, MaxType));
+
+  return getAddExpr(PromotedOps, Flags, Depth);
+}
+
 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
   // A pointer operand may evaluate to a nonpointer expression, such as null.
   if (!V->getType()->isPointerTy())
diff --git a/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp b/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp
index af8232b03f1ed..c4e13989e17a2 100644
--- a/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp
@@ -67,10 +67,13 @@ AliasResult SCEVAAResult::alias(const MemoryLocation &LocA,
     // Test whether the difference is known to be great enough that memory of
     // the given sizes don't overlap. This assumes that ASizeInt and BSizeInt
     // are non-zero, which is special-cased above.
-    if (!isa<SCEVCouldNotCompute>(BA) &&
-        ASizeInt.ule(SE.getUnsignedRange(BA).getUnsignedMin()) &&
-        (-BSizeInt).uge(SE.getUnsignedRange(BA).getUnsignedMax()))
-      return AliasResult::NoAlias;
+    if (!isa<SCEVCouldNotCompute>(BA)) {
+      if (SE.isSCEVable(BA->getType()))
+        BA = SE.getTruncateOrAnyExtend(BA, AS->getType());
+      if (ASizeInt.ule(SE.getUnsignedRange(BA).getUnsignedMin()) &&
+          (-BSizeInt).uge(SE.getUnsignedRange(BA).getUnsignedMax()))
+        return AliasResult::NoAlias;
+    }
 
     // Folding the subtraction while preserving range information can be tricky
     // (because of INT_MIN, etc.); if the prior test failed, swap AS and BS
@@ -82,10 +85,13 @@ AliasResult SCEVAAResult::alias(const MemoryLocation &LocA,
     // Test whether the difference is known to be great enough that memory of
     // the given sizes don't overlap. This assumes that ASizeInt and BSizeInt
     // are non-zero, which is special-cased above.
-    if (!isa<SCEVCouldNotCompute>(AB) &&
-        BSizeInt.ule(SE.getUnsignedRange(AB).getUnsignedMin()) &&
-        (-ASizeInt).uge(SE.getUnsignedRange(AB).getUnsignedMax()))
-      return AliasResult::NoAlias;
+    if (!isa<SCEVCouldNotCompute>(AB)) {
+      if (SE.isSCEVable(AB->getType()))
+        AB = SE.getTruncateOrAnyExtend(AB, AS->getType());
+      if (BSizeInt.ule(SE.getUnsignedRange(AB).getUnsignedMin()) &&
+          (-ASizeInt).uge(SE.getUnsignedRange(AB).getUnsignedMax()))
+        return AliasResult::NoAlias;
+    }
   }
 
   // If ScalarEvolution can find an underlying object, form a new query.
diff --git a/llvm/test/Analysis/ScalarEvolution/scev-aa.ll b/llvm/test/Analysis/ScalarEvolution/scev-aa.ll
index a81baa73a93bd..5610833e9c474 100644
--- a/llvm/test/Analysis/ScalarEvolution/scev-aa.ll
+++ b/llvm/test/Analysis/ScalarEvolution/scev-aa.ll
@@ -340,3 +340,29 @@ for.latch:
 for.end:
   ret void
 }
+
+; CHECK-LABEL: Function: test_different_pointer_bases_of_inttoptr: 2 pointers, 0 call sites
+; CHECK:   NoAlias:	<16 x i8>* %tmp5, <16 x i8>* %tmp7
+
+define void @test_different_pointer_bases_of_inttoptr() {
+entry:
+  br label %for.body
+
+for.body:
+  %tmp = phi i32 [ %next, %for.body ], [ 1, %entry ]
+  %tmp1 = shl nsw i32 %tmp, 1
+  %tmp2 = add nuw nsw i32 %tmp1, %tmp1
+  %tmp3 = mul nsw i32 %tmp2, 1408
+  %tmp4 = add nsw i32 %tmp3, 1408
+  %tmp5 = getelementptr inbounds i8, ptr inttoptr (i32 1024 to ptr), i32 %tmp1
+  %tmp6 = load <16 x i8>, ptr %tmp5, align 1
+  %tmp7 = getelementptr inbounds i8, ptr inttoptr (i32 4096 to ptr), i32 %tmp4
+  store <16 x i8> %tmp6, ptr %tmp7, align 1
+
+  %next = add i32 %tmp, 2
+  %exitcond = icmp slt i32 %next, 10
+  br i1 %exitcond, label %for.body, label %for.end
+
+for.end:
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/91453


More information about the llvm-commits mailing list