[llvm] faf8407 - [InstSimplify] Extend handlng of fp min/max.

Serguei Katkov via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 26 20:46:10 PDT 2023


Author: Serguei Katkov
Date: 2023-04-27T10:45:27+07:00
New Revision: faf8407aecd15125261787bc9b9b4d448174b5d4

URL: https://github.com/llvm/llvm-project/commit/faf8407aecd15125261787bc9b9b4d448174b5d4
DIFF: https://github.com/llvm/llvm-project/commit/faf8407aecd15125261787bc9b9b4d448174b5d4.diff

LOG: [InstSimplify] Extend handlng of fp min/max.

Add support the cases like
  m(m(X,Y),m'(X,Y)) => m(X,Y)
where m is one of maxnum, minnum, maximum, minimum and
m' is m or inverse of m.

alive2 correctness check:
maxnum(maxnum,maxnum) https://alive2.llvm.org/ce/z/kSyAzo
maxnum(maxnum,minnum) https://alive2.llvm.org/ce/z/Vra8j2
minnum(minnum,minnum) https://alive2.llvm.org/ce/z/B6h-hW
minnum(minnum,maxnum) https://alive2.llvm.org/ce/z/rG2u_b
maximum(maximum,maximum) https://alive2.llvm.org/ce/z/N2nevY
maximum(maximum,minimum) https://alive2.llvm.org/ce/z/23RFcP
minimum(minimum,minimum) https://alive2.llvm.org/ce/z/spHZ-U
minimum(minimum,maximum) https://alive2.llvm.org/ce/z/Aa-VE8

Reviewed By: dantrushin, RKSimon
Differential Revision: https://reviews.llvm.org/D147137

Added: 
    

Modified: 
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/test/Transforms/InstSimplify/fminmax-folds.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index d4639f6b88818..9a2b030aba5c7 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6119,6 +6119,51 @@ static Value *foldMinMaxSharedOp(Intrinsic::ID IID, Value *Op0, Value *Op1) {
   return nullptr;
 }
 
+/// Given a min/max intrinsic, see if it can be removed based on having an
+/// operand that is another min/max intrinsic with shared operand(s). The caller
+/// is expected to swap the operand arguments to handle commutation.
+static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0,
+                                         Value *Op1) {
+  assert((IID == Intrinsic::maxnum || IID == Intrinsic::minnum ||
+          IID == Intrinsic::maximum || IID == Intrinsic::minimum) &&
+         "Unsupported intrinsic");
+
+  auto *M0 = dyn_cast<IntrinsicInst>(Op0);
+  // If Op0 is not the same intrinsic as IID, do not process.
+  // This is a 
diff erence with integer min/max handling. We do not process the
+  // case like max(min(X,Y),min(X,Y)) => min(X,Y). But it can be handled by GVN.
+  if (!M0 || M0->getIntrinsicID() != IID)
+    return nullptr;
+  Value *X0 = M0->getOperand(0);
+  Value *Y0 = M0->getOperand(1);
+  // Simple case, m(m(X,Y), X) => m(X, Y)
+  //              m(m(X,Y), Y) => m(X, Y)
+  // For minimum/maximum, X is NaN => m(NaN, Y) == NaN and m(NaN, NaN) == NaN.
+  // For minimum/maximum, Y is NaN => m(X, NaN) == NaN  and m(NaN, NaN) == NaN.
+  // For minnum/maxnum, X is NaN => m(NaN, Y) == Y and m(Y, Y) == Y.
+  // For minnum/maxnum, Y is NaN => m(X, NaN) == X and m(X, NaN) == X.
+  if (X0 == Op1 || Y0 == Op1)
+    return M0;
+
+  auto *M1 = dyn_cast<IntrinsicInst>(Op1);
+  if (!M1)
+    return nullptr;
+  Value *X1 = M1->getOperand(0);
+  Value *Y1 = M1->getOperand(1);
+  Intrinsic::ID IID1 = M1->getIntrinsicID();
+  // we have a case m(m(X,Y),m'(X,Y)) taking into account m' is commutative.
+  // if m' is m or inversion of m => m(m(X,Y),m'(X,Y)) == m(X,Y).
+  // For minimum/maximum, X is NaN => m(NaN,Y) == m'(NaN, Y) == NaN.
+  // For minimum/maximum, Y is NaN => m(X,NaN) == m'(X, NaN) == NaN.
+  // For minnum/maxnum, X is NaN => m(NaN,Y) == m'(NaN, Y) == Y.
+  // For minnum/maxnum, Y is NaN => m(X,NaN) == m'(X, NaN) == X.
+  if ((X0 == X1 && Y0 == Y1) || (X0 == Y1 && Y0 == X1))
+    if (IID1 == IID || getInverseMinMaxIntrinsic(IID1) == IID)
+      return M0;
+
+  return nullptr;
+}
+
 static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
                                       const SimplifyQuery &Q) {
   Intrinsic::ID IID = F->getIntrinsicID();
@@ -6358,14 +6403,10 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
 
     // Min/max of the same operation with common operand:
     // m(m(X, Y)), X --> m(X, Y) (4 commuted variants)
-    if (auto *M0 = dyn_cast<IntrinsicInst>(Op0))
-      if (M0->getIntrinsicID() == IID &&
-          (M0->getOperand(0) == Op1 || M0->getOperand(1) == Op1))
-        return Op0;
-    if (auto *M1 = dyn_cast<IntrinsicInst>(Op1))
-      if (M1->getIntrinsicID() == IID &&
-          (M1->getOperand(0) == Op0 || M1->getOperand(1) == Op0))
-        return Op1;
+    if (Value *V = foldMinimumMaximumSharedOp(IID, Op0, Op1))
+      return V;
+    if (Value *V = foldMinimumMaximumSharedOp(IID, Op1, Op0))
+      return V;
 
     break;
   }

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f987353697ce0..7ceca4a9ba936 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -7735,6 +7735,12 @@ Intrinsic::ID llvm::getInverseMinMaxIntrinsic(Intrinsic::ID MinMaxID) {
   case Intrinsic::smin: return Intrinsic::smax;
   case Intrinsic::umax: return Intrinsic::umin;
   case Intrinsic::umin: return Intrinsic::umax;
+  // Please note that next four intrinsics may produce the same result for
+  // original and inverted case even if X != Y due to NaN is handled specially.
+  case Intrinsic::maximum: return Intrinsic::minimum;
+  case Intrinsic::minimum: return Intrinsic::maximum;
+  case Intrinsic::maxnum: return Intrinsic::minnum;
+  case Intrinsic::minnum: return Intrinsic::maxnum;
   default: llvm_unreachable("Unexpected intrinsic");
   }
 }

