[llvm] 1e36c96 - [InstCombine] Fold ((X << nuw Z) binop nuw Y) >>u Z --> X binop nuw (Y >>u Z) (#88193)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 7 12:18:00 PDT 2024


Author: AtariDreams
Date: 2024-05-07T21:17:56+02:00
New Revision: 1e36c96dc0998e886644d6fc76aa475d88d9645c

URL: https://github.com/llvm/llvm-project/commit/1e36c96dc0998e886644d6fc76aa475d88d9645c
DIFF: https://github.com/llvm/llvm-project/commit/1e36c96dc0998e886644d6fc76aa475d88d9645c.diff

LOG: [InstCombine] Fold ((X << nuw Z) binop nuw Y) >>u Z --> X binop nuw (Y >>u Z) (#88193)

Proofs:
https://alive2.llvm.org/ce/z/N9dRzP
https://alive2.llvm.org/ce/z/Xrpc-Y
https://alive2.llvm.org/ce/z/BagBM6

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
    llvm/test/Transforms/InstCombine/lshr.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 1cb21a1d81af4..8847de3667130 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1259,6 +1259,54 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
       match(Op1, m_SpecificIntAllowPoison(BitWidth - 1)))
     return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
 
+  // ((X << nuw Z) sub nuw Y) >>u exact Z --> X sub nuw (Y >>u exact Z),
+  Value *Y;
+  if (I.isExact() &&
+      match(Op0, m_OneUse(m_NUWSub(m_NUWShl(m_Value(X), m_Specific(Op1)),
+                                   m_Value(Y))))) {
+    Value *NewLshr = Builder.CreateLShr(Y, Op1, "", /*isExact=*/true);
+    auto *NewSub = BinaryOperator::CreateNUWSub(X, NewLshr);
+    NewSub->setHasNoSignedWrap(
+        cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap());
+    return NewSub;
+  }
+
+  auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) {
+    switch (BinOpcode) {
+    default:
+      return false;
+    case Instruction::Add:
+    case Instruction::And:
+    case Instruction::Or:
+    case Instruction::Xor:
+      // And does not work here, and sub is handled separately.
+      return true;
+    }
+  };
+
+  // If both the binop and the shift are nuw, then:
+  // ((X << nuw Z) binop nuw Y) >>u Z --> X binop nuw (Y >>u Z)
+  if (match(Op0, m_OneUse(m_c_BinOp(m_NUWShl(m_Value(X), m_Specific(Op1)),
+                                    m_Value(Y))))) {
+    BinaryOperator *Op0OB = cast<BinaryOperator>(Op0);
+    if (isSuitableBinOpcode(Op0OB->getOpcode())) {
+      if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op0);
+          !OBO || OBO->hasNoUnsignedWrap()) {
+        Value *NewLshr = Builder.CreateLShr(
+            Y, Op1, "", I.isExact() && Op0OB->getOpcode() != Instruction::And);
+        auto *NewBinOp = BinaryOperator::Create(Op0OB->getOpcode(), NewLshr, X);
+        if (OBO) {
+          NewBinOp->setHasNoUnsignedWrap(true);
+          NewBinOp->setHasNoSignedWrap(OBO->hasNoSignedWrap());
+        } else if (auto *Disjoint = dyn_cast<PossiblyDisjointInst>(Op0)) {
+          cast<PossiblyDisjointInst>(NewBinOp)->setIsDisjoint(
+              Disjoint->isDisjoint());
+        }
+        return NewBinOp;
+      }
+    }
+  }
+
   if (match(Op1, m_APInt(C))) {
     unsigned ShAmtC = C->getZExtValue();
     auto *II = dyn_cast<IntrinsicInst>(Op0);
@@ -1275,7 +1323,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
       return new ZExtInst(Cmp, Ty);
     }
 
