[llvm] f38365a - [InstCombine] Add support for maximum(a,b) + minimum(a,b) => a + b

Serguei Katkov via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 6 22:38:24 PDT 2023


Author: Serguei Katkov
Date: 2023-04-07T12:38:04+07:00
New Revision: f38365aef436d0f2ae042ad3038c8a6159dafe78

URL: https://github.com/llvm/llvm-project/commit/f38365aef436d0f2ae042ad3038c8a6159dafe78
DIFF: https://github.com/llvm/llvm-project/commit/f38365aef436d0f2ae042ad3038c8a6159dafe78.diff

LOG: [InstCombine] Add support for maximum(a,b) + minimum(a,b) => a + b

Unfortunately alive2 cannot prove the correctness due to fails by timeout even for
float type half.

However it should be correct. If a and b are not NaN, maximum and minimum will just
return different values (a and b) and take into account a + b == b + a this is the same.
If a or b is NaN, than maximum and minimum are equal to NaN and NaN + NaN is NaN.
a + b is also a NaN.

In terms of preserving fast flags, we cannot preserve ninf due to
minimum(NaN, Infinity) == maximum(NaN, Infinity) == NaN,
minimum(NaN, Infinity) +ninf maximum(NaN, Infinity) == NaN +ninf NaN = NaN
However transformation will change
minimum(NaN, Infinity) + maximum(NaN, Infinity) to NaN +ninf Infinity == poison.

But if fadd is marked as nnan, we can preserve because NaN +ninf/nnan NaN = poison as well.

The same optimization for
  maximum(a,b) * minimum(a,b) => a * b
is added.
All said above for fadd is correct for fmul.

Reviewed By: mkazantsev
Differential Revision: https://reviews.llvm.org/D147299

Added: 
    

Modified: 
    llvm/include/llvm/IR/PatternMatch.h
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll
    llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 32443b02fa436..e141a96049162 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -2364,6 +2364,14 @@ m_c_MaxOrMin(const LHS &L, const RHS &R) {
                      m_CombineOr(m_c_UMax(L, R), m_c_UMin(L, R)));
 }
 
+template <Intrinsic::ID IntrID, typename T0, typename T1>
+inline match_combine_or<typename m_Intrinsic_Ty<T0, T1>::Ty,
+                        typename m_Intrinsic_Ty<T1, T0>::Ty>
+m_c_Intrinsic(const T0 &Op0, const T1 &Op1) {
+  return m_CombineOr(m_Intrinsic<IntrID>(Op0, Op1),
+                     m_Intrinsic<IntrID>(Op1, Op0));
+}
+
 /// Matches FAdd with LHS and RHS in either order.
 template <typename LHS, typename RHS>
 inline BinaryOp_match<LHS, RHS, Instruction::FAdd, true>

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 49573fd85c2f8..8f6934c6af083 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1811,6 +1811,20 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
       return replaceInstUsesWith(I, V);
   }
 
+  // minumum(X, Y) + maximum(X, Y) => X + Y.
+  if (match(&I,
+            m_c_FAdd(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)),
+                     m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X),
+                                                       m_Deferred(Y))))) {
+    BinaryOperator *Result = BinaryOperator::CreateFAddFMF(X, Y, &I);
+    // We cannot preserve ninf if nnan flag is not set.
+    // If X is NaN and Y is Inf then in original program we had NaN + NaN,
+    // while in optimized version NaN + Inf and this is a poison with ninf flag.
+    if (!Result->hasNoNaNs())
+      Result->setHasNoInfs(false);
+    return Result;
+  }
+
   return nullptr;
 }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 2f61d52a95546..5d4dd2fdac66d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -772,6 +772,20 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
       I.hasNoSignedZeros() && match(Start, m_Zero()))
     return replaceInstUsesWith(I, Start);
 
