[llvm] edb9e9a - [InstCombine] Implement `SimplifyDemandedBits` for `llvm.ptrmask`
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 1 21:50:54 PDT 2023
Author: Noah Goldstein
Date: 2023-11-01T23:50:35-05:00
New Revision: edb9e9a5fb3cae861511f9a11e6d000e8a82fe91
URL: https://github.com/llvm/llvm-project/commit/edb9e9a5fb3cae861511f9a11e6d000e8a82fe91
DIFF: https://github.com/llvm/llvm-project/commit/edb9e9a5fb3cae861511f9a11e6d000e8a82fe91.diff
LOG: [InstCombine] Implement `SimplifyDemandedBits` for `llvm.ptrmask`
Logic basically copies 'and' but we can't return a constant if the
result == rhs (mask) so that case is skipped.
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/lib/Transforms/InstCombine/InstCombineInternal.h
llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
llvm/test/Transforms/InstCombine/ptrmask.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 7e585e9166247cd..5c08ab190eba476 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1962,6 +1962,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
break;
}
case Intrinsic::ptrmask: {
+ unsigned BitWidth = DL.getPointerTypeSizeInBits(II->getType());
+ KnownBits Known(BitWidth);
+ if (SimplifyDemandedInstructionBits(*II, Known))
+ return II;
+
Value *InnerPtr, *InnerMask;
if (match(II->getArgOperand(0),
m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(InnerPtr),
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 01c89ea06f2d9df..34b10220ec88aba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -544,6 +544,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// Tries to simplify operands to an integer instruction based on its
/// demanded bits.
bool SimplifyDemandedInstructionBits(Instruction &Inst);
+ bool SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known);
Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
APInt &UndefElts, unsigned Depth = 0,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index cd6b017874e8d6c..fa6fe9d30abd1b8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -48,15 +48,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
return true;
}
+/// Returns the bitwidth of the given scalar or pointer type. For vector types,
+/// returns the element type's bitwidth.
+static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
+ if (unsigned BitWidth = Ty->getScalarSizeInBits())
+ return BitWidth;
+ return DL.getPointerTypeSizeInBits(Ty);
+}
/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
/// the instruction has any properties that allow us to simplify its operands.
-bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
- unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
- KnownBits Known(BitWidth);
- APInt DemandedMask(APInt::getAllOnes(BitWidth));
-
+bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst,
+ KnownBits &Known) {
+ APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth()));
Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
0, &Inst);
if (!V) return false;
@@ -65,6 +70,13 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
return true;
}
+/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
+/// the instruction has any properties that allow us to simplify its operands.
+bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
+ KnownBits Known(getBitWidth(Inst.getType(), DL));
+ return SimplifyDemandedInstructionBits(Inst, Known);
+}
+
/// This form of SimplifyDemandedBits simplifies the specified instruction
/// operand if possible, updating it in place. It returns true if it made any
/// change and false otherwise.
@@ -143,7 +155,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI);
KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);
-
// If this is the root being simplified, allow it to have multiple uses,
// just set the DemandedMask to all bits so that we can try to simplify the
// operands. This allows visitTruncInst (for example) to simplify the
@@ -893,6 +904,48 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
break;
}
+ case Intrinsic::ptrmask: {
+ unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits();
+ RHSKnown = KnownBits(MaskWidth);
+ // If either the LHS or the RHS are Zero, the result is zero.
+ if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) ||
+ SimplifyDemandedBits(
+ I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth),
+ RHSKnown, Depth + 1))
+ return I;
+
+ // TODO: Should be 1-extend
+ RHSKnown = RHSKnown.anyextOrTrunc(BitWidth);
+ assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
+ assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
+
+ Known = LHSKnown & RHSKnown;
+ KnownBitsComputed = true;
+
+ // If the client is only demanding bits we know to be zero, return
+ // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer
+ // provenance, but making the mask zero will be easily optimizable in
+ // the backend.
+ if (DemandedMask.isSubsetOf(Known.Zero) &&
+ !match(I->getOperand(1), m_Zero()))
+ return replaceOperand(
+ *I, 1, Constant::getNullValue(I->getOperand(1)->getType()));
+
+ // Mask in demanded space does nothing.
+ // NOTE: We may have attributes associated with the return value of the
+ // llvm.ptrmask intrinsic that will be lost when we just return the
+ // operand. We should try to preserve them.
+ if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
+ return I->getOperand(0);
+
+ // If the RHS is a constant, see if we can simplify it.
+ if (ShrinkDemandedConstant(
+ I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth)))
+ return I;
+
+ break;
+ }
+
case Intrinsic::fshr:
case Intrinsic::fshl: {
const APInt *SA;
@@ -978,8 +1031,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
// If the client is only demanding bits that we know, return the known
- // constant.
- if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
+ // constant. We can't directly simplify pointers as a constant because of
+ // pointer provenance.
+ // TODO: We could return `(inttoptr const)` for pointers.
+ if (!V->getType()->isPointerTy() && DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(VTy, Known.One);
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/ptrmask.ll b/llvm/test/Transforms/InstCombine/ptrmask.ll
index 222f6e9ae9c1cc4..b036b8bcf14cddc 100644
--- a/llvm/test/Transforms/InstCombine/ptrmask.ll
+++ b/llvm/test/Transforms/InstCombine/ptrmask.ll
@@ -82,7 +82,7 @@ define ptr @ptrmask_combine_add_nonnull(ptr %p) {
; CHECK-SAME: (ptr [[P:%.*]]) {
; CHECK-NEXT: [[PM0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64)
; CHECK-NEXT: [[PGEP:%.*]] = getelementptr i8, ptr [[PM0]], i64 33
-; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -16)
+; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -32)
; CHECK-NEXT: ret ptr [[R]]
;
%pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)
@@ -230,7 +230,7 @@ define <2 x i32> @ptrtoint_of_ptrmask_vec_fail(<2 x ptr addrspace(1) > %p, <2 x
define ptr addrspace(1) @ptrmask_is_null(ptr addrspace(1) align 32 %p) {
; CHECK-LABEL: define ptr addrspace(1) @ptrmask_is_null
; CHECK-SAME: (ptr addrspace(1) align 32 [[P:%.*]]) {
-; CHECK-NEXT: [[R:%.*]] = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) [[P]], i32 31)
+; CHECK-NEXT: [[R:%.*]] = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) [[P]], i32 0)
; CHECK-NEXT: ret ptr addrspace(1) [[R]]
;
%r = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) %p, i32 31)
@@ -250,7 +250,7 @@ define <2 x ptr addrspace(1) > @ptrmask_is_null_vec(<2 x ptr addrspace(1) > alig
define ptr addrspace(1) @ptrmask_is_null_fail(ptr addrspace(1) align 16 %p) {
; CHECK-LABEL: define ptr addrspace(1) @ptrmask_is_null_fail
; CHECK-SAME: (ptr addrspace(1) align 16 [[P:%.*]]) {
-; CHECK-NEXT: [[R:%.*]] = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) [[P]], i32 31)
+; CHECK-NEXT: [[R:%.*]] = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) [[P]], i32 16)
; CHECK-NEXT: ret ptr addrspace(1) [[R]]
;
%r = call ptr addrspace(1) @llvm.ptrmask.p1.i32(ptr addrspace(1) %p, i32 31)
@@ -290,10 +290,9 @@ define ptr addrspace(1) @ptrmask_maintain_provenance_i32(ptr addrspace(1) %p0) {
define ptr @ptrmask_is_useless0(i64 %i, i64 %m) {
; CHECK-LABEL: define ptr @ptrmask_is_useless0
; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT: [[M0:%.*]] = and i64 [[M]], -4
; CHECK-NEXT: [[I0:%.*]] = and i64 [[I]], -4
; CHECK-NEXT: [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
+; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M]])
; CHECK-NEXT: ret ptr [[R]]
;
%m0 = and i64 %m, -4
@@ -306,10 +305,9 @@ define ptr @ptrmask_is_useless0(i64 %i, i64 %m) {
define ptr @ptrmask_is_useless1(i64 %i, i64 %m) {
; CHECK-LABEL: define ptr @ptrmask_is_useless1
; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT: [[M0:%.*]] = and i64 [[M]], -4
; CHECK-NEXT: [[I0:%.*]] = and i64 [[I]], -8
; CHECK-NEXT: [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
+; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M]])
; CHECK-NEXT: ret ptr [[R]]
;
%m0 = and i64 %m, -4
@@ -322,10 +320,9 @@ define ptr @ptrmask_is_useless1(i64 %i, i64 %m) {
define ptr @ptrmask_is_useless2(i64 %i, i64 %m) {
; CHECK-LABEL: define ptr @ptrmask_is_useless2
; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT: [[M0:%.*]] = and i64 [[M]], 127
; CHECK-NEXT: [[I0:%.*]] = and i64 [[I]], 31
; CHECK-NEXT: [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
+; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M]])
; CHECK-NEXT: ret ptr [[R]]
;
%m0 = and i64 %m, 127
@@ -338,10 +335,9 @@ define ptr @ptrmask_is_useless2(i64 %i, i64 %m) {
define ptr @ptrmask_is_useless3(i64 %i, i64 %m) {
; CHECK-LABEL: define ptr @ptrmask_is_useless3
; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT: [[M0:%.*]] = and i64 [[M]], 127
; CHECK-NEXT: [[I0:%.*]] = and i64 [[I]], 127
; CHECK-NEXT: [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
+; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M]])
; CHECK-NEXT: ret ptr [[R]]
;
%m0 = and i64 %m, 127
@@ -354,11 +350,9 @@ define ptr @ptrmask_is_useless3(i64 %i, i64 %m) {
define ptr @ptrmask_is_useless4(i64 %i, i64 %m) {
; CHECK-LABEL: define ptr @ptrmask_is_useless4
; CHECK-SAME: (i64 [[I:%.*]], i64 [[M:%.*]]) {
-; CHECK-NEXT: [[M0:%.*]] = or i64 [[M]], -4
; CHECK-NEXT: [[I0:%.*]] = and i64 [[I]], -4
; CHECK-NEXT: [[P0:%.*]] = inttoptr i64 [[I0]] to ptr
-; CHECK-NEXT: [[R:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 [[M0]])
-; CHECK-NEXT: ret ptr [[R]]
+; CHECK-NEXT: ret ptr [[P0]]
;
%m0 = or i64 %m, -4
%i0 = and i64 %i, -4
@@ -370,10 +364,9 @@ define ptr @ptrmask_is_useless4(i64 %i, i64 %m) {
define <2 x ptr> @ptrmask_is_useless_vec(<2 x i64> %i, <2 x i64> %m) {
; CHECK-LABEL: define <2 x ptr> @ptrmask_is_useless_vec
; CHECK-SAME: (<2 x i64> [[I:%.*]], <2 x i64> [[M:%.*]]) {
-; CHECK-NEXT: [[M0:%.*]] = and <2 x i64> [[M]], <i64 127, i64 127>
; CHECK-NEXT: [[I0:%.*]] = and <2 x i64> [[I]], <i64 31, i64 31>
; CHECK-NEXT: [[P0:%.*]] = inttoptr <2 x i64> [[I0]] to <2 x ptr>
-; CHECK-NEXT: [[R:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[P0]], <2 x i64> [[M0]])
+; CHECK-NEXT: [[R:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[P0]], <2 x i64> [[M]])
; CHECK-NEXT: ret <2 x ptr> [[R]]
;
%m0 = and <2 x i64> %m, <i64 127, i64 127>
@@ -386,10 +379,9 @@ define <2 x ptr> @ptrmask_is_useless_vec(<2 x i64> %i, <2 x i64> %m) {
define <2 x ptr> @ptrmask_is_useless_vec_todo(<2 x i64> %i, <2 x i64> %m) {
; CHECK-LABEL: define <2 x ptr> @ptrmask_is_useless_vec_todo
; CHECK-SAME: (<2 x i64> [[I:%.*]], <2 x i64> [[M:%.*]]) {
-; CHECK-NEXT: [[M0:%.*]] = and <2 x i64> [[M]], <i64 127, i64 127>
; CHECK-NEXT: [[I0:%.*]] = and <2 x i64> [[I]], <i64 31, i64 127>
; CHECK-NEXT: [[P0:%.*]] = inttoptr <2 x i64> [[I0]] to <2 x ptr>
-; CHECK-NEXT: [[R:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[P0]], <2 x i64> [[M0]])
+; CHECK-NEXT: [[R:%.*]] = call <2 x ptr> @llvm.ptrmask.v2p0.v2i64(<2 x ptr> [[P0]], <2 x i64> [[M]])
; CHECK-NEXT: ret <2 x ptr> [[R]]
;
%m0 = and <2 x i64> %m, <i64 127, i64 127>
More information about the llvm-commits
mailing list