[llvm] [InstCombine] Handle commuted pattern for `((X s/ C1) << C2) + X` (PR #121737)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 6 02:22:56 PST 2025


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/121737

>From 01055b1a6f565d8d10ba7efa2206b8f0dac84cca Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 6 Jan 2025 00:21:58 +0800
Subject: [PATCH 1/3] [InstCombine] Add pre-commit tests. NFC.

---
 .../Transforms/InstCombine/add-shl-sdiv-to-srem.ll  | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
index 84462f9a7f592b..60bfe3f8665b78 100644
--- a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
+++ b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
@@ -12,6 +12,19 @@ define i8 @add-shl-sdiv-scalar0(i8 %x) {
   ret i8 %rz
 }
 
+define i8 @add-shl-sdiv-scalar0_commuted(i8 %x) {
+; CHECK-LABEL: @add-shl-sdiv-scalar0_commuted(
+; CHECK-NEXT:    [[SD:%.*]] = sdiv i8 [[X:%.*]], -4
+; CHECK-NEXT:    [[SL:%.*]] = shl i8 [[SD]], 2
+; CHECK-NEXT:    [[RZ:%.*]] = add i8 [[X]], [[SL]]
+; CHECK-NEXT:    ret i8 [[RZ]]
+;
+  %sd = sdiv i8 %x, -4
+  %sl = shl i8 %sd, 2
+  %rz = add i8 %x, %sl
+  ret i8 %rz
+}
+
 define i8 @add-shl-sdiv-scalar1(i8 %x) {
 ; CHECK-LABEL: @add-shl-sdiv-scalar1(
 ; CHECK-NEXT:    [[RZ:%.*]] = srem i8 [[X:%.*]], 64

>From 79c303869c61e8d4479aaaa29b8b677c45a9318b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 6 Jan 2025 00:50:45 +0800
Subject: [PATCH 2/3] [InstCombine] Handle commuted pattern for `((X s/ C1) <<
 C2) + X`

---
 llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp    | 7 ++++---
 llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll | 4 +---
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 7a184a19d7c54a..74d17067de16e5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1625,12 +1625,13 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
 
   // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2
   const APInt *C1, *C2;
-  if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) {
+  if (match(&I, m_c_Add(m_Shl(m_SDiv(m_Value(A), m_APInt(C1)), m_APInt(C2)),
+                        m_Deferred(A)))) {
     APInt one(C2->getBitWidth(), 1);
     APInt minusC1 = -(*C1);
     if (minusC1 == (one << *C2)) {
-      Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1);
-      return BinaryOperator::CreateSRem(RHS, NewRHS);
+      Constant *NewRHS = ConstantInt::get(A->getType(), minusC1);
+      return BinaryOperator::CreateSRem(A, NewRHS);
     }
   }
 
diff --git a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
index 60bfe3f8665b78..d4edf12eba6acf 100644
--- a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
+++ b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
@@ -14,9 +14,7 @@ define i8 @add-shl-sdiv-scalar0(i8 %x) {
 
 define i8 @add-shl-sdiv-scalar0_commuted(i8 %x) {
 ; CHECK-LABEL: @add-shl-sdiv-scalar0_commuted(
-; CHECK-NEXT:    [[SD:%.*]] = sdiv i8 [[X:%.*]], -4
-; CHECK-NEXT:    [[SL:%.*]] = shl i8 [[SD]], 2
-; CHECK-NEXT:    [[RZ:%.*]] = add i8 [[X]], [[SL]]
+; CHECK-NEXT:    [[RZ:%.*]] = srem i8 [[X:%.*]], 4
 ; CHECK-NEXT:    ret i8 [[RZ]]
 ;
   %sd = sdiv i8 %x, -4

>From 31a6f4a6dd81a6fe5ef956069d24826c9ccee129 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 6 Jan 2025 18:22:31 +0800
Subject: [PATCH 3/3] [InstCombine] Move the logic into
 `foldAddLikeCommutative`

---
 .../InstCombine/InstCombineAddSub.cpp         | 25 ++++++++++---------
 1 file changed, 13 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 74d17067de16e5..dee07b260dcd31 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1326,6 +1326,18 @@ Instruction *InstCombinerImpl::foldAddLikeCommutative(Value *LHS, Value *RHS,
     R->setHasNoUnsignedWrap(NUWOut);
     return R;
   }
+
+  // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2
+  const APInt *C1, *C2;
+  if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) {
+    APInt one(C2->getBitWidth(), 1);
+    APInt minusC1 = -(*C1);
+    if (minusC1 == (one << *C2)) {
+      Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1);
+      return BinaryOperator::CreateSRem(RHS, NewRHS);
+    }
+  }
+
   return nullptr;
 }
 
@@ -1623,18 +1635,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1)
   if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V);
 
-  // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2
-  const APInt *C1, *C2;
-  if (match(&I, m_c_Add(m_Shl(m_SDiv(m_Value(A), m_APInt(C1)), m_APInt(C2)),
-                        m_Deferred(A)))) {
-    APInt one(C2->getBitWidth(), 1);
-    APInt minusC1 = -(*C1);
-    if (minusC1 == (one << *C2)) {
-      Constant *NewRHS = ConstantInt::get(A->getType(), minusC1);
-      return BinaryOperator::CreateSRem(A, NewRHS);
-    }
-  }
-
+  const APInt *C1;
   // (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit
   if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) &&
       C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countl_zero())) {



More information about the llvm-commits mailing list