[llvm] r321998 - [InstCombine] fold min/max tree with common operand (PR35717)

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 8 07:05:34 PST 2018


Author: spatel
Date: Mon Jan  8 07:05:34 2018
New Revision: 321998

URL: http://llvm.org/viewvc/llvm-project?rev=321998&view=rev
Log:
[InstCombine] fold min/max tree with common operand (PR35717)

There is precedence for factorization transforms in instcombine for FP ops with fast-math. 
We also have similar logic in foldSPFofSPF().

It would take more work to add this to reassociate because that's specialized for binops, 
and min/max are not binops (or even single instructions). Also, I don't have evidence that 
larger min/max trees than this exist in real code, but if we find that's true, we might
want to reorganize where/how we do this optimization.

In the motivating example from https://bugs.llvm.org/show_bug.cgi?id=35717 , we have:

int test(int xc, int xm, int xy) {
  int xk;
  if (xc < xm)
    xk = xc < xy ? xc : xy;
  else
    xk = xm < xy ? xm : xy;
  return xk;
}

This patch solves that problem because we recognize more min/max patterns after rL321672

https://rise4fun.com/Alive/Qjne
https://rise4fun.com/Alive/3yg

Differential Revision: https://reviews.llvm.org/D41603

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/trunk/test/Transforms/InstCombine/max-of-nots.ll
    llvm/trunk/test/Transforms/InstCombine/minmax-fold.ll

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineSelect.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineSelect.cpp?rev=321998&r1=321997&r2=321998&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineSelect.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineSelect.cpp Mon Jan  8 07:05:34 2018
@@ -1289,6 +1289,63 @@ static Instruction *foldSelectCmpXchg(Se
   return nullptr;
 }
 
+/// Reduce a sequence of min/max with a common operand.
+static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS,
+                                        Value *RHS,
+                                        InstCombiner::BuilderTy &Builder) {
+  assert(SelectPatternResult::isMinOrMax(SPF) && "Expected a min/max");
+  // TODO: Allow FP min/max with nnan/nsz.
+  if (!LHS->getType()->isIntOrIntVectorTy())
+    return nullptr;
+
+  // Match 3 of the same min/max ops. Example: umin(umin(), umin()).
+  Value *A, *B, *C, *D;
+  SelectPatternResult L = matchSelectPattern(LHS, A, B);
+  SelectPatternResult R = matchSelectPattern(RHS, C, D);
+  if (SPF != L.Flavor || L.Flavor != R.Flavor)
+    return nullptr;
+
+  // Look for a common operand. The use checks are different than usual because
+  // a min/max pattern typically has 2 uses of each op: 1 by the cmp and 1 by
+  // the select.
+  Value *MinMaxOp = nullptr;
+  Value *ThirdOp = nullptr;
+  if (LHS->getNumUses() <= 2 && RHS->getNumUses() > 2) {
+    // If the LHS is only used in this chain and the RHS is used outside of it,
+    // reuse the RHS min/max because that will eliminate the LHS.
+    if (D == A || C == A) {
+      // min(min(a, b), min(c, a)) --> min(min(c, a), b)
+      // min(min(a, b), min(a, d)) --> min(min(a, d), b)
+      MinMaxOp = RHS;
+      ThirdOp = B;
+    } else if (D == B || C == B) {
+      // min(min(a, b), min(c, b)) --> min(min(c, b), a)
+      // min(min(a, b), min(b, d)) --> min(min(b, d), a)
+      MinMaxOp = RHS;
+      ThirdOp = A;
+    }
+  } else if (RHS->getNumUses() <= 2) {
+    // Reuse the LHS. This will eliminate the RHS.
+    if (D == A || D == B) {
+      // min(min(a, b), min(c, a)) --> min(min(a, b), c)
+      // min(min(a, b), min(c, b)) --> min(min(a, b), c)
+      MinMaxOp = LHS;
+      ThirdOp = C;
+    } else if (C == A || C == B) {
+      // min(min(a, b), min(b, d)) --> min(min(a, b), d)
+      // min(min(a, b), min(c, b)) --> min(min(a, b), d)
+      MinMaxOp = LHS;
+      ThirdOp = D;
+    }
+  }
+  if (!MinMaxOp || !ThirdOp)
+    return nullptr;
+
+  CmpInst::Predicate P = getCmpPredicateForMinMax(SPF);
+  Value *CmpABC = Builder.CreateICmp(P, MinMaxOp, ThirdOp);
+  return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp);
+}
+
 Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
   Value *CondVal = SI.getCondition();
   Value *TrueVal = SI.getTrueValue();
@@ -1563,6 +1620,9 @@ Instruction *InstCombiner::visitSelectIn
         Value *NewSel = Builder.CreateSelect(InvertedCmp, A, B);
         return BinaryOperator::CreateNot(NewSel);
       }
