[llvm] [CGP] [CodeGenPrepare] Folding `urem` with loop invariant value plus offset (PR #104724)

via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 18 16:16:26 PDT 2024


https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/104724

This extends the existing fold:

```
for(i = Start; i < End; ++i)
   Rem = (i nuw+- IncrLoopInvariant) u% RemAmtLoopInvariant;
```
 ->
```
Rem = (Start nuw+- IncrLoopInvariant) % RemAmtLoopInvariant;
for(i = Start; i < End; ++i, ++rem)
   Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
```

To work with a non-zero `IncrLoopInvariant`.

This is a common usage in cases such as:

```
for(i = 0; i < N; ++i)
    if ((i + 1) % X) == 0)
        do_something_occasionally_but_not_first_iter();
```

Alive2 w/ i4/unrolled 6x (needs to be ran locally due to timeout):
https://alive2.llvm.org/ce/z/6tgyN3

Exhaust proof over all uint8_t combinations in C++:
https://godbolt.org/z/WYa561388


>From 7c66607babaa7860c72fe1d144f207fa09804809 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Sun, 18 Aug 2024 16:12:15 -0700
Subject: [PATCH] [CGP] [CodeGenPrepare] Folding `urem` with loop invariant
 value plus offset

This extends the existing fold:

```
for(i = Start; i < End; ++i)
   Rem = (i nuw+- IncrLoopInvariant) u% RemAmtLoopInvariant;
```
 ->
```
Rem = (Start nuw+- IncrLoopInvariant) % RemAmtLoopInvariant;
for(i = Start; i < End; ++i, ++rem)
   Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
```

To work with a non-zero `IncrLoopInvariant`.

This is a common usage in cases such as:

```
for(i = 0; i < N; ++i)
    if ((i + 1) % X) == 0)
        do_something_occasionally_but_not_first_iter();
```

Alive2 w/ i4/unrolled 6x (needs to be ran locally due to timeout):
https://alive2.llvm.org/ce/z/6tgyN3

Exhaust proof over all uint8_t combinations in C++:
https://godbolt.org/z/WYa561388
---
 llvm/lib/CodeGen/CodeGenPrepare.cpp           | 72 ++++++++++++++++---
 .../CodeGenPrepare/X86/fold-loop-of-urem.ll   | 12 ++--
 2 files changed, 72 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 72db1f4adabf76..2ad56031acf992 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -1976,17 +1976,43 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI,
   return true;
 }
 
-static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem,
-                                                  const LoopInfo *LI,
-                                                  Value *&RemAmtOut,
-                                                  PHINode *&LoopIncrPNOut) {
+static bool isRemOfLoopIncrementWithLoopInvariant(
+    Instruction *Rem, const LoopInfo *LI, Value *&RemAmtOut,
+    std::optional<bool> &AddOrSubOut, Value *&AddOrSubInstOut,
+    Value *&AddOrSubOffsetOut, PHINode *&LoopIncrPNOut) {
   Value *Incr, *RemAmt;
   // NB: If RemAmt is a power of 2 it *should* have been transformed by now.
   if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt))))
     return false;
 
+  std::optional<bool> AddOrSub;
+  Value *AddOrSubOffset;
   // Find out loop increment PHI.
   auto *PN = dyn_cast<PHINode>(Incr);
+  if (PN != nullptr) {
+    AddOrSub = std::nullopt;
+    AddOrSubOffset = nullptr;
+  } else {
+    // Search through a NUW add/sub on top of the loop increment.
+    Value *V0, *V1;
+    if (match(Incr, m_NUWAddLike(m_Value(V0), m_Value(V1))))
+      AddOrSub = true;
+    else if (match(Incr, m_NUWSub(m_Value(V0), m_Value(V1))))
+      AddOrSub = false;
+    else
+      return false;
+
+    AddOrSubInstOut = Incr;
+
+    PN = dyn_cast<PHINode>(V0);
+    if (PN != nullptr) {
+      AddOrSubOffset = V1;
+    } else if (*AddOrSub) {
+      PN = dyn_cast<PHINode>(V1);
+      AddOrSubOffset = V0;
+    }
+  }
+
   if (!PN)
     return false;
 
@@ -2027,6 +2053,8 @@ static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem,
   // Set output variables.
   RemAmtOut = RemAmt;
   LoopIncrPNOut = PN;
+  AddOrSubOut = AddOrSub;
+  AddOrSubOffsetOut = AddOrSubOffset;
 
   return true;
 }
@@ -2041,15 +2069,15 @@ static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem,
 // Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant;
 // for(i = Start; i < End; ++i, ++rem)
 //    Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
-//
-// Currently only implemented for `IncrLoopInvariant` being zero.
 static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL,
                                     const LoopInfo *LI,
                                     SmallSet<BasicBlock *, 32> &FreshBBs,
                                     bool IsHuge) {
-  Value *RemAmt;
+  std::optional<bool> AddOrSub;
+  Value *AddOrSubOffset, *RemAmt, *AddOrSubInst;
   PHINode *LoopIncrPN;
-  if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, LoopIncrPN))
+  if (!isRemOfLoopIncrementWithLoopInvariant(
+          Rem, LI, RemAmt, AddOrSub, AddOrSubInst, AddOrSubOffset, LoopIncrPN))
     return false;
 
   // Only non-constant remainder as the extra IV is probably not profitable
