[llvm] 2cb6b06 - [InstCombine] Add constant combines for `(urem/srem (shl X, Y), (shl X, Z))`

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 6 12:46:52 PDT 2023


Author: Noah Goldstein
Date: 2023-07-06T14:46:34-05:00
New Revision: 2cb6b06c8930f5254e8fd393b81e3c96884840d3

URL: https://github.com/llvm/llvm-project/commit/2cb6b06c8930f5254e8fd393b81e3c96884840d3
DIFF: https://github.com/llvm/llvm-project/commit/2cb6b06c8930f5254e8fd393b81e3c96884840d3.diff

LOG: [InstCombine] Add constant combines for `(urem/srem (shl X, Y), (shl X, Z))`

Forked from D142901 to deduce more `nsw`/`nuw` flag for the output
`shl`.

We can handle the following cases + some `nsw`/`nuw` flags:

The rationale for doing this all in `InstCombine` rather than handling
the constant `shl` cases in `InstSimplify` is we often create a new
instruction because we are able to deduce more `nsw`/`nuw` flags than
the original instruction had.

Differential Revision: https://reviews.llvm.org/D144225

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstCombine/rem-mul-shl.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f6912f9a3cb849..9b86c93946e8a9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1762,12 +1762,22 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
 // Variety of transform for (urem/srem (mul/shl X, Y), (mul/shl X, Z))
 static Instruction *simplifyIRemMulShl(BinaryOperator &I,
                                        InstCombinerImpl &IC) {
-  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X;
-  const APInt *Y, *Z;
-  if (!(match(Op0, m_Mul(m_Value(X), m_APInt(Y))) &&
-        match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))) &&
-      !(match(Op0, m_Mul(m_APInt(Y), m_Value(X))) &&
-        match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))))
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr;
+  APInt Y, Z;
+
+  // If V is not nullptr, it will be matched using m_Specific.
+  auto MatchShiftOrMul = [](Value *Op, Value *&V, APInt &C) -> bool {
+    const APInt *Tmp = nullptr;
+    if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) ||
+        (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp)))))
+      C = *Tmp;
+    else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) ||
+             (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp)))))
+      C = APInt(Tmp->getBitWidth(), 1) << *Tmp;
+    return Tmp != nullptr;
+  };
+
+  if (!MatchShiftOrMul(Op0, X, Y) || !MatchShiftOrMul(Op1, X, Z))
     return nullptr;
 
   bool IsSRem = I.getOpcode() == Instruction::SRem;
@@ -1779,7 +1789,7 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
   bool BO0HasNUW = BO0->hasNoUnsignedWrap();
   bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW;
 
-  APInt RemYZ = IsSRem ? Y->srem(*Z) : Y->urem(*Z);
+  APInt RemYZ = IsSRem ? Y.srem(Z) : Y.urem(Z);
   // (rem (mul nuw/nsw X, Y), (mul X, Z))
   //      if (rem Y, Z) == 0
   //          -> 0
@@ -1793,9 +1803,9 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
   // (rem (mul X, Y), (mul nuw/nsw X, Z))
   //      if (rem Y, Z) == Y
   //          -> (mul nuw/nsw X, Y)
-  if (RemYZ == *Y && BO1NoWrap) {
+  if (RemYZ == Y && BO1NoWrap) {
     BinaryOperator *BO =
-        BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), *Y));
+        BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), Y));
     // Copy any overflow flags from Op0.
     BO->setHasNoSignedWrap(IsSRem || BO0HasNSW);
     BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW);
@@ -1805,7 +1815,7 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
   // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z))
   //      if Y >= Z
   //          -> (mul {nuw} nsw X, (rem Y, Z))
-  if (Y->uge(*Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) {
+  if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) {
     BinaryOperator *BO =
         BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ));
     BO->setHasNoSignedWrap();

