[llvm] 5d4a0d5 - [InstCombine] Teach takeLog2 about right shifts, truncation and bitwise-and

David Majnemer via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 27 22:30:59 PDT 2024


Author: David Majnemer
Date: 2024-10-28T05:13:04Z
New Revision: 5d4a0d54b5269bad1410e6db957836fe98634069

URL: https://github.com/llvm/llvm-project/commit/5d4a0d54b5269bad1410e6db957836fe98634069
DIFF: https://github.com/llvm/llvm-project/commit/5d4a0d54b5269bad1410e6db957836fe98634069.diff

LOG: [InstCombine] Teach takeLog2 about right shifts, truncation and bitwise-and

We left some easy opportunities for further simplifications.

log2(trunc(x)) is simply trunc(log2(x)). This is safe if we know that
trunc is NUW because it means that the truncation didn't drop any bits.
It is also safe if the caller is OK with zero as a possible answer.

log2(x >>u y) is simply `log2(x) - y`.

log2(x & y) is a funny one. It comes up when doing something like:
```
unsigned int f(unsigned int x, unsigned int y) {
  unsigned char a = 1u << x;
  return y / a;
}
```

LLVM would canonicalize this to:
```
  %shl = shl nuw i32 1, %x
  %conv1 = and i32 %shl, 255
  %div = udiv i32 %y, %conv1
```

In cases like these, we can ignore the mask entirely.
This is equivalent to `y >> x`.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstCombine/div.ll
    llvm/test/Transforms/InstCombine/shift.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f4f3644acfe5ea..b9c165da906da4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1427,6 +1427,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
     if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
       return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
 
+  // log2(trunc x) -> trunc log2(X)
+  // FIXME: Require one use?
+  if (match(Op, m_Trunc(m_Value(X)))) {
+    auto *TI = cast<TruncInst>(Op);
+    if (AssumeNonZero || TI->hasNoUnsignedWrap())
+      if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+        return IfFold([&]() {
+          return Builder.CreateTrunc(LogX, Op->getType(), "",
+                                     /*IsNUW=*/TI->hasNoUnsignedWrap());
+        });
+  }
+
   // log2(X << Y) -> log2(X) + Y
   // FIXME: Require one use unless X is 1?
   if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) {
@@ -1437,6 +1449,24 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
         return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
   }
 
+  // log2(X >>u Y) -> log2(X) - Y
+  // FIXME: Require one use?
+  if (match(Op, m_LShr(m_Value(X), m_Value(Y)))) {
+    auto *PEO = cast<PossiblyExactOperator>(Op);
+    if (AssumeNonZero || PEO->isExact())
+      if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+        return IfFold([&]() { return Builder.CreateSub(LogX, Y); });
+  }
+
+  // log2(X & Y) -> either log2(X) or log2(Y)
+  // This requires `AssumeNonZero` as `X & Y` may be zero when X != Y.
+  if (AssumeNonZero && match(Op, m_And(m_Value(X), m_Value(Y)))) {
+    if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+      return IfFold([&]() { return LogX; });
+    if (Value *LogY = takeLog2(Builder, Y, Depth, AssumeNonZero, DoFold))
+      return IfFold([&]() { return LogY; });
+  }
+
   // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
   // FIXME: Require one use?
   if (SelectInst *SI = dyn_cast<SelectInst>(Op))

