[llvm] edb9e9a - [InstCombine] Implement `SimplifyDemandedBits` for `llvm.ptrmask`

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 1 21:50:54 PDT 2023


Author: Noah Goldstein
Date: 2023-11-01T23:50:35-05:00
New Revision: edb9e9a5fb3cae861511f9a11e6d000e8a82fe91

URL: https://github.com/llvm/llvm-project/commit/edb9e9a5fb3cae861511f9a11e6d000e8a82fe91
DIFF: https://github.com/llvm/llvm-project/commit/edb9e9a5fb3cae861511f9a11e6d000e8a82fe91.diff

LOG: [InstCombine] Implement `SimplifyDemandedBits` for `llvm.ptrmask`

Logic basically copies 'and' but we can't return a constant if the
result == rhs (mask) so that case is skipped.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
    llvm/test/Transforms/InstCombine/ptrmask.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 7e585e9166247cd..5c08ab190eba476 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1962,6 +1962,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     break;
   }
   case Intrinsic::ptrmask: {
+    unsigned BitWidth = DL.getPointerTypeSizeInBits(II->getType());
+    KnownBits Known(BitWidth);
+    if (SimplifyDemandedInstructionBits(*II, Known))
+      return II;
+
     Value *InnerPtr, *InnerMask;
     if (match(II->getArgOperand(0),
               m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(InnerPtr),

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 01c89ea06f2d9df..34b10220ec88aba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -544,6 +544,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   /// Tries to simplify operands to an integer instruction based on its
   /// demanded bits.
   bool SimplifyDemandedInstructionBits(Instruction &Inst);
+  bool SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known);
 
   Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
                                     APInt &UndefElts, unsigned Depth = 0,

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index cd6b017874e8d6c..fa6fe9d30abd1b8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -48,15 +48,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
   return true;
 }
 
+/// Returns the bitwidth of the given scalar or pointer type. For vector types,
+/// returns the element type's bitwidth.
+static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
+  if (unsigned BitWidth = Ty->getScalarSizeInBits())
+    return BitWidth;
 
+  return DL.getPointerTypeSizeInBits(Ty);
+}
 
 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
 /// the instruction has any properties that allow us to simplify its operands.
-bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
-  unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
-  KnownBits Known(BitWidth);
-  APInt DemandedMask(APInt::getAllOnes(BitWidth));
-
+bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst,
+                                                       KnownBits &Known) {
+  APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth()));
   Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
                                      0, &Inst);
   if (!V) return false;
@@ -65,6 +70,13 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
   return true;
 }
 
+/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
+/// the instruction has any properties that allow us to simplify its operands.
+bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
+  KnownBits Known(getBitWidth(Inst.getType(), DL));
+  return SimplifyDemandedInstructionBits(Inst, Known);
+}
+
 /// This form of SimplifyDemandedBits simplifies the specified instruction
 /// operand if possible, updating it in place. It returns true if it made any
 /// change and false otherwise.
@@ -143,7 +155,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI);
 
   KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);
-
   // If this is the root being simplified, allow it to have multiple uses,
   // just set the DemandedMask to all bits so that we can try to simplify the
   // operands.  This allows visitTruncInst (for example) to simplify the
@@ -893,6 +904,48 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
         }
         break;
       }
+      case Intrinsic::ptrmask: {
+        unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits();
+        RHSKnown = KnownBits(MaskWidth);
+        // If either the LHS or the RHS are Zero, the result is zero.
+        if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) ||
+            SimplifyDemandedBits(
+                I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth),
+                RHSKnown, Depth + 1))
+          return I;
+
+        // TODO: Should be 1-extend
+        RHSKnown = RHSKnown.anyextOrTrunc(BitWidth);
+        assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
+        assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
+
+        Known = LHSKnown & RHSKnown;
+        KnownBitsComputed = true;
+
+        // If the client is only demanding bits we know to be zero, return
+        // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer
+        // provenance, but making the mask zero will be easily optimizable in
+        // the backend.
+        if (DemandedMask.isSubsetOf(Known.Zero) &&
+            !match(I->getOperand(1), m_Zero()))
+          return replaceOperand(
+              *I, 1, Constant::getNullValue(I->getOperand(1)->getType()));
+
+        // Mask in demanded space does nothing.
+        // NOTE: We may have attributes associated with the return value of the
+        // llvm.ptrmask intrinsic that will be lost when we just return the
+        // operand. We should try to preserve them.
+        if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
+          return I->getOperand(0);
+
+        // If the RHS is a constant, see if we can simplify it.
+        if (ShrinkDemandedConstant(
+                I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth)))
+          return I;
+
+        break;
+      }
+
       case Intrinsic::fshr:
       case Intrinsic::fshl: {
         const APInt *SA;
@@ -978,8 +1031,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
   }
 
   // If the client is only demanding bits that we know, return the known
