[llvm] [InstCombine] Fold `X udiv Y` to `X lshr cttz(Y)` if Y is a power of 2 (PR #121386)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 08:04:46 PST 2025


https://github.com/veera-sivarajan updated https://github.com/llvm/llvm-project/pull/121386

>From 35bb10c16be3a07b3892713382bdbee4c8f16433 Mon Sep 17 00:00:00 2001
From: Veera <sveera.2001 at gmail.com>
Date: Tue, 31 Dec 2024 08:14:58 +0000
Subject: [PATCH 1/2] Add Test

---
 llvm/test/Transforms/InstCombine/div-shift.ll | 97 +++++++++++++++++++
 1 file changed, 97 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/div-shift.ll b/llvm/test/Transforms/InstCombine/div-shift.ll
index 8dd6d4a2e83712..b8f3593f4a33ac 100644
--- a/llvm/test/Transforms/InstCombine/div-shift.ll
+++ b/llvm/test/Transforms/InstCombine/div-shift.ll
@@ -1294,3 +1294,100 @@ entry:
   %div = sdiv i32 %add, %add2
   ret i32 %div
 }
+
+define i8 @udiv_if_power_of_two(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_if_power_of_two(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call range(i8 0, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i8 [[TMP0]], 1
+; CHECK-NEXT:    br i1 [[TMP1]], label [[BB1:%.*]], label [[BB3:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    [[TMP3:%.*]] = udiv i8 [[X:%.*]], [[Y]]
+; CHECK-NEXT:    br label [[BB3]]
+; CHECK:       bb3:
+; CHECK-NEXT:    [[_0_SROA_0_0:%.*]] = phi i8 [ [[TMP3]], [[BB1]] ], [ 0, [[START:%.*]] ]
+; CHECK-NEXT:    ret i8 [[_0_SROA_0_0]]
+;
+start:
+  %ctpop = tail call i8 @llvm.ctpop.i8(i8 %y)
+  %cmp = icmp eq i8 %ctpop, 1
+  br i1 %cmp, label %bb1, label %bb3
+
+bb1:
+  %div = udiv i8 %x, %y
+  br label %bb3
+
+bb3:
+  %result = phi i8 [ %div, %bb1 ], [ 0, %start ]
+  ret i8 %result
+}
+
+define i8 @udiv_exact_assume_power_of_two(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_exact_assume_power_of_two(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call range(i8 1, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
+; CHECK-NEXT:    [[COND:%.*]] = icmp eq i8 [[TMP0]], 1
+; CHECK-NEXT:    tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[_0:%.*]] = udiv exact i8 [[X:%.*]], [[Y]]
+; CHECK-NEXT:    ret i8 [[_0]]
+;
+start:
+  %ctpop = tail call i8 @llvm.ctpop.i8(i8 %y)
+  %cond = icmp eq i8 %ctpop, 1
+  tail call void @llvm.assume(i1 %cond)
+  %div = udiv exact i8 %x, %y
+  ret i8 %div
+}
+
+define i7 @udiv_assume_power_of_two_illegal_type(i7 %x, i7 %y) {
+; CHECK-LABEL: @udiv_assume_power_of_two_illegal_type(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call range(i7 1, 8) i7 @llvm.ctpop.i7(i7 [[Y:%.*]])
+; CHECK-NEXT:    [[COND:%.*]] = icmp eq i7 [[TMP0]], 1
+; CHECK-NEXT:    tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[_0:%.*]] = udiv i7 [[X:%.*]], [[Y]]
+; CHECK-NEXT:    ret i7 [[_0]]
+;
+start:
+  %ctpop = tail call i7 @llvm.ctpop.i7(i7 %y)
+  %cond = icmp eq i7 %ctpop, 1
+  tail call void @llvm.assume(i1 %cond)
+  %div = udiv i7 %x, %y
+  ret i7 %div
+}
+
+define i8 @udiv_assume_power_of_two_multiuse(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_assume_power_of_two_multiuse(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call range(i8 1, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
+; CHECK-NEXT:    [[COND:%.*]] = icmp eq i8 [[TMP0]], 1
+; CHECK-NEXT:    tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[_0:%.*]] = udiv i8 [[X:%.*]], [[Y]]
+; CHECK-NEXT:    call void @use(i8 [[_0]])
+; CHECK-NEXT:    ret i8 [[_0]]
+;
+start:
+  %ctpop = tail call i8 @llvm.ctpop.i8(i8 %y)
+  %cond = icmp eq i8 %ctpop, 1
+  tail call void @llvm.assume(i1 %cond)
+  %div = udiv i8 %x, %y
+  call void @use(i8 %div)
+  ret i8 %div
+}
+
+define i8 @udiv_power_of_two_negative(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @udiv_power_of_two_negative(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[CTPOP:%.*]] = tail call range(i8 0, 9) i8 @llvm.ctpop.i8(i8 [[Z:%.*]])
+; CHECK-NEXT:    [[COND:%.*]] = icmp eq i8 [[CTPOP]], 1
+; CHECK-NEXT:    tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[_0:%.*]] = udiv i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i8 [[_0]]
+;
+start:
+  %ctpop = tail call i8 @llvm.ctpop.i8(i8 %z)
+  %cond = icmp eq i8 %ctpop, 1
+  tail call void @llvm.assume(i1 %cond)
+  %div = udiv i8 %x, %y
+  ret i8 %div
+}

>From b22040368956fe0f11a644aefb59804104b81870 Mon Sep 17 00:00:00 2001
From: Veera <sveera.2001 at gmail.com>
Date: Thu, 9 Jan 2025 15:24:25 +0000
Subject: [PATCH 2/2] Fold `X udiv Y` to `X lshr cttz(Y)` if Y is a power of 2

---
 .../InstCombine/InstCombineMulDivRem.cpp      | 25 ++++++++++++++-----
 .../IndVarSimplify/rewrite-loop-exit-value.ll |  6 +++--
 llvm/test/Transforms/InstCombine/div-shift.ll | 24 ++++++++++++------
 3 files changed, 39 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f85a3c93651353..b8077953ac1acd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1623,14 +1623,27 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
     return Lshr;
   }
 
-  // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
-  if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
-               /*DoFold*/ false)) {
-    Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0,
-                          /*AssumeNonZero*/ true, /*DoFold*/ true);
+  auto GetShiftableDenom = [&](Value *Denom) -> Value * {
+    // Op0 udiv Op1 -> Op0 lshr log2(Op1), if log2() folds away.
+    if (takeLog2(Builder, Denom, /*Depth=*/0, /*AssumeNonZero=*/true,
+                 /*DoFold=*/false))
+      return takeLog2(Builder, Denom, /*Depth=*/0, /*AssumeNonZero=*/true,
+                      /*DoFold=*/true);
+
+    // Op0 udiv Op1 -> Op0 lshr cttz(Op1), if Op1 is a power of 2.
+    if (isKnownToBeAPowerOfTwo(Denom, /*OrZero=*/true, /*Depth=*/0, &I))
+      // This will increase instruction count but it's okay
+      // since bitwise operations are substantially faster than
+      // division.
+      return Builder.CreateBinaryIntrinsic(Intrinsic::cttz, Denom,
+                                           Builder.getTrue());
+
+    return nullptr;
+  };
+
+  if (auto Res = GetShiftableDenom(Op1))
     return replaceInstUsesWith(
         I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
-  }
 
   return nullptr;
 }
diff --git a/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll b/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll
index 1956f454a52bbf..fa47d06d859e97 100644
--- a/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll
+++ b/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll
@@ -218,7 +218,8 @@ define i32 @vscale_slt_with_vp_umin(ptr nocapture %A, i32 %n) mustprogress vscal
 ; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
 ; CHECK:       for.end:
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
-; CHECK-NEXT:    [[TMP1:%.*]] = udiv i32 [[TMP0]], [[VF]]
+; CHECK-NEXT:    [[TMP5:%.*]] = call range(i32 2, 33) i32 @llvm.cttz.i32(i32 [[VF]], i1 true)
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], [[TMP5]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = mul i32 [[TMP1]], [[VSCALE]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = shl i32 [[TMP2]], 2
 ; CHECK-NEXT:    [[TMP4:%.*]] = sub i32 [[N]], [[TMP3]]
@@ -270,7 +271,8 @@ define i32 @vscale_slt_with_vp_umin2(ptr nocapture %A, i32 %n) mustprogress vsca
 ; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
 ; CHECK:       for.end:
 ; CHECK-NEXT:    [[TMP0:%.*]] = add i32 [[N]], -1
-; CHECK-NEXT:    [[TMP1:%.*]] = udiv i32 [[TMP0]], [[VF]]
+; CHECK-NEXT:    [[TMP5:%.*]] = call range(i32 2, 33) i32 @llvm.cttz.i32(i32 [[VF]], i1 true)
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], [[TMP5]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = mul i32 [[TMP1]], [[VSCALE]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = shl i32 [[TMP2]], 2
 ; CHECK-NEXT:    [[TMP4:%.*]] = sub i32 [[N]], [[TMP3]]
diff --git a/llvm/test/Transforms/InstCombine/div-shift.ll b/llvm/test/Transforms/InstCombine/div-shift.ll
index b8f3593f4a33ac..af83f37011ba01 100644
--- a/llvm/test/Transforms/InstCombine/div-shift.ll
+++ b/llvm/test/Transforms/InstCombine/div-shift.ll
@@ -148,7 +148,8 @@ define i8 @udiv_umin_extra_use(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[Z2:%.*]] = shl nuw i8 1, [[Z:%.*]]
 ; CHECK-NEXT:    [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[Y2]], i8 [[Z2]])
 ; CHECK-NEXT:    call void @use(i8 [[M]])
-; CHECK-NEXT:    [[D:%.*]] = udiv i8 [[X:%.*]], [[M]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[M]], i1 true)
+; CHECK-NEXT:    [[D:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[D]]
 ;
   %y2 = shl i8 1, %y
@@ -165,7 +166,8 @@ define i8 @udiv_smin(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[Y2:%.*]] = shl nuw i8 1, [[Y:%.*]]
 ; CHECK-NEXT:    [[Z2:%.*]] = shl nuw i8 1, [[Z:%.*]]
 ; CHECK-NEXT:    [[M:%.*]] = call i8 @llvm.smin.i8(i8 [[Y2]], i8 [[Z2]])
-; CHECK-NEXT:    [[D:%.*]] = udiv i8 [[X:%.*]], [[M]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[M]], i1 true)
+; CHECK-NEXT:    [[D:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[D]]
 ;
   %y2 = shl i8 1, %y
@@ -181,7 +183,8 @@ define i8 @udiv_smax(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:    [[Y2:%.*]] = shl nuw i8 1, [[Y:%.*]]
 ; CHECK-NEXT:    [[Z2:%.*]] = shl nuw i8 1, [[Z:%.*]]
 ; CHECK-NEXT:    [[M:%.*]] = call i8 @llvm.smax.i8(i8 [[Y2]], i8 [[Z2]])
-; CHECK-NEXT:    [[D:%.*]] = udiv i8 [[X:%.*]], [[M]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[M]], i1 true)
+; CHECK-NEXT:    [[D:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[D]]
 ;
   %y2 = shl i8 1, %y
@@ -1006,7 +1009,8 @@ define i8 @udiv_fail_shl_overflow(i8 %x, i8 %y) {
 ; CHECK-LABEL: @udiv_fail_shl_overflow(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl i8 2, [[Y:%.*]]
 ; CHECK-NEXT:    [[MIN:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 1)
-; CHECK-NEXT:    [[MUL:%.*]] = udiv i8 [[X:%.*]], [[MIN]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[MIN]], i1 true)
+; CHECK-NEXT:    [[MUL:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[MUL]]
 ;
   %shl = shl i8 2, %y
@@ -1302,7 +1306,8 @@ define i8 @udiv_if_power_of_two(i8 %x, i8 %y) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i8 [[TMP0]], 1
 ; CHECK-NEXT:    br i1 [[TMP1]], label [[BB1:%.*]], label [[BB3:%.*]]
 ; CHECK:       bb1:
-; CHECK-NEXT:    [[TMP3:%.*]] = udiv i8 [[X:%.*]], [[Y]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[Y]], i1 true)
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i8 [[X:%.*]], [[TMP2]]
 ; CHECK-NEXT:    br label [[BB3]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    [[_0_SROA_0_0:%.*]] = phi i8 [ [[TMP3]], [[BB1]] ], [ 0, [[START:%.*]] ]
@@ -1328,7 +1333,8 @@ define i8 @udiv_exact_assume_power_of_two(i8 %x, i8 %y) {
 ; CHECK-NEXT:    [[TMP0:%.*]] = tail call range(i8 1, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
 ; CHECK-NEXT:    [[COND:%.*]] = icmp eq i8 [[TMP0]], 1
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[COND]])
-; CHECK-NEXT:    [[_0:%.*]] = udiv exact i8 [[X:%.*]], [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[Y]], i1 true)
+; CHECK-NEXT:    [[_0:%.*]] = lshr exact i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[_0]]
 ;
 start:
@@ -1345,7 +1351,8 @@ define i7 @udiv_assume_power_of_two_illegal_type(i7 %x, i7 %y) {
 ; CHECK-NEXT:    [[TMP0:%.*]] = tail call range(i7 1, 8) i7 @llvm.ctpop.i7(i7 [[Y:%.*]])
 ; CHECK-NEXT:    [[COND:%.*]] = icmp eq i7 [[TMP0]], 1
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[COND]])
-; CHECK-NEXT:    [[_0:%.*]] = udiv i7 [[X:%.*]], [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i7 0, 8) i7 @llvm.cttz.i7(i7 [[Y]], i1 true)
+; CHECK-NEXT:    [[_0:%.*]] = lshr i7 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i7 [[_0]]
 ;
 start:
@@ -1362,7 +1369,8 @@ define i8 @udiv_assume_power_of_two_multiuse(i8 %x, i8 %y) {
 ; CHECK-NEXT:    [[TMP0:%.*]] = tail call range(i8 1, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
 ; CHECK-NEXT:    [[COND:%.*]] = icmp eq i8 [[TMP0]], 1
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[COND]])
-; CHECK-NEXT:    [[_0:%.*]] = udiv i8 [[X:%.*]], [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[Y]], i1 true)
+; CHECK-NEXT:    [[_0:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    call void @use(i8 [[_0]])
 ; CHECK-NEXT:    ret i8 [[_0]]
 ;



More information about the llvm-commits mailing list