[llvm] 14eefa5 - [InstCombine] factorize min/max intrinsic ops with common operand (2nd try)

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 12 13:34:16 PDT 2021


Author: Sanjay Patel
Date: 2021-08-12T16:32:07-04:00
New Revision: 14eefa57f2b62bcffb0d7cade2230e7bba9b6d03

URL: https://github.com/llvm/llvm-project/commit/14eefa57f2b62bcffb0d7cade2230e7bba9b6d03
DIFF: https://github.com/llvm/llvm-project/commit/14eefa57f2b62bcffb0d7cade2230e7bba9b6d03.diff

LOG: [InstCombine] factorize min/max intrinsic ops with common operand (2nd try)

This is a re-try of 6de1dbbd09c1 which was reverted because
it missed a null check. Extra test for that failure added.

Original commit message:
This is an adaptation of D41603 and another step on the way
to canonicalizing to the intrinsic forms of min/max.

See D98152 for status.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/test/Transforms/InstCombine/minmax-intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 210652e23377..ab0aa0d40298 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -795,6 +795,64 @@ static Instruction *foldClampRangeOfTwo(IntrinsicInst *II,
   return SelectInst::Create(Cmp, ConstantInt::get(II->getType(), *C0), I1);
 }
 
+/// Reduce a sequence of min/max intrinsics with a common operand.
+static Instruction *factorizeMinMaxTree(IntrinsicInst *II,
+                                        InstCombiner::BuilderTy &Builder) {
+  // Match 3 of the same min/max ops. Example: umin(umin(), umin()).
+  auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
+  auto *RHS = dyn_cast<IntrinsicInst>(II->getArgOperand(1));
+  Intrinsic::ID MinMaxID = II->getIntrinsicID();
+  if (!LHS || !RHS || LHS->getIntrinsicID() !=  MinMaxID ||
+      RHS->getIntrinsicID() != MinMaxID ||
+      (!LHS->hasOneUse() && !RHS->hasOneUse()))
+    return nullptr;
+
+  Value *A = LHS->getArgOperand(0);
+  Value *B = LHS->getArgOperand(1);
+  Value *C = RHS->getArgOperand(0);
+  Value *D = RHS->getArgOperand(1);
+
+  // Look for a common operand.
+  Value *MinMaxOp = nullptr;
+  Value *ThirdOp = nullptr;
+  if (LHS->hasOneUse()) {
+    // If the LHS is only used in this chain and the RHS is used outside of it,
+    // reuse the RHS min/max because that will eliminate the LHS.
+    if (D == A || C == A) {
+      // min(min(a, b), min(c, a)) --> min(min(c, a), b)
+      // min(min(a, b), min(a, d)) --> min(min(a, d), b)
+      MinMaxOp = RHS;
+      ThirdOp = B;
+    } else if (D == B || C == B) {
+      // min(min(a, b), min(c, b)) --> min(min(c, b), a)
+      // min(min(a, b), min(b, d)) --> min(min(b, d), a)
+      MinMaxOp = RHS;
+      ThirdOp = A;
+    }
+  } else {
+    assert(RHS->hasOneUse() && "Expected one-use operand");
+    // Reuse the LHS. This will eliminate the RHS.
+    if (D == A || D == B) {
+      // min(min(a, b), min(c, a)) --> min(min(a, b), c)
+      // min(min(a, b), min(c, b)) --> min(min(a, b), c)
+      MinMaxOp = LHS;
+      ThirdOp = C;
+    } else if (C == A || C == B) {
+      // min(min(a, b), min(b, d)) --> min(min(a, b), d)
+      // min(min(a, b), min(c, b)) --> min(min(a, b), d)
+      MinMaxOp = LHS;
+      ThirdOp = D;
+    }
+  }
+
+  if (!MinMaxOp || !ThirdOp)
+    return nullptr;
+
+  Module *Mod = II->getModule();
+  Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType());
+  return CallInst::Create(MinMax, { MinMaxOp, ThirdOp });
+}
+
 /// CallInst simplification. This mostly only handles folding of intrinsic
 /// instructions. For normal calls, it allows visitCallBase to do the heavy
 /// lifting.
@@ -1056,6 +1114,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
         if (Instruction *R = FoldOpIntoSelect(*II, Sel))
           return R;
 
