[llvm] 945eeb2 - [InstCombine] Simplify `(X / C0) * C1 + (X % C0) * C2` to `(X / C0) * (C1 - C2 * C0) + X * C2` (#76285)

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 24 02:01:53 PDT 2024


Author: Yingwei Zheng
Date: 2024-04-24T17:01:49+08:00
New Revision: 945eeb2d92758ef907ef3aeb3251fadc64b731b3

URL: https://github.com/llvm/llvm-project/commit/945eeb2d92758ef907ef3aeb3251fadc64b731b3
DIFF: https://github.com/llvm/llvm-project/commit/945eeb2d92758ef907ef3aeb3251fadc64b731b3.diff

LOG: [InstCombine] Simplify `(X / C0) * C1 + (X % C0) * C2` to `(X / C0) * (C1 - C2 * C0) + X * C2` (#76285)

Since `DivRemPairPass` runs after `ReassociatePass` in the optimization
pipeline, I decided to do this simplification in `InstCombine`.

Alive2: https://alive2.llvm.org/ce/z/Jgsiqf
Fixes #76128.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/test/Transforms/InstCombine/add4.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index fc284bc61cce97..88b7e496897e1f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1134,6 +1134,8 @@ static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) {
 
 // Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1)
 // does not overflow.
+// Simplifies (X / C0) * C1 + (X % C0) * C2 to
+// (X / C0) * (C1 - C2 * C0) + X * C2
 Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) {
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
   Value *X, *MulOpV;
@@ -1161,6 +1163,33 @@ Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) {
     }
   }
 
+  // Match I = (X / C0) * C1 + (X % C0) * C2
+  Value *Div, *Rem;
+  APInt C1, C2;
+  if (!LHS->hasOneUse() || !MatchMul(LHS, Div, C1))
+    Div = LHS, C1 = APInt(I.getType()->getScalarSizeInBits(), 1);
+  if (!RHS->hasOneUse() || !MatchMul(RHS, Rem, C2))
+    Rem = RHS, C2 = APInt(I.getType()->getScalarSizeInBits(), 1);
+  if (match(Div, m_IRem(m_Value(), m_Value()))) {
+    std::swap(Div, Rem);
+    std::swap(C1, C2);
+  }
+  Value *DivOpV;
+  APInt DivOpC;
+  if (MatchRem(Rem, X, C0, IsSigned) &&
+      MatchDiv(Div, DivOpV, DivOpC, IsSigned) && X == DivOpV && C0 == DivOpC) {
+    APInt NewC = C1 - C2 * C0;
+    if (!NewC.isZero() && !Rem->hasOneUse())
+      return nullptr;
+    if (!isGuaranteedNotToBeUndef(X, &AC, &I, &DT))
+      return nullptr;
+    Value *MulXC2 = Builder.CreateMul(X, ConstantInt::get(X->getType(), C2));
+    if (NewC.isZero())
+      return MulXC2;
+    return Builder.CreateAdd(
+        Builder.CreateMul(Div, ConstantInt::get(X->getType(), NewC)), MulXC2);
+  }
+
   return nullptr;
 }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index e1923a3441790a..8ec1ed7529c1cb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3958,6 +3958,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                       /*SimplifyOnly*/ false, *this))
     return BinaryOperator::CreateOr(Op0, V);
 
+  if (cast<PossiblyDisjointInst>(I).isDisjoint())
+    if (Value *V = SimplifyAddWithRemainder(I))
+      return replaceInstUsesWith(I, V);
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/add4.ll b/llvm/test/Transforms/InstCombine/add4.ll
index 7fc164c8b9a7c9..77f7fc7b35cd44 100644
--- a/llvm/test/Transforms/InstCombine/add4.ll
+++ b/llvm/test/Transforms/InstCombine/add4.ll
@@ -1,6 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
+declare void @use(i32)
+
 define i64 @match_unsigned(i64 %x) {
 ; CHECK-LABEL: @match_unsigned(
 ; CHECK-NEXT:    [[UREM:%.*]] = urem i64 [[X:%.*]], 19136
@@ -127,3 +129,163 @@ define i32 @not_match_overflow(i32 %x) {
   %t4 = add i32 %t, %t3
   ret i32 %t4
 }
+
+; Tests from PR76128.
+define i32 @fold_add_udiv_urem(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_sdiv_srem(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_sdiv_srem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nsw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = sdiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = srem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_to_mul(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_to_mul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ADD:%.*]] = mul i32 [[VAL:%.*]], 3
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 7
+  %mul1 = mul i32 %div, 21
+  %rem = urem i32 %val, 7
+  %mul2 = mul i32 %rem, 3
+  %add = add i32 %mul1, %mul2
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_to_mul_multiuse(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_to_mul_multiuse(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL:%.*]], 7
+; CHECK-NEXT:    call void @use(i32 [[REM]])
+; CHECK-NEXT:    [[ADD:%.*]] = mul i32 [[VAL]], 3
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 7
+  %mul1 = mul i32 %div, 21
+  %rem = urem i32 %val, 7
+  call void @use(i32 %rem)
+  %mul2 = mul i32 %rem, 3
+  %add = add i32 %mul1, %mul2
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_commuted(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_commuted(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = add i32 %rem, %shl
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_or_disjoint(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_or_disjoint(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = or disjoint i32 %shl, %rem
+  ret i32 %add
+}
+; Negative tests
+define i32 @fold_add_udiv_urem_without_noundef(i32 %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_without_noundef(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_multiuse_mul(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_multiuse_mul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    call void @use(i32 [[SHL]])
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  call void @use(i32 %shl)
+  %rem = urem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_srem(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_srem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = srem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = srem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_non_constant(i32 noundef %val, i32 noundef %c) {
+; CHECK-LABEL: @fold_add_udiv_urem_non_constant(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], [[C]]
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, %c
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, %c
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}


        


More information about the llvm-commits mailing list