[clang] Add combines/simplifications for `llvm.ptrmask` (PR #67166)

via cfe-commits cfe-commits at lists.llvm.org
Fri Sep 22 12:48:25 PDT 2023


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/67166

>From 3982b40f82c6cb91c9bec407cea6b7a4cdc4220e Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 22 Sep 2023 08:21:21 -0500
Subject: [PATCH 1/9] [InstSimplify] Add tests for simplify `llvm.ptrmask`; NFC

Differential Revision: https://reviews.llvm.org/D156632
---
 llvm/test/Transforms/InstSimplify/ptrmask.ll | 143 +++++++++++++++++++
 1 file changed, 143 insertions(+)
 create mode 100644 llvm/test/Transforms/InstSimplify/ptrmask.ll

diff --git a/llvm/test/Transforms/InstSimplify/ptrmask.ll b/llvm/test/Transforms/InstSimplify/ptrmask.ll
new file mode 100644
index 000000000000000..0af7186978ae5bc
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ptrmask.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
+
+declare ptr @llvm.ptrmask.p0.i32(ptr, i32)
+declare ptr @llvm.ptrmask.p0.i64(ptr, i64)
+
+define ptr @ptrmask_simplify_poison_mask(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_poison_mask
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 poison)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 poison)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_undef_mask(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_undef_mask
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 undef)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 undef)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_0_mask(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_0_mask
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 0)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 0)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_1s_mask(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_1s_mask
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -1)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -1)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_1s_mask_i32_fail(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_1s_mask_i32_fail
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -1)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -1)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_poison_ptr(i64 %m) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_poison_ptr
+; CHECK-SAME: (i64 [[M:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr poison, i64 [[M]])
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr poison, i64 %m)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_undef_ptr(i32 %m) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_undef_ptr
+; CHECK-SAME: (i32 [[M:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr undef, i32 [[M]])
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i32(ptr undef, i32 %m)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_null_ptr(i64 %m) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_null_ptr
+; CHECK-SAME: (i64 [[M:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr null, i64 [[M]])
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr null, i64 %m)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_ptrmask(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_ptrmask
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[M:%.*]] = ptrtoint ptr [[P]] to i64
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M]])
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %m = ptrtoint ptr %p to i64
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_ptrmask_i32_fail(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_ptrmask_i32_fail
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[M:%.*]] = ptrtoint ptr [[P]] to i32
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 [[M]])
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %m = ptrtoint ptr %p to i32
+  %r = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 %m)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_aligned_unused(ptr align 64 %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_aligned_unused
+; CHECK-SAME: (ptr align 64 [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_aligned_unused_i32_fail(ptr align 64 %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_aligned_unused_i32_fail
+; CHECK-SAME: (ptr align 64 [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -64)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -64)
+  ret ptr %r
+}
+
+define ptr @ptrmask_simplify_known_unused(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_simplify_known_unused
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[PM0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
+; CHECK-NEXT:    [[PGEP:%.*]] = getelementptr i8, ptr [[PM0]], i64 32
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -32)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
+  %pgep = getelementptr i8, ptr %pm0, i64 32
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %pgep, i64 -32)
+  ret ptr %r
+}

>From dc6533899cbafcb43d4227620faaaba4d9ac1059 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 22 Sep 2023 08:21:27 -0500
Subject: [PATCH 2/9] [InstSimplify] Add some basic simplifications for
 `llvm.ptrmask`

Mostly the same as `and`. We also have a check for a useless
`llvm.ptrmask` if the ptr is already known aligned.

Differential Revision: https://reviews.llvm.org/D156633
---
 llvm/lib/Analysis/InstructionSimplify.cpp     | 42 +++++++++++++++++++
 .../InferAddressSpaces/AMDGPU/ptrmask.ll      |  5 +--
 .../test/Transforms/InstCombine/align-addr.ll |  6 +--
 llvm/test/Transforms/InstSimplify/ptrmask.ll  | 22 ++++------
 4 files changed, 53 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index d8aa614cae53b10..1049845d97844a7 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6397,6 +6397,48 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
       return Constant::getNullValue(ReturnType);
     break;
   }
+  case Intrinsic::ptrmask: {
+    // Fail loudly in case this is ever changed.
+    // TODO: If vector types are supported the logic that checks if the mask is
+    // useless should be updated to use generic constants.
+    assert(!Op0->getType()->isVectorTy() && !Op1->getType()->isVectorTy() &&
+           "These simplifications where written at a time when ptrmask did not "
+           "support vector types and may not work for vectors");
+
+    // NOTE: We can't apply these simplifications based on the value of Op1
+    // because we need to preserve provenance.
+    if (isa<PoisonValue>(Op0))
+      return Op0;
+
+    if (Q.isUndefValue(Op0))
+      return Constant::getNullValue(Op0->getType());
+
+    if (match(Op0, m_Zero()))
+      return Constant::getNullValue(Op0->getType());
+
+    if (Op1->getType()->getScalarSizeInBits() ==
+        Q.DL.getPointerTypeSizeInBits(Op0->getType())) {
+      if (match(Op1, m_PtrToInt(m_Specific(Op0))))
+        return Op0;
+
+      // TODO: We may have attributes assosiated with the return value of the
+      // llvm.ptrmask intrinsic that will be lost when we just return the
+      // operand. We should try to preserve them.
+      if (match(Op1, m_AllOnes()))
+        return Op0;
+
+      const APInt *C;
+      if (match(Op1, m_APInt(C))) {
+        KnownBits PtrKnown =
+            computeKnownBits(Op0, Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+        // See if we only masking off bits we know are already zero due to
+        // alignment.
+        if ((*C | PtrKnown.Zero).isAllOnes())
+          return Op0;
+      }
+    }
+    break;
+  }
   case Intrinsic::smax:
   case Intrinsic::smin:
   case Intrinsic::umax:
diff --git a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
index c9db0656c6b7dd4..561a5ff35ba1082 100644
--- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
@@ -320,8 +320,7 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_fffffffffffffffe(ptr addrspace(
 
 define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffffffffffff(ptr addrspace(3) %src.ptr) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat_const_mask_ffffffffffffffff(
-; CHECK-NEXT:    [[TMP1:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) [[SRC_PTR:%.*]], i32 -1)
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[TMP1]], align 1
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[SRC_PTR:%.*]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %cast = addrspacecast ptr addrspace(3) %src.ptr to ptr
@@ -333,7 +332,7 @@ define i8 @ptrmask_cast_local_to_flat_const_mask_ffffffffffffffff(ptr addrspace(
 ; Make sure non-constant masks can also be handled.
 define i8 @ptrmask_cast_local_to_flat_load_range_mask(ptr addrspace(3) %src.ptr, ptr addrspace(1) %mask.ptr) {
 ; CHECK-LABEL: @ptrmask_cast_local_to_flat_load_range_mask(
-; CHECK-NEXT:    [[LOAD_MASK:%.*]] = load i64, ptr addrspace(1) [[MASK_PTR:%.*]], align 8, !range !0
+; CHECK-NEXT:    [[LOAD_MASK:%.*]] = load i64, ptr addrspace(1) [[MASK_PTR:%.*]], align 8, !range [[RNG0:![0-9]+]]
 ; CHECK-NEXT:    [[TMP1:%.*]] = trunc i64 [[LOAD_MASK]] to i32
 ; CHECK-NEXT:    [[TMP2:%.*]] = call ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3) [[SRC_PTR:%.*]], i32 [[TMP1]])
 ; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[TMP2]], align 1
diff --git a/llvm/test/Transforms/InstCombine/align-addr.ll b/llvm/test/Transforms/InstCombine/align-addr.ll
index 23f620310d7c26c..f1164ad72d07fba 100644
--- a/llvm/test/Transforms/InstCombine/align-addr.ll
+++ b/llvm/test/Transforms/InstCombine/align-addr.ll
@@ -183,8 +183,7 @@ define <16 x i8> @ptrmask_align8_ptr_align1(ptr align 1 %ptr) {
 ; TODO: Should be able to drop the ptrmask
 define <16 x i8> @ptrmask_align8_ptr_align8(ptr align 8 %ptr) {
 ; CHECK-LABEL: @ptrmask_align8_ptr_align8(
-; CHECK-NEXT:    [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -8)
-; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
+; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[PTR:%.*]], align 1
 ; CHECK-NEXT:    ret <16 x i8> [[LOAD]]
 ;
   %aligned = call ptr @llvm.ptrmask.p0.i64(ptr %ptr, i64 -8)
@@ -196,8 +195,7 @@ define <16 x i8> @ptrmask_align8_ptr_align8(ptr align 8 %ptr) {
 ; TODO: Should be able to drop the ptrmask
 define <16 x i8> @ptrmask_align8_ptr_align16(ptr align 16 %ptr) {
 ; CHECK-LABEL: @ptrmask_align8_ptr_align16(
-; CHECK-NEXT:    [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -8)
-; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
+; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[PTR:%.*]], align 1
 ; CHECK-NEXT:    ret <16 x i8> [[LOAD]]
 ;
   %aligned = call ptr @llvm.ptrmask.p0.i64(ptr %ptr, i64 -8)
diff --git a/llvm/test/Transforms/InstSimplify/ptrmask.ll b/llvm/test/Transforms/InstSimplify/ptrmask.ll
index 0af7186978ae5bc..6c219e8587f0a0d 100644
--- a/llvm/test/Transforms/InstSimplify/ptrmask.ll
+++ b/llvm/test/Transforms/InstSimplify/ptrmask.ll
@@ -37,8 +37,7 @@ define ptr @ptrmask_simplify_0_mask(ptr %p) {
 define ptr @ptrmask_simplify_1s_mask(ptr %p) {
 ; CHECK-LABEL: define ptr @ptrmask_simplify_1s_mask
 ; CHECK-SAME: (ptr [[P:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -1)
-; CHECK-NEXT:    ret ptr [[R]]
+; CHECK-NEXT:    ret ptr [[P]]
 ;
   %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -1)
   ret ptr %r
@@ -57,8 +56,7 @@ define ptr @ptrmask_simplify_1s_mask_i32_fail(ptr %p) {
 define ptr @ptrmask_simplify_poison_ptr(i64 %m) {
 ; CHECK-LABEL: define ptr @ptrmask_simplify_poison_ptr
 ; CHECK-SAME: (i64 [[M:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr poison, i64 [[M]])
-; CHECK-NEXT:    ret ptr [[R]]
+; CHECK-NEXT:    ret ptr poison
 ;
   %r = call ptr @llvm.ptrmask.p0.i64(ptr poison, i64 %m)
   ret ptr %r
@@ -67,8 +65,7 @@ define ptr @ptrmask_simplify_poison_ptr(i64 %m) {
 define ptr @ptrmask_simplify_undef_ptr(i32 %m) {
 ; CHECK-LABEL: define ptr @ptrmask_simplify_undef_ptr
 ; CHECK-SAME: (i32 [[M:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr undef, i32 [[M]])
-; CHECK-NEXT:    ret ptr [[R]]
+; CHECK-NEXT:    ret ptr null
 ;
   %r = call ptr @llvm.ptrmask.p0.i32(ptr undef, i32 %m)
   ret ptr %r
@@ -77,8 +74,7 @@ define ptr @ptrmask_simplify_undef_ptr(i32 %m) {
 define ptr @ptrmask_simplify_null_ptr(i64 %m) {
 ; CHECK-LABEL: define ptr @ptrmask_simplify_null_ptr
 ; CHECK-SAME: (i64 [[M:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr null, i64 [[M]])
-; CHECK-NEXT:    ret ptr [[R]]
+; CHECK-NEXT:    ret ptr null
 ;
   %r = call ptr @llvm.ptrmask.p0.i64(ptr null, i64 %m)
   ret ptr %r
@@ -87,9 +83,7 @@ define ptr @ptrmask_simplify_null_ptr(i64 %m) {
 define ptr @ptrmask_simplify_ptrmask(ptr %p) {
 ; CHECK-LABEL: define ptr @ptrmask_simplify_ptrmask
 ; CHECK-SAME: (ptr [[P:%.*]]) {
-; CHECK-NEXT:    [[M:%.*]] = ptrtoint ptr [[P]] to i64
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M]])
-; CHECK-NEXT:    ret ptr [[R]]
+; CHECK-NEXT:    ret ptr [[P]]
 ;
   %m = ptrtoint ptr %p to i64
   %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m)
@@ -111,8 +105,7 @@ define ptr @ptrmask_simplify_ptrmask_i32_fail(ptr %p) {
 define ptr @ptrmask_simplify_aligned_unused(ptr align 64 %p) {
 ; CHECK-LABEL: define ptr @ptrmask_simplify_aligned_unused
 ; CHECK-SAME: (ptr align 64 [[P:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
-; CHECK-NEXT:    ret ptr [[R]]
+; CHECK-NEXT:    ret ptr [[P]]
 ;
   %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
   ret ptr %r
@@ -133,8 +126,7 @@ define ptr @ptrmask_simplify_known_unused(ptr %p) {
 ; CHECK-SAME: (ptr [[P:%.*]]) {
 ; CHECK-NEXT:    [[PM0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
 ; CHECK-NEXT:    [[PGEP:%.*]] = getelementptr i8, ptr [[PM0]], i64 32
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -32)
-; CHECK-NEXT:    ret ptr [[R]]
+; CHECK-NEXT:    ret ptr [[PGEP]]
 ;
   %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
   %pgep = getelementptr i8, ptr %pm0, i64 32

>From a467064b6dcfe57e7348ea4f575d8219bcbaea45 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 22 Sep 2023 08:21:32 -0500
Subject: [PATCH 3/9] [InstCombine] Add tests for combining `llvm.ptrmask`; NFC

Differential Revision: https://reviews.llvm.org/D156634
---
 .../InstCombine/consecutive-ptrmask.ll        |  38 +++++
 llvm/test/Transforms/InstCombine/ptrmask.ll   | 154 ++++++++++++++++++
 2 files changed, 192 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/ptrmask.ll

diff --git a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
index 904c758b99306f4..adf044da92d2bdf 100644
--- a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
@@ -65,3 +65,41 @@ define ptr @fold_2x_fail_type_mismatch2(ptr %p, i64 %m0, i32 %m1) {
   %p1 = call ptr @llvm.ptrmask.p0.i32(ptr %p0, i32 %m1)
   ret ptr %p1
 }
+
+
+define ptr @fold_2x_type_mismatch_const0(ptr %p, i32 %m1) {
+; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const0
+; CHECK-SAME: (ptr [[P:%.*]], i32 [[M1:%.*]]) {
+; CHECK-NEXT:    [[P0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -128)
+; CHECK-NEXT:    [[P1:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 [[M1]])
+; CHECK-NEXT:    ret ptr [[P1]]
+;
+  %p0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -128)
+  %p1 = call ptr @llvm.ptrmask.p0.i32(ptr %p0, i32 %m1)
+  ret ptr %p1
+}
+
+define ptr @fold_2x_type_mismatch_const1(ptr %p, i64 %m0) {
+; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const1
+; CHECK-SAME: (ptr [[P:%.*]], i64 [[M0:%.*]]) {
+; CHECK-NEXT:    [[P0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M0]])
+; CHECK-NEXT:    [[P1:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 -2)
+; CHECK-NEXT:    ret ptr [[P1]]
+;
+  %p0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m0)
+  %p1 = call ptr @llvm.ptrmask.p0.i32(ptr %p0, i32 -2)
+  ret ptr %p1
+}
+
+
+define ptr @fold_2x_type_mismatch_const2(ptr %p) {
+; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const2
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[P0:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -4)
+; CHECK-NEXT:    [[P1:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 -31)
+; CHECK-NEXT:    ret ptr [[P1]]
+;
+  %p0 = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -4)
+  %p1 = call ptr @llvm.ptrmask.p0.i64(ptr %p0, i64 -31)
+  ret ptr %p1
+}
diff --git a/llvm/test/Transforms/InstCombine/ptrmask.ll b/llvm/test/Transforms/InstCombine/ptrmask.ll
new file mode 100644
index 000000000000000..b4ea72b8fb8b65d
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/ptrmask.ll
@@ -0,0 +1,154 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare ptr @llvm.ptrmask.p0.i32(ptr, i32)
+declare ptr @llvm.ptrmask.p0.i64(ptr, i64)
+
+define ptr @ptrmask_combine_consecutive_preserve_attrs(ptr %p0, i64 %m1) {
+; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs
+; CHECK-SAME: (ptr [[P0:%.*]], i64 [[M1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = and i64 [[M1]], 224
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[TMP1]])
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p0, i64 224)
+  %r = call noalias ptr @llvm.ptrmask.p0.i64(ptr %pm0, i64 %m1)
+  ret ptr %r
+}
+
+define ptr @ptrmask_combine_consecutive_preserve_attrs_fail(ptr %p0, i64 %m0) {
+; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs_fail
+; CHECK-SAME: (ptr [[P0:%.*]], i64 [[M0:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = and i64 [[M0]], 193
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[TMP1]])
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %pm0 = call noalias ptr @llvm.ptrmask.p0.i64(ptr %p0, i64 %m0)
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %pm0, i64 193)
+  ret ptr %r
+}
+
+define ptr @ptrmask_combine_consecutive_preserve_attrs_todo0(ptr %p0) {
+; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs_todo0
+; CHECK-SAME: (ptr [[P0:%.*]]) {
+; CHECK-NEXT:    [[PM0:%.*]] = call noalias ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 224)
+; CHECK-NEXT:    ret ptr [[PM0]]
+;
+  %pm0 = call noalias ptr @llvm.ptrmask.p0.i64(ptr %p0, i64 224)
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %pm0, i64 224)
+  ret ptr %r
+}
+
+define ptr @ptrmask_combine_consecutive_preserve_attrs_todo1(ptr %p0) {
+; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs_todo1
+; CHECK-SAME: (ptr [[P0:%.*]]) {
+; CHECK-NEXT:    [[PM0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 224)
+; CHECK-NEXT:    ret ptr [[PM0]]
+;
+  %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p0, i64 224)
+  %r = call noalias ptr @llvm.ptrmask.p0.i64(ptr %pm0, i64 224)
+  ret ptr %r
+}
+
+define ptr @ptrmask_combine_consecutive_preserve_attrs_todo2(ptr %p0) {
+; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs_todo2
+; CHECK-SAME: (ptr [[P0:%.*]]) {
+; CHECK-NEXT:    [[PM0:%.*]] = call noalias ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 224)
+; CHECK-NEXT:    ret ptr [[PM0]]
+;
+  %pm0 = call noalias ptr @llvm.ptrmask.p0.i32(ptr %p0, i32 224)
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %pm0, i64 224)
+  ret ptr %r
+}
+
+define ptr @ptrmask_combine_add_nonnull(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_combine_add_nonnull
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[PM0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
+; CHECK-NEXT:    [[PGEP:%.*]] = getelementptr i8, ptr [[PM0]], i64 33
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -16)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
+  %pgep = getelementptr i8, ptr %pm0, i64 33
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %pgep, i64 -16)
+  ret ptr %r
+}
+
+define ptr @ptrmask_combine_add_alignment(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_combine_add_alignment
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
+  ret ptr %r
+}
+
+define ptr @ptrmask_combine_add_alignment2(ptr align 32 %p) {
+; CHECK-LABEL: define ptr @ptrmask_combine_add_alignment2
+; CHECK-SAME: (ptr align 32 [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -64)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -64)
+  ret ptr %r
+}
+
+define ptr @ptrmask_combine_improve_alignment(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_combine_improve_alignment
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -64)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call align 32 ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -64)
+  ret ptr %r
+}
+
+define ptr @ptrmask_combine_improve_alignment_fail(ptr %p) {
+; CHECK-LABEL: define ptr @ptrmask_combine_improve_alignment_fail
+; CHECK-SAME: (ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call align 128 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -64)
+; CHECK-NEXT:    ret ptr [[R]]
+;
+  %r = call align 128 ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -64)
+  ret ptr %r
+}
+
+define i64 @ptrtoint_of_ptrmask(ptr %p, i64 %m) {
+; CHECK-LABEL: define i64 @ptrtoint_of_ptrmask
+; CHECK-SAME: (ptr [[P:%.*]], i64 [[M:%.*]]) {
+; CHECK-NEXT:    [[PM:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M]])
+; CHECK-NEXT:    [[R:%.*]] = ptrtoint ptr [[PM]] to i64
+; CHECK-NEXT:    ret i64 [[R]]
+;
+  %pm = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m)
+  %r = ptrtoint ptr %pm to i64
+  ret i64 %r
+}
+
+; This succeeds because (ptrtoint i32) gets folded to (trunc i32 (ptrtoint i64))
+define i32 @ptrtoint_of_ptrmask2(ptr %p, i64 %m) {
+; CHECK-LABEL: define i32 @ptrtoint_of_ptrmask2
+; CHECK-SAME: (ptr [[P:%.*]], i64 [[M:%.*]]) {
+; CHECK-NEXT:    [[PM:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M]])
+; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint ptr [[PM]] to i64
+; CHECK-NEXT:    [[R:%.*]] = trunc i64 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %pm = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m)
+  %r = ptrtoint ptr %pm to i32
+  ret i32 %r
+}
+
+define i64 @ptrtoint_of_ptrmask_fail(ptr %p, i32 %m) {
+; CHECK-LABEL: define i64 @ptrtoint_of_ptrmask_fail
+; CHECK-SAME: (ptr [[P:%.*]], i32 [[M:%.*]]) {
+; CHECK-NEXT:    [[PM:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 [[M]])
+; CHECK-NEXT:    [[R:%.*]] = ptrtoint ptr [[PM]] to i64
+; CHECK-NEXT:    ret i64 [[R]]
+;
+  %pm = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 %m)
+  %r = ptrtoint ptr %pm to i64
+  ret i64 %r
+}

>From 451b524c3bc425be84685c687076908c3c0ea37b Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 22 Sep 2023 08:21:38 -0500
Subject: [PATCH 4/9] [InstCombine] Make combines on `llvm.ptrmask` fail loudly
 if we have vec types; NFC

Currently vec types are unsupported by `llvm.ptrmask`. Should this
change its probably better to fail loudly rather hope the code ends up
working properly for vec types.

Differential Revision: https://reviews.llvm.org/D156635
---
 llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index c6100f24b0507de..4b9b401f3ab070d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1964,6 +1964,15 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     break;
   }
   case Intrinsic::ptrmask: {
+    Value *Op0 = II->getArgOperand(0);
+    Value *Op1 = II->getArgOperand(1);
+    // Fail loudly in case this is ever changed.
+    // TODO: If vector types are supported the merging of (ptrmask (ptrmask))
+    // need to ensure we don't merge a vectype with non-vec type.
+    assert(!Op0->getType()->isVectorTy() && !Op1->getType()->isVectorTy() &&
+           "These combines where written at a time when ptrmask did not "
+           "support vector types and may not work for vectors");
+
     Value *InnerPtr, *InnerMask;
     if (match(II->getArgOperand(0),
               m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(InnerPtr),

>From 5a82c3b7a4643c0a31b942974dea4479a94af0a8 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 22 Sep 2023 08:21:43 -0500
Subject: [PATCH 5/9] [InstCombine] Deduce `align` and `nonnull` return
 attributes for `llvm.ptrmask`

We can deduce the former based on the mask / incoming pointer
alignment.  We can set the latter based if know the result in non-zero
(this is essentially just caching our analysis result).

Differential Revision: https://reviews.llvm.org/D156636
---
 clang/test/CodeGen/arm64_32-vaarg.c           |  2 +-
 .../InstCombine/InstCombineCalls.cpp          | 22 +++++++++++++++++++
 .../test/Transforms/InstCombine/align-addr.ll | 12 +++++-----
 .../InstCombine/consecutive-ptrmask.ll        | 10 ++++-----
 llvm/test/Transforms/InstCombine/ptrmask.ll   | 18 +++++++--------
 5 files changed, 43 insertions(+), 21 deletions(-)

diff --git a/clang/test/CodeGen/arm64_32-vaarg.c b/clang/test/CodeGen/arm64_32-vaarg.c
index 9fbcf88ecfdcc33..3f1f4443436da15 100644
--- a/clang/test/CodeGen/arm64_32-vaarg.c
+++ b/clang/test/CodeGen/arm64_32-vaarg.c
@@ -29,7 +29,7 @@ long long test_longlong(OneLongLong input, va_list *mylist) {
   // CHECK-LABEL: define{{.*}} i64 @test_longlong(i64 %input
   // CHECK: [[STARTPTR:%.*]] = load ptr, ptr %mylist
   // CHECK: [[ALIGN_TMP:%.+]] = getelementptr inbounds i8, ptr [[STARTPTR]], i32 7
-  // CHECK: [[ALIGNED_ADDR:%.+]] = tail call ptr @llvm.ptrmask.p0.i32(ptr nonnull [[ALIGN_TMP]], i32 -8)
+  // CHECK: [[ALIGNED_ADDR:%.+]] = tail call align 8 ptr @llvm.ptrmask.p0.i32(ptr nonnull [[ALIGN_TMP]], i32 -8)
   // CHECK: [[NEXT:%.*]] = getelementptr inbounds i8, ptr [[ALIGNED_ADDR]], i32 8
   // CHECK: store ptr [[NEXT]], ptr %mylist
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 4b9b401f3ab070d..58b69bfa8e49b4f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1985,6 +1985,28 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
                                     {InnerPtr, NewMask}));
       }
     }
+    bool Changed = false;
+    KnownBits Known = computeKnownBits(II, /*Depth*/ 0, II);
+    // See if we can deduce non-null.
+    if (!CI.hasRetAttr(Attribute::NonNull) &&
+        (Known.isNonZero() ||
+         isKnownNonZero(II, DL, /*Depth*/ 0, &AC, II, &DT))) {
+      CI.addRetAttr(Attribute::NonNull);
+      Changed = true;
+    }
+
+    // Known bits will capture if we had alignment information assosiated with
+    // the pointer argument.
+    if (Known.countMinTrailingZeros() > Log2(CI.getRetAlign().valueOrOne())) {
+      if (CI.hasRetAttr(Attribute::Alignment))
+        CI.removeRetAttr(Attribute::Alignment);
+      CI.addRetAttr(
+          Attribute::get(CI.getContext(), Attribute::Alignment,
+                         uint64_t(1) << Known.countMinTrailingZeros()));
+      Changed = true;
+    }
+    if (Changed)
+      return &CI;
     break;
   }
   case Intrinsic::uadd_with_overflow:
diff --git a/llvm/test/Transforms/InstCombine/align-addr.ll b/llvm/test/Transforms/InstCombine/align-addr.ll
index f1164ad72d07fba..dab113ee547f1e5 100644
--- a/llvm/test/Transforms/InstCombine/align-addr.ll
+++ b/llvm/test/Transforms/InstCombine/align-addr.ll
@@ -134,7 +134,7 @@ define <16 x i8> @ptrmask_align_unknown_ptr_align1(ptr align 1 %ptr, i64 %mask)
 
 define <16 x i8> @ptrmask_align_unknown_ptr_align8(ptr align 8 %ptr, i64 %mask) {
 ; CHECK-LABEL: @ptrmask_align_unknown_ptr_align8(
-; CHECK-NEXT:    [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 [[MASK:%.*]])
+; CHECK-NEXT:    [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 [[MASK:%.*]])
 ; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
 ; CHECK-NEXT:    ret <16 x i8> [[LOAD]]
 ;
@@ -146,7 +146,7 @@ define <16 x i8> @ptrmask_align_unknown_ptr_align8(ptr align 8 %ptr, i64 %mask)
 ; Increase load align from 1 to 2
 define <16 x i8> @ptrmask_align2_ptr_align1(ptr align 1 %ptr) {
 ; CHECK-LABEL: @ptrmask_align2_ptr_align1(
-; CHECK-NEXT:    [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -2)
+; CHECK-NEXT:    [[ALIGNED:%.*]] = call align 2 ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -2)
 ; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
 ; CHECK-NEXT:    ret <16 x i8> [[LOAD]]
 ;
@@ -158,7 +158,7 @@ define <16 x i8> @ptrmask_align2_ptr_align1(ptr align 1 %ptr) {
 ; Increase load align from 1 to 4
 define <16 x i8> @ptrmask_align4_ptr_align1(ptr align 1 %ptr) {
 ; CHECK-LABEL: @ptrmask_align4_ptr_align1(
-; CHECK-NEXT:    [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -4)
+; CHECK-NEXT:    [[ALIGNED:%.*]] = call align 4 ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -4)
 ; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
 ; CHECK-NEXT:    ret <16 x i8> [[LOAD]]
 ;
@@ -170,7 +170,7 @@ define <16 x i8> @ptrmask_align4_ptr_align1(ptr align 1 %ptr) {
 ; Increase load align from 1 to 8
 define <16 x i8> @ptrmask_align8_ptr_align1(ptr align 1 %ptr) {
 ; CHECK-LABEL: @ptrmask_align8_ptr_align1(
-; CHECK-NEXT:    [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -8)
+; CHECK-NEXT:    [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i64(ptr [[PTR:%.*]], i64 -8)
 ; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
 ; CHECK-NEXT:    ret <16 x i8> [[LOAD]]
 ;
@@ -207,7 +207,7 @@ define <16 x i8> @ptrmask_align8_ptr_align16(ptr align 16 %ptr) {
 ; than the pointer size.
 define <16 x i8> @ptrmask_align8_ptr_align1_smallmask(ptr align 1 %ptr) {
 ; CHECK-LABEL: @ptrmask_align8_ptr_align1_smallmask(
-; CHECK-NEXT:    [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[PTR:%.*]], i32 -8)
+; CHECK-NEXT:    [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i32(ptr [[PTR:%.*]], i32 -8)
 ; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
 ; CHECK-NEXT:    ret <16 x i8> [[LOAD]]
 ;
@@ -220,7 +220,7 @@ define <16 x i8> @ptrmask_align8_ptr_align1_smallmask(ptr align 1 %ptr) {
 ; than the pointer size.
 define <16 x i8> @ptrmask_align8_ptr_align1_bigmask(ptr align 1 %ptr) {
 ; CHECK-LABEL: @ptrmask_align8_ptr_align1_bigmask(
-; CHECK-NEXT:    [[ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i128(ptr [[PTR:%.*]], i128 -8)
+; CHECK-NEXT:    [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i128(ptr [[PTR:%.*]], i128 -8)
 ; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
 ; CHECK-NEXT:    ret <16 x i8> [[LOAD]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
index adf044da92d2bdf..a4f75774baad7f9 100644
--- a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
@@ -70,8 +70,8 @@ define ptr @fold_2x_fail_type_mismatch2(ptr %p, i64 %m0, i32 %m1) {
 define ptr @fold_2x_type_mismatch_const0(ptr %p, i32 %m1) {
 ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const0
 ; CHECK-SAME: (ptr [[P:%.*]], i32 [[M1:%.*]]) {
-; CHECK-NEXT:    [[P0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -128)
-; CHECK-NEXT:    [[P1:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 [[M1]])
+; CHECK-NEXT:    [[P0:%.*]] = call align 128 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -128)
+; CHECK-NEXT:    [[P1:%.*]] = call align 128 ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 [[M1]])
 ; CHECK-NEXT:    ret ptr [[P1]]
 ;
   %p0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -128)
@@ -83,7 +83,7 @@ define ptr @fold_2x_type_mismatch_const1(ptr %p, i64 %m0) {
 ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const1
 ; CHECK-SAME: (ptr [[P:%.*]], i64 [[M0:%.*]]) {
 ; CHECK-NEXT:    [[P0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M0]])
-; CHECK-NEXT:    [[P1:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 -2)
+; CHECK-NEXT:    [[P1:%.*]] = call align 2 ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 -2)
 ; CHECK-NEXT:    ret ptr [[P1]]
 ;
   %p0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m0)
@@ -95,8 +95,8 @@ define ptr @fold_2x_type_mismatch_const1(ptr %p, i64 %m0) {
 define ptr @fold_2x_type_mismatch_const2(ptr %p) {
 ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const2
 ; CHECK-SAME: (ptr [[P:%.*]]) {
-; CHECK-NEXT:    [[P0:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -4)
-; CHECK-NEXT:    [[P1:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 -31)
+; CHECK-NEXT:    [[P0:%.*]] = call align 4 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -4)
+; CHECK-NEXT:    [[P1:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 -31)
 ; CHECK-NEXT:    ret ptr [[P1]]
 ;
   %p0 = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -4)
diff --git a/llvm/test/Transforms/InstCombine/ptrmask.ll b/llvm/test/Transforms/InstCombine/ptrmask.ll
index b4ea72b8fb8b65d..2fd055650a863b4 100644
--- a/llvm/test/Transforms/InstCombine/ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/ptrmask.ll
@@ -8,7 +8,7 @@ define ptr @ptrmask_combine_consecutive_preserve_attrs(ptr %p0, i64 %m1) {
 ; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs
 ; CHECK-SAME: (ptr [[P0:%.*]], i64 [[M1:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i64 [[M1]], 224
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[TMP1]])
+; CHECK-NEXT:    [[R:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[TMP1]])
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p0, i64 224)
@@ -31,7 +31,7 @@ define ptr @ptrmask_combine_consecutive_preserve_attrs_fail(ptr %p0, i64 %m0) {
 define ptr @ptrmask_combine_consecutive_preserve_attrs_todo0(ptr %p0) {
 ; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs_todo0
 ; CHECK-SAME: (ptr [[P0:%.*]]) {
-; CHECK-NEXT:    [[PM0:%.*]] = call noalias ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 224)
+; CHECK-NEXT:    [[PM0:%.*]] = call noalias align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 224)
 ; CHECK-NEXT:    ret ptr [[PM0]]
 ;
   %pm0 = call noalias ptr @llvm.ptrmask.p0.i64(ptr %p0, i64 224)
@@ -42,7 +42,7 @@ define ptr @ptrmask_combine_consecutive_preserve_attrs_todo0(ptr %p0) {
 define ptr @ptrmask_combine_consecutive_preserve_attrs_todo1(ptr %p0) {
 ; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs_todo1
 ; CHECK-SAME: (ptr [[P0:%.*]]) {
-; CHECK-NEXT:    [[PM0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 224)
+; CHECK-NEXT:    [[PM0:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 224)
 ; CHECK-NEXT:    ret ptr [[PM0]]
 ;
   %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p0, i64 224)
@@ -53,7 +53,7 @@ define ptr @ptrmask_combine_consecutive_preserve_attrs_todo1(ptr %p0) {
 define ptr @ptrmask_combine_consecutive_preserve_attrs_todo2(ptr %p0) {
 ; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs_todo2
 ; CHECK-SAME: (ptr [[P0:%.*]]) {
-; CHECK-NEXT:    [[PM0:%.*]] = call noalias ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 224)
+; CHECK-NEXT:    [[PM0:%.*]] = call noalias align 32 ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 224)
 ; CHECK-NEXT:    ret ptr [[PM0]]
 ;
   %pm0 = call noalias ptr @llvm.ptrmask.p0.i32(ptr %p0, i32 224)
@@ -64,9 +64,9 @@ define ptr @ptrmask_combine_consecutive_preserve_attrs_todo2(ptr %p0) {
 define ptr @ptrmask_combine_add_nonnull(ptr %p) {
 ; CHECK-LABEL: define ptr @ptrmask_combine_add_nonnull
 ; CHECK-SAME: (ptr [[P:%.*]]) {
-; CHECK-NEXT:    [[PM0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
+; CHECK-NEXT:    [[PM0:%.*]] = call align 64 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
 ; CHECK-NEXT:    [[PGEP:%.*]] = getelementptr i8, ptr [[PM0]], i64 33
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -16)
+; CHECK-NEXT:    [[R:%.*]] = call nonnull align 32 ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -16)
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
@@ -78,7 +78,7 @@ define ptr @ptrmask_combine_add_nonnull(ptr %p) {
 define ptr @ptrmask_combine_add_alignment(ptr %p) {
 ; CHECK-LABEL: define ptr @ptrmask_combine_add_alignment
 ; CHECK-SAME: (ptr [[P:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
+; CHECK-NEXT:    [[R:%.*]] = call align 64 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %r = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
@@ -88,7 +88,7 @@ define ptr @ptrmask_combine_add_alignment(ptr %p) {
 define ptr @ptrmask_combine_add_alignment2(ptr align 32 %p) {
 ; CHECK-LABEL: define ptr @ptrmask_combine_add_alignment2
 ; CHECK-SAME: (ptr align 32 [[P:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -64)
+; CHECK-NEXT:    [[R:%.*]] = call align 64 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -64)
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %r = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -64)
@@ -98,7 +98,7 @@ define ptr @ptrmask_combine_add_alignment2(ptr align 32 %p) {
 define ptr @ptrmask_combine_improve_alignment(ptr %p) {
 ; CHECK-LABEL: define ptr @ptrmask_combine_improve_alignment
 ; CHECK-SAME: (ptr [[P:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -64)
+; CHECK-NEXT:    [[R:%.*]] = call align 64 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -64)
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %r = call align 32 ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -64)

>From 6d8e8c2d404d8891716e9026776ae72b207bcbb3 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 22 Sep 2023 08:21:48 -0500
Subject: [PATCH 6/9] [InstCombine] Implement `SimplifyDemandedBits` for
 `llvm.ptrmask`

Logic basically copies 'and' but we can't return a constant if the
result == `rhs` (mask) so that case is skipped.

Differential Revision: https://reviews.llvm.org/D156637
---
 .../InstCombine/InstCombineCalls.cpp          | 14 ++--
 .../InstCombine/InstCombineInternal.h         |  1 +
 .../InstCombineSimplifyDemanded.cpp           | 69 +++++++++++++++++--
 .../test/Transforms/InstCombine/align-addr.ll |  2 +-
 .../InstCombine/consecutive-ptrmask.ll        |  2 +-
 llvm/test/Transforms/InstCombine/ptrmask.ll   |  2 +-
 6 files changed, 78 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 58b69bfa8e49b4f..5c1de2af2282299 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1964,6 +1964,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     break;
   }
   case Intrinsic::ptrmask: {
+    KnownBits Known(DL.getPointerTypeSizeInBits(II->getType()));
+    if (SimplifyDemandedInstructionBits(*II, Known))
+      return II;
+
     Value *Op0 = II->getArgOperand(0);
     Value *Op1 = II->getArgOperand(1);
     // Fail loudly in case this is ever changed.
@@ -1985,8 +1989,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
                                     {InnerPtr, NewMask}));
       }
     }
+
     bool Changed = false;
-    KnownBits Known = computeKnownBits(II, /*Depth*/ 0, II);
     // See if we can deduce non-null.
     if (!CI.hasRetAttr(Attribute::NonNull) &&
         (Known.isNonZero() ||
@@ -2000,9 +2004,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Known.countMinTrailingZeros() > Log2(CI.getRetAlign().valueOrOne())) {
       if (CI.hasRetAttr(Attribute::Alignment))
         CI.removeRetAttr(Attribute::Alignment);
-      CI.addRetAttr(
-          Attribute::get(CI.getContext(), Attribute::Alignment,
-                         uint64_t(1) << Known.countMinTrailingZeros()));
+      CI.addRetAttr(Attribute::get(
+          CI.getContext(), Attribute::Alignment,
+          uint64_t(1) << (Known.isZero() ? (Known.getBitWidth() - 1)
+                                         : Known.countMinTrailingZeros())));
+
       Changed = true;
     }
     if (Changed)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 40c24d87bfec508..0b03aa48fa0c0cc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -526,6 +526,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   /// Tries to simplify operands to an integer instruction based on its
   /// demanded bits.
   bool SimplifyDemandedInstructionBits(Instruction &Inst);
+  bool SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known);
 
   Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
                                     APInt &UndefElts, unsigned Depth = 0,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index be005e61a8d2d89..7902eb0c6c829cc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -48,15 +48,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
   return true;
 }
 
+/// Returns the bitwidth of the given scalar or pointer type. For vector types,
+/// returns the element type's bitwidth.
+static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
+  if (unsigned BitWidth = Ty->getScalarSizeInBits())
+    return BitWidth;
 
+  return DL.getPointerTypeSizeInBits(Ty);
+}
 
 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
 /// the instruction has any properties that allow us to simplify its operands.
-bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
-  unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
-  KnownBits Known(BitWidth);
-  APInt DemandedMask(APInt::getAllOnes(BitWidth));
-
+bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst,
+                                                       KnownBits &Known) {
+  APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth()));
   Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
                                      0, &Inst);
   if (!V) return false;
@@ -65,6 +70,13 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
   return true;
 }
 
+/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
+/// the instruction has any properties that allow us to simplify its operands.
+bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
+  KnownBits Known(getBitWidth(Inst.getType(), DL));
+  return SimplifyDemandedInstructionBits(Inst, Known);
+}
+
 /// This form of SimplifyDemandedBits simplifies the specified instruction
 /// operand if possible, updating it in place. It returns true if it made any
 /// change and false otherwise.
@@ -898,6 +910,53 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
         }
         break;
       }
+      case Intrinsic::ptrmask: {
+        // Fail loudly in case this is ever changed.
+        // Likely not much needs to be changed here to support vector types.
+        assert(!I->getOperand(0)->getType()->isVectorTy() &&
+               !I->getOperand(1)->getType()->isVectorTy() &&
+               "These simplifications where written at a time when ptrmask did "
+               "not support vector types and may not work for vectors");
+
+        unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits();
+        RHSKnown = KnownBits(MaskWidth);
+        // If either the LHS or the RHS are Zero, the result is zero.
+        if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) ||
+            SimplifyDemandedBits(
+                I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth),
+                RHSKnown, Depth + 1))
+          return I;
+
+        RHSKnown = RHSKnown.zextOrTrunc(BitWidth);
+        assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
+        assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
+
+        Known = LHSKnown & RHSKnown;
+        KnownBitsComputed = DemandedMask.isAllOnes();
+
+        // If the client is only demanding bits we know to be zero, return
+        // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer
+        // provenance, but making the mask zero will be easily optimizable in
+        // the backend.
+        if (DemandedMask.isSubsetOf(Known.Zero))
+          return replaceOperand(
+              *I, 1, Constant::getNullValue(I->getOperand(1)->getType()));
+
+        // Mask in demanded space does nothing.
+        // TODO: We may have attributes assosiated with the return value of the
+        // llvm.ptrmask intrinsic that will be lost when we just return the
+        // operand. We should try to preserve them.
+        if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
+          return I->getOperand(0);
+
+        // If the RHS is a constant, see if we can simplify it.
+        if (ShrinkDemandedConstant(
+                I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth)))
+          return I;
+
+        break;
+      }
+
       case Intrinsic::fshr:
       case Intrinsic::fshl: {
         const APInt *SA;
diff --git a/llvm/test/Transforms/InstCombine/align-addr.ll b/llvm/test/Transforms/InstCombine/align-addr.ll
index dab113ee547f1e5..f3b33f013a9b5d4 100644
--- a/llvm/test/Transforms/InstCombine/align-addr.ll
+++ b/llvm/test/Transforms/InstCombine/align-addr.ll
@@ -220,7 +220,7 @@ define <16 x i8> @ptrmask_align8_ptr_align1_smallmask(ptr align 1 %ptr) {
 ; than the pointer size.
 define <16 x i8> @ptrmask_align8_ptr_align1_bigmask(ptr align 1 %ptr) {
 ; CHECK-LABEL: @ptrmask_align8_ptr_align1_bigmask(
-; CHECK-NEXT:    [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i128(ptr [[PTR:%.*]], i128 -8)
+; CHECK-NEXT:    [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i128(ptr [[PTR:%.*]], i128 18446744073709551608)
 ; CHECK-NEXT:    [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 1
 ; CHECK-NEXT:    ret <16 x i8> [[LOAD]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
index a4f75774baad7f9..c5ee06c3902fae1 100644
--- a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
@@ -96,7 +96,7 @@ define ptr @fold_2x_type_mismatch_const2(ptr %p) {
 ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const2
 ; CHECK-SAME: (ptr [[P:%.*]]) {
 ; CHECK-NEXT:    [[P0:%.*]] = call align 4 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -4)
-; CHECK-NEXT:    [[P1:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 -31)
+; CHECK-NEXT:    [[P1:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 4294967264)
 ; CHECK-NEXT:    ret ptr [[P1]]
 ;
   %p0 = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -4)
diff --git a/llvm/test/Transforms/InstCombine/ptrmask.ll b/llvm/test/Transforms/InstCombine/ptrmask.ll
index 2fd055650a863b4..91fbd373f57b372 100644
--- a/llvm/test/Transforms/InstCombine/ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/ptrmask.ll
@@ -66,7 +66,7 @@ define ptr @ptrmask_combine_add_nonnull(ptr %p) {
 ; CHECK-SAME: (ptr [[P:%.*]]) {
 ; CHECK-NEXT:    [[PM0:%.*]] = call align 64 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
 ; CHECK-NEXT:    [[PGEP:%.*]] = getelementptr i8, ptr [[PM0]], i64 33
-; CHECK-NEXT:    [[R:%.*]] = call nonnull align 32 ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -16)
+; CHECK-NEXT:    [[R:%.*]] = call nonnull align 32 ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -32)
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)

>From 8229697553f424b075dcfd5e56704ec99c9939a2 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 22 Sep 2023 08:21:53 -0500
Subject: [PATCH 7/9] [InstCombine] Preserve return attributes when merging
 `llvm.ptrmask`

If we have assosiated attributes i.e `([ret_attrs] (ptrmask (ptrmask
p0, m0), m1))` we should preserve `[ret_attrs]` when combining the two
`llvm.ptrmask`s.

Differential Revision: https://reviews.llvm.org/D156638
---
 .../InstCombine/InstCombineCalls.cpp          | 23 +++++++++++--------
 llvm/test/Transforms/InstCombine/ptrmask.ll   |  2 +-
 2 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 5c1de2af2282299..8fa61e3804933f2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1978,19 +1978,22 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
            "support vector types and may not work for vectors");
 
     Value *InnerPtr, *InnerMask;
-    if (match(II->getArgOperand(0),
-              m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(InnerPtr),
-                                                       m_Value(InnerMask))))) {
-      if (II->getArgOperand(1)->getType() == InnerMask->getType()) {
-        Value *NewMask = Builder.CreateAnd(II->getArgOperand(1), InnerMask);
-        return replaceInstUsesWith(
-            *II,
-            Builder.CreateIntrinsic(InnerPtr->getType(), Intrinsic::ptrmask,
-                                    {InnerPtr, NewMask}));
+    bool Changed = false;
+    // Combine:
+    // (ptrmask (ptrmask p, A), B)
+    //    -> (ptrmask p, (and A, B))
+    if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(
+                       m_Value(InnerPtr), m_Value(InnerMask))))) {
+      if (InnerMask->getType() == Op1->getType()) {
+        // TODO: If InnerMask == Op1, we could copy attributes from inner
+        // callsite -> outer callsite.
+        Value *NewMask = Builder.CreateAnd(Op1, InnerMask);
+        replaceOperand(CI, 0, InnerPtr);
+        replaceOperand(CI, 1, NewMask);
+        Changed = true;
       }
     }
 
-    bool Changed = false;
     // See if we can deduce non-null.
     if (!CI.hasRetAttr(Attribute::NonNull) &&
         (Known.isNonZero() ||
diff --git a/llvm/test/Transforms/InstCombine/ptrmask.ll b/llvm/test/Transforms/InstCombine/ptrmask.ll
index 91fbd373f57b372..7b4b534b5b2f920 100644
--- a/llvm/test/Transforms/InstCombine/ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/ptrmask.ll
@@ -8,7 +8,7 @@ define ptr @ptrmask_combine_consecutive_preserve_attrs(ptr %p0, i64 %m1) {
 ; CHECK-LABEL: define ptr @ptrmask_combine_consecutive_preserve_attrs
 ; CHECK-SAME: (ptr [[P0:%.*]], i64 [[M1:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i64 [[M1]], 224
-; CHECK-NEXT:    [[R:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[TMP1]])
+; CHECK-NEXT:    [[R:%.*]] = call noalias align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[TMP1]])
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p0, i64 224)

>From 00db488c1884a108df26e2580593043b13ccc04e Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 22 Sep 2023 08:21:59 -0500
Subject: [PATCH 8/9] [InstCombine] Merge consecutive `llvm.ptrmask` with
 different mask types if a mask is constant.

We can `zext` / `trunc` a constant mask for free, so if one of the
masks if constant, we can proceed even if they are not the same type.

Differential Revision: https://reviews.llvm.org/D156639
---
 .../InstCombine/InstCombineCalls.cpp          | 20 +++++++++++++++++++
 .../InstCombine/consecutive-ptrmask.ll        | 11 +++++-----
 2 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 8fa61e3804933f2..e0a8b30b34a8ff0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1984,10 +1984,30 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     //    -> (ptrmask p, (and A, B))
     if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(
                        m_Value(InnerPtr), m_Value(InnerMask))))) {
+      // See if combining the two masks is free.
+      bool OkayToMerge = InnerMask->getType() == Op1->getType();
+      bool NeedsNew = false;
+      if (!OkayToMerge) {
+        if (match(InnerMask, m_ImmConstant())) {
+          InnerMask = Builder.CreateZExtOrTrunc(InnerMask, Op1->getType());
+          OkayToMerge = true;
+        } else if (match(Op1, m_ImmConstant())) {
+          Op1 = Builder.CreateZExtOrTrunc(Op1, InnerMask->getType());
+          OkayToMerge = true;
+          // Need to create a new one here, as the intrinsic id needs to change.
+          NeedsNew = true;
+        }
+      }
       if (InnerMask->getType() == Op1->getType()) {
         // TODO: If InnerMask == Op1, we could copy attributes from inner
         // callsite -> outer callsite.
         Value *NewMask = Builder.CreateAnd(Op1, InnerMask);
+        if (NeedsNew)
+          return replaceInstUsesWith(
+              *II,
+              Builder.CreateIntrinsic(InnerPtr->getType(), Intrinsic::ptrmask,
+                                      {InnerPtr, NewMask}));
+
         replaceOperand(CI, 0, InnerPtr);
         replaceOperand(CI, 1, NewMask);
         Changed = true;
diff --git a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
index c5ee06c3902fae1..0723fe2fd4239c4 100644
--- a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll
@@ -70,8 +70,8 @@ define ptr @fold_2x_fail_type_mismatch2(ptr %p, i64 %m0, i32 %m1) {
 define ptr @fold_2x_type_mismatch_const0(ptr %p, i32 %m1) {
 ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const0
 ; CHECK-SAME: (ptr [[P:%.*]], i32 [[M1:%.*]]) {
-; CHECK-NEXT:    [[P0:%.*]] = call align 128 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -128)
-; CHECK-NEXT:    [[P1:%.*]] = call align 128 ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 [[M1]])
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[M1]], -128
+; CHECK-NEXT:    [[P1:%.*]] = call align 128 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 [[TMP1]])
 ; CHECK-NEXT:    ret ptr [[P1]]
 ;
   %p0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -128)
@@ -82,8 +82,8 @@ define ptr @fold_2x_type_mismatch_const0(ptr %p, i32 %m1) {
 define ptr @fold_2x_type_mismatch_const1(ptr %p, i64 %m0) {
 ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const1
 ; CHECK-SAME: (ptr [[P:%.*]], i64 [[M0:%.*]]) {
-; CHECK-NEXT:    [[P0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M0]])
-; CHECK-NEXT:    [[P1:%.*]] = call align 2 ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 -2)
+; CHECK-NEXT:    [[TMP1:%.*]] = and i64 [[M0]], 4294967294
+; CHECK-NEXT:    [[P1:%.*]] = call align 2 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[TMP1]])
 ; CHECK-NEXT:    ret ptr [[P1]]
 ;
   %p0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m0)
@@ -95,8 +95,7 @@ define ptr @fold_2x_type_mismatch_const1(ptr %p, i64 %m0) {
 define ptr @fold_2x_type_mismatch_const2(ptr %p) {
 ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const2
 ; CHECK-SAME: (ptr [[P:%.*]]) {
-; CHECK-NEXT:    [[P0:%.*]] = call align 4 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -4)
-; CHECK-NEXT:    [[P1:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 4294967264)
+; CHECK-NEXT:    [[P1:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 4294967264)
 ; CHECK-NEXT:    ret ptr [[P1]]
 ;
   %p0 = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -4)

>From a6f3bc1d086e953d752504c79ab80d2f51330b8d Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 22 Sep 2023 08:22:01 -0500
Subject: [PATCH 9/9] [InstCombine] Fold `(ptrtoint (ptrmask p0, m0))` -> `(and
 (ptrtoint p0), m0)`

`and` is generally more supported so if we have a `ptrmask` anyways
might as well use `and`.

Differential Revision: https://reviews.llvm.org/D156640
---
 llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp |  9 +++++++++
 llvm/test/Transforms/InstCombine/ptrmask.ll          | 10 +++++-----
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 8ff61ab36da307b..508609670b9ac88 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1955,6 +1955,15 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
     return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
   }
 
+  // (ptrtoint (ptrmask P, M))
+  //    -> (and (ptrtoint P), M)
+  // This is generally beneficial as `and` is better supported than `ptrmask`.
+  Value *Ptr, *Mask;
+  if (match(SrcOp, m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(Ptr),
+                                                            m_Value(Mask)))) &&
+      Mask->getType() == Ty)
+    return BinaryOperator::CreateAnd(Builder.CreatePtrToInt(Ptr, Ty), Mask);
+
   if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) {
     // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use.
     // While this can increase the number of instructions it doesn't actually
diff --git a/llvm/test/Transforms/InstCombine/ptrmask.ll b/llvm/test/Transforms/InstCombine/ptrmask.ll
index 7b4b534b5b2f920..f4839f546f82746 100644
--- a/llvm/test/Transforms/InstCombine/ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/ptrmask.ll
@@ -118,8 +118,8 @@ define ptr @ptrmask_combine_improve_alignment_fail(ptr %p) {
 define i64 @ptrtoint_of_ptrmask(ptr %p, i64 %m) {
 ; CHECK-LABEL: define i64 @ptrtoint_of_ptrmask
 ; CHECK-SAME: (ptr [[P:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT:    [[PM:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M]])
-; CHECK-NEXT:    [[R:%.*]] = ptrtoint ptr [[PM]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint ptr [[P]] to i64
+; CHECK-NEXT:    [[R:%.*]] = and i64 [[TMP1]], [[M]]
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
   %pm = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m)
@@ -131,9 +131,9 @@ define i64 @ptrtoint_of_ptrmask(ptr %p, i64 %m) {
 define i32 @ptrtoint_of_ptrmask2(ptr %p, i64 %m) {
 ; CHECK-LABEL: define i32 @ptrtoint_of_ptrmask2
 ; CHECK-SAME: (ptr [[P:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT:    [[PM:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M]])
-; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint ptr [[PM]] to i64
-; CHECK-NEXT:    [[R:%.*]] = trunc i64 [[TMP1]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint ptr [[P]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = and i64 [[TMP1]], [[M]]
+; CHECK-NEXT:    [[R:%.*]] = trunc i64 [[TMP2]] to i32
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %pm = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m)



More information about the cfe-commits mailing list