[llvm] 2116921 - [InstCombine] Fold `select` of `srem` and conditional add

Antonio Frighetto via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 7 17:04:09 PDT 2023


Author: Antonio Frighetto
Date: 2023-08-08T00:02:16Z
New Revision: 211692137af45fcf6d6b14582e004a746f9e5b2e

URL: https://github.com/llvm/llvm-project/commit/211692137af45fcf6d6b14582e004a746f9e5b2e
DIFF: https://github.com/llvm/llvm-project/commit/211692137af45fcf6d6b14582e004a746f9e5b2e.diff

LOG: [InstCombine] Fold `select` of `srem` and conditional add

Simplify a pattern that may show up when computing
the remainder of euclidean division. Particularly,
when the divisor is a power of two and never negative,
the signed remainder can be folded with a bitwise and.

Fixes 64305.

Proofs: https://alive2.llvm.org/ce/z/9_KG6c

Differential Revision: https://reviews.llvm.org/D156811

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/select-divrem.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 661c50062223c4..3b7875dd761bce 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2584,6 +2584,48 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT,
   return nullptr;
 }
 
+/// Tries to reduce a pattern that arises when calculating the remainder of the
+/// Euclidean division. When the divisor is a power of two and is guaranteed not
+/// to be negative, a signed remainder can be folded with a bitwise and.
+///
+/// (x % n) < 0 ? (x % n) + n : (x % n)
+///    -> x & (n - 1)
+static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
+                                       IRBuilderBase &Builder) {
+  Value *CondVal = SI.getCondition();
+  Value *TrueVal = SI.getTrueValue();
+  Value *FalseVal = SI.getFalseValue();
+
+  ICmpInst::Predicate Pred;
+  Value *Op, *RemRes, *Remainder;
+  const APInt *C;
+  bool TrueIfSigned = false;
+
+  if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) &&
+        IC.isSignBitCheck(Pred, *C, TrueIfSigned)))
+    return nullptr;
+
+  // If the sign bit is not set, we have a SGE/SGT comparison, and the operands
+  // of the select are inverted.
+  if (!TrueIfSigned)
+    std::swap(TrueVal, FalseVal);
+
+  // We are matching a quite specific pattern here:
+  // %rem = srem i32 %x, %n
+  // %cnd = icmp slt i32 %rem, 0
+  // %add = add i32 %rem, %n
+  // %sel = select i1 %cnd, i32 %add, i32 %rem
+  if (!(match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) &&
+        match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) &&
+        IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) &&
+        FalseVal == RemRes))
+    return nullptr;
+
+  Value *Add = Builder.CreateAdd(Remainder,
+                                 Constant::getAllOnesValue(RemRes->getType()));
+  return BinaryOperator::CreateAnd(Op, Add);
+}
+
 static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
   FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
   if (!FI)
@@ -3430,6 +3472,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   if (Instruction *I = foldSelectExtConst(SI))
     return I;
 
+  if (Instruction *I = foldSelectWithSRem(SI, *this, Builder))
+    return I;
+
   // Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0))
   // Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx))
   auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base,

diff  --git a/llvm/test/Transforms/InstCombine/select-divrem.ll b/llvm/test/Transforms/InstCombine/select-divrem.ll
index 1343191e349d71..a5b56609d60620 100644
--- a/llvm/test/Transforms/InstCombine/select-divrem.ll
+++ b/llvm/test/Transforms/InstCombine/select-divrem.ll
@@ -216,10 +216,7 @@ define i5 @urem_common_dividend_defined_cond(i1 noundef %b, i5 %x, i5 %y, i5 %z)
 
 define i32 @rem_euclid_1(i32 %0) {
 ; CHECK-LABEL: @rem_euclid_1(
-; CHECK-NEXT:    [[REM:%.*]] = srem i32 [[TMP0:%.*]], 8
-; CHECK-NEXT:    [[COND:%.*]] = icmp slt i32 [[REM]], 0
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[REM]], 8
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], i32 [[ADD]], i32 [[REM]]
+; CHECK-NEXT:    [[SEL:%.*]] = and i32 [[TMP0:%.*]], 7
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
   %rem = srem i32 %0, 8
