[llvm] [InstCombine] Added optimization for shift add (PR #163502)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 15 20:06:19 PDT 2025


https://github.com/manik-muk updated https://github.com/llvm/llvm-project/pull/163502

>From ff8405b11abd4eae571d0f00333f65d831bbb321 Mon Sep 17 00:00:00 2001
From: Manik Mukherjee <mkmrocks20 at gmail.com>
Date: Wed, 15 Oct 2025 01:40:27 -0400
Subject: [PATCH 1/3] added optimization for shift add

---
 .../InstCombine/InstCombineShifts.cpp         |  24 +++
 llvm/test/Transforms/InstCombine/shift-add.ll | 144 ++++++++++++++++++
 2 files changed, 168 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index d457e0c7dd1c4..fc2a0018e725c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1803,6 +1803,30 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
           cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap());
       return NewAdd;
     }
+
+    // Fold ((X << A) + C) >> B  -->  (X << (A - B)) + (C >> B)
+    // when the shift is exact and the add is nsw.
+    // This transforms patterns like: ((x << 4) + 16) ashr exact 1  -->  (x <<
+    // 3) + 8
+    const APInt *ShlAmt, *AddC;
+    if (I.isExact() &&
+        match(Op0, m_c_NSWAdd(m_NSWShl(m_Value(X), m_APInt(ShlAmt)),
+                              m_APInt(AddC))) &&
+        ShlAmt->uge(ShAmt)) {
+      // Check if C is divisible by (1 << ShAmt)
+      if (AddC->isShiftedMask() || AddC->countTrailingZeros() >= ShAmt ||
+          AddC->ashr(ShAmt).shl(ShAmt) == *AddC) {
+        // X << (A - B)
+        Constant *NewShlAmt = ConstantInt::get(Ty, *ShlAmt - ShAmt);
+        Value *NewShl = Builder.CreateShl(X, NewShlAmt);
+
+        // C >> B
+        Constant *NewAddC = ConstantInt::get(Ty, AddC->ashr(ShAmt));
+
+        // (X << (A - B)) + (C >> B)
+        return BinaryOperator::CreateAdd(NewShl, NewAddC);
+      }
+    }
   }
 
   const SimplifyQuery Q = SQ.getWithInstruction(&I);
diff --git a/llvm/test/Transforms/InstCombine/shift-add.ll b/llvm/test/Transforms/InstCombine/shift-add.ll
index 81cbc2ac23b5f..1d1f219904f74 100644
--- a/llvm/test/Transforms/InstCombine/shift-add.ll
+++ b/llvm/test/Transforms/InstCombine/shift-add.ll
@@ -804,3 +804,147 @@ define <2 x i8> @lshr_fold_or_disjoint_cnt_out_of_bounds(<2 x i8> %x) {
   %r = lshr <2 x i8> <i8 2, i8 3>, %a
   ret <2 x i8> %r
 }
