[llvm] aba71f3 - [InstCombine] Add constant combines for `(urem/srem (mul X, Y), (mul X, Z))`

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 16 11:02:33 PDT 2023


Author: Noah Goldstein
Date: 2023-03-16T13:01:46-05:00
New Revision: aba71f37d00cf0c2de0b0d0bd24a3467fe8d697f

URL: https://github.com/llvm/llvm-project/commit/aba71f37d00cf0c2de0b0d0bd24a3467fe8d697f
DIFF: https://github.com/llvm/llvm-project/commit/aba71f37d00cf0c2de0b0d0bd24a3467fe8d697f.diff

LOG: [InstCombine] Add constant combines for `(urem/srem (mul X, Y), (mul X, Z))`

We can handle the following cases + some `nsw`/`nuw` flags:

`(srem (mul X, Y), (mul X, Z))`
    [If `srem(Y, Z) == 0`]
        -> 0
            - https://alive2.llvm.org/ce/z/PW4XZ-
    [If `srem(Y, Z) == Y`]
        -> `(mul nuw nsw X, Y)`
            - https://alive2.llvm.org/ce/z/DQe9Ek
        -> `(mul nsw X, Y)`
            - https://alive2.llvm.org/ce/z/Nr_MdH

    [If `Y`/`Z` are constant]
        -> `(mul/shl nuw nsw X, (srem Y, Z))`
            - https://alive2.llvm.org/ce/z/ccTFj2
            - https://alive2.llvm.org/ce/z/i_UQ5A
        -> `(mul/shl nsw X, (srem Y, Z))`
            - https://alive2.llvm.org/ce/z/mQKc63
            - https://alive2.llvm.org/ce/z/uERkKH

`(urem (mul X, Y), (mul X, Z))`
    [If `urem(Y, Z) == 0`]
        -> 0
            - https://alive2.llvm.org/ce/z/LL7UVR
    [If `srem(Y, Z) == Y`]
        -> `(mul nuw nsw X, Y)`
            - https://alive2.llvm.org/ce/z/9Kgs_i
        -> `(mul nuw X, Y)`
            - https://alive2.llvm.org/ce/z/ow9i8u

    [If `Y`/`Z` are constant]
        -> `(mul nuw nsw X, (srem Y, Z))`
            - https://alive2.llvm.org/ce/z/mNnQqJ
            - https://alive2.llvm.org/ce/z/Bj_DR-
            - https://alive2.llvm.org/ce/z/X6ZEtQ
        -> `(mul nuw X, (srem Y, Z))`
            - https://alive2.llvm.org/ce/z/SJYtUV

The rationale for doing this all in `InstCombine` rather than handling
the constant `mul` cases in `InstSimplify` is we often create a new
instruction because we are able to deduce more `nsw`/`nuw` flags than
the original instruction had.

Reviewed By: MattDevereau, sdesmalen

Differential Revision: https://reviews.llvm.org/D143014

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstCombine/rem-mul-shl.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 57e90076e1874..5768f71265ccd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1698,6 +1698,63 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   return nullptr;
 }
 
