[llvm] [InferAddressSpaces] Handle unconverted ptrmask (PR #140802)

Robert Imschweiler via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 26 04:48:58 PDT 2025


https://github.com/ro-i updated https://github.com/llvm/llvm-project/pull/140802

>From 93be28d273c2c0f7e97e68e615eca42361a5278f Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Tue, 20 May 2025 15:55:36 -0500
Subject: [PATCH 01/12] [InferAddressSpaces] Handle unconverted ptrmask

In case a ptrmask cannot be converted to the new address space due to an
unknown mask value, this needs to be detcted and an addrspacecast is
needed to not hinder a future use of the unconverted return value of
ptrmask. Otherwise, users of this value will become invalid by receiving
a nullptr as an operand.

This LLVM defect was identified via the AMD Fuzzing project.
---
 .../Transforms/Scalar/InferAddressSpaces.cpp   | 15 ++++++++++++++-
 .../InferAddressSpaces/AMDGPU/ptrmask.ll       | 18 ++++++++++++++++++
 2 files changed, 32 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 66836ef05d5db..acc4368b0772b 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -1350,7 +1350,20 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
 
     unsigned OperandNo = PoisonUse->getOperandNo();
     assert(isa<PoisonValue>(NewV->getOperand(OperandNo)));
-    NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(PoisonUse->get()));
+    WeakTrackingVH NewOp = ValueWithNewAddrSpace.lookup(PoisonUse->get());
+    if (NewOp) {
+      NewV->setOperand(OperandNo, NewOp);
+    } else {
+      // Something went wrong while converting the instruction defining the new
+      // operand value.  -> Replace the poison value with the previous operand
+      // value combined with an addrspace case.
+      Value *PoisonOp = NewV->getOperand(OperandNo);
+      Value *OldOp = V->getOperand(OperandNo);
+      Value *AddrSpaceCast =
+          new AddrSpaceCastInst(OldOp, PoisonOp->getType(), "",
+                                cast<Instruction>(NewV)->getIterator());
+      NewV->setOperand(OperandNo, AddrSpaceCast);
+    }
   }
 
   SmallVector<Instruction *, 16> DeadInstructions;
diff --git a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
index 6ef926f935830..1c1d1df79520d 100644
--- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
@@ -343,6 +343,24 @@ define i8 @ptrmask_cast_local_to_flat_load_range_mask(ptr addrspace(3) %src.ptr,
   ret i8 %load
 }
 
+; Non-const masks with no known range should not prevent other ptr-manipulating
+; instructions (such as gep) from being converted.
+define i8 @ptrmask_cast_local_to_flat_unknown_mask(ptr addrspace(3) %src.ptr, i64 %mask, i64 %idx) {
+; CHECK-LABEL: @ptrmask_cast_local_to_flat_unknown_mask(
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[SRC_PTR:%.*]] to ptr
+; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 [[MASK:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP1]], i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[GEP]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %cast = addrspacecast ptr addrspace(3) %src.ptr to ptr
+  %masked = call ptr @llvm.ptrmask.p0.i64(ptr %cast, i64 %mask)
+  %gep = getelementptr i8, ptr %masked, i64 %idx
+  %load = load i8, ptr %gep
+  ret i8 %load
+}
+
 declare ptr @llvm.ptrmask.p0.i64(ptr, i64) #0
 declare ptr addrspace(5) @llvm.ptrmask.p5.i32(ptr addrspace(5), i32) #0
 declare ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3), i32) #0

>From 6f2064e85ea32455d11c5b2df360b017ebb1e883 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Wed, 21 May 2025 03:42:35 -0500
Subject: [PATCH 02/12] implement feedback

---
 .../Transforms/Scalar/InferAddressSpaces.cpp   |  2 +-
 .../InferAddressSpaces/AMDGPU/ptrmask.ll       | 18 ++++++++++++++++++
 2 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index acc4368b0772b..5e870e7bbd7c5 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -1356,7 +1356,7 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
     } else {
       // Something went wrong while converting the instruction defining the new
       // operand value.  -> Replace the poison value with the previous operand
-      // value combined with an addrspace case.
+      // value combined with an addrspacecast.
       Value *PoisonOp = NewV->getOperand(OperandNo);
       Value *OldOp = V->getOperand(OperandNo);
       Value *AddrSpaceCast =
diff --git a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
index 1c1d1df79520d..ccc7df89dcd8e 100644
--- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
@@ -361,7 +361,25 @@ define i8 @ptrmask_cast_local_to_flat_unknown_mask(ptr addrspace(3) %src.ptr, i6
   ret i8 %load
 }
 
+define i8 @ptrmask_vector_cast_local_to_flat_unknown_mask(<2 x ptr addrspace(3)> %src.ptr, <2 x i64> %mask, i64 %ptridx, i64 %idx) {
+; CHECK-LABEL: @ptrmask_vector_cast_local_to_flat_unknown_mask(
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast <2 x ptr addrspace(3)> [[SRC_PTR:%.*]] to <2 x ptr>
+; CHECK-NEXT:    [[MASKED:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[CAST]], <2 x i64> [[MASK:%.*]])
+; CHECK-NEXT:    [[PTR:%.*]] = extractelement <2 x ptr> [[MASKED]], i64 [[PTRIDX:%.*]]
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[GEP]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %cast = addrspacecast <2 x ptr addrspace(3)> %src.ptr to <2 x ptr>
+  %masked = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> %cast, <2 x i64> %mask)
+  %ptr = extractelement <2 x ptr> %masked, i64 %ptridx
+  %gep = getelementptr i8, ptr %ptr, i64 %idx
+  %load = load i8, ptr %gep
+  ret i8 %load
+}
+
 declare ptr @llvm.ptrmask.p0.i64(ptr, i64) #0
+declare <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr>, <2 x i64>) #0
 declare ptr addrspace(5) @llvm.ptrmask.p5.i32(ptr addrspace(5), i32) #0
 declare ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3), i32) #0
 declare ptr addrspace(1) @llvm.ptrmask.p1.i64(ptr addrspace(1), i64) #0

>From 9a11446da0ab234f6fa39a1d948bd6a8c7facfaa Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Wed, 21 May 2025 08:39:38 -0500
Subject: [PATCH 03/12] implement alternative approach

---
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 34 +++++-------
 .../InferAddressSpaces/AMDGPU/ptrmask.ll      | 52 ++++++++++++++-----
 2 files changed, 53 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 5e870e7bbd7c5..10962895176b9 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -658,9 +658,6 @@ static Value *operandWithNewAddressSpaceOrCreatePoison(
 // Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast
 // from a pointer whose type already matches. Therefore, this function returns a
 // Value* instead of an Instruction*.
-//
-// This may also return nullptr in the case the instruction could not be
-// rewritten.
 Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
     Instruction *I, unsigned NewAddrSpace,
     const ValueToValueMapTy &ValueWithNewAddrSpace,
@@ -681,17 +678,24 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
     // Technically the intrinsic ID is a pointer typed argument, so specially
     // handle calls early.
     assert(II->getIntrinsicID() == Intrinsic::ptrmask);
+    const Use &PtrArgUse = II->getArgOperandUse(0);
     Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison(
-        II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace,
-        PredicatedAS, PoisonUsesToFix);
+        PtrArgUse, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS,
+        PoisonUsesToFix);
     Value *Rewrite =
         TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr);
     if (Rewrite) {
       assert(Rewrite != II && "cannot modify this pointer operation in place");
       return Rewrite;
     }
-
-    return nullptr;
+    // Leave the ptrmask as-is and insert an addrspacecast after it.
+    Instruction *AddrSpaceCast = new AddrSpaceCastInst(II, NewPtr->getType());
+    AddrSpaceCast->insertAfter(II->getIterator());
+    AddrSpaceCast->setDebugLoc(II->getDebugLoc());
+    // If we generated a poison operand for the ptr argument, remove it.
+    if (!PoisonUsesToFix->empty() && PoisonUsesToFix->back() == &PtrArgUse)
+      PoisonUsesToFix->pop_back();
+    return AddrSpaceCast;
   }
 
   unsigned AS = TTI->getAssumedAddrSpace(I);
@@ -1351,19 +1355,9 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
     unsigned OperandNo = PoisonUse->getOperandNo();
     assert(isa<PoisonValue>(NewV->getOperand(OperandNo)));
     WeakTrackingVH NewOp = ValueWithNewAddrSpace.lookup(PoisonUse->get());
-    if (NewOp) {
-      NewV->setOperand(OperandNo, NewOp);
-    } else {
-      // Something went wrong while converting the instruction defining the new
-      // operand value.  -> Replace the poison value with the previous operand
-      // value combined with an addrspacecast.
-      Value *PoisonOp = NewV->getOperand(OperandNo);
-      Value *OldOp = V->getOperand(OperandNo);
-      Value *AddrSpaceCast =
-          new AddrSpaceCastInst(OldOp, PoisonOp->getType(), "",
-                                cast<Instruction>(NewV)->getIterator());
-      NewV->setOperand(OperandNo, AddrSpaceCast);
-    }
+    assert(NewOp &&
+           "poison replacements in ValueWithNewAddrSpace shouldn't be null");
+    NewV->setOperand(OperandNo, NewOp);
   }
 
   SmallVector<Instruction *, 16> DeadInstructions;
