[llvm] [AggressiveInstcombine] Fold away shift in or reduction chain. (PR #137875)

David Green via llvm-commits llvm-commits at lists.llvm.org
Fri May 9 04:26:35 PDT 2025


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/137875

>From 57ff4e8313a9864bce00c140d27f534e4a2a9625 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Fri, 9 May 2025 12:26:25 +0100
Subject: [PATCH] [AggressiveInstcombine] Fold away shift in or reduction
 chain. #137875

If we have icmp eq or(a, shl(b)), 0 then the shift can be removed so long as it
is nuw or nsw. It is still comparing that some bits are non-zero.
https://alive2.llvm.org/ce/z/nhrBVX.

This is also true of ne, and true for longer or chains.
---
 .../AggressiveInstCombine.cpp                 | 57 +++++++++++++++++++
 .../AggressiveInstCombine/or-shift-chain.ll   | 26 +++------
 2 files changed, 65 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index af994022a8ec1..8f1a216004b08 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -827,6 +827,62 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
   return true;
 }
 
+/// Combine away instructions providing they are still equivalent when compared
+/// against 0. i.e do they have any bits set.
+static Value *optimizeShiftInOrChain(Value *V, IRBuilder<> &Builder) {
+  auto *I = dyn_cast<Instruction>(V);
+  if (!I || I->getOpcode() != Instruction::Or || !I->hasOneUse())
+    return nullptr;
+
+  Value *A;
+
+  // Look deeper into the chain of or's, combining away shl (so long as they are
+  // nuw or nsw).
+  Value *Op0 = I->getOperand(0);
+  if (match(Op0, m_CombineOr(m_NSWShl(m_Value(A), m_Value()),
+                             m_NUWShl(m_Value(A), m_Value()))))
+    Op0 = A;
+  else if (auto *NOp = optimizeShiftInOrChain(Op0, Builder))
+    Op0 = NOp;
+
+  Value *Op1 = I->getOperand(1);
+  if (match(Op1, m_CombineOr(m_NSWShl(m_Value(A), m_Value()),
+                             m_NUWShl(m_Value(A), m_Value()))))
+    Op1 = A;
+  else if (auto *NOp = optimizeShiftInOrChain(Op1, Builder))
+    Op1 = NOp;
+
+  if (Op0 != I->getOperand(0) || Op1 != I->getOperand(1))
+    return Builder.CreateOr(Op0, Op1);
+  return nullptr;
+}
+
+static bool foldICmpOrChain(Instruction &I, const DataLayout &DL,
+                            TargetTransformInfo &TTI, AliasAnalysis &AA,
+                            const DominatorTree &DT) {
+  CmpPredicate Pred;
+  Value *Op0;
+  if (!match(&I, m_ICmp(Pred, m_Value(Op0), m_Zero())) ||
+      !ICmpInst::isEquality(Pred))
+    return false;
+
+  // If the chain or or's matches a load, combine to that before attempting to
+  // remove shifts.
+  if (auto OpI = dyn_cast<Instruction>(Op0))
+    if (OpI->getOpcode() == Instruction::Or)
+      if (foldConsecutiveLoads(*OpI, DL, TTI, AA, DT))
+        return true;
+
+  IRBuilder<> Builder(&I);
+  // icmp eq/ne or(shl(a), b), 0 -> icmp eq/ne or(a, b), 0
+  if (auto *Res = optimizeShiftInOrChain(Op0, Builder)) {
+    I.replaceAllUsesWith(Builder.CreateICmp(Pred, Res, I.getOperand(1)));
+    return true;
+  }
+
+  return false;
+}
+
 // Calculate GEP Stride and accumulated const ModOffset. Return Stride and
 // ModOffset
 static std::pair<APInt, APInt>
@@ -1253,6 +1309,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
       MadeChange |= tryToRecognizeTableBasedCttz(I);
       MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
       MadeChange |= foldPatternedLoads(I, DL);
+      MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
       // NOTE: This function introduces erasing of the instruction `I`, so it
       // needs to be called at the end of this sequence, otherwise we may make
       // bugs.
