[llvm] 69ba565 - [InstCombine] Handle commuted pattern for `((X s/ C1) << C2) + X` (#121737)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 6 05:43:55 PST 2025
Author: Yingwei Zheng
Date: 2025-01-06T21:43:52+08:00
New Revision: 69ba565734a64bea91062bfd0c5988276b73eb87
URL: https://github.com/llvm/llvm-project/commit/69ba565734a64bea91062bfd0c5988276b73eb87
DIFF: https://github.com/llvm/llvm-project/commit/69ba565734a64bea91062bfd0c5988276b73eb87.diff
LOG: [InstCombine] Handle commuted pattern for `((X s/ C1) << C2) + X` (#121737)
Closes https://github.com/llvm/llvm-project/issues/121700
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 9dc593bdf3058f..73876d00e73a7c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1326,6 +1326,18 @@ Instruction *InstCombinerImpl::foldAddLikeCommutative(Value *LHS, Value *RHS,
R->setHasNoUnsignedWrap(NUWOut);
return R;
}
+
+ // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2
+ const APInt *C1, *C2;
+ if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) {
+ APInt One(C2->getBitWidth(), 1);
+ APInt MinusC1 = -(*C1);
+ if (MinusC1 == (One << *C2)) {
+ Constant *NewRHS = ConstantInt::get(RHS->getType(), MinusC1);
+ return BinaryOperator::CreateSRem(RHS, NewRHS);
+ }
+ }
+
return nullptr;
}
@@ -1623,17 +1635,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
// X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1)
if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V);
- // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2
- const APInt *C1, *C2;
- if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) {
- APInt one(C2->getBitWidth(), 1);
- APInt minusC1 = -(*C1);
- if (minusC1 == (one << *C2)) {
- Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1);
- return BinaryOperator::CreateSRem(RHS, NewRHS);
- }
- }
-
+ const APInt *C1;
// (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit
if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) &&
C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countl_zero())) {
diff --git a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
index 84462f9a7f592b..d4edf12eba6acf 100644
--- a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
+++ b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll
@@ -12,6 +12,17 @@ define i8 @add-shl-sdiv-scalar0(i8 %x) {
ret i8 %rz
}
+define i8 @add-shl-sdiv-scalar0_commuted(i8 %x) {
+; CHECK-LABEL: @add-shl-sdiv-scalar0_commuted(
+; CHECK-NEXT: [[RZ:%.*]] = srem i8 [[X:%.*]], 4
+; CHECK-NEXT: ret i8 [[RZ]]
+;
+ %sd = sdiv i8 %x, -4
+ %sl = shl i8 %sd, 2
+ %rz = add i8 %x, %sl
+ ret i8 %rz
+}
+
define i8 @add-shl-sdiv-scalar1(i8 %x) {
; CHECK-LABEL: @add-shl-sdiv-scalar1(
; CHECK-NEXT: [[RZ:%.*]] = srem i8 [[X:%.*]], 64
More information about the llvm-commits
mailing list