[llvm] [InstCombine] Fold (X << Y) / (X << Z) -> (1 << Y) / (1 << Z) (PR #68863)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 12 01:35:15 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: XChy (XChy)

<details>
<summary>Changes</summary>

Resolve #<!-- -->68857.
[Alive2](https://alive2.llvm.org/ce/z/B8DF-d) proof.
For `sdiv`, it seems to be a refinement, so I handle the signed cases too.

---
Full diff: https://github.com/llvm/llvm-project/pull/68863.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+20) 
- (modified) llvm/test/Transforms/InstCombine/div-shift.ll (+120) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 560c87b6efa7038..0247bbac32b64a0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -980,6 +980,26 @@ static Instruction *foldIDivShl(BinaryOperator &I,
       Ret = BinaryOperator::CreateSDiv(X, Y);
   }
 
+  // If X << Y and X << Z does not overflow, then:
+  // (X << Y) / (X << Z) -> (1 << Y) / (1 << Z)
+  // Ignore it when X == 1, to avoid infinite loop.
+  if (match(Op0, m_Shl(m_Value(X), m_Value(Y))) &&
+      match(Op1, m_Shl(m_Specific(X), m_Value(Z))) && !match(X, m_One()) &&
+      (Op0->hasOneUse() || Op1->hasOneUse())) {
+    auto *Shl0 = cast<OverflowingBinaryOperator>(Op0);
+    auto *Shl1 = cast<OverflowingBinaryOperator>(Op1);
+
+    Constant *One = ConstantInt::get(
+        X->getType(), APInt(X->getType()->getScalarSizeInBits(), 1));
+
+    if (!IsSigned && (Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap()))
+      Ret = BinaryOperator::CreateUDiv(Builder.CreateShl(One, Y, "common.shl"),
+                                       Builder.CreateShl(One, Z));
+    if (IsSigned && (Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap()))
+      Ret = BinaryOperator::CreateSDiv(Builder.CreateShl(One, Y),
+                                       Builder.CreateShl(One, Z));
+  }
+
   if (!Ret)
     return nullptr;
 
diff --git a/llvm/test/Transforms/InstCombine/div-shift.ll b/llvm/test/Transforms/InstCombine/div-shift.ll
index 76c5328dc8499e0..9dac529bb80731b 100644
--- a/llvm/test/Transforms/InstCombine/div-shift.ll
+++ b/llvm/test/Transforms/InstCombine/div-shift.ll
@@ -2,6 +2,7 @@
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
 declare void @use(i8)
+declare void @use32(i32)
 
 declare i8 @llvm.umin.i8(i8, i8)
 declare i8 @llvm.umax.i8(i8, i8)
@@ -1025,3 +1026,122 @@ define i8 @udiv_shl_no_overflow(i8 %x, i8 %y) {
   %mul = udiv i8 %x, %min
   ret i8 %mul
 }
+
+define i32 @sdiv_shl_pair_const(i32 %a) {
+; CHECK-LABEL: @sdiv_shl_pair_const(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret i32 2
+;
+entry:
+  %lhs = shl nsw i32 %a, 2
+  %rhs = shl nsw i32 %a, 1
+  %div = sdiv i32 %lhs, %rhs
+  ret i32 %div
+}
+
+define i32 @udiv_shl_pair_const(i32 %a) {
+; CHECK-LABEL: @udiv_shl_pair_const(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret i32 2
+;
+entry:
+  %lhs = shl nuw i32 %a, 2
+  %rhs = shl nuw i32 %a, 1
+  %div = udiv i32 %lhs, %rhs
+  ret i32 %div
+}
+
+define i32 @sdiv_shl_pair(i32 %a, i32 %x, i32 %y) {
+; CHECK-LABEL: @sdiv_shl_pair(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = shl nuw i32 1, [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw i32 1, [[X:%.*]]
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i32 [[TMP1]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[DIV]]
+;
+entry:
+  %lhs = shl nsw i32 %a, %x
+  %rhs = shl nsw i32 %a, %y
+  %div = sdiv i32 %lhs, %rhs
+  ret i32 %div
+}
+
+define i32 @udiv_shl_pair(i32 %a, i32 %x, i32 %y) {
+; CHECK-LABEL: @udiv_shl_pair(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COMMON_SHL:%.*]] = shl nuw i32 1, [[X:%.*]]
+; CHECK-NEXT:    [[DIV1:%.*]] = lshr i32 [[COMMON_SHL]], [[Y:%.*]]
+; CHECK-NEXT:    ret i32 [[DIV1]]
+;
+entry:
+  %lhs = shl nuw nsw i32 %a, %x
+  %rhs = shl nuw nsw i32 %a, %y
+  %div = udiv i32 %lhs, %rhs
+  ret i32 %div
+}
+
+define i32 @sdiv_shl_pair_overflow_fail(i32 %a, i32 %x, i32 %y) {
+; CHECK-LABEL: @sdiv_shl_pair_overflow_fail(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[LHS:%.*]] = shl i32 [[A:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = shl nsw i32 [[A]], [[Y:%.*]]
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i32 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i32 [[DIV]]
+;
+entry:
+  %lhs = shl i32 %a, %x
+  %rhs = shl nsw i32 %a, %y
+  %div = sdiv i32 %lhs, %rhs
+  ret i32 %div
+}
+
+define i32 @sdiv_shl_pair_nuw_fail(i32 %a, i32 %x, i32 %y) {
+; CHECK-LABEL: @sdiv_shl_pair_nuw_fail(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[LHS:%.*]] = shl nuw i32 [[A:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = shl nsw i32 [[A]], [[Y:%.*]]
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i32 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i32 [[DIV]]
+;
+entry:
+  %lhs = shl nuw i32 %a, %x
+  %rhs = shl nsw i32 %a, %y
+  %div = sdiv i32 %lhs, %rhs
+  ret i32 %div
+}
+
+define i32 @udiv_shl_pair_multi_use(i32 %a, i32 %x, i32 %y) {
+; CHECK-LABEL: @udiv_shl_pair_multi_use(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[LHS:%.*]] = shl nuw i32 [[A:%.*]], [[X:%.*]]
+; CHECK-NEXT:    call void @use32(i32 [[LHS]])
+; CHECK-NEXT:    [[COMMON_SHL:%.*]] = shl nuw i32 1, [[X]]
+; CHECK-NEXT:    [[DIV1:%.*]] = lshr i32 [[COMMON_SHL]], [[Y:%.*]]
+; CHECK-NEXT:    ret i32 [[DIV1]]
+;
+entry:
+  %lhs = shl nuw i32 %a, %x
+  call void @use32(i32 %lhs)
+  %rhs = shl nuw i32 %a, %y
+  %div = udiv i32 %lhs, %rhs
+  ret i32 %div
+}
+
+define i32 @udiv_shl_pair_multi_use_fail(i32 %a, i32 %x, i32 %y) {
+; CHECK-LABEL: @udiv_shl_pair_multi_use_fail(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[LHS:%.*]] = shl nuw i32 [[A:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[RHS:%.*]] = shl nuw i32 [[A]], [[Y:%.*]]
+; CHECK-NEXT:    call void @use32(i32 [[LHS]])
+; CHECK-NEXT:    call void @use32(i32 [[RHS]])
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[LHS]], [[RHS]]
+; CHECK-NEXT:    ret i32 [[DIV]]
+;
+entry:
+  %lhs = shl nuw i32 %a, %x
+  %rhs = shl nuw i32 %a, %y
+  call void @use32(i32 %lhs)
+  call void @use32(i32 %rhs)
+  %div = udiv i32 %lhs, %rhs
+  ret i32 %div
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/68863


More information about the llvm-commits mailing list