[llvm] 0d335f7 - [InstCombine] Handle more commuted cases in matchesSquareSum()
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Wed May 8 20:36:35 PDT 2024
Author: Nikita Popov
Date: 2024-05-09T12:35:16+09:00
New Revision: 0d335f78e45341db53d9f956adcebbb2d2616c9a
URL: https://github.com/llvm/llvm-project/commit/0d335f78e45341db53d9f956adcebbb2d2616c9a
DIFF: https://github.com/llvm/llvm-project/commit/0d335f78e45341db53d9f956adcebbb2d2616c9a.diff
LOG: [InstCombine] Handle more commuted cases in matchesSquareSum()
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
llvm/test/Transforms/InstCombine/add.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 51ac77348ed9e..bff09f5676680 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1014,7 +1014,7 @@ static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A,
// (a * a) + (((a * 2) + b) * b)
if (match(&I, m_c_BinOp(
AddOp, m_OneUse(m_BinOp(MulOp, m_Value(A), m_Deferred(A))),
- m_OneUse(m_BinOp(
+ m_OneUse(m_c_BinOp(
MulOp,
m_c_BinOp(AddOp, m_BinOp(Mul2Op, m_Deferred(A), M2Rhs),
m_Value(B)),
@@ -1025,16 +1025,16 @@ static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A,
// +
// (a * a + b * b) or (b * b + a * a)
return match(
- &I,
- m_c_BinOp(AddOp,
- m_CombineOr(
- m_OneUse(m_BinOp(
- Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)),
- m_OneUse(m_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs),
+ &I, m_c_BinOp(
+ AddOp,
+ m_CombineOr(
+ m_OneUse(m_BinOp(
+ Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)),
+ m_OneUse(m_c_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs),
m_Value(B)))),
- m_OneUse(m_c_BinOp(
- AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)),
- m_BinOp(MulOp, m_Deferred(B), m_Deferred(B))))));
+ m_OneUse(
+ m_c_BinOp(AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)),
+ m_BinOp(MulOp, m_Deferred(B), m_Deferred(B))))));
}
// Fold integer variations of a^2 + 2*a*b + b^2 -> (a + b)^2
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 42e901ea2d5a3..25087fef68a11 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -3287,11 +3287,8 @@ define i32 @add_reduce_sqr_sum_flipped(i32 %a, i32 %b) {
define i32 @add_reduce_sqr_sum_flipped2(i32 %a, i32 %bx) {
; CHECK-LABEL: @add_reduce_sqr_sum_flipped2(
; CHECK-NEXT: [[B:%.*]] = xor i32 [[BX:%.*]], 42
-; CHECK-NEXT: [[A_SQ:%.*]] = mul nsw i32 [[A:%.*]], [[A]]
-; CHECK-NEXT: [[TWO_A:%.*]] = shl i32 [[A]], 1
-; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = add i32 [[TWO_A]], [[B]]
-; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[B]], [[TWO_A_PLUS_B]]
-; CHECK-NEXT: [[ADD:%.*]] = add i32 [[MUL]], [[A_SQ]]
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[B]], [[A:%.*]]
+; CHECK-NEXT: [[ADD:%.*]] = mul i32 [[TMP1]], [[TMP1]]
; CHECK-NEXT: ret i32 [[ADD]]
;
%b = xor i32 %bx, 42 ; thwart complexity-based canonicalization
@@ -3350,11 +3347,8 @@ define i32 @add_reduce_sqr_sum_order2_flipped(i32 %a, i32 %b) {
define i32 @add_reduce_sqr_sum_order2_flipped2(i32 %a, i32 %bx) {
; CHECK-LABEL: @add_reduce_sqr_sum_order2_flipped2(
; CHECK-NEXT: [[B:%.*]] = xor i32 [[BX:%.*]], 42
-; CHECK-NEXT: [[A_SQ:%.*]] = mul nsw i32 [[A:%.*]], [[A]]
-; CHECK-NEXT: [[TWOA:%.*]] = shl i32 [[A]], 1
-; CHECK-NEXT: [[TWOAB1:%.*]] = add i32 [[B]], [[TWOA]]
-; CHECK-NEXT: [[TWOAB_B2:%.*]] = mul i32 [[B]], [[TWOAB1]]
-; CHECK-NEXT: [[AB2:%.*]] = add i32 [[A_SQ]], [[TWOAB_B2]]
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[B]], [[A:%.*]]
+; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]]
; CHECK-NEXT: ret i32 [[AB2]]
;
%b = xor i32 %bx, 42 ; thwart complexity-based canonicalization
@@ -3370,11 +3364,8 @@ define i32 @add_reduce_sqr_sum_order2_flipped2(i32 %a, i32 %bx) {
define i32 @add_reduce_sqr_sum_order2_flipped3(i32 %a, i32 %bx) {
; CHECK-LABEL: @add_reduce_sqr_sum_order2_flipped3(
; CHECK-NEXT: [[B:%.*]] = xor i32 [[BX:%.*]], 42
-; CHECK-NEXT: [[A_SQ:%.*]] = mul nsw i32 [[A:%.*]], [[A]]
-; CHECK-NEXT: [[TWOA:%.*]] = shl i32 [[A]], 1
-; CHECK-NEXT: [[B_SQ1:%.*]] = add i32 [[TWOA]], [[B]]
-; CHECK-NEXT: [[TWOAB_B2:%.*]] = mul i32 [[B]], [[B_SQ1]]
-; CHECK-NEXT: [[AB2:%.*]] = add i32 [[A_SQ]], [[TWOAB_B2]]
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[B]], [[A:%.*]]
+; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]]
; CHECK-NEXT: ret i32 [[AB2]]
;
%b = xor i32 %bx, 42 ; thwart complexity-based canonicalization
@@ -3570,12 +3561,8 @@ define i32 @add_reduce_sqr_sum_order5_flipped2(i32 %a, i32 %b) {
define i32 @add_reduce_sqr_sum_order5_flipped3(i32 %ax, i32 %b) {
; CHECK-LABEL: @add_reduce_sqr_sum_order5_flipped3(
; CHECK-NEXT: [[A:%.*]] = xor i32 [[AX:%.*]], 42
-; CHECK-NEXT: [[A_SQ:%.*]] = mul nsw i32 [[A]], [[A]]
-; CHECK-NEXT: [[TWOB:%.*]] = shl i32 [[B:%.*]], 1
-; CHECK-NEXT: [[TWOAB:%.*]] = mul i32 [[A]], [[TWOB]]
-; CHECK-NEXT: [[B_SQ:%.*]] = mul i32 [[B]], [[B]]
-; CHECK-NEXT: [[A2_B2:%.*]] = add i32 [[A_SQ]], [[B_SQ]]
-; CHECK-NEXT: [[AB2:%.*]] = add i32 [[TWOAB]], [[A2_B2]]
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[A]], [[B:%.*]]
+; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]]
; CHECK-NEXT: ret i32 [[AB2]]
;
%a = xor i32 %ax, 42 ; thwart complexity-based canonicalization
More information about the llvm-commits
mailing list