+    if (Instruction *NewMinMax = factorizeMinMaxTree(II, Builder))
+       return NewMinMax;
+
     break;
   }
   case Intrinsic::bswap: {

diff  --git a/llvm/test/Transforms/InstCombine/minmax-intrinsics.ll b/llvm/test/Transforms/InstCombine/minmax-intrinsics.ll
index 54e133ded3c9..c21724b46308 100644
--- a/llvm/test/Transforms/InstCombine/minmax-intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-intrinsics.ll
@@ -907,9 +907,8 @@ define <3 x i1> @umin_ne_zero2(<3 x i8> %a, <3 x i8> %b) {
 
 define i8 @smax(i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @smax(
-; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
-; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Z:%.*]])
-; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.smax.i8(i8 [[M1]], i8 [[M2]])
+; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 [[Z:%.*]])
+; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.smax.i8(i8 [[M2]], i8 [[Y:%.*]])
 ; CHECK-NEXT:    ret i8 [[M3]]
 ;
   %m1 = call i8 @llvm.smax.i8(i8 %x, i8 %y)
@@ -920,9 +919,8 @@ define i8 @smax(i8 %x, i8 %y, i8 %z) {
 
 define <3 x i8> @smin(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) {
 ; CHECK-LABEL: @smin(
-; CHECK-NEXT:    [[M1:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[Y:%.*]], <3 x i8> [[X:%.*]])
-; CHECK-NEXT:    [[M2:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[X]], <3 x i8> [[Z:%.*]])
-; CHECK-NEXT:    [[M3:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[M1]], <3 x i8> [[M2]])
+; CHECK-NEXT:    [[M2:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[X:%.*]], <3 x i8> [[Z:%.*]])
+; CHECK-NEXT:    [[M3:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[M2]], <3 x i8> [[Y:%.*]])
 ; CHECK-NEXT:    ret <3 x i8> [[M3]]
 ;
   %m1 = call <3 x i8> @llvm.smin.v3i8(<3 x i8> %y, <3 x i8> %x)
@@ -935,8 +933,7 @@ define i8 @umax(i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @umax(
 ; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
 ; CHECK-NEXT:    call void @use(i8 [[M1]])
-; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.umax.i8(i8 [[Z:%.*]], i8 [[X]])
-; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.umax.i8(i8 [[M1]], i8 [[M2]])
+; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.umax.i8(i8 [[M1]], i8 [[Z:%.*]])
 ; CHECK-NEXT:    ret i8 [[M3]]
 ;
   %m1 = call i8 @llvm.umax.i8(i8 %x, i8 %y)
@@ -948,10 +945,9 @@ define i8 @umax(i8 %x, i8 %y, i8 %z) {
 
 define i8 @umin(i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @umin(
-; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.umin.i8(i8 [[Y:%.*]], i8 [[X:%.*]])
-; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.umin.i8(i8 [[Z:%.*]], i8 [[X]])
+; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.umin.i8(i8 [[Z:%.*]], i8 [[X:%.*]])
 ; CHECK-NEXT:    call void @use(i8 [[M2]])
-; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.umin.i8(i8 [[M1]], i8 [[M2]])
+; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.umin.i8(i8 [[M2]], i8 [[Y:%.*]])
 ; CHECK-NEXT:    ret i8 [[M3]]
 ;
   %m1 = call i8 @llvm.umin.i8(i8 %y, i8 %x)
@@ -961,6 +957,8 @@ define i8 @umin(i8 %x, i8 %y, i8 %z) {
   ret i8 %m3
 }
 
+; negative test - too many uses
+
 define i8 @smax_uses(i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @smax_uses(
 ; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
@@ -978,6 +976,21 @@ define i8 @smax_uses(i8 %x, i8 %y, i8 %z) {
   ret i8 %m3
 }
 
+; negative test - must have common operand
+
+define i8 @smax_no_common_op(i8 %x, i8 %y, i8 %z, i8 %w) {
+; CHECK-LABEL: @smax_no_common_op(
+; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.smax.i8(i8 [[W:%.*]], i8 [[Z:%.*]])
+; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.smax.i8(i8 [[M1]], i8 [[M2]])
+; CHECK-NEXT:    ret i8 [[M3]]
+;
+  %m1 = call i8 @llvm.smax.i8(i8 %x, i8 %y)
+  %m2 = call i8 @llvm.smax.i8(i8 %w, i8 %z)
+  %m3 = call i8 @llvm.smax.i8(i8 %m1, i8 %m2)
+  ret i8 %m3
+}
+
 define i8 @umax_demand_lshr(i8 %x) {
 ; CHECK-LABEL: @umax_demand_lshr(
 ; CHECK-NEXT:    [[R:%.*]] = lshr i8 [[X:%.*]], 4


        


More information about the llvm-commits mailing list