[llvm] perf/goldsteinn/improve smax (PR #88170)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 9 10:57:38 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: None (goldsteinn)
<details>
<summary>Changes</summary>
- **[ValueTracking] Expand `isKnown{Negative,Positive}` APIs; NFC**
- **[ValueTracking] Add tests for improving `isKnownNonZero` of `smax`; NFC**
- **[ValueTracking] improve `isKnownNonZero` precision for `smax`**
---
Full diff: https://github.com/llvm/llvm-project/pull/88170.diff
3 Files Affected:
- (modified) llvm/include/llvm/Analysis/ValueTracking.h (+10)
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+47-19)
- (modified) llvm/test/Transforms/InstSimplify/known-non-zero.ll (+11)
``````````diff
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 3970efba18cc8c..7287a8fb122bbb 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -145,11 +145,21 @@ bool isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
bool isKnownPositive(const Value *V, const SimplifyQuery &SQ,
unsigned Depth = 0);
+/// Returns true if the given value is known be positive (i.e. non-negative
+/// and non-zero) for DemandedElts.
+bool isKnownPositive(const Value *V, const APInt &DemandedElts,
+ const SimplifyQuery &SQ, unsigned Depth = 0);
+
/// Returns true if the given value is known be negative (i.e. non-positive
/// and non-zero).
bool isKnownNegative(const Value *V, const SimplifyQuery &DL,
unsigned Depth = 0);
+/// Returns true if the given value is known be negative (i.e. non-positive
+/// and non-zero) for DemandedElts.
+bool isKnownNegative(const Value *V, const APInt &DemandedElts,
+ const SimplifyQuery &DL, unsigned Depth = 0);
+
/// Return true if the given values are known to be non-equal when defined.
/// Supports scalar integer types only.
bool isKnownNonEqual(const Value *V1, const Value *V2, const DataLayout &DL,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index ca48cfe7738154..187d781c59e072 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -289,21 +289,52 @@ bool llvm::isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
return computeKnownBits(V, Depth, SQ).isNonNegative();
}
-bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
- unsigned Depth) {
+static bool isKnownPositive(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known, const SimplifyQuery &SQ,
+ unsigned Depth) {
if (auto *CI = dyn_cast<ConstantInt>(V))
return CI->getValue().isStrictlyPositive();
// If `isKnownNonNegative` ever becomes more sophisticated, make sure to keep
// this updated.
- KnownBits Known = computeKnownBits(V, Depth, SQ);
+ Known = computeKnownBits(V, DemandedElts, Depth, SQ);
return Known.isNonNegative() &&
- (Known.isNonZero() || ::isKnownNonZero(V, Depth, SQ));
+ (Known.isNonZero() || ::isKnownNonZero(V, DemandedElts, Depth, SQ));
+}
+
+bool llvm::isKnownPositive(const Value *V, const APInt &DemandedElts,
+ const SimplifyQuery &SQ, unsigned Depth) {
+ KnownBits Known(getBitWidth(V->getType(), SQ.DL));
+ return ::isKnownPositive(V, DemandedElts, Known, SQ, Depth);
+}
+
+bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
+ unsigned Depth) {
+ auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+ APInt DemandedElts =
+ FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+ return isKnownPositive(V, DemandedElts, SQ, Depth);
+}
+
+static bool isKnownNegative(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known, const SimplifyQuery &SQ,
+ unsigned Depth) {
+ Known = computeKnownBits(V, DemandedElts, Depth, SQ);
+ return Known.isNegative();
+}
+
+bool llvm::isKnownNegative(const Value *V, const APInt &DemandedElts,
+ const SimplifyQuery &SQ, unsigned Depth) {
+ KnownBits Known(getBitWidth(V->getType(), SQ.DL));
+ return ::isKnownNegative(V, DemandedElts, Known, SQ, Depth);
}
bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
unsigned Depth) {
- return computeKnownBits(V, Depth, SQ).isNegative();
+ auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+ APInt DemandedElts =
+ FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+ return isKnownNegative(V, DemandedElts, SQ, Depth);
}
static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
@@ -2830,21 +2861,18 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
isKnownNonZero(II->getArgOperand(0), DemandedElts, Depth, Q);
case Intrinsic::smin:
case Intrinsic::smax: {
- auto KnownOpImpliesNonZero = [&](const KnownBits &K) {
- return II->getIntrinsicID() == Intrinsic::smin
- ? K.isNegative()
- : K.isStrictlyPositive();
+ bool AllNonZero = true;
+ auto KnownOpImpliesNonZero = [&](const Value *Op) {
+ KnownBits TmpKnown(getBitWidth(Op->getType(), Q.DL));
+ bool Ret =
+ II->getIntrinsicID() == Intrinsic::smin
+ ? ::isKnownNegative(Op, DemandedElts, TmpKnown, Q, Depth)
+ : ::isKnownPositive(Op, DemandedElts, TmpKnown, Q, Depth);
+ AllNonZero &= TmpKnown.isNonZero();
+ return Ret;
};
- KnownBits XKnown =
- computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q);
- if (KnownOpImpliesNonZero(XKnown))
- return true;
- KnownBits YKnown =
- computeKnownBits(II->getArgOperand(1), DemandedElts, Depth, Q);
- if (KnownOpImpliesNonZero(YKnown))
- return true;
-
- if (XKnown.isNonZero() && YKnown.isNonZero())
+ if (KnownOpImpliesNonZero(II->getArgOperand(0)) ||
+ KnownOpImpliesNonZero(II->getArgOperand(1)) || AllNonZero)
return true;
}
[[fallthrough]];
diff --git a/llvm/test/Transforms/InstSimplify/known-non-zero.ll b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
index b647f11af4461d..51f80f62c2f34c 100644
--- a/llvm/test/Transforms/InstSimplify/known-non-zero.ll
+++ b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
@@ -166,3 +166,14 @@ A:
B:
ret i1 0
}
+
+define i1 @smax_non_zero(i8 %xx, i8 %y) {
+; CHECK-LABEL: @smax_non_zero(
+; CHECK-NEXT: ret i1 false
+;
+ %x0 = and i8 %xx, 63
+ %x = add i8 %x0, 1
+ %v = call i8 @llvm.smax.i8(i8 %x, i8 %y)
+ %r = icmp eq i8 %v, 0
+ ret i1 %r
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/88170
More information about the llvm-commits
mailing list