[llvm] a6edcea - [InstCombine] Simplify `(add/sub (sub/add) (sub/add))` irrelivant of use-count
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 27 11:43:32 PDT 2024
Author: Noah Goldstein
Date: 2024-08-27T11:43:17-07:00
New Revision: a6edcea211a3d415212adb69b544f853351a7627
URL: https://github.com/llvm/llvm-project/commit/a6edcea211a3d415212adb69b544f853351a7627
DIFF: https://github.com/llvm/llvm-project/commit/a6edcea211a3d415212adb69b544f853351a7627.diff
LOG: [InstCombine] Simplify `(add/sub (sub/add) (sub/add))` irrelivant of use-count
Added folds:
- `(add (sub X, Y), (sub Z, X))` -> `(sub Z, Y)`
- `(sub (add X, Y), (add X, Z))` -> `(sub Y, Z)`
The fold typically is handled in the `Reassosiate` pass, but it fails
if the inner `sub`/`add` are multi-use. Less importantly, Reassosiate
doesn't propagate flags correctly.
This patch adds the fold explicitly the InstCombine
Proofs: https://alive2.llvm.org/ce/z/p6JyRP
Closes #105866
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
llvm/lib/Transforms/InstCombine/InstCombineInternal.h
llvm/test/Transforms/InstCombine/fold-add-sub.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index d7758b5fbf1786..e5c3a20e1a6487 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1304,6 +1304,24 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) {
X, ConstantInt::get(Add.getType(), DivC->exactLogBase2()));
}
+Instruction *InstCombinerImpl::foldAddLikeCommutative(Value *LHS, Value *RHS,
+ bool NSW, bool NUW) {
+ Value *A, *B, *C;
+ if (match(LHS, m_Sub(m_Value(A), m_Value(B))) &&
+ match(RHS, m_Sub(m_Value(C), m_Specific(A)))) {
+ Instruction *R = BinaryOperator::CreateSub(C, B);
+ bool NSWOut = NSW && match(LHS, m_NSWSub(m_Value(), m_Value())) &&
+ match(RHS, m_NSWSub(m_Value(), m_Value()));
+
+ bool NUWOut = match(LHS, m_NUWSub(m_Value(), m_Value())) &&
+ match(RHS, m_NUWSub(m_Value(), m_Value()));
+ R->setHasNoSignedWrap(NSWOut);
+ R->setHasNoUnsignedWrap(NUWOut);
+ return R;
+ }
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::
canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
BinaryOperator &I) {
@@ -1521,6 +1539,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return R;
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+ if (Instruction *R = foldAddLikeCommutative(LHS, RHS, I.hasNoSignedWrap(),
+ I.hasNoUnsignedWrap()))
+ return R;
+ if (Instruction *R = foldAddLikeCommutative(RHS, LHS, I.hasNoSignedWrap(),
+ I.hasNoUnsignedWrap()))
+ return R;
Type *Ty = I.getType();
if (Ty->isIntOrIntVectorTy(1))
return BinaryOperator::CreateXor(LHS, RHS);
@@ -2286,6 +2310,33 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
}
+ {
+ Value *W, *Z;
+ if (match(Op0, m_AddLike(m_Value(W), m_Value(X))) &&
+ match(Op1, m_AddLike(m_Value(Y), m_Value(Z)))) {
+ Instruction *R = nullptr;
+ if (W == Y)
+ R = BinaryOperator::CreateSub(X, Z);
+ else if (W == Z)
+ R = BinaryOperator::CreateSub(X, Y);
+ else if (X == Y)
+ R = BinaryOperator::CreateSub(W, Z);
+ else if (X == Z)
+ R = BinaryOperator::CreateSub(W, Y);
+ if (R) {
+ bool NSW = I.hasNoSignedWrap() &&
+ match(Op0, m_NSWAddLike(m_Value(), m_Value())) &&
+ match(Op1, m_NSWAddLike(m_Value(), m_Value()));
+
+ bool NUW = I.hasNoUnsignedWrap() &&
+ match(Op1, m_NUWAddLike(m_Value(), m_Value()));
+ R->setHasNoSignedWrap(NSW);
+ R->setHasNoUnsignedWrap(NUW);
+ return R;
+ }
+ }
+ }
+
// (~X) - (~Y) --> Y - X
{
// Need to ensure we can consume at least one of the `not` instructions,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 0af06c7e463f80..4f557532f9f783 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3580,6 +3580,17 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
return R;
+ if (cast<PossiblyDisjointInst>(I).isDisjoint()) {
+ if (Instruction *R =
+ foldAddLikeCommutative(I.getOperand(0), I.getOperand(1),
+ /*NSW=*/true, /*NUW=*/true))
+ return R;
+ if (Instruction *R =
+ foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
+ /*NSW=*/true, /*NUW=*/true))
+ return R;
+ }
+
Value *X, *Y;
const APInt *CV;
if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index b3957b760b4a29..57f27e6a3b7fa5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -585,6 +585,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
FPClassTest DemandedMask, KnownFPClass &Known,
unsigned Depth = 0);
+ /// Common transforms for add / disjoint or
+ Instruction *foldAddLikeCommutative(Value *LHS, Value *RHS, bool NSW,
+ bool NUW);
+
/// Canonicalize the position of binops relative to shufflevector.
Instruction *foldVectorBinop(BinaryOperator &Inst);
Instruction *foldVectorSelect(SelectInst &Sel);
diff --git a/llvm/test/Transforms/InstCombine/fold-add-sub.ll b/llvm/test/Transforms/InstCombine/fold-add-sub.ll
index 107572459eb613..bbb7bb67369e7f 100644
--- a/llvm/test/Transforms/InstCombine/fold-add-sub.ll
+++ b/llvm/test/Transforms/InstCombine/fold-add-sub.ll
@@ -8,7 +8,7 @@ define i8 @test_add_nsw(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = add nsw i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub nsw i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add nsw i8 %x, %y
@@ -25,7 +25,7 @@ define i8 @test_add_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = add nuw i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add nsw i8 %x, %y
@@ -42,7 +42,7 @@ define i8 @test_add(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = add i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add i8 %x, %y
@@ -76,7 +76,7 @@ define i8 @test_add_nuw(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = sub nuw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub nuw i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add i8 %x, %y
@@ -93,7 +93,7 @@ define i8 @test_add_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add i8 %x, %y
@@ -110,7 +110,7 @@ define i8 @test_sub_nuw(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub nuw i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = add i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub nuw i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub nuw i8 %x, %y
@@ -127,7 +127,7 @@ define i8 @test_sub_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = add nuw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub nuw i8 %x, %y
@@ -144,7 +144,7 @@ define i8 @test_sub_nsw(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub nsw i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub nsw i8 %x, %y
@@ -161,7 +161,7 @@ define i8 @test_sub_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub i8 %x, %y
@@ -178,7 +178,7 @@ define i8 @test_sub_none(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT: [[R:%.*]] = add i8 [[LHS]], [[RHS]]
+; CHECK-NEXT: [[R:%.*]] = sub i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub i8 %x, %y
More information about the llvm-commits
mailing list