[llvm] [InstCombine] Factorise add/sub and max/min using distributivity (PR #101507)

Jorge Botto via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 1 08:55:06 PDT 2024


https://github.com/jf-botto created https://github.com/llvm/llvm-project/pull/101507

This PR fixes part of [Issue 92433](https://github.com/llvm/llvm-project/issues/92433).
https://alive2.llvm.org/ce/z/ZaXf7L

The alive proof sometimes times out. I have to add `noundef` to certain variables to make it not time out more consistently. Given how min/max/add preserve undef, I believe these optimisations to be correct when undef values are passed through.

>From d51cdbb6926d436b9faa4a98b1647702e0fd7a88 Mon Sep 17 00:00:00 2001
From: Jorge Botto <jorge.botto.16 at ucl.ac.uk>
Date: Thu, 1 Aug 2024 16:19:50 +0100
Subject: [PATCH 1/2] Precommit test

---
 .../InstCombine/intrinsic-distributive.ll     | 226 ++++++++++++++++++
 1 file changed, 226 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/intrinsic-distributive.ll

diff --git a/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll b/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll
new file mode 100644
index 0000000000000..b874fcde67d1e
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll
@@ -0,0 +1,226 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine < %s 2>&1 | FileCheck %s
+
+define i32 @umin_of_umax(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umin_of_umax(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[MAX1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Z]])
+; CHECK-NEXT:    [[MAX2:%.*]] = call i32 @llvm.umax.i32(i32 [[Y]], i32 [[Z]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[MAX1]], i32 [[MAX2]])
+; CHECK-NEXT:    ret i32 [[MIN]]
+;
+  %max1 = call i32 @llvm.umax.i32(i32 %x, i32 %z)
+  %max2 = call i32 @llvm.umax.i32(i32 %y, i32 %z)
+  %min = call i32 @llvm.umin.i32(i32 %max1, i32 %max2)
+  ret i32 %min
+}
+
+define i32 @umin_of_umax_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umin_of_umax_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[MAX1:%.*]] = call i32 @llvm.umax.i32(i32 [[Z]], i32 [[X]])
+; CHECK-NEXT:    [[MAX2:%.*]] = call i32 @llvm.umax.i32(i32 [[Z]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[MAX1]], i32 [[MAX2]])
+; CHECK-NEXT:    ret i32 [[MIN]]
+;
+  %max1 = call i32 @llvm.umax.i32(i32 %z, i32 %x)
+  %max2 = call i32 @llvm.umax.i32(i32 %z, i32 %y)
+  %min = call i32 @llvm.umin.i32(i32 %max1, i32 %max2)
+  ret i32 %min
+}
+
+define i32 @smin_of_smax(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smin_of_smax(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[MAX1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Z]])
+; CHECK-NEXT:    [[MAX2:%.*]] = call i32 @llvm.smax.i32(i32 [[Y]], i32 [[Z]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[MAX1]], i32 [[MAX2]])
+; CHECK-NEXT:    ret i32 [[MIN]]
+;
+  %max1 = call i32 @llvm.smax.i32(i32 %x, i32 %z)
+  %max2 = call i32 @llvm.smax.i32(i32 %y, i32 %z)
+  %min = call i32 @llvm.smin.i32(i32 %max1, i32 %max2)
+  ret i32 %min
+}
+
+define i32 @smin_of_smax_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smin_of_smax_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[MAX1:%.*]] = call i32 @llvm.smax.i32(i32 [[Z]], i32 [[X]])
+; CHECK-NEXT:    [[MAX2:%.*]] = call i32 @llvm.smax.i32(i32 [[Z]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[MAX1]], i32 [[MAX2]])
+; CHECK-NEXT:    ret i32 [[MIN]]
+;
+  %max1 = call i32 @llvm.smax.i32(i32 %z, i32 %x)
+  %max2 = call i32 @llvm.smax.i32(i32 %z, i32 %y)
+  %min = call i32 @llvm.smin.i32(i32 %max1, i32 %max2)
+  ret i32 %min
+}
+
+define i32 @umax_of_umin(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umax_of_umin(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[MIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Z]])
+; CHECK-NEXT:    [[MIN2:%.*]] = call i32 @llvm.umin.i32(i32 [[Y]], i32 [[Z]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[MIN1]], i32 [[MIN2]])
+; CHECK-NEXT:    ret i32 [[MAX]]
+;
+  %min1 = call i32 @llvm.umin.i32(i32 %x, i32 %z)
+  %min2 = call i32 @llvm.umin.i32(i32 %y, i32 %z)
+  %max = call i32 @llvm.umax.i32(i32 %min1, i32 %min2)
+  ret i32 %max
+}
+
+define i32 @umax_of_umin_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umax_of_umin_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[MIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[Z]], i32 [[X]])
+; CHECK-NEXT:    [[MIN2:%.*]] = call i32 @llvm.umin.i32(i32 [[Z]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[MIN1]], i32 [[MIN2]])
+; CHECK-NEXT:    ret i32 [[MAX]]
+;
+  %min1 = call i32 @llvm.umin.i32(i32 %z, i32 %x)
+  %min2 = call i32 @llvm.umin.i32(i32 %z, i32 %y)
+  %max = call i32 @llvm.umax.i32(i32 %min1, i32 %min2)
+  ret i32 %max
+}
+
+define i32 @smax_of_smin(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smax_of_smin(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[MIN1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Z]])
+; CHECK-NEXT:    [[MIN2:%.*]] = call i32 @llvm.smin.i32(i32 [[Y]], i32 [[Z]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[MIN1]], i32 [[MIN2]])
+; CHECK-NEXT:    ret i32 [[MAX]]
+;
+  %min1 = call i32 @llvm.smin.i32(i32 %x, i32 %z)
+  %min2 = call i32 @llvm.smin.i32(i32 %y, i32 %z)
+  %max = call i32 @llvm.smax.i32(i32 %min1, i32 %min2)
+  ret i32 %max
+}
+
+define i32 @smax_of_smin_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smax_of_smin_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[MIN1:%.*]] = call i32 @llvm.smin.i32(i32 [[Z]], i32 [[X]])
+; CHECK-NEXT:    [[MIN2:%.*]] = call i32 @llvm.smin.i32(i32 [[Z]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[MIN1]], i32 [[MIN2]])
+; CHECK-NEXT:    ret i32 [[MAX]]
+;
+  %min1 = call i32 @llvm.smin.i32(i32 %z, i32 %x)
+  %min2 = call i32 @llvm.smin.i32(i32 %z, i32 %y)
+  %max = call i32 @llvm.smax.i32(i32 %min1, i32 %min2)
+  ret i32 %max
+}
+
+define i32 @umax_of_uadd_sat(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umax_of_uadd_sat(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]])
+; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Y]], i32 [[Z]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    ret i32 [[MAX]]
+;
+  %add1 = call i32 @llvm.uadd.sat.i32(i32 %x, i32 %z)
+  %add2 = call i32 @llvm.uadd.sat.i32(i32 %y, i32 %z)
+  %max = call i32 @llvm.umax.i32(i32 %add1, i32 %add2)
+  ret i32 %max
+}
+
+define i32 @umax_of_uadd_sat_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umax_of_uadd_sat_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Z]], i32 [[X]])
+; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Z]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    ret i32 [[MAX]]
+;
+  %add1 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %x)
+  %add2 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %y)
+  %max = call i32 @llvm.umax.i32(i32 %add1, i32 %add2)
+  ret i32 %max
+}
+
+define i32 @umin_of_uadd_sat(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umin_of_uadd_sat(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]])
+; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Y]], i32 [[Z]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    ret i32 [[MIN]]
+;
+  %add1 = call i32 @llvm.uadd.sat.i32(i32 %x, i32 %z)
+  %add2 = call i32 @llvm.uadd.sat.i32(i32 %y, i32 %z)
+  %min = call i32 @llvm.umin.i32(i32 %add1, i32 %add2)
+  ret i32 %min
+}
+
+define i32 @umin_of_uadd_sat_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umin_of_uadd_sat_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Z]], i32 [[X]])
+; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Z]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    ret i32 [[MIN]]
+;
+  %add1 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %x)
+  %add2 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %y)
+  %min = call i32 @llvm.umin.i32(i32 %add1, i32 %add2)
+  ret i32 %min
+}
+
+define i32 @smax_of_sadd_sat(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smax_of_sadd_sat(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[X]], i32 [[Z]])
+; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Y]], i32 [[Z]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    ret i32 [[MAX]]
+;
+  %add1 = call i32 @llvm.sadd.sat.i32(i32 %x, i32 %z)
+  %add2 = call i32 @llvm.sadd.sat.i32(i32 %y, i32 %z)
+  %max = call i32 @llvm.smax.i32(i32 %add1, i32 %add2)
+  ret i32 %max
+}
+
+define i32 @smax_of_sadd_sat_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smax_of_sadd_sat_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Z]], i32 [[X]])
+; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Z]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    ret i32 [[MAX]]
+;
+  %add1 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %x)
+  %add2 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %y)
+  %max = call i32 @llvm.smax.i32(i32 %add1, i32 %add2)
+  ret i32 %max
+}
+
+define i32 @smin_of_sadd_sat(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smin_of_sadd_sat(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[X]], i32 [[Z]])
+; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Y]], i32 [[Z]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    ret i32 [[MIN]]
+;
+  %add1 = call i32 @llvm.sadd.sat.i32(i32 %x, i32 %z)
+  %add2 = call i32 @llvm.sadd.sat.i32(i32 %y, i32 %z)
+  %min = call i32 @llvm.smin.i32(i32 %add1, i32 %add2)
+  ret i32 %min
+}
+
+define i32 @smin_of_sadd_sat_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smin_of_sadd_sat_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Z]], i32 [[X]])
+; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Z]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    ret i32 [[MIN]]
+;
+  %add1 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %x)
+  %add2 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %y)
+  %min = call i32 @llvm.smin.i32(i32 %add1, i32 %add2)
+  ret i32 %min
+}

