[llvm] [LoopIdiomRecognize] Replacing the idiom of counting trailing zeros with the intrinsic cttz (PR #87820)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 5 12:01:06 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Matvey Kalashnikov (foreverjun)

<details>
<summary>Changes</summary>

This patch adds the ability to optimize the following code patterns in which the loop body is replaced by the cttz intrinsic:
```
int ctz(uint32_t n)
{
    int count = 0;
    if (n == 0)
    {
        return 32;
    }
    while ((n & 1) == 0)
    {
        count += 1;
        n >>= 1;
    }
    return count;
}

int ctz(uint64_t n)
{
    int count = 0;
    if (n != 0)
    {
        while ((n & 1) == 0)
        {
            n >>= 1;
            count += 1;  
        }
    }
    else
    {
        return 64;
    }
    return count;
}
```

<details>
  <summary>
    This is the LLVM IR before loop-idiom pass:
  </summary>

```
define dso_local signext i32 @<!-- -->ctz(i64 noundef %n) local_unnamed_addr #<!-- -->0 {
entry:
  %cmp.not = icmp eq i64 %n, 0
  br i1 %cmp.not, label %cleanup, label %while.cond.preheader

while.cond.preheader:                             ; preds = %entry
  %and4 = and i64 %n, 1
  %cmp15 = icmp eq i64 %and4, 0
  br i1 %cmp15, label %while.body.preheader, label %cleanup

while.body.preheader:                             ; preds = %while.cond.preheader
  %0 = tail call i64 @<!-- -->llvm.cttz.i64(i64 %n, i1 true), !range !9
  %1 = trunc i64 %0 to i32
  br label %cleanup

cleanup:                                          ; preds = %while.body.preheader, %while.cond.preheader, %entry
  %retval.0 = phi i32 [ 64, %entry ], [ 0, %while.cond.preheader ], [ %1, %while.body.preheader ]
  ret i32 %retval.0
}
```
</details>

<details>
  <summary>
    This is LLVM IR after LoopSimplifyPass LoopIdiomRecognize and LoopIdiomRecognizePass with cttz intrinsic inserted in while.body.preheader:
  </summary>

```
define dso_local signext i32 @<!-- -->ctz(i64 noundef %n) local_unnamed_addr #<!-- -->0 {
entry:
  %cmp.not = icmp eq i64 %n, 0
  br i1 %cmp.not, label %cleanup, label %while.cond.preheader

while.cond.preheader:                             ; preds = %entry
  %and4 = and i64 %n, 1
  %cmp15 = icmp eq i64 %and4, 0
  br i1 %cmp15, label %while.body.preheader, label %cleanup

while.body.preheader:                             ; preds = %while.cond.preheader
  %0 = call i64 @<!-- -->llvm.cttz.i64(i64 %n, i1 true)
  %1 = trunc i64 %0 to i32
  br label %while.body

while.body:                                       ; preds = %while.body, %while.body.preheader
  %count.07 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
  %n.addr.06 = phi i64 [ %shr, %while.body ], [ %n, %while.body.preheader ]
  %shr = lshr exact i64 %n.addr.06, 1
  %add = add nuw nsw i32 %count.07, 1
  %2 = and i64 %n.addr.06, 2
  %cmp1 = icmp eq i64 %2, 0
  br i1 %cmp1, label %while.body, label %cleanup.loopexit, !llvm.loop !9

cleanup.loopexit:                                 ; preds = %while.body
  %add.lcssa = phi i32 [ %1, %while.body ]
  br label %cleanup

cleanup:                                          ; preds = %cleanup.loopexit, %while.cond.preheader, %entry
  %retval.0 = phi i32 [ 64, %entry ], [ 0, %while.cond.preheader ], [ %add.lcssa, %cleanup.loopexit ]
  ret i32 %retval.0
}
```
</details>



---
Full diff: https://github.com/llvm/llvm-project/pull/87820.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+222-1) 
- (added) llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll (+147) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index c7e25c9f3d2c92..f4d9fd28373883 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -243,6 +243,10 @@ class LoopIdiomRecognize {
   bool recognizeShiftUntilBitTest();
   bool recognizeShiftUntilZero();
 
+  bool recognizeAndInsertCtz();
+  void transformLoopToCtz(BasicBlock *PreCondBB, Instruction *CntInst,
+                          PHINode *CntPhi, Value *Var);
+
   /// @}
 };
 } // end anonymous namespace