@@ -231,10 +228,7 @@ define i32 @rem_euclid_1(i32 %0) {
 
 define i32 @rem_euclid_2(i32 %0) {
 ; CHECK-LABEL: @rem_euclid_2(
-; CHECK-NEXT:    [[REM:%.*]] = srem i32 [[TMP0:%.*]], 8
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[REM]], 8
-; CHECK-NEXT:    [[COND1:%.*]] = icmp slt i32 [[REM]], 0
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND1]], i32 [[ADD]], i32 [[REM]]
+; CHECK-NEXT:    [[SEL:%.*]] = and i32 [[TMP0:%.*]], 7
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
   %rem = srem i32 %0, 8
@@ -291,10 +285,7 @@ define i32 @rem_euclid_wrong_operands_select(i32 %0) {
 
 define <2 x i32> @rem_euclid_vec(<2 x i32> %0) {
 ; CHECK-LABEL: @rem_euclid_vec(
-; CHECK-NEXT:    [[REM:%.*]] = srem <2 x i32> [[TMP0:%.*]], <i32 8, i32 8>
-; CHECK-NEXT:    [[COND:%.*]] = icmp slt <2 x i32> [[REM]], zeroinitializer
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw <2 x i32> [[REM]], <i32 8, i32 8>
-; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x i32> [[ADD]], <2 x i32> [[REM]]
+; CHECK-NEXT:    [[SEL:%.*]] = and <2 x i32> [[TMP0:%.*]], <i32 7, i32 7>
 ; CHECK-NEXT:    ret <2 x i32> [[SEL]]
 ;
   %rem = srem <2 x i32> %0, <i32 8, i32 8>
@@ -306,10 +297,7 @@ define <2 x i32> @rem_euclid_vec(<2 x i32> %0) {
 
 define i128 @rem_euclid_i128(i128 %0) {
 ; CHECK-LABEL: @rem_euclid_i128(
-; CHECK-NEXT:    [[REM:%.*]] = srem i128 [[TMP0:%.*]], 8
-; CHECK-NEXT:    [[COND:%.*]] = icmp slt i128 [[REM]], 0
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i128 [[REM]], 8
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], i128 [[ADD]], i128 [[REM]]
+; CHECK-NEXT:    [[SEL:%.*]] = and i128 [[TMP0:%.*]], 7
 ; CHECK-NEXT:    ret i128 [[SEL]]
 ;
   %rem = srem i128 %0, 8
@@ -321,11 +309,9 @@ define i128 @rem_euclid_i128(i128 %0) {
 
 define i8 @rem_euclid_non_const_pow2(i8 %0, i8 %1) {
 ; CHECK-LABEL: @rem_euclid_non_const_pow2(
-; CHECK-NEXT:    [[POW2:%.*]] = shl nuw i8 1, [[TMP0:%.*]]
-; CHECK-NEXT:    [[REM:%.*]] = srem i8 [[TMP1:%.*]], [[POW2]]
-; CHECK-NEXT:    [[COND:%.*]] = icmp slt i8 [[REM]], 0
-; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[COND]], i8 [[POW2]], i8 0
-; CHECK-NEXT:    [[SEL:%.*]] = add i8 [[REM]], [[ADD]]
+; CHECK-NEXT:    [[NOTMASK:%.*]] = shl nsw i8 -1, [[TMP0:%.*]]
+; CHECK-NEXT:    [[TMP3:%.*]] = xor i8 [[NOTMASK]], -1
+; CHECK-NEXT:    [[SEL:%.*]] = and i8 [[TMP3]], [[TMP1:%.*]]
 ; CHECK-NEXT:    ret i8 [[SEL]]
 ;
   %pow2 = shl i8 1, %0


        


More information about the llvm-commits mailing list