[llvm] [InstCombine] Simplify `(X / C0) * C1 + (X % C0) * C2` to `(X / C0) * (C1 - C2 * C0) + X * C2` (PR #76285)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 19 07:08:53 PDT 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/76285

>From a42ae75dfa40de6296200f183ff2fde0ef95a7d5 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 23 Dec 2023 21:23:15 +0800
Subject: [PATCH 1/2] [InstCombine] Add pre-commit tests from PR76128. NFC.

---
 llvm/test/Transforms/InstCombine/add4.ll | 173 +++++++++++++++++++++++
 1 file changed, 173 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/add4.ll b/llvm/test/Transforms/InstCombine/add4.ll
index 7fc164c8b9a7c9..410f144f51c3b0 100644
--- a/llvm/test/Transforms/InstCombine/add4.ll
+++ b/llvm/test/Transforms/InstCombine/add4.ll
@@ -1,6 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
+declare void @use(i32)
+
 define i64 @match_unsigned(i64 %x) {
 ; CHECK-LABEL: @match_unsigned(
 ; CHECK-NEXT:    [[UREM:%.*]] = urem i64 [[X:%.*]], 19136
@@ -127,3 +129,174 @@ define i32 @not_match_overflow(i32 %x) {
   %t4 = add i32 %t, %t3
   ret i32 %t4
 }
+
+; Tests from PR76128.
+define i32 @fold_add_udiv_urem(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_sdiv_srem(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_sdiv_srem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = srem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = sdiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = srem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_to_mul(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_to_mul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 7
+; CHECK-NEXT:    [[MUL1:%.*]] = mul i32 [[DIV]], 21
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 7
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nuw nsw i32 [[REM]], 3
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 7
+  %mul1 = mul i32 %div, 21
+  %rem = urem i32 %val, 7
+  %mul2 = mul i32 %rem, 3
+  %add = add i32 %mul1, %mul2
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_to_mul_multiuse(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_to_mul_multiuse(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 7
+; CHECK-NEXT:    [[MUL1:%.*]] = mul i32 [[DIV]], 21
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 7
+; CHECK-NEXT:    call void @use(i32 [[REM]])
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nuw nsw i32 [[REM]], 3
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 7
+  %mul1 = mul i32 %div, 21
+  %rem = urem i32 %val, 7
+  call void @use(i32 %rem)
+  %mul2 = mul i32 %rem, 3
+  %add = add i32 %mul1, %mul2
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_commuted(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_commuted(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[REM]], [[SHL]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = add i32 %rem, %shl
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_or_disjoint(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_or_disjoint(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = or disjoint i32 %shl, %rem
+  ret i32 %add
+}
+; Negative tests
+define i32 @fold_add_udiv_urem_without_noundef(i32 %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_without_noundef(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_multiuse_mul(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_multiuse_mul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    call void @use(i32 [[SHL]])
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  call void @use(i32 %shl)
+  %rem = urem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_srem(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_srem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = srem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = srem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_non_constant(i32 noundef %val, i32 noundef %c) {
+; CHECK-LABEL: @fold_add_udiv_urem_non_constant(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], [[C]]
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, %c
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, %c
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}

>From 3673875b024a253e464bb627ce23444f8a9b3d90 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 23 Dec 2023 21:23:57 +0800
Subject: [PATCH 2/2] [InstCombine] Simplify `(X / C0) * C1 + (X % C0) * C2` to
 `(X / C0) * (C1 - C2 * C0) + X * C2`

---
 .../InstCombine/InstCombineAddSub.cpp         | 29 ++++++++++++++++
 .../InstCombine/InstCombineAndOrXor.cpp       |  4 +++
 llvm/test/Transforms/InstCombine/add4.ll      | 33 +++++++------------
 3 files changed, 44 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 9a16ca96db76b6..13d08b102c2055 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1134,6 +1134,8 @@ static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) {
 
 // Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1)
 // does not overflow.
+// Simplifies (X / C0) * C1 + (X % C0) * C2 to
+// (X / C0) * (C1 - C2 * C0) + X * C2
 Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) {
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
   Value *X, *MulOpV;
@@ -1161,6 +1163,33 @@ Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) {
     }
   }
 
+  // Match I = (X / C0) * C1 + (X % C0) * C2
+  Value *Div, *Rem;
+  APInt C1, C2;
+  if (!LHS->hasOneUse() || !MatchMul(LHS, Div, C1))
+    Div = LHS, C1 = APInt(I.getType()->getScalarSizeInBits(), 1);
+  if (!RHS->hasOneUse() || !MatchMul(RHS, Rem, C2))
+    Rem = RHS, C2 = APInt(I.getType()->getScalarSizeInBits(), 1);
+  if (match(Div, m_IRem(m_Value(), m_Value()))) {
+    std::swap(Div, Rem);
+    std::swap(C1, C2);
+  }
+  Value *DivOpV;
+  APInt DivOpC;
+  if (MatchRem(Rem, X, C0, IsSigned) &&
+      MatchDiv(Div, DivOpV, DivOpC, IsSigned) && X == DivOpV && C0 == DivOpC) {
+    if (!isGuaranteedNotToBeUndef(X, &AC, &I, &DT))
+      return nullptr;
+    APInt NewC = C1 - C2 * C0;
+    if (!NewC.isZero() && !Rem->hasOneUse())
+      return nullptr;
+    Value *MulXC2 = Builder.CreateMul(X, ConstantInt::get(X->getType(), C2));
+    if (NewC.isZero())
+      return MulXC2;
+    return Builder.CreateAdd(
+        Builder.CreateMul(Div, ConstantInt::get(X->getType(), NewC)), MulXC2);
+  }
+
   return nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index bf7c0074a38f05..14fe536ed05a61 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3958,6 +3958,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                       /*SimplifyOnly*/ false, *this))
     return BinaryOperator::CreateOr(Op0, V);
 
+  if (cast<PossiblyDisjointInst>(I).isDisjoint())
+    if (Value *V = SimplifyAddWithRemainder(I))
+      return replaceInstUsesWith(I, V);
+
   return nullptr;
 }
 
diff --git a/llvm/test/Transforms/InstCombine/add4.ll b/llvm/test/Transforms/InstCombine/add4.ll
index 410f144f51c3b0..77f7fc7b35cd44 100644
--- a/llvm/test/Transforms/InstCombine/add4.ll
+++ b/llvm/test/Transforms/InstCombine/add4.ll
@@ -135,9 +135,8 @@ define i32 @fold_add_udiv_urem(i32 noundef %val) {
 ; CHECK-LABEL: @fold_add_udiv_urem(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
-; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
-; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
-; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
 ; CHECK-NEXT:    ret i32 [[ADD]]
 ;
 entry:
@@ -151,9 +150,8 @@ define i32 @fold_add_sdiv_srem(i32 noundef %val) {
 ; CHECK-LABEL: @fold_add_sdiv_srem(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[DIV:%.*]] = sdiv i32 [[VAL:%.*]], 10
-; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
-; CHECK-NEXT:    [[REM:%.*]] = srem i32 [[VAL]], 10
-; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nsw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
 ; CHECK-NEXT:    ret i32 [[ADD]]
 ;
 entry:
@@ -166,11 +164,7 @@ entry:
 define i32 @fold_add_udiv_urem_to_mul(i32 noundef %val) {
 ; CHECK-LABEL: @fold_add_udiv_urem_to_mul(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 7
-; CHECK-NEXT:    [[MUL1:%.*]] = mul i32 [[DIV]], 21
-; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 7
-; CHECK-NEXT:    [[MUL2:%.*]] = mul nuw nsw i32 [[REM]], 3
-; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    [[ADD:%.*]] = mul i32 [[VAL:%.*]], 3
 ; CHECK-NEXT:    ret i32 [[ADD]]
 ;
 entry:
@@ -184,12 +178,9 @@ entry:
 define i32 @fold_add_udiv_urem_to_mul_multiuse(i32 noundef %val) {
 ; CHECK-LABEL: @fold_add_udiv_urem_to_mul_multiuse(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 7
-; CHECK-NEXT:    [[MUL1:%.*]] = mul i32 [[DIV]], 21
-; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 7
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL:%.*]], 7
 ; CHECK-NEXT:    call void @use(i32 [[REM]])
-; CHECK-NEXT:    [[MUL2:%.*]] = mul nuw nsw i32 [[REM]], 3
-; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    [[ADD:%.*]] = mul i32 [[VAL]], 3
 ; CHECK-NEXT:    ret i32 [[ADD]]
 ;
 entry:
@@ -205,9 +196,8 @@ define i32 @fold_add_udiv_urem_commuted(i32 noundef %val) {
 ; CHECK-LABEL: @fold_add_udiv_urem_commuted(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
-; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
-; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
-; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[REM]], [[SHL]]
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
 ; CHECK-NEXT:    ret i32 [[ADD]]
 ;
 entry:
@@ -221,9 +211,8 @@ define i32 @fold_add_udiv_urem_or_disjoint(i32 noundef %val) {
 ; CHECK-LABEL: @fold_add_udiv_urem_or_disjoint(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
-; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
-; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
-; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
 ; CHECK-NEXT:    ret i32 [[ADD]]
 ;
 entry:



More information about the llvm-commits mailing list