[llvm] [AggressiveInstcombine] Fold away shift in or reduction chain. (PR #137875)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Fri May 9 04:26:35 PDT 2025
https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/137875
>From 57ff4e8313a9864bce00c140d27f534e4a2a9625 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Fri, 9 May 2025 12:26:25 +0100
Subject: [PATCH] [AggressiveInstcombine] Fold away shift in or reduction
chain. #137875
If we have icmp eq or(a, shl(b)), 0 then the shift can be removed so long as it
is nuw or nsw. It is still comparing that some bits are non-zero.
https://alive2.llvm.org/ce/z/nhrBVX.
This is also true of ne, and true for longer or chains.
---
.../AggressiveInstCombine.cpp | 57 +++++++++++++++++++
.../AggressiveInstCombine/or-shift-chain.ll | 26 +++------
2 files changed, 65 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index af994022a8ec1..8f1a216004b08 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -827,6 +827,62 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
return true;
}
+/// Combine away instructions providing they are still equivalent when compared
+/// against 0. i.e do they have any bits set.
+static Value *optimizeShiftInOrChain(Value *V, IRBuilder<> &Builder) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I || I->getOpcode() != Instruction::Or || !I->hasOneUse())
+ return nullptr;
+
+ Value *A;
+
+ // Look deeper into the chain of or's, combining away shl (so long as they are
+ // nuw or nsw).
+ Value *Op0 = I->getOperand(0);
+ if (match(Op0, m_CombineOr(m_NSWShl(m_Value(A), m_Value()),
+ m_NUWShl(m_Value(A), m_Value()))))
+ Op0 = A;
+ else if (auto *NOp = optimizeShiftInOrChain(Op0, Builder))
+ Op0 = NOp;
+
+ Value *Op1 = I->getOperand(1);
+ if (match(Op1, m_CombineOr(m_NSWShl(m_Value(A), m_Value()),
+ m_NUWShl(m_Value(A), m_Value()))))
+ Op1 = A;
+ else if (auto *NOp = optimizeShiftInOrChain(Op1, Builder))
+ Op1 = NOp;
+
+ if (Op0 != I->getOperand(0) || Op1 != I->getOperand(1))
+ return Builder.CreateOr(Op0, Op1);
+ return nullptr;
+}
+
+static bool foldICmpOrChain(Instruction &I, const DataLayout &DL,
+ TargetTransformInfo &TTI, AliasAnalysis &AA,
+ const DominatorTree &DT) {
+ CmpPredicate Pred;
+ Value *Op0;
+ if (!match(&I, m_ICmp(Pred, m_Value(Op0), m_Zero())) ||
+ !ICmpInst::isEquality(Pred))
+ return false;
+
+ // If the chain or or's matches a load, combine to that before attempting to
+ // remove shifts.
+ if (auto OpI = dyn_cast<Instruction>(Op0))
+ if (OpI->getOpcode() == Instruction::Or)
+ if (foldConsecutiveLoads(*OpI, DL, TTI, AA, DT))
+ return true;
+
+ IRBuilder<> Builder(&I);
+ // icmp eq/ne or(shl(a), b), 0 -> icmp eq/ne or(a, b), 0
+ if (auto *Res = optimizeShiftInOrChain(Op0, Builder)) {
+ I.replaceAllUsesWith(Builder.CreateICmp(Pred, Res, I.getOperand(1)));
+ return true;
+ }
+
+ return false;
+}
+
// Calculate GEP Stride and accumulated const ModOffset. Return Stride and
// ModOffset
static std::pair<APInt, APInt>
@@ -1253,6 +1309,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
MadeChange |= tryToRecognizeTableBasedCttz(I);
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
MadeChange |= foldPatternedLoads(I, DL);
+ MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
// NOTE: This function introduces erasing of the instruction `I`, so it
// needs to be called at the end of this sequence, otherwise we may make
// bugs.
diff --git a/llvm/test/Transforms/AggressiveInstCombine/or-shift-chain.ll b/llvm/test/Transforms/AggressiveInstCombine/or-shift-chain.ll
index 6816ccc7bd02b..b50f957c5aefe 100644
--- a/llvm/test/Transforms/AggressiveInstCombine/or-shift-chain.ll
+++ b/llvm/test/Transforms/AggressiveInstCombine/or-shift-chain.ll
@@ -3,8 +3,7 @@
define i1 @remove_shift_nuw_ab(i8 %a, i8 %b, i8 %s) {
; CHECK-LABEL: @remove_shift_nuw_ab(
-; CHECK-NEXT: [[T:%.*]] = shl nuw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[T]], [[B:%.*]]
+; CHECK-NEXT: [[OR:%.*]] = or i8 [[T:%.*]], [[B:%.*]]
; CHECK-NEXT: [[IC:%.*]] = icmp eq i8 [[OR]], 0
; CHECK-NEXT: ret i1 [[IC]]
;
@@ -16,8 +15,7 @@ define i1 @remove_shift_nuw_ab(i8 %a, i8 %b, i8 %s) {
define i1 @remove_shift_nuw_ba(i8 %a, i8 %b, i8 %s) {
; CHECK-LABEL: @remove_shift_nuw_ba(
-; CHECK-NEXT: [[T:%.*]] = shl nuw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[B:%.*]], [[T]]
+; CHECK-NEXT: [[OR:%.*]] = or i8 [[B:%.*]], [[T:%.*]]
; CHECK-NEXT: [[IC:%.*]] = icmp eq i8 [[OR]], 0
; CHECK-NEXT: ret i1 [[IC]]
;
@@ -29,8 +27,7 @@ define i1 @remove_shift_nuw_ba(i8 %a, i8 %b, i8 %s) {
define i1 @remove_shift_nsw(i8 %a, i8 %b, i8 %s) {
; CHECK-LABEL: @remove_shift_nsw(
-; CHECK-NEXT: [[T:%.*]] = shl nsw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[T]], [[B:%.*]]
+; CHECK-NEXT: [[OR:%.*]] = or i8 [[T:%.*]], [[B:%.*]]
; CHECK-NEXT: [[IC:%.*]] = icmp eq i8 [[OR]], 0
; CHECK-NEXT: ret i1 [[IC]]
;
@@ -42,8 +39,7 @@ define i1 @remove_shift_nsw(i8 %a, i8 %b, i8 %s) {
define i1 @remove_shift_nuw_ne(i8 %a, i8 %b, i8 %s) {
; CHECK-LABEL: @remove_shift_nuw_ne(
-; CHECK-NEXT: [[T:%.*]] = shl nuw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[T]], [[B:%.*]]
+; CHECK-NEXT: [[OR:%.*]] = or i8 [[T:%.*]], [[B:%.*]]
; CHECK-NEXT: [[IC:%.*]] = icmp eq i8 [[OR]], 0
; CHECK-NEXT: ret i1 [[IC]]
;
@@ -55,8 +51,7 @@ define i1 @remove_shift_nuw_ne(i8 %a, i8 %b, i8 %s) {
define i1 @remove_shift_nsw_ne(i8 %a, i8 %b, i8 %s) {
; CHECK-LABEL: @remove_shift_nsw_ne(
-; CHECK-NEXT: [[T:%.*]] = shl nsw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[T]], [[B:%.*]]
+; CHECK-NEXT: [[OR:%.*]] = or i8 [[T:%.*]], [[B:%.*]]
; CHECK-NEXT: [[IC:%.*]] = icmp eq i8 [[OR]], 0
; CHECK-NEXT: ret i1 [[IC]]
;
@@ -81,9 +76,8 @@ define i1 @remove_shift_wraps(i8 %a, i8 %b, i8 %s) {
define i1 @remove_shift_chain_d(i8 %a, i8 %b, i8 %c, i8 %d, i8 %s) {
; CHECK-LABEL: @remove_shift_chain_d(
-; CHECK-NEXT: [[DT:%.*]] = shl nuw i8 [[D:%.*]], [[S:%.*]]
; CHECK-NEXT: [[OR1:%.*]] = or i8 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[OR2:%.*]] = or i8 [[C:%.*]], [[DT]]
+; CHECK-NEXT: [[OR2:%.*]] = or i8 [[C:%.*]], [[DT:%.*]]
; CHECK-NEXT: [[OR:%.*]] = or i8 [[OR1]], [[OR2]]
; CHECK-NEXT: [[IC:%.*]] = icmp eq i8 [[OR]], 0
; CHECK-NEXT: ret i1 [[IC]]
@@ -98,12 +92,8 @@ define i1 @remove_shift_chain_d(i8 %a, i8 %b, i8 %c, i8 %d, i8 %s) {
define i1 @remove_shift_chain_abcd(i8 %a, i8 %b, i8 %c, i8 %d, i8 %s) {
; CHECK-LABEL: @remove_shift_chain_abcd(
-; CHECK-NEXT: [[AT:%.*]] = shl nuw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT: [[BT:%.*]] = shl nuw i8 [[B:%.*]], 2
-; CHECK-NEXT: [[CT:%.*]] = shl nuw i8 [[C:%.*]], 1
-; CHECK-NEXT: [[DT:%.*]] = shl nuw i8 [[D:%.*]], [[S]]
-; CHECK-NEXT: [[OR1:%.*]] = or i8 [[AT]], [[BT]]
-; CHECK-NEXT: [[OR2:%.*]] = or i8 [[CT]], [[DT]]
+; CHECK-NEXT: [[OR1:%.*]] = or i8 [[AT:%.*]], [[BT:%.*]]
+; CHECK-NEXT: [[OR2:%.*]] = or i8 [[CT:%.*]], [[DT:%.*]]
; CHECK-NEXT: [[OR:%.*]] = or i8 [[OR1]], [[OR2]]
; CHECK-NEXT: [[IC:%.*]] = icmp eq i8 [[OR]], 0
; CHECK-NEXT: ret i1 [[IC]]
More information about the llvm-commits
mailing list