+
+define i32 @ashr_exact_add_shl_fold(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[V0]], 8
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}
+
+; Test with larger shift amounts
+define i32 @ashr_exact_add_shl_fold_larger_shift(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_larger_shift(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 1
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[V0]], 2
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 3
+  ret i32 %v2
+}
+
+; Test with negative constant
+define i32 @ashr_exact_add_shl_fold_negative_const(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_negative_const(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 2
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[V0]], -4
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, -16
+  %v2 = ashr exact i32 %v1, 2
+  ret i32 %v2
+}
+
+; Test where shift amount equals shl amount (result is just the constant)
+define i32 @ashr_exact_add_shl_fold_equal_shifts(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_equal_shifts(
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[ARG0:%.*]], 1
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 4
+  ret i32 %v2
+}
+
+; Negative test: not exact - should not transform
+define i32 @ashr_add_shl_no_exact(i32 %arg0) {
+; CHECK-LABEL: @ashr_add_shl_no_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[TMP1]], 8
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr i32 %v1, 1
+  ret i32 %v2
+}
+
+; Negative test: add is not nsw - should not transform
+define i32 @ashr_exact_add_shl_no_nsw_add(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_no_nsw_add(
+; CHECK-NEXT:    [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 4
+; CHECK-NEXT:    [[V1:%.*]] = add i32 [[V0]], 16
+; CHECK-NEXT:    [[V2:%.*]] = ashr exact i32 [[V1]], 1
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}
+
+; Negative test: shl is not nsw - should not transform
+define i32 @ashr_exact_add_shl_no_nsw_shl(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_no_nsw_shl(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 4
+; CHECK-NEXT:    [[V1:%.*]] = add nsw i32 [[V0]], 16
+; CHECK-NEXT:    [[V2:%.*]] = ashr exact i32 [[V1]], 1
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}
+
+; Negative test: constant not divisible by shift amount
+define i32 @ashr_exact_add_shl_not_divisible(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_not_divisible(
+; CHECK-NEXT:    [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 4
+; CHECK-NEXT:    [[V1:%.*]] = add nsw i32 [[V0]], 17
+; CHECK-NEXT:    ret i32 [[V1]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 17
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}
+
+; Negative test: shift amount greater than shl amount
+define i32 @ashr_exact_add_shl_shift_too_large(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_shift_too_large(
+; CHECK-NEXT:    [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 2
+; CHECK-NEXT:    [[V1:%.*]] = add nsw i32 [[V0]], 16
+; CHECK-NEXT:    [[V2:%.*]] = ashr exact i32 [[V1]], 4
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 2
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 4
+  ret i32 %v2
+}
+
+; Vector test
+define <2 x i32> @ashr_exact_add_shl_fold_vector(<2 x i32> %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i32> [[ARG0:%.*]], splat (i32 3)
+; CHECK-NEXT:    [[V2:%.*]] = add <2 x i32> [[TMP1]], splat (i32 8)
+; CHECK-NEXT:    ret <2 x i32> [[V2]]
+;
+  %v0 = shl nsw <2 x i32> %arg0, <i32 4, i32 4>
+  %v1 = add nsw <2 x i32> %v0, <i32 16, i32 16>
+  %v2 = ashr exact <2 x i32> %v1, <i32 1, i32 1>
+  ret <2 x i32> %v2
+}
+
+; Test commutative add (constant on left)
+define i32 @ashr_exact_add_shl_fold_commute(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_commute(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[V0]], 8
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 16, %v0
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}

>From 3048e0c0d0f0601e053593f4a0a60f84f2b0f016 Mon Sep 17 00:00:00 2001
From: Manik Mukherjee <mkmrocks20 at gmail.com>
Date: Wed, 15 Oct 2025 23:01:31 -0400
Subject: [PATCH 2/3] modified framework based on comments

---
 .../InstCombine/InstCombineShifts.cpp         | 72 +++++++++++++++++--
 1 file changed, 66 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index fc2a0018e725c..28570dab83805 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -610,6 +610,32 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
   case Instruction::LShr:
     return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
 
+  case Instruction::Add: {
+    // We can fold Add through right shifts if it has the appropriate nowrap
+    // flag. For lshr: requires nuw (no unsigned wrap) For ashr: requires nsw
+    // (no signed wrap) We don't support left shift through add.
+    if (IsLeftShift)
+      return false;
+
+    auto *BO = cast<BinaryOperator>(I);
+
+    // Determine which flag is required based on the shift type
+    bool HasRequiredFlag;
+    if (isa<LShrOperator>(CxtI))
+      HasRequiredFlag = BO->hasNoUnsignedWrap();
+    else if (isa<AShrOperator>(CxtI))
+      HasRequiredFlag = BO->hasNoSignedWrap();
+    else
+      return false;
+
+    if (!HasRequiredFlag)
+      return false;
+
+    // Both operands must be shiftable, pass through CxtI to preserve shift type
+    return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, CxtI) &&
+           canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, CxtI);
+  }
+
   case Instruction::Select: {
     SelectInst *SI = cast<SelectInst>(I);
     Value *TrueVal = SI->getTrueValue();
@@ -731,6 +757,18 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
     return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
                             IC.Builder);
 
+  case Instruction::Add:
+    // Shift both operands, then perform the add.
+    I->setOperand(
+        0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
+    I->setOperand(
+        1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
+    // We must clear the nuw/nsw flags because the original values that didn't
+    // overflow might overflow after we shift them.
+    cast<BinaryOperator>(I)->setHasNoUnsignedWrap(false);
+    cast<BinaryOperator>(I)->setHasNoSignedWrap(false);
+    return I;
+
   case Instruction::Select:
     I->setOperand(
         1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
@@ -1635,6 +1673,30 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
       return BinaryOperator::CreateLShr(NewShl, Shl1_Op1);
     }
   }
+
+  // Fold ((X << A) + C) >>u B  -->  (X << (A - B)) + (C >>u B)
+  // when the shift is exact and the add has nuw.
+  const APInt *ShAmtAPInt, *ShlAmt, *AddC;
+  if (match(Op1, m_APInt(ShAmtAPInt)) && I.isExact() &&
+      match(Op0, m_c_NUWAdd(m_NUWShl(m_Value(X), m_APInt(ShlAmt)),
+                            m_APInt(AddC))) &&
+      ShlAmt->uge(*ShAmtAPInt)) {
+    unsigned ShAmt = ShAmtAPInt->getZExtValue();
+    // Check if C is divisible by (1 << ShAmt)
+    if (AddC->isShiftedMask() || AddC->countTrailingZeros() >= ShAmt ||
+        AddC->lshr(ShAmt).shl(ShAmt) == *AddC) {
+      // X << (A - B)
+      Constant *NewShlAmt = ConstantInt::get(Ty, *ShlAmt - ShAmt);
+      Value *NewShl = Builder.CreateShl(X, NewShlAmt);
+
+      // C >>u B
+      Constant *NewAddC = ConstantInt::get(Ty, AddC->lshr(ShAmt));
+
+      // (X << (A - B)) + (C >>u B)
+      return BinaryOperator::CreateAdd(NewShl, NewAddC);
+    }
+  }
+
   return nullptr;
 }
 
@@ -1804,10 +1866,8 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
       return NewAdd;
     }
 
-    // Fold ((X << A) + C) >> B  -->  (X << (A - B)) + (C >> B)
-    // when the shift is exact and the add is nsw.
-    // This transforms patterns like: ((x << 4) + 16) ashr exact 1  -->  (x <<
-    // 3) + 8
+    // Fold ((X << A) + C) >>s B  -->  (X << (A - B)) + (C >>s B)
+    // when the shift is exact and the add has nsw.
     const APInt *ShlAmt, *AddC;
     if (I.isExact() &&
         match(Op0, m_c_NSWAdd(m_NSWShl(m_Value(X), m_APInt(ShlAmt)),
@@ -1820,10 +1880,10 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
         Constant *NewShlAmt = ConstantInt::get(Ty, *ShlAmt - ShAmt);
         Value *NewShl = Builder.CreateShl(X, NewShlAmt);
 
-        // C >> B
+        // C >>s B
         Constant *NewAddC = ConstantInt::get(Ty, AddC->ashr(ShAmt));
 
-        // (X << (A - B)) + (C >> B)
+        // (X << (A - B)) + (C >>s B)
         return BinaryOperator::CreateAdd(NewShl, NewAddC);
       }
     }

>From b866050491ce9977553ecd4e35bdc6766ba41045 Mon Sep 17 00:00:00 2001
From: Manik Mukherjee <mkmrocks20 at gmail.com>
Date: Wed, 15 Oct 2025 23:06:07 -0400
Subject: [PATCH 3/3] run clang format

---
 llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 28570dab83805..aafdd92f80925 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -632,7 +632,8 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
       return false;
 
     // Both operands must be shiftable, pass through CxtI to preserve shift type
-    return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, CxtI) &&
+    return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC,
+                              CxtI) &&
            canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, CxtI);
   }
 
@@ -1678,8 +1679,8 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
   // when the shift is exact and the add has nuw.
   const APInt *ShAmtAPInt, *ShlAmt, *AddC;
   if (match(Op1, m_APInt(ShAmtAPInt)) && I.isExact() &&
-      match(Op0, m_c_NUWAdd(m_NUWShl(m_Value(X), m_APInt(ShlAmt)),
-                            m_APInt(AddC))) &&
+      match(Op0,
+            m_c_NUWAdd(m_NUWShl(m_Value(X), m_APInt(ShlAmt)), m_APInt(AddC))) &&
       ShlAmt->uge(*ShAmtAPInt)) {
     unsigned ShAmt = ShAmtAPInt->getZExtValue();
     // Check if C is divisible by (1 << ShAmt)



More information about the llvm-commits mailing list