[llvm] 3ecf731 - [InstCombine] Reduce absolute diff from min+max+sub
Jun Zhang via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 9 16:01:45 PST 2023
Author: Jun Zhang
Date: 2023-03-10T08:00:40+08:00
New Revision: 3ecf731376cb4661d85c10932bd7bd7c11998ad8
URL: https://github.com/llvm/llvm-project/commit/3ecf731376cb4661d85c10932bd7bd7c11998ad8
DIFF: https://github.com/llvm/llvm-project/commit/3ecf731376cb4661d85c10932bd7bd7c11998ad8.diff
LOG: [InstCombine] Reduce absolute diff from min+max+sub
This patch implements fold: max(a,b) nsw/nuw - min(a,b) --> abs(a nsw - b)
Alive2: https://alive2.llvm.org/ce/z/4yLp7D
Fixes: https://github.com/llvm/llvm-project/issues/61228
Signed-off-by: Jun Zhang <jun at junz.org>
Differential Revision: https://reviews.llvm.org/D145540
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
llvm/test/Transforms/InstCombine/sub-minmax.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 0d8fa5c5a2c5..2c2b767b2ced 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2406,6 +2406,18 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
return replaceInstUsesWith(I, Mul);
}
+ // max(X,Y) nsw/nuw - min(X,Y) --> abs(X nsw - Y)
+ if (match(Op0, m_OneUse(m_c_SMax(m_Value(X), m_Value(Y)))) &&
+ match(Op1, m_OneUse(m_c_SMin(m_Specific(X), m_Specific(Y))))) {
+ if (I.hasNoUnsignedWrap() || I.hasNoSignedWrap()) {
+ Value *Sub =
+ Builder.CreateSub(X, Y, "sub", /*HasNUW=*/false, /*HasNSW=*/true);
+ Value *Call =
+ Builder.CreateBinaryIntrinsic(Intrinsic::abs, Sub, Builder.getTrue());
+ return replaceInstUsesWith(I, Call);
+ }
+ }
+
return TryToNarrowDeduceFlags();
}
diff --git a/llvm/test/Transforms/InstCombine/sub-minmax.ll b/llvm/test/Transforms/InstCombine/sub-minmax.ll
index 0e0eff0fd58f..c9ce165c3898 100644
--- a/llvm/test/Transforms/InstCombine/sub-minmax.ll
+++ b/llvm/test/Transforms/InstCombine/sub-minmax.ll
@@ -1001,9 +1001,8 @@ define i8 @sub_smin0_sub_nsw_commute(i8 %x, i8 %y) {
define i8 @sub_max_min_nsw(i8 %a, i8 %b) {
; CHECK-LABEL: define {{[^@]+}}@sub_max_min_nsw
; CHECK-SAME: (i8 [[A:%.*]], i8 [[B:%.*]]) {
-; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.smin.i8(i8 [[A]], i8 [[B]])
-; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.smax.i8(i8 [[A]], i8 [[B]])
-; CHECK-NEXT: [[AB:%.*]] = sub nsw i8 [[MAX]], [[MIN]]
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw i8 [[A]], [[B]]
+; CHECK-NEXT: [[AB:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB]], i1 true)
; CHECK-NEXT: ret i8 [[AB]]
;
%min = call i8 @llvm.smin.i8(i8 %a, i8 %b)
@@ -1015,9 +1014,8 @@ define i8 @sub_max_min_nsw(i8 %a, i8 %b) {
define i8 @sub_max_min_nuw(i8 %a, i8 %b) {
; CHECK-LABEL: define {{[^@]+}}@sub_max_min_nuw
; CHECK-SAME: (i8 [[A:%.*]], i8 [[B:%.*]]) {
-; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.smin.i8(i8 [[A]], i8 [[B]])
-; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.smax.i8(i8 [[A]], i8 [[B]])
-; CHECK-NEXT: [[AB:%.*]] = sub nuw i8 [[MAX]], [[MIN]]
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw i8 [[A]], [[B]]
+; CHECK-NEXT: [[AB:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB]], i1 true)
; CHECK-NEXT: ret i8 [[AB]]
;
%min = call i8 @llvm.smin.i8(i8 %a, i8 %b)
@@ -1026,12 +1024,37 @@ define i8 @sub_max_min_nuw(i8 %a, i8 %b) {
ret i8 %ab
}
+define i8 @sub_max_min_nsw_commute(i8 %a, i8 %b) {
+; CHECK-LABEL: define {{[^@]+}}@sub_max_min_nsw_commute
+; CHECK-SAME: (i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw i8 [[A]], [[B]]
+; CHECK-NEXT: [[AB:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB]], i1 true)
+; CHECK-NEXT: ret i8 [[AB]]
+;
+ %min = call i8 @llvm.smin.i8(i8 %b, i8 %a)
+ %max = call i8 @llvm.smax.i8(i8 %a, i8 %b)
+ %ab = sub nsw i8 %max, %min
+ ret i8 %ab
+}
+
+define i8 @sub_max_min_nuw_commute(i8 %a, i8 %b) {
+; CHECK-LABEL: define {{[^@]+}}@sub_max_min_nuw_commute
+; CHECK-SAME: (i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw i8 [[A]], [[B]]
+; CHECK-NEXT: [[AB:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB]], i1 true)
+; CHECK-NEXT: ret i8 [[AB]]
+;
+ %min = call i8 @llvm.smin.i8(i8 %b, i8 %a)
+ %max = call i8 @llvm.smax.i8(i8 %a, i8 %b)
+ %ab = sub nuw i8 %max, %min
+ ret i8 %ab
+}
+
define <2 x i8> @sub_max_min_vec_nsw(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: define {{[^@]+}}@sub_max_min_vec_nsw
; CHECK-SAME: (<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]]) {
-; CHECK-NEXT: [[MIN:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[A]], <2 x i8> [[B]])
-; CHECK-NEXT: [[MAX:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[A]], <2 x i8> [[B]])
-; CHECK-NEXT: [[AB:%.*]] = sub nsw <2 x i8> [[MAX]], [[MIN]]
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw <2 x i8> [[A]], [[B]]
+; CHECK-NEXT: [[AB:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[SUB]], i1 true)
; CHECK-NEXT: ret <2 x i8> [[AB]]
;
%min = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %a, <2 x i8> %b)
@@ -1043,16 +1066,81 @@ define <2 x i8> @sub_max_min_vec_nsw(<2 x i8> %a, <2 x i8> %b) {
define <2 x i8> @sub_max_min_vec_nuw(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: define {{[^@]+}}@sub_max_min_vec_nuw
; CHECK-SAME: (<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]]) {
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw <2 x i8> [[A]], [[B]]
+; CHECK-NEXT: [[AB:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[SUB]], i1 true)
+; CHECK-NEXT: ret <2 x i8> [[AB]]
+;
+ %min = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %a, <2 x i8> %b)
+ %max = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %a, <2 x i8> %b)
+ %ab = sub nuw <2 x i8> %max, %min
+ ret <2 x i8> %ab
+}
+
+define <2 x i8> @sub_max_min_vec_nsw_commute(<2 x i8> %a, <2 x i8> %b) {
+; CHECK-LABEL: define {{[^@]+}}@sub_max_min_vec_nsw_commute
+; CHECK-SAME: (<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]]) {
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw <2 x i8> [[A]], [[B]]
+; CHECK-NEXT: [[AB:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[SUB]], i1 true)
+; CHECK-NEXT: ret <2 x i8> [[AB]]
+;
+ %min = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %b, <2 x i8> %a)
+ %max = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %a, <2 x i8> %b)
+ %ab = sub nsw <2 x i8> %max, %min
+ ret <2 x i8> %ab
+}
+
+define <2 x i8> @sub_max_min_vec_nuw_commute(<2 x i8> %a, <2 x i8> %b) {
+; CHECK-LABEL: define {{[^@]+}}@sub_max_min_vec_nuw_commute
+; CHECK-SAME: (<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]]) {
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw <2 x i8> [[A]], [[B]]
+; CHECK-NEXT: [[AB:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[SUB]], i1 true)
+; CHECK-NEXT: ret <2 x i8> [[AB]]
+;
+ %min = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %b, <2 x i8> %a)
+ %max = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %a, <2 x i8> %b)
+ %ab = sub nuw <2 x i8> %max, %min
+ ret <2 x i8> %ab
+}
+
+; negative test - multiple use
+
+define i8 @sub_max_min_multi_use(i8 %a, i8 %b) {
+; CHECK-LABEL: define {{[^@]+}}@sub_max_min_multi_use
+; CHECK-SAME: (i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.smin.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT: call void @use8(i8 [[MIN]])
+; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.smax.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT: call void @use8(i8 [[MAX]])
+; CHECK-NEXT: [[AB:%.*]] = sub nsw i8 [[MAX]], [[MIN]]
+; CHECK-NEXT: ret i8 [[AB]]
+;
+ %min = call i8 @llvm.smin.i8(i8 %a, i8 %b)
+ call void @use8(i8 %min)
+ %max = call i8 @llvm.smax.i8(i8 %a, i8 %b)
+ call void @use8(i8 %max)
+ %ab = sub nsw i8 %max, %min
+ ret i8 %ab
+}
+
+define <2 x i8> @sub_max_min_vec_multi_use(<2 x i8> %a, <2 x i8> %b) {
+; CHECK-LABEL: define {{[^@]+}}@sub_max_min_vec_multi_use
+; CHECK-SAME: (<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]]) {
; CHECK-NEXT: [[MIN:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[A]], <2 x i8> [[B]])
+; CHECK-NEXT: call void @use8v2(<2 x i8> [[MIN]])
; CHECK-NEXT: [[MAX:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[A]], <2 x i8> [[B]])
-; CHECK-NEXT: [[AB:%.*]] = sub nuw <2 x i8> [[MAX]], [[MIN]]
+; CHECK-NEXT: call void @use8v2(<2 x i8> [[MAX]])
+; CHECK-NEXT: [[AB:%.*]] = sub nsw <2 x i8> [[MAX]], [[MIN]]
; CHECK-NEXT: ret <2 x i8> [[AB]]
;
%min = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %a, <2 x i8> %b)
+ call void @use8v2(<2 x i8> %min)
%max = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %a, <2 x i8> %b)
- %ab = sub nuw <2 x i8> %max, %min
+ call void @use8v2(<2 x i8> %max)
+ %ab = sub nsw <2 x i8> %max, %min
ret <2 x i8> %ab
}
declare void @use8(i8)
declare void @use32(i32 %u)
+
+declare void @use8v2(i8)
More information about the llvm-commits
mailing list