[llvm] [InstCombine] Added optimization for shift add (PR #163502)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 15 20:06:19 PDT 2025
https://github.com/manik-muk updated https://github.com/llvm/llvm-project/pull/163502
>From ff8405b11abd4eae571d0f00333f65d831bbb321 Mon Sep 17 00:00:00 2001
From: Manik Mukherjee <mkmrocks20 at gmail.com>
Date: Wed, 15 Oct 2025 01:40:27 -0400
Subject: [PATCH 1/3] added optimization for shift add
---
.../InstCombine/InstCombineShifts.cpp | 24 +++
llvm/test/Transforms/InstCombine/shift-add.ll | 144 ++++++++++++++++++
2 files changed, 168 insertions(+)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index d457e0c7dd1c4..fc2a0018e725c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1803,6 +1803,30 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap());
return NewAdd;
}
+
+ // Fold ((X << A) + C) >> B --> (X << (A - B)) + (C >> B)
+ // when the shift is exact and the add is nsw.
+ // This transforms patterns like: ((x << 4) + 16) ashr exact 1 --> (x <<
+ // 3) + 8
+ const APInt *ShlAmt, *AddC;
+ if (I.isExact() &&
+ match(Op0, m_c_NSWAdd(m_NSWShl(m_Value(X), m_APInt(ShlAmt)),
+ m_APInt(AddC))) &&
+ ShlAmt->uge(ShAmt)) {
+ // Check if C is divisible by (1 << ShAmt)
+ if (AddC->isShiftedMask() || AddC->countTrailingZeros() >= ShAmt ||
+ AddC->ashr(ShAmt).shl(ShAmt) == *AddC) {
+ // X << (A - B)
+ Constant *NewShlAmt = ConstantInt::get(Ty, *ShlAmt - ShAmt);
+ Value *NewShl = Builder.CreateShl(X, NewShlAmt);
+
+ // C >> B
+ Constant *NewAddC = ConstantInt::get(Ty, AddC->ashr(ShAmt));
+
+ // (X << (A - B)) + (C >> B)
+ return BinaryOperator::CreateAdd(NewShl, NewAddC);
+ }
+ }
}
const SimplifyQuery Q = SQ.getWithInstruction(&I);
diff --git a/llvm/test/Transforms/InstCombine/shift-add.ll b/llvm/test/Transforms/InstCombine/shift-add.ll
index 81cbc2ac23b5f..1d1f219904f74 100644
--- a/llvm/test/Transforms/InstCombine/shift-add.ll
+++ b/llvm/test/Transforms/InstCombine/shift-add.ll
@@ -804,3 +804,147 @@ define <2 x i8> @lshr_fold_or_disjoint_cnt_out_of_bounds(<2 x i8> %x) {
%r = lshr <2 x i8> <i8 2, i8 3>, %a
ret <2 x i8> %r
}
+
+define i32 @ashr_exact_add_shl_fold(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[V0]], 8
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
+
+; Test with larger shift amounts
+define i32 @ashr_exact_add_shl_fold_larger_shift(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_larger_shift(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 1
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[V0]], 2
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 3
+ ret i32 %v2
+}
+
+; Test with negative constant
+define i32 @ashr_exact_add_shl_fold_negative_const(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_negative_const(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 2
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[V0]], -4
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, -16
+ %v2 = ashr exact i32 %v1, 2
+ ret i32 %v2
+}
+
+; Test where shift amount equals shl amount (result is just the constant)
+define i32 @ashr_exact_add_shl_fold_equal_shifts(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_equal_shifts(
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[ARG0:%.*]], 1
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 4
+ ret i32 %v2
+}
+
+; Negative test: not exact - should not transform
+define i32 @ashr_add_shl_no_exact(i32 %arg0) {
+; CHECK-LABEL: @ashr_add_shl_no_exact(
+; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[TMP1]], 8
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr i32 %v1, 1
+ ret i32 %v2
+}
+
+; Negative test: add is not nsw - should not transform
+define i32 @ashr_exact_add_shl_no_nsw_add(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_no_nsw_add(
+; CHECK-NEXT: [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 4
+; CHECK-NEXT: [[V1:%.*]] = add i32 [[V0]], 16
+; CHECK-NEXT: [[V2:%.*]] = ashr exact i32 [[V1]], 1
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
+
+; Negative test: shl is not nsw - should not transform
+define i32 @ashr_exact_add_shl_no_nsw_shl(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_no_nsw_shl(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 4
+; CHECK-NEXT: [[V1:%.*]] = add nsw i32 [[V0]], 16
+; CHECK-NEXT: [[V2:%.*]] = ashr exact i32 [[V1]], 1
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
+
+; Negative test: constant not divisible by shift amount
+define i32 @ashr_exact_add_shl_not_divisible(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_not_divisible(
+; CHECK-NEXT: [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 4
+; CHECK-NEXT: [[V1:%.*]] = add nsw i32 [[V0]], 17
+; CHECK-NEXT: ret i32 [[V1]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 17
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
+
+; Negative test: shift amount greater than shl amount
+define i32 @ashr_exact_add_shl_shift_too_large(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_shift_too_large(
+; CHECK-NEXT: [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 2
+; CHECK-NEXT: [[V1:%.*]] = add nsw i32 [[V0]], 16
+; CHECK-NEXT: [[V2:%.*]] = ashr exact i32 [[V1]], 4
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 2
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 4
+ ret i32 %v2
+}
+
+; Vector test
+define <2 x i32> @ashr_exact_add_shl_fold_vector(<2 x i32> %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_vector(
+; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i32> [[ARG0:%.*]], splat (i32 3)
+; CHECK-NEXT: [[V2:%.*]] = add <2 x i32> [[TMP1]], splat (i32 8)
+; CHECK-NEXT: ret <2 x i32> [[V2]]
+;
+ %v0 = shl nsw <2 x i32> %arg0, <i32 4, i32 4>
+ %v1 = add nsw <2 x i32> %v0, <i32 16, i32 16>
+ %v2 = ashr exact <2 x i32> %v1, <i32 1, i32 1>
+ ret <2 x i32> %v2
+}
+
+; Test commutative add (constant on left)
+define i32 @ashr_exact_add_shl_fold_commute(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_commute(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[V0]], 8
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 16, %v0
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
>From 3048e0c0d0f0601e053593f4a0a60f84f2b0f016 Mon Sep 17 00:00:00 2001
From: Manik Mukherjee <mkmrocks20 at gmail.com>
Date: Wed, 15 Oct 2025 23:01:31 -0400
Subject: [PATCH 2/3] modified framework based on comments
---
.../InstCombine/InstCombineShifts.cpp | 72 +++++++++++++++++--
1 file changed, 66 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index fc2a0018e725c..28570dab83805 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -610,6 +610,32 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
case Instruction::LShr:
return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
+ case Instruction::Add: {
+ // We can fold Add through right shifts if it has the appropriate nowrap
+ // flag. For lshr: requires nuw (no unsigned wrap) For ashr: requires nsw
+ // (no signed wrap) We don't support left shift through add.
+ if (IsLeftShift)
+ return false;
+
+ auto *BO = cast<BinaryOperator>(I);
+
+ // Determine which flag is required based on the shift type
+ bool HasRequiredFlag;
+ if (isa<LShrOperator>(CxtI))
+ HasRequiredFlag = BO->hasNoUnsignedWrap();
+ else if (isa<AShrOperator>(CxtI))
+ HasRequiredFlag = BO->hasNoSignedWrap();
+ else
+ return false;
+
+ if (!HasRequiredFlag)
+ return false;
+
+ // Both operands must be shiftable, pass through CxtI to preserve shift type
+ return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, CxtI) &&
+ canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, CxtI);
+ }
+
case Instruction::Select: {
SelectInst *SI = cast<SelectInst>(I);
Value *TrueVal = SI->getTrueValue();
@@ -731,6 +757,18 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
IC.Builder);
+ case Instruction::Add:
+ // Shift both operands, then perform the add.
+ I->setOperand(
+ 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
+ I->setOperand(
+ 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
+ // We must clear the nuw/nsw flags because the original values that didn't
+ // overflow might overflow after we shift them.
+ cast<BinaryOperator>(I)->setHasNoUnsignedWrap(false);
+ cast<BinaryOperator>(I)->setHasNoSignedWrap(false);
+ return I;
+
case Instruction::Select:
I->setOperand(
1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
@@ -1635,6 +1673,30 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
return BinaryOperator::CreateLShr(NewShl, Shl1_Op1);
}
}
+
+ // Fold ((X << A) + C) >>u B --> (X << (A - B)) + (C >>u B)
+ // when the shift is exact and the add has nuw.
+ const APInt *ShAmtAPInt, *ShlAmt, *AddC;
+ if (match(Op1, m_APInt(ShAmtAPInt)) && I.isExact() &&
+ match(Op0, m_c_NUWAdd(m_NUWShl(m_Value(X), m_APInt(ShlAmt)),
+ m_APInt(AddC))) &&
+ ShlAmt->uge(*ShAmtAPInt)) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
+ // Check if C is divisible by (1 << ShAmt)
+ if (AddC->isShiftedMask() || AddC->countTrailingZeros() >= ShAmt ||
+ AddC->lshr(ShAmt).shl(ShAmt) == *AddC) {
+ // X << (A - B)
+ Constant *NewShlAmt = ConstantInt::get(Ty, *ShlAmt - ShAmt);
+ Value *NewShl = Builder.CreateShl(X, NewShlAmt);
+
+ // C >>u B
+ Constant *NewAddC = ConstantInt::get(Ty, AddC->lshr(ShAmt));
+
+ // (X << (A - B)) + (C >>u B)
+ return BinaryOperator::CreateAdd(NewShl, NewAddC);
+ }
+ }
+
return nullptr;
}
@@ -1804,10 +1866,8 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
return NewAdd;
}
- // Fold ((X << A) + C) >> B --> (X << (A - B)) + (C >> B)
- // when the shift is exact and the add is nsw.
- // This transforms patterns like: ((x << 4) + 16) ashr exact 1 --> (x <<
- // 3) + 8
+ // Fold ((X << A) + C) >>s B --> (X << (A - B)) + (C >>s B)
+ // when the shift is exact and the add has nsw.
const APInt *ShlAmt, *AddC;
if (I.isExact() &&
match(Op0, m_c_NSWAdd(m_NSWShl(m_Value(X), m_APInt(ShlAmt)),
@@ -1820,10 +1880,10 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
Constant *NewShlAmt = ConstantInt::get(Ty, *ShlAmt - ShAmt);
Value *NewShl = Builder.CreateShl(X, NewShlAmt);
- // C >> B
+ // C >>s B
Constant *NewAddC = ConstantInt::get(Ty, AddC->ashr(ShAmt));
- // (X << (A - B)) + (C >> B)
+ // (X << (A - B)) + (C >>s B)
return BinaryOperator::CreateAdd(NewShl, NewAddC);
}
}
>From b866050491ce9977553ecd4e35bdc6766ba41045 Mon Sep 17 00:00:00 2001
From: Manik Mukherjee <mkmrocks20 at gmail.com>
Date: Wed, 15 Oct 2025 23:06:07 -0400
Subject: [PATCH 3/3] run clang format
---
llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 28570dab83805..aafdd92f80925 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -632,7 +632,8 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
return false;
// Both operands must be shiftable, pass through CxtI to preserve shift type
- return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, CxtI) &&
+ return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC,
+ CxtI) &&
canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, CxtI);
}
@@ -1678,8 +1679,8 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
// when the shift is exact and the add has nuw.
const APInt *ShAmtAPInt, *ShlAmt, *AddC;
if (match(Op1, m_APInt(ShAmtAPInt)) && I.isExact() &&
- match(Op0, m_c_NUWAdd(m_NUWShl(m_Value(X), m_APInt(ShlAmt)),
- m_APInt(AddC))) &&
+ match(Op0,
+ m_c_NUWAdd(m_NUWShl(m_Value(X), m_APInt(ShlAmt)), m_APInt(AddC))) &&
ShlAmt->uge(*ShAmtAPInt)) {
unsigned ShAmt = ShAmtAPInt->getZExtValue();
// Check if C is divisible by (1 << ShAmt)
More information about the llvm-commits
mailing list