[llvm] added optimization for shift add (PR #163502)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 14 22:41:46 PDT 2025


https://github.com/manik-muk created https://github.com/llvm/llvm-project/pull/163502

Addresses #163115

>From ff8405b11abd4eae571d0f00333f65d831bbb321 Mon Sep 17 00:00:00 2001
From: Manik Mukherjee <mkmrocks20 at gmail.com>
Date: Wed, 15 Oct 2025 01:40:27 -0400
Subject: [PATCH] added optimization for shift add

---
 .../InstCombine/InstCombineShifts.cpp         |  24 +++
 llvm/test/Transforms/InstCombine/shift-add.ll | 144 ++++++++++++++++++
 2 files changed, 168 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index d457e0c7dd1c4..fc2a0018e725c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1803,6 +1803,30 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
           cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap());
       return NewAdd;
     }
+
+    // Fold ((X << A) + C) >> B  -->  (X << (A - B)) + (C >> B)
+    // when the shift is exact and the add is nsw.
+    // This transforms patterns like: ((x << 4) + 16) ashr exact 1  -->  (x <<
+    // 3) + 8
+    const APInt *ShlAmt, *AddC;
+    if (I.isExact() &&
+        match(Op0, m_c_NSWAdd(m_NSWShl(m_Value(X), m_APInt(ShlAmt)),
+                              m_APInt(AddC))) &&
+        ShlAmt->uge(ShAmt)) {
+      // Check if C is divisible by (1 << ShAmt)
+      if (AddC->isShiftedMask() || AddC->countTrailingZeros() >= ShAmt ||
+          AddC->ashr(ShAmt).shl(ShAmt) == *AddC) {
+        // X << (A - B)
+        Constant *NewShlAmt = ConstantInt::get(Ty, *ShlAmt - ShAmt);
+        Value *NewShl = Builder.CreateShl(X, NewShlAmt);
+
+        // C >> B
+        Constant *NewAddC = ConstantInt::get(Ty, AddC->ashr(ShAmt));
+
+        // (X << (A - B)) + (C >> B)
+        return BinaryOperator::CreateAdd(NewShl, NewAddC);
+      }
+    }
   }
 
   const SimplifyQuery Q = SQ.getWithInstruction(&I);
diff --git a/llvm/test/Transforms/InstCombine/shift-add.ll b/llvm/test/Transforms/InstCombine/shift-add.ll
index 81cbc2ac23b5f..1d1f219904f74 100644
--- a/llvm/test/Transforms/InstCombine/shift-add.ll
+++ b/llvm/test/Transforms/InstCombine/shift-add.ll
@@ -804,3 +804,147 @@ define <2 x i8> @lshr_fold_or_disjoint_cnt_out_of_bounds(<2 x i8> %x) {
   %r = lshr <2 x i8> <i8 2, i8 3>, %a
   ret <2 x i8> %r
 }
+
+define i32 @ashr_exact_add_shl_fold(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[V0]], 8
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}
+
+; Test with larger shift amounts
+define i32 @ashr_exact_add_shl_fold_larger_shift(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_larger_shift(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 1
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[V0]], 2
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 3
+  ret i32 %v2
+}
+
+; Test with negative constant
+define i32 @ashr_exact_add_shl_fold_negative_const(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_negative_const(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 2
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[V0]], -4
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, -16
+  %v2 = ashr exact i32 %v1, 2
+  ret i32 %v2
+}
+
+; Test where shift amount equals shl amount (result is just the constant)
+define i32 @ashr_exact_add_shl_fold_equal_shifts(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_equal_shifts(
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[ARG0:%.*]], 1
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 4
+  ret i32 %v2
+}
+
+; Negative test: not exact - should not transform
+define i32 @ashr_add_shl_no_exact(i32 %arg0) {
+; CHECK-LABEL: @ashr_add_shl_no_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[TMP1]], 8
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr i32 %v1, 1
+  ret i32 %v2
+}
+
+; Negative test: add is not nsw - should not transform
+define i32 @ashr_exact_add_shl_no_nsw_add(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_no_nsw_add(
+; CHECK-NEXT:    [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 4
+; CHECK-NEXT:    [[V1:%.*]] = add i32 [[V0]], 16
+; CHECK-NEXT:    [[V2:%.*]] = ashr exact i32 [[V1]], 1
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}
+
+; Negative test: shl is not nsw - should not transform
+define i32 @ashr_exact_add_shl_no_nsw_shl(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_no_nsw_shl(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 4
+; CHECK-NEXT:    [[V1:%.*]] = add nsw i32 [[V0]], 16
+; CHECK-NEXT:    [[V2:%.*]] = ashr exact i32 [[V1]], 1
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}
+
+; Negative test: constant not divisible by shift amount
+define i32 @ashr_exact_add_shl_not_divisible(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_not_divisible(
+; CHECK-NEXT:    [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 4
+; CHECK-NEXT:    [[V1:%.*]] = add nsw i32 [[V0]], 17
+; CHECK-NEXT:    ret i32 [[V1]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 %v0, 17
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}
+
+; Negative test: shift amount greater than shl amount
+define i32 @ashr_exact_add_shl_shift_too_large(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_shift_too_large(
+; CHECK-NEXT:    [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 2
+; CHECK-NEXT:    [[V1:%.*]] = add nsw i32 [[V0]], 16
+; CHECK-NEXT:    [[V2:%.*]] = ashr exact i32 [[V1]], 4
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 2
+  %v1 = add nsw i32 %v0, 16
+  %v2 = ashr exact i32 %v1, 4
+  ret i32 %v2
+}
+
+; Vector test
+define <2 x i32> @ashr_exact_add_shl_fold_vector(<2 x i32> %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i32> [[ARG0:%.*]], splat (i32 3)
+; CHECK-NEXT:    [[V2:%.*]] = add <2 x i32> [[TMP1]], splat (i32 8)
+; CHECK-NEXT:    ret <2 x i32> [[V2]]
+;
+  %v0 = shl nsw <2 x i32> %arg0, <i32 4, i32 4>
+  %v1 = add nsw <2 x i32> %v0, <i32 16, i32 16>
+  %v2 = ashr exact <2 x i32> %v1, <i32 1, i32 1>
+  ret <2 x i32> %v2
+}
+
+; Test commutative add (constant on left)
+define i32 @ashr_exact_add_shl_fold_commute(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_commute(
+; CHECK-NEXT:    [[V0:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT:    [[V2:%.*]] = add i32 [[V0]], 8
+; CHECK-NEXT:    ret i32 [[V2]]
+;
+  %v0 = shl nsw i32 %arg0, 4
+  %v1 = add nsw i32 16, %v0
+  %v2 = ashr exact i32 %v1, 1
+  ret i32 %v2
+}



More information about the llvm-commits mailing list