diff --git a/llvm/test/Transforms/AggressiveInstCombine/or-shift-chain.ll b/llvm/test/Transforms/AggressiveInstCombine/or-shift-chain.ll
index 6816ccc7bd02b..b50f957c5aefe 100644
--- a/llvm/test/Transforms/AggressiveInstCombine/or-shift-chain.ll
+++ b/llvm/test/Transforms/AggressiveInstCombine/or-shift-chain.ll
@@ -3,8 +3,7 @@
 
 define i1 @remove_shift_nuw_ab(i8 %a, i8 %b, i8 %s) {
 ; CHECK-LABEL: @remove_shift_nuw_ab(
-; CHECK-NEXT:    [[T:%.*]] = shl nuw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[T]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[T:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    [[IC:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[IC]]
 ;
@@ -16,8 +15,7 @@ define i1 @remove_shift_nuw_ab(i8 %a, i8 %b, i8 %s) {
 
 define i1 @remove_shift_nuw_ba(i8 %a, i8 %b, i8 %s) {
 ; CHECK-LABEL: @remove_shift_nuw_ba(
-; CHECK-NEXT:    [[T:%.*]] = shl nuw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[B:%.*]], [[T]]
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[B:%.*]], [[T:%.*]]
 ; CHECK-NEXT:    [[IC:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[IC]]
 ;
@@ -29,8 +27,7 @@ define i1 @remove_shift_nuw_ba(i8 %a, i8 %b, i8 %s) {
 
 define i1 @remove_shift_nsw(i8 %a, i8 %b, i8 %s) {
 ; CHECK-LABEL: @remove_shift_nsw(
-; CHECK-NEXT:    [[T:%.*]] = shl nsw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[T]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[T:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    [[IC:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[IC]]
 ;
@@ -42,8 +39,7 @@ define i1 @remove_shift_nsw(i8 %a, i8 %b, i8 %s) {
 
 define i1 @remove_shift_nuw_ne(i8 %a, i8 %b, i8 %s) {
 ; CHECK-LABEL: @remove_shift_nuw_ne(
-; CHECK-NEXT:    [[T:%.*]] = shl nuw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[T]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[T:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    [[IC:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[IC]]
 ;
@@ -55,8 +51,7 @@ define i1 @remove_shift_nuw_ne(i8 %a, i8 %b, i8 %s) {
 
 define i1 @remove_shift_nsw_ne(i8 %a, i8 %b, i8 %s) {
 ; CHECK-LABEL: @remove_shift_nsw_ne(
-; CHECK-NEXT:    [[T:%.*]] = shl nsw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[T]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[T:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    [[IC:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[IC]]
 ;
@@ -81,9 +76,8 @@ define i1 @remove_shift_wraps(i8 %a, i8 %b, i8 %s) {
 
 define i1 @remove_shift_chain_d(i8 %a, i8 %b, i8 %c, i8 %d, i8 %s) {
 ; CHECK-LABEL: @remove_shift_chain_d(
-; CHECK-NEXT:    [[DT:%.*]] = shl nuw i8 [[D:%.*]], [[S:%.*]]
 ; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[C:%.*]], [[DT]]
+; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[C:%.*]], [[DT:%.*]]
 ; CHECK-NEXT:    [[OR:%.*]] = or i8 [[OR1]], [[OR2]]
 ; CHECK-NEXT:    [[IC:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[IC]]
@@ -98,12 +92,8 @@ define i1 @remove_shift_chain_d(i8 %a, i8 %b, i8 %c, i8 %d, i8 %s) {
 
 define i1 @remove_shift_chain_abcd(i8 %a, i8 %b, i8 %c, i8 %d, i8 %s) {
 ; CHECK-LABEL: @remove_shift_chain_abcd(
-; CHECK-NEXT:    [[AT:%.*]] = shl nuw i8 [[A:%.*]], [[S:%.*]]
-; CHECK-NEXT:    [[BT:%.*]] = shl nuw i8 [[B:%.*]], 2
-; CHECK-NEXT:    [[CT:%.*]] = shl nuw i8 [[C:%.*]], 1
-; CHECK-NEXT:    [[DT:%.*]] = shl nuw i8 [[D:%.*]], [[S]]
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[AT]], [[BT]]
-; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[CT]], [[DT]]
+; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[AT:%.*]], [[BT:%.*]]
+; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[CT:%.*]], [[DT:%.*]]
 ; CHECK-NEXT:    [[OR:%.*]] = or i8 [[OR1]], [[OR2]]
 ; CHECK-NEXT:    [[IC:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[IC]]



More information about the llvm-commits mailing list