[llvm] 0d335f7 - [InstCombine] Handle more commuted cases in matchesSquareSum()

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 20:36:35 PDT 2024


Author: Nikita Popov
Date: 2024-05-09T12:35:16+09:00
New Revision: 0d335f78e45341db53d9f956adcebbb2d2616c9a

URL: https://github.com/llvm/llvm-project/commit/0d335f78e45341db53d9f956adcebbb2d2616c9a
DIFF: https://github.com/llvm/llvm-project/commit/0d335f78e45341db53d9f956adcebbb2d2616c9a.diff

LOG: [InstCombine] Handle more commuted cases in matchesSquareSum()

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/test/Transforms/InstCombine/add.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 51ac77348ed9e..bff09f5676680 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1014,7 +1014,7 @@ static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A,
   // (a * a) + (((a * 2) + b) * b)
   if (match(&I, m_c_BinOp(
                     AddOp, m_OneUse(m_BinOp(MulOp, m_Value(A), m_Deferred(A))),
-                    m_OneUse(m_BinOp(
+                    m_OneUse(m_c_BinOp(
                         MulOp,
                         m_c_BinOp(AddOp, m_BinOp(Mul2Op, m_Deferred(A), M2Rhs),
                                   m_Value(B)),
@@ -1025,16 +1025,16 @@ static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A,
   // +
   // (a * a + b * b) or (b * b + a * a)
   return match(
-      &I,
-      m_c_BinOp(AddOp,
-                m_CombineOr(
-                    m_OneUse(m_BinOp(
-                        Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)),
-                    m_OneUse(m_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs),
+      &I, m_c_BinOp(
+              AddOp,
+              m_CombineOr(
+                  m_OneUse(m_BinOp(
+                      Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)),
+                  m_OneUse(m_c_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs),
                                      m_Value(B)))),
-                m_OneUse(m_c_BinOp(
-                    AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)),
-                    m_BinOp(MulOp, m_Deferred(B), m_Deferred(B))))));
+              m_OneUse(
+                  m_c_BinOp(AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)),
+                            m_BinOp(MulOp, m_Deferred(B), m_Deferred(B))))));
 }
 
 // Fold integer variations of a^2 + 2*a*b + b^2 -> (a + b)^2

diff  --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 42e901ea2d5a3..25087fef68a11 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -3287,11 +3287,8 @@ define i32 @add_reduce_sqr_sum_flipped(i32 %a, i32 %b) {
 define i32 @add_reduce_sqr_sum_flipped2(i32 %a, i32 %bx) {
 ; CHECK-LABEL: @add_reduce_sqr_sum_flipped2(
 ; CHECK-NEXT:    [[B:%.*]] = xor i32 [[BX:%.*]], 42
-; CHECK-NEXT:    [[A_SQ:%.*]] = mul nsw i32 [[A:%.*]], [[A]]
-; CHECK-NEXT:    [[TWO_A:%.*]] = shl i32 [[A]], 1
-; CHECK-NEXT:    [[TWO_A_PLUS_B:%.*]] = add i32 [[TWO_A]], [[B]]
-; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[B]], [[TWO_A_PLUS_B]]
-; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[MUL]], [[A_SQ]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[B]], [[A:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = mul i32 [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[ADD]]
 ;
   %b = xor i32 %bx, 42 ; thwart complexity-based canonicalization
@@ -3350,11 +3347,8 @@ define i32 @add_reduce_sqr_sum_order2_flipped(i32 %a, i32 %b) {
 define i32 @add_reduce_sqr_sum_order2_flipped2(i32 %a, i32 %bx) {
 ; CHECK-LABEL: @add_reduce_sqr_sum_order2_flipped2(
 ; CHECK-NEXT:    [[B:%.*]] = xor i32 [[BX:%.*]], 42
-; CHECK-NEXT:    [[A_SQ:%.*]] = mul nsw i32 [[A:%.*]], [[A]]
-; CHECK-NEXT:    [[TWOA:%.*]] = shl i32 [[A]], 1
-; CHECK-NEXT:    [[TWOAB1:%.*]] = add i32 [[B]], [[TWOA]]
-; CHECK-NEXT:    [[TWOAB_B2:%.*]] = mul i32 [[B]], [[TWOAB1]]
-; CHECK-NEXT:    [[AB2:%.*]] = add i32 [[A_SQ]], [[TWOAB_B2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[B]], [[A:%.*]]
+; CHECK-NEXT:    [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[AB2]]
 ;
   %b = xor i32 %bx, 42 ; thwart complexity-based canonicalization
@@ -3370,11 +3364,8 @@ define i32 @add_reduce_sqr_sum_order2_flipped2(i32 %a, i32 %bx) {
 define i32 @add_reduce_sqr_sum_order2_flipped3(i32 %a, i32 %bx) {
 ; CHECK-LABEL: @add_reduce_sqr_sum_order2_flipped3(
 ; CHECK-NEXT:    [[B:%.*]] = xor i32 [[BX:%.*]], 42
-; CHECK-NEXT:    [[A_SQ:%.*]] = mul nsw i32 [[A:%.*]], [[A]]
-; CHECK-NEXT:    [[TWOA:%.*]] = shl i32 [[A]], 1
-; CHECK-NEXT:    [[B_SQ1:%.*]] = add i32 [[TWOA]], [[B]]
-; CHECK-NEXT:    [[TWOAB_B2:%.*]] = mul i32 [[B]], [[B_SQ1]]
-; CHECK-NEXT:    [[AB2:%.*]] = add i32 [[A_SQ]], [[TWOAB_B2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[B]], [[A:%.*]]
+; CHECK-NEXT:    [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[AB2]]
 ;
   %b = xor i32 %bx, 42 ; thwart complexity-based canonicalization
@@ -3570,12 +3561,8 @@ define i32 @add_reduce_sqr_sum_order5_flipped2(i32 %a, i32 %b) {
 define i32 @add_reduce_sqr_sum_order5_flipped3(i32 %ax, i32 %b) {
 ; CHECK-LABEL: @add_reduce_sqr_sum_order5_flipped3(
 ; CHECK-NEXT:    [[A:%.*]] = xor i32 [[AX:%.*]], 42
-; CHECK-NEXT:    [[A_SQ:%.*]] = mul nsw i32 [[A]], [[A]]
-; CHECK-NEXT:    [[TWOB:%.*]] = shl i32 [[B:%.*]], 1
-; CHECK-NEXT:    [[TWOAB:%.*]] = mul i32 [[A]], [[TWOB]]
-; CHECK-NEXT:    [[B_SQ:%.*]] = mul i32 [[B]], [[B]]
-; CHECK-NEXT:    [[A2_B2:%.*]] = add i32 [[A_SQ]], [[B_SQ]]
-; CHECK-NEXT:    [[AB2:%.*]] = add i32 [[TWOAB]], [[A2_B2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[A]], [[B:%.*]]
+; CHECK-NEXT:    [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[AB2]]
 ;
   %a = xor i32 %ax, 42 ; thwart complexity-based canonicalization


        


More information about the llvm-commits mailing list