+  // minimun(X, Y) * maximum(X, Y) => X * Y.
+  if (match(&I,
+            m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)),
+                     m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X),
+                                                       m_Deferred(Y))))) {
+    BinaryOperator *Result = BinaryOperator::CreateFMulFMF(X, Y, &I);
+    // We cannot preserve ninf if nnan flag is not set.
+    // If X is NaN and Y is Inf then in original program we had NaN * NaN,
+    // while in optimized version NaN * Inf and this is a poison with ninf flag.
+    if (!Result->hasNoNaNs())
+      Result->setHasNoInfs(false);
+    return Result;
+  }
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll b/llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll
index 98bbea11dae2b..d690b81cbdb06 100644
--- a/llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll
+++ b/llvm/test/Transforms/InstCombine/fadd-maximum-minimum.ll
@@ -9,9 +9,7 @@ declare <4 x float>    @llvm.maximum.v4f32(<4 x float> %Val0, <4 x float> %Val1)
 define float @test(float %a, float %b) {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = fadd float [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fadd float [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -24,9 +22,7 @@ entry:
 define float @test_comm1(float %a, float %b) {
 ; CHECK-LABEL: @test_comm1(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = fadd float [[MAX]], [[MIN]]
+; CHECK-NEXT:    [[RES:%.*]] = fadd float [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -39,9 +35,7 @@ entry:
 define float @test_comm2(float %a, float %b) {
 ; CHECK-LABEL: @test_comm2(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[B]], float [[A]])
-; CHECK-NEXT:    [[RES:%.*]] = fadd float [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fadd float [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -54,9 +48,7 @@ entry:
 define float @test_comm3(float %a, float %b) {
 ; CHECK-LABEL: @test_comm3(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[B]], float [[A]])
-; CHECK-NEXT:    [[RES:%.*]] = fadd float [[MAX]], [[MIN]]
+; CHECK-NEXT:    [[RES:%.*]] = fadd float [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -69,9 +61,7 @@ entry:
 define <4 x float> @test_vect(<4 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: @test_vect(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call <4 x float> @llvm.minimum.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call <4 x float> @llvm.maximum.v4f32(<4 x float> [[B]], <4 x float> [[A]])
-; CHECK-NEXT:    [[RES:%.*]] = fadd <4 x float> [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fadd <4 x float> [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret <4 x float> [[RES]]
 ;
 entry:
@@ -84,9 +74,7 @@ entry:
 define float @test_flags(float %a, float %b) {
 ; CHECK-LABEL: @test_flags(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = fadd fast float [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fadd fast float [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -99,9 +87,7 @@ entry:
 define float @test_flags2(float %a, float %b) {
 ; CHECK-LABEL: @test_flags2(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = fadd reassoc ninf nsz arcp contract afn float [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fadd reassoc nsz arcp contract afn float [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:

diff  --git a/llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll b/llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll
index 42f451d205a6b..c6f8c8a01b5e8 100644
--- a/llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll
+++ b/llvm/test/Transforms/InstCombine/fmul-maximum-minimum.ll
@@ -9,9 +9,7 @@ declare <4 x float>    @llvm.maximum.v4f32(<4 x float> %Val0, <4 x float> %Val1)
 define float @test(float %a, float %b) {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = fmul float [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fmul float [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -24,9 +22,7 @@ entry:
 define float @test_comm1(float %a, float %b) {
 ; CHECK-LABEL: @test_comm1(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = fmul float [[MAX]], [[MIN]]
+; CHECK-NEXT:    [[RES:%.*]] = fmul float [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -39,9 +35,7 @@ entry:
 define float @test_comm2(float %a, float %b) {
 ; CHECK-LABEL: @test_comm2(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[B]], float [[A]])
-; CHECK-NEXT:    [[RES:%.*]] = fmul float [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fmul float [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -55,9 +49,7 @@ entry:
 define float @test_comm3(float %a, float %b) {
 ; CHECK-LABEL: @test_comm3(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[B]], float [[A]])
-; CHECK-NEXT:    [[RES:%.*]] = fmul float [[MAX]], [[MIN]]
+; CHECK-NEXT:    [[RES:%.*]] = fmul float [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -70,9 +62,7 @@ entry:
 define <4 x float> @test_vect(<4 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: @test_vect(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call <4 x float> @llvm.minimum.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call <4 x float> @llvm.maximum.v4f32(<4 x float> [[B]], <4 x float> [[A]])
-; CHECK-NEXT:    [[RES:%.*]] = fmul <4 x float> [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fmul <4 x float> [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret <4 x float> [[RES]]
 ;
 entry:
@@ -85,9 +75,7 @@ entry:
 define float @test_flags(float %a, float %b) {
 ; CHECK-LABEL: @test_flags(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = fmul fast float [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fmul fast float [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:
@@ -100,9 +88,7 @@ entry:
 define float @test_flags2(float %a, float %b) {
 ; CHECK-LABEL: @test_flags2(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[A]], float [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = fmul reassoc ninf nsz arcp contract afn float [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = fmul reassoc nsz arcp contract afn float [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret float [[RES]]
 ;
 entry:


        


More information about the llvm-commits mailing list