[llvm] [InstCombine] Simplify `(add/sub (sub/add) (sub/add))` irrelivant of use-count (PR #105866)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 26 12:31:47 PDT 2024


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/105866

>From b417163b8a42cfa2f37d044755e994f175fb1146 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 22 Aug 2024 14:42:05 -0700
Subject: [PATCH 1/5] [InstCombine] Add tests for reassosiating `(add/sub
 (sub/add) (sub/add))`; NFC

---
 .../Transforms/InstCombine/fold-add-sub.ll    | 207 ++++++++++++++++++
 1 file changed, 207 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/fold-add-sub.ll

diff --git a/llvm/test/Transforms/InstCombine/fold-add-sub.ll b/llvm/test/Transforms/InstCombine/fold-add-sub.ll
new file mode 100644
index 00000000000000..107572459eb613
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fold-add-sub.ll
@@ -0,0 +1,207 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare void @use.i8(i8)
+define i8 @test_add_nsw(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_add_nsw(
+; CHECK-NEXT:    [[LHS:%.*]] = add nsw i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = add nsw i8 [[X]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = add nsw i8 %x, %y
+  %rhs = add nsw i8 %x, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = sub nsw i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_add_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_add_nsw_no_prop(
+; CHECK-NEXT:    [[LHS:%.*]] = add nsw i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = add nuw i8 [[X]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = add nsw i8 %x, %y
+  %rhs = add nuw i8 %x, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = sub nsw i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_add(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_add(
+; CHECK-NEXT:    [[LHS:%.*]] = add i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = add i8 [[X]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = add i8 %x, %y
+  %rhs = add i8 %x, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = sub i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_add_fail(i8 %w, i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_add_fail(
+; CHECK-NEXT:    [[LHS:%.*]] = add i8 [[W:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = add i8 [[X:%.*]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = add i8 %w, %y
+  %rhs = add i8 %x, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = sub i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_add_nuw(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_add_nuw(
+; CHECK-NEXT:    [[LHS:%.*]] = add i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = sub nuw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = add i8 %x, %y
+  %rhs = or disjoint i8 %x, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = sub nuw i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_add_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_add_nuw_no_prop(
+; CHECK-NEXT:    [[LHS:%.*]] = add i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = add i8 %x, %y
+  %rhs = or disjoint i8 %x, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = sub i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_sub_nuw(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_sub_nuw(
+; CHECK-NEXT:    [[LHS:%.*]] = sub nuw i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = sub nuw i8 [[Y]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = add i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = sub nuw i8 %x, %y
+  %rhs = sub nuw i8 %y, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = add i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_sub_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_sub_nuw_no_prop(
+; CHECK-NEXT:    [[LHS:%.*]] = sub nuw i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = add nuw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = sub nuw i8 %x, %y
+  %rhs = sub i8 %y, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = add nuw i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_sub_nsw(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_sub_nsw(
+; CHECK-NEXT:    [[LHS:%.*]] = sub nsw i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = sub nsw i8 %x, %y
+  %rhs = sub nsw i8 %y, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = or disjoint i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_sub_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_sub_nsw_no_prop(
+; CHECK-NEXT:    [[LHS:%.*]] = sub i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = sub i8 %x, %y
+  %rhs = sub nsw i8 %y, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = or disjoint i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_sub_none(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_sub_none(
+; CHECK-NEXT:    [[LHS:%.*]] = sub i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = add i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = sub i8 %x, %y
+  %rhs = sub i8 %y, %z
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = add i8 %lhs, %rhs
+  ret i8 %r
+}
+
+define i8 @test_sub_none_fail(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @test_sub_none_fail(
+; CHECK-NEXT:    [[LHS:%.*]] = sub i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = sub i8 [[Z:%.*]], [[Y]]
+; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
+; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
+; CHECK-NEXT:    [[R:%.*]] = add i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lhs = sub i8 %x, %y
+  %rhs = sub i8 %z, %y
+  call void @use.i8(i8 %lhs)
+  call void @use.i8(i8 %rhs)
+  %r = add i8 %lhs, %rhs
+  ret i8 %r
+}

>From 73f071283f98d9bf70a43ff56f46c111da25d8ef Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 22 Aug 2024 14:42:08 -0700
Subject: [PATCH 2/5] [InstCombine] Simplify `(add/sub (sub/add) (sub/add))`
 irrelivant of use-count

Added folds:
    - `(add (sub X, Y), (sub Z, X))` -> `(sub Z, Y)`
    - `(sub (add X, Y), (add X, Z))` -> `(sub Y, Z)`

The fold typically is handled in the `Reassosiate` pass, but it fails
if the inner `sub`/`add` are multi-use. Less importantly, Reassosiate
doesn't propagate flags correctly.

This patch adds the fold explicitly the InstCombine

Proofs: https://alive2.llvm.org/ce/z/p6JyRP
---
 .../InstCombine/InstCombineAddSub.cpp         | 30 +++++++++++++++++++
 .../InstCombine/InstCombineAndOrXor.cpp       |  3 ++
 .../InstCombine/InstCombineInternal.h         |  3 ++
 .../InstCombine/InstructionCombining.cpp      | 26 ++++++++++++++++
 .../Transforms/InstCombine/fold-add-sub.ll    | 20 ++++++-------
 5 files changed, 72 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index d7758b5fbf1786..5272bc11865f53 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1520,6 +1520,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Instruction *R = combineAddSubWithShlAddSub(Builder, I))
     return R;
 
+  if (Instruction *R = foldAddLike(I))
+    return R;
+
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
   Type *Ty = I.getType();
   if (Ty->isIntOrIntVectorTy(1))
@@ -2286,6 +2289,33 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
     }
   }
 
+  {
+    Value *W, *Z;
+    if (match(Op0, m_AddLike(m_Value(W), m_Value(X))) &&
+        match(Op1, m_AddLike(m_Value(Y), m_Value(Z)))) {
+      Instruction *R = nullptr;
+      if (W == Y)
+        R = BinaryOperator::CreateSub(X, Z);
+      if (W == Z)
+        R = BinaryOperator::CreateSub(X, Y);
+      if (X == Y)
+        R = BinaryOperator::CreateSub(W, Z);
+      if (X == Z)
+        R = BinaryOperator::CreateSub(W, Y);
+      if (R) {
+        bool NSW = I.hasNoSignedWrap() &&
+                   match(Op0, m_NSWAddLike(m_Value(), m_Value())) &&
+                   match(Op1, m_NSWAddLike(m_Value(), m_Value()));
+
+        bool NUW = I.hasNoUnsignedWrap() &&
+                   match(Op1, m_NUWAddLike(m_Value(), m_Value()));
+        R->setHasNoSignedWrap(NSW);
+        R->setHasNoUnsignedWrap(NUW);
+        return R;
+      }
+    }
+  }
+
   // (~X) - (~Y) --> Y - X
   {
     // Need to ensure we can consume at least one of the `not` instructions,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 0af06c7e463f80..9425b0b146bedb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3580,6 +3580,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
     return R;
 
+  if (Instruction *R = foldAddLike(I))
+    return R;
+
   Value *X, *Y;
   const APInt *CV;
   if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index b3957b760b4a29..fff787f69c395f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -585,6 +585,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                FPClassTest DemandedMask, KnownFPClass &Known,
                                unsigned Depth = 0);
 
+  /// Common transforms for add / disjoint or
+  Instruction *foldAddLike(BinaryOperator &I);
+
   /// Canonicalize the position of binops relative to shufflevector.
   Instruction *foldVectorBinop(BinaryOperator &Inst);
   Instruction *foldVectorSelect(SelectInst &Sel);
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8a96d1d0fb4c90..db49ae536a569f 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2008,6 +2008,32 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) {
   return true;
 }
 
+Instruction *InstCombinerImpl::foldAddLike(BinaryOperator &I) {
+  Value *LHS = I.getOperand(0);
+  Value *RHS = I.getOperand(1);
+  Value *A, *B, *C, *D;
+  if (match(LHS, m_Sub(m_Value(A), m_Value(B))) &&
+      match(RHS, m_Sub(m_Value(C), m_Value(D)))) {
+    Instruction *R = nullptr;
+    if (A == D)
+      R = BinaryOperator::CreateSub(C, B);
+    if (C == B)
+      R = BinaryOperator::CreateSub(A, D);
+    if (R) {
+      bool NSW = match(&I, m_NSWAddLike(m_Value(), m_Value())) &&
+                 match(LHS, m_NSWSub(m_Value(), m_Value())) &&
+                 match(RHS, m_NSWSub(m_Value(), m_Value()));
+
+      bool NUW = match(LHS, m_NUWSub(m_Value(), m_Value())) &&
+                 match(RHS, m_NUWSub(m_Value(), m_Value()));
+      R->setHasNoSignedWrap(NSW);
+      R->setHasNoUnsignedWrap(NUW);
+      return R;
+    }
+  }
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
   if (!isa<VectorType>(Inst.getType()))
     return nullptr;
diff --git a/llvm/test/Transforms/InstCombine/fold-add-sub.ll b/llvm/test/Transforms/InstCombine/fold-add-sub.ll
index 107572459eb613..bbb7bb67369e7f 100644
--- a/llvm/test/Transforms/InstCombine/fold-add-sub.ll
+++ b/llvm/test/Transforms/InstCombine/fold-add-sub.ll
@@ -8,7 +8,7 @@ define i8 @test_add_nsw(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = add nsw i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add nsw i8 %x, %y
@@ -25,7 +25,7 @@ define i8 @test_add_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = add nuw i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add nsw i8 %x, %y
@@ -42,7 +42,7 @@ define i8 @test_add(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = add i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add i8 %x, %y
@@ -76,7 +76,7 @@ define i8 @test_add_nuw(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub nuw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add i8 %x, %y
@@ -93,7 +93,7 @@ define i8 @test_add_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[Y]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = add i8 %x, %y
@@ -110,7 +110,7 @@ define i8 @test_sub_nuw(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub nuw i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = add i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub nuw i8 %x, %y
@@ -127,7 +127,7 @@ define i8 @test_sub_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = add nuw i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub nuw i8 %x, %y
@@ -144,7 +144,7 @@ define i8 @test_sub_nsw(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub nsw i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub nsw i8 %x, %y
@@ -161,7 +161,7 @@ define i8 @test_sub_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub i8 %x, %y
@@ -178,7 +178,7 @@ define i8 @test_sub_none(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
 ; CHECK-NEXT:    call void @use.i8(i8 [[LHS]])
 ; CHECK-NEXT:    call void @use.i8(i8 [[RHS]])
-; CHECK-NEXT:    [[R:%.*]] = add i8 [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[X]], [[Z]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %lhs = sub i8 %x, %y

>From d16b8d4593a8a8a414a3529e7eb3c62060dc2a0d Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 23 Aug 2024 13:26:44 -0700
Subject: [PATCH 3/5] Correctly guard foldAddLike for or

---
 llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 9425b0b146bedb..9c76f818861829 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3580,8 +3580,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
     return R;
 
-  if (Instruction *R = foldAddLike(I))
-    return R;
+  if (cast<PossiblyDisjointInst>(I).isDisjoint())
+    if (Instruction *R = foldAddLike(I))
+      return R;
 
   Value *X, *Y;
   const APInt *CV;

>From f058c173fabb8a21ccd7c56797b0e6d77dd75010 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Sat, 24 Aug 2024 09:38:23 -0700
Subject: [PATCH 4/5] Fixups

---
 .../InstCombine/InstCombineAddSub.cpp         | 12 +++++-----
 .../InstCombine/InstCombineAndOrXor.cpp       |  3 ++-
 .../InstCombine/InstCombineInternal.h         |  2 +-
 .../InstCombine/InstructionCombining.cpp      | 22 +++++++++----------
 4 files changed, 19 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 5272bc11865f53..f47f63eb39b77d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1520,10 +1520,10 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Instruction *R = combineAddSubWithShlAddSub(Builder, I))
     return R;
 
-  if (Instruction *R = foldAddLike(I))
-    return R;
-
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+  if (Instruction *R =
+          foldAddLike(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap()))
+    return R;
   Type *Ty = I.getType();
   if (Ty->isIntOrIntVectorTy(1))
     return BinaryOperator::CreateXor(LHS, RHS);
@@ -2296,11 +2296,11 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
       Instruction *R = nullptr;
       if (W == Y)
         R = BinaryOperator::CreateSub(X, Z);
-      if (W == Z)
+      else if (W == Z)
         R = BinaryOperator::CreateSub(X, Y);
-      if (X == Y)
+      else if (X == Y)
         R = BinaryOperator::CreateSub(W, Z);
-      if (X == Z)
+      else if (X == Z)
         R = BinaryOperator::CreateSub(W, Y);
       if (R) {
         bool NSW = I.hasNoSignedWrap() &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 9c76f818861829..08f16f8b1ee269 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3581,7 +3581,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
     return R;
 
   if (cast<PossiblyDisjointInst>(I).isDisjoint())
-    if (Instruction *R = foldAddLike(I))
+    if (Instruction *R = foldAddLike(I.getOperand(0), I.getOperand(1),
+                                     /*NSW=*/true, /*NUW=*/true))
       return R;
 
   Value *X, *Y;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index fff787f69c395f..0e52ff2014d154 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -586,7 +586,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                unsigned Depth = 0);
 
   /// Common transforms for add / disjoint or
-  Instruction *foldAddLike(BinaryOperator &I);
+  Instruction *foldAddLike(Value *LHS, Value *RHS, bool NSW, bool NUW);
 
   /// Canonicalize the position of binops relative to shufflevector.
   Instruction *foldVectorBinop(BinaryOperator &Inst);
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index db49ae536a569f..2861ebcb822143 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2008,26 +2008,24 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) {
   return true;
 }
 
-Instruction *InstCombinerImpl::foldAddLike(BinaryOperator &I) {
-  Value *LHS = I.getOperand(0);
-  Value *RHS = I.getOperand(1);
+Instruction *InstCombinerImpl::foldAddLike(Value *LHS, Value *RHS, bool NSW,
+                                           bool NUW) {
   Value *A, *B, *C, *D;
   if (match(LHS, m_Sub(m_Value(A), m_Value(B))) &&
       match(RHS, m_Sub(m_Value(C), m_Value(D)))) {
     Instruction *R = nullptr;
     if (A == D)
       R = BinaryOperator::CreateSub(C, B);
-    if (C == B)
+    else if (C == B)
       R = BinaryOperator::CreateSub(A, D);
     if (R) {
-      bool NSW = match(&I, m_NSWAddLike(m_Value(), m_Value())) &&
-                 match(LHS, m_NSWSub(m_Value(), m_Value())) &&
-                 match(RHS, m_NSWSub(m_Value(), m_Value()));
-
-      bool NUW = match(LHS, m_NUWSub(m_Value(), m_Value())) &&
-                 match(RHS, m_NUWSub(m_Value(), m_Value()));
-      R->setHasNoSignedWrap(NSW);
-      R->setHasNoUnsignedWrap(NUW);
+      bool NSWOut = NSW && match(LHS, m_NSWSub(m_Value(), m_Value())) &&
+                    match(RHS, m_NSWSub(m_Value(), m_Value()));
+
+      bool NUWOut = match(LHS, m_NUWSub(m_Value(), m_Value())) &&
+                    match(RHS, m_NUWSub(m_Value(), m_Value()));
+      R->setHasNoSignedWrap(NSWOut);
+      R->setHasNoUnsignedWrap(NUWOut);
       return R;
     }
   }

>From 6d628214170d264f9a03add36bd3d9d9172e5dd4 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 26 Aug 2024 12:31:18 -0700
Subject: [PATCH 5/5] Make API commutatable

---
 .../InstCombine/InstCombineAddSub.cpp         |  7 ++--
 .../InstCombine/InstCombineAndOrXor.cpp       | 12 +++++--
 .../InstCombine/InstCombineInternal.h         |  3 +-
 .../InstCombine/InstructionCombining.cpp      | 32 ++++++++-----------
 4 files changed, 29 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index f47f63eb39b77d..352a77bb11181a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1521,8 +1521,11 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     return R;
 
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
-  if (Instruction *R =
-          foldAddLike(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap()))
+  if (Instruction *R = foldAddLikeCommutative(LHS, RHS, I.hasNoSignedWrap(),
+                                              I.hasNoUnsignedWrap()))
+    return R;
+  if (Instruction *R = foldAddLikeCommutative(RHS, LHS, I.hasNoSignedWrap(),
+                                              I.hasNoUnsignedWrap()))
     return R;
   Type *Ty = I.getType();
   if (Ty->isIntOrIntVectorTy(1))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 08f16f8b1ee269..4f557532f9f783 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3580,10 +3580,16 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
     return R;
 
-  if (cast<PossiblyDisjointInst>(I).isDisjoint())
-    if (Instruction *R = foldAddLike(I.getOperand(0), I.getOperand(1),
-                                     /*NSW=*/true, /*NUW=*/true))
+  if (cast<PossiblyDisjointInst>(I).isDisjoint()) {
+    if (Instruction *R =
+            foldAddLikeCommutative(I.getOperand(0), I.getOperand(1),
+                                   /*NSW=*/true, /*NUW=*/true))
+      return R;
+    if (Instruction *R =
+            foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
+                                   /*NSW=*/true, /*NUW=*/true))
       return R;
+  }
 
   Value *X, *Y;
   const APInt *CV;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 0e52ff2014d154..57f27e6a3b7fa5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -586,7 +586,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                unsigned Depth = 0);
 
   /// Common transforms for add / disjoint or
-  Instruction *foldAddLike(Value *LHS, Value *RHS, bool NSW, bool NUW);
+  Instruction *foldAddLikeCommutative(Value *LHS, Value *RHS, bool NSW,
+                                      bool NUW);
 
   /// Canonicalize the position of binops relative to shufflevector.
   Instruction *foldVectorBinop(BinaryOperator &Inst);
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 2861ebcb822143..3c94d65ba75ecb 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2008,26 +2008,20 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) {
   return true;
 }
 
-Instruction *InstCombinerImpl::foldAddLike(Value *LHS, Value *RHS, bool NSW,
-                                           bool NUW) {
-  Value *A, *B, *C, *D;
+Instruction *InstCombinerImpl::foldAddLikeCommutative(Value *LHS, Value *RHS,
+                                                      bool NSW, bool NUW) {
+  Value *A, *B, *C;
   if (match(LHS, m_Sub(m_Value(A), m_Value(B))) &&
-      match(RHS, m_Sub(m_Value(C), m_Value(D)))) {
-    Instruction *R = nullptr;
-    if (A == D)
-      R = BinaryOperator::CreateSub(C, B);
-    else if (C == B)
-      R = BinaryOperator::CreateSub(A, D);
-    if (R) {
-      bool NSWOut = NSW && match(LHS, m_NSWSub(m_Value(), m_Value())) &&
-                    match(RHS, m_NSWSub(m_Value(), m_Value()));
-
-      bool NUWOut = match(LHS, m_NUWSub(m_Value(), m_Value())) &&
-                    match(RHS, m_NUWSub(m_Value(), m_Value()));
-      R->setHasNoSignedWrap(NSWOut);
-      R->setHasNoUnsignedWrap(NUWOut);
-      return R;
-    }
+      match(RHS, m_Sub(m_Value(C), m_Specific(A)))) {
+    Instruction *R = BinaryOperator::CreateSub(C, B);
+    bool NSWOut = NSW && match(LHS, m_NSWSub(m_Value(), m_Value())) &&
+                  match(RHS, m_NSWSub(m_Value(), m_Value()));
+
+    bool NUWOut = match(LHS, m_NUWSub(m_Value(), m_Value())) &&
+                  match(RHS, m_NUWSub(m_Value(), m_Value()));
+    R->setHasNoSignedWrap(NSWOut);
+    R->setHasNoUnsignedWrap(NUWOut);
+    return R;
   }
   return nullptr;
 }



More information about the llvm-commits mailing list