[llvm] f38365a - [InstCombine] Add support for maximum(a,b) + minimum(a,b) => a + b
Serguei Katkov via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 6 22:38:24 PDT 2023
Author: Serguei Katkov
Date: 2023-04-07T12:38:04+07:00
New Revision: f38365aef436d0f2ae042ad3038c8a6159dafe78
URL: https://github.com/llvm/llvm-project/commit/f38365aef436d0f2ae042ad3038c8a6159dafe78
DIFF: https://github.com/llvm/llvm-project/commit/f38365aef436d0f2ae042ad3038c8a6159dafe78.diff
LOG: [InstCombine] Add support for maximum(a,b) + minimum(a,b) => a + b
Unfortunately alive2 cannot prove the correctness due to fails by timeout even for
float type half.
However it should be correct. If a and b are not NaN, maximum and minimum will just
return different values (a and b) and take into account a + b == b + a this is the same.
If a or b is NaN, than maximum and minimum are equal to NaN and NaN + NaN is NaN.
a + b is also a NaN.
In terms of preserving fast flags, we cannot preserve ninf due to
minimum(NaN, Infinity) == maximum(NaN, Infinity) == NaN,
minimum(NaN, Infinity) +ninf maximum(NaN, Infinity) == NaN +ninf NaN = NaN
However transformation will change
minimum(NaN, Infinity) + maximum(NaN, Infinity) to NaN +ninf Infinity == poison.
But if fadd is marked as nnan, we can preserve because NaN +ninf/nnan NaN = poison as well.
The same optimization for
maximum(a,b) * minimum(a,b) => a * b
is added.
All said above for fadd is correct for fmul.
Reviewed By: mkazantsev
Differential Revision: https://reviews.llvm.org/D147299
Added:
Modified:
llvm/include/llvm/IR/PatternMatch.h
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll
llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 32443b02fa436..e141a96049162 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -2364,6 +2364,14 @@ m_c_MaxOrMin(const LHS &L, const RHS &R) {
m_CombineOr(m_c_UMax(L, R), m_c_UMin(L, R)));
}
+template <Intrinsic::ID IntrID, typename T0, typename T1>
+inline match_combine_or<typename m_Intrinsic_Ty<T0, T1>::Ty,
+ typename m_Intrinsic_Ty<T1, T0>::Ty>
+m_c_Intrinsic(const T0 &Op0, const T1 &Op1) {
+ return m_CombineOr(m_Intrinsic<IntrID>(Op0, Op1),
+ m_Intrinsic<IntrID>(Op1, Op0));
+}
+
/// Matches FAdd with LHS and RHS in either order.
template <typename LHS, typename RHS>
inline BinaryOp_match<LHS, RHS, Instruction::FAdd, true>
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 49573fd85c2f8..8f6934c6af083 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1811,6 +1811,20 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
return replaceInstUsesWith(I, V);
}
+ // minumum(X, Y) + maximum(X, Y) => X + Y.
+ if (match(&I,
+ m_c_FAdd(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)),
+ m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X),
+ m_Deferred(Y))))) {
+ BinaryOperator *Result = BinaryOperator::CreateFAddFMF(X, Y, &I);
+ // We cannot preserve ninf if nnan flag is not set.
+ // If X is NaN and Y is Inf then in original program we had NaN + NaN,
+ // while in optimized version NaN + Inf and this is a poison with ninf flag.
+ if (!Result->hasNoNaNs())
+ Result->setHasNoInfs(false);
+ return Result;
+ }
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 2f61d52a95546..5d4dd2fdac66d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -772,6 +772,20 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
I.hasNoSignedZeros() && match(Start, m_Zero()))
return replaceInstUsesWith(I, Start);
+ // minimun(X, Y) * maximum(X, Y) => X * Y.
+ if (match(&I,
+ m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)),
+ m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X),
+ m_Deferred(Y))))) {
+ BinaryOperator *Result = BinaryOperator::CreateFMulFMF(X, Y, &I);
+ // We cannot preserve ninf if nnan flag is not set.
+ // If X is NaN and Y is Inf then in original program we had NaN * NaN,
+ // while in optimized version NaN * Inf and this is a poison with ninf flag.
+ if (!Result->hasNoNaNs())
+ Result->setHasNoInfs(false);
+ return Result;
+ }
+
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll b/llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll
index 98bbea11dae2b..d690b81cbdb06 100644
--- a/llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll
+++ b/llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll
@@ -9,9 +9,7 @@ declare <4 x float> @llvm.maximum.v4f32(<4 x float> %Val0, <4 x float> %Val1)
define float @test(float %a, float %b) {
; CHECK-LABEL: @test(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT: [[RES:%.*]] = fadd float [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fadd float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -24,9 +22,7 @@ entry:
define float @test_comm1(float %a, float %b) {
; CHECK-LABEL: @test_comm1(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT: [[RES:%.*]] = fadd float [[MAX]], [[MIN]]
+; CHECK-NEXT: [[RES:%.*]] = fadd float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -39,9 +35,7 @@ entry:
define float @test_comm2(float %a, float %b) {
; CHECK-LABEL: @test_comm2(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[B]], float [[A]])
-; CHECK-NEXT: [[RES:%.*]] = fadd float [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fadd float [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -54,9 +48,7 @@ entry:
define float @test_comm3(float %a, float %b) {
; CHECK-LABEL: @test_comm3(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[B]], float [[A]])
-; CHECK-NEXT: [[RES:%.*]] = fadd float [[MAX]], [[MIN]]
+; CHECK-NEXT: [[RES:%.*]] = fadd float [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -69,9 +61,7 @@ entry:
define <4 x float> @test_vect(<4 x float> %a, <4 x float> %b) {
; CHECK-LABEL: @test_vect(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call <4 x float> @llvm.minimum.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call <4 x float> @llvm.maximum.v4f32(<4 x float> [[B]], <4 x float> [[A]])
-; CHECK-NEXT: [[RES:%.*]] = fadd <4 x float> [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fadd <4 x float> [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: ret <4 x float> [[RES]]
;
entry:
@@ -84,9 +74,7 @@ entry:
define float @test_flags(float %a, float %b) {
; CHECK-LABEL: @test_flags(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT: [[RES:%.*]] = fadd fast float [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fadd fast float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -99,9 +87,7 @@ entry:
define float @test_flags2(float %a, float %b) {
; CHECK-LABEL: @test_flags2(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT: [[RES:%.*]] = fadd reassoc ninf nsz arcp contract afn float [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fadd reassoc nsz arcp contract afn float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
diff --git a/llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll b/llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll
index 42f451d205a6b..c6f8c8a01b5e8 100644
--- a/llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll
+++ b/llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll
@@ -9,9 +9,7 @@ declare <4 x float> @llvm.maximum.v4f32(<4 x float> %Val0, <4 x float> %Val1)
define float @test(float %a, float %b) {
; CHECK-LABEL: @test(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT: [[RES:%.*]] = fmul float [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fmul float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -24,9 +22,7 @@ entry:
define float @test_comm1(float %a, float %b) {
; CHECK-LABEL: @test_comm1(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT: [[RES:%.*]] = fmul float [[MAX]], [[MIN]]
+; CHECK-NEXT: [[RES:%.*]] = fmul float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -39,9 +35,7 @@ entry:
define float @test_comm2(float %a, float %b) {
; CHECK-LABEL: @test_comm2(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[B]], float [[A]])
-; CHECK-NEXT: [[RES:%.*]] = fmul float [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fmul float [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -55,9 +49,7 @@ entry:
define float @test_comm3(float %a, float %b) {
; CHECK-LABEL: @test_comm3(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[B]], float [[A]])
-; CHECK-NEXT: [[RES:%.*]] = fmul float [[MAX]], [[MIN]]
+; CHECK-NEXT: [[RES:%.*]] = fmul float [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -70,9 +62,7 @@ entry:
define <4 x float> @test_vect(<4 x float> %a, <4 x float> %b) {
; CHECK-LABEL: @test_vect(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call <4 x float> @llvm.minimum.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call <4 x float> @llvm.maximum.v4f32(<4 x float> [[B]], <4 x float> [[A]])
-; CHECK-NEXT: [[RES:%.*]] = fmul <4 x float> [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fmul <4 x float> [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: ret <4 x float> [[RES]]
;
entry:
@@ -85,9 +75,7 @@ entry:
define float @test_flags(float %a, float %b) {
; CHECK-LABEL: @test_flags(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT: [[RES:%.*]] = fmul fast float [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fmul fast float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
@@ -100,9 +88,7 @@ entry:
define float @test_flags2(float %a, float %b) {
; CHECK-LABEL: @test_flags2(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT: [[RES:%.*]] = fmul reassoc ninf nsz arcp contract afn float [[MIN]], [[MAX]]
+; CHECK-NEXT: [[RES:%.*]] = fmul reassoc nsz arcp contract afn float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret float [[RES]]
;
entry:
More information about the llvm-commits
mailing list