[llvm] [InstCombine] Infer nuw on mul nsw with non-negative operands (PR #90170)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 23:05:00 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

If a mul nsw has non-negative operands, it's also nuw.

Proof: https://alive2.llvm.org/ce/z/2Dz9Uu

Fixes https://github.com/llvm/llvm-project/issues/90020.

I originally thought I needed this to make a fold work, but it turned out to be unnecessary, so I no longer have a specific motivation for this. I figured I'd still submit that patch now that I already wrote it...

---
Full diff: https://github.com/llvm/llvm-project/pull/90170.diff


6 Files Affected:

- (modified) llvm/include/llvm/Analysis/ValueTracking.h (+2-1) 
- (modified) llvm/include/llvm/Transforms/InstCombine/InstCombiner.h (+4-3) 
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+7-1) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+3-2) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+1-1) 
- (modified) llvm/test/Transforms/InstCombine/mul.ll (+1-1) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 571e44cdac2650..afd18e7e56ba0c 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -860,7 +860,8 @@ enum class OverflowResult {
 };
 
 OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS,
-                                             const SimplifyQuery &SQ);
+                                             const SimplifyQuery &SQ,
+                                             bool IsNSW = false);
 OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
                                            const SimplifyQuery &SQ);
 OverflowResult
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index ea1f4fc3b85dc8..855d1aeddfaee0 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -461,9 +461,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
 
   OverflowResult computeOverflowForUnsignedMul(const Value *LHS,
                                                const Value *RHS,
-                                               const Instruction *CxtI) const {
-    return llvm::computeOverflowForUnsignedMul(LHS, RHS,
-                                               SQ.getWithInstruction(CxtI));
+                                               const Instruction *CxtI,
+                                               bool IsNSW = false) const {
+    return llvm::computeOverflowForUnsignedMul(
+        LHS, RHS, SQ.getWithInstruction(CxtI), IsNSW);
   }
 
   OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index de38eddaa98fef..1b461e7cfd01f0 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -6686,9 +6686,15 @@ llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
 
 OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
                                                    const Value *RHS,
-                                                   const SimplifyQuery &SQ) {
+                                                   const SimplifyQuery &SQ,
+                                                   bool IsNSW) {
   KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
   KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
+
+  // mul nsw of two non-negative numbers is also nuw.
+  if (IsNSW && LHSKnown.isNonNegative() && RHSKnown.isNonNegative())
+    return OverflowResult::NeverOverflows;
+
   ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
   ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
   return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index aafb4cf6ca6a62..db7838bbe3c256 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -354,8 +354,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   }
 
   bool willNotOverflowUnsignedMul(const Value *LHS, const Value *RHS,
-                                  const Instruction &CxtI) const {
-    return computeOverflowForUnsignedMul(LHS, RHS, &CxtI) ==
+                                  const Instruction &CxtI,
+                                  bool IsNSW = false) const {
+    return computeOverflowForUnsignedMul(LHS, RHS, &CxtI, IsNSW) ==
            OverflowResult::NeverOverflows;
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 4ed4c36e21e016..ca1b1921404d80 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -530,7 +530,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
     I.setHasNoSignedWrap(true);
   }
 
-  if (!HasNUW && willNotOverflowUnsignedMul(Op0, Op1, I)) {
+  if (!HasNUW && willNotOverflowUnsignedMul(Op0, Op1, I, I.hasNoSignedWrap())) {
     Changed = true;
     I.setHasNoUnsignedWrap(true);
   }
diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index 4c1ce10171dd71..4fb3c0b1ad4916 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -2146,7 +2146,7 @@ define i8 @mul_nsw_nonneg(i8 %x, i8 %y) {
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[X_NNEG]])
 ; CHECK-NEXT:    [[Y_NNEG:%.*]] = icmp sgt i8 [[Y:%.*]], -1
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[Y_NNEG]])
-; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]]
 ; CHECK-NEXT:    ret i8 [[MUL]]
 ;
   %x.nneg = icmp sge i8 %x, 0

``````````

</details>


https://github.com/llvm/llvm-project/pull/90170


More information about the llvm-commits mailing list