[llvm] [InferAS] Infer the address space of inttoptr (PR #173244)

Luo Yuanke via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 10 20:53:14 PST 2026


https://github.com/LuoYuanke updated https://github.com/llvm/llvm-project/pull/173244

>From 890dfe54e31ac2613a755af2e075a3470f968661 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Sun, 21 Dec 2025 17:28:34 +0800
Subject: [PATCH 01/12] [InferAS] Infer the address space of inttoptr

Currently the InferAddressSpaces would check if the bit value doesn't
change for <ptrtoint, inttoptr> address space cast. However the
condition is too strict. Sometime only the low bit address changes for
swizzling, and the address space is not changed. Take below code as
an example, we can tranform `%gp2 = inttoptr i64 %b to ptr` to
`%gp2 = inttoptr i64 %b to ptr addrspace(2)` and specify addrspace(2)
for the following store instruction.
```
   %gp = addrspacecast ptr addrspace(2) %sp to ptr
   %a = ptrtoint ptr %gp to i64
   %b = xor i64 7, %a
   %gp2 = inttoptr i64 %b to ptr
   store i16 0, ptr %gp2, align 2
```
This patch tries to infer the unchanged leading bit for the address
and let the target determine if it is safe to perform address space
cast for inttoptr instruction.
---
 .../llvm/Analysis/TargetTransformInfo.h       |  14 ++
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   5 +
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   6 +
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |  16 +++
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 117 ++++++++++++----
 .../InferAddressSpaces/NVPTX/int2ptr.ll       | 125 ++++++++++++++++++
 6 files changed, 259 insertions(+), 24 deletions(-)
 create mode 100644 llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index ff91b24ff17e5..897972a0cfd5f 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1940,6 +1940,20 @@ class TargetTransformInfo {
   /// This should also apply to lowering for vector funnel shifts (rotates).
   LLVM_ABI bool isVectorShiftByScalarCheap(Type *Ty) const;
 
+  /// Return true if it is safe to cast integer to pointer with new address
+  /// space. The address of integer form may only change in the least
+  /// significant bit (e.g. within a page). In that case target can determine
+  /// if it is safe to cast the generic address space to the original address
+  /// space. For below example, we can replace `%gp2 = inttoptr i64 %b to ptr`
+  /// with `%gp2 = inttoptr i64 %b to ptr addrspace(2)`
+  ///   %gp = addrspacecast ptr addrspace(2) %sp to ptr
+  ///   %a = ptrtoint ptr %gp to i64
+  ///   %b = xor i64 7, %a
+  ///   %gp2 = inttoptr i64 %b to ptr
+  ///   store i16 0, ptr %gp2, align 2
+  LLVM_ABI bool isSafeToCastIntPtrWithAS(unsigned AddrUnchangedLeadingBit,
+                                         unsigned SrcAS, unsigned DstAS) const;
+
   struct VPLegalization {
     enum VPTransform {
       // keep the predicating parameter
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 07b3755924fd1..3599cdbeff0d3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -1159,6 +1159,11 @@ class TargetTransformInfoImplBase {
 
   virtual bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
 
+  virtual bool isSafeToCastIntPtrWithAS(unsigned AddrUnchangedLeadingBit,
+                                        unsigned SrcAS, unsigned DstAS) const {
+    return false;
+  }
+
   virtual TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index b2b77da4914d6..a4e266f3c7b76 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1494,6 +1494,12 @@ bool TargetTransformInfo::isVectorShiftByScalarCheap(Type *Ty) const {
   return TTIImpl->isVectorShiftByScalarCheap(Ty);
 }
 
+bool TargetTransformInfo::isSafeToCastIntPtrWithAS(
+    unsigned AddrUnchangedLeadingBit, unsigned SrcAS, unsigned DstAS) const {
+  return TTIImpl->isSafeToCastIntPtrWithAS(AddrUnchangedLeadingBit, SrcAS,
+                                           DstAS);
+}
+
 unsigned
 TargetTransformInfo::getNumBytesToPadGlobalArray(unsigned Size,
                                                  Type *ArrayType) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index ae12a6ea3baa3..c5847ebe3fc0e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -180,6 +180,22 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
     }
   }
 
+  bool isSafeToCastIntPtrWithAS(unsigned AddrUnchangedLeadingBit,
+                                unsigned SrcAS, unsigned DstAS) const override {
+    if (SrcAS != llvm::ADDRESS_SPACE_GENERIC)
+      return false;
+    if (DstAS != llvm::ADDRESS_SPACE_GLOBAL &&
+        DstAS != llvm::ADDRESS_SPACE_SHARED)
+      return false;
+
+    // Address change within 4K size does not change the original address space
+    // and is safe to perform address cast form SrcAS to DstAS.
+    if (AddrUnchangedLeadingBit >= 52)
+      return true;
+
+    return false;
+  }
+
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 03efc156df1e8..716391eeb1442 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -122,6 +122,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/KnownBits.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/Local.h"
@@ -191,6 +192,19 @@ class InferAddressSpacesImpl {
   /// Target specific address space which uses of should be replaced if
   /// possible.
   unsigned FlatAddrSpace = 0;
+  mutable DenseMap<const Value *, Value *> PtrIntCastPairs;
+
+  bool isSafeToCastPtrIntPair(const Operator *I2P, const DataLayout &DL) const;
+  bool isAddressExpression(const Value &V, const DataLayout &DL,
+                           const TargetTransformInfo *TTI) const;
+  Value *cloneConstantExprWithNewAddressSpace(
+      ConstantExpr *CE, unsigned NewAddrSpace,
+      const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL,
+      const TargetTransformInfo *TTI) const;
+
+  SmallVector<Value *, 2>
+  getPointerOperands(const Value &V, const DataLayout &DL,
+                     const TargetTransformInfo *TTI) const;
 
   // Try to update the address space of V. If V is updated, returns true and
   // false otherwise.
@@ -314,8 +328,9 @@ static bool isNoopPtrIntCastPair(const Operator *I2P, const DataLayout &DL,
 // TODO: Currently, we only consider:
 //   - arguments
 //   - phi, bitcast, addrspacecast, and getelementptr operators
-static bool isAddressExpression(const Value &V, const DataLayout &DL,
-                                const TargetTransformInfo *TTI) {
+bool InferAddressSpacesImpl::isAddressExpression(
+    const Value &V, const DataLayout &DL,
+    const TargetTransformInfo *TTI) const {
 
   if (const Argument *Arg = dyn_cast<Argument>(&V))
     return Arg->getType()->isPointerTy() &&
@@ -340,7 +355,7 @@ static bool isAddressExpression(const Value &V, const DataLayout &DL,
     return II && II->getIntrinsicID() == Intrinsic::ptrmask;
   }
   case Instruction::IntToPtr:
-    return isNoopPtrIntCastPair(Op, DL, TTI);
+    return isNoopPtrIntCastPair(Op, DL, TTI) || isSafeToCastPtrIntPair(Op, DL);
   default:
     // That value is an address expression if it has an assumed address space.
     return TTI->getAssumedAddrSpace(&V) != UninitializedAddressSpace;
@@ -350,9 +365,9 @@ static bool isAddressExpression(const Value &V, const DataLayout &DL,
 // Returns the pointer operands of V.
 //
 // Precondition: V is an address expression.
-static SmallVector<Value *, 2>
-getPointerOperands(const Value &V, const DataLayout &DL,
-                   const TargetTransformInfo *TTI) {
+SmallVector<Value *, 2> InferAddressSpacesImpl::getPointerOperands(
+    const Value &V, const DataLayout &DL,
+    const TargetTransformInfo *TTI) const {
   if (isa<Argument>(&V))
     return {};
 
@@ -375,15 +390,56 @@ getPointerOperands(const Value &V, const DataLayout &DL,
     return {II.getArgOperand(0)};
   }
   case Instruction::IntToPtr: {
-    assert(isNoopPtrIntCastPair(&Op, DL, TTI));
-    auto *P2I = cast<Operator>(Op.getOperand(0));
-    return {P2I->getOperand(0)};
+    if (isNoopPtrIntCastPair(&Op, DL, TTI)) {
+      auto *P2I = cast<Operator>(Op.getOperand(0));
+      return {P2I->getOperand(0)};
+    }
+    assert(isSafeToCastPtrIntPair(&Op, DL));
+    return {PtrIntCastPairs[&Op]};
   }
   default:
     llvm_unreachable("Unexpected instruction type.");
   }
 }
 
+bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
+    const Operator *I2P, const DataLayout &DL) const {
+  assert(I2P->getOpcode() == Instruction::IntToPtr);
+  if (PtrIntCastPairs.count(I2P))
+    return true;
+
+  if (I2P->getType()->isVectorTy())
+    return false;
+
+  auto *Xor = dyn_cast<Operator>(I2P->getOperand(0));
+  if (!Xor || Xor->getOpcode() != Instruction::Xor)
+    return false;
+
+  auto *LHS = Xor->getOperand(0);
+  auto *Mask = Xor->getOperand(1);
+  auto *P2I = dyn_cast<Operator>(LHS);
+  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
+    std::swap(LHS, Mask);
+  P2I = dyn_cast<Operator>(LHS);
+  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
+    return false;
+
+  auto *ASCast = dyn_cast<Operator>(P2I->getOperand(0));
+  if (!ASCast || ASCast->getOpcode() != Instruction::AddrSpaceCast)
+    return false;
+
+  KnownBits Known = computeKnownBits(Mask, DL, &AC, nullptr, DT);
+  unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
+  unsigned DstAS = ASCast->getOperand(0)->getType()->getPointerAddressSpace();
+  unsigned AddrUnchangedLeadingBit = Known.Zero.countLeadingOnes();
+  if (TTI->isSafeToCastIntPtrWithAS(AddrUnchangedLeadingBit, SrcAS, DstAS)) {
+    PtrIntCastPairs[I2P] = P2I->getOperand(0);
+    return true;
+  }
+
+  return false;
+}
+
 bool InferAddressSpacesImpl::rewriteIntrinsicOperands(IntrinsicInst *II,
                                                       Value *OldV,
                                                       Value *NewV) const {
@@ -592,6 +648,8 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
     } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) {
       if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI))
         PushPtrOperand(cast<Operator>(I2P->getOperand(0))->getOperand(0));
+      else if (isSafeToCastPtrIntPair(cast<Operator>(I2P), *DL))
+        PushPtrOperand(PtrIntCastPairs[I2P]);
     } else if (auto *RI = dyn_cast<ReturnInst>(&I)) {
       if (auto *RV = RI->getReturnValue();
           RV && RV->getType()->isPtrOrPtrVectorTy())
@@ -838,15 +896,20 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
     return SelectInst::Create(I->getOperand(0), NewPointerOperands[1],
                               NewPointerOperands[2], "", nullptr, I);
   case Instruction::IntToPtr: {
-    assert(isNoopPtrIntCastPair(cast<Operator>(I), *DL, TTI));
-    Value *Src = cast<Operator>(I->getOperand(0))->getOperand(0);
-    if (Src->getType() == NewPtrType)
-      return Src;
-
-    // If we had a no-op inttoptr/ptrtoint pair, we may still have inferred a
-    // source address space from a generic pointer source need to insert a cast
-    // back.
-    return new AddrSpaceCastInst(Src, NewPtrType);
+    if (isNoopPtrIntCastPair(cast<Operator>(I), *DL, TTI)) {
+      Value *Src = cast<Operator>(I->getOperand(0))->getOperand(0);
+      if (Src->getType() == NewPtrType)
+        return Src;
+
+      // If we had a no-op inttoptr/ptrtoint pair, we may still have inferred a
+      // source address space from a generic pointer source need to insert a
+      // cast back.
+      return new AddrSpaceCastInst(Src, NewPtrType);
+    }
+    assert(isSafeToCastPtrIntPair(cast<Operator>(I), *DL));
+    auto *Src = I->getOperand(0);
+    IntToPtrInst *NewI2P = new IntToPtrInst(Src, NewPtrType);
+    return NewI2P;
   }
   default:
     llvm_unreachable("Unexpected opcode");
@@ -856,10 +919,10 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
 // Similar to cloneInstructionWithNewAddressSpace, returns a clone of the
 // constant expression `CE` with its operands replaced as specified in
 // ValueWithNewAddrSpace.
-static Value *cloneConstantExprWithNewAddressSpace(
+Value *InferAddressSpacesImpl::cloneConstantExprWithNewAddressSpace(
     ConstantExpr *CE, unsigned NewAddrSpace,
     const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL,
-    const TargetTransformInfo *TTI) {
+    const TargetTransformInfo *TTI) const {
   Type *TargetType =
       CE->getType()->isPtrOrPtrVectorTy()
           ? getPtrOrVecOfPtrsWithNewAS(CE->getType(), NewAddrSpace)
@@ -881,10 +944,15 @@ static Value *cloneConstantExprWithNewAddressSpace(
   }
 
   if (CE->getOpcode() == Instruction::IntToPtr) {
-    assert(isNoopPtrIntCastPair(cast<Operator>(CE), *DL, TTI));
-    Constant *Src = cast<ConstantExpr>(CE->getOperand(0))->getOperand(0);
-    assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
-    return Src;
+    if (isNoopPtrIntCastPair(cast<Operator>(CE), *DL, TTI)) {
+      Constant *Src = cast<ConstantExpr>(CE->getOperand(0))->getOperand(0);
+      assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
+      return Src;
+    }
+    assert(isSafeToCastPtrIntPair(cast<Operator>(CE), *DL));
+    auto *Src = CE->getOperand(0);
+    return ConstantExpr::getIntToPtr(Src, TargetType);
+    // return ConstantExpr::getAddrSpaceCast(CE, TargetType);
   }
 
   // Computes the operands of the new constant expression.
@@ -990,6 +1058,7 @@ unsigned InferAddressSpacesImpl::joinAddressSpaces(unsigned AS1,
 bool InferAddressSpacesImpl::run(Function &CurFn) {
   F = &CurFn;
   DL = &F->getDataLayout();
+  PtrIntCastPairs.clear();
 
   if (AssumeDefaultIsFlatAddressSpace)
     FlatAddrSpace = 0;
diff --git a/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
new file mode 100644
index 0000000000000..1b4a8114a0ca8
--- /dev/null
+++ b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
@@ -0,0 +1,125 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
+; RUN: opt -S -mtriple=nvptx64-nvidia-cuda -passes=infer-address-spaces %s | FileCheck %s
+
+define void @test_smem_fail(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_smem_fail(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = xor i64 8191, [[A]]
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    store i16 0, ptr [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(3) %sp to ptr
+  %a = ptrtoint ptr %gp to i64
+  %b = xor i64 8191, %a
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
+define void @test_smem(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_smem(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = xor i64 4095, [[A]]
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(3) %sp to ptr
+  %a = ptrtoint ptr %gp to i64
+  %b = xor i64 4095, %a
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
+define void @test_gmem(ptr addrspace(1) %sp) {
+; CHECK-LABEL: define void @test_gmem(
+; CHECK-SAME: ptr addrspace(1) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(1) [[SP]] to ptr
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = xor i64 7, [[A]]
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(1)
+; CHECK-NEXT:    store i16 0, ptr addrspace(1) [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(1) %sp to ptr
+  %a = ptrtoint ptr %gp to i64
+  %b = xor i64 7, %a
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
+define void @test_lmem(ptr addrspace(5) %sp) {
+; CHECK-LABEL: define void @test_lmem(
+; CHECK-SAME: ptr addrspace(5) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(5) [[SP]] to ptr
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = xor i64 7, [[A]]
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    store i16 0, ptr [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(5) %sp to ptr
+  %a = ptrtoint ptr %gp to i64
+  %b = xor i64 7, %a
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
+define void @test3(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test3(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[T1:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[AND:%.*]] = lshr i64 [[T1]], 8
+; CHECK-NEXT:    [[SHR:%.*]] = and i64 [[AND]], 8
+; CHECK-NEXT:    [[AND1:%.*]] = lshr i64 [[T1]], 10
+; CHECK-NEXT:    [[SHR2:%.*]] = and i64 [[AND1]], 4
+; CHECK-NEXT:    [[OR:%.*]] = or i64 [[SHR]], [[SHR2]]
+; CHECK-NEXT:    [[AND3:%.*]] = lshr i64 [[T1]], 4
+; CHECK-NEXT:    [[SHR4:%.*]] = and i64 [[AND3]], 112
+; CHECK-NEXT:    [[OR5:%.*]] = or i64 [[OR]], [[SHR4]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor i64 [[OR5]], [[T1]]
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[XOR]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(3) %sp to ptr
+  %t1 = ptrtoint ptr %gp to i64
+  %and = lshr i64 %t1, 8
+  %shr = and i64 %and, 8
+  %and1 = lshr i64 %t1, 10
+  %shr2 = and i64 %and1, 4
+  %or = or i64 %shr, %shr2
+  %and3 = lshr i64 %t1, 4
+  %shr4 = and i64 %and3, 112
+  %or5 = or i64 %or, %shr4
+  %xor = xor i64 %or5, %t1
+  %gp2 = inttoptr i64 %xor to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
+ at g = addrspace(1) global i32 0, align 4
+
+define void @test_ce() {
+; CHECK-LABEL: define void @test_ce() {
+; CHECK-NEXT:    store i32 0, ptr inttoptr (i64 xor (i64 ptrtoint (ptr addrspacecast (ptr addrspace(1) @g to ptr) to i64), i64 7) to ptr), align 4
+; CHECK-NEXT:    ret void
+;
+  store i32 0, ptr inttoptr (i64
+  xor (i64
+  ptrtoint (ptr
+  addrspacecast (ptr addrspace(1) @g to ptr)
+  to i64),
+  i64 7)
+  to ptr)
+  ret void
+}

>From 142abfc58b5c560bfbb0fcb375de0fcc28997b5a Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Fri, 9 Jan 2026 16:54:34 +0800
Subject: [PATCH 02/12] Support more pattern of logic operation for the address
 computation.

---
 .../llvm/Analysis/TargetTransformInfo.h       |  2 +-
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  4 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  8 +-
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |  6 +-
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 34 +++++--
 llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll   |  2 +-
 .../InferAddressSpaces/NVPTX/int2ptr.ll       | 98 ++++++++++++++++++-
 7 files changed, 132 insertions(+), 22 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 897972a0cfd5f..d1d3cabfb89fa 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1951,7 +1951,7 @@ class TargetTransformInfo {
   ///   %b = xor i64 7, %a
   ///   %gp2 = inttoptr i64 %b to ptr
   ///   store i16 0, ptr %gp2, align 2
-  LLVM_ABI bool isSafeToCastIntPtrWithAS(unsigned AddrUnchangedLeadingBit,
+  LLVM_ABI bool isSafeToCastIntPtrWithAS(unsigned AddrChangedLSB,
                                          unsigned SrcAS, unsigned DstAS) const;
 
   struct VPLegalization {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3599cdbeff0d3..0445a7f26f929 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -1159,8 +1159,8 @@ class TargetTransformInfoImplBase {
 
   virtual bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
 
-  virtual bool isSafeToCastIntPtrWithAS(unsigned AddrUnchangedLeadingBit,
-                                        unsigned SrcAS, unsigned DstAS) const {
+  virtual bool isSafeToCastIntPtrWithAS(unsigned AddrChangedLSB, unsigned SrcAS,
+                                        unsigned DstAS) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a4e266f3c7b76..77356a9f105d7 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1494,10 +1494,10 @@ bool TargetTransformInfo::isVectorShiftByScalarCheap(Type *Ty) const {
   return TTIImpl->isVectorShiftByScalarCheap(Ty);
 }
 
-bool TargetTransformInfo::isSafeToCastIntPtrWithAS(
-    unsigned AddrUnchangedLeadingBit, unsigned SrcAS, unsigned DstAS) const {
-  return TTIImpl->isSafeToCastIntPtrWithAS(AddrUnchangedLeadingBit, SrcAS,
-                                           DstAS);
+bool TargetTransformInfo::isSafeToCastIntPtrWithAS(unsigned AddrChangedLSB,
+                                                   unsigned SrcAS,
+                                                   unsigned DstAS) const {
+  return TTIImpl->isSafeToCastIntPtrWithAS(AddrChangedLSB, SrcAS, DstAS);
 }
 
 unsigned
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index c5847ebe3fc0e..8ab4a7ffa1b5b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -180,8 +180,8 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
     }
   }
 
-  bool isSafeToCastIntPtrWithAS(unsigned AddrUnchangedLeadingBit,
-                                unsigned SrcAS, unsigned DstAS) const override {
+  bool isSafeToCastIntPtrWithAS(unsigned AddrChangedLSB, unsigned SrcAS,
+                                unsigned DstAS) const override {
     if (SrcAS != llvm::ADDRESS_SPACE_GENERIC)
       return false;
     if (DstAS != llvm::ADDRESS_SPACE_GLOBAL &&
@@ -190,7 +190,7 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
 
     // Address change within 4K size does not change the original address space
     // and is safe to perform address cast form SrcAS to DstAS.
-    if (AddrUnchangedLeadingBit >= 52)
+    if (AddrChangedLSB <= 12)
       return true;
 
     return false;
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 716391eeb1442..5a8ba45b86069 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -402,6 +402,22 @@ SmallVector<Value *, 2> InferAddressSpacesImpl::getPointerOperands(
   }
 }
 
+static unsigned computeKnownChangedLSB(const Operator *LogicOp,
+                                       const Value *Mask, const DataLayout &DL,
+                                       AssumptionCache *AC,
+                                       const DominatorTree *DT) {
+  KnownBits Known = computeKnownBits(Mask, DL, AC, nullptr, DT);
+  switch (LogicOp->getOpcode()) {
+  case Instruction::Xor:
+  case Instruction::Or:
+    return Known.getBitWidth() - Known.countMinLeadingZeros();
+  case Instruction::And:
+    return Known.getBitWidth() - Known.countMinLeadingOnes();
+  default:
+    return -1;
+  }
+}
+
 bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
     const Operator *I2P, const DataLayout &DL) const {
   assert(I2P->getOpcode() == Instruction::IntToPtr);
@@ -411,12 +427,17 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
   if (I2P->getType()->isVectorTy())
     return false;
 
-  auto *Xor = dyn_cast<Operator>(I2P->getOperand(0));
-  if (!Xor || Xor->getOpcode() != Instruction::Xor)
+  auto *LogicOP = dyn_cast<Operator>(I2P->getOperand(0));
+  if (!LogicOP)
+    return false;
+
+  if (LogicOP->getOpcode() != Instruction::Xor &&
+      LogicOP->getOpcode() != Instruction::Or &&
+      LogicOP->getOpcode() != Instruction::And)
     return false;
 
-  auto *LHS = Xor->getOperand(0);
-  auto *Mask = Xor->getOperand(1);
+  auto *LHS = LogicOP->getOperand(0);
+  auto *Mask = LogicOP->getOperand(1);
   auto *P2I = dyn_cast<Operator>(LHS);
   if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
     std::swap(LHS, Mask);
@@ -428,11 +449,10 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
   if (!ASCast || ASCast->getOpcode() != Instruction::AddrSpaceCast)
     return false;
 
-  KnownBits Known = computeKnownBits(Mask, DL, &AC, nullptr, DT);
   unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
   unsigned DstAS = ASCast->getOperand(0)->getType()->getPointerAddressSpace();
-  unsigned AddrUnchangedLeadingBit = Known.Zero.countLeadingOnes();
-  if (TTI->isSafeToCastIntPtrWithAS(AddrUnchangedLeadingBit, SrcAS, DstAS)) {
+  unsigned AddrChangedLSB = computeKnownChangedLSB(LogicOP, Mask, DL, &AC, DT);
+  if (TTI->isSafeToCastIntPtrWithAS(AddrChangedLSB, SrcAS, DstAS)) {
     PtrIntCastPairs[I2P] = P2I->getOperand(0);
     return true;
   }
diff --git a/llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll b/llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll
index a21261c768862..b9a24c722b5cc 100644
--- a/llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll
+++ b/llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll
@@ -14,7 +14,7 @@ define i32  @test_disjoint_or_addr(i16 %a) {
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    mov.b64 %rd1, a;
 ; CHECK-NEXT:    cvta.global.u64 %rd2, %rd1;
-; CHECK-NEXT:    ld.b32 %r1, [%rd2+8];
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd2+8];
 ; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
 ; CHECK-NEXT:    ret;
   %a1 = ptrtoint ptr @a to i64
diff --git a/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
index 1b4a8114a0ca8..10aa8188c9691 100644
--- a/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
@@ -6,21 +6,21 @@ define void @test_smem_fail(ptr addrspace(3) %sp) {
 ; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
 ; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
 ; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
-; CHECK-NEXT:    [[B:%.*]] = xor i64 8191, [[A]]
+; CHECK-NEXT:    [[B:%.*]] = xor i64 4096, [[A]]
 ; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
 ; CHECK-NEXT:    store i16 0, ptr [[GP2]], align 2
 ; CHECK-NEXT:    ret void
 ;
   %gp = addrspacecast ptr addrspace(3) %sp to ptr
   %a = ptrtoint ptr %gp to i64
-  %b = xor i64 8191, %a
+  %b = xor i64 4096, %a
   %gp2 = inttoptr i64 %b to ptr
   store i16 0, ptr %gp2, align 2
   ret void
 }
 
-define void @test_smem(ptr addrspace(3) %sp) {
-; CHECK-LABEL: define void @test_smem(
+define void @test_xor_smem(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_xor_smem(
 ; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
 ; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
 ; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
@@ -37,6 +37,96 @@ define void @test_smem(ptr addrspace(3) %sp) {
   ret void
 }
 
+define void @test_xor_smem2(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_xor_smem2(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = xor i64 [[A]], 4095
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(3) %sp to ptr
+  %a = ptrtoint ptr %gp to i64
+  %b = xor i64 %a, 4095
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
+define void @test_or_smem(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_or_smem(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = or i64 4095, [[A]]
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(3) %sp to ptr
+  %a = ptrtoint ptr %gp to i64
+  %b = or i64 4095, %a
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
+define void @test_or_smem2(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_or_smem2(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = or i64 4096, [[A]]
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    store i16 0, ptr [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(3) %sp to ptr
+  %a = ptrtoint ptr %gp to i64
+  %b = or i64 4096, %a
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
+define void @test_and_smem(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_and_smem(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = and i64 [[A]], -4096
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(3) %sp to ptr
+  %a = ptrtoint ptr %gp to i64
+  %b = and i64 %a, -4096
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
+define void @test_and_smem_fail(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_and_smem_fail(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = and i64 [[A]], -4097
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    store i16 0, ptr [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp = addrspacecast ptr addrspace(3) %sp to ptr
+  %a = ptrtoint ptr %gp to i64
+  %b = and i64 %a, -4097
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
 define void @test_gmem(ptr addrspace(1) %sp) {
 ; CHECK-LABEL: define void @test_gmem(
 ; CHECK-SAME: ptr addrspace(1) [[SP:%.*]]) {

>From 21e3caff94dad1be8a416617c5e2b2016c66aaf4 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Thu, 22 Jan 2026 11:30:44 +0800
Subject: [PATCH 03/12] [InferAS] revise the TTI interface

---
 .../llvm/Analysis/TargetTransformInfo.h       | 28 +++++++++----------
 .../llvm/Analysis/TargetTransformInfoImpl.h   | 10 +++----
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 11 ++++----
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   | 13 ++++-----
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 25 ++++++++++-------
 5 files changed, 44 insertions(+), 43 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index d1d3cabfb89fa..891515f734233 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -567,6 +567,20 @@ class TargetTransformInfo {
   LLVM_ABI KnownBits computeKnownBitsAddrSpaceCast(
       unsigned FromAS, unsigned ToAS, const KnownBits &FromPtrBits) const;
 
+  /// Return the preserved ptr bit mask that is safe to cast integer to pointer
+  /// with new address space. The address of integer form may only change in the
+  /// least significant bit (e.g. within a page). In that case target can
+  /// determine if it is safe to cast the generic address space to the original
+  /// address space. For below example, we can replace `%gp2 = inttoptr i64 %b
+  /// to ptr` with `%gp2 = inttoptr i64 %b to ptr addrspace(2)`
+  ///   %gp = addrspacecast ptr addrspace(2) %sp to ptr
+  ///   %a = ptrtoint ptr %gp to i64
+  ///   %b = xor i64 7, %a
+  ///   %gp2 = inttoptr i64 %b to ptr
+  ///   store i16 0, ptr %gp2, align 2
+  LLVM_ABI APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
+                                                  unsigned DstAS) const;
+
   /// Return true if globals in this address space can have initializers other
   /// than `undef`.
   LLVM_ABI bool
@@ -1940,20 +1954,6 @@ class TargetTransformInfo {
   /// This should also apply to lowering for vector funnel shifts (rotates).
   LLVM_ABI bool isVectorShiftByScalarCheap(Type *Ty) const;
 
-  /// Return true if it is safe to cast integer to pointer with new address
-  /// space. The address of integer form may only change in the least
-  /// significant bit (e.g. within a page). In that case target can determine
-  /// if it is safe to cast the generic address space to the original address
-  /// space. For below example, we can replace `%gp2 = inttoptr i64 %b to ptr`
-  /// with `%gp2 = inttoptr i64 %b to ptr addrspace(2)`
-  ///   %gp = addrspacecast ptr addrspace(2) %sp to ptr
-  ///   %a = ptrtoint ptr %gp to i64
-  ///   %b = xor i64 7, %a
-  ///   %gp2 = inttoptr i64 %b to ptr
-  ///   store i16 0, ptr %gp2, align 2
-  LLVM_ABI bool isSafeToCastIntPtrWithAS(unsigned AddrChangedLSB,
-                                         unsigned SrcAS, unsigned DstAS) const;
-
   struct VPLegalization {
     enum VPTransform {
       // keep the predicating parameter
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 0445a7f26f929..ec2c850d4afbc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -193,6 +193,11 @@ class TargetTransformInfoImplBase {
     return FromPtrBits.anyextOrTrunc(ToASBitSize);
   }
 
+  virtual APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
+                                                 unsigned DstAS) const {
+    return APInt::getZero(DL.getPointerSizeInBits());
+  }
+
   virtual bool
   canHaveNonUndefGlobalInitializerInAddressSpace(unsigned AS) const {
     return AS == 0;
@@ -1159,11 +1164,6 @@ class TargetTransformInfoImplBase {
 
   virtual bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
 
-  virtual bool isSafeToCastIntPtrWithAS(unsigned AddrChangedLSB, unsigned SrcAS,
-                                        unsigned DstAS) const {
-    return false;
-  }
-
   virtual TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 77356a9f105d7..4075be4fd2b75 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -339,6 +339,11 @@ KnownBits TargetTransformInfo::computeKnownBitsAddrSpaceCast(
   return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS, FromPtrBits);
 }
 
+APInt TargetTransformInfo::getAddrSpaceCastPreservedPtrMask(
+    unsigned SrcAS, unsigned DstAS) const {
+  return TTIImpl->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
+}
+
 bool TargetTransformInfo::canHaveNonUndefGlobalInitializerInAddressSpace(
     unsigned AS) const {
   return TTIImpl->canHaveNonUndefGlobalInitializerInAddressSpace(AS);
@@ -1494,12 +1499,6 @@ bool TargetTransformInfo::isVectorShiftByScalarCheap(Type *Ty) const {
   return TTIImpl->isVectorShiftByScalarCheap(Ty);
 }
 
-bool TargetTransformInfo::isSafeToCastIntPtrWithAS(unsigned AddrChangedLSB,
-                                                   unsigned SrcAS,
-                                                   unsigned DstAS) const {
-  return TTIImpl->isSafeToCastIntPtrWithAS(AddrChangedLSB, SrcAS, DstAS);
-}
-
 unsigned
 TargetTransformInfo::getNumBytesToPadGlobalArray(unsigned Size,
                                                  Type *ArrayType) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 8ab4a7ffa1b5b..5fda44aa0bde9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -180,20 +180,17 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
     }
   }
 
-  bool isSafeToCastIntPtrWithAS(unsigned AddrChangedLSB, unsigned SrcAS,
-                                unsigned DstAS) const override {
+  APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
+                                         unsigned DstAS) const override {
     if (SrcAS != llvm::ADDRESS_SPACE_GENERIC)
-      return false;
+      return APInt::getZero(DL.getPointerSizeInBits());
     if (DstAS != llvm::ADDRESS_SPACE_GLOBAL &&
         DstAS != llvm::ADDRESS_SPACE_SHARED)
-      return false;
+      return APInt::getZero(DL.getPointerSizeInBits());
 
     // Address change within 4K size does not change the original address space
     // and is safe to perform address cast form SrcAS to DstAS.
-    if (AddrChangedLSB <= 12)
-      return true;
-
-    return false;
+    return {DL.getPointerSizeInBits(), 0xfff};
   }
 
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 5a8ba45b86069..765038d91782b 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -402,19 +402,21 @@ SmallVector<Value *, 2> InferAddressSpacesImpl::getPointerOperands(
   }
 }
 
-static unsigned computeKnownChangedLSB(const Operator *LogicOp,
-                                       const Value *Mask, const DataLayout &DL,
-                                       AssumptionCache *AC,
-                                       const DominatorTree *DT) {
+// Return mask. The 1 in mask indicate the bit is changed.
+// This helper function is to compute the max know changed bits for ptr1 and
+// ptr2 after the operation `ptr2 = ptr1 Op Mask`.
+static APInt computeMaxChangedPtrBits(const Operator *Op, const Value *Mask,
+                                      const DataLayout &DL, AssumptionCache *AC,
+                                      const DominatorTree *DT) {
   KnownBits Known = computeKnownBits(Mask, DL, AC, nullptr, DT);
-  switch (LogicOp->getOpcode()) {
+  switch (Op->getOpcode()) {
   case Instruction::Xor:
   case Instruction::Or:
-    return Known.getBitWidth() - Known.countMinLeadingZeros();
+    return ~Known.Zero;
   case Instruction::And:
-    return Known.getBitWidth() - Known.countMinLeadingOnes();
+    return ~Known.One;
   default:
-    return -1;
+    return APInt::getAllOnes(Known.getBitWidth());
   }
 }
 
@@ -451,8 +453,11 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
 
   unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
   unsigned DstAS = ASCast->getOperand(0)->getType()->getPointerAddressSpace();
-  unsigned AddrChangedLSB = computeKnownChangedLSB(LogicOP, Mask, DL, &AC, DT);
-  if (TTI->isSafeToCastIntPtrWithAS(AddrChangedLSB, SrcAS, DstAS)) {
+  APInt PreservedPtrMask = TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
+  APInt ChangedPtrBits = computeMaxChangedPtrBits(LogicOP, Mask, DL, &AC, DT);
+  // Check if the address bits change is within the preserved mask. If the bits
+  // change is not preserved, it is not safe to perform address space cast.
+  if ((ChangedPtrBits & (~PreservedPtrMask)).isZero()) {
     PtrIntCastPairs[I2P] = P2I->getOperand(0);
     return true;
   }

>From 446ac64cf8cb4696c81e0a7703a00e826e306758 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Thu, 22 Jan 2026 19:55:00 +0800
Subject: [PATCH 04/12] Address Matt's comments

---
 .../llvm/Analysis/TargetTransformInfo.h       |  4 ++--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  6 ++---
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  5 +++--
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   | 12 +++++-----
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 22 +++++++++++++------
 5 files changed, 30 insertions(+), 19 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 891515f734233..a8c749f789d8b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -578,8 +578,8 @@ class TargetTransformInfo {
   ///   %b = xor i64 7, %a
   ///   %gp2 = inttoptr i64 %b to ptr
   ///   store i16 0, ptr %gp2, align 2
-  LLVM_ABI APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
-                                                  unsigned DstAS) const;
+  LLVM_ABI std::optional<APInt>
+  getAddrSpaceCastPreservedPtrMask(unsigned SrcAS, unsigned DstAS) const;
 
   /// Return true if globals in this address space can have initializers other
   /// than `undef`.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index ec2c850d4afbc..4d384327e8a87 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -193,9 +193,9 @@ class TargetTransformInfoImplBase {
     return FromPtrBits.anyextOrTrunc(ToASBitSize);
   }
 
-  virtual APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
-                                                 unsigned DstAS) const {
-    return APInt::getZero(DL.getPointerSizeInBits());
+  virtual std::optional<APInt>
+  getAddrSpaceCastPreservedPtrMask(unsigned SrcAS, unsigned DstAS) const {
+    return std::nullopt;
   }
 
   virtual bool
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 4075be4fd2b75..5f65e0b8509f1 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -339,8 +339,9 @@ KnownBits TargetTransformInfo::computeKnownBitsAddrSpaceCast(
   return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS, FromPtrBits);
 }
 
-APInt TargetTransformInfo::getAddrSpaceCastPreservedPtrMask(
-    unsigned SrcAS, unsigned DstAS) const {
+std::optional<APInt>
+TargetTransformInfo::getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
+                                                      unsigned DstAS) const {
   return TTIImpl->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 5fda44aa0bde9..e1f94fc2dd0ed 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -180,17 +180,19 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
     }
   }
 
-  APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
-                                         unsigned DstAS) const override {
+  std::optional<APInt>
+  getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
+                                   unsigned DstAS) const override {
     if (SrcAS != llvm::ADDRESS_SPACE_GENERIC)
-      return APInt::getZero(DL.getPointerSizeInBits());
+      return std::nullopt;
     if (DstAS != llvm::ADDRESS_SPACE_GLOBAL &&
         DstAS != llvm::ADDRESS_SPACE_SHARED)
-      return APInt::getZero(DL.getPointerSizeInBits());
+      return std::nullopt;
 
     // Address change within 4K size does not change the original address space
     // and is safe to perform address cast form SrcAS to DstAS.
-    return {DL.getPointerSizeInBits(), 0xfff};
+    APInt PtrMask(DL.getPointerSizeInBits(llvm::ADDRESS_SPACE_GENERIC), 0xfff);
+    return PtrMask;
   }
 
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 765038d91782b..5d223a5aeb57c 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -130,6 +130,7 @@
 #include <cassert>
 #include <iterator>
 #include <limits>
+#include <optional>
 #include <utility>
 #include <vector>
 
@@ -429,7 +430,7 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
   if (I2P->getType()->isVectorTy())
     return false;
 
-  auto *LogicOP = dyn_cast<Operator>(I2P->getOperand(0));
+  Operator *LogicOP = dyn_cast<Operator>(I2P->getOperand(0));
   if (!LogicOP)
     return false;
 
@@ -438,26 +439,33 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
       LogicOP->getOpcode() != Instruction::And)
     return false;
 
-  auto *LHS = LogicOP->getOperand(0);
-  auto *Mask = LogicOP->getOperand(1);
-  auto *P2I = dyn_cast<Operator>(LHS);
+  Value *LHS = LogicOP->getOperand(0);
+  Value *Mask = LogicOP->getOperand(1);
+  Operator *P2I = dyn_cast<Operator>(LHS);
   if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
     std::swap(LHS, Mask);
   P2I = dyn_cast<Operator>(LHS);
   if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
     return false;
 
-  auto *ASCast = dyn_cast<Operator>(P2I->getOperand(0));
+  Operator *ASCast = dyn_cast<Operator>(P2I->getOperand(0));
   if (!ASCast || ASCast->getOpcode() != Instruction::AddrSpaceCast)
     return false;
 
   unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
   unsigned DstAS = ASCast->getOperand(0)->getType()->getPointerAddressSpace();
-  APInt PreservedPtrMask = TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
+  std::optional<APInt> PreservedPtrMask =
+      TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
+  if (!PreservedPtrMask)
+    return false;
+
   APInt ChangedPtrBits = computeMaxChangedPtrBits(LogicOP, Mask, DL, &AC, DT);
   // Check if the address bits change is within the preserved mask. If the bits
   // change is not preserved, it is not safe to perform address space cast.
-  if ((ChangedPtrBits & (~PreservedPtrMask)).isZero()) {
+  if (ChangedPtrBits.getBitWidth() != PreservedPtrMask->getBitWidth())
+    return false;
+
+  if (ChangedPtrBits.isSubsetOf(*PreservedPtrMask)) {
     PtrIntCastPairs[I2P] = P2I->getOperand(0);
     return true;
   }

>From 26e6e78eedecb10f2191edd58c9033fed1408a15 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Fri, 23 Jan 2026 10:25:00 +0800
Subject: [PATCH 05/12] Address comments.

---
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 34 +++++++------------
 .../InferAddressSpaces/NVPTX/int2ptr.ll       | 20 +++++++++++
 2 files changed, 33 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 5d223a5aeb57c..46a03c1bb7205 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -97,6 +97,7 @@
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -111,6 +112,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -137,6 +139,7 @@
 #define DEBUG_TYPE "infer-address-spaces"
 
 using namespace llvm;
+using namespace llvm::PatternMatch;
 
 static cl::opt<bool> AssumeDefaultIsFlatAddressSpace(
     "assume-default-is-flat-addrspace", cl::init(false), cl::ReallyHidden,
@@ -430,43 +433,32 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
   if (I2P->getType()->isVectorTy())
     return false;
 
-  Operator *LogicOP = dyn_cast<Operator>(I2P->getOperand(0));
-  if (!LogicOP)
+  Value *LogicalOp = I2P->getOperand(0);
+  Value *OldPtr, *Mask;
+  if (!match(LogicalOp,
+             m_c_BitwiseLogic(m_PtrToInt(m_Value(OldPtr)), m_Value(Mask))))
     return false;
 
-  if (LogicOP->getOpcode() != Instruction::Xor &&
-      LogicOP->getOpcode() != Instruction::Or &&
-      LogicOP->getOpcode() != Instruction::And)
-    return false;
-
-  Value *LHS = LogicOP->getOperand(0);
-  Value *Mask = LogicOP->getOperand(1);
-  Operator *P2I = dyn_cast<Operator>(LHS);
-  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
-    std::swap(LHS, Mask);
-  P2I = dyn_cast<Operator>(LHS);
-  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
-    return false;
-
-  Operator *ASCast = dyn_cast<Operator>(P2I->getOperand(0));
-  if (!ASCast || ASCast->getOpcode() != Instruction::AddrSpaceCast)
+  Operator *AsCast = dyn_cast<AddrSpaceCastOperator>(OldPtr);
+  if (!AsCast)
     return false;
 
   unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
-  unsigned DstAS = ASCast->getOperand(0)->getType()->getPointerAddressSpace();
+  unsigned DstAS = AsCast->getOperand(0)->getType()->getPointerAddressSpace();
   std::optional<APInt> PreservedPtrMask =
       TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
   if (!PreservedPtrMask)
     return false;
 
-  APInt ChangedPtrBits = computeMaxChangedPtrBits(LogicOP, Mask, DL, &AC, DT);
+  APInt ChangedPtrBits =
+      computeMaxChangedPtrBits(cast<Operator>(LogicalOp), Mask, DL, &AC, DT);
   // Check if the address bits change is within the preserved mask. If the bits
   // change is not preserved, it is not safe to perform address space cast.
   if (ChangedPtrBits.getBitWidth() != PreservedPtrMask->getBitWidth())
     return false;
 
   if (ChangedPtrBits.isSubsetOf(*PreservedPtrMask)) {
-    PtrIntCastPairs[I2P] = P2I->getOperand(0);
+    PtrIntCastPairs[I2P] = OldPtr;
     return true;
   }
 
diff --git a/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
index 10aa8188c9691..e4d694c7dfb5f 100644
--- a/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
@@ -19,6 +19,26 @@ define void @test_smem_fail(ptr addrspace(3) %sp) {
   ret void
 }
 
+define void @test_smem_fail2(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_smem_fail2(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP1:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[GP:%.*]] = getelementptr i8, ptr [[GP1]], i32 8
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = xor i64 4095, [[A]]
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    store i16 0, ptr [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp1 = addrspacecast ptr addrspace(3) %sp to ptr
+  %gp = getelementptr i8, ptr %gp1, i32 8
+  %a = ptrtoint ptr %gp to i64
+  %b = xor i64 4095, %a
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
 define void @test_xor_smem(ptr addrspace(3) %sp) {
 ; CHECK-LABEL: define void @test_xor_smem(
 ; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {

>From d6e34dd5aad5b2f2876458706408b343c4067f18 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Sun, 25 Jan 2026 09:20:59 +0800
Subject: [PATCH 06/12] Revert "Address comments."

This reverts commit 26e6e78eedecb10f2191edd58c9033fed1408a15.
The logical operator pattern match doesn't work for constant
expression, so revert the patch
---
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 34 ++++++++++++-------
 .../InferAddressSpaces/NVPTX/int2ptr.ll       | 20 -----------
 2 files changed, 21 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 46a03c1bb7205..5d223a5aeb57c 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -97,7 +97,6 @@
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
-#include "llvm/IR/Argument.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -112,7 +111,6 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PassManager.h"
-#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -139,7 +137,6 @@
 #define DEBUG_TYPE "infer-address-spaces"
 
 using namespace llvm;
-using namespace llvm::PatternMatch;
 
 static cl::opt<bool> AssumeDefaultIsFlatAddressSpace(
     "assume-default-is-flat-addrspace", cl::init(false), cl::ReallyHidden,
@@ -433,32 +430,43 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
   if (I2P->getType()->isVectorTy())
     return false;
 
-  Value *LogicalOp = I2P->getOperand(0);
-  Value *OldPtr, *Mask;
-  if (!match(LogicalOp,
-             m_c_BitwiseLogic(m_PtrToInt(m_Value(OldPtr)), m_Value(Mask))))
+  Operator *LogicOP = dyn_cast<Operator>(I2P->getOperand(0));
+  if (!LogicOP)
     return false;
 
-  Operator *AsCast = dyn_cast<AddrSpaceCastOperator>(OldPtr);
-  if (!AsCast)
+  if (LogicOP->getOpcode() != Instruction::Xor &&
+      LogicOP->getOpcode() != Instruction::Or &&
+      LogicOP->getOpcode() != Instruction::And)
+    return false;
+
+  Value *LHS = LogicOP->getOperand(0);
+  Value *Mask = LogicOP->getOperand(1);
+  Operator *P2I = dyn_cast<Operator>(LHS);
+  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
+    std::swap(LHS, Mask);
+  P2I = dyn_cast<Operator>(LHS);
+  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
+    return false;
+
+  Operator *ASCast = dyn_cast<Operator>(P2I->getOperand(0));
+  if (!ASCast || ASCast->getOpcode() != Instruction::AddrSpaceCast)
     return false;
 
   unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
-  unsigned DstAS = AsCast->getOperand(0)->getType()->getPointerAddressSpace();
+  unsigned DstAS = ASCast->getOperand(0)->getType()->getPointerAddressSpace();
   std::optional<APInt> PreservedPtrMask =
       TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
   if (!PreservedPtrMask)
     return false;
 
-  APInt ChangedPtrBits =
-      computeMaxChangedPtrBits(cast<Operator>(LogicalOp), Mask, DL, &AC, DT);
+  APInt ChangedPtrBits = computeMaxChangedPtrBits(LogicOP, Mask, DL, &AC, DT);
   // Check if the address bits change is within the preserved mask. If the bits
   // change is not preserved, it is not safe to perform address space cast.
   if (ChangedPtrBits.getBitWidth() != PreservedPtrMask->getBitWidth())
     return false;
 
   if (ChangedPtrBits.isSubsetOf(*PreservedPtrMask)) {
-    PtrIntCastPairs[I2P] = OldPtr;
+    PtrIntCastPairs[I2P] = P2I->getOperand(0);
     return true;
   }
 
diff --git a/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
index e4d694c7dfb5f..10aa8188c9691 100644
--- a/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
@@ -19,26 +19,6 @@ define void @test_smem_fail(ptr addrspace(3) %sp) {
   ret void
 }
 
-define void @test_smem_fail2(ptr addrspace(3) %sp) {
-; CHECK-LABEL: define void @test_smem_fail2(
-; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
-; CHECK-NEXT:    [[GP1:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
-; CHECK-NEXT:    [[GP:%.*]] = getelementptr i8, ptr [[GP1]], i32 8
-; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
-; CHECK-NEXT:    [[B:%.*]] = xor i64 4095, [[A]]
-; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
-; CHECK-NEXT:    store i16 0, ptr [[GP2]], align 2
-; CHECK-NEXT:    ret void
-;
-  %gp1 = addrspacecast ptr addrspace(3) %sp to ptr
-  %gp = getelementptr i8, ptr %gp1, i32 8
-  %a = ptrtoint ptr %gp to i64
-  %b = xor i64 4095, %a
-  %gp2 = inttoptr i64 %b to ptr
-  store i16 0, ptr %gp2, align 2
-  ret void
-}
-
 define void @test_xor_smem(ptr addrspace(3) %sp) {
 ; CHECK-LABEL: define void @test_xor_smem(
 ; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {

>From 76cb078acd30c852d99c448ac8f405a522f91579 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Sun, 25 Jan 2026 10:03:06 +0800
Subject: [PATCH 07/12] [InferAS] Address the comments.

1. revert the inferface getAddrSpaceCastPreservedPtrMask to return APInt
2. emit new addrspacecast instead of emit inttoptr with new address space.
---
 .../llvm/Analysis/TargetTransformInfo.h       |  4 +-
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  6 +--
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  5 +-
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |  9 ++--
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 30 ++++++-----
 llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll   |  6 ++-
 .../InferAddressSpaces/NVPTX/int2ptr.ll       | 50 ++++++++++++++-----
 7 files changed, 67 insertions(+), 43 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index a8c749f789d8b..891515f734233 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -578,8 +578,8 @@ class TargetTransformInfo {
   ///   %b = xor i64 7, %a
   ///   %gp2 = inttoptr i64 %b to ptr
   ///   store i16 0, ptr %gp2, align 2
-  LLVM_ABI std::optional<APInt>
-  getAddrSpaceCastPreservedPtrMask(unsigned SrcAS, unsigned DstAS) const;
+  LLVM_ABI APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
+                                                  unsigned DstAS) const;
 
   /// Return true if globals in this address space can have initializers other
   /// than `undef`.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 4d384327e8a87..f1637f0ee967d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -193,9 +193,9 @@ class TargetTransformInfoImplBase {
     return FromPtrBits.anyextOrTrunc(ToASBitSize);
   }
 
-  virtual std::optional<APInt>
-  getAddrSpaceCastPreservedPtrMask(unsigned SrcAS, unsigned DstAS) const {
-    return std::nullopt;
+  virtual APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
+                                                 unsigned DstAS) const {
+    return {DL.getPointerSizeInBits(SrcAS), 0};
   }
 
   virtual bool
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 5f65e0b8509f1..4075be4fd2b75 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -339,9 +339,8 @@ KnownBits TargetTransformInfo::computeKnownBitsAddrSpaceCast(
   return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS, FromPtrBits);
 }
 
-std::optional<APInt>
-TargetTransformInfo::getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
-                                                      unsigned DstAS) const {
+APInt TargetTransformInfo::getAddrSpaceCastPreservedPtrMask(
+    unsigned SrcAS, unsigned DstAS) const {
   return TTIImpl->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index e1f94fc2dd0ed..4dbe5076a11bd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -180,14 +180,13 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
     }
   }
 
-  std::optional<APInt>
-  getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
-                                   unsigned DstAS) const override {
+  APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
+                                         unsigned DstAS) const override {
     if (SrcAS != llvm::ADDRESS_SPACE_GENERIC)
-      return std::nullopt;
+      return {DL.getPointerSizeInBits(llvm::ADDRESS_SPACE_GENERIC), 0};
     if (DstAS != llvm::ADDRESS_SPACE_GLOBAL &&
         DstAS != llvm::ADDRESS_SPACE_SHARED)
-      return std::nullopt;
+      return {DL.getPointerSizeInBits(llvm::ADDRESS_SPACE_GENERIC), 0};
 
     // Address change within 4K size does not change the original address space
     // and is safe to perform address cast form SrcAS to DstAS.
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 5d223a5aeb57c..b27b9e0eaa80d 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -424,6 +424,8 @@ static APInt computeMaxChangedPtrBits(const Operator *Op, const Value *Mask,
 bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
     const Operator *I2P, const DataLayout &DL) const {
   assert(I2P->getOpcode() == Instruction::IntToPtr);
+  // If I2P has been checked before and has the corresponding old pointer value,
+  // just return true.
   if (PtrIntCastPairs.count(I2P))
     return true;
 
@@ -454,19 +456,17 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
 
   unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
   unsigned DstAS = ASCast->getOperand(0)->getType()->getPointerAddressSpace();
-  std::optional<APInt> PreservedPtrMask =
-      TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
-  if (!PreservedPtrMask)
-    return false;
-
+  APInt PreservedPtrMask = TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
   APInt ChangedPtrBits = computeMaxChangedPtrBits(LogicOP, Mask, DL, &AC, DT);
   // Check if the address bits change is within the preserved mask. If the bits
   // change is not preserved, it is not safe to perform address space cast.
-  if (ChangedPtrBits.getBitWidth() != PreservedPtrMask->getBitWidth())
-    return false;
-
-  if (ChangedPtrBits.isSubsetOf(*PreservedPtrMask)) {
-    PtrIntCastPairs[I2P] = P2I->getOperand(0);
+  // The following pattern is not safe to cast address space.
+  //   %1 = ptrtoint ptr addrspace(3) %sp to i32
+  //   %2 = zext i32 %1 to i64
+  //   %gp = inttoptr i64 %2 to ptr
+  assert(ChangedPtrBits.getBitWidth() == PreservedPtrMask.getBitWidth());
+  if (ChangedPtrBits.isSubsetOf(PreservedPtrMask)) {
+    PtrIntCastPairs.insert({I2P, P2I->getOperand(0)});
     return true;
   }
 
@@ -940,9 +940,9 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
       return new AddrSpaceCastInst(Src, NewPtrType);
     }
     assert(isSafeToCastPtrIntPair(cast<Operator>(I), *DL));
-    auto *Src = I->getOperand(0);
-    IntToPtrInst *NewI2P = new IntToPtrInst(Src, NewPtrType);
-    return NewI2P;
+    AddrSpaceCastInst *AsCast = new AddrSpaceCastInst(I, NewPtrType);
+    AsCast->insertAfter(I);
+    return AsCast;
   }
   default:
     llvm_unreachable("Unexpected opcode");
@@ -983,9 +983,7 @@ Value *InferAddressSpacesImpl::cloneConstantExprWithNewAddressSpace(
       return Src;
     }
     assert(isSafeToCastPtrIntPair(cast<Operator>(CE), *DL));
-    auto *Src = CE->getOperand(0);
-    return ConstantExpr::getIntToPtr(Src, TargetType);
-    // return ConstantExpr::getAddrSpaceCast(CE, TargetType);
+    return ConstantExpr::getAddrSpaceCast(CE, TargetType);
   }
 
   // Computes the operands of the new constant expression.
diff --git a/llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll b/llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll
index b9a24c722b5cc..02833ee861315 100644
--- a/llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll
+++ b/llvm/test/CodeGen/NVPTX/disjoint-or-addr.ll
@@ -9,12 +9,14 @@ define i32  @test_disjoint_or_addr(i16 %a) {
 ; CHECK-LABEL: test_disjoint_or_addr(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<2>;
-; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-NEXT:    .reg .b64 %rd<5>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    mov.b64 %rd1, a;
 ; CHECK-NEXT:    cvta.global.u64 %rd2, %rd1;
-; CHECK-NEXT:    ld.global.b32 %r1, [%rd2+8];
+; CHECK-NEXT:    or.b64 %rd3, %rd2, 8;
+; CHECK-NEXT:    cvta.to.global.u64 %rd4, %rd3;
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd4];
 ; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
 ; CHECK-NEXT:    ret;
   %a1 = ptrtoint ptr @a to i64
diff --git a/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
index 10aa8188c9691..90cdab3c09084 100644
--- a/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/NVPTX/int2ptr.ll
@@ -19,14 +19,35 @@ define void @test_smem_fail(ptr addrspace(3) %sp) {
   ret void
 }
 
+define void @test_smem_fail2(ptr addrspace(3) %sp) {
+; CHECK-LABEL: define void @test_smem_fail2(
+; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
+; CHECK-NEXT:    [[GP1:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
+; CHECK-NEXT:    [[GP:%.*]] = getelementptr i8, ptr [[GP1]], i32 8
+; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
+; CHECK-NEXT:    [[B:%.*]] = xor i64 4095, [[A]]
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    store i16 0, ptr [[GP2]], align 2
+; CHECK-NEXT:    ret void
+;
+  %gp1 = addrspacecast ptr addrspace(3) %sp to ptr
+  %gp = getelementptr i8, ptr %gp1, i32 8
+  %a = ptrtoint ptr %gp to i64
+  %b = xor i64 4095, %a
+  %gp2 = inttoptr i64 %b to ptr
+  store i16 0, ptr %gp2, align 2
+  ret void
+}
+
 define void @test_xor_smem(ptr addrspace(3) %sp) {
 ; CHECK-LABEL: define void @test_xor_smem(
 ; CHECK-SAME: ptr addrspace(3) [[SP:%.*]]) {
 ; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
 ; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
 ; CHECK-NEXT:    [[B:%.*]] = xor i64 4095, [[A]]
-; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(3)
-; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[GP2]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[TMP1]], align 2
 ; CHECK-NEXT:    ret void
 ;
   %gp = addrspacecast ptr addrspace(3) %sp to ptr
@@ -43,8 +64,9 @@ define void @test_xor_smem2(ptr addrspace(3) %sp) {
 ; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
 ; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
 ; CHECK-NEXT:    [[B:%.*]] = xor i64 [[A]], 4095
-; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(3)
-; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[GP2]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[TMP1]], align 2
 ; CHECK-NEXT:    ret void
 ;
   %gp = addrspacecast ptr addrspace(3) %sp to ptr
@@ -61,8 +83,9 @@ define void @test_or_smem(ptr addrspace(3) %sp) {
 ; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
 ; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
 ; CHECK-NEXT:    [[B:%.*]] = or i64 4095, [[A]]
-; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(3)
-; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[GP2]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[TMP1]], align 2
 ; CHECK-NEXT:    ret void
 ;
   %gp = addrspacecast ptr addrspace(3) %sp to ptr
@@ -97,8 +120,9 @@ define void @test_and_smem(ptr addrspace(3) %sp) {
 ; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(3) [[SP]] to ptr
 ; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
 ; CHECK-NEXT:    [[B:%.*]] = and i64 [[A]], -4096
-; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(3)
-; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[GP2]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[TMP1]], align 2
 ; CHECK-NEXT:    ret void
 ;
   %gp = addrspacecast ptr addrspace(3) %sp to ptr
@@ -133,8 +157,9 @@ define void @test_gmem(ptr addrspace(1) %sp) {
 ; CHECK-NEXT:    [[GP:%.*]] = addrspacecast ptr addrspace(1) [[SP]] to ptr
 ; CHECK-NEXT:    [[A:%.*]] = ptrtoint ptr [[GP]] to i64
 ; CHECK-NEXT:    [[B:%.*]] = xor i64 7, [[A]]
-; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr addrspace(1)
-; CHECK-NEXT:    store i16 0, ptr addrspace(1) [[GP2]], align 2
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[B]] to ptr
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[GP2]] to ptr addrspace(1)
+; CHECK-NEXT:    store i16 0, ptr addrspace(1) [[TMP1]], align 2
 ; CHECK-NEXT:    ret void
 ;
   %gp = addrspacecast ptr addrspace(1) %sp to ptr
@@ -177,8 +202,9 @@ define void @test3(ptr addrspace(3) %sp) {
 ; CHECK-NEXT:    [[SHR4:%.*]] = and i64 [[AND3]], 112
 ; CHECK-NEXT:    [[OR5:%.*]] = or i64 [[OR]], [[SHR4]]
 ; CHECK-NEXT:    [[XOR:%.*]] = xor i64 [[OR5]], [[T1]]
-; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[XOR]] to ptr addrspace(3)
-; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[GP2]], align 2
+; CHECK-NEXT:    [[GP2:%.*]] = inttoptr i64 [[XOR]] to ptr
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[GP2]] to ptr addrspace(3)
+; CHECK-NEXT:    store i16 0, ptr addrspace(3) [[TMP1]], align 2
 ; CHECK-NEXT:    ret void
 ;
   %gp = addrspacecast ptr addrspace(3) %sp to ptr

>From 8709d54dd22fac232bc3df8cea9fcaa4efc9c6e3 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Sun, 25 Jan 2026 10:30:43 +0800
Subject: [PATCH 08/12] [InferAS] call BaseT method when it has nothing to do.

---
 llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 4dbe5076a11bd..ec4f3724df236 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -183,10 +183,10 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   APInt getAddrSpaceCastPreservedPtrMask(unsigned SrcAS,
                                          unsigned DstAS) const override {
     if (SrcAS != llvm::ADDRESS_SPACE_GENERIC)
-      return {DL.getPointerSizeInBits(llvm::ADDRESS_SPACE_GENERIC), 0};
+      return BaseT::getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
     if (DstAS != llvm::ADDRESS_SPACE_GLOBAL &&
         DstAS != llvm::ADDRESS_SPACE_SHARED)
-      return {DL.getPointerSizeInBits(llvm::ADDRESS_SPACE_GENERIC), 0};
+      return BaseT::getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
 
     // Address change within 4K size does not change the original address space
     // and is safe to perform address cast form SrcAS to DstAS.

>From ea877e34385f71eed792c708d53ae8344a219f81 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Sun, 25 Jan 2026 21:42:27 +0800
Subject: [PATCH 09/12] [InferAS] return early when PreservedPtrMask is zero.

---
 llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index b27b9e0eaa80d..e70f02e231557 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -457,6 +457,8 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
   unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
   unsigned DstAS = ASCast->getOperand(0)->getType()->getPointerAddressSpace();
   APInt PreservedPtrMask = TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
+  if (PreservedPtrMask.isZero())
+    return false;
   APInt ChangedPtrBits = computeMaxChangedPtrBits(LogicOP, Mask, DL, &AC, DT);
   // Check if the address bits change is within the preserved mask. If the bits
   // change is not preserved, it is not safe to perform address space cast.

>From 87ade49f8cb75c81db158e94a0a69f40d93666bd Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Sat, 31 Jan 2026 13:37:55 +0800
Subject: [PATCH 10/12] [InferAS] Address the comments.

1. Use pattern match to detect inttoptr
2. Rename isSafeToCastPtrIntPair to isSafeToCastIntToPtrAddrSpace
3. Refine comments
---
 .../llvm/Analysis/TargetTransformInfo.h       |  3 +-
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 50 ++++++++-----------
 2 files changed, 24 insertions(+), 29 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 891515f734233..7aea10eda995f 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -568,7 +568,8 @@ class TargetTransformInfo {
       unsigned FromAS, unsigned ToAS, const KnownBits &FromPtrBits) const;
 
   /// Return the preserved ptr bit mask that is safe to cast integer to pointer
-  /// with new address space. The address of integer form may only change in the
+  /// with new address space. The returned APInt size is identical to the source
+  /// address space size. The address of integer form may only change in the
   /// least significant bit (e.g. within a page). In that case target can
   /// determine if it is safe to cast the generic address space to the original
   /// address space. For below example, we can replace `%gp2 = inttoptr i64 %b
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index e70f02e231557..54f3c0f13ceec 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -97,6 +97,7 @@
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -110,6 +111,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
@@ -137,6 +139,7 @@
 #define DEBUG_TYPE "infer-address-spaces"
 
 using namespace llvm;
+using namespace llvm::PatternMatch;
 
 static cl::opt<bool> AssumeDefaultIsFlatAddressSpace(
     "assume-default-is-flat-addrspace", cl::init(false), cl::ReallyHidden,
@@ -195,7 +198,8 @@ class InferAddressSpacesImpl {
   unsigned FlatAddrSpace = 0;
   mutable DenseMap<const Value *, Value *> PtrIntCastPairs;
 
-  bool isSafeToCastPtrIntPair(const Operator *I2P, const DataLayout &DL) const;
+  bool isSafeToCastIntToPtrAddrSpace(const Operator *I2P,
+                                     const DataLayout &DL) const;
   bool isAddressExpression(const Value &V, const DataLayout &DL,
                            const TargetTransformInfo *TTI) const;
   Value *cloneConstantExprWithNewAddressSpace(
@@ -356,7 +360,8 @@ bool InferAddressSpacesImpl::isAddressExpression(
     return II && II->getIntrinsicID() == Intrinsic::ptrmask;
   }
   case Instruction::IntToPtr:
-    return isNoopPtrIntCastPair(Op, DL, TTI) || isSafeToCastPtrIntPair(Op, DL);
+    return isNoopPtrIntCastPair(Op, DL, TTI) ||
+           isSafeToCastIntToPtrAddrSpace(Op, DL);
   default:
     // That value is an address expression if it has an assumed address space.
     return TTI->getAssumedAddrSpace(&V) != UninitializedAddressSpace;
@@ -395,7 +400,7 @@ SmallVector<Value *, 2> InferAddressSpacesImpl::getPointerOperands(
       auto *P2I = cast<Operator>(Op.getOperand(0));
       return {P2I->getOperand(0)};
     }
-    assert(isSafeToCastPtrIntPair(&Op, DL));
+    assert(isSafeToCastIntToPtrAddrSpace(&Op, DL));
     return {PtrIntCastPairs[&Op]};
   }
   default:
@@ -421,7 +426,7 @@ static APInt computeMaxChangedPtrBits(const Operator *Op, const Value *Mask,
   }
 }
 
-bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
+bool InferAddressSpacesImpl::isSafeToCastIntToPtrAddrSpace(
     const Operator *I2P, const DataLayout &DL) const {
   assert(I2P->getOpcode() == Instruction::IntToPtr);
   // If I2P has been checked before and has the corresponding old pointer value,
@@ -432,34 +437,23 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
   if (I2P->getType()->isVectorTy())
     return false;
 
-  Operator *LogicOP = dyn_cast<Operator>(I2P->getOperand(0));
-  if (!LogicOP)
+  Value *LogicalOp = I2P->getOperand(0);
+  Value *OldPtr, *Mask;
+  if (!match(LogicalOp,
+             m_c_BitwiseLogic(m_PtrToInt(m_Value(OldPtr)), m_Value(Mask))))
     return false;
 
-  if (LogicOP->getOpcode() != Instruction::Xor &&
-      LogicOP->getOpcode() != Instruction::Or &&
-      LogicOP->getOpcode() != Instruction::And)
-    return false;
-
-  Value *LHS = LogicOP->getOperand(0);
-  Value *Mask = LogicOP->getOperand(1);
-  Operator *P2I = dyn_cast<Operator>(LHS);
-  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
-    std::swap(LHS, Mask);
-  P2I = dyn_cast<Operator>(LHS);
-  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
-    return false;
-
-  Operator *ASCast = dyn_cast<Operator>(P2I->getOperand(0));
-  if (!ASCast || ASCast->getOpcode() != Instruction::AddrSpaceCast)
+  Operator *AsCast = dyn_cast<AddrSpaceCastOperator>(OldPtr);
+  if (!AsCast)
     return false;
 
   unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
-  unsigned DstAS = ASCast->getOperand(0)->getType()->getPointerAddressSpace();
+  unsigned DstAS = AsCast->getOperand(0)->getType()->getPointerAddressSpace();
   APInt PreservedPtrMask = TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
   if (PreservedPtrMask.isZero())
     return false;
-  APInt ChangedPtrBits = computeMaxChangedPtrBits(LogicOP, Mask, DL, &AC, DT);
+  APInt ChangedPtrBits =
+      computeMaxChangedPtrBits(cast<Operator>(LogicalOp), Mask, DL, &AC, DT);
   // Check if the address bits change is within the preserved mask. If the bits
   // change is not preserved, it is not safe to perform address space cast.
   // The following pattern is not safe to cast address space.
@@ -468,7 +462,7 @@ bool InferAddressSpacesImpl::isSafeToCastPtrIntPair(
   //   %gp = inttoptr i64 %2 to ptr
   assert(ChangedPtrBits.getBitWidth() == PreservedPtrMask.getBitWidth());
   if (ChangedPtrBits.isSubsetOf(PreservedPtrMask)) {
-    PtrIntCastPairs.insert({I2P, P2I->getOperand(0)});
+    PtrIntCastPairs.insert({I2P, OldPtr});
     return true;
   }
 
@@ -683,7 +677,7 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
     } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) {
       if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI))
         PushPtrOperand(cast<Operator>(I2P->getOperand(0))->getOperand(0));
-      else if (isSafeToCastPtrIntPair(cast<Operator>(I2P), *DL))
+      else if (isSafeToCastIntToPtrAddrSpace(cast<Operator>(I2P), *DL))
         PushPtrOperand(PtrIntCastPairs[I2P]);
     } else if (auto *RI = dyn_cast<ReturnInst>(&I)) {
       if (auto *RV = RI->getReturnValue();
@@ -941,7 +935,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
       // cast back.
       return new AddrSpaceCastInst(Src, NewPtrType);
     }
-    assert(isSafeToCastPtrIntPair(cast<Operator>(I), *DL));
+    assert(isSafeToCastIntToPtrAddrSpace(cast<Operator>(I), *DL));
     AddrSpaceCastInst *AsCast = new AddrSpaceCastInst(I, NewPtrType);
     AsCast->insertAfter(I);
     return AsCast;
@@ -984,7 +978,7 @@ Value *InferAddressSpacesImpl::cloneConstantExprWithNewAddressSpace(
       assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
       return Src;
     }
-    assert(isSafeToCastPtrIntPair(cast<Operator>(CE), *DL));
+    assert(isSafeToCastIntToPtrAddrSpace(cast<Operator>(CE), *DL));
     return ConstantExpr::getAddrSpaceCast(CE, TargetType);
   }
 

>From 16ba9f9bbb9b2ea48004726ad4a413c6e1a09f13 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Sat, 31 Jan 2026 13:42:48 +0800
Subject: [PATCH 11/12] [InferAS] Fix clang format issues

---
 llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 54f3c0f13ceec..d22ecab433608 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -111,8 +111,8 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Operator.h"
-#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"

>From 8f4ff5cd5167a64fa808109b4aa45f2035049e27 Mon Sep 17 00:00:00 2001
From: Yuanke Luo <ykluo at birentech.com>
Date: Wed, 11 Feb 2026 12:47:18 +0800
Subject: [PATCH 12/12] Add getIntToPtrPointerOperand() API

---
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 48 ++++++++++---------
 1 file changed, 26 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index d22ecab433608..1e1b2db4828c1 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -198,8 +198,8 @@ class InferAddressSpacesImpl {
   unsigned FlatAddrSpace = 0;
   mutable DenseMap<const Value *, Value *> PtrIntCastPairs;
 
-  bool isSafeToCastIntToPtrAddrSpace(const Operator *I2P,
-                                     const DataLayout &DL) const;
+  Value *getIntToPtrPointerOperand(const Operator *I2P) const;
+  bool isSafeToCastIntToPtrAddrSpace(const Operator *I2P) const;
   bool isAddressExpression(const Value &V, const DataLayout &DL,
                            const TargetTransformInfo *TTI) const;
   Value *cloneConstantExprWithNewAddressSpace(
@@ -360,8 +360,7 @@ bool InferAddressSpacesImpl::isAddressExpression(
     return II && II->getIntrinsicID() == Intrinsic::ptrmask;
   }
   case Instruction::IntToPtr:
-    return isNoopPtrIntCastPair(Op, DL, TTI) ||
-           isSafeToCastIntToPtrAddrSpace(Op, DL);
+    return isNoopPtrIntCastPair(Op, DL, TTI) || getIntToPtrPointerOperand(Op);
   default:
     // That value is an address expression if it has an assumed address space.
     return TTI->getAssumedAddrSpace(&V) != UninitializedAddressSpace;
@@ -400,7 +399,7 @@ SmallVector<Value *, 2> InferAddressSpacesImpl::getPointerOperands(
       auto *P2I = cast<Operator>(Op.getOperand(0));
       return {P2I->getOperand(0)};
     }
-    assert(isSafeToCastIntToPtrAddrSpace(&Op, DL));
+    assert(isSafeToCastIntToPtrAddrSpace(&Op));
     return {PtrIntCastPairs[&Op]};
   }
   default:
@@ -427,33 +426,38 @@ static APInt computeMaxChangedPtrBits(const Operator *Op, const Value *Mask,
 }
 
 bool InferAddressSpacesImpl::isSafeToCastIntToPtrAddrSpace(
-    const Operator *I2P, const DataLayout &DL) const {
-  assert(I2P->getOpcode() == Instruction::IntToPtr);
-  // If I2P has been checked before and has the corresponding old pointer value,
-  // just return true.
-  if (PtrIntCastPairs.count(I2P))
-    return true;
+    const Operator *I2P) const {
+  return PtrIntCastPairs.contains(I2P);
+}
 
+Value *
+InferAddressSpacesImpl::getIntToPtrPointerOperand(const Operator *I2P) const {
+  assert(I2P->getOpcode() == Instruction::IntToPtr);
   if (I2P->getType()->isVectorTy())
-    return false;
+    return nullptr;
+
+  // If I2P has been accessed and has the corresponding old pointer value, just
+  // return true.
+  if (PtrIntCastPairs.count(I2P))
+    return PtrIntCastPairs[I2P];
 
   Value *LogicalOp = I2P->getOperand(0);
   Value *OldPtr, *Mask;
   if (!match(LogicalOp,
              m_c_BitwiseLogic(m_PtrToInt(m_Value(OldPtr)), m_Value(Mask))))
-    return false;
+    return nullptr;
 
   Operator *AsCast = dyn_cast<AddrSpaceCastOperator>(OldPtr);
   if (!AsCast)
-    return false;
+    return nullptr;
 
   unsigned SrcAS = I2P->getType()->getPointerAddressSpace();
   unsigned DstAS = AsCast->getOperand(0)->getType()->getPointerAddressSpace();
   APInt PreservedPtrMask = TTI->getAddrSpaceCastPreservedPtrMask(SrcAS, DstAS);
   if (PreservedPtrMask.isZero())
-    return false;
+    return nullptr;
   APInt ChangedPtrBits =
-      computeMaxChangedPtrBits(cast<Operator>(LogicalOp), Mask, DL, &AC, DT);
+      computeMaxChangedPtrBits(cast<Operator>(LogicalOp), Mask, *DL, &AC, DT);
   // Check if the address bits change is within the preserved mask. If the bits
   // change is not preserved, it is not safe to perform address space cast.
   // The following pattern is not safe to cast address space.
@@ -463,10 +467,10 @@ bool InferAddressSpacesImpl::isSafeToCastIntToPtrAddrSpace(
   assert(ChangedPtrBits.getBitWidth() == PreservedPtrMask.getBitWidth());
   if (ChangedPtrBits.isSubsetOf(PreservedPtrMask)) {
     PtrIntCastPairs.insert({I2P, OldPtr});
-    return true;
+    return OldPtr;
   }
 
-  return false;
+  return nullptr;
 }
 
 bool InferAddressSpacesImpl::rewriteIntrinsicOperands(IntrinsicInst *II,
@@ -677,8 +681,8 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
     } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) {
       if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI))
         PushPtrOperand(cast<Operator>(I2P->getOperand(0))->getOperand(0));
-      else if (isSafeToCastIntToPtrAddrSpace(cast<Operator>(I2P), *DL))
-        PushPtrOperand(PtrIntCastPairs[I2P]);
+      else if (auto *P2I = getIntToPtrPointerOperand(cast<Operator>(I2P)))
+        PushPtrOperand(P2I);
     } else if (auto *RI = dyn_cast<ReturnInst>(&I)) {
       if (auto *RV = RI->getReturnValue();
           RV && RV->getType()->isPtrOrPtrVectorTy())
@@ -935,7 +939,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
       // cast back.
       return new AddrSpaceCastInst(Src, NewPtrType);
     }
-    assert(isSafeToCastIntToPtrAddrSpace(cast<Operator>(I), *DL));
+    assert(isSafeToCastIntToPtrAddrSpace(cast<Operator>(I)));
     AddrSpaceCastInst *AsCast = new AddrSpaceCastInst(I, NewPtrType);
     AsCast->insertAfter(I);
     return AsCast;
@@ -978,7 +982,7 @@ Value *InferAddressSpacesImpl::cloneConstantExprWithNewAddressSpace(
       assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
       return Src;
     }
-    assert(isSafeToCastIntToPtrAddrSpace(cast<Operator>(CE), *DL));
+    assert(isSafeToCastIntToPtrAddrSpace(cast<Operator>(CE)));
     return ConstantExpr::getAddrSpaceCast(CE, TargetType);
   }
 



More information about the llvm-commits mailing list