+// Variety of transform for (urem/srem (mul/shl X, Y), (mul/shl X, Z))
+static Instruction *simplifyIRemMulShl(BinaryOperator &I,
+                                       InstCombinerImpl &IC) {
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X;
+  const APInt *Y, *Z;
+  if (!(match(Op0, m_Mul(m_Value(X), m_APInt(Y))) &&
+        match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))) &&
+      !(match(Op0, m_Mul(m_APInt(Y), m_Value(X))) &&
+        match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))))
+    return nullptr;
+
+  bool IsSRem = I.getOpcode() == Instruction::SRem;
+
+  OverflowingBinaryOperator *BO0 = cast<OverflowingBinaryOperator>(Op0);
+  // TODO: We may be able to deduce more about nsw/nuw of BO0/BO1 based on Y >=
+  // Z or Z >= Y.
+  bool BO0HasNSW = BO0->hasNoSignedWrap();
+  bool BO0HasNUW = BO0->hasNoUnsignedWrap();
+  bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW;
+
+  APInt RemYZ = IsSRem ? Y->srem(*Z) : Y->urem(*Z);
+  // (rem (mul nuw/nsw X, Y), (mul X, Z))
+  //      if (rem Y, Z) == 0
+  //          -> 0
+  if (RemYZ.isZero() && BO0NoWrap)
+    return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType()));
+
+  OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1);
+  bool BO1HasNSW = BO1->hasNoSignedWrap();
+  bool BO1HasNUW = BO1->hasNoUnsignedWrap();
+  bool BO1NoWrap = IsSRem ? BO1HasNSW : BO1HasNUW;
+  // (rem (mul X, Y), (mul nuw/nsw X, Z))
+  //      if (rem Y, Z) == Y
+  //          -> (mul nuw/nsw X, Y)
+  if (RemYZ == *Y && BO1NoWrap) {
+    BinaryOperator *BO =
+        BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), *Y));
+    // Copy any overflow flags from Op0.
+    BO->setHasNoSignedWrap(IsSRem || BO0HasNSW);
+    BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW);
+    return BO;
+  }
+
+  // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z))
+  //      if Y >= Z
+  //          -> (mul {nuw} nsw X, (rem Y, Z))
+  if (Y->uge(*Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) {
+    BinaryOperator *BO =
+        BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ));
+    BO->setHasNoSignedWrap();
+    BO->setHasNoUnsignedWrap(BO0HasNUW);
+    return BO;
+  }
+
+  return nullptr;
+}
+
 /// This function implements the transforms common to both integer remainder
 /// instructions (urem and srem). It is called by the visitors to those integer
 /// remainder instructions.
@@ -1750,6 +1807,9 @@ Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) {
     }
   }
 