-    Value *X;
     const APInt *C1;
     if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
       if (C1->ult(ShAmtC)) {
@@ -1320,7 +1367,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
     // ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C)
     // TODO: Consolidate with the more general transform that starts from shl
     //       (the shifts are in the opposite order).
-    Value *Y;
     if (match(Op0,
               m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))),
                                m_Value(Y))))) {

diff  --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll
index 7d611ba188d6b..563e669f90353 100644
--- a/llvm/test/Transforms/InstCombine/lshr.ll
+++ b/llvm/test/Transforms/InstCombine/lshr.ll
@@ -163,6 +163,17 @@ define <2 x i8> @lshr_exact_splat_vec(<2 x i8> %x) {
   ret <2 x i8> %lshr
 }
 
+define <2 x i8> @lshr_exact_splat_vec_nuw(<2 x i8> %x) {
+; CHECK-LABEL: @lshr_exact_splat_vec_nuw(
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw <2 x i8> [[X:%.*]], <i8 1, i8 1>
+; CHECK-NEXT:    ret <2 x i8> [[LSHR]]
+;
+  %shl = shl nuw <2 x i8> %x, <i8 2, i8 2>
+  %add = add nuw <2 x i8> %shl, <i8 4, i8 4>
+  %lshr = lshr <2 x i8> %add, <i8 2, i8 2>
+  ret <2 x i8> %lshr
+}
+
 define i8 @shl_add(i8 %x, i8 %y) {
 ; CHECK-LABEL: @shl_add(
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i8 [[Y:%.*]], 2
@@ -360,8 +371,222 @@ define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {
   ret <3 x i14> %t
 }
 
+define i32 @shl_add_lshr_flag_preservation(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_add_lshr_flag_preservation(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw nsw i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %add = add nuw nsw i32 %shl, %y
+  %lshr = lshr exact i32 %add, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_add_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_add_lshr(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %add = add nuw i32 %shl, %y
+  %lshr = lshr i32 %add, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_add_lshr_comm(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_add_lshr_comm(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %add = add nuw i32 %y, %shl
+  %lshr = lshr i32 %add, %c
+  ret i32 %lshr
+}
+
 ; Negative test
 
+define i32 @shl_add_lshr_no_nuw(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_add_lshr_no_nuw(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[ADD]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %add = add i32 %shl, %y
+  %lshr = lshr i32 %add, %c
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @shl_sub_lshr_not_exact(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_sub_lshr_not_exact(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[SUB:%.*]] = sub nuw i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[SUB]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %sub = sub nuw i32 %shl, %y
+  %lshr = lshr i32 %sub, %c
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @shl_sub_lshr_no_nuw(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_sub_lshr_no_nuw(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[SUB:%.*]] = sub nsw i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr exact i32 [[SUB]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nsw i32 %x, %c
+  %sub = sub nsw i32 %shl, %y
+  %lshr = lshr exact i32 %sub, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_sub_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_sub_lshr(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = sub nuw nsw i32 [[X:%.*]], [[TMP1]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %sub = sub nuw nsw i32 %shl, %y
+  %lshr = lshr exact i32 %sub, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_or_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_or_lshr(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = or i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %or = or i32 %shl, %y
+  %lshr = lshr i32 %or, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_or_disjoint_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_or_disjoint_lshr(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = or disjoint i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %or = or disjoint i32 %shl, %y
+  %lshr = lshr i32 %or, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_or_lshr_comm(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_or_lshr_comm(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = or i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %or = or i32 %y, %shl
+  %lshr = lshr i32 %or, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_or_disjoint_lshr_comm(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_or_disjoint_lshr_comm(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = or disjoint i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %or = or disjoint i32 %y, %shl
+  %lshr = lshr i32 %or, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_xor_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_xor_lshr(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = xor i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %xor = xor i32 %shl, %y
+  %lshr = lshr i32 %xor, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_xor_lshr_comm(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_xor_lshr_comm(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = xor i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %xor = xor i32 %y, %shl
+  %lshr = lshr i32 %xor, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_and_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_and_lshr(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = and i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %and = and i32 %shl, %y
+  %lshr = lshr i32 %and, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_and_lshr_comm(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_and_lshr_comm(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = and i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %and = and i32 %y, %shl
+  %lshr = lshr i32 %and, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_lshr_and_exact(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_lshr_and_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[Y:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[TMP2]]
+;
+  %2 = shl nuw i32 %x, %c
+  %3 = and i32 %2, %y
+  %4 = lshr exact i32 %3, %c
+  ret i32 %4
+}
+
+; Negative test
+
+define i32 @shl_add_lshr_neg(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @shl_add_lshr_neg(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[SHL]], [[Z:%.*]]
+; CHECK-NEXT:    [[RES:%.*]] = lshr exact i32 [[ADD]], [[Z]]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %shl = shl nuw i32 %x, %y
+  %add = add nuw nsw i32 %shl, %z
+  %res = lshr exact i32 %add, %z
+  ret i32 %res
+}
+
 define i32 @mul_splat_fold_wrong_mul_const(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold_wrong_mul_const(
 ; CHECK-NEXT:    [[M:%.*]] = mul nuw i32 [[X:%.*]], 65538
@@ -375,6 +600,21 @@ define i32 @mul_splat_fold_wrong_mul_const(i32 %x) {
 
 ; Negative test
 
+define i32 @shl_add_lshr_multiuse(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @shl_add_lshr_multiuse(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[SHL]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use(i32 [[ADD]])
+; CHECK-NEXT:    [[RES:%.*]] = lshr exact i32 [[ADD]], [[Z]]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %shl = shl nuw i32 %x, %y
+  %add = add nuw nsw i32 %shl, %z
+  call void @use (i32 %add)
+  %res = lshr exact i32 %add, %z
+  ret i32 %res
+}
+
 define i32 @mul_splat_fold_wrong_lshr_const(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold_wrong_lshr_const(
 ; CHECK-NEXT:    [[M:%.*]] = mul nuw i32 [[X:%.*]], 65537


        


More information about the llvm-commits mailing list