[llvm] [InstCombine] Fold `X udiv Y` to 'X lshr cttz(Y)` if Y is a power of 2 (PR #121386)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 31 03:38:50 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Veera (veera-sivarajan)
<details>
<summary>Changes</summary>
Fixes #<!-- -->115767
This PR folds `X udiv Y` to `X lshr cttz(Y)` if Y is a power of two since
bitwise operations are faster than division.
Proof: https://alive2.llvm.org/ce/z/qHmLta
---
Full diff: https://github.com/llvm/llvm-project/pull/121386.diff
3 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+10)
- (modified) llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll (+4-2)
- (modified) llvm/test/Transforms/InstCombine/div-shift.ll (+105-4)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f85a3c93651353..00779fe5fa2ee1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1632,6 +1632,16 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
}
+ // Op0 udiv Op1 -> Op0 lshr cttz(Op1), if Op1 is a power of 2.
+ if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, /*Depth*/ 0, &I)) {
+ // This will increase instruction count but it's okay
+ // since bitwise operations are substantially faster than
+ // division.
+ auto *Cttz =
+ Builder.CreateBinaryIntrinsic(Intrinsic::cttz, Op1, Builder.getTrue());
+ return BinaryOperator::CreateLShr(Op0, Cttz);
+ }
+
return nullptr;
}
diff --git a/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll b/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll
index 1956f454a52bbf..fa47d06d859e97 100644
--- a/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll
+++ b/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll
@@ -218,7 +218,8 @@ define i32 @vscale_slt_with_vp_umin(ptr nocapture %A, i32 %n) mustprogress vscal
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
; CHECK: for.end:
; CHECK-NEXT: [[TMP0:%.*]] = add nsw i32 [[N]], -1
-; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[TMP0]], [[VF]]
+; CHECK-NEXT: [[TMP5:%.*]] = call range(i32 2, 33) i32 @llvm.cttz.i32(i32 [[VF]], i1 true)
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], [[TMP5]]
; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP1]], [[VSCALE]]
; CHECK-NEXT: [[TMP3:%.*]] = shl i32 [[TMP2]], 2
; CHECK-NEXT: [[TMP4:%.*]] = sub i32 [[N]], [[TMP3]]
@@ -270,7 +271,8 @@ define i32 @vscale_slt_with_vp_umin2(ptr nocapture %A, i32 %n) mustprogress vsca
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
; CHECK: for.end:
; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[N]], -1
-; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[TMP0]], [[VF]]
+; CHECK-NEXT: [[TMP5:%.*]] = call range(i32 2, 33) i32 @llvm.cttz.i32(i32 [[VF]], i1 true)
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], [[TMP5]]
; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP1]], [[VSCALE]]
; CHECK-NEXT: [[TMP3:%.*]] = shl i32 [[TMP2]], 2
; CHECK-NEXT: [[TMP4:%.*]] = sub i32 [[N]], [[TMP3]]
diff --git a/llvm/test/Transforms/InstCombine/div-shift.ll b/llvm/test/Transforms/InstCombine/div-shift.ll
index 8dd6d4a2e83712..005daed087c169 100644
--- a/llvm/test/Transforms/InstCombine/div-shift.ll
+++ b/llvm/test/Transforms/InstCombine/div-shift.ll
@@ -148,7 +148,8 @@ define i8 @udiv_umin_extra_use(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[Z2:%.*]] = shl nuw i8 1, [[Z:%.*]]
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[Y2]], i8 [[Z2]])
; CHECK-NEXT: call void @use(i8 [[M]])
-; CHECK-NEXT: [[D:%.*]] = udiv i8 [[X:%.*]], [[M]]
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[M]], i1 true)
+; CHECK-NEXT: [[D:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
; CHECK-NEXT: ret i8 [[D]]
;
%y2 = shl i8 1, %y
@@ -165,7 +166,8 @@ define i8 @udiv_smin(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[Y2:%.*]] = shl nuw i8 1, [[Y:%.*]]
; CHECK-NEXT: [[Z2:%.*]] = shl nuw i8 1, [[Z:%.*]]
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.smin.i8(i8 [[Y2]], i8 [[Z2]])
-; CHECK-NEXT: [[D:%.*]] = udiv i8 [[X:%.*]], [[M]]
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[M]], i1 true)
+; CHECK-NEXT: [[D:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
; CHECK-NEXT: ret i8 [[D]]
;
%y2 = shl i8 1, %y
@@ -181,7 +183,8 @@ define i8 @udiv_smax(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[Y2:%.*]] = shl nuw i8 1, [[Y:%.*]]
; CHECK-NEXT: [[Z2:%.*]] = shl nuw i8 1, [[Z:%.*]]
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.smax.i8(i8 [[Y2]], i8 [[Z2]])
-; CHECK-NEXT: [[D:%.*]] = udiv i8 [[X:%.*]], [[M]]
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[M]], i1 true)
+; CHECK-NEXT: [[D:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
; CHECK-NEXT: ret i8 [[D]]
;
%y2 = shl i8 1, %y
@@ -1006,7 +1009,8 @@ define i8 @udiv_fail_shl_overflow(i8 %x, i8 %y) {
; CHECK-LABEL: @udiv_fail_shl_overflow(
; CHECK-NEXT: [[SHL:%.*]] = shl i8 2, [[Y:%.*]]
; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 1)
-; CHECK-NEXT: [[MUL:%.*]] = udiv i8 [[X:%.*]], [[MIN]]
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[MIN]], i1 true)
+; CHECK-NEXT: [[MUL:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
; CHECK-NEXT: ret i8 [[MUL]]
;
%shl = shl i8 2, %y
@@ -1294,3 +1298,100 @@ entry:
%div = sdiv i32 %add, %add2
ret i32 %div
}
+
+define i8 @udiv_if_power_of_two(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_if_power_of_two(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call range(i8 0, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[TMP0]], 1
+; CHECK-NEXT: br i1 [[TMP1]], label [[BB1:%.*]], label [[BB3:%.*]]
+; CHECK: bb1:
+; CHECK-NEXT: [[TMP2:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[Y]], i1 true)
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i8 [[X:%.*]], [[TMP2]]
+; CHECK-NEXT: br label [[BB3]]
+; CHECK: bb3:
+; CHECK-NEXT: [[_0_SROA_0_0:%.*]] = phi i8 [ [[TMP3]], [[BB1]] ], [ 0, [[START:%.*]] ]
+; CHECK-NEXT: ret i8 [[_0_SROA_0_0]]
+;
+start:
+ %0 = tail call i8 @llvm.ctpop.i8(i8 %y)
+ %1 = icmp eq i8 %0, 1
+ br i1 %1, label %bb1, label %bb3
+
+bb1:
+ %2 = udiv i8 %x, %y
+ br label %bb3
+
+bb3:
+ %_0.sroa.0.0 = phi i8 [ %2, %bb1 ], [ 0, %start ]
+ ret i8 %_0.sroa.0.0
+}
+
+define i8 @udiv_exact_assume_power_of_two(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_exact_assume_power_of_two(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call range(i8 1, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
+; CHECK-NEXT: [[COND:%.*]] = icmp eq i8 [[TMP0]], 1
+; CHECK-NEXT: tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[Y]], i1 true)
+; CHECK-NEXT: [[_0:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
+; CHECK-NEXT: ret i8 [[_0]]
+;
+start:
+ %0 = tail call i8 @llvm.ctpop.i8(i8 %y)
+ %cond = icmp eq i8 %0, 1
+ tail call void @llvm.assume(i1 %cond)
+ %_0 = udiv exact i8 %x, %y
+ ret i8 %_0
+}
+
+define i7 @udiv_assume_power_of_two_illegal_type(i7 %x, i7 %y) {
+; CHECK-LABEL: @udiv_assume_power_of_two_illegal_type(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call range(i7 1, 8) i7 @llvm.ctpop.i7(i7 [[Y:%.*]])
+; CHECK-NEXT: [[COND:%.*]] = icmp eq i7 [[TMP0]], 1
+; CHECK-NEXT: tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i7 0, 8) i7 @llvm.cttz.i7(i7 [[Y]], i1 true)
+; CHECK-NEXT: [[_0:%.*]] = lshr i7 [[X:%.*]], [[TMP1]]
+; CHECK-NEXT: ret i7 [[_0]]
+;
+start:
+ %0 = tail call i7 @llvm.ctpop.i7(i7 %y)
+ %cond = icmp eq i7 %0, 1
+ tail call void @llvm.assume(i1 %cond)
+ %_0 = udiv i7 %x, %y
+ ret i7 %_0
+}
+
+define i8 @udiv_assume_power_of_two_multiuse(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_assume_power_of_two_multiuse(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call range(i8 1, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
+; CHECK-NEXT: [[COND:%.*]] = icmp eq i8 [[TMP0]], 1
+; CHECK-NEXT: tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[Y]], i1 true)
+; CHECK-NEXT: [[_0:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
+; CHECK-NEXT: call void @use(i8 [[_0]])
+; CHECK-NEXT: ret i8 [[_0]]
+;
+start:
+ %0 = tail call i8 @llvm.ctpop.i8(i8 %y)
+ %cond = icmp eq i8 %0, 1
+ tail call void @llvm.assume(i1 %cond)
+ %_0 = udiv i8 %x, %y
+ call void @use(i8 %_0)
+ ret i8 %_0
+}
+
+define i8 @udiv_power_of_two_negative(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_power_of_two_negative(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[_0:%.*]] = udiv i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i8 [[_0]]
+;
+start:
+ %0 = tail call i8 @llvm.ctpop.i8(i8 %y)
+ %cond = icmp eq i8 %0, 1
+ %_0 = udiv i8 %x, %y
+ ret i8 %_0
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/121386
More information about the llvm-commits
mailing list