[llvm] 4276d00 - [InstCombine] add helper function for sub-of-min/max folds; NFC

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 4 15:32:20 PDT 2022


Author: Sanjay Patel
Date: 2022-07-04T17:43:18-04:00
New Revision: 4276d00b125351ebb2420e598332800976809a9e

URL: https://github.com/llvm/llvm-project/commit/4276d00b125351ebb2420e598332800976809a9e
DIFF: https://github.com/llvm/llvm-project/commit/4276d00b125351ebb2420e598332800976809a9e.diff

LOG: [InstCombine] add helper function for sub-of-min/max folds; NFC

The test diffs are cosmetic -- but improvements -- because we
let instcombine handle replacement. Instead of dropping the
old value name, it propagates to the new instruction.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/test/Transforms/InstCombine/sub-minmax.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index f4d8b79a5311d..b429280aae6f3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1750,6 +1750,43 @@ Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
   return Builder.CreateIntCast(Result, Ty, true);
 }
 
+static Instruction *foldSubOfMinMax(BinaryOperator &I,
+                                    InstCombiner::BuilderTy &Builder) {
+  Value *Op0 = I.getOperand(0);
+  Value *Op1 = I.getOperand(1);
+  Type *Ty = I.getType();
+  auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op1);
+  if (!MinMax)
+    return nullptr;
+
+  // sub(add(X,Y), s/umin(X,Y)) --> s/umax(X,Y)
+  // sub(add(X,Y), s/umax(X,Y)) --> s/umin(X,Y)
+  Value *X = MinMax->getLHS();
+  Value *Y = MinMax->getRHS();
+  if (match(Op0, m_c_Add(m_Specific(X), m_Specific(Y))) &&
+      (Op0->hasOneUse() || Op1->hasOneUse())) {
+    Intrinsic::ID InvID = getInverseMinMaxIntrinsic(MinMax->getIntrinsicID());
+    Function *F = Intrinsic::getDeclaration(I.getModule(), InvID, Ty);
+    return CallInst::Create(F, {X, Y});
+  }
+
+  // sub(add(X,Y),umin(Y,Z)) --> add(X,usub.sat(Y,Z))
+  // sub(add(X,Z),umin(Y,Z)) --> add(X,usub.sat(Z,Y))
+  Value *Z;
+  if (match(Op1, m_OneUse(m_UMin(m_Value(Y), m_Value(Z))))) {
+    if (match(Op0, m_OneUse(m_c_Add(m_Specific(Y), m_Value(X))))) {
+      Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, Ty, {Y, Z});
+      return BinaryOperator::CreateAdd(X, USub);
+    }
+    if (match(Op0, m_OneUse(m_c_Add(m_Specific(Z), m_Value(X))))) {
+      Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, Ty, {Z, Y});
+      return BinaryOperator::CreateAdd(X, USub);
+    }
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
   if (Value *V = simplifySubInst(I.getOperand(0), I.getOperand(1),
                                  I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
@@ -2016,36 +2053,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
     }
   }
 
-  if (auto *II = dyn_cast<MinMaxIntrinsic>(Op1)) {
-    {
-      // sub(add(X,Y), s/umin(X,Y)) --> s/umax(X,Y)
-      // sub(add(X,Y), s/umax(X,Y)) --> s/umin(X,Y)
-      Value *X = II->getLHS();
-      Value *Y = II->getRHS();
-      if (match(Op0, m_c_Add(m_Specific(X), m_Specific(Y))) &&
-          (Op0->hasOneUse() || Op1->hasOneUse())) {
-        Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
-        Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, Y);
-        return replaceInstUsesWith(I, InvMaxMin);
-      }
-    }
-
-    {
-      // sub(add(X,Y),umin(Y,Z)) --> add(X,usub.sat(Y,Z))
-      // sub(add(X,Z),umin(Y,Z)) --> add(X,usub.sat(Z,Y))
-      Value *X, *Y, *Z;
-      if (match(Op1, m_OneUse(m_UMin(m_Value(Y), m_Value(Z))))) {
-        if (match(Op0, m_OneUse(m_c_Add(m_Specific(Y), m_Value(X)))))
-          return BinaryOperator::CreateAdd(
-              X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(),
-                                         {Y, Z}));
-        if (match(Op0, m_OneUse(m_c_Add(m_Specific(Z), m_Value(X)))))
-          return BinaryOperator::CreateAdd(
-              X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(),
-                                         {Z, Y}));
-      }
-    }
-  }
+  if (Instruction *R = foldSubOfMinMax(I, Builder))
+    return R;
 
   {
     // If we have a subtraction between some value and a select between

diff  --git a/llvm/test/Transforms/InstCombine/sub-minmax.ll b/llvm/test/Transforms/InstCombine/sub-minmax.ll
index 286e9d69444d6..fa762987c6b99 100644
--- a/llvm/test/Transforms/InstCombine/sub-minmax.ll
+++ b/llvm/test/Transforms/InstCombine/sub-minmax.ll
@@ -669,8 +669,8 @@ define i8 @umin_sub_op0_use(i8 %x, i8 %y) {
 define i8 @
diff _add_smin(i8 %x, i8 %y) {
 ; CHECK-LABEL: define {{[^@]+}}@
diff _add_smin
 ; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Y]])