>From 3961757df490d57382960f66c2ecfe8336b57e77 Mon Sep 17 00:00:00 2001
From: Jorge Botto <jorge.botto.16 at ucl.ac.uk>
Date: Thu, 1 Aug 2024 16:33:51 +0100
Subject: [PATCH 2/2] Adding missed optimisation

---
 .../InstCombine/InstCombineCalls.cpp          | 94 ++++++++++++++++++-
 .../InstCombine/intrinsic-distributive.ll     | 80 +++++++---------
 2 files changed, 125 insertions(+), 49 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index cc68fd4cf1c1b..3267d27d703e3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1505,6 +1505,80 @@ foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1,
       ConstantInt::getTrue(ZeroUndef->getType()));
 }
 
+/// Return whether "X LOp (Y ROp Z)" is always equal to
+/// "(X LOp Y) ROp (X LOp Z)".
+static bool leftDistributesOverRightIntrinsic(Intrinsic::ID LOp,
+                                              Intrinsic::ID ROp) {
+  switch (LOp) {
+  case Intrinsic::umax:
+    return ROp == Intrinsic::umin;
+  case Intrinsic::smax:
+    return ROp == Intrinsic::smin;
+  case Intrinsic::umin:
+    return ROp == Intrinsic::umax;
+  case Intrinsic::smin:
+    return ROp == Intrinsic::smax;
+  case Intrinsic::uadd_sat:
+    return ROp == Intrinsic::umax || ROp == Intrinsic::umin;
+  case Intrinsic::sadd_sat:
+    return ROp == Intrinsic::smax || ROp == Intrinsic::smin;
+  default:
+    return false;
+  }
+}
+
+// Attempts to factorise a common term
+// in an instruction that has the form "(A op' B) op (C op' D)
+static Instruction *
+foldCallUsingDistributiveLaws(CallInst *II, InstCombiner::BuilderTy &Builder) {
+  Value *LHS = II->getOperand(0), *RHS = II->getOperand(1);
+  Intrinsic::ID TopLevelOpcode = II->getCalledFunction()->getIntrinsicID();
+
+  if (LHS && RHS) {
+    CallInst *Op0 = dyn_cast<CallInst>(LHS);
+    CallInst *Op1 = dyn_cast<CallInst>(RHS);
+
+    if (!Op0 || !Op1)
+      return nullptr;
+
+    if (Op0->getCalledFunction()->getIntrinsicID() !=
+        Op1->getCalledFunction()->getIntrinsicID())
+      return nullptr;
+
+    Intrinsic::ID InnerOpcode = Op0->getCalledFunction()->getIntrinsicID();
+
+    bool InnerCommutative = Op0->isCommutative();
+    bool Distributive =
+        leftDistributesOverRightIntrinsic(InnerOpcode, TopLevelOpcode);
+
+    Value *A = Op0->getOperand(0);
+    Value *B = Op0->getOperand(1);
+    Value *C = Op1->getOperand(0);
+    Value *D = Op1->getOperand(1);
+
+    if (Distributive && (A == C || (InnerCommutative && A == D))) {
+      if (A != C)
+        std::swap(C, D);
+
+      Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, B, D);
+      Function *F = Intrinsic::getDeclaration(II->getModule(), InnerOpcode,
+                                              II->getType());
+      return CallInst::Create(F, {NewIntrinsic, A});
+    }
+
+    if (Distributive && InnerCommutative && (B == D || B == C)) {
+      if (B != D)
+        std::swap(C, D);
+
+      Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, A, C);
+      Function *F = Intrinsic::getDeclaration(II->getModule(), InnerOpcode,
+                                              II->getType());
+      return CallInst::Create(F, {NewIntrinsic, B});
+    }
+  }
+  return nullptr;
+}
+
 /// CallInst simplification. This mostly only handles folding of intrinsic
 /// instructions. For normal calls, it allows visitCallBase to do the heavy
 /// lifting.