diff  --git a/llvm/test/Transforms/InstSimplify/fminmax-folds.ll b/llvm/test/Transforms/InstSimplify/fminmax-folds.ll
index a5785275afff4..a8a9e96a652fa 100644
--- a/llvm/test/Transforms/InstSimplify/fminmax-folds.ll
+++ b/llvm/test/Transforms/InstSimplify/fminmax-folds.ll
@@ -1206,9 +1206,7 @@ define float @maximum_inf_commute(float %x) {
 define float @maximum_maximum_minimum(float %x, float %y) {
 ; CHECK-LABEL: @maximum_maximum_minimum(
 ; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maximum.f32(float [[X:%.*]], float [[Y:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minimum.f32(float [[X]], float [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call float @llvm.maximum.f32(float [[MAX]], float [[MIN]])
-; CHECK-NEXT:    ret float [[VAL]]
+; CHECK-NEXT:    ret float [[MAX]]
 ;
   %max = call float @llvm.maximum.f32(float %x, float %y)
   %min = call float @llvm.minimum.f32(float %x, float %y)
@@ -1219,9 +1217,7 @@ define float @maximum_maximum_minimum(float %x, float %y) {
 define double @maximum_minimum_maximum(double %x, double %y) {
 ; CHECK-LABEL: @maximum_minimum_maximum(
 ; CHECK-NEXT:    [[MAX:%.*]] = call double @llvm.maximum.f64(double [[X:%.*]], double [[Y:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call double @llvm.minimum.f64(double [[X]], double [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call double @llvm.maximum.f64(double [[MIN]], double [[MAX]])
-; CHECK-NEXT:    ret double [[VAL]]
+; CHECK-NEXT:    ret double [[MAX]]
 ;
   %max = call double @llvm.maximum.f64(double %x, double %y)
   %min = call double @llvm.minimum.f64(double %x, double %y)
@@ -1245,9 +1241,7 @@ define float @maximum_minimum_minimum(float %x, float %y) {
 define half @maximum_maximum_maximum(half %x, half %y) {
 ; CHECK-LABEL: @maximum_maximum_maximum(
 ; CHECK-NEXT:    [[MAX1:%.*]] = call half @llvm.maximum.f16(half [[X:%.*]], half [[Y:%.*]])
-; CHECK-NEXT:    [[MAX2:%.*]] = call half @llvm.maximum.f16(half [[X]], half [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call half @llvm.maximum.f16(half [[MAX1]], half [[MAX2]])
-; CHECK-NEXT:    ret half [[VAL]]
+; CHECK-NEXT:    ret half [[MAX1]]
 ;
   %max1 = call half @llvm.maximum.f16(half %x, half %y)
   %max2 = call half @llvm.maximum.f16(half %x, half %y)
@@ -1257,10 +1251,8 @@ define half @maximum_maximum_maximum(half %x, half %y) {
 
 define <2 x float> @minimum_maximum_minimum(<2 x float> %x, <2 x float> %y) {
 ; CHECK-LABEL: @minimum_maximum_minimum(
-; CHECK-NEXT:    [[MAX:%.*]] = call <2 x float> @llvm.maximum.v2f32(<2 x float> [[X:%.*]], <2 x float> [[Y:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call <2 x float> @llvm.minimum.v2f32(<2 x float> [[X]], <2 x float> [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call <2 x float> @llvm.minimum.v2f32(<2 x float> [[MAX]], <2 x float> [[MIN]])
-; CHECK-NEXT:    ret <2 x float> [[VAL]]
+; CHECK-NEXT:    [[MIN:%.*]] = call <2 x float> @llvm.minimum.v2f32(<2 x float> [[X:%.*]], <2 x float> [[Y:%.*]])
+; CHECK-NEXT:    ret <2 x float> [[MIN]]
 ;
   %max = call <2 x float> @llvm.maximum.v2f32(<2 x float> %x, <2 x float> %y)
   %min = call <2 x float> @llvm.minimum.v2f32(<2 x float> %x, <2 x float> %y)
@@ -1270,10 +1262,8 @@ define <2 x float> @minimum_maximum_minimum(<2 x float> %x, <2 x float> %y) {
 
 define <2 x double> @minimum_minimum_maximum(<2 x double> %x, <2 x double> %y) {
 ; CHECK-LABEL: @minimum_minimum_maximum(
-; CHECK-NEXT:    [[MAX:%.*]] = call <2 x double> @llvm.maximum.v2f64(<2 x double> [[X:%.*]], <2 x double> [[Y:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call <2 x double> @llvm.minimum.v2f64(<2 x double> [[X]], <2 x double> [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call <2 x double> @llvm.minimum.v2f64(<2 x double> [[MIN]], <2 x double> [[MAX]])
-; CHECK-NEXT:    ret <2 x double> [[VAL]]
+; CHECK-NEXT:    [[MIN:%.*]] = call <2 x double> @llvm.minimum.v2f64(<2 x double> [[X:%.*]], <2 x double> [[Y:%.*]])
+; CHECK-NEXT:    ret <2 x double> [[MIN]]
 ;
   %max = call <2 x double> @llvm.maximum.v2f64(<2 x double> %x, <2 x double> %y)
   %min = call <2 x double> @llvm.minimum.v2f64(<2 x double> %x, <2 x double> %y)
@@ -1297,9 +1287,7 @@ define float @minimum_maximum_maximum(float %x, float %y) {
 define float @minimum_minimum_minimum(float %x, float %y) {
 ; CHECK-LABEL: @minimum_minimum_minimum(
 ; CHECK-NEXT:    [[MIN1:%.*]] = call float @llvm.minimum.f32(float [[X:%.*]], float [[Y:%.*]])
-; CHECK-NEXT:    [[MIN2:%.*]] = call float @llvm.minimum.f32(float [[X]], float [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call float @llvm.minimum.f32(float [[MIN1]], float [[MIN2]])
-; CHECK-NEXT:    ret float [[VAL]]
+; CHECK-NEXT:    ret float [[MIN1]]
 ;
   %min1 = call float @llvm.minimum.f32(float %x, float %y)
   %min2 = call float @llvm.minimum.f32(float %x, float %y)
@@ -1310,9 +1298,7 @@ define float @minimum_minimum_minimum(float %x, float %y) {
 define double @maxnum_maxnum_minnum(double %x, double %y) {
 ; CHECK-LABEL: @maxnum_maxnum_minnum(
 ; CHECK-NEXT:    [[MAX:%.*]] = call double @llvm.maxnum.f64(double [[X:%.*]], double [[Y:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call double @llvm.minnum.f64(double [[X]], double [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call double @llvm.maxnum.f64(double [[MAX]], double [[MIN]])
-; CHECK-NEXT:    ret double [[VAL]]
+; CHECK-NEXT:    ret double [[MAX]]
 ;
   %max = call double @llvm.maxnum.f64(double %x, double %y)
   %min = call double @llvm.minnum.f64(double %x, double %y)
@@ -1323,9 +1309,7 @@ define double @maxnum_maxnum_minnum(double %x, double %y) {
 define <2 x float> @maxnum_minnum_maxnum(<2 x float> %x, <2 x float> %y) {
 ; CHECK-LABEL: @maxnum_minnum_maxnum(
 ; CHECK-NEXT:    [[MAX:%.*]] = call <2 x float> @llvm.maxnum.v2f32(<2 x float> [[X:%.*]], <2 x float> [[Y:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call <2 x float> @llvm.minnum.v2f32(<2 x float> [[X]], <2 x float> [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call <2 x float> @llvm.maxnum.v2f32(<2 x float> [[MIN]], <2 x float> [[MAX]])
-; CHECK-NEXT:    ret <2 x float> [[VAL]]
+; CHECK-NEXT:    ret <2 x float> [[MAX]]
 ;
   %max = call <2 x float> @llvm.maxnum.v2f32(<2 x float> %x, <2 x float> %y)
   %min = call <2 x float> @llvm.minnum.v2f32(<2 x float> %x, <2 x float> %y)
@@ -1349,9 +1333,7 @@ define <2 x double> @maxnum_minnum_minmum(<2 x double> %x, <2 x double> %y) {
 define float @maxnum_maxnum_maxnum(float %x, float %y) {
 ; CHECK-LABEL: @maxnum_maxnum_maxnum(
 ; CHECK-NEXT:    [[MAX1:%.*]] = call float @llvm.maxnum.f32(float [[X:%.*]], float [[Y:%.*]])
-; CHECK-NEXT:    [[MAX2:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call float @llvm.maxnum.f32(float [[MAX1]], float [[MAX2]])
-; CHECK-NEXT:    ret float [[VAL]]
+; CHECK-NEXT:    ret float [[MAX1]]
 ;
   %max1 = call float @llvm.maxnum.f32(float %x, float %y)
   %max2 = call float @llvm.maxnum.f32(float %x, float %y)
@@ -1361,10 +1343,8 @@ define float @maxnum_maxnum_maxnum(float %x, float %y) {
 
 define double @minnum_maxnum_minnum(double %x, double %y) {
 ; CHECK-LABEL: @minnum_maxnum_minnum(
-; CHECK-NEXT:    [[MAX:%.*]] = call double @llvm.maxnum.f64(double [[X:%.*]], double [[Y:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call double @llvm.minnum.f64(double [[X]], double [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call double @llvm.minnum.f64(double [[MAX]], double [[MIN]])
-; CHECK-NEXT:    ret double [[VAL]]
+; CHECK-NEXT:    [[MIN:%.*]] = call double @llvm.minnum.f64(double [[X:%.*]], double [[Y:%.*]])
+; CHECK-NEXT:    ret double [[MIN]]
 ;
   %max = call double @llvm.maxnum.f64(double %x, double %y)
   %min = call double @llvm.minnum.f64(double %x, double %y)
@@ -1374,10 +1354,8 @@ define double @minnum_maxnum_minnum(double %x, double %y) {
 
 define float @minnum_minnum_maxnum(float %x, float %y) {
 ; CHECK-LABEL: @minnum_minnum_maxnum(
-; CHECK-NEXT:    [[MAX:%.*]] = call float @llvm.maxnum.f32(float [[X:%.*]], float [[Y:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call float @llvm.minnum.f32(float [[MIN]], float [[MAX]])
-; CHECK-NEXT:    ret float [[VAL]]
+; CHECK-NEXT:    [[MIN:%.*]] = call float @llvm.minnum.f32(float [[X:%.*]], float [[Y:%.*]])
+; CHECK-NEXT:    ret float [[MIN]]
 ;
   %max = call float @llvm.maxnum.f32(float %x, float %y)
   %min = call float @llvm.minnum.f32(float %x, float %y)
@@ -1401,9 +1379,7 @@ define <2 x float> @minnum_maxnum_maxnum(<2 x float> %x, <2 x float> %y) {
 define <2 x double> @minnum_minnum_minmum(<2 x double> %x, <2 x double> %y) {
 ; CHECK-LABEL: @minnum_minnum_minmum(
 ; CHECK-NEXT:    [[MIN1:%.*]] = call <2 x double> @llvm.minnum.v2f64(<2 x double> [[X:%.*]], <2 x double> [[Y:%.*]])
-; CHECK-NEXT:    [[MIN2:%.*]] = call <2 x double> @llvm.minnum.v2f64(<2 x double> [[X]], <2 x double> [[Y]])
-; CHECK-NEXT:    [[VAL:%.*]] = call <2 x double> @llvm.minnum.v2f64(<2 x double> [[MIN1]], <2 x double> [[MIN2]])
-; CHECK-NEXT:    ret <2 x double> [[VAL]]
+; CHECK-NEXT:    ret <2 x double> [[MIN1]]
 ;
   %min1 = call <2 x double> @llvm.minnum.v2f64(<2 x double> %x, <2 x double> %y)
   %min2 = call <2 x double> @llvm.minnum.v2f64(<2 x double> %x, <2 x double> %y)


        


More information about the llvm-commits mailing list