+
+      if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder))
+        return I;
     }
 
     if (SPF) {

Modified: llvm/trunk/test/Transforms/InstCombine/max-of-nots.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/max-of-nots.ll?rev=321998&r1=321997&r2=321998&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/max-of-nots.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/max-of-nots.ll Mon Jan  8 07:05:34 2018
@@ -84,12 +84,10 @@ define i8 @umin_not_2_extra_use(i8 %x, i
 
 define i8 @umin3_not(i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @umin3_not(
-; CHECK-NEXT:    [[CMPYX:%.*]] = icmp ult i8 %y, %x
 ; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt i8 %x, %z
 ; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[TMP1]], i8 %x, i8 %z
-; CHECK-NEXT:    [[TMP3:%.*]] = icmp ugt i8 %y, %z
-; CHECK-NEXT:    [[TMP4:%.*]] = select i1 [[TMP3]], i8 %y, i8 %z
-; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[CMPYX]], i8 [[TMP2]], i8 [[TMP4]]
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp ugt i8 [[TMP2]], %y
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[TMP3]], i8 [[TMP2]], i8 %y
 ; CHECK-NEXT:    [[R:%.*]] = xor i8 [[R:%.*]].v, -1
 ; CHECK-NEXT:    ret i8 [[R]]
 ;

Modified: llvm/trunk/test/Transforms/InstCombine/minmax-fold.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/minmax-fold.ll?rev=321998&r1=321997&r2=321998&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/minmax-fold.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/minmax-fold.ll Mon Jan  8 07:05:34 2018
@@ -754,10 +754,8 @@ define i32 @common_factor_smin(i32 %a, i
 ; CHECK-LABEL: @common_factor_smin(
 ; CHECK-NEXT:    [[CMP_AB:%.*]] = icmp slt i32 %a, %b
 ; CHECK-NEXT:    [[MIN_AB:%.*]] = select i1 [[CMP_AB]], i32 %a, i32 %b
-; CHECK-NEXT:    [[CMP_BC:%.*]] = icmp slt i32 %b, %c
-; CHECK-NEXT:    [[MIN_BC:%.*]] = select i1 [[CMP_BC]], i32 %b, i32 %c
-; CHECK-NEXT:    [[CMP_AB_BC:%.*]] = icmp slt i32 [[MIN_AB]], [[MIN_BC]]
-; CHECK-NEXT:    [[MIN_ABC:%.*]] = select i1 [[CMP_AB_BC]], i32 [[MIN_AB]], i32 [[MIN_BC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i32 [[MIN_AB]], %c
+; CHECK-NEXT:    [[MIN_ABC:%.*]] = select i1 [[TMP1]], i32 [[MIN_AB]], i32 %c
 ; CHECK-NEXT:    ret i32 [[MIN_ABC]]
 ;
   %cmp_ab = icmp slt i32 %a, %b
@@ -775,10 +773,8 @@ define <2 x i32> @common_factor_smax(<2
 ; CHECK-LABEL: @common_factor_smax(
 ; CHECK-NEXT:    [[CMP_AB:%.*]] = icmp sgt <2 x i32> %a, %b
 ; CHECK-NEXT:    [[MAX_AB:%.*]] = select <2 x i1> [[CMP_AB]], <2 x i32> %a, <2 x i32> %b
-; CHECK-NEXT:    [[CMP_CB:%.*]] = icmp sgt <2 x i32> %c, %b
-; CHECK-NEXT:    [[MAX_CB:%.*]] = select <2 x i1> [[CMP_CB]], <2 x i32> %c, <2 x i32> %b
-; CHECK-NEXT:    [[CMP_AB_CB:%.*]] = icmp sgt <2 x i32> [[MAX_AB]], [[MAX_CB]]
-; CHECK-NEXT:    [[MAX_ABC:%.*]] = select <2 x i1> [[CMP_AB_CB]], <2 x i32> [[MAX_AB]], <2 x i32> [[MAX_CB]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sgt <2 x i32> [[MAX_AB]], %c
+; CHECK-NEXT:    [[MAX_ABC:%.*]] = select <2 x i1> [[TMP1]], <2 x i32> [[MAX_AB]], <2 x i32> %c
 ; CHECK-NEXT:    ret <2 x i32> [[MAX_ABC]]
 ;
   %cmp_ab = icmp sgt <2 x i32> %a, %b
@@ -796,10 +792,8 @@ define <2 x i32> @common_factor_umin(<2
 ; CHECK-LABEL: @common_factor_umin(
 ; CHECK-NEXT:    [[CMP_BC:%.*]] = icmp ult <2 x i32> %b, %c
 ; CHECK-NEXT:    [[MIN_BC:%.*]] = select <2 x i1> [[CMP_BC]], <2 x i32> %b, <2 x i32> %c
-; CHECK-NEXT:    [[CMP_AB:%.*]] = icmp ult <2 x i32> %a, %b
-; CHECK-NEXT:    [[MIN_AB:%.*]] = select <2 x i1> [[CMP_AB]], <2 x i32> %a, <2 x i32> %b
-; CHECK-NEXT:    [[CMP_BC_AB:%.*]] = icmp ult <2 x i32> [[MIN_BC]], [[MIN_AB]]
-; CHECK-NEXT:    [[MIN_ABC:%.*]] = select <2 x i1> [[CMP_BC_AB]], <2 x i32> [[MIN_BC]], <2 x i32> [[MIN_AB]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult <2 x i32> [[MIN_BC]], %a
+; CHECK-NEXT:    [[MIN_ABC:%.*]] = select <2 x i1> [[TMP1]], <2 x i32> [[MIN_BC]], <2 x i32> %a
 ; CHECK-NEXT:    ret <2 x i32> [[MIN_ABC]]
 ;
   %cmp_bc = icmp ult <2 x i32> %b, %c
@@ -817,10 +811,8 @@ define i32 @common_factor_umax(i32 %a, i
 ; CHECK-LABEL: @common_factor_umax(
 ; CHECK-NEXT:    [[CMP_BC:%.*]] = icmp ugt i32 %b, %c
 ; CHECK-NEXT:    [[MAX_BC:%.*]] = select i1 [[CMP_BC]], i32 %b, i32 %c
-; CHECK-NEXT:    [[CMP_BA:%.*]] = icmp ugt i32 %b, %a
-; CHECK-NEXT:    [[MAX_BA:%.*]] = select i1 [[CMP_BA]], i32 %b, i32 %a
-; CHECK-NEXT:    [[CMP_BC_BA:%.*]] = icmp ugt i32 [[MAX_BC]], [[MAX_BA]]
-; CHECK-NEXT:    [[MAX_ABC:%.*]] = select i1 [[CMP_BC_BA]], i32 [[MAX_BC]], i32 [[MAX_BA]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt i32 [[MAX_BC]], %a
+; CHECK-NEXT:    [[MAX_ABC:%.*]] = select i1 [[TMP1]], i32 [[MAX_BC]], i32 %a
 ; CHECK-NEXT:    ret i32 [[MAX_ABC]]
 ;
   %cmp_bc = icmp ugt i32 %b, %c
@@ -838,10 +830,8 @@ define i32 @common_factor_umax_extra_use
 ; CHECK-LABEL: @common_factor_umax_extra_use_lhs(
 ; CHECK-NEXT:    [[CMP_BC:%.*]] = icmp ugt i32 %b, %c
 ; CHECK-NEXT:    [[MAX_BC:%.*]] = select i1 [[CMP_BC]], i32 %b, i32 %c
-; CHECK-NEXT:    [[CMP_BA:%.*]] = icmp ugt i32 %b, %a
-; CHECK-NEXT:    [[MAX_BA:%.*]] = select i1 [[CMP_BA]], i32 %b, i32 %a
-; CHECK-NEXT:    [[CMP_BC_BA:%.*]] = icmp ugt i32 [[MAX_BC]], [[MAX_BA]]
-; CHECK-NEXT:    [[MAX_ABC:%.*]] = select i1 [[CMP_BC_BA]], i32 [[MAX_BC]], i32 [[MAX_BA]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt i32 [[MAX_BC]], %a
+; CHECK-NEXT:    [[MAX_ABC:%.*]] = select i1 [[TMP1]], i32 [[MAX_BC]], i32 %a
 ; CHECK-NEXT:    call void @extra_use(i32 [[MAX_BC]])
 ; CHECK-NEXT:    ret i32 [[MAX_ABC]]
 ;
@@ -857,12 +847,10 @@ define i32 @common_factor_umax_extra_use
 
 define i32 @common_factor_umax_extra_use_rhs(i32 %a, i32 %b, i32 %c) {
 ; CHECK-LABEL: @common_factor_umax_extra_use_rhs(
-; CHECK-NEXT:    [[CMP_BC:%.*]] = icmp ugt i32 %b, %c
-; CHECK-NEXT:    [[MAX_BC:%.*]] = select i1 [[CMP_BC]], i32 %b, i32 %c
 ; CHECK-NEXT:    [[CMP_BA:%.*]] = icmp ugt i32 %b, %a
 ; CHECK-NEXT:    [[MAX_BA:%.*]] = select i1 [[CMP_BA]], i32 %b, i32 %a
-; CHECK-NEXT:    [[CMP_BC_BA:%.*]] = icmp ugt i32 [[MAX_BC]], [[MAX_BA]]
-; CHECK-NEXT:    [[MAX_ABC:%.*]] = select i1 [[CMP_BC_BA]], i32 [[MAX_BC]], i32 [[MAX_BA]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt i32 [[MAX_BA]], %c
+; CHECK-NEXT:    [[MAX_ABC:%.*]] = select i1 [[TMP1]], i32 [[MAX_BA]], i32 %c
 ; CHECK-NEXT:    call void @extra_use(i32 [[MAX_BA]])
 ; CHECK-NEXT:    ret i32 [[MAX_ABC]]
 ;




More information about the llvm-commits mailing list