+  if (Instruction *R = simplifyIRemMulShl(I, *this))
+    return R;
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
index 4f915836a1a04..c7cb9c954bb3a 100644
--- a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
+++ b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
@@ -31,10 +31,7 @@ define i8 @urem_1_shl(i8 %X, i8 %Y) {
 
 define <vscale x 16 x i8> @urem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(<vscale x 16 x i8> %X) {
 ; CHECK-LABEL: @urem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nuw <vscale x 16 x i8> [[X:%.*]], shufflevector (<vscale x 16 x i8> insertelement (<vscale x 16 x i8> poison, i8 15, i64 0), <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
-; CHECK-NEXT:    [[BO1:%.*]] = mul <vscale x 16 x i8> [[X]], shufflevector (<vscale x 16 x i8> insertelement (<vscale x 16 x i8> poison, i8 5, i64 0), <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
-; CHECK-NEXT:    [[R:%.*]] = urem <vscale x 16 x i8> [[BO0]], [[BO1]]
-; CHECK-NEXT:    ret <vscale x 16 x i8> [[R]]
+; CHECK-NEXT:    ret <vscale x 16 x i8> zeroinitializer
 ;
   %BO0 = mul nuw <vscale x 16 x i8> %X, shufflevector(<vscale x 16 x i8> insertelement(<vscale x 16 x i8> poison, i8 15, i64 0) , <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
   %BO1 = mul <vscale x 16 x i8> %X, shufflevector(<vscale x 16 x i8> insertelement(<vscale x 16 x i8> poison, i8 5, i64 0) , <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
@@ -44,10 +41,7 @@ define <vscale x 16 x i8> @urem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(<vscale x 16
 
 define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0(i8 %X) {
 ; CHECK-LABEL: @urem_XY_XZ_with_CY_rem_CZ_eq_0(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 15
-; CHECK-NEXT:    [[BO1:%.*]] = mul i8 [[X]], 5
-; CHECK-NEXT:    [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 0
 ;
   %BO0 = mul nuw i8 %X, 15
   %BO1 = mul i8 %X, 5
@@ -70,9 +64,7 @@ define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0_fail_missing_flag(i8 %X) {
 
 define i8 @urem_XY_XZ_with_CY_lt_CZ(i8 %X) {
 ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ(
-; CHECK-NEXT:    [[BO0:%.*]] = mul i8 [[X:%.*]], 3
-; CHECK-NEXT:    [[BO1:%.*]] = mul nuw i8 [[X]], 12
-; CHECK-NEXT:    [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = mul nuw i8 [[X:%.*]], 3
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %BO0 = mul i8 %X, 3
@@ -122,9 +114,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ_fail_missing_flag(i8 %X) {
 
 define i8 @urem_XY_XZ_with_CY_gt_CZ(i8 %X) {
 ; CHECK-LABEL: @urem_XY_XZ_with_CY_gt_CZ(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 21
-; CHECK-NEXT:    [[BO1:%.*]] = mul i8 [[X]], 6
-; CHECK-NEXT:    [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = mul nuw nsw i8 [[X:%.*]], 3
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %BO0 = mul nuw i8 %X, 21
@@ -242,10 +232,7 @@ define i8 @urem_XY_XZ_with_Y_Z_is_mul_X_RemYZ_fail_missing_flags2(i8 %X, i8 %Y,
 ;; Signed Verions
 define <vscale x 16 x i8> @srem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(<vscale x 16 x i8> %X) {
 ; CHECK-LABEL: @srem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nsw <vscale x 16 x i8> [[X:%.*]], shufflevector (<vscale x 16 x i8> insertelement (<vscale x 16 x i8> poison, i8 15, i64 0), <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
-; CHECK-NEXT:    [[BO1:%.*]] = mul <vscale x 16 x i8> [[X]], shufflevector (<vscale x 16 x i8> insertelement (<vscale x 16 x i8> poison, i8 5, i64 0), <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
-; CHECK-NEXT:    [[R:%.*]] = srem <vscale x 16 x i8> [[BO0]], [[BO1]]
-; CHECK-NEXT:    ret <vscale x 16 x i8> [[R]]
+; CHECK-NEXT:    ret <vscale x 16 x i8> zeroinitializer
 ;
   %BO0 = mul nsw <vscale x 16 x i8> %X, shufflevector(<vscale x 16 x i8> insertelement(<vscale x 16 x i8> poison, i8 15, i64 0) , <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
   %BO1 = mul <vscale x 16 x i8> %X, shufflevector(<vscale x 16 x i8> insertelement(<vscale x 16 x i8> poison, i8 5, i64 0) , <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer)
@@ -255,10 +242,7 @@ define <vscale x 16 x i8> @srem_XY_XZ_with_CY_rem_CZ_eq_0_scalable(<vscale x 16
 
 define i8 @srem_XY_XZ_with_CY_rem_CZ_eq_0(i8 %X) {
 ; CHECK-LABEL: @srem_XY_XZ_with_CY_rem_CZ_eq_0(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nsw i8 [[X:%.*]], 9
-; CHECK-NEXT:    [[BO1:%.*]] = mul i8 [[X]], 3
-; CHECK-NEXT:    [[R:%.*]] = srem i8 [[BO0]], [[BO1]]
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 0
 ;
   %BO0 = mul nsw i8 %X, 9
   %BO1 = mul i8 %X, 3
@@ -294,9 +278,7 @@ define <2 x i8> @srem_XY_XZ_with_CY_lt_CZ(<2 x i8> %X) {
 
 define i8 @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out(i8 %X) {
 ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 5
-; CHECK-NEXT:    [[BO1:%.*]] = mul nsw i8 [[X]], 15
-; CHECK-NEXT:    [[R:%.*]] = srem i8 [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = mul nuw nsw i8 [[X:%.*]], 5
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %BO0 = mul nuw i8 %X, 5
@@ -346,9 +328,7 @@ define i8 @srem_XY_XZ_with_CY_gt_CZ(i8 %X) {
 
 define i8 @srem_XY_XZ_with_CY_gt_CZ_with_nuw_out(i8 %X) {
 ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ_with_nuw_out(
-; CHECK-NEXT:    [[BO0:%.*]] = mul nuw nsw i8 [[X:%.*]], 10
-; CHECK-NEXT:    [[BO1:%.*]] = mul nsw i8 [[X]], 6
-; CHECK-NEXT:    [[R:%.*]] = srem i8 [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = shl nuw nsw i8 [[X:%.*]], 2
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %BO0 = mul nsw nuw i8 %X, 10


        


More information about the llvm-commits mailing list