[llvm] 453d983 - [InstCombine] Add transforms for `(rem (shl Y, X), (shl Z, X))`
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 6 12:46:56 PDT 2023
Author: Noah Goldstein
Date: 2023-07-06T14:46:34-05:00
New Revision: 453d983d56c4f6407cf6bef52064f74d95448e11
URL: https://github.com/llvm/llvm-project/commit/453d983d56c4f6407cf6bef52064f74d95448e11
DIFF: https://github.com/llvm/llvm-project/commit/453d983d56c4f6407cf6bef52064f74d95448e11.diff
LOG: [InstCombine] Add transforms for `(rem (shl Y, X), (shl Z, X))`
This is just filling in a missing case from D144225.
We treat `(shl Y, X)` and `(shl Z, X)` as `(mul Z, 1 << X)` and `(mul
Y, 1 << X)` then reuse the same transformations that already exist.
Reviewed By: sdesmalen
Differential Revision: https://reviews.llvm.org/D147108
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
llvm/test/Transforms/InstCombine/rem-mul-shl.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 9b86c93946e8a9..a66c2071cce5bd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1759,14 +1759,20 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
return nullptr;
}
-// Variety of transform for (urem/srem (mul/shl X, Y), (mul/shl X, Z))
+// Variety of transform for:
+// (urem/srem (mul X, Y), (mul X, Z))
+// (urem/srem (shl X, Y), (shl X, Z))
+// (urem/srem (shl Y, X), (shl Z, X))
+// NB: The shift cases are really just extensions of the mul case. We treat
+// shift as Val * (1 << Amt).
static Instruction *simplifyIRemMulShl(BinaryOperator &I,
InstCombinerImpl &IC) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr;
APInt Y, Z;
+ bool ShiftByX = false;
// If V is not nullptr, it will be matched using m_Specific.
- auto MatchShiftOrMul = [](Value *Op, Value *&V, APInt &C) -> bool {
+ auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C) -> bool {
const APInt *Tmp = nullptr;
if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) ||
(V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp)))))
@@ -1774,11 +1780,34 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) ||
(V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp)))))
C = APInt(Tmp->getBitWidth(), 1) << *Tmp;
- return Tmp != nullptr;
+ if (Tmp != nullptr)
+ return true;
+
+ // Reset `V` so we don't start with specific value on next match attempt.
+ V = nullptr;
+ return false;
+ };
+
+ auto MatchShiftCX = [](Value *Op, APInt &C, Value *&V) -> bool {
+ const APInt *Tmp = nullptr;
+ if ((!V && match(Op, m_Shl(m_APInt(Tmp), m_Value(V)))) ||
+ (V && match(Op, m_Shl(m_APInt(Tmp), m_Specific(V))))) {
+ C = *Tmp;
+ return true;
+ }
+
+ // Reset `V` so we don't start with specific value on next match attempt.
+ V = nullptr;
+ return false;
};
- if (!MatchShiftOrMul(Op0, X, Y) || !MatchShiftOrMul(Op1, X, Z))
+ if (MatchShiftOrMulXC(Op0, X, Y) && MatchShiftOrMulXC(Op1, X, Z)) {
+ // pass
+ } else if (MatchShiftCX(Op0, Y, X) && MatchShiftCX(Op1, Z, X)) {
+ ShiftByX = true;
+ } else {
return nullptr;
+ }
bool IsSRem = I.getOpcode() == Instruction::SRem;
@@ -1796,6 +1825,17 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
if (RemYZ.isZero() && BO0NoWrap)
return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType()));
+ // Helper function to emit either (RemSimplificationC << X) or
+ // (RemSimplificationC * X) depending on whether we matched Op0/Op1 as
+ // (shl V, X) or (mul V, X) respectively.
+ auto CreateMulOrShift =
+ [&](const APInt &RemSimplificationC) -> BinaryOperator * {
+ Value *RemSimplification =
+ ConstantInt::get(I.getType(), RemSimplificationC);
+ return ShiftByX ? BinaryOperator::CreateShl(RemSimplification, X)
+ : BinaryOperator::CreateMul(X, RemSimplification);
+ };
+
OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1);
bool BO1HasNSW = BO1->hasNoSignedWrap();
bool BO1HasNUW = BO1->hasNoUnsignedWrap();
@@ -1804,8 +1844,7 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
// if (rem Y, Z) == Y
// -> (mul nuw/nsw X, Y)
if (RemYZ == Y && BO1NoWrap) {
- BinaryOperator *BO =
- BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), Y));
+ BinaryOperator *BO = CreateMulOrShift(Y);
// Copy any overflow flags from Op0.
BO->setHasNoSignedWrap(IsSRem || BO0HasNSW);
BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW);
@@ -1816,8 +1855,7 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
// if Y >= Z
// -> (mul {nuw} nsw X, (rem Y, Z))
if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) {
- BinaryOperator *BO =
- BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ));
+ BinaryOperator *BO = CreateMulOrShift(RemYZ);
BO->setHasNoSignedWrap();
BO->setHasNoUnsignedWrap(BO0HasNUW);
return BO;
diff --git a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
index 52cfd302064f8a..e16da2d684ee50 100644
--- a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
+++ b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
@@ -51,10 +51,7 @@ define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0(i8 %X) {
define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0_with_shl(i8 %X) {
; CHECK-LABEL: @urem_XY_XZ_with_CY_rem_CZ_eq_0_with_shl(
-; CHECK-NEXT: [[BO0:%.*]] = shl nuw i8 15, [[X:%.*]]
-; CHECK-NEXT: [[BO1:%.*]] = shl i8 5, [[X]]
-; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
-; CHECK-NEXT: ret i8 [[R]]
+; CHECK-NEXT: ret i8 0
;
%BO0 = shl nuw i8 15, %X
%BO1 = shl i8 5, %X
@@ -88,9 +85,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ(i8 %X) {
define i8 @urem_XY_XZ_with_CY_lt_CZ_with_shl(i8 %X) {
; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_with_shl(
-; CHECK-NEXT: [[BO0:%.*]] = shl i8 3, [[X:%.*]]
-; CHECK-NEXT: [[BO1:%.*]] = shl nuw i8 12, [[X]]
-; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
+; CHECK-NEXT: [[R:%.*]] = shl nuw i8 3, [[X:%.*]]
; CHECK-NEXT: ret i8 [[R]]
;
%BO0 = shl i8 3, %X
@@ -309,9 +304,7 @@ define i8 @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out(i8 %X) {
define <2 x i8> @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out_with_shl(<2 x i8> %X) {
; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out_with_shl(
-; CHECK-NEXT: [[BO0:%.*]] = shl nuw <2 x i8> <i8 3, i8 3>, [[X:%.*]]
-; CHECK-NEXT: [[BO1:%.*]] = shl nsw <2 x i8> <i8 15, i8 15>, [[X]]
-; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]]
+; CHECK-NEXT: [[R:%.*]] = shl nuw nsw <2 x i8> <i8 3, i8 3>, [[X:%.*]]
; CHECK-NEXT: ret <2 x i8> [[R]]
;
%BO0 = shl nuw <2 x i8> <i8 3, i8 3>, %X
More information about the llvm-commits
mailing list