[llvm] 4276d00 - [InstCombine] add helper function for sub-of-min/max folds; NFC
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 4 15:32:20 PDT 2022
Author: Sanjay Patel
Date: 2022-07-04T17:43:18-04:00
New Revision: 4276d00b125351ebb2420e598332800976809a9e
URL: https://github.com/llvm/llvm-project/commit/4276d00b125351ebb2420e598332800976809a9e
DIFF: https://github.com/llvm/llvm-project/commit/4276d00b125351ebb2420e598332800976809a9e.diff
LOG: [InstCombine] add helper function for sub-of-min/max folds; NFC
The test diffs are cosmetic -- but improvements -- because we
let instcombine handle replacement. Instead of dropping the
old value name, it propagates to the new instruction.
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
llvm/test/Transforms/InstCombine/sub-minmax.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index f4d8b79a5311d..b429280aae6f3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1750,6 +1750,43 @@ Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
return Builder.CreateIntCast(Result, Ty, true);
}
+static Instruction *foldSubOfMinMax(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ Value *Op0 = I.getOperand(0);
+ Value *Op1 = I.getOperand(1);
+ Type *Ty = I.getType();
+ auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op1);
+ if (!MinMax)
+ return nullptr;
+
+ // sub(add(X,Y), s/umin(X,Y)) --> s/umax(X,Y)
+ // sub(add(X,Y), s/umax(X,Y)) --> s/umin(X,Y)
+ Value *X = MinMax->getLHS();
+ Value *Y = MinMax->getRHS();
+ if (match(Op0, m_c_Add(m_Specific(X), m_Specific(Y))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
+ Intrinsic::ID InvID = getInverseMinMaxIntrinsic(MinMax->getIntrinsicID());
+ Function *F = Intrinsic::getDeclaration(I.getModule(), InvID, Ty);
+ return CallInst::Create(F, {X, Y});
+ }
+
+ // sub(add(X,Y),umin(Y,Z)) --> add(X,usub.sat(Y,Z))
+ // sub(add(X,Z),umin(Y,Z)) --> add(X,usub.sat(Z,Y))
+ Value *Z;
+ if (match(Op1, m_OneUse(m_UMin(m_Value(Y), m_Value(Z))))) {
+ if (match(Op0, m_OneUse(m_c_Add(m_Specific(Y), m_Value(X))))) {
+ Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, Ty, {Y, Z});
+ return BinaryOperator::CreateAdd(X, USub);
+ }
+ if (match(Op0, m_OneUse(m_c_Add(m_Specific(Z), m_Value(X))))) {
+ Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, Ty, {Z, Y});
+ return BinaryOperator::CreateAdd(X, USub);
+ }
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
if (Value *V = simplifySubInst(I.getOperand(0), I.getOperand(1),
I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
@@ -2016,36 +2053,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
}
- if (auto *II = dyn_cast<MinMaxIntrinsic>(Op1)) {
- {
- // sub(add(X,Y), s/umin(X,Y)) --> s/umax(X,Y)
- // sub(add(X,Y), s/umax(X,Y)) --> s/umin(X,Y)
- Value *X = II->getLHS();
- Value *Y = II->getRHS();
- if (match(Op0, m_c_Add(m_Specific(X), m_Specific(Y))) &&
- (Op0->hasOneUse() || Op1->hasOneUse())) {
- Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
- Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, Y);
- return replaceInstUsesWith(I, InvMaxMin);
- }
- }
-
- {
- // sub(add(X,Y),umin(Y,Z)) --> add(X,usub.sat(Y,Z))
- // sub(add(X,Z),umin(Y,Z)) --> add(X,usub.sat(Z,Y))
- Value *X, *Y, *Z;
- if (match(Op1, m_OneUse(m_UMin(m_Value(Y), m_Value(Z))))) {
- if (match(Op0, m_OneUse(m_c_Add(m_Specific(Y), m_Value(X)))))
- return BinaryOperator::CreateAdd(
- X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(),
- {Y, Z}));
- if (match(Op0, m_OneUse(m_c_Add(m_Specific(Z), m_Value(X)))))
- return BinaryOperator::CreateAdd(
- X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(),
- {Z, Y}));
- }
- }
- }
+ if (Instruction *R = foldSubOfMinMax(I, Builder))
+ return R;
{
// If we have a subtraction between some value and a select between
diff --git a/llvm/test/Transforms/InstCombine/sub-minmax.ll b/llvm/test/Transforms/InstCombine/sub-minmax.ll
index 286e9d69444d6..fa762987c6b99 100644
--- a/llvm/test/Transforms/InstCombine/sub-minmax.ll
+++ b/llvm/test/Transforms/InstCombine/sub-minmax.ll
@@ -669,8 +669,8 @@ define i8 @umin_sub_op0_use(i8 %x, i8 %y) {
define i8 @
diff _add_smin(i8 %x, i8 %y) {
; CHECK-LABEL: define {{[^@]+}}@
diff _add_smin
; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Y]])
-; CHECK-NEXT: ret i8 [[TMP1]]
+; CHECK-NEXT: [[S:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Y]])
+; CHECK-NEXT: ret i8 [[S]]
;
%a = add i8 %x, %y
%m = call i8 @llvm.smin.i8(i8 %x, i8 %y)
@@ -681,8 +681,8 @@ define i8 @
diff _add_smin(i8 %x, i8 %y) {
define i8 @
diff _add_smax(i8 %x, i8 %y) {
; CHECK-LABEL: define {{[^@]+}}@
diff _add_smax
; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[Y]], i8 [[X]])
-; CHECK-NEXT: ret i8 [[TMP1]]
+; CHECK-NEXT: [[S:%.*]] = call i8 @llvm.smin.i8(i8 [[Y]], i8 [[X]])
+; CHECK-NEXT: ret i8 [[S]]
;
%a = add i8 %x, %y
%m = call i8 @llvm.smax.i8(i8 %y, i8 %x)
@@ -693,8 +693,8 @@ define i8 @
diff _add_smax(i8 %x, i8 %y) {
define i8 @
diff _add_umin(i8 %x, i8 %y) {
; CHECK-LABEL: define {{[^@]+}}@
diff _add_umin
; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[X]], i8 [[Y]])
-; CHECK-NEXT: ret i8 [[TMP1]]
+; CHECK-NEXT: [[S:%.*]] = call i8 @llvm.umax.i8(i8 [[X]], i8 [[Y]])
+; CHECK-NEXT: ret i8 [[S]]
;
%a = add i8 %x, %y
%m = call i8 @llvm.umin.i8(i8 %x, i8 %y)
@@ -705,8 +705,8 @@ define i8 @
diff _add_umin(i8 %x, i8 %y) {
define i8 @
diff _add_umax(i8 %x, i8 %y) {
; CHECK-LABEL: define {{[^@]+}}@
diff _add_umax
; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[Y]], i8 [[X]])
-; CHECK-NEXT: ret i8 [[TMP1]]
+; CHECK-NEXT: [[S:%.*]] = call i8 @llvm.umin.i8(i8 [[Y]], i8 [[X]])
+; CHECK-NEXT: ret i8 [[S]]
;
%a = add i8 %x, %y
%m = call i8 @llvm.umax.i8(i8 %y, i8 %x)
@@ -718,9 +718,9 @@ define i8 @
diff _add_smin_use(i8 %x, i8 %y) {
; CHECK-LABEL: define {{[^@]+}}@
diff _add_smin_use
; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 [[Y]])
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Y]])
+; CHECK-NEXT: [[S:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Y]])
; CHECK-NEXT: call void @use8(i8 [[M]])
-; CHECK-NEXT: ret i8 [[TMP1]]
+; CHECK-NEXT: ret i8 [[S]]
;
%a = add i8 %x, %y
%m = call i8 @llvm.smin.i8(i8 %x, i8 %y)
@@ -733,9 +733,9 @@ define i8 @
diff _add_use_smax(i8 %x, i8 %y) {
; CHECK-LABEL: define {{[^@]+}}@
diff _add_use_smax
; CHECK-SAME: (i8 [[X:%.*]], i8 [[Y:%.*]]) {
; CHECK-NEXT: [[A:%.*]] = add i8 [[X]], [[Y]]
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[Y]], i8 [[X]])
+; CHECK-NEXT: [[S:%.*]] = call i8 @llvm.smin.i8(i8 [[Y]], i8 [[X]])
; CHECK-NEXT: call void @use8(i8 [[A]])
-; CHECK-NEXT: ret i8 [[TMP1]]
+; CHECK-NEXT: ret i8 [[S]]
;
%a = add i8 %x, %y
%m = call i8 @llvm.smax.i8(i8 %y, i8 %x)
More information about the llvm-commits
mailing list