@@ -2066,7 +2094,33 @@ static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL,
     return false;
   Loop *L = LI->getLoopFor(Rem->getParent());
 
+  // If we have add/sub create initial value for remainder.
+  // The logic here is:
+  // (urem (add/sub nuw Start, IncrLoopInvariant), RemAmtLoopInvariant
+  //
+  // Only proceed if the expression simplifies (otherwise we can't fully
+  // optimize out the urem).
   Value *Start = LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader());
+  if (AddOrSub) {
+    assert(AddOrSubOffset && AddOrSubInst &&
+           "We found an add/sub but missing values");
+    // Without dom-condition/assumption cache we aren't likely to get much out
+    // of a context instruction.
+    const SimplifyQuery Q(*DL);
+    bool NSW = cast<OverflowingBinaryOperator>(AddOrSubInst)->hasNoSignedWrap();
+    if (*AddOrSub)
+      Start = simplifyAddInst(Start, AddOrSubOffset, /*IsNSW=*/NSW,
+                              /*IsNUW=*/true, Q);
+    else
+      Start = simplifySubInst(Start, AddOrSubOffset, /*IsNSW=*/NSW,
+                              /*IsNUW=*/true, Q);
+    if (!Start)
+      return false;
+
+    Start = simplifyURemInst(Start, RemAmt, Q);
+    if (!Start)
+      return false;
+  }
 
   // Create new remainder with induction variable.
   Type *Ty = Rem->getType();
@@ -2093,6 +2147,8 @@ static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL,
 
   replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge);
   Rem->eraseFromParent();
+  if (AddOrSubInst && AddOrSubInst->use_empty())
+    cast<Instruction>(AddOrSubInst)->eraseFromParent();
   return true;
 }
 
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/fold-loop-of-urem.ll b/llvm/test/Transforms/CodeGenPrepare/X86/fold-loop-of-urem.ll
index 304ae337ed4197..61bbe08e293b39 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/fold-loop-of-urem.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/fold-loop-of-urem.ll
@@ -706,10 +706,12 @@ define void @simple_urem_to_sel_non_zero_start_through_add(i32 %N, i32 %rem_amt_
 ; CHECK:       [[FOR_COND_CLEANUP]]:
 ; CHECK-NEXT:    ret void
 ; CHECK:       [[FOR_BODY]]:
+; CHECK-NEXT:    [[REM:%.*]] = phi i32 [ 7, %[[FOR_BODY_PREHEADER]] ], [ [[TMP3:%.*]], %[[FOR_BODY]] ]
 ; CHECK-NEXT:    [[I_04:%.*]] = phi i32 [ [[INC:%.*]], %[[FOR_BODY]] ], [ 2, %[[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT:    [[I_WITH_OFF:%.*]] = add nuw i32 [[I_04]], 5
-; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[I_WITH_OFF]], [[REM_AMT]]
 ; CHECK-NEXT:    tail call void @use.i32(i32 [[REM]])
+; CHECK-NEXT:    [[TMP1:%.*]] = add nuw i32 [[REM]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i32 [[TMP1]], [[REM_AMT]]
+; CHECK-NEXT:    [[TMP3]] = select i1 [[TMP2]], i32 0, i32 [[TMP1]]
 ; CHECK-NEXT:    [[INC]] = add nuw i32 [[I_04]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC]], [[N]]
 ; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label %[[FOR_COND_CLEANUP]], label %[[FOR_BODY]]
@@ -817,10 +819,12 @@ define void @simple_urem_to_sel_non_zero_start_through_sub(i32 %N, i32 %rem_amt,
 ; CHECK:       [[FOR_COND_CLEANUP]]:
 ; CHECK-NEXT:    ret void
 ; CHECK:       [[FOR_BODY]]:
+; CHECK-NEXT:    [[REM:%.*]] = phi i32 [ 0, %[[FOR_BODY_PREHEADER]] ], [ [[TMP3:%.*]], %[[FOR_BODY]] ]
 ; CHECK-NEXT:    [[I_04:%.*]] = phi i32 [ [[INC:%.*]], %[[FOR_BODY]] ], [ [[START]], %[[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT:    [[I_WITH_OFF:%.*]] = sub nuw i32 [[I_04]], [[START]]
-; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[I_WITH_OFF]], [[REM_AMT]]
 ; CHECK-NEXT:    tail call void @use.i32(i32 [[REM]])
+; CHECK-NEXT:    [[TMP1:%.*]] = add nuw i32 [[REM]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i32 [[TMP1]], [[REM_AMT]]
+; CHECK-NEXT:    [[TMP3]] = select i1 [[TMP2]], i32 0, i32 [[TMP1]]
 ; CHECK-NEXT:    [[INC]] = add nuw i32 [[I_04]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC]], [[N]]
 ; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label %[[FOR_COND_CLEANUP]], label %[[FOR_BODY]]



More information about the llvm-commits mailing list