-  // constant.
-  if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
+  // constant. We can't directly simplify pointers as a constant because of
+  // pointer provenance.
+  // TODO: We could return `(inttoptr const)` for pointers.
+  if (!V->getType()->isPointerTy() && DemandedMask.isSubsetOf(Known.Zero | Known.One))
     return Constant::getIntegerValue(VTy, Known.One);
   return nullptr;
 }

diff  --git a/llvm/test/Transforms/InstCombine/ptrmask.ll b/llvm/test/Transforms/InstCombine/ptrmask.ll
index 222f6e9ae9c1cc4..b036b8bcf14cddc 100644
--- a/llvm/test/Transforms/InstCombine/ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/ptrmask.ll
@@ -82,7 +82,7 @@ define ptr @ptrmask_combine_add_nonnull(ptr %p) {
 ; CHECK-SAME: (ptr [[P:%.*]]) {
 ; CHECK-NEXT:    [[PM0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
 ; CHECK-NEXT:    [[PGEP:%.*]] = getelementptr i8, ptr [[PM0]], i64 33
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -16)
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -32)
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
@@ -230,7 +230,7 @@ define <2 x i32> @ptrtoint_of_ptrmask_vec_fail(<2 x ptr addrspace(1) > %p, <2 x
 define ptr addrspace(1) @ptrmask_is_null(ptr addrspace(1) align 32 %p) {
 ; CHECK-LABEL: define ptr addrspace(1) @ptrmask_is_null
 ; CHECK-SAME: (ptr addrspace(1) align 32 [[P:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) [[P]], i32 31)
+; CHECK-NEXT:    [[R:%.*]] = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) [[P]], i32 0)
 ; CHECK-NEXT:    ret ptr addrspace(1) [[R]]
 ;
   %r = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) %p, i32 31)
@@ -250,7 +250,7 @@ define <2 x ptr addrspace(1) > @ptrmask_is_null_vec(<2 x ptr addrspace(1) > alig
 define ptr addrspace(1) @ptrmask_is_null_fail(ptr addrspace(1) align 16 %p) {
 ; CHECK-LABEL: define ptr addrspace(1) @ptrmask_is_null_fail
 ; CHECK-SAME: (ptr addrspace(1) align 16 [[P:%.*]]) {
-; CHECK-NEXT:    [[R:%.*]] = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) [[P]], i32 31)
+; CHECK-NEXT:    [[R:%.*]] = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) [[P]], i32 16)
 ; CHECK-NEXT:    ret ptr addrspace(1) [[R]]
 ;
   %r = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) %p, i32 31)
@@ -290,10 +290,9 @@ define ptr addrspace(1) @ptrmask_maintain_provenance_i32(ptr addrspace(1) %p0) {
 define ptr @ptrmask_is_useless0(i64 %i, i64 %m) {
 ; CHECK-LABEL: define ptr @ptrmask_is_useless0
 ; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT:    [[M0:%.*]] = and i64 [[M]], -4
 ; CHECK-NEXT:    [[I0:%.*]] = and i64 [[I]], -4
 ; CHECK-NEXT:    [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M]])
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %m0 = and i64 %m, -4
@@ -306,10 +305,9 @@ define ptr @ptrmask_is_useless0(i64 %i, i64 %m) {
 define ptr @ptrmask_is_useless1(i64 %i, i64 %m) {
 ; CHECK-LABEL: define ptr @ptrmask_is_useless1
 ; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT:    [[M0:%.*]] = and i64 [[M]], -4
 ; CHECK-NEXT:    [[I0:%.*]] = and i64 [[I]], -8
 ; CHECK-NEXT:    [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M]])
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %m0 = and i64 %m, -4
@@ -322,10 +320,9 @@ define ptr @ptrmask_is_useless1(i64 %i, i64 %m) {
 define ptr @ptrmask_is_useless2(i64 %i, i64 %m) {
 ; CHECK-LABEL: define ptr @ptrmask_is_useless2
 ; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT:    [[M0:%.*]] = and i64 [[M]], 127
 ; CHECK-NEXT:    [[I0:%.*]] = and i64 [[I]], 31
 ; CHECK-NEXT:    [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M]])
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %m0 = and i64 %m, 127
@@ -338,10 +335,9 @@ define ptr @ptrmask_is_useless2(i64 %i, i64 %m) {
 define ptr @ptrmask_is_useless3(i64 %i, i64 %m) {
 ; CHECK-LABEL: define ptr @ptrmask_is_useless3
 ; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT:    [[M0:%.*]] = and i64 [[M]], 127
 ; CHECK-NEXT:    [[I0:%.*]] = and i64 [[I]], 127
 ; CHECK-NEXT:    [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
+; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M]])
 ; CHECK-NEXT:    ret ptr [[R]]
 ;
   %m0 = and i64 %m, 127