-; CHECK-NEXT:    ret i8 [[TMP1]]
+; CHECK-NEXT:    [[S:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Y]])
+; CHECK-NEXT:    ret i8 [[S]]
 ;
   %a = add i8 %x, %y
   %m = call i8 @llvm.smin.i8(i8 %x, i8 %y)
@@ -681,8 +681,8 @@ define i8 @
diff _add_smin(i8 %x, i8 %y) {
 define i8 @
diff _add_smax(i8 %x, i8 %y) {
 ; CHECK-LABEL: define {{[^@]+}}@
diff _add_smax
 ; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[Y]], i8 [[X]])
-; CHECK-NEXT:    ret i8 [[TMP1]]
+; CHECK-NEXT:    [[S:%.*]] = call i8 @llvm.smin.i8(i8 [[Y]], i8 [[X]])
+; CHECK-NEXT:    ret i8 [[S]]
 ;
   %a = add i8 %x, %y
   %m = call i8 @llvm.smax.i8(i8 %y, i8 %x)
@@ -693,8 +693,8 @@ define i8 @
diff _add_smax(i8 %x, i8 %y) {
 define i8 @
diff _add_umin(i8 %x, i8 %y) {
 ; CHECK-LABEL: define {{[^@]+}}@
diff _add_umin
 ; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[X]], i8 [[Y]])
-; CHECK-NEXT:    ret i8 [[TMP1]]
+; CHECK-NEXT:    [[S:%.*]] = call i8 @llvm.umax.i8(i8 [[X]], i8 [[Y]])
+; CHECK-NEXT:    ret i8 [[S]]
 ;
   %a = add i8 %x, %y
   %m = call i8 @llvm.umin.i8(i8 %x, i8 %y)
@@ -705,8 +705,8 @@ define i8 @
diff _add_umin(i8 %x, i8 %y) {
 define i8 @
diff _add_umax(i8 %x, i8 %y) {
 ; CHECK-LABEL: define {{[^@]+}}@
diff _add_umax
 ; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[Y]], i8 [[X]])
-; CHECK-NEXT:    ret i8 [[TMP1]]
+; CHECK-NEXT:    [[S:%.*]] = call i8 @llvm.umin.i8(i8 [[Y]], i8 [[X]])
+; CHECK-NEXT:    ret i8 [[S]]
 ;
   %a = add i8 %x, %y
   %m = call i8 @llvm.umax.i8(i8 %y, i8 %x)
@@ -718,9 +718,9 @@ define i8 @
diff _add_smin_use(i8 %x, i8 %y) {
 ; CHECK-LABEL: define {{[^@]+}}@
diff _add_smin_use
 ; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
 ; CHECK-NEXT:    [[M:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 [[Y]])
-; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Y]])
+; CHECK-NEXT:    [[S:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Y]])
 ; CHECK-NEXT:    call void @use8(i8 [[M]])
-; CHECK-NEXT:    ret i8 [[TMP1]]
+; CHECK-NEXT:    ret i8 [[S]]
 ;
   %a = add i8 %x, %y
   %m = call i8 @llvm.smin.i8(i8 %x, i8 %y)
@@ -733,9 +733,9 @@ define i8 @
diff _add_use_smax(i8 %x, i8 %y) {
 ; CHECK-LABEL: define {{[^@]+}}@
diff _add_use_smax
 ; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
 ; CHECK-NEXT:    [[A:%.*]] = add i8 [[X]], [[Y]]
-; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[Y]], i8 [[X]])
+; CHECK-NEXT:    [[S:%.*]] = call i8 @llvm.smin.i8(i8 [[Y]], i8 [[X]])
 ; CHECK-NEXT:    call void @use8(i8 [[A]])
-; CHECK-NEXT:    ret i8 [[TMP1]]
+; CHECK-NEXT:    ret i8 [[S]]
 ;
   %a = add i8 %x, %y
   %m = call i8 @llvm.smax.i8(i8 %y, i8 %x)


        


More information about the llvm-commits mailing list