[llvm] [ValueTracking] Extend known bits of `mul` with self and constant (PR #81892)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 15 10:17:13 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Antonio Frighetto (antoniofrighetto)
<details>
<summary>Changes</summary>
InstCombine was previously suboptimal when computing `abs(b * a * a)` with `b` being a known constant.
This has been addressed by computing the sign bit in `computeKnownBitsMul`, when multiplying a value with itself and a constant.
Fixes: https://github.com/llvm/llvm-project/issues/78018.
Proofs: https://alive2.llvm.org/ce/z/E3uJum.
---
Full diff: https://github.com/llvm/llvm-project/pull/81892.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+13)
- (modified) llvm/test/Transforms/InstCombine/mul.ll (+61)
``````````diff
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index cc1d5b74dcfc53..db26af467139c9 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -373,11 +373,24 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
bool isKnownNegative = false;
bool isKnownNonNegative = false;
+ const APInt *C;
+ unsigned BitWidth = Known.getBitWidth();
+
// If the multiplication is known not to overflow, compute the sign bit.
if (NSW) {
if (Op0 == Op1) {
// The product of a number with itself is non-negative.
isKnownNonNegative = true;
+ } else if (((match(Op0, m_Mul(m_Specific(Op1), m_APInt(C))) &&
+ cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()) ||
+ (match(Op0, m_Shl(m_Specific(Op1), m_APInt(C))) &&
+ C->ult(BitWidth))) &&
+ !C->isZero() && !Known.isZero()) {
+ // The product of a number with itself and a constant depends on the sign
+ // of the constant.
+ KnownBits KnownC = KnownBits::makeConstant(*C);
+ isKnownNonNegative = KnownC.isNonNegative();
+ isKnownNegative = KnownC.isNegative();
} else {
bool isKnownNonNegativeOp1 = Known.isNonNegative();
bool isKnownNonNegativeOp0 = Known2.isNonNegative();
diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index e7141d7c25ad21..656ed50fc25fd2 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -2049,3 +2049,64 @@ define i32 @zext_negpow2_use(i8 %x) {
%r = mul i32 %zx, -16777216 ; -1 << 24
ret i32 %r
}
+
+define i1 @self_with_constant_greater_than_zero(i8 %a) {
+; CHECK-LABEL: @self_with_constant_greater_than_zero(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: ret i1 true
+;
+entry:
+ %mul = mul nsw i8 %a, 3
+ %mul1 = mul nsw i8 %mul, %a
+ %cmp = icmp sge i8 %mul1, 0
+ ret i1 %cmp
+}
+
+define i8 @abs_of_self_with_constant(i8 %a) {
+; CHECK-LABEL: @abs_of_self_with_constant(
+; CHECK-NEXT: [[MUL:%.*]] = shl nsw i8 [[A:%.*]], 1
+; CHECK-NEXT: [[MUL1:%.*]] = mul nsw i8 [[MUL]], [[A]]
+; CHECK-NEXT: ret i8 [[MUL1]]
+;
+ %mul = mul nsw i8 %a, 2
+ %mul1 = mul nsw i8 %mul, %a
+ %r = tail call i8 @llvm.abs.i8(i8 %mul1, i1 true)
+ ret i8 %r
+}
+
+define i8 @abs_of_self_with_constant_neg_constant(i8 %a) {
+; CHECK-LABEL: @abs_of_self_with_constant_neg_constant(
+; CHECK-NEXT: [[MUL_NEG:%.*]] = mul i8 [[A:%.*]], 3
+; CHECK-NEXT: [[MUL1_NEG:%.*]] = mul nsw i8 [[MUL_NEG]], [[A]]
+; CHECK-NEXT: ret i8 [[MUL1_NEG]]
+;
+ %mul = mul nsw i8 %a, -3
+ %mul1 = mul nsw i8 %mul, %a
+ %r = tail call i8 @llvm.abs.i8(i8 %mul1, i1 true)
+ ret i8 %r
+}
+
+define i8 @abs_of_self_with_constant_2_inv_ops(i8 %a) {
+; CHECK-LABEL: @abs_of_self_with_constant_2_inv_ops(
+; CHECK-NEXT: [[MUL:%.*]] = shl nsw i8 [[A:%.*]], 2
+; CHECK-NEXT: [[MUL1:%.*]] = mul nsw i8 [[MUL]], [[A]]
+; CHECK-NEXT: ret i8 [[MUL1]]
+;
+ %mul = mul nsw i8 4, %a
+ %mul1 = mul nsw i8 %a, %mul
+ %r = tail call i8 @llvm.abs.i8(i8 %mul1, i1 true)
+ ret i8 %r
+}
+
+define i8 @abs_of_self_with_constant_no_nsw(i8 %a) {
+; CHECK-LABEL: @abs_of_self_with_constant_no_nsw(
+; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[A:%.*]], 3
+; CHECK-NEXT: [[MUL1:%.*]] = mul nsw i8 [[MUL]], [[A]]
+; CHECK-NEXT: [[R:%.*]] = tail call i8 @llvm.abs.i8(i8 [[MUL1]], i1 true)
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %mul = mul i8 %a, 3
+ %mul1 = mul nsw i8 %mul, %a
+ %r = tail call i8 @llvm.abs.i8(i8 %mul1, i1 true)
+ ret i8 %r
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/81892
More information about the llvm-commits
mailing list