[llvm] 453d983 - [InstCombine] Add transforms for `(rem (shl Y, X), (shl Z, X))`
    Noah Goldstein via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Thu Jul  6 12:46:56 PDT 2023
    
    
  
Author: Noah Goldstein
Date: 2023-07-06T14:46:34-05:00
New Revision: 453d983d56c4f6407cf6bef52064f74d95448e11
URL: https://github.com/llvm/llvm-project/commit/453d983d56c4f6407cf6bef52064f74d95448e11
DIFF: https://github.com/llvm/llvm-project/commit/453d983d56c4f6407cf6bef52064f74d95448e11.diff
LOG: [InstCombine] Add transforms for `(rem (shl Y, X), (shl Z, X))`
This is just filling in a missing case from D144225.
We treat `(shl Y, X)` and `(shl Z, X)` as `(mul Z, 1 << X)` and `(mul
Y, 1 << X)` then reuse the same transformations that already exist.
Reviewed By: sdesmalen
Differential Revision: https://reviews.llvm.org/D147108
Added: 
    
Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstCombine/rem-mul-shl.ll
Removed: 
    
################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 9b86c93946e8a9..a66c2071cce5bd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1759,14 +1759,20 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   return nullptr;
 }
 
-// Variety of transform for (urem/srem (mul/shl X, Y), (mul/shl X, Z))
+// Variety of transform for:
+//  (urem/srem (mul X, Y), (mul X, Z))
+//  (urem/srem (shl X, Y), (shl X, Z))
+//  (urem/srem (shl Y, X), (shl Z, X))
+// NB: The shift cases are really just extensions of the mul case. We treat
+// shift as Val * (1 << Amt).
 static Instruction *simplifyIRemMulShl(BinaryOperator &I,
                                        InstCombinerImpl &IC) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr;
   APInt Y, Z;
+  bool ShiftByX = false;
 
   // If V is not nullptr, it will be matched using m_Specific.
-  auto MatchShiftOrMul = [](Value *Op, Value *&V, APInt &C) -> bool {
+  auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C) -> bool {
     const APInt *Tmp = nullptr;
     if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) ||
         (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp)))))
@@ -1774,11 +1780,34 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
     else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) ||
              (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp)))))
       C = APInt(Tmp->getBitWidth(), 1) << *Tmp;
-    return Tmp != nullptr;
+    if (Tmp != nullptr)
+      return true;
+
+    // Reset `V` so we don't start with specific value on next match attempt.
+    V = nullptr;
+    return false;
+  };
+
+  auto MatchShiftCX = [](Value *Op, APInt &C, Value *&V) -> bool {
+    const APInt *Tmp = nullptr;
+    if ((!V && match(Op, m_Shl(m_APInt(Tmp), m_Value(V)))) ||
+        (V && match(Op, m_Shl(m_APInt(Tmp), m_Specific(V))))) {
+      C = *Tmp;
+      return true;
+    }
+
+    // Reset `V` so we don't start with specific value on next match attempt.
+    V = nullptr;
+    return false;
   };
 
-  if (!MatchShiftOrMul(Op0, X, Y) || !MatchShiftOrMul(Op1, X, Z))
+  if (MatchShiftOrMulXC(Op0, X, Y) && MatchShiftOrMulXC(Op1, X, Z)) {
+    // pass
+  } else if (MatchShiftCX(Op0, Y, X) && MatchShiftCX(Op1, Z, X)) {
+    ShiftByX = true;
+  } else {
     return nullptr;
+  }
 
   bool IsSRem = I.getOpcode() == Instruction::SRem;
 
@@ -1796,6 +1825,17 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
   if (RemYZ.isZero() && BO0NoWrap)
     return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType()));
 
+  // Helper function to emit either (RemSimplificationC << X) or
+  // (RemSimplificationC * X) depending on whether we matched Op0/Op1 as
+  // (shl V, X) or (mul V, X) respectively.
+  auto CreateMulOrShift =
+      [&](const APInt &RemSimplificationC) -> BinaryOperator * {
+    Value *RemSimplification =
+        ConstantInt::get(I.getType(), RemSimplificationC);
+    return ShiftByX ? BinaryOperator::CreateShl(RemSimplification, X)
+                    : BinaryOperator::CreateMul(X, RemSimplification);
+  };
+
   OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1);
   bool BO1HasNSW = BO1->hasNoSignedWrap();
   bool BO1HasNUW = BO1->hasNoUnsignedWrap();
@@ -1804,8 +1844,7 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
   //      if (rem Y, Z) == Y
   //          -> (mul nuw/nsw X, Y)
   if (RemYZ == Y && BO1NoWrap) {
-    BinaryOperator *BO =
-        BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), Y));
+    BinaryOperator *BO = CreateMulOrShift(Y);
     // Copy any overflow flags from Op0.
     BO->setHasNoSignedWrap(IsSRem || BO0HasNSW);
     BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW);
@@ -1816,8 +1855,7 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I,
   //      if Y >= Z
   //          -> (mul {nuw} nsw X, (rem Y, Z))
   if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) {
-    BinaryOperator *BO =
-        BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ));
+    BinaryOperator *BO = CreateMulOrShift(RemYZ);
     BO->setHasNoSignedWrap();
     BO->setHasNoUnsignedWrap(BO0HasNUW);
     return BO;
diff  --git a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
index 52cfd302064f8a..e16da2d684ee50 100644
--- a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
+++ b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll
@@ -51,10 +51,7 @@ define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0(i8 %X) {
 
 define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0_with_shl(i8 %X) {
 ; CHECK-LABEL: @urem_XY_XZ_with_CY_rem_CZ_eq_0_with_shl(
-; CHECK-NEXT:    [[BO0:%.*]] = shl nuw i8 15, [[X:%.*]]
-; CHECK-NEXT:    [[BO1:%.*]] = shl i8 5, [[X]]
-; CHECK-NEXT:    [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 0
 ;
   %BO0 = shl nuw i8 15, %X
   %BO1 = shl i8 5, %X
@@ -88,9 +85,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ(i8 %X) {
 
 define i8 @urem_XY_XZ_with_CY_lt_CZ_with_shl(i8 %X) {
 ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_with_shl(
-; CHECK-NEXT:    [[BO0:%.*]] = shl i8 3, [[X:%.*]]
-; CHECK-NEXT:    [[BO1:%.*]] = shl nuw i8 12, [[X]]
-; CHECK-NEXT:    [[R:%.*]] = urem i8 [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = shl nuw i8 3, [[X:%.*]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %BO0 = shl i8 3, %X
@@ -309,9 +304,7 @@ define i8 @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out(i8 %X) {
 
 define <2 x i8> @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out_with_shl(<2 x i8> %X) {
 ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out_with_shl(
-; CHECK-NEXT:    [[BO0:%.*]] = shl nuw <2 x i8> <i8 3, i8 3>, [[X:%.*]]
-; CHECK-NEXT:    [[BO1:%.*]] = shl nsw <2 x i8> <i8 15, i8 15>, [[X]]
-; CHECK-NEXT:    [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]]
+; CHECK-NEXT:    [[R:%.*]] = shl nuw nsw <2 x i8> <i8 3, i8 3>, [[X:%.*]]
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %BO0 = shl nuw <2 x i8> <i8 3, i8 3>, %X
        
    
    
More information about the llvm-commits
mailing list