diff  --git a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
index c7cb9c954bb3a8..b6b81aa5981395 100644
--- a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
+++ b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
@@ -75,9 +75,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ(i8 %X) {
 
 define <2 x i8> @urem_XY_XZ_with_CY_lt_CZ_with_nsw_out(<2 x i8> %X) {
 ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_with_nsw_out(
-; CHECK-NEXT:    [[BO0:%.*]] = shl nsw <2 x i8> [[X:%.*]], <i8 2, i8 2>
-; CHECK-NEXT:    [[BO1:%.*]] = mul nuw <2 x i8> [[X]], <i8 12, i8 12>
-; CHECK-NEXT:    [[R:%.*]] = urem <2 x i8> [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = shl nuw nsw <2 x i8> [[X:%.*]], <i8 2, i8 2>
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %BO0 = shl nsw <2 x i8> %X, <i8 2, i8 2>
@@ -88,9 +86,7 @@ define <2 x i8> @urem_XY_XZ_with_CY_lt_CZ_with_nsw_out(<2 x i8> %X) {
 
 define i8 @urem_XY_XZ_with_CY_lt_CZ_no_nsw_out(i8 %X) {
 ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_no_nsw_out(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 3
-; CHECK-NEXT:    [[BO1:%.*]] = shl nuw nsw i8 [[X]], 3
-; CHECK-NEXT:    [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = mul nuw i8 [[X:%.*]], 3
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %BO0 = mul nuw i8 %X, 3
@@ -265,9 +261,7 @@ define i8 @srem_XY_XZ_with_CY_rem_CZ_eq_0_fail_missing_flag(i8 %X) {
 
 define <2 x i8> @srem_XY_XZ_with_CY_lt_CZ(<2 x i8> %X) {
 ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ(
-; CHECK-NEXT:    [[BO0:%.*]] = shl <2 x i8> [[X:%.*]], <i8 3, i8 3>
-; CHECK-NEXT:    [[BO1:%.*]] = mul nsw <2 x i8> [[X]], <i8 15, i8 15>
-; CHECK-NEXT:    [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], <i8 3, i8 3>
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %BO0 = shl <2 x i8> %X, <i8 3, i8 3>
@@ -289,9 +283,7 @@ define i8 @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out(i8 %X) {
 
 define i8 @srem_XY_XZ_with_CY_lt_CZ_no_nsw_out(i8 %X) {
 ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_no_nsw_out(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nsw i8 [[X:%.*]], 5
-; CHECK-NEXT:    [[BO1:%.*]] = shl nuw nsw i8 [[X]], 4
-; CHECK-NEXT:    [[R:%.*]] = srem i8 [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = mul nsw i8 [[X:%.*]], 5
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %BO0 = mul nsw i8 %X, 5
@@ -315,9 +307,7 @@ define i8 @srem_XY_XZ_with_CY_lt_CZ_fail_missing_flag(i8 %X) {
 
 define i8 @srem_XY_XZ_with_CY_gt_CZ(i8 %X) {
 ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ(
-; CHECK-NEXT:    [[BO0:%.*]] = shl nsw i8 [[X:%.*]], 3
-; CHECK-NEXT:    [[BO1:%.*]] = mul nsw i8 [[X]], 6
-; CHECK-NEXT:    [[R:%.*]] = srem i8 [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = shl nsw i8 [[X:%.*]], 1
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %BO0 = shl nsw i8 %X, 3
@@ -339,9 +329,7 @@ define i8 @srem_XY_XZ_with_CY_gt_CZ_with_nuw_out(i8 %X) {
 
 define <2 x i8> @srem_XY_XZ_with_CY_gt_CZ_no_nuw_out(<2 x i8> %X) {
 ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ_no_nuw_out(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nsw <2 x i8> [[X:%.*]], <i8 10, i8 10>
-; CHECK-NEXT:    [[BO1:%.*]] = shl nuw nsw <2 x i8> [[X]], <i8 3, i8 3>
-; CHECK-NEXT:    [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], <i8 1, i8 1>
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %BO0 = mul nsw <2 x i8> %X, <i8 10, i8 10>


        


More information about the llvm-commits mailing list