[llvm] 74a5849 - [InstCombine] fold signed absolute diff patterns
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 6 10:53:32 PST 2023
Author: Sanjay Patel
Date: 2023-03-06T13:49:48-05:00
New Revision: 74a58499b7c0995b782a2f851b8a6c44b2bc5361
URL: https://github.com/llvm/llvm-project/commit/74a58499b7c0995b782a2f851b8a6c44b2bc5361
DIFF: https://github.com/llvm/llvm-project/commit/74a58499b7c0995b782a2f851b8a6c44b2bc5361.diff
LOG: [InstCombine] fold signed absolute diff patterns
This overlaps partially with the codegen patch D144789. This needs no-wrap
for correctness, and I'm not sure if there's an unsigned equivalent:
https://alive2.llvm.org/ce/z/ErmQ-9
https://alive2.llvm.org/ce/z/mr-c_A
This is obviously an improvement in IR, and it looks like a codegen win
for all targets and data types that I sampled.
The 'nabs' case is left as a potential follow-up (and seems less likely
to occur in real code).
Differential Revision: https://reviews.llvm.org/D145073
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
llvm/test/Transforms/InstCombine/abs-1.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ac1f2d5b37745..d1bac8a089388 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -950,6 +950,47 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return nullptr;
}
+/// Try to match patterns with select and subtract as absolute
diff erence.
+static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
+ auto *TI = dyn_cast<Instruction>(TVal);
+ auto *FI = dyn_cast<Instruction>(FVal);
+ if (!TI || !FI)
+ return nullptr;
+
+ // Normalize predicate to gt/lt rather than ge/le.
+ ICmpInst::Predicate Pred = Cmp->getStrictPredicate();
+ Value *A = Cmp->getOperand(0);
+ Value *B = Cmp->getOperand(1);
+
+ // Normalize "A - B" as the true value of the select.
+ if (match(FI, m_Sub(m_Specific(A), m_Specific(B)))) {
+ std::swap(FI, TI);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ // With any pair of no-wrap subtracts:
+ // (A > B) ? (A - B) : (B - A) --> abs(A - B)
+ if (Pred == CmpInst::ICMP_SGT &&
+ match(TI, m_Sub(m_Specific(A), m_Specific(B))) &&
+ match(FI, m_Sub(m_Specific(B), m_Specific(A))) &&
+ (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap()) &&
+ (FI->hasNoSignedWrap() || FI->hasNoUnsignedWrap())) {
+ // The remaining subtract is not "nuw" any more.
+ // If there's one use of the subtract (no other use than the use we are
+ // about to replace), then we know that the sub is "nsw" in this context
+ // even if it was only "nuw" before. If there's another use, then we can't
+ // add "nsw" to the existing instruction because it may not be safe in the
+ // other user's context.
+ TI->setHasNoUnsignedWrap(false);
+ if (!TI->hasNoSignedWrap())
+ TI->setHasNoSignedWrap(TI->hasOneUse());
+ return Builder.CreateBinaryIntrinsic(Intrinsic::abs, TI, Builder.getTrue());
+ }
+
+ return nullptr;
+}
+
/// Fold the following code sequence:
/// \code
/// int a = ctlz(x & -x);
@@ -1790,6 +1831,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
+ if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
+ return replaceInstUsesWith(SI, V);
+
return Changed ? &SI : nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/abs-1.ll b/llvm/test/Transforms/InstCombine/abs-1.ll
index b0a1044902bc4..7355c560c820b 100644
--- a/llvm/test/Transforms/InstCombine/abs-1.ll
+++ b/llvm/test/Transforms/InstCombine/abs-1.ll
@@ -679,6 +679,8 @@ define i8 @nabs_extra_use_icmp_sub(i8 %x) {
ret i8 %s
}
+; TODO: negate-of-abs-
diff
+
define i32 @nabs_
diff _signed_slt(i32 %a, i32 %b) {
; CHECK-LABEL: @nabs_
diff _signed_slt(
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
@@ -694,6 +696,8 @@ define i32 @nabs_
diff _signed_slt(i32 %a, i32 %b) {
ret i32 %cond
}
+; TODO: negate-of-abs-
diff
+
define <2 x i8> @nabs_
diff _signed_sle(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @nabs_
diff _signed_sle(
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp sgt <2 x i8> [[A:%.*]], [[B:%.*]]
@@ -711,11 +715,9 @@ define <2 x i8> @nabs_
diff _signed_sle(<2 x i8> %a, <2 x i8> %b) {
define i8 @abs_
diff _signed_sgt(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_
diff _signed_sgt(
-; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i8 [[B]], [[A]]
-; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A]], [[B]]
+; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]])
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
+; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
@@ -728,12 +730,11 @@ define i8 @abs_
diff _signed_sgt(i8 %a, i8 %b) {
define i8 @abs_
diff _signed_sge(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_
diff _signed_sge(
-; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i8 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i8 [[B]], [[A]]
+; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i8 [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_BA]])
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A]], [[B]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]])
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP_NOT]], i8 [[SUB_BA]], i8 [[SUB_AB]]
+; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sge i8 %a, %b
@@ -745,6 +746,8 @@ define i8 @abs_
diff _signed_sge(i8 %a, i8 %b) {
ret i8 %cond
}
+; negative test - need nsw
+
define i32 @abs_
diff _signed_slt_no_nsw(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_
diff _signed_slt_no_nsw(
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
@@ -760,12 +763,12 @@ define i32 @abs_
diff _signed_slt_no_nsw(i32 %a, i32 %b) {
ret i32 %cond
}
+; bonus nuw - it's fine to match the pattern, but nuw can't propagate
+
define i8 @abs_
diff _signed_sgt_nsw_nuw(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_
diff _signed_sgt_nsw_nuw(
-; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw nsw i8 [[B]], [[A]]
-; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw nsw i8 [[A]], [[B]]
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
+; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
@@ -775,12 +778,12 @@ define i8 @abs_
diff _signed_sgt_nsw_nuw(i8 %a, i8 %b) {
ret i8 %cond
}
+; this is absolute
diff , but nuw can't propagate and nsw can be set.
+
define i8 @abs_
diff _signed_sgt_nuw(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_
diff _signed_sgt_nuw(
-; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]]
-; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]]
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
+; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
@@ -790,13 +793,14 @@ define i8 @abs_
diff _signed_sgt_nuw(i8 %a, i8 %b) {
ret i8 %cond
}
+; same as above
+
define i8 @abs_
diff _signed_sgt_nuw_extra_use1(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_
diff _signed_sgt_nuw_extra_use1(
-; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]]
+; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_BA]])
-; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]]
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
+; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A]], [[B]]
+; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
@@ -807,13 +811,13 @@ define i8 @abs_
diff _signed_sgt_nuw_extra_use1(i8 %a, i8 %b) {
ret i8 %cond
}
+; nuw can't propagate, and the extra use prevents applying nsw
+
define i8 @abs_
diff _signed_sgt_nuw_extra_use2(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_
diff _signed_sgt_nuw_extra_use2(
-; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]]
-; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]]
+; CHECK-NEXT: [[SUB_AB:%.*]] = sub i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]])
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
+; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
@@ -824,14 +828,15 @@ define i8 @abs_
diff _signed_sgt_nuw_extra_use2(i8 %a, i8 %b) {
ret i8 %cond
}
+; same as above
+
define i8 @abs_
diff _signed_sgt_nuw_extra_use3(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_
diff _signed_sgt_nuw_extra_use3(
-; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]]
+; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_BA]])
-; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]]
+; CHECK-NEXT: [[SUB_AB:%.*]] = sub i8 [[A]], [[B]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]])
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
+; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
@@ -843,6 +848,8 @@ define i8 @abs_
diff _signed_sgt_nuw_extra_use3(i8 %a, i8 %b) {
ret i8 %cond
}
+; negative test - wrong predicate
+
define i32 @abs_
diff _signed_slt_swap_wrong_pred1(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_
diff _signed_slt_swap_wrong_pred1(
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], [[B:%.*]]
@@ -858,6 +865,8 @@ define i32 @abs_
diff _signed_slt_swap_wrong_pred1(i32 %a, i32 %b) {
ret i32 %cond
}
+; negative test - wrong predicate
+
define i32 @abs_
diff _signed_slt_swap_wrong_pred2(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_
diff _signed_slt_swap_wrong_pred2(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
@@ -873,6 +882,8 @@ define i32 @abs_
diff _signed_slt_swap_wrong_pred2(i32 %a, i32 %b) {
ret i32 %cond
}
+; negative test - need common operands
+
define i32 @abs_
diff _signed_slt_swap_wrong_op(i32 %a, i32 %b, i32 %z) {
; CHECK-LABEL: @abs_
diff _signed_slt_swap_wrong_op(
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], [[B:%.*]]
@@ -890,10 +901,8 @@ define i32 @abs_
diff _signed_slt_swap_wrong_op(i32 %a, i32 %b, i32 %z) {
define i32 @abs_
diff _signed_slt_swap(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_
diff _signed_slt_swap(
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i32 [[B]], [[A]]
-; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i32 [[A]], [[B]]
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[SUB_BA]], i32 [[SUB_AB]]
+; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i32 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.abs.i32(i32 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp slt i32 %a, %b
@@ -905,10 +914,8 @@ define i32 @abs_
diff _signed_slt_swap(i32 %a, i32 %b) {
define <2 x i8> @abs_
diff _signed_sle_swap(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @abs_
diff _signed_sle_swap(
-; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp sgt <2 x i8> [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw <2 x i8> [[B]], [[A]]
-; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw <2 x i8> [[A]], [[B]]
-; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[CMP_NOT]], <2 x i8> [[SUB_AB]], <2 x i8> [[SUB_BA]]
+; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw <2 x i8> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[COND:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[SUB_AB]], i1 true)
; CHECK-NEXT: ret <2 x i8> [[COND]]
;
%cmp = icmp sle <2 x i8> %a, %b
@@ -918,6 +925,8 @@ define <2 x i8> @abs_
diff _signed_sle_swap(<2 x i8> %a, <2 x i8> %b) {
ret <2 x i8> %cond
}
+; TODO: negate-of-abs-
diff
+
define i8 @nabs_
diff _signed_sgt_swap(i8 %a, i8 %b) {
; CHECK-LABEL: @nabs_
diff _signed_sgt_swap(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
@@ -935,6 +944,8 @@ define i8 @nabs_
diff _signed_sgt_swap(i8 %a, i8 %b) {
ret i8 %cond
}
+; TODO: negate-of-abs-
diff , but too many uses?
+
define i8 @nabs_
diff _signed_sge_swap(i8 %a, i8 %b) {
; CHECK-LABEL: @nabs_
diff _signed_sge_swap(
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i8 [[A:%.*]], [[B:%.*]]
@@ -954,6 +965,8 @@ define i8 @nabs_
diff _signed_sge_swap(i8 %a, i8 %b) {
ret i8 %cond
}
+; negative test - need nsw
+
define i32 @abs_
diff _signed_slt_no_nsw_swap(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_
diff _signed_slt_no_nsw_swap(
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
More information about the llvm-commits
mailing list