[llvm] a6edcea - [InstCombine] Simplify `(add/sub (sub/add) (sub/add))` irrelivant of use-count

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 27 11:43:32 PDT 2024


Author: Noah Goldstein
Date: 2024-08-27T11:43:17-07:00
New Revision: a6edcea211a3d415212adb69b544f853351a7627

URL: https://github.com/llvm/llvm-project/commit/a6edcea211a3d415212adb69b544f853351a7627
DIFF: https://github.com/llvm/llvm-project/commit/a6edcea211a3d415212adb69b544f853351a7627.diff

LOG: [InstCombine] Simplify `(add/sub (sub/add) (sub/add))` irrelivant of use-count

Added folds:
    - `(add (sub X, Y), (sub Z, X))` -> `(sub Z, Y)`
    - `(sub (add X, Y), (add X, Z))` -> `(sub Y, Z)`

The fold typically is handled in the `Reassosiate` pass, but it fails
if the inner `sub`/`add` are multi-use. Less importantly, Reassosiate
doesn't propagate flags correctly.

This patch adds the fold explicitly the InstCombine

Proofs: https://alive2.llvm.org/ce/z/p6JyRP

Closes #105866

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/test/Transforms/InstCombine/fold-add-sub.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index d7758b5fbf1786..e5c3a20e1a6487 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1304,6 +1304,24 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) {
       X, ConstantInt::get(Add.getType(), DivC->exactLogBase2()));
 }
 
+Instruction *InstCombinerImpl::foldAddLikeCommutative(Value *LHS, Value *RHS,
+                                                      bool NSW, bool NUW) {
+  Value *A, *B, *C;
+  if (match(LHS, m_Sub(m_Value(A), m_Value(B))) &&
+      match(RHS, m_Sub(m_Value(C), m_Specific(A)))) {
+    Instruction *R = BinaryOperator::CreateSub(C, B);
+    bool NSWOut = NSW && match(LHS, m_NSWSub(m_Value(), m_Value())) &&
+                  match(RHS, m_NSWSub(m_Value(), m_Value()));
+
+    bool NUWOut = match(LHS, m_NUWSub(m_Value(), m_Value())) &&
+                  match(RHS, m_NUWSub(m_Value(), m_Value()));
+    R->setHasNoSignedWrap(NSWOut);
+    R->setHasNoUnsignedWrap(NUWOut);
+    return R;
+  }
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::
     canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
         BinaryOperator &I) {
@@ -1521,6 +1539,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     return R;
 
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+  if (Instruction *R = foldAddLikeCommutative(LHS, RHS, I.hasNoSignedWrap(),
+                                              I.hasNoUnsignedWrap()))
+    return R;
+  if (Instruction *R = foldAddLikeCommutative(RHS, LHS, I.hasNoSignedWrap(),
+                                              I.hasNoUnsignedWrap()))
+    return R;
   Type *Ty = I.getType();
   if (Ty->isIntOrIntVectorTy(1))
     return BinaryOperator::CreateXor(LHS, RHS);
@@ -2286,6 +2310,33 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
     }
   }
 
+  {
+    Value *W, *Z;
+    if (match(Op0, m_AddLike(m_Value(W), m_Value(X))) &&
+        match(Op1, m_AddLike(m_Value(Y), m_Value(Z)))) {
+      Instruction *R = nullptr;
+      if (W == Y)
+        R = BinaryOperator::CreateSub(X, Z);
+      else if (W == Z)
+        R = BinaryOperator::CreateSub(X, Y);
+      else if (X == Y)
+        R = BinaryOperator::CreateSub(W, Z);
+      else if (X == Z)
+        R = BinaryOperator::CreateSub(W, Y);
+      if (R) {
+        bool NSW = I.hasNoSignedWrap() &&
+                   match(Op0, m_NSWAddLike(m_Value(), m_Value())) &&
+                   match(Op1, m_NSWAddLike(m_Value(), m_Value()));
+
+        bool NUW = I.hasNoUnsignedWrap() &&
+                   match(Op1, m_NUWAddLike(m_Value(), m_Value()));
+        R->setHasNoSignedWrap(NSW);
+        R->setHasNoUnsignedWrap(NUW);
+        return R;
+      }
+    }
+  }
+
   // (~X) - (~Y) --> Y - X
   {
     // Need to ensure we can consume at least one of the `not` instructions,

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 0af06c7e463f80..4f557532f9f783 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3580,6 +3580,17 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
     return R;
 
+  if (cast<PossiblyDisjointInst>(I).isDisjoint()) {
+    if (Instruction *R =
+            foldAddLikeCommutative(I.getOperand(0), I.getOperand(1),
+                                   /*NSW=*/true, /*NUW=*/true))
+      return R;
+    if (Instruction *R =
+            foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
+                                   /*NSW=*/true, /*NUW=*/true))
+      return R;
+  }
+
   Value *X, *Y;
   const APInt *CV;
   if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index b3957b760b4a29..57f27e6a3b7fa5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -585,6 +585,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                FPClassTest DemandedMask, KnownFPClass &Known,
                                unsigned Depth = 0);
 
+  /// Common transforms for add / disjoint or
+  Instruction *foldAddLikeCommutative(Value *LHS, Value *RHS, bool NSW,
+                                      bool NUW);
+
   /// Canonicalize the position of binops relative to shufflevector.
   Instruction *foldVectorBinop(BinaryOperator &Inst);
   Instruction *foldVectorSelect(SelectInst &Sel);

diff  --git a/llvm/test/Transforms/InstCombine/fold-add-sub.ll b/llvm/test/Transforms/InstCombine/fold-add-sub.ll
index 107572459eb613..bbb7bb67369e7f 100644
--- a/llvm/test/Transforms/InstCombine/fold-add-sub.ll
+++ b/llvm/test/Transforms/InstCombine/fold-add-sub.ll
@@ -8,7 +8,7 @@ define i8 @test_add_nsw(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = add nsw i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add nsw i8 %x, %y
@@ -25,7 +25,7 @@ define i8 @test_add_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = add nuw i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add nsw i8 %x, %y
@@ -42,7 +42,7 @@ define i8 @test_add(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = add i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add i8 %x, %y
@@ -76,7 +76,7 @@ define i8 @test_add_nuw(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub nuw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add i8 %x, %y
@@ -93,7 +93,7 @@ define i8 @test_add_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add i8 %x, %y
@@ -110,7 +110,7 @@ define i8 @test_sub_nuw(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub nuw i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = add i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub nuw i8 %x, %y
@@ -127,7 +127,7 @@ define i8 @test_sub_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = add nuw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub nuw i8 %x, %y
@@ -144,7 +144,7 @@ define i8 @test_sub_nsw(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub nsw i8 %x, %y
@@ -161,7 +161,7 @@ define i8 @test_sub_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub i8 %x, %y
@@ -178,7 +178,7 @@ define i8 @test_sub_none(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = add i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub i8 %x, %y


        


More information about the llvm-commits mailing list