@@ -1484,7 +1488,8 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() {
                     << CurLoop->getHeader()->getName() << "\n");
 
   return recognizePopcount() || recognizeAndInsertFFS() ||
-         recognizeShiftUntilBitTest() || recognizeShiftUntilZero();
+         recognizeShiftUntilBitTest() || recognizeShiftUntilZero() ||
+         recognizeAndInsertCtz();
 }
 
 /// Check if the given conditional branch is based on the comparison between
@@ -2868,3 +2873,219 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
   ++NumShiftUntilZero;
   return MadeChange;
 }
+
+// This function recognizes a loop that counts the number of trailing zeros
+//  loop:
+//  %count.010 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
+//  %n.addr.09 = phi i32 [ %shr, %while.body ], [ %n, %while.body.preheader ]
+//  %add = add nuw nsw i32 %count.010, 1
+//  %shr = ashr exact i32 %n.addr.09, 1
+//  %0 = and i32 %n.addr.09, 2
+//  %cmp1 = icmp eq i32 %0, 0
+//  br i1 %cmp1, label %while.body, label %if.end.loopexit
+static bool detectShiftUntilZeroAndOneIdiom(Loop *CurLoop, Value *&InitX,
+                                            Instruction *&CntInst,
+                                            PHINode *&CntPhi) {
+  BasicBlock *LoopEntry;
+  Value *VarX;
+  Instruction *DefX;
+
+  CntInst = nullptr;
+  CntPhi = nullptr;
+  LoopEntry = *(CurLoop->block_begin());
+
+  // Check if the loop-back branch is in desirable form.
+  //  "if (x == 0) goto loop-entry"
+  if (Value *T = matchCondition(
+          dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry, true)) {
+    DefX = dyn_cast<Instruction>(T);
+  } else {
+    LLVM_DEBUG(dbgs() << "Bad condition for branch instruction\n");
+    return false;
+  }
+
+  // operand compares with 2, because we are looking for "x & 2"
+  // which was optimized by previous passes from "(x >> 1) & 1"
+
+  if (!match(DefX, m_c_And(PatternMatch::m_Value(VarX),
+                           PatternMatch::m_SpecificInt(2))))
+    return false;
+
+  // check if VarX is a phi node
+
+  auto *PhiX = dyn_cast<PHINode>(VarX);
+
+  if (!PhiX || PhiX->getParent() != LoopEntry)
+    return false;
+
+  Instruction *DefXRShift = nullptr;
+
+  // check if PhiX has a shift instruction as a operand, which is a "x >> 1"
+
+  for (int i = 0; i < 2; ++i) {
+    if (auto *Inst = dyn_cast<Instruction>(PhiX->getOperand(i))) {
+      if (Inst->getOpcode() == Instruction::AShr ||
+          Inst->getOpcode() == Instruction::LShr) {
+        DefXRShift = Inst;
+        break;
+      }
+    }
+  }
+
+  if (DefXRShift == nullptr)
+    return false;
+
+  // check if the shift instruction is a "x >> 1"
+  auto *Shft = dyn_cast<ConstantInt>(DefXRShift->getOperand(1));
+  if (!Shft || !Shft->isOne())
+    return false;
+
+  if (DefXRShift->getOperand(0) != VarX)
+    return false;
+
+  InitX = PhiX->getIncomingValueForBlock(CurLoop->getLoopPreheader());
+
+  // Find the instruction which counts the trailing zeros: cnt.next = cnt + 1.
+  for (Instruction &Inst : llvm::make_range(
+           LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) {
+    if (Inst.getOpcode() != Instruction::Add)
+      continue;
+
+    ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1));
+    if (!Inc || !Inc->isOne())
+      continue;
+
+    PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry);
+    if (!Phi)
+      continue;
+
+    CntInst = &Inst;
+    CntPhi = Phi;
+    break;
+  }
+  if (!CntInst)
+    return false;
+
+  return true;
+}
+
+/// Recognize CTTZ idiom in a non-countable loop and convert it to countable
+/// with CTTZ of variable as a trip count. If  CTTZ was inserted, returns true;
+/// otherwise, returns false.
+///
+// int count_trailing_zeroes(uint32_t n) {
+// int count = 0;
+// if (n == 0){
+//     return 32;
+// }
+// while ((n & 1) == 0) {
+//     count += 1;
+//     n >>= 1;
+// }
+//
+//
+// return count;
+// }
+bool LoopIdiomRecognize::recognizeAndInsertCtz() {
+  // Give up if the loop has multiple blocks or multiple backedges.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
+    return false;
+
+  Value *InitX;
+  PHINode *CntPhi = nullptr;
+  Instruction *CntInst = nullptr;
+  // For counting trailing zeros with uncountable loop idiom, transformation is
+  // always profitable if IdiomCanonicalSize is 7.
+  const size_t IdiomCanonicalSize = 7;
+
+  if (!detectShiftUntilZeroAndOneIdiom(CurLoop, InitX, CntInst, CntPhi))
+    return false;
+
+  BasicBlock *PH = CurLoop->getLoopPreheader();
+
+  auto *PreCondBB = PH->getSinglePredecessor();
+  if (!PreCondBB)
+    return false;
+  auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator());
+  if (!PreCondBI)
+    return false;
+
+  // check that initial value is not zero and "(init & 1) == 0"
+  // initial value must not be zero, because it will cause infinite loop
+  // without this check, after replacing the loop with cttz, the counter will be
+  // size of int, while before the replacement the loop would have executed
+  // indefinitely
+
+  // match that case, where n is initial value
+  // entry:
+  //   %cmp.not = icmp eq i32 %n, 0
+  //   br i1 %cmp.not, label %cleanup, label %while.cond.preheader
+  //
+  // while.cond.preheader:
+  //   %and5 = and i32 %n, 1
+  //   %cmp16 = icmp eq i32 %and5, 0
+  //   br i1 %cmp16, label %while.body.preheader, label %cleanup
+
+  Value *PreCond = matchCondition(PreCondBI, PH, true);
+
+  if (!PreCond)
+    return false;
+
+  Value *InitPredX = nullptr;
+  if (!match(PreCond, m_c_And(PatternMatch::m_Value(InitPredX),
+                              PatternMatch::m_One())) ||
+      InitPredX != InitX)
+    return false;
+  auto *PrePreCondBB = PreCondBB->getSinglePredecessor();
+  if (!PrePreCondBB)
+    return false;
+  auto *PrePreCondBI = dyn_cast<BranchInst>(PrePreCondBB->getTerminator());
+  if (!PrePreCondBI)
+    return false;
+  if (matchCondition(PrePreCondBI, PreCondBB) != InitX)
+    return false;
+
+  // CTTZ intrinsic always profitable after deleting the loop.
+  // the loop has only 7 instructions:
+
+  // @llvm.dbg doesn't count as they have no semantic effect.
+  auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();
+  uint32_t HeaderSize =
+      std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end());
+  if (HeaderSize != IdiomCanonicalSize)
+    return false;
+
+  transformLoopToCtz(PH, CntInst, CntPhi, InitX);
+  return true;
+}
+
+void LoopIdiomRecognize::transformLoopToCtz(BasicBlock *Preheader,
+                                            Instruction *CntInst,
+                                            PHINode *CntPhi, Value *InitX) {
+  BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
+  const DebugLoc &DL = CntInst->getDebugLoc();
+
+  // Insert the CTTZ instruction at the end of the preheader block
+  IRBuilder<> Builder(PreheaderBr);
+  Builder.SetCurrentDebugLocation(DL);
+  Value *Count = createFFSIntrinsic(Builder, InitX, DL,
+                                    /* is zero poison */ true, Intrinsic::cttz);
+
+  Value *NewCount = Count;
+
+  NewCount = Builder.CreateZExtOrTrunc(NewCount, CntInst->getType());
+
+  Value *CntInitVal = CntPhi->getIncomingValueForBlock(Preheader);
+  // If the counter was being incremented in the loop, add NewCount to the
+  // counter's initial value, but only if the initial value is not zero.
+  ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal);
+  if (!InitConst || !InitConst->isZero())
+    NewCount = Builder.CreateAdd(NewCount, CntInitVal);
+
+  BasicBlock *Body = *(CurLoop->block_begin());
+
+  // All the references to the original counter outside
+  //  the loop are replaced with the NewCount
+  CntInst->replaceUsesOutsideBlock(NewCount, Body);
+  SE->forgetLoop(CurLoop);
+}
\ No newline at end of file
diff --git a/llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll b/llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll
new file mode 100644
index 00000000000000..5c32d497829348
--- /dev/null
+++ b/llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll
@@ -0,0 +1,147 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=loop-idiom -mtriple=riscv32 -S < %s | FileCheck %s
+; RUN: opt -passes=loop-idiom -mtriple=riscv64 -S < %s | FileCheck %s
+
+; Copied from popcnt test.
+
+;To recognize this pattern:
+;int ctz(uint32_t n)
+;{
+;    int count = 0;
+;    if (n == 0)
+;    {
+;        return 32;
+;    }
+;    while ((n & 1) == 0)
+;    {
+;        count += 1;
+;        n >>= 1;
+;    }
+;    return count;
+;}
+
+define signext i32 @count_trailing_zeroes(i32 noundef signext %n) local_unnamed_addr #0 {
+; CHECK-LABEL: @count_trailing_zeroes(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[N:%.*]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]]
+; CHECK:       while.cond.preheader:
+; CHECK-NEXT:    [[AND4:%.*]] = and i32 [[N]], 1
+; CHECK-NEXT:    [[CMP15:%.*]] = icmp eq i32 [[AND4]], 0
+; CHECK-NEXT:    br i1 [[CMP15]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]]
+; CHECK:       while.body.preheader:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.cttz.i32(i32 [[N]], i1 true)
+; CHECK-NEXT:    br label [[WHILE_BODY:%.*]]
+; CHECK:       while.body:
+; CHECK-NEXT:    [[COUNT_07:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[N_ADDR_06:%.*]] = phi i32 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[ADD]] = add nuw nsw i32 [[COUNT_07]], 1
+; CHECK-NEXT:    [[SHR]] = lshr i32 [[N_ADDR_06]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[N_ADDR_06]], 2
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i32 [[TMP1]], 0
+; CHECK-NEXT:    br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]]
+; CHECK:       cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP0]], [[WHILE_BODY]] ]
+; CHECK-NEXT:    br label [[CLEANUP]]
+; CHECK:       cleanup:
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = phi i32 [ 32, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ]
+; CHECK-NEXT:    ret i32 [[RETVAL_0]]
+;
+entry:
+  %cmp = icmp eq i32 %n, 0
+  br i1 %cmp, label %cleanup, label %while.cond.preheader
+
+while.cond.preheader:                             ; preds = %entry
+  %and4 = and i32 %n, 1
+  %cmp15 = icmp eq i32 %and4, 0
+  br i1 %cmp15, label %while.body, label %cleanup
+
+while.body:                                       ; preds = %while.cond.preheader, %while.body
+  %count.07 = phi i32 [ %add, %while.body ], [ 0, %while.cond.preheader ]
+  %n.addr.06 = phi i32 [ %shr, %while.body ], [ %n, %while.cond.preheader ]
+  %add = add nuw nsw i32 %count.07, 1
+  %shr = lshr i32 %n.addr.06, 1
+  %0 = and i32 %n.addr.06, 2
+  %cmp1 = icmp eq i32 %0, 0
+  br i1 %cmp1, label %while.body, label %cleanup
+
+cleanup:                                          ; preds = %while.body, %while.cond.preheader, %entry
+  %retval.0 = phi i32 [ 32, %entry ], [ 0, %while.cond.preheader ], [ %add, %while.body ]
+  ret i32 %retval.0
+}
+
+;int ctz(uint64_t n)
+;{
+;    int count = 0;
+;    if (n != 0)
+;    {
+;        while ((n & 1) == 0)
+;        {
+;            n >>= 1;
+;            count += 1;
+;        }
+;    }
+;    else
+;    {
+;        return 64;
+;    }
+;    return count;
+;}
+
+define dso_local signext i32 @ctz(i64 noundef %n) local_unnamed_addr {
+; CHECK-LABEL: @ctz(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i64 [[N:%.*]], 0
+; CHECK-NEXT:    br i1 [[CMP_NOT]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]]
+; CHECK:       while.cond.preheader:
+; CHECK-NEXT:    [[AND5:%.*]] = and i64 [[N]], 1
+; CHECK-NEXT:    [[CMP16:%.*]] = icmp eq i64 [[AND5]], 0
+; CHECK-NEXT:    br i1 [[CMP16]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]]
+; CHECK:       while.body.preheader:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.cttz.i64(i64 [[N]], i1 true)
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
+; CHECK-NEXT:    br label [[WHILE_BODY:%.*]]
+; CHECK:       while.body:
+; CHECK-NEXT:    [[COUNT_08:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[N_ADDR_07:%.*]] = phi i64 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[SHR]] = lshr i64 [[N_ADDR_07]], 1
+; CHECK-NEXT:    [[ADD]] = add nuw nsw i32 [[COUNT_08]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = and i64 [[N_ADDR_07]], 2
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i64 [[TMP2]], 0
+; CHECK-NEXT:    br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]]
+; CHECK:       cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP1]], [[WHILE_BODY]] ]
+; CHECK-NEXT:    br label [[CLEANUP]]
+; CHECK:       cleanup:
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = phi i32 [ 64, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ]
+; CHECK-NEXT:    ret i32 [[RETVAL_0]]
+;
+entry:
+  %cmp.not = icmp eq i64 %n, 0
+  br i1 %cmp.not, label %cleanup, label %while.cond.preheader
+
+while.cond.preheader:                             ; preds = %entry
+  %and5 = and i64 %n, 1
+  %cmp16 = icmp eq i64 %and5, 0
+  br i1 %cmp16, label %while.body.preheader, label %cleanup
+
+while.body.preheader:                             ; preds = %while.cond.preheader
+  br label %while.body
+
+while.body:                                       ; preds = %while.body.preheader, %while.body
+  %count.08 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
+  %n.addr.07 = phi i64 [ %shr, %while.body ], [ %n, %while.body.preheader ]
+  %shr = lshr i64 %n.addr.07, 1
+  %add = add nuw nsw i32 %count.08, 1
+  %0 = and i64 %n.addr.07, 2
+  %cmp1 = icmp eq i64 %0, 0
+  br i1 %cmp1, label %while.body, label %cleanup.loopexit
+
+cleanup.loopexit:                                 ; preds = %while.body
+  %add.lcssa = phi i32 [ %add, %while.body ]
+  br label %cleanup
+
+cleanup:                                          ; preds = %cleanup.loopexit, %while.cond.preheader, %entry
+  %retval.0 = phi i32 [ 64, %entry ], [ 0, %while.cond.preheader ], [ %add.lcssa, %cleanup.loopexit ]
+  ret i32 %retval.0
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/87820


More information about the llvm-commits mailing list