diff --git a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
index ccc7df89dcd8e..bc89d49b3cd2b 100644
--- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
@@ -3,9 +3,10 @@
 
 define i8 @ptrmask_cast_local_to_flat(ptr addrspace(3) %src.ptr, i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat(
-; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[SRC_PTR:%.*]] to ptr
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[SRC_PTR1:%.*]] to ptr
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[MASKED]], align 1
+; CHECK-NEXT:    [[SRC_PTR:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[SRC_PTR]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %cast = addrspacecast ptr addrspace(3) %src.ptr to ptr
@@ -16,9 +17,10 @@ define i8 @ptrmask_cast_local_to_flat(ptr addrspace(3) %src.ptr, i64 %mask) {
 
 define i8 @ptrmask_cast_private_to_flat(ptr addrspace(5) %src.ptr, i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_private_to_flat(
-; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(5) [[SRC_PTR:%.*]] to ptr
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(5) [[SRC_PTR1:%.*]] to ptr
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[MASKED]], align 1
+; CHECK-NEXT:    [[SRC_PTR:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(5)
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(5) [[SRC_PTR]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %cast = addrspacecast ptr addrspace(5) %src.ptr to ptr
@@ -29,9 +31,10 @@ define i8 @ptrmask_cast_private_to_flat(ptr addrspace(5) %src.ptr, i64 %mask) {
 
 define i8 @ptrmask_cast_region_to_flat(ptr addrspace(2) %src.ptr, i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_region_to_flat(
-; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(2) [[SRC_PTR:%.*]] to ptr
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(2) [[SRC_PTR1:%.*]] to ptr
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[MASKED]], align 1
+; CHECK-NEXT:    [[SRC_PTR:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(2)
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(2) [[SRC_PTR]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %cast = addrspacecast ptr addrspace(2) %src.ptr to ptr
@@ -109,7 +112,8 @@ define i8 @ptrmask_cast_flat_to_global(ptr %ptr, i64 %mask) {
 define i8 @ptrmask_cast_local_to_flat_global(i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat_global(
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr addrspacecast (ptr addrspace(3) @lds0 to ptr), i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[MASKED]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[TMP1]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %masked = call ptr @llvm.ptrmask.p0.i64(ptr addrspacecast (ptr addrspace(3) @lds0 to ptr), i64 %mask)
@@ -150,7 +154,8 @@ define i8 @multi_ptrmask_cast_local_to_flat(ptr addrspace(3) %src.ptr, i64 %mask
 ; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[SRC_PTR:%.*]] to ptr
 ; CHECK-NEXT:    [[LOAD0:%.*]] = load i8, ptr addrspace(3) [[SRC_PTR]], align 1
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD1:%.*]] = load i8, ptr [[MASKED]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[LOAD1:%.*]] = load i8, ptr addrspace(3) [[TMP1]], align 1
 ; CHECK-NEXT:    [[ADD:%.*]] = add i8 [[LOAD0]], [[LOAD1]]
 ; CHECK-NEXT:    ret i8 [[ADD]]
 ;
@@ -167,7 +172,8 @@ define i8 @multi_ptrmask_cast_region_to_flat(ptr addrspace(2) %src.ptr, i64 %mas
 ; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(2) [[SRC_PTR:%.*]] to ptr
 ; CHECK-NEXT:    [[LOAD0:%.*]] = load i8, ptr addrspace(2) [[SRC_PTR]], align 1
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD1:%.*]] = load i8, ptr [[MASKED]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(2)
+; CHECK-NEXT:    [[LOAD1:%.*]] = load i8, ptr addrspace(2) [[TMP1]], align 1
 ; CHECK-NEXT:    [[ADD:%.*]] = add i8 [[LOAD0]], [[LOAD1]]
 ; CHECK-NEXT:    ret i8 [[ADD]]
 ;
@@ -182,9 +188,10 @@ define i8 @multi_ptrmask_cast_region_to_flat(ptr addrspace(2) %src.ptr, i64 %mas
 ; Do not fold this since it clears a single high bit.
 define i8 @ptrmask_cast_local_to_flat_const_mask_fffffffeffffffff(ptr addrspace(3) %src.ptr) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat_const_mask_fffffffeffffffff(
-; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[SRC_PTR:%.*]] to ptr
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[SRC_PTR1:%.*]] to ptr
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 -4294967297)
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[MASKED]], align 1
+; CHECK-NEXT:    [[SRC_PTR:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[SRC_PTR]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %cast = addrspacecast ptr addrspace(3) %src.ptr to ptr
@@ -196,9 +203,10 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_fffffffeffffffff(ptr addrspace(
 ; Do not fold this since it clears a single high bit.
 define i8 @ptrmask_cast_local_to_flat_const_mask_7fffffffffffffff(ptr addrspace(3) %src.ptr) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat_const_mask_7fffffffffffffff(
-; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[SRC_PTR:%.*]] to ptr
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[SRC_PTR1:%.*]] to ptr
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 9223372036854775807)
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[MASKED]], align 1
+; CHECK-NEXT:    [[SRC_PTR:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[SRC_PTR]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %cast = addrspacecast ptr addrspace(3) %src.ptr to ptr
@@ -378,6 +386,24 @@ define i8 @ptrmask_vector_cast_local_to_flat_unknown_mask(<2 x ptr addrspace(3)>
   ret i8 %load
 }
 
+define i8 @interleaved_ptrmask_cast_local_to_flat_unknown_mask(ptr addrspace(3) %src.ptr, i64 %mask, i64 %idx) {
+; CHECK-LABEL: @interleaved_ptrmask_cast_local_to_flat_unknown_mask(
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[GEP1:%.*]] to ptr
+; CHECK-NEXT:    store i64 [[MASK:%.*]], ptr addrspace(3) [[GEP1]], align 8
+; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 [[MASK]])
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP1]], i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[GEP]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %cast = addrspacecast ptr addrspace(3) %src.ptr to ptr
+  store i64 %mask, ptr %cast
+  %masked = call ptr @llvm.ptrmask.p0.i64(ptr %cast, i64 %mask)
+  %gep = getelementptr i8, ptr %masked, i64 %idx
+  %load = load i8, ptr %gep
+  ret i8 %load
+}
+
 declare ptr @llvm.ptrmask.p0.i64(ptr, i64) #0
 declare <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr>, <2 x i64>) #0
 declare ptr addrspace(5) @llvm.ptrmask.p5.i32(ptr addrspace(5), i32) #0

>From 298b74e0e21efe3f2b75a00552f910e36256fe34 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Thu, 22 May 2025 09:34:33 -0500
Subject: [PATCH 04/12] pull ptrmask handling completely into
 InferAddressSpaces

---
 .../AMDGPU/AMDGPUTargetTransformInfo.cpp      | 35 -----------
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 59 ++++++++++++++-----
 2 files changed, 43 insertions(+), 51 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
index 58bfc0b80b24f..faeee5ebae5ff 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
@@ -1067,41 +1067,6 @@ Value *GCNTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
       ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx);
     return NewVal;
   }
-  case Intrinsic::ptrmask: {
-    unsigned OldAS = OldV->getType()->getPointerAddressSpace();
-    unsigned NewAS = NewV->getType()->getPointerAddressSpace();
-    Value *MaskOp = II->getArgOperand(1);
-    Type *MaskTy = MaskOp->getType();
-
-    bool DoTruncate = false;
-
-    const GCNTargetMachine &TM =
-        static_cast<const GCNTargetMachine &>(getTLI()->getTargetMachine());
-    if (!TM.isNoopAddrSpaceCast(OldAS, NewAS)) {
-      // All valid 64-bit to 32-bit casts work by chopping off the high
-      // bits. Any masking only clearing the low bits will also apply in the new
-      // address space.
-      if (DL.getPointerSizeInBits(OldAS) != 64 ||
-          DL.getPointerSizeInBits(NewAS) != 32)
-        return nullptr;
-
-      // TODO: Do we need to thread more context in here?
-      KnownBits Known = computeKnownBits(MaskOp, DL, nullptr, II);
-      if (Known.countMinLeadingOnes() < 32)
-        return nullptr;
-
-      DoTruncate = true;
-    }
-
-    IRBuilder<> B(II);
-    if (DoTruncate) {
-      MaskTy = B.getInt32Ty();
-      MaskOp = B.CreateTrunc(MaskOp, MaskTy);
-    }
-
-    return B.CreateIntrinsic(Intrinsic::ptrmask, {NewV->getType(), MaskTy},
-                             {NewV, MaskOp});
-  }
   case Intrinsic::amdgcn_flat_atomic_fmax_num:
   case Intrinsic::amdgcn_flat_atomic_fmin_num: {
     Type *DestTy = II->getType();
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 10962895176b9..69975924b0f9e 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -678,24 +678,51 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
     // Technically the intrinsic ID is a pointer typed argument, so specially
     // handle calls early.
     assert(II->getIntrinsicID() == Intrinsic::ptrmask);
-    const Use &PtrArgUse = II->getArgOperandUse(0);
+    const Use &PtrOpUse = II->getArgOperandUse(0);
+    unsigned OldAddrSpace = PtrOpUse.get()->getType()->getPointerAddressSpace();
+    Value *MaskOp = II->getArgOperand(1);
+    Type *MaskTy = MaskOp->getType();
+
+    bool DoTruncate = false;
+    bool DoNotConvert = false;
+
+    if (!TTI->isNoopAddrSpaceCast(OldAddrSpace, NewAddrSpace)) {
+      // All valid 64-bit to 32-bit casts work by chopping off the high
+      // bits. Any masking only clearing the low bits will also apply in the new
+      // address space.
+      if (DL->getPointerSizeInBits(OldAddrSpace) != 64 ||
+          DL->getPointerSizeInBits(NewAddrSpace) != 32) {
+        DoNotConvert = true;
+      } else {
+        // TODO: Do we need to thread more context in here?
+        KnownBits Known = computeKnownBits(MaskOp, *DL, 0, nullptr, II);
+        if (Known.countMinLeadingOnes() < 32)
+          DoNotConvert = true;
+        else
+          DoTruncate = true;
+      }
+    }
+    if (DoNotConvert) {
+      // Leave the ptrmask as-is and insert an addrspacecast after it.
+      std::optional<BasicBlock::iterator> InsertPoint =
+          II->getInsertionPointAfterDef();
+      assert(InsertPoint && "insertion after ptrmask should be possible");
+      Instruction *AddrSpaceCast =
+          new AddrSpaceCastInst(II, NewPtrType, "", *InsertPoint);
+      AddrSpaceCast->setDebugLoc(II->getDebugLoc());
+      return AddrSpaceCast;
+    }
+
+    IRBuilder<> B(II);
+    if (DoTruncate) {
+      MaskTy = B.getInt32Ty();
+      MaskOp = B.CreateTrunc(MaskOp, MaskTy);
+    }
     Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison(
-        PtrArgUse, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS,
+        PtrOpUse, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS,
         PoisonUsesToFix);
-    Value *Rewrite =
-        TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr);
-    if (Rewrite) {
-      assert(Rewrite != II && "cannot modify this pointer operation in place");
-      return Rewrite;
-    }
-    // Leave the ptrmask as-is and insert an addrspacecast after it.
-    Instruction *AddrSpaceCast = new AddrSpaceCastInst(II, NewPtr->getType());
-    AddrSpaceCast->insertAfter(II->getIterator());
-    AddrSpaceCast->setDebugLoc(II->getDebugLoc());
-    // If we generated a poison operand for the ptr argument, remove it.
-    if (!PoisonUsesToFix->empty() && PoisonUsesToFix->back() == &PtrArgUse)
-      PoisonUsesToFix->pop_back();
-    return AddrSpaceCast;
+    return B.CreateIntrinsic(Intrinsic::ptrmask, {NewPtr->getType(), MaskTy},
+                             {NewPtr, MaskOp});
   }
 
   unsigned AS = TTI->getAssumedAddrSpace(I);

>From 2de20906cc442f1c9036ab7a9d64a0bb6e925e00 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Thu, 22 May 2025 15:51:50 -0500
Subject: [PATCH 05/12] move to helper function; check for integral addrspace;
 move addrspace check to TTI

---
 .../llvm/Analysis/TargetTransformInfo.h       |  5 +
 .../llvm/Analysis/TargetTransformInfoImpl.h   | 29 ++++++
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  5 +
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 97 ++++++++++---------
 4 files changed, 91 insertions(+), 45 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 8f4ce80ada5ed..f8e8b316702b0 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -495,6 +495,11 @@ class TargetTransformInfo {
 
   LLVM_ABI bool isNoopAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
 
+  std::optional<uint64_t> getAddrSpaceCastMaskWidth(unsigned FromAS,
+                                                    unsigned ToAS,
+                                                    Value *MaskOP,
+                                                    Instruction *I) const;
+
   /// Return true if globals in this address space can have initializers other
   /// than `undef`.
   LLVM_ABI bool
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a80b4c5179bad..0f560b26498b5 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -16,6 +16,7 @@
 
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
@@ -151,6 +152,34 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isNoopAddrSpaceCast(unsigned, unsigned) const { return false; }
+
+  // Assuming that the cast between the two given addrspaces is not a noop,
+  // calculate the width of the given mask value so that it can be applied to
+  // the destination addrspace. In case it cannot be applied since the cast
+  // between the two addrspaces is invalid or the mask value is larger than the
+  // resulting addrspace bit-width, return an empty optional.
+  //
+  // Note that this currently expects the addrspaces to be integral. In case one
+  // of them isn't, an empty optional is returned.
+  virtual std::optional<uint64_t>
+  getAddrSpaceCastMaskWidth(unsigned FromAS, unsigned ToAS, Value *MaskOp,
+                            Instruction *I) const {
+    if (DL.isNonIntegralAddressSpace(FromAS) ||
+        DL.isNonIntegralAddressSpace(ToAS))
+      return std::nullopt;
+    // All valid 64-bit to 32-bit casts work by chopping off the high
+    // bits. Any masking only clearing the low bits will also apply in the new
+    // address space.
+    if (DL.getPointerSizeInBits(FromAS) != 64 ||
+        DL.getPointerSizeInBits(ToAS) != 32)
+      return std::nullopt;
+    // TODO: Do we need to thread more context in here?
+    KnownBits Known = computeKnownBits(MaskOp, DL, 0, nullptr, I);
+    if (Known.countMinLeadingOnes() < 32)
+      return std::nullopt;
+    return 32;
+  }
+
   virtual bool
   canHaveNonUndefGlobalInitializerInAddressSpace(unsigned AS) const {
     return AS == 0;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 2d053e55bdfa9..c03c87b61853a 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -329,6 +329,11 @@ bool TargetTransformInfo::isNoopAddrSpaceCast(unsigned FromAS,
   return TTIImpl->isNoopAddrSpaceCast(FromAS, ToAS);
 }
 
+std::optional<uint64_t> TargetTransformInfo::getAddrSpaceCastMaskWidth(
+    unsigned FromAS, unsigned ToAS, Value *MaskOP, Instruction *I) const {
+  return TTIImpl->getAddrSpaceCastMaskWidth(FromAS, ToAS, MaskOP, I);
+}
+
 bool TargetTransformInfo::canHaveNonUndefGlobalInitializerInAddressSpace(
     unsigned AS) const {
   return TTIImpl->canHaveNonUndefGlobalInitializerInAddressSpace(AS);
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 69975924b0f9e..074ad66493ed5 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -206,6 +206,12 @@ class InferAddressSpacesImpl {
 
   bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const;
 
+  Value *clonePtrMaskWithNewAddressSpace(
+      IntrinsicInst *I, unsigned NewAddrSpace,
+      const ValueToValueMapTy &ValueWithNewAddrSpace,
+      const PredicatedAddrSpaceMapTy &PredicatedAS,
+      SmallVectorImpl<const Use *> *PoisonUsesToFix) const;
+
   Value *cloneInstructionWithNewAddressSpace(
       Instruction *I, unsigned NewAddrSpace,
       const ValueToValueMapTy &ValueWithNewAddrSpace,
@@ -649,6 +655,50 @@ static Value *operandWithNewAddressSpaceOrCreatePoison(
   return PoisonValue::get(NewPtrTy);
 }
 
+// A helper function for cloneInstructionWithNewAddressSpace. Handles the
+// conversion of a ptrmask intrinsic instruction.
+Value *InferAddressSpacesImpl::clonePtrMaskWithNewAddressSpace(
+    IntrinsicInst *I, unsigned NewAddrSpace,
+    const ValueToValueMapTy &ValueWithNewAddrSpace,
+    const PredicatedAddrSpaceMapTy &PredicatedAS,
+    SmallVectorImpl<const Use *> *PoisonUsesToFix) const {
+  const Use &PtrOpUse = I->getArgOperandUse(0);
+  unsigned OldAddrSpace = PtrOpUse.get()->getType()->getPointerAddressSpace();
+  Value *MaskOp = I->getArgOperand(1);
+  Type *MaskTy = MaskOp->getType();
+
+  std::optional<uint64_t> TruncateToWidth;
+
+  if (!TTI->isNoopAddrSpaceCast(OldAddrSpace, NewAddrSpace)) {
+    // Get the mask width that is applicable to the new addrspace.
+    TruncateToWidth =
+        TTI->getAddrSpaceCastMaskWidth(OldAddrSpace, NewAddrSpace, MaskOp, I);
+    // If there is no such mask, leave the ptrmask as-is and insert an
+    // addrspacecast after it.
+    if (!TruncateToWidth) {
+      std::optional<BasicBlock::iterator> InsertPoint =
+          I->getInsertionPointAfterDef();
+      assert(InsertPoint && "insertion after ptrmask should be possible");
+      Type *NewPtrType = getPtrOrVecOfPtrsWithNewAS(I->getType(), NewAddrSpace);
+      Instruction *AddrSpaceCast =
+          new AddrSpaceCastInst(I, NewPtrType, "", *InsertPoint);
+      AddrSpaceCast->setDebugLoc(I->getDebugLoc());
+      return AddrSpaceCast;
+    }
+  }
+
+  IRBuilder<> B(I);
+  if (TruncateToWidth) {
+    MaskTy = B.getIntNTy(*TruncateToWidth);
+    MaskOp = B.CreateTrunc(MaskOp, MaskTy);
+  }
+  Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison(
+      PtrOpUse, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS,
+      PoisonUsesToFix);
+  return B.CreateIntrinsic(Intrinsic::ptrmask, {NewPtr->getType(), MaskTy},
+                           {NewPtr, MaskOp});
+}
+
 // Returns a clone of `I` with its operands converted to those specified in
 // ValueWithNewAddrSpace. Due to potential cycles in the data flow graph, an
 // operand whose address space needs to be modified might not exist in
@@ -678,51 +728,8 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
     // Technically the intrinsic ID is a pointer typed argument, so specially
     // handle calls early.
     assert(II->getIntrinsicID() == Intrinsic::ptrmask);
-    const Use &PtrOpUse = II->getArgOperandUse(0);
-    unsigned OldAddrSpace = PtrOpUse.get()->getType()->getPointerAddressSpace();
-    Value *MaskOp = II->getArgOperand(1);
-    Type *MaskTy = MaskOp->getType();
-
-    bool DoTruncate = false;
-    bool DoNotConvert = false;
-
-    if (!TTI->isNoopAddrSpaceCast(OldAddrSpace, NewAddrSpace)) {
-      // All valid 64-bit to 32-bit casts work by chopping off the high
-      // bits. Any masking only clearing the low bits will also apply in the new
-      // address space.
-      if (DL->getPointerSizeInBits(OldAddrSpace) != 64 ||
-          DL->getPointerSizeInBits(NewAddrSpace) != 32) {
-        DoNotConvert = true;
-      } else {
-        // TODO: Do we need to thread more context in here?
-        KnownBits Known = computeKnownBits(MaskOp, *DL, 0, nullptr, II);
-        if (Known.countMinLeadingOnes() < 32)
-          DoNotConvert = true;
-        else
-          DoTruncate = true;
-      }
-    }
-    if (DoNotConvert) {
-      // Leave the ptrmask as-is and insert an addrspacecast after it.
-      std::optional<BasicBlock::iterator> InsertPoint =
-          II->getInsertionPointAfterDef();
-      assert(InsertPoint && "insertion after ptrmask should be possible");
-      Instruction *AddrSpaceCast =
-          new AddrSpaceCastInst(II, NewPtrType, "", *InsertPoint);
-      AddrSpaceCast->setDebugLoc(II->getDebugLoc());
-      return AddrSpaceCast;
-    }
-
-    IRBuilder<> B(II);
-    if (DoTruncate) {
-      MaskTy = B.getInt32Ty();
-      MaskOp = B.CreateTrunc(MaskOp, MaskTy);
-    }
-    Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison(
-        PtrOpUse, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS,
-        PoisonUsesToFix);
-    return B.CreateIntrinsic(Intrinsic::ptrmask, {NewPtr->getType(), MaskTy},
-                             {NewPtr, MaskOp});
+    return clonePtrMaskWithNewAddressSpace(
+        II, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, PoisonUsesToFix);
   }
 
   unsigned AS = TTI->getAssumedAddrSpace(I);

>From af027cbee6a0f13909fb75b84a713ae1286b1b20 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Tue, 10 Jun 2025 06:43:50 -0500
Subject: [PATCH 06/12] fix after rebase

---
 llvm/include/llvm/Analysis/TargetTransformInfoImpl.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 0f560b26498b5..40f4b50e9790d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -174,7 +174,7 @@ class TargetTransformInfoImplBase {
         DL.getPointerSizeInBits(ToAS) != 32)
       return std::nullopt;
     // TODO: Do we need to thread more context in here?
-    KnownBits Known = computeKnownBits(MaskOp, DL, 0, nullptr, I);
+    KnownBits Known = computeKnownBits(MaskOp, DL, nullptr, I);
     if (Known.countMinLeadingOnes() < 32)
       return std::nullopt;
     return 32;

>From 1e71a8fd6808e88d394d00d8e92f3a088732e9a4 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Mon, 21 Jul 2025 08:41:10 -0500
Subject: [PATCH 07/12] implement feedback

---
 .../llvm/Analysis/TargetTransformInfo.h       |  6 ++--
 .../llvm/Analysis/TargetTransformInfoImpl.h   | 36 ++++++++-----------
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  7 ++--
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 17 +++++----
 4 files changed, 31 insertions(+), 35 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f8e8b316702b0..0b202272b1f4b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -495,10 +495,8 @@ class TargetTransformInfo {
 
   LLVM_ABI bool isNoopAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
 
-  std::optional<uint64_t> getAddrSpaceCastMaskWidth(unsigned FromAS,
-                                                    unsigned ToAS,
-                                                    Value *MaskOP,
-                                                    Instruction *I) const;
+  LLVM_ABI KnownBits computeKnownBitsAddrSpaceCast(unsigned FromAS,
+                                                   unsigned ToAS) const;
 
   /// Return true if globals in this address space can have initializers other
   /// than `undef`.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 40f4b50e9790d..9d7e94f7d138c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -154,30 +154,22 @@ class TargetTransformInfoImplBase {
   virtual bool isNoopAddrSpaceCast(unsigned, unsigned) const { return false; }
 
   // Assuming that the cast between the two given addrspaces is not a noop,
-  // calculate the width of the given mask value so that it can be applied to
-  // the destination addrspace. In case it cannot be applied since the cast
-  // between the two addrspaces is invalid or the mask value is larger than the
-  // resulting addrspace bit-width, return an empty optional.
-  //
-  // Note that this currently expects the addrspaces to be integral. In case one
-  // of them isn't, an empty optional is returned.
-  virtual std::optional<uint64_t>
-  getAddrSpaceCastMaskWidth(unsigned FromAS, unsigned ToAS, Value *MaskOp,
-                            Instruction *I) const {
+  // calculate the known bits of the resulting ptr in the destination addrspace.
+  // The default implementation returns 0 known bits in case either one of the
+  // addrspaces is not integral or the bit size of the source addrspace is
+  // smaller than the bit size of the destination addrspace.
+  virtual KnownBits computeKnownBitsAddrSpaceCast(unsigned FromAS,
+                                                  unsigned ToAS) const {
     if (DL.isNonIntegralAddressSpace(FromAS) ||
         DL.isNonIntegralAddressSpace(ToAS))
-      return std::nullopt;
-    // All valid 64-bit to 32-bit casts work by chopping off the high
-    // bits. Any masking only clearing the low bits will also apply in the new
-    // address space.
-    if (DL.getPointerSizeInBits(FromAS) != 64 ||
-        DL.getPointerSizeInBits(ToAS) != 32)
-      return std::nullopt;
-    // TODO: Do we need to thread more context in here?
-    KnownBits Known = computeKnownBits(MaskOp, DL, nullptr, I);
-    if (Known.countMinLeadingOnes() < 32)
-      return std::nullopt;
-    return 32;
+      return 0;
+    unsigned FromASBitSize = DL.getPointerSizeInBits(FromAS);
+    unsigned ToASBitSize = DL.getPointerSizeInBits(ToAS);
+    // Per default, we assume that all valid "larger" (e.g. 64-bit) to "smaller"
+    // (e.g. 32-bit) casts work by chopping off the high bits.
+    if (FromASBitSize < ToASBitSize)
+      return 0;
+    return ToASBitSize;
   }
 
   virtual bool
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index c03c87b61853a..e0d5c9a69c942 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -329,9 +329,10 @@ bool TargetTransformInfo::isNoopAddrSpaceCast(unsigned FromAS,
   return TTIImpl->isNoopAddrSpaceCast(FromAS, ToAS);
 }
 
-std::optional<uint64_t> TargetTransformInfo::getAddrSpaceCastMaskWidth(
-    unsigned FromAS, unsigned ToAS, Value *MaskOP, Instruction *I) const {
-  return TTIImpl->getAddrSpaceCastMaskWidth(FromAS, ToAS, MaskOP, I);
+KnownBits
+TargetTransformInfo::computeKnownBitsAddrSpaceCast(unsigned FromAS,
+                                                   unsigned ToAS) const {
+  return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS);
 }
 
 bool TargetTransformInfo::canHaveNonUndefGlobalInitializerInAddressSpace(
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 074ad66493ed5..5525e64560d9b 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -667,15 +667,20 @@ Value *InferAddressSpacesImpl::clonePtrMaskWithNewAddressSpace(
   Value *MaskOp = I->getArgOperand(1);
   Type *MaskTy = MaskOp->getType();
 
-  std::optional<uint64_t> TruncateToWidth;
+  std::optional<KnownBits> KnownPtrBits;
 
   if (!TTI->isNoopAddrSpaceCast(OldAddrSpace, NewAddrSpace)) {
+    KnownPtrBits =
+        TTI->computeKnownBitsAddrSpaceCast(OldAddrSpace, NewAddrSpace);
     // Get the mask width that is applicable to the new addrspace.
-    TruncateToWidth =
-        TTI->getAddrSpaceCastMaskWidth(OldAddrSpace, NewAddrSpace, MaskOp, I);
     // If there is no such mask, leave the ptrmask as-is and insert an
     // addrspacecast after it.
-    if (!TruncateToWidth) {
+    KnownBits KnownMaskBits = computeKnownBits(MaskOp, *DL, nullptr, I);
+    // Any masking only clearing the low bits will also apply in the new address
+    // space. (To check this: compute the number of mask bits that might be zero
+    // and compare it with the number of ptr bits that might be one.)
+    if (KnownMaskBits.getBitWidth() - KnownMaskBits.countMinLeadingOnes() >
+        KnownPtrBits->countMaxActiveBits()) {
       std::optional<BasicBlock::iterator> InsertPoint =
           I->getInsertionPointAfterDef();
       assert(InsertPoint && "insertion after ptrmask should be possible");
@@ -688,8 +693,8 @@ Value *InferAddressSpacesImpl::clonePtrMaskWithNewAddressSpace(
   }
 
   IRBuilder<> B(I);
-  if (TruncateToWidth) {
-    MaskTy = B.getIntNTy(*TruncateToWidth);
+  if (KnownPtrBits) {
+    MaskTy = B.getIntNTy(KnownPtrBits->getBitWidth());
     MaskOp = B.CreateTrunc(MaskOp, MaskTy);
   }
   Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison(

>From 71aa7e111b7f02b90914610431f48ee2659bcab0 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Thu, 4 Sep 2025 17:24:46 -0500
Subject: [PATCH 08/12] implement feedback

---
 .../llvm/Analysis/TargetTransformInfo.h       |  9 ++-
 .../llvm/Analysis/TargetTransformInfoImpl.h   | 70 +++++++++++++++----
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 12 +++-
 .../Transforms/Scalar/InferAddressSpaces.cpp  | 46 +++++++-----
 .../InferAddressSpaces/AMDGPU/ptrmask.ll      | 52 ++++++++++++++
 5 files changed, 153 insertions(+), 36 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0b202272b1f4b..6e11828ed27a9 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -495,8 +495,13 @@ class TargetTransformInfo {
 
   LLVM_ABI bool isNoopAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
 
-  LLVM_ABI KnownBits computeKnownBitsAddrSpaceCast(unsigned FromAS,
-                                                   unsigned ToAS) const;
+  LLVM_ABI std::optional<std::pair<KnownBits, KnownBits>>
+  computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
+                                const Value &PtrOp) const;
+
+  LLVM_ABI std::optional<KnownBits>
+  computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
+                                const KnownBits &FromPtrBits) const;
 
   /// Return true if globals in this address space can have initializers other
   /// than `undef`.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9d7e94f7d138c..865f22fe1f9d2 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -153,23 +153,63 @@ class TargetTransformInfoImplBase {
 
   virtual bool isNoopAddrSpaceCast(unsigned, unsigned) const { return false; }
 
-  // Assuming that the cast between the two given addrspaces is not a noop,
-  // calculate the known bits of the resulting ptr in the destination addrspace.
-  // The default implementation returns 0 known bits in case either one of the
-  // addrspaces is not integral or the bit size of the source addrspace is
-  // smaller than the bit size of the destination addrspace.
-  virtual KnownBits computeKnownBitsAddrSpaceCast(unsigned FromAS,
-                                                  unsigned ToAS) const {
-    if (DL.isNonIntegralAddressSpace(FromAS) ||
-        DL.isNonIntegralAddressSpace(ToAS))
-      return 0;
-    unsigned FromASBitSize = DL.getPointerSizeInBits(FromAS);
+  // Given an address space cast of the given pointer value, calculate the known
+  // bits of the source pointer in the source addrspace and the destination
+  // pointer in the destination addrspace.
+  // The default implementation returns an empty optional in case one of the
+  // addrspaces is not integral.
+  virtual std::optional<std::pair<KnownBits, KnownBits>>
+  computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
+                                const Value &PtrOp) const {
+    if (DL.isNonIntegralAddressSpace(FromAS))
+      return std::nullopt;
+
+    KnownBits FromPtrBits;
+    if (const AddrSpaceCastInst *CastI = dyn_cast<AddrSpaceCastInst>(&PtrOp)) {
+      std::optional<std::pair<KnownBits, KnownBits>> KB =
+          computeKnownBitsAddrSpaceCast(CastI->getSrcAddressSpace(),
+                                        CastI->getDestAddressSpace(),
+                                        *CastI->getPointerOperand());
+      if (!KB)
+        return std::nullopt;
+      FromPtrBits = KB->second;
+    } else if (isa<ConstantPointerNull>(PtrOp) && !FromAS) {
+      // For addrspace 0, we know that a null pointer has the value 0.
+      FromPtrBits = KnownBits::makeConstant(
+          APInt::getZero(DL.getPointerSizeInBits(FromAS)));
+    } else {
+      FromPtrBits = computeKnownBits(&PtrOp, DL, nullptr);
+    }
+
+    std::optional<KnownBits> ToPtrBits =
+        computeKnownBitsAddrSpaceCast(FromAS, ToAS, FromPtrBits);
+    if (!ToPtrBits)
+      return std::nullopt;
+
+    return std::pair(FromPtrBits, *ToPtrBits);
+  }
+
+  // Given an address space cast, calculate the known bits of the resulting ptr
+  // in the destination addrspace using the known bits of the source pointer in
+  // the source addrspace.
+  // The default implementation returns an empty optional in case the source
+  // addrspace is not an integral addrspace.
+  virtual std::optional<KnownBits>
+  computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
+                                const KnownBits &FromPtrBits) const {
+    if (DL.isNonIntegralAddressSpace(ToAS))
+      return std::nullopt;
+
     unsigned ToASBitSize = DL.getPointerSizeInBits(ToAS);
-    // Per default, we assume that all valid "larger" (e.g. 64-bit) to "smaller"
+    // By default, we assume that all valid "larger" (e.g. 64-bit) to "smaller"
     // (e.g. 32-bit) casts work by chopping off the high bits.
-    if (FromASBitSize < ToASBitSize)
-      return 0;
-    return ToASBitSize;
+    if (FromPtrBits.getBitWidth() >= ToASBitSize)
+      return FromPtrBits.trunc(ToASBitSize);
+    // By default, we do not assume that null results in null again, except for
+    // addrspace 0.
+    if (!FromAS && FromPtrBits.isZero())
+      return FromPtrBits.zext(ToASBitSize);
+    return FromPtrBits.anyext(ToASBitSize);
   }
 
   virtual bool
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index e0d5c9a69c942..88b7062e2b009 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -329,10 +329,16 @@ bool TargetTransformInfo::isNoopAddrSpaceCast(unsigned FromAS,
   return TTIImpl->isNoopAddrSpaceCast(FromAS, ToAS);
 }
 
-KnownBits
+std::optional<std::pair<KnownBits, KnownBits>>
 TargetTransformInfo::computeKnownBitsAddrSpaceCast(unsigned FromAS,
-                                                   unsigned ToAS) const {
-  return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS);
+                                                   unsigned ToAS,
+                                                   const Value &PtrOp) const {
+  return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS, PtrOp);
+}
+
+std::optional<KnownBits> TargetTransformInfo::computeKnownBitsAddrSpaceCast(
+    unsigned FromAS, unsigned ToAS, const KnownBits &FromPtrBits) const {
+  return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS, FromPtrBits);
 }
 
 bool TargetTransformInfo::canHaveNonUndefGlobalInitializerInAddressSpace(
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 5525e64560d9b..4875e444e22ab 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -663,24 +663,38 @@ Value *InferAddressSpacesImpl::clonePtrMaskWithNewAddressSpace(
     const PredicatedAddrSpaceMapTy &PredicatedAS,
     SmallVectorImpl<const Use *> *PoisonUsesToFix) const {
   const Use &PtrOpUse = I->getArgOperandUse(0);
-  unsigned OldAddrSpace = PtrOpUse.get()->getType()->getPointerAddressSpace();
+  unsigned OldAddrSpace = PtrOpUse->getType()->getPointerAddressSpace();
   Value *MaskOp = I->getArgOperand(1);
   Type *MaskTy = MaskOp->getType();
 
-  std::optional<KnownBits> KnownPtrBits;
-
+  std::optional<KnownBits> OldPtrBits;
+  std::optional<KnownBits> NewPtrBits;
   if (!TTI->isNoopAddrSpaceCast(OldAddrSpace, NewAddrSpace)) {
-    KnownPtrBits =
-        TTI->computeKnownBitsAddrSpaceCast(OldAddrSpace, NewAddrSpace);
-    // Get the mask width that is applicable to the new addrspace.
-    // If there is no such mask, leave the ptrmask as-is and insert an
-    // addrspacecast after it.
-    KnownBits KnownMaskBits = computeKnownBits(MaskOp, *DL, nullptr, I);
-    // Any masking only clearing the low bits will also apply in the new address
-    // space. (To check this: compute the number of mask bits that might be zero
-    // and compare it with the number of ptr bits that might be one.)
-    if (KnownMaskBits.getBitWidth() - KnownMaskBits.countMinLeadingOnes() >
-        KnownPtrBits->countMaxActiveBits()) {
+    if (std::optional<std::pair<KnownBits, KnownBits>> KB =
+            TTI->computeKnownBitsAddrSpaceCast(OldAddrSpace, NewAddrSpace,
+                                               *PtrOpUse.get())) {
+      OldPtrBits = KB->first;
+      NewPtrBits = KB->second;
+    }
+  }
+
+  // If the pointers in both addrspaces have a bitwise representation and if the
+  // representation of the new pointer is smaller (fewer bits) than the old one,
+  // check if the mask is applicable to the ptr in the new addrspace. Any
+  // masking only clearing the low bits will also apply in the new addrspace
+  // Note: checking if the mask clears high bits is not sufficient as those
+  // might have already been 0 in the old ptr.
+  if (NewPtrBits && OldPtrBits->getBitWidth() > NewPtrBits->getBitWidth()) {
+    KnownBits MaskBits = computeKnownBits(MaskOp, *DL, nullptr, I);
+    // Set all unknown bits of the old ptr to 1, so that we are conservative in
+    // checking which bits are cleared by the mask.
+    OldPtrBits->One |= ~OldPtrBits->Zero;
+    // Check which bits are cleared by the mask in the old ptr.
+    KnownBits ClearedBits = KnownBits::sub(*OldPtrBits, *OldPtrBits & MaskBits);
+
+    // If the mask isn't applicable to the new ptr, leave the ptrmask as-is and
+    // insert an addrspacecast after it.
+    if (ClearedBits.countMaxActiveBits() > NewPtrBits->countMaxActiveBits()) {
       std::optional<BasicBlock::iterator> InsertPoint =
           I->getInsertionPointAfterDef();
       assert(InsertPoint && "insertion after ptrmask should be possible");
@@ -693,8 +707,8 @@ Value *InferAddressSpacesImpl::clonePtrMaskWithNewAddressSpace(
   }
 
   IRBuilder<> B(I);
-  if (KnownPtrBits) {
-    MaskTy = B.getIntNTy(KnownPtrBits->getBitWidth());
+  if (NewPtrBits) {
+    MaskTy = B.getIntNTy(NewPtrBits->getBitWidth());
     MaskOp = B.CreateTrunc(MaskOp, MaskTy);
   }
   Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison(
diff --git a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
index bc89d49b3cd2b..27cd9ad97cd5f 100644
--- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
@@ -15,6 +15,33 @@ define i8 @ptrmask_cast_local_to_flat(ptr addrspace(3) %src.ptr, i64 %mask) {
   ret i8 %load
 }
 
+; Casting null does not necessarily result in null again ...
+define i8 @ptrmask_cast_local_null_to_flat(i64 %mask) {
+; CHECK-LABEL: @ptrmask_cast_local_null_to_flat(
+; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr addrspacecast (ptr addrspace(3) null to ptr), i64 [[MASK:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[TMP1]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %cast = addrspacecast ptr addrspace(3) null to ptr
+  %masked = call ptr @llvm.ptrmask.p0.i64(ptr %cast, i64 %mask)
+  %load = load i8, ptr %masked
+  ret i8 %load
+}
+
+; ... exception: addrspace 0
+define i8 @ptrmask_cast_flat_null_to_local(i32 %mask) {
+; CHECK-LABEL: @ptrmask_cast_flat_null_to_local(
+; CHECK-NEXT:    [[MASKED:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) addrspacecast (ptr null to ptr addrspace(3)), i32 [[MASK:%.*]])
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[MASKED]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %cast = addrspacecast ptr null to ptr addrspace(3)
+  %masked = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) %cast, i32 %mask)
+  %load = load i8, ptr addrspace(3) %masked
+  ret i8 %load
+}
+
 define i8 @ptrmask_cast_private_to_flat(ptr addrspace(5) %src.ptr, i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_private_to_flat(
 ; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(5) [[SRC_PTR1:%.*]] to ptr
@@ -215,6 +242,20 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_7fffffffffffffff(ptr addrspace(
   ret i8 %load
 }
 
+; Do not fold: casting null does not necessarily result in null again
+define i8 @ptrmask_cast_local_null_to_flat_const_mask_7fffffffffffffff() {
+; CHECK-LABEL: @ptrmask_cast_local_null_to_flat_const_mask_7fffffffffffffff(
+; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr addrspacecast (ptr addrspace(3) null to ptr), i64 9223372036854775807)
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[TMP1]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %cast = addrspacecast ptr addrspace(3) null to ptr
+  %masked = call ptr @llvm.ptrmask.p0.i64(ptr %cast, i64 9223372036854775807)
+  %load = load i8, ptr %masked
+  ret i8 %load
+}
+
 define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffff00000000(ptr addrspace(3) %src.ptr) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat_const_mask_ffffffff00000000(
 ; CHECK-NEXT:    [[TMP1:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) [[SRC_PTR:%.*]], i32 0)
@@ -227,6 +268,17 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffff00000000(ptr addrspace(
   ret i8 %load
 }
 
+define i8 @ptrmask_cast_local_null_to_flat_const_mask_ffffffff00000000() {
+; CHECK-LABEL: @ptrmask_cast_local_null_to_flat_const_mask_ffffffff00000000(
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) null, align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %cast = addrspacecast ptr addrspace(3) null to ptr
+  %masked = call ptr @llvm.ptrmask.p0.i64(ptr %cast, i64 -4294967296)
+  %load = load i8, ptr %masked
+  ret i8 %load
+}
+
 define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffff80000000(ptr addrspace(3) %src.ptr) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat_const_mask_ffffffff80000000(
 ; CHECK-NEXT:    [[TMP1:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) [[SRC_PTR:%.*]], i32 -2147483648)

>From b1a66bf32a8f1536485c483c9f20b124ad23dd64 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Fri, 5 Sep 2025 02:36:21 -0500
Subject: [PATCH 09/12] implement feedback

---
 .../llvm/Analysis/TargetTransformInfo.h       |  7 ++-
 .../llvm/Analysis/TargetTransformInfoImpl.h   | 46 +++++++------------
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  4 +-
 .../InferAddressSpaces/AMDGPU/ptrmask.ll      |  3 +-
 4 files changed, 24 insertions(+), 36 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 6e11828ed27a9..fcf3563a898a5 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -495,13 +495,12 @@ class TargetTransformInfo {
 
   LLVM_ABI bool isNoopAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
 
-  LLVM_ABI std::optional<std::pair<KnownBits, KnownBits>>
+  LLVM_ABI std::pair<KnownBits, KnownBits>
   computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
                                 const Value &PtrOp) const;
 
-  LLVM_ABI std::optional<KnownBits>
-  computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
-                                const KnownBits &FromPtrBits) const;
+  LLVM_ABI KnownBits computeKnownBitsAddrSpaceCast(
+      unsigned FromAS, unsigned ToAS, const KnownBits &FromPtrBits) const;
 
   /// Return true if globals in this address space can have initializers other
   /// than `undef`.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 865f22fe1f9d2..74d1db6bdfccf 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -156,24 +156,20 @@ class TargetTransformInfoImplBase {
   // Given an address space cast of the given pointer value, calculate the known
   // bits of the source pointer in the source addrspace and the destination
   // pointer in the destination addrspace.
-  // The default implementation returns an empty optional in case one of the
-  // addrspaces is not integral.
-  virtual std::optional<std::pair<KnownBits, KnownBits>>
+  virtual std::pair<KnownBits, KnownBits>
   computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
                                 const Value &PtrOp) const {
     if (DL.isNonIntegralAddressSpace(FromAS))
-      return std::nullopt;
+      return std::pair(KnownBits(DL.getPointerSizeInBits(FromAS)),
+                       KnownBits(DL.getPointerSizeInBits(ToAS)));
 
     KnownBits FromPtrBits;
     if (const AddrSpaceCastInst *CastI = dyn_cast<AddrSpaceCastInst>(&PtrOp)) {
-      std::optional<std::pair<KnownBits, KnownBits>> KB =
-          computeKnownBitsAddrSpaceCast(CastI->getSrcAddressSpace(),
-                                        CastI->getDestAddressSpace(),
-                                        *CastI->getPointerOperand());
-      if (!KB)
-        return std::nullopt;
-      FromPtrBits = KB->second;
-    } else if (isa<ConstantPointerNull>(PtrOp) && !FromAS) {
+      std::pair<KnownBits, KnownBits> KB = computeKnownBitsAddrSpaceCast(
+          CastI->getSrcAddressSpace(), CastI->getDestAddressSpace(),
+          *CastI->getPointerOperand());
+      FromPtrBits = KB.second;
+    } else if (isa<ConstantPointerNull>(PtrOp) && FromAS == 0) {
       // For addrspace 0, we know that a null pointer has the value 0.
       FromPtrBits = KnownBits::makeConstant(
           APInt::getZero(DL.getPointerSizeInBits(FromAS)));
@@ -181,35 +177,27 @@ class TargetTransformInfoImplBase {
       FromPtrBits = computeKnownBits(&PtrOp, DL, nullptr);
     }
 
-    std::optional<KnownBits> ToPtrBits =
+    KnownBits ToPtrBits =
         computeKnownBitsAddrSpaceCast(FromAS, ToAS, FromPtrBits);
-    if (!ToPtrBits)
-      return std::nullopt;
 
-    return std::pair(FromPtrBits, *ToPtrBits);
+    return std::pair(FromPtrBits, ToPtrBits);
   }
 
   // Given an address space cast, calculate the known bits of the resulting ptr
   // in the destination addrspace using the known bits of the source pointer in
   // the source addrspace.
-  // The default implementation returns an empty optional in case the source
-  // addrspace is not an integral addrspace.
-  virtual std::optional<KnownBits>
+  virtual KnownBits
   computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
                                 const KnownBits &FromPtrBits) const {
-    if (DL.isNonIntegralAddressSpace(ToAS))
-      return std::nullopt;
-
     unsigned ToASBitSize = DL.getPointerSizeInBits(ToAS);
+
+    if (DL.isNonIntegralAddressSpace(FromAS))
+      return KnownBits(ToASBitSize);
+
     // By default, we assume that all valid "larger" (e.g. 64-bit) to "smaller"
     // (e.g. 32-bit) casts work by chopping off the high bits.
-    if (FromPtrBits.getBitWidth() >= ToASBitSize)
-      return FromPtrBits.trunc(ToASBitSize);
-    // By default, we do not assume that null results in null again, except for
-    // addrspace 0.
-    if (!FromAS && FromPtrBits.isZero())
-      return FromPtrBits.zext(ToASBitSize);
-    return FromPtrBits.anyext(ToASBitSize);
+    // By default, we do not assume that null results in null again.
+    return FromPtrBits.anyextOrTrunc(ToASBitSize);
   }
 
   virtual bool
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 88b7062e2b009..91db1ff824ef6 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -329,14 +329,14 @@ bool TargetTransformInfo::isNoopAddrSpaceCast(unsigned FromAS,
   return TTIImpl->isNoopAddrSpaceCast(FromAS, ToAS);
 }
 
-std::optional<std::pair<KnownBits, KnownBits>>
+std::pair<KnownBits, KnownBits>
 TargetTransformInfo::computeKnownBitsAddrSpaceCast(unsigned FromAS,
                                                    unsigned ToAS,
                                                    const Value &PtrOp) const {
   return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS, PtrOp);
 }
 
-std::optional<KnownBits> TargetTransformInfo::computeKnownBitsAddrSpaceCast(
+KnownBits TargetTransformInfo::computeKnownBitsAddrSpaceCast(
     unsigned FromAS, unsigned ToAS, const KnownBits &FromPtrBits) const {
   return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS, FromPtrBits);
 }
diff --git a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
index 27cd9ad97cd5f..5e0716c1e308f 100644
--- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
@@ -29,7 +29,8 @@ define i8 @ptrmask_cast_local_null_to_flat(i64 %mask) {
   ret i8 %load
 }
 
-; ... exception: addrspace 0
+; ... exception: addrspace 0 casted to a smaller addrspace (by default we assume
+; that casting to a smaller addrspace = truncating)
 define i8 @ptrmask_cast_flat_null_to_local(i32 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_flat_null_to_local(
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) addrspacecast (ptr null to ptr addrspace(3)), i32 [[MASK:%.*]])

>From ba0e502fbc6a3af26739efc82c68d2b3ada98c53 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Fri, 5 Sep 2025 21:02:00 +0200
Subject: [PATCH 10/12] Update
 llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp

Co-authored-by: Shilei Tian <i at tianshilei.me>
---
 llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 4875e444e22ab..707cf4ba51f0e 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -685,7 +685,7 @@ Value *InferAddressSpacesImpl::clonePtrMaskWithNewAddressSpace(
   // Note: checking if the mask clears high bits is not sufficient as those
   // might have already been 0 in the old ptr.
   if (NewPtrBits && OldPtrBits->getBitWidth() > NewPtrBits->getBitWidth()) {
-    KnownBits MaskBits = computeKnownBits(MaskOp, *DL, nullptr, I);
+    KnownBits MaskBits = computeKnownBits(MaskOp, *DL, /*AssumptionCache=*/nullptr, I);
     // Set all unknown bits of the old ptr to 1, so that we are conservative in
     // checking which bits are cleared by the mask.
     OldPtrBits->One |= ~OldPtrBits->Zero;

>From 8f14063731b38f82724a4c5d2b79b3fe65a767ec Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Fri, 5 Sep 2025 14:07:16 -0500
Subject: [PATCH 11/12] fix formatting

---
 llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 707cf4ba51f0e..b2a306a0bfe04 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -685,7 +685,8 @@ Value *InferAddressSpacesImpl::clonePtrMaskWithNewAddressSpace(
   // Note: checking if the mask clears high bits is not sufficient as those
   // might have already been 0 in the old ptr.
   if (NewPtrBits && OldPtrBits->getBitWidth() > NewPtrBits->getBitWidth()) {
-    KnownBits MaskBits = computeKnownBits(MaskOp, *DL, /*AssumptionCache=*/nullptr, I);
+    KnownBits MaskBits =
+        computeKnownBits(MaskOp, *DL, /*AssumptionCache=*/nullptr, I);
     // Set all unknown bits of the old ptr to 1, so that we are conservative in
     // checking which bits are cleared by the mask.
     OldPtrBits->One |= ~OldPtrBits->Zero;

>From 8ab73686a371dde99125b5c3cdb22aac7777ebb7 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschweiler at amd.com>
Date: Fri, 26 Sep 2025 06:46:21 -0500
Subject: [PATCH 12/12] implement feedback

---
 .../llvm/Analysis/TargetTransformInfo.h       |   9 +-
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  20 ++-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   5 +-
 .../Transforms/Scalar/InferAddressSpaces.cpp  |   9 +-
 .../InferAddressSpaces/AMDGPU/ptrmask.ll      | 141 +++++++++++++++---
 5 files changed, 142 insertions(+), 42 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index fcf3563a898a5..bfec188306180 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -495,10 +495,15 @@ class TargetTransformInfo {
 
   LLVM_ABI bool isNoopAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
 
+  // Given an address space cast of the given pointer value, calculate the known
+  // bits of the source pointer in the source addrspace and the destination
+  // pointer in the destination addrspace.
   LLVM_ABI std::pair<KnownBits, KnownBits>
-  computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
-                                const Value &PtrOp) const;
+  computeKnownBitsAddrSpaceCast(unsigned ToAS, const Value &PtrOp) const;
 
+  // Given an address space cast, calculate the known bits of the resulting ptr
+  // in the destination addrspace using the known bits of the source pointer in
+  // the source addrspace.
   LLVM_ABI KnownBits computeKnownBitsAddrSpaceCast(
       unsigned FromAS, unsigned ToAS, const KnownBits &FromPtrBits) const;
 
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 74d1db6bdfccf..f6c8e5755cfdb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -153,12 +153,13 @@ class TargetTransformInfoImplBase {
 
   virtual bool isNoopAddrSpaceCast(unsigned, unsigned) const { return false; }
 
-  // Given an address space cast of the given pointer value, calculate the known
-  // bits of the source pointer in the source addrspace and the destination
-  // pointer in the destination addrspace.
   virtual std::pair<KnownBits, KnownBits>
-  computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
-                                const Value &PtrOp) const {
+  computeKnownBitsAddrSpaceCast(unsigned ToAS, const Value &PtrOp) const {
+    const Type *PtrTy = PtrOp.getType();
+    assert(PtrTy->isPtrOrPtrVectorTy() &&
+           "expected pointer or pointer vector type");
+    unsigned FromAS = PtrTy->getPointerAddressSpace();
+
     if (DL.isNonIntegralAddressSpace(FromAS))
       return std::pair(KnownBits(DL.getPointerSizeInBits(FromAS)),
                        KnownBits(DL.getPointerSizeInBits(ToAS)));
@@ -166,10 +167,10 @@ class TargetTransformInfoImplBase {
     KnownBits FromPtrBits;
     if (const AddrSpaceCastInst *CastI = dyn_cast<AddrSpaceCastInst>(&PtrOp)) {
       std::pair<KnownBits, KnownBits> KB = computeKnownBitsAddrSpaceCast(
-          CastI->getSrcAddressSpace(), CastI->getDestAddressSpace(),
-          *CastI->getPointerOperand());
+          CastI->getDestAddressSpace(), *CastI->getPointerOperand());
       FromPtrBits = KB.second;
-    } else if (isa<ConstantPointerNull>(PtrOp) && FromAS == 0) {
+    } else if (const Constant *PtrC = dyn_cast<Constant>(&PtrOp);
+               FromAS == 0 && PtrC->isNullValue()) {
       // For addrspace 0, we know that a null pointer has the value 0.
       FromPtrBits = KnownBits::makeConstant(
           APInt::getZero(DL.getPointerSizeInBits(FromAS)));
@@ -183,9 +184,6 @@ class TargetTransformInfoImplBase {
     return std::pair(FromPtrBits, ToPtrBits);
   }
 
-  // Given an address space cast, calculate the known bits of the resulting ptr
-  // in the destination addrspace using the known bits of the source pointer in
-  // the source addrspace.
   virtual KnownBits
   computeKnownBitsAddrSpaceCast(unsigned FromAS, unsigned ToAS,
                                 const KnownBits &FromPtrBits) const {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 91db1ff824ef6..4ad55d5cd4c44 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -330,10 +330,9 @@ bool TargetTransformInfo::isNoopAddrSpaceCast(unsigned FromAS,
 }
 
 std::pair<KnownBits, KnownBits>
-TargetTransformInfo::computeKnownBitsAddrSpaceCast(unsigned FromAS,
-                                                   unsigned ToAS,
+TargetTransformInfo::computeKnownBitsAddrSpaceCast(unsigned ToAS,
                                                    const Value &PtrOp) const {
-  return TTIImpl->computeKnownBitsAddrSpaceCast(FromAS, ToAS, PtrOp);
+  return TTIImpl->computeKnownBitsAddrSpaceCast(ToAS, PtrOp);
 }
 
 KnownBits TargetTransformInfo::computeKnownBitsAddrSpaceCast(
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index b2a306a0bfe04..5c0e1ecb274a7 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -671,8 +671,7 @@ Value *InferAddressSpacesImpl::clonePtrMaskWithNewAddressSpace(
   std::optional<KnownBits> NewPtrBits;
   if (!TTI->isNoopAddrSpaceCast(OldAddrSpace, NewAddrSpace)) {
     if (std::optional<std::pair<KnownBits, KnownBits>> KB =
-            TTI->computeKnownBitsAddrSpaceCast(OldAddrSpace, NewAddrSpace,
-                                               *PtrOpUse.get())) {
+            TTI->computeKnownBitsAddrSpaceCast(NewAddrSpace, *PtrOpUse.get())) {
       OldPtrBits = KB->first;
       NewPtrBits = KB->second;
     }
@@ -709,7 +708,11 @@ Value *InferAddressSpacesImpl::clonePtrMaskWithNewAddressSpace(
 
   IRBuilder<> B(I);
   if (NewPtrBits) {
-    MaskTy = B.getIntNTy(NewPtrBits->getBitWidth());
+    if (MaskTy->isVectorTy())
+      MaskTy = VectorType::get(B.getIntNTy(NewPtrBits->getBitWidth()),
+                               cast<VectorType>(MaskTy));
+    else
+      MaskTy = B.getIntNTy(NewPtrBits->getBitWidth());
     MaskOp = B.CreateTrunc(MaskOp, MaskTy);
   }
   Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison(
diff --git a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
index 5e0716c1e308f..644b8cdd0d2e4 100644
--- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
@@ -15,7 +15,20 @@ define i8 @ptrmask_cast_local_to_flat(ptr addrspace(3) %src.ptr, i64 %mask) {
   ret i8 %load
 }
 
-; Casting null does not necessarily result in null again ...
+define <3 x ptr addrspace(3)> @ptrmask_vector_cast_local_to_flat(<3 x ptr addrspace(3)> %src.ptr, <3 x i64> %mask) {
+; CHECK-LABEL: @ptrmask_vector_cast_local_to_flat(
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast <3 x ptr addrspace(3)> [[SRC_PTR:%.*]] to <3 x ptr>
+; CHECK-NEXT:    [[MASKED:%.*]] = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> [[CAST]], <3 x i64> [[MASK:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast <3 x ptr> [[MASKED]] to <3 x ptr addrspace(3)>
+; CHECK-NEXT:    ret <3 x ptr addrspace(3)> [[TMP1]]
+;
+  %cast = addrspacecast <3 x ptr addrspace(3)> %src.ptr to <3 x ptr>
+  %masked = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> %cast, <3 x i64> %mask)
+  %cast2 = addrspacecast <3 x ptr> %masked to <3 x ptr addrspace(3)>
+  ret <3 x ptr addrspace(3)> %cast2
+}
+
+; Casting null does not necessarily result in null again.
 define i8 @ptrmask_cast_local_null_to_flat(i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_local_null_to_flat(
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr addrspacecast (ptr addrspace(3) null to ptr), i64 [[MASK:%.*]])
@@ -23,24 +36,20 @@ define i8 @ptrmask_cast_local_null_to_flat(i64 %mask) {
 ; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[TMP1]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
-  %cast = addrspacecast ptr addrspace(3) null to ptr
-  %masked = call ptr @llvm.ptrmask.p0.i64(ptr %cast, i64 %mask)
+  %masked = call ptr @llvm.ptrmask.p0.i64(ptr addrspacecast (ptr addrspace(3) null to ptr), i64 %mask)
   %load = load i8, ptr %masked
   ret i8 %load
 }
 
-; ... exception: addrspace 0 casted to a smaller addrspace (by default we assume
-; that casting to a smaller addrspace = truncating)
-define i8 @ptrmask_cast_flat_null_to_local(i32 %mask) {
-; CHECK-LABEL: @ptrmask_cast_flat_null_to_local(
-; CHECK-NEXT:    [[MASKED:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) addrspacecast (ptr null to ptr addrspace(3)), i32 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[MASKED]], align 1
-; CHECK-NEXT:    ret i8 [[LOAD]]
+define <3 x ptr addrspace(3)> @ptrmask_vector_cast_local_null_to_flat(<3 x i64> %mask) {
+; CHECK-LABEL: @ptrmask_vector_cast_local_null_to_flat(
+; CHECK-NEXT:    [[MASKED:%.*]] = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> <ptr addrspacecast (ptr addrspace(3) null to ptr), ptr addrspacecast (ptr addrspace(3) null to ptr), ptr addrspacecast (ptr addrspace(3) null to ptr)>, <3 x i64> [[MASK:%.*]])
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast <3 x ptr> [[MASKED]] to <3 x ptr addrspace(3)>
+; CHECK-NEXT:    ret <3 x ptr addrspace(3)> [[CAST]]
 ;
-  %cast = addrspacecast ptr null to ptr addrspace(3)
-  %masked = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) %cast, i32 %mask)
-  %load = load i8, ptr addrspace(3) %masked
-  ret i8 %load
+  %masked = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> addrspacecast (<3 x ptr addrspace(3)> <ptr addrspace(3) null, ptr addrspace(3) null, ptr addrspace(3) null> to <3 x ptr>), <3 x i64> %mask)
+  %cast = addrspacecast <3 x ptr> %masked to <3 x ptr addrspace(3)>
+  ret <3 x ptr addrspace(3)> %cast
 }
 
 define i8 @ptrmask_cast_private_to_flat(ptr addrspace(5) %src.ptr, i64 %mask) {
@@ -108,6 +117,46 @@ define i8 @ptrmask_cast_flat_to_local(ptr %ptr, i64 %mask) {
   ret i8 %load
 }
 
+define <3 x ptr addrspace(3)> @ptrmask_vector_cast_flat_to_local(<3 x ptr> %src.ptr, <3 x i64> %mask) {
+; CHECK-LABEL: @ptrmask_vector_cast_flat_to_local(
+; CHECK-NEXT:    [[SRC_PTR:%.*]] = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> [[SRC_PTR1:%.*]], <3 x i64> [[MASK:%.*]])
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast <3 x ptr> [[SRC_PTR]] to <3 x ptr addrspace(3)>
+; CHECK-NEXT:    ret <3 x ptr addrspace(3)> [[CAST]]
+;
+  %masked = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> %src.ptr, <3 x i64> %mask)
+  %cast = addrspacecast <3 x ptr> %masked to <3 x ptr addrspace(3)>
+  ret <3 x ptr addrspace(3)> %cast
+}
+
+; Casting null *does* result in null again if addrspace 0 is casted to a
+; smaller addrspace (by default we assume that casting to a smaller addrspace =
+; truncating)
+define i8 @ptrmask_cast_flat_null_to_local(i64 %mask) {
+; CHECK-LABEL: @ptrmask_cast_flat_null_to_local(
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) addrspacecast (ptr null to ptr addrspace(3)), align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %masked = call ptr @llvm.ptrmask.p0.i64(ptr null, i64 %mask)
+  %cast = addrspacecast ptr %masked to ptr addrspace(3)
+  %load = load i8, ptr addrspace(3) %cast
+  ret i8 %load
+}
+
+define i8 @ptrmask_vector_cast_flat_null_to_local(<3 x i64> %mask, i32 %ptridx, i32 %idx) {
+; CHECK-LABEL: @ptrmask_vector_cast_flat_null_to_local(
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr addrspace(3) addrspacecast (ptr null to ptr addrspace(3)), i32 [[IDX:%.*]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[GEP]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %masked = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> <ptr null, ptr null, ptr null>, <3 x i64> %mask)
+  %cast = addrspacecast <3 x ptr> %masked to <3 x ptr addrspace(3)>
+  %ptr = extractelement <3 x ptr addrspace(3)> %cast, i32 %ptridx
+  %gep = getelementptr i8, ptr addrspace(3) %ptr, i32 %idx
+  %load = load i8, ptr addrspace(3) %gep
+  ret i8 %load
+}
+
+
 define i8 @ptrmask_cast_flat_to_private(ptr %ptr, i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_flat_to_private(
 ; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 [[MASK:%.*]])
@@ -228,6 +277,19 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_fffffffeffffffff(ptr addrspace(
   ret i8 %load
 }
 
+define <3 x ptr addrspace(3)> @ptrmask_vector_cast_local_to_flat_const_mask_fffffffeffffffff(<3 x ptr addrspace(3)> %src.ptr) {
+; CHECK-LABEL: @ptrmask_vector_cast_local_to_flat_const_mask_fffffffeffffffff(
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast <3 x ptr addrspace(3)> [[SRC_PTR:%.*]] to <3 x ptr>
+; CHECK-NEXT:    [[MASKED:%.*]] = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> [[CAST]], <3 x i64> splat (i64 -4294967297))
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast <3 x ptr> [[MASKED]] to <3 x ptr addrspace(3)>
+; CHECK-NEXT:    ret <3 x ptr addrspace(3)> [[TMP1]]
+;
+  %cast = addrspacecast <3 x ptr addrspace(3)> %src.ptr to <3 x ptr>
+  %masked = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> %cast, <3 x i64> <i64 -4294967297, i64 -4294967297, i64 -4294967297>)
+  %cast2 = addrspacecast <3 x ptr> %masked to <3 x ptr addrspace(3)>
+  ret <3 x ptr addrspace(3)> %cast2
+}
+
 ; Do not fold this since it clears a single high bit.
 define i8 @ptrmask_cast_local_to_flat_const_mask_7fffffffffffffff(ptr addrspace(3) %src.ptr) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat_const_mask_7fffffffffffffff(
@@ -269,6 +331,17 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffff00000000(ptr addrspace(
   ret i8 %load
 }
 
+define <3 x ptr addrspace(3)> @ptrmask_vector_cast_local_to_flat_const_mask_ffffffff00000000(<3 x ptr addrspace(3)> %src.ptr) {
+; CHECK-LABEL: @ptrmask_vector_cast_local_to_flat_const_mask_ffffffff00000000(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <3 x ptr addrspace(3)> @llvm.ptrmask.v3p3.v3i32(<3 x ptr addrspace(3)> [[SRC_PTR:%.*]], <3 x i32> zeroinitializer)
+; CHECK-NEXT:    ret <3 x ptr addrspace(3)> [[TMP1]]
+;
+  %cast = addrspacecast <3 x ptr addrspace(3)> %src.ptr to <3 x ptr>
+  %masked = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> %cast, <3 x i64> <i64 -4294967296, i64 -4294967296, i64 -4294967296>)
+  %cast2 = addrspacecast <3 x ptr> %masked to <3 x ptr addrspace(3)>
+  ret <3 x ptr addrspace(3)> %cast2
+}
+
 define i8 @ptrmask_cast_local_null_to_flat_const_mask_ffffffff00000000() {
 ; CHECK-LABEL: @ptrmask_cast_local_null_to_flat_const_mask_ffffffff00000000(
 ; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) null, align 1
@@ -305,6 +378,17 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffffffff0000(ptr addrspace(
   ret i8 %load
 }
 
+define <3 x ptr addrspace(3)> @ptrmask_vector_cast_local_to_flat_const_mask_ffffffffffff0000(<3 x ptr addrspace(3)> %src.ptr) {
+; CHECK-LABEL: @ptrmask_vector_cast_local_to_flat_const_mask_ffffffffffff0000(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <3 x ptr addrspace(3)> @llvm.ptrmask.v3p3.v3i32(<3 x ptr addrspace(3)> [[SRC_PTR:%.*]], <3 x i32> splat (i32 -65536))
+; CHECK-NEXT:    ret <3 x ptr addrspace(3)> [[TMP1]]
+;
+  %cast = addrspacecast <3 x ptr addrspace(3)> %src.ptr to <3 x ptr>
+  %masked = call <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr> %cast, <3 x i64> <i64 -65536, i64 -65536, i64 -65536>)
+  %cast2 = addrspacecast <3 x ptr> %masked to <3 x ptr addrspace(3)>
+  ret <3 x ptr addrspace(3)> %cast2
+}
+
 define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffffffffff00(ptr addrspace(3) %src.ptr) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat_const_mask_ffffffffffffff00(
 ; CHECK-NEXT:    [[TMP1:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) [[SRC_PTR:%.*]], i32 -256)
@@ -404,6 +488,20 @@ define i8 @ptrmask_cast_local_to_flat_load_range_mask(ptr addrspace(3) %src.ptr,
   ret i8 %load
 }
 
+define <2 x ptr addrspace(3)> @ptrmask_vector_cast_local_to_flat_load_range_mask(<2 x ptr addrspace(3)> %src.ptr, ptr addrspace(1) %mask.ptr) {
+; CHECK-LABEL: @ptrmask_vector_cast_local_to_flat_load_range_mask(
+; CHECK-NEXT:    [[LOAD_MASK:%.*]] = load <2 x i64>, ptr addrspace(1) [[MASK_PTR:%.*]], align 16, !range [[RNG0]]
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i64> [[LOAD_MASK]] to <2 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = call <2 x ptr addrspace(3)> @llvm.ptrmask.v2p3.v2i32(<2 x ptr addrspace(3)> [[SRC_PTR:%.*]], <2 x i32> [[TMP1]])
+; CHECK-NEXT:    ret <2 x ptr addrspace(3)> [[TMP2]]
+;
+  %load.mask = load <2 x i64>, ptr addrspace(1) %mask.ptr, align 16, !range !0
+  %cast = addrspacecast <2 x ptr addrspace(3)> %src.ptr to <2 x ptr>
+  %masked = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> %cast, <2 x i64> %load.mask)
+  %cast2 = addrspacecast <2 x ptr> %masked to <2 x ptr addrspace(3)>
+  ret <2 x ptr addrspace(3)> %cast2
+}
+
 ; Non-const masks with no known range should not prevent other ptr-manipulating
 ; instructions (such as gep) from being converted.
 define i8 @ptrmask_cast_local_to_flat_unknown_mask(ptr addrspace(3) %src.ptr, i64 %mask, i64 %idx) {
@@ -422,21 +520,17 @@ define i8 @ptrmask_cast_local_to_flat_unknown_mask(ptr addrspace(3) %src.ptr, i6
   ret i8 %load
 }
 
-define i8 @ptrmask_vector_cast_local_to_flat_unknown_mask(<2 x ptr addrspace(3)> %src.ptr, <2 x i64> %mask, i64 %ptridx, i64 %idx) {
+define <2 x ptr addrspace(3)> @ptrmask_vector_cast_local_to_flat_unknown_mask(<2 x ptr addrspace(3)> %src.ptr, <2 x i64> %mask) {
 ; CHECK-LABEL: @ptrmask_vector_cast_local_to_flat_unknown_mask(
 ; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast <2 x ptr addrspace(3)> [[SRC_PTR:%.*]] to <2 x ptr>
 ; CHECK-NEXT:    [[MASKED:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[CAST]], <2 x i64> [[MASK:%.*]])
-; CHECK-NEXT:    [[PTR:%.*]] = extractelement <2 x ptr> [[MASKED]], i64 [[PTRIDX:%.*]]
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[IDX:%.*]]
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[GEP]], align 1
-; CHECK-NEXT:    ret i8 [[LOAD]]
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast <2 x ptr> [[MASKED]] to <2 x ptr addrspace(3)>
+; CHECK-NEXT:    ret <2 x ptr addrspace(3)> [[TMP1]]
 ;
   %cast = addrspacecast <2 x ptr addrspace(3)> %src.ptr to <2 x ptr>
   %masked = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> %cast, <2 x i64> %mask)
-  %ptr = extractelement <2 x ptr> %masked, i64 %ptridx
-  %gep = getelementptr i8, ptr %ptr, i64 %idx
-  %load = load i8, ptr %gep
-  ret i8 %load
+  %cast2 = addrspacecast <2 x ptr> %masked to <2 x ptr addrspace(3)>
+  ret <2 x ptr addrspace(3)> %cast2
 }
 
 define i8 @interleaved_ptrmask_cast_local_to_flat_unknown_mask(ptr addrspace(3) %src.ptr, i64 %mask, i64 %idx) {
@@ -459,6 +553,7 @@ define i8 @interleaved_ptrmask_cast_local_to_flat_unknown_mask(ptr addrspace(3)
 
 declare ptr @llvm.ptrmask.p0.i64(ptr, i64) #0
 declare <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr>, <2 x i64>) #0
+declare <3 x ptr> @llvm.ptrmask.v3p0.v3i64(<3 x ptr>, <3 x i64>) #0
 declare ptr addrspace(5) @llvm.ptrmask.p5.i32(ptr addrspace(5), i32) #0
 declare ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3), i32) #0
 declare ptr addrspace(1) @llvm.ptrmask.p1.i64(ptr addrspace(1), i64) #0



More information about the llvm-commits mailing list