@@ -1731,6 +1805,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
             foldMinimumOverTrailingOrLeadingZeroCount<Intrinsic::ctlz>(
                 I0, I1, DL, Builder))
       return replaceInstUsesWith(*II, FoldedCtlz);
+
+    if (Instruction *I = foldCallUsingDistributiveLaws(II, Builder)) {
+      return I;
+    }
+
     [[fallthrough]];
   }
   case Intrinsic::umax: {
@@ -1751,9 +1830,18 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     }
     // If both operands of unsigned min/max are sign-extended, it is still ok
     // to narrow the operation.
+
+    if (Instruction *I = foldCallUsingDistributiveLaws(II, Builder))
+      return I;
+
+    [[fallthrough]];
+  }
+  case Intrinsic::smax: {
+    if (Instruction *I = foldCallUsingDistributiveLaws(II, Builder))
+      return I;
+
     [[fallthrough]];
   }
-  case Intrinsic::smax:
   case Intrinsic::smin: {
     Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1);
     Value *X, *Y;
@@ -1929,6 +2017,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       }
     }
 
+    if (Instruction *I = foldCallUsingDistributiveLaws(II, Builder)) {
+      return I;
+    }
+
     break;
   }
   case Intrinsic::bitreverse: {
diff --git a/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll b/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll
index b874fcde67d1e..10f4d8bbc7a0d 100644
--- a/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll
+++ b/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll
@@ -4,9 +4,8 @@
 define i32 @umin_of_umax(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @umin_of_umax(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[MAX1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[MAX2:%.*]] = call i32 @llvm.umax.i32(i32 [[Y]], i32 [[Z]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[MAX1]], i32 [[MAX2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umax.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MIN]]
 ;
   %max1 = call i32 @llvm.umax.i32(i32 %x, i32 %z)
@@ -18,9 +17,8 @@ define i32 @umin_of_umax(i32 %x, i32 %y, i32 %z) {
 define i32 @umin_of_umax_comm(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @umin_of_umax_comm(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[MAX1:%.*]] = call i32 @llvm.umax.i32(i32 [[Z]], i32 [[X]])
-; CHECK-NEXT:    [[MAX2:%.*]] = call i32 @llvm.umax.i32(i32 [[Z]], i32 [[Y]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[MAX1]], i32 [[MAX2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umax.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MIN]]
 ;
   %max1 = call i32 @llvm.umax.i32(i32 %z, i32 %x)
@@ -32,9 +30,8 @@ define i32 @umin_of_umax_comm(i32 %x, i32 %y, i32 %z) {
 define i32 @smin_of_smax(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @smin_of_smax(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[MAX1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[MAX2:%.*]] = call i32 @llvm.smax.i32(i32 [[Y]], i32 [[Z]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[MAX1]], i32 [[MAX2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MIN]]
 ;
   %max1 = call i32 @llvm.smax.i32(i32 %x, i32 %z)
@@ -46,9 +43,8 @@ define i32 @smin_of_smax(i32 %x, i32 %y, i32 %z) {
 define i32 @smin_of_smax_comm(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @smin_of_smax_comm(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[MAX1:%.*]] = call i32 @llvm.smax.i32(i32 [[Z]], i32 [[X]])
-; CHECK-NEXT:    [[MAX2:%.*]] = call i32 @llvm.smax.i32(i32 [[Z]], i32 [[Y]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[MAX1]], i32 [[MAX2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MIN]]
 ;
   %max1 = call i32 @llvm.smax.i32(i32 %z, i32 %x)
@@ -60,9 +56,8 @@ define i32 @smin_of_smax_comm(i32 %x, i32 %y, i32 %z) {
 define i32 @umax_of_umin(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @umax_of_umin(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[MIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[MIN2:%.*]] = call i32 @llvm.umin.i32(i32 [[Y]], i32 [[Z]])
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[MIN1]], i32 [[MIN2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umin.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MAX]]
 ;
   %min1 = call i32 @llvm.umin.i32(i32 %x, i32 %z)
@@ -74,9 +69,8 @@ define i32 @umax_of_umin(i32 %x, i32 %y, i32 %z) {
 define i32 @umax_of_umin_comm(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @umax_of_umin_comm(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[MIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[Z]], i32 [[X]])
-; CHECK-NEXT:    [[MIN2:%.*]] = call i32 @llvm.umin.i32(i32 [[Z]], i32 [[Y]])
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[MIN1]], i32 [[MIN2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umin.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MAX]]
 ;
   %min1 = call i32 @llvm.umin.i32(i32 %z, i32 %x)
@@ -88,9 +82,8 @@ define i32 @umax_of_umin_comm(i32 %x, i32 %y, i32 %z) {
 define i32 @smax_of_smin(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @smax_of_smin(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[MIN1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[MIN2:%.*]] = call i32 @llvm.smin.i32(i32 [[Y]], i32 [[Z]])
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[MIN1]], i32 [[MIN2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MAX]]
 ;
   %min1 = call i32 @llvm.smin.i32(i32 %x, i32 %z)
@@ -102,9 +95,8 @@ define i32 @smax_of_smin(i32 %x, i32 %y, i32 %z) {
 define i32 @smax_of_smin_comm(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @smax_of_smin_comm(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[MIN1:%.*]] = call i32 @llvm.smin.i32(i32 [[Z]], i32 [[X]])
-; CHECK-NEXT:    [[MIN2:%.*]] = call i32 @llvm.smin.i32(i32 [[Z]], i32 [[Y]])
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[MIN1]], i32 [[MIN2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MAX]]
 ;
   %min1 = call i32 @llvm.smin.i32(i32 %z, i32 %x)
@@ -116,9 +108,8 @@ define i32 @smax_of_smin_comm(i32 %x, i32 %y, i32 %z) {
 define i32 @umax_of_uadd_sat(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @umax_of_uadd_sat(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Y]], i32 [[Z]])
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MAX]]
 ;
   %add1 = call i32 @llvm.uadd.sat.i32(i32 %x, i32 %z)
@@ -130,9 +121,8 @@ define i32 @umax_of_uadd_sat(i32 %x, i32 %y, i32 %z) {
 define i32 @umax_of_uadd_sat_comm(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @umax_of_uadd_sat_comm(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Z]], i32 [[X]])
-; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Z]], i32 [[Y]])
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MAX]]
 ;
   %add1 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %x)
@@ -144,9 +134,8 @@ define i32 @umax_of_uadd_sat_comm(i32 %x, i32 %y, i32 %z) {
 define i32 @umin_of_uadd_sat(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @umin_of_uadd_sat(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Y]], i32 [[Z]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MIN]]
 ;
   %add1 = call i32 @llvm.uadd.sat.i32(i32 %x, i32 %z)
@@ -158,9 +147,8 @@ define i32 @umin_of_uadd_sat(i32 %x, i32 %y, i32 %z) {
 define i32 @umin_of_uadd_sat_comm(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @umin_of_uadd_sat_comm(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Z]], i32 [[X]])
-; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[Z]], i32 [[Y]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MIN]]
 ;
   %add1 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %x)
@@ -172,9 +160,8 @@ define i32 @umin_of_uadd_sat_comm(i32 %x, i32 %y, i32 %z) {
 define i32 @smax_of_sadd_sat(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @smax_of_sadd_sat(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Y]], i32 [[Z]])
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MAX]]
 ;
   %add1 = call i32 @llvm.sadd.sat.i32(i32 %x, i32 %z)
@@ -186,9 +173,8 @@ define i32 @smax_of_sadd_sat(i32 %x, i32 %y, i32 %z) {
 define i32 @smax_of_sadd_sat_comm(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @smax_of_sadd_sat_comm(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Z]], i32 [[X]])
-; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Z]], i32 [[Y]])
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MAX]]
 ;
   %add1 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %x)
@@ -200,9 +186,8 @@ define i32 @smax_of_sadd_sat_comm(i32 %x, i32 %y, i32 %z) {
 define i32 @smin_of_sadd_sat(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @smin_of_sadd_sat(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Y]], i32 [[Z]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MIN]]
 ;
   %add1 = call i32 @llvm.sadd.sat.i32(i32 %x, i32 %z)
@@ -214,9 +199,8 @@ define i32 @smin_of_sadd_sat(i32 %x, i32 %y, i32 %z) {
 define i32 @smin_of_sadd_sat_comm(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i32 @smin_of_sadd_sat_comm(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[ADD1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Z]], i32 [[X]])
-; CHECK-NEXT:    [[ADD2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[Z]], i32 [[Y]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[ADD1]], i32 [[ADD2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
 ; CHECK-NEXT:    ret i32 [[MIN]]
 ;
   %add1 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %x)



More information about the llvm-commits mailing list