@@ -354,11 +350,9 @@ define ptr @ptrmask_is_useless3(i64 %i, i64 %m) {
 define ptr @ptrmask_is_useless4(i64 %i, i64 %m) {
 ; CHECK-LABEL: define ptr @ptrmask_is_useless4
 ; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT:    [[M0:%.*]] = or i64 [[M]], -4
 ; CHECK-NEXT:    [[I0:%.*]] = and i64 [[I]], -4
 ; CHECK-NEXT:    [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT:    [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
-; CHECK-NEXT:    ret ptr [[R]]
+; CHECK-NEXT:    ret ptr [[P0]]
 ;
   %m0 = or i64 %m, -4
   %i0 = and i64 %i, -4
@@ -370,10 +364,9 @@ define ptr @ptrmask_is_useless4(i64 %i, i64 %m) {
 define <2 x ptr> @ptrmask_is_useless_vec(<2 x i64> %i, <2 x i64> %m) {
 ; CHECK-LABEL: define <2 x ptr> @ptrmask_is_useless_vec
 ; CHECK-SAME: (<2 x i64> [[I:%.*]], <2 x i64> [[M:%.*]]) {
-; CHECK-NEXT:    [[M0:%.*]] = and <2 x i64> [[M]], <i64 127, i64 127>
 ; CHECK-NEXT:    [[I0:%.*]] = and <2 x i64> [[I]], <i64 31, i64 31>
 ; CHECK-NEXT:    [[P0:%.*]] = inttoptr <2 x i64> [[I0]] to <2 x ptr>
-; CHECK-NEXT:    [[R:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[P0]], <2 x i64> [[M0]])
+; CHECK-NEXT:    [[R:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[P0]], <2 x i64> [[M]])
 ; CHECK-NEXT:    ret <2 x ptr> [[R]]
 ;
   %m0 = and <2 x i64> %m, <i64 127, i64 127>
@@ -386,10 +379,9 @@ define <2 x ptr> @ptrmask_is_useless_vec(<2 x i64> %i, <2 x i64> %m) {
 define <2 x ptr> @ptrmask_is_useless_vec_todo(<2 x i64> %i, <2 x i64> %m) {
 ; CHECK-LABEL: define <2 x ptr> @ptrmask_is_useless_vec_todo
 ; CHECK-SAME: (<2 x i64> [[I:%.*]], <2 x i64> [[M:%.*]]) {
-; CHECK-NEXT:    [[M0:%.*]] = and <2 x i64> [[M]], <i64 127, i64 127>
 ; CHECK-NEXT:    [[I0:%.*]] = and <2 x i64> [[I]], <i64 31, i64 127>
 ; CHECK-NEXT:    [[P0:%.*]] = inttoptr <2 x i64> [[I0]] to <2 x ptr>
-; CHECK-NEXT:    [[R:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[P0]], <2 x i64> [[M0]])
+; CHECK-NEXT:    [[R:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[P0]], <2 x i64> [[M]])
 ; CHECK-NEXT:    ret <2 x ptr> [[R]]
 ;
   %m0 = and <2 x i64> %m, <i64 127, i64 127>


        


More information about the llvm-commits mailing list