diff  --git a/llvm/test/Transforms/InstCombine/div.ll b/llvm/test/Transforms/InstCombine/div.ll
index e8a25ff44d0296..a91c9bfc91c40d 100644
--- a/llvm/test/Transforms/InstCombine/div.ll
+++ b/llvm/test/Transforms/InstCombine/div.ll
@@ -429,9 +429,8 @@ define <2 x i32> @test31(<2 x i32> %x) {
 
 define i32 @test32(i32 %a, i32 %b) {
 ; CHECK-LABEL: @test32(
-; CHECK-NEXT:    [[SHL:%.*]] = shl i32 2, [[B:%.*]]
-; CHECK-NEXT:    [[DIV:%.*]] = lshr i32 [[SHL]], 2
-; CHECK-NEXT:    [[DIV2:%.*]] = udiv i32 [[A:%.*]], [[DIV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[B:%.*]], -1
+; CHECK-NEXT:    [[DIV2:%.*]] = lshr i32 [[A:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[DIV2]]
 ;
   %shl = shl i32 2, %b
@@ -1832,3 +1831,41 @@ define i32 @fold_disjoint_or_over_udiv(i32 %x) {
   %r = udiv i32 %or, 9
   ret i32 %r
 }
+
+define i8 @udiv_trunc_shl(i32 %x) {
+; CHECK-LABEL: @udiv_trunc_shl(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT:    [[UDIV1:%.*]] = lshr i8 8, [[TMP1]]
+; CHECK-NEXT:    ret i8 [[UDIV1]]
+;
+  %lshr = shl i32 1, %x
+  %trunc = trunc i32 %lshr to i8
+  %div = udiv i8 8, %trunc
+  ret i8 %div
+}
+
+define i32 @zext_udiv_trunc_lshr(i32 %x) {
+; CHECK-LABEL: @zext_udiv_trunc_lshr(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = sub i8 5, [[TMP1]]
+; CHECK-NEXT:    [[UDIV1:%.*]] = lshr i8 8, [[TMP2]]
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext nneg i8 [[UDIV1]] to i32
+; CHECK-NEXT:    ret i32 [[ZEXT]]
+;
+  %lshr = lshr i32 32, %x
+  %trunc = trunc i32 %lshr to i8
+  %div = udiv i8 8, %trunc
+  %zext = zext i8 %div to i32
+  ret i32 %zext
+}
+
+define i32 @udiv_and_shl(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: @udiv_and_shl(
+; CHECK-NEXT:    [[DIV1:%.*]] = lshr i32 [[C:%.*]], [[A:%.*]]
+; CHECK-NEXT:    ret i32 [[DIV1]]
+;
+  %shl = shl i32 1, %a
+  %and = and i32 %b, %shl
+  %div = udiv i32 %c, %and
+  ret i32 %div
+}

diff  --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll
index 558f4ffbfcabe4..69f531e98f045b 100644
--- a/llvm/test/Transforms/InstCombine/shift.ll
+++ b/llvm/test/Transforms/InstCombine/shift.ll
@@ -677,8 +677,8 @@ entry:
 
 define i32 @test42(i32 %a, i32 %b) {
 ; CHECK-LABEL: @test42(
-; CHECK-NEXT:    [[DIV:%.*]] = lshr exact i32 4096, [[B:%.*]]
-; CHECK-NEXT:    [[DIV2:%.*]] = udiv i32 [[A:%.*]], [[DIV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = sub i32 12, [[B:%.*]]
+; CHECK-NEXT:    [[DIV2:%.*]] = lshr i32 [[A:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[DIV2]]
 ;
   %div = lshr i32 4096, %b    ; must be exact otherwise we'd divide by zero
@@ -688,8 +688,8 @@ define i32 @test42(i32 %a, i32 %b) {
 
 define <2 x i32> @test42vec(<2 x i32> %a, <2 x i32> %b) {
 ; CHECK-LABEL: @test42vec(
-; CHECK-NEXT:    [[DIV:%.*]] = lshr exact <2 x i32> <i32 4096, i32 4096>, [[B:%.*]]
-; CHECK-NEXT:    [[DIV2:%.*]] = udiv <2 x i32> [[A:%.*]], [[DIV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <2 x i32> <i32 12, i32 12>, [[B:%.*]]
+; CHECK-NEXT:    [[DIV2:%.*]] = lshr <2 x i32> [[A:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <2 x i32> [[DIV2]]
 ;
   %div = lshr <2 x i32> <i32 4096, i32 4096>, %b    ; must be exact otherwise we'd divide by zero


        


More information about the llvm-commits mailing list