[llvm] [LoopIdiomRecognize] Replacing the idiom of counting trailing zeros with the intrinsic cttz (PR #87820)
Matvey Kalashnikov via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 5 12:00:17 PDT 2024
https://github.com/foreverjun created https://github.com/llvm/llvm-project/pull/87820
This patch adds the ability to optimize the following code patterns in which the loop body is replaced by the cttz intrinsic:
```
int ctz(uint32_t n)
{
int count = 0;
if (n == 0)
{
return 32;
}
while ((n & 1) == 0)
{
count += 1;
n >>= 1;
}
return count;
}
int ctz(uint64_t n)
{
int count = 0;
if (n != 0)
{
while ((n & 1) == 0)
{
n >>= 1;
count += 1;
}
}
else
{
return 64;
}
return count;
}
```
<details>
<summary>
This is the LLVM IR before loop-idiom pass:
</summary>
```
define dso_local signext i32 @ctz(i64 noundef %n) local_unnamed_addr #0 {
entry:
%cmp.not = icmp eq i64 %n, 0
br i1 %cmp.not, label %cleanup, label %while.cond.preheader
while.cond.preheader: ; preds = %entry
%and4 = and i64 %n, 1
%cmp15 = icmp eq i64 %and4, 0
br i1 %cmp15, label %while.body.preheader, label %cleanup
while.body.preheader: ; preds = %while.cond.preheader
%0 = tail call i64 @llvm.cttz.i64(i64 %n, i1 true), !range !9
%1 = trunc i64 %0 to i32
br label %cleanup
cleanup: ; preds = %while.body.preheader, %while.cond.preheader, %entry
%retval.0 = phi i32 [ 64, %entry ], [ 0, %while.cond.preheader ], [ %1, %while.body.preheader ]
ret i32 %retval.0
}
```
</details>
<details>
<summary>
This is LLVM IR after LoopSimplifyPass LoopIdiomRecognize and LoopIdiomRecognizePass with cttz intrinsic inserted in while.body.preheader:
</summary>
```
define dso_local signext i32 @ctz(i64 noundef %n) local_unnamed_addr #0 {
entry:
%cmp.not = icmp eq i64 %n, 0
br i1 %cmp.not, label %cleanup, label %while.cond.preheader
while.cond.preheader: ; preds = %entry
%and4 = and i64 %n, 1
%cmp15 = icmp eq i64 %and4, 0
br i1 %cmp15, label %while.body.preheader, label %cleanup
while.body.preheader: ; preds = %while.cond.preheader
%0 = call i64 @llvm.cttz.i64(i64 %n, i1 true)
%1 = trunc i64 %0 to i32
br label %while.body
while.body: ; preds = %while.body, %while.body.preheader
%count.07 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
%n.addr.06 = phi i64 [ %shr, %while.body ], [ %n, %while.body.preheader ]
%shr = lshr exact i64 %n.addr.06, 1
%add = add nuw nsw i32 %count.07, 1
%2 = and i64 %n.addr.06, 2
%cmp1 = icmp eq i64 %2, 0
br i1 %cmp1, label %while.body, label %cleanup.loopexit, !llvm.loop !9
cleanup.loopexit: ; preds = %while.body
%add.lcssa = phi i32 [ %1, %while.body ]
br label %cleanup
cleanup: ; preds = %cleanup.loopexit, %while.cond.preheader, %entry
%retval.0 = phi i32 [ 64, %entry ], [ 0, %while.cond.preheader ], [ %add.lcssa, %cleanup.loopexit ]
ret i32 %retval.0
}
```
</details>
>From 6f391b16808db9f4a87f151a3a67a106afaf4bc6 Mon Sep 17 00:00:00 2001
From: matvey <kalashnikov.matvey.m at gmail.com>
Date: Fri, 5 Apr 2024 20:49:12 +0300
Subject: [PATCH] Added trailing zeros counting pattern recognition.
---
.../Transforms/Scalar/LoopIdiomRecognize.cpp | 223 +++++++++++++++++-
llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll | 147 ++++++++++++
2 files changed, 369 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index c7e25c9f3d2c92..f4d9fd28373883 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -243,6 +243,10 @@ class LoopIdiomRecognize {
bool recognizeShiftUntilBitTest();
bool recognizeShiftUntilZero();
+ bool recognizeAndInsertCtz();
+ void transformLoopToCtz(BasicBlock *PreCondBB, Instruction *CntInst,
+ PHINode *CntPhi, Value *Var);
+
/// @}
};
} // end anonymous namespace
@@ -1484,7 +1488,8 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() {
<< CurLoop->getHeader()->getName() << "\n");
return recognizePopcount() || recognizeAndInsertFFS() ||
- recognizeShiftUntilBitTest() || recognizeShiftUntilZero();
+ recognizeShiftUntilBitTest() || recognizeShiftUntilZero() ||
+ recognizeAndInsertCtz();
}
/// Check if the given conditional branch is based on the comparison between
@@ -2868,3 +2873,219 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
++NumShiftUntilZero;
return MadeChange;
}
+
+// This function recognizes a loop that counts the number of trailing zeros
+// loop:
+// %count.010 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
+// %n.addr.09 = phi i32 [ %shr, %while.body ], [ %n, %while.body.preheader ]
+// %add = add nuw nsw i32 %count.010, 1
+// %shr = ashr exact i32 %n.addr.09, 1
+// %0 = and i32 %n.addr.09, 2
+// %cmp1 = icmp eq i32 %0, 0
+// br i1 %cmp1, label %while.body, label %if.end.loopexit
+static bool detectShiftUntilZeroAndOneIdiom(Loop *CurLoop, Value *&InitX,
+ Instruction *&CntInst,
+ PHINode *&CntPhi) {
+ BasicBlock *LoopEntry;
+ Value *VarX;
+ Instruction *DefX;
+
+ CntInst = nullptr;
+ CntPhi = nullptr;
+ LoopEntry = *(CurLoop->block_begin());
+
+ // Check if the loop-back branch is in desirable form.
+ // "if (x == 0) goto loop-entry"
+ if (Value *T = matchCondition(
+ dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry, true)) {
+ DefX = dyn_cast<Instruction>(T);
+ } else {
+ LLVM_DEBUG(dbgs() << "Bad condition for branch instruction\n");
+ return false;
+ }
+
+ // operand compares with 2, because we are looking for "x & 2"
+ // which was optimized by previous passes from "(x >> 1) & 1"
+
+ if (!match(DefX, m_c_And(PatternMatch::m_Value(VarX),
+ PatternMatch::m_SpecificInt(2))))
+ return false;
+
+ // check if VarX is a phi node
+
+ auto *PhiX = dyn_cast<PHINode>(VarX);
+
+ if (!PhiX || PhiX->getParent() != LoopEntry)
+ return false;
+
+ Instruction *DefXRShift = nullptr;
+
+ // check if PhiX has a shift instruction as a operand, which is a "x >> 1"
+
+ for (int i = 0; i < 2; ++i) {
+ if (auto *Inst = dyn_cast<Instruction>(PhiX->getOperand(i))) {
+ if (Inst->getOpcode() == Instruction::AShr ||
+ Inst->getOpcode() == Instruction::LShr) {
+ DefXRShift = Inst;
+ break;
+ }
+ }
+ }
+
+ if (DefXRShift == nullptr)
+ return false;
+
+ // check if the shift instruction is a "x >> 1"
+ auto *Shft = dyn_cast<ConstantInt>(DefXRShift->getOperand(1));
+ if (!Shft || !Shft->isOne())
+ return false;
+
+ if (DefXRShift->getOperand(0) != VarX)
+ return false;
+
+ InitX = PhiX->getIncomingValueForBlock(CurLoop->getLoopPreheader());
+
+ // Find the instruction which counts the trailing zeros: cnt.next = cnt + 1.
+ for (Instruction &Inst : llvm::make_range(
+ LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) {
+ if (Inst.getOpcode() != Instruction::Add)
+ continue;
+
+ ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1));
+ if (!Inc || !Inc->isOne())
+ continue;
+
+ PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry);
+ if (!Phi)
+ continue;
+
+ CntInst = &Inst;
+ CntPhi = Phi;
+ break;
+ }
+ if (!CntInst)
+ return false;
+
+ return true;
+}
+
+/// Recognize CTTZ idiom in a non-countable loop and convert it to countable
+/// with CTTZ of variable as a trip count. If CTTZ was inserted, returns true;
+/// otherwise, returns false.
+///
+// int count_trailing_zeroes(uint32_t n) {
+// int count = 0;
+// if (n == 0){
+// return 32;
+// }
+// while ((n & 1) == 0) {
+// count += 1;
+// n >>= 1;
+// }
+//
+//
+// return count;
+// }
+bool LoopIdiomRecognize::recognizeAndInsertCtz() {
+ // Give up if the loop has multiple blocks or multiple backedges.
+ if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
+ return false;
+
+ Value *InitX;
+ PHINode *CntPhi = nullptr;
+ Instruction *CntInst = nullptr;
+ // For counting trailing zeros with uncountable loop idiom, transformation is
+ // always profitable if IdiomCanonicalSize is 7.
+ const size_t IdiomCanonicalSize = 7;
+
+ if (!detectShiftUntilZeroAndOneIdiom(CurLoop, InitX, CntInst, CntPhi))
+ return false;
+
+ BasicBlock *PH = CurLoop->getLoopPreheader();
+
+ auto *PreCondBB = PH->getSinglePredecessor();
+ if (!PreCondBB)
+ return false;
+ auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator());
+ if (!PreCondBI)
+ return false;
+
+ // check that initial value is not zero and "(init & 1) == 0"
+ // initial value must not be zero, because it will cause infinite loop
+ // without this check, after replacing the loop with cttz, the counter will be
+ // size of int, while before the replacement the loop would have executed
+ // indefinitely
+
+ // match that case, where n is initial value
+ // entry:
+ // %cmp.not = icmp eq i32 %n, 0
+ // br i1 %cmp.not, label %cleanup, label %while.cond.preheader
+ //
+ // while.cond.preheader:
+ // %and5 = and i32 %n, 1
+ // %cmp16 = icmp eq i32 %and5, 0
+ // br i1 %cmp16, label %while.body.preheader, label %cleanup
+
+ Value *PreCond = matchCondition(PreCondBI, PH, true);
+
+ if (!PreCond)
+ return false;
+
+ Value *InitPredX = nullptr;
+ if (!match(PreCond, m_c_And(PatternMatch::m_Value(InitPredX),
+ PatternMatch::m_One())) ||
+ InitPredX != InitX)
+ return false;
+ auto *PrePreCondBB = PreCondBB->getSinglePredecessor();
+ if (!PrePreCondBB)
+ return false;
+ auto *PrePreCondBI = dyn_cast<BranchInst>(PrePreCondBB->getTerminator());
+ if (!PrePreCondBI)
+ return false;
+ if (matchCondition(PrePreCondBI, PreCondBB) != InitX)
+ return false;
+
+ // CTTZ intrinsic always profitable after deleting the loop.
+ // the loop has only 7 instructions:
+
+ // @llvm.dbg doesn't count as they have no semantic effect.
+ auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();
+ uint32_t HeaderSize =
+ std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end());
+ if (HeaderSize != IdiomCanonicalSize)
+ return false;
+
+ transformLoopToCtz(PH, CntInst, CntPhi, InitX);
+ return true;
+}
+
+void LoopIdiomRecognize::transformLoopToCtz(BasicBlock *Preheader,
+ Instruction *CntInst,
+ PHINode *CntPhi, Value *InitX) {
+ BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
+ const DebugLoc &DL = CntInst->getDebugLoc();
+
+ // Insert the CTTZ instruction at the end of the preheader block
+ IRBuilder<> Builder(PreheaderBr);
+ Builder.SetCurrentDebugLocation(DL);
+ Value *Count = createFFSIntrinsic(Builder, InitX, DL,
+ /* is zero poison */ true, Intrinsic::cttz);
+
+ Value *NewCount = Count;
+
+ NewCount = Builder.CreateZExtOrTrunc(NewCount, CntInst->getType());
+
+ Value *CntInitVal = CntPhi->getIncomingValueForBlock(Preheader);
+ // If the counter was being incremented in the loop, add NewCount to the
+ // counter's initial value, but only if the initial value is not zero.
+ ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal);
+ if (!InitConst || !InitConst->isZero())
+ NewCount = Builder.CreateAdd(NewCount, CntInitVal);
+
+ BasicBlock *Body = *(CurLoop->block_begin());
+
+ // All the references to the original counter outside
+ // the loop are replaced with the NewCount
+ CntInst->replaceUsesOutsideBlock(NewCount, Body);
+ SE->forgetLoop(CurLoop);
+}
\ No newline at end of file
diff --git a/llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll b/llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll
new file mode 100644
index 00000000000000..5c32d497829348
--- /dev/null
+++ b/llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll
@@ -0,0 +1,147 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=loop-idiom -mtriple=riscv32 -S < %s | FileCheck %s
+; RUN: opt -passes=loop-idiom -mtriple=riscv64 -S < %s | FileCheck %s
+
+; Copied from popcnt test.
+
+;To recognize this pattern:
+;int ctz(uint32_t n)
+;{
+; int count = 0;
+; if (n == 0)
+; {
+; return 32;
+; }
+; while ((n & 1) == 0)
+; {
+; count += 1;
+; n >>= 1;
+; }
+; return count;
+;}
+
+define signext i32 @count_trailing_zeroes(i32 noundef signext %n) local_unnamed_addr #0 {
+; CHECK-LABEL: @count_trailing_zeroes(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[N:%.*]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]]
+; CHECK: while.cond.preheader:
+; CHECK-NEXT: [[AND4:%.*]] = and i32 [[N]], 1
+; CHECK-NEXT: [[CMP15:%.*]] = icmp eq i32 [[AND4]], 0
+; CHECK-NEXT: br i1 [[CMP15]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]]
+; CHECK: while.body.preheader:
+; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.cttz.i32(i32 [[N]], i1 true)
+; CHECK-NEXT: br label [[WHILE_BODY:%.*]]
+; CHECK: while.body:
+; CHECK-NEXT: [[COUNT_07:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ]
+; CHECK-NEXT: [[N_ADDR_06:%.*]] = phi i32 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ]
+; CHECK-NEXT: [[ADD]] = add nuw nsw i32 [[COUNT_07]], 1
+; CHECK-NEXT: [[SHR]] = lshr i32 [[N_ADDR_06]], 1
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[N_ADDR_06]], 2
+; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[TMP1]], 0
+; CHECK-NEXT: br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]]
+; CHECK: cleanup.loopexit:
+; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP0]], [[WHILE_BODY]] ]
+; CHECK-NEXT: br label [[CLEANUP]]
+; CHECK: cleanup:
+; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ 32, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ]
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+entry:
+ %cmp = icmp eq i32 %n, 0
+ br i1 %cmp, label %cleanup, label %while.cond.preheader
+
+while.cond.preheader: ; preds = %entry
+ %and4 = and i32 %n, 1
+ %cmp15 = icmp eq i32 %and4, 0
+ br i1 %cmp15, label %while.body, label %cleanup
+
+while.body: ; preds = %while.cond.preheader, %while.body
+ %count.07 = phi i32 [ %add, %while.body ], [ 0, %while.cond.preheader ]
+ %n.addr.06 = phi i32 [ %shr, %while.body ], [ %n, %while.cond.preheader ]
+ %add = add nuw nsw i32 %count.07, 1
+ %shr = lshr i32 %n.addr.06, 1
+ %0 = and i32 %n.addr.06, 2
+ %cmp1 = icmp eq i32 %0, 0
+ br i1 %cmp1, label %while.body, label %cleanup
+
+cleanup: ; preds = %while.body, %while.cond.preheader, %entry
+ %retval.0 = phi i32 [ 32, %entry ], [ 0, %while.cond.preheader ], [ %add, %while.body ]
+ ret i32 %retval.0
+}
+
+;int ctz(uint64_t n)
+;{
+; int count = 0;
+; if (n != 0)
+; {
+; while ((n & 1) == 0)
+; {
+; n >>= 1;
+; count += 1;
+; }
+; }
+; else
+; {
+; return 64;
+; }
+; return count;
+;}
+
+define dso_local signext i32 @ctz(i64 noundef %n) local_unnamed_addr {
+; CHECK-LABEL: @ctz(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i64 [[N:%.*]], 0
+; CHECK-NEXT: br i1 [[CMP_NOT]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]]
+; CHECK: while.cond.preheader:
+; CHECK-NEXT: [[AND5:%.*]] = and i64 [[N]], 1
+; CHECK-NEXT: [[CMP16:%.*]] = icmp eq i64 [[AND5]], 0
+; CHECK-NEXT: br i1 [[CMP16]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]]
+; CHECK: while.body.preheader:
+; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.cttz.i64(i64 [[N]], i1 true)
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
+; CHECK-NEXT: br label [[WHILE_BODY:%.*]]
+; CHECK: while.body:
+; CHECK-NEXT: [[COUNT_08:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ]
+; CHECK-NEXT: [[N_ADDR_07:%.*]] = phi i64 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ]
+; CHECK-NEXT: [[SHR]] = lshr i64 [[N_ADDR_07]], 1
+; CHECK-NEXT: [[ADD]] = add nuw nsw i32 [[COUNT_08]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = and i64 [[N_ADDR_07]], 2
+; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[TMP2]], 0
+; CHECK-NEXT: br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]]
+; CHECK: cleanup.loopexit:
+; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP1]], [[WHILE_BODY]] ]
+; CHECK-NEXT: br label [[CLEANUP]]
+; CHECK: cleanup:
+; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ 64, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ]
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+entry:
+ %cmp.not = icmp eq i64 %n, 0
+ br i1 %cmp.not, label %cleanup, label %while.cond.preheader
+
+while.cond.preheader: ; preds = %entry
+ %and5 = and i64 %n, 1
+ %cmp16 = icmp eq i64 %and5, 0
+ br i1 %cmp16, label %while.body.preheader, label %cleanup
+
+while.body.preheader: ; preds = %while.cond.preheader
+ br label %while.body
+
+while.body: ; preds = %while.body.preheader, %while.body
+ %count.08 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
+ %n.addr.07 = phi i64 [ %shr, %while.body ], [ %n, %while.body.preheader ]
+ %shr = lshr i64 %n.addr.07, 1
+ %add = add nuw nsw i32 %count.08, 1
+ %0 = and i64 %n.addr.07, 2
+ %cmp1 = icmp eq i64 %0, 0
+ br i1 %cmp1, label %while.body, label %cleanup.loopexit
+
+cleanup.loopexit: ; preds = %while.body
+ %add.lcssa = phi i32 [ %add, %while.body ]
+ br label %cleanup
+
+cleanup: ; preds = %cleanup.loopexit, %while.cond.preheader, %entry
+ %retval.0 = phi i32 [ 64, %entry ], [ 0, %while.cond.preheader ], [ %add.lcssa, %cleanup.loopexit ]
+ ret i32 %retval.0
+}
More information about the llvm-commits
mailing list