[llvm] [InstCombine] Fold ((X << nuw Z) binop nuw Y) >>u Z --> X binop nuw (Y >>u Z) (PR #88193)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 6 17:38:09 PDT 2024


https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/88193

>From e6b64d12e0bac273cf08d29579b1c7c2edccd7d1 Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Sun, 5 May 2024 21:43:20 -0400
Subject: [PATCH 1/2] [InstCombine] Pre-commit tests (NFC)

---
 llvm/test/Transforms/InstCombine/lshr.ll | 146 +++++++++++++++++++++++
 1 file changed, 146 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll
index 7d611ba188d6b4..d320f4dab77801 100644
--- a/llvm/test/Transforms/InstCombine/lshr.ll
+++ b/llvm/test/Transforms/InstCombine/lshr.ll
@@ -163,6 +163,18 @@ define <2 x i8> @lshr_exact_splat_vec(<2 x i8> %x) {
   ret <2 x i8> %lshr
 }
 
+define <2 x i8> @lshr_exact_splat_vec_nuw(<2 x i8> %x) {
+; CHECK-LABEL: @lshr_exact_splat_vec_nuw(
+; CHECK-NEXT:    [[TMP1:%.*]] = add <2 x i8> [[X:%.*]], <i8 1, i8 1>
+; CHECK-NEXT:    [[LSHR:%.*]] = and <2 x i8> [[TMP1]], <i8 63, i8 63>
+; CHECK-NEXT:    ret <2 x i8> [[LSHR]]
+;
+  %shl = shl nuw <2 x i8> %x, <i8 2, i8 2>
+  %add = add nuw <2 x i8> %shl, <i8 4, i8 4>
+  %lshr = lshr <2 x i8> %add, <i8 2, i8 2>
+  ret <2 x i8> %lshr
+}
+
 define i8 @shl_add(i8 %x, i8 %y) {
 ; CHECK-LABEL: @shl_add(
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i8 [[Y:%.*]], 2
@@ -360,8 +372,127 @@ define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {
   ret <3 x i14> %t
 }
 
+define i32 @shl_add_lshr_flag_preservation(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_add_lshr_flag_preservation(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr exact i32 [[ADD]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %add = add nuw nsw i32 %shl, %y
+  %lshr = lshr exact i32 %add, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_add_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_add_lshr(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[ADD]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %add = add nuw i32 %shl, %y
+  %lshr = lshr i32 %add, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_add_lshr_comm(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_add_lshr_comm(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[ADD]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %add = add nuw i32 %y, %shl
+  %lshr = lshr i32 %add, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_sub_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_sub_lshr(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[SUB:%.*]] = sub nuw i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[SUB]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %sub = sub nuw i32 %shl, %y
+  %lshr = lshr i32 %sub, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_or_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_or_lshr(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = or i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[OR]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %or = or i32 %shl, %y
+  %lshr = lshr i32 %or, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_or_disjoint_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_or_disjoint_lshr(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = or disjoint i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[OR]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %or = or disjoint i32 %shl, %y
+  %lshr = lshr i32 %or, %c
+  ret i32 %lshr
+}
+
+define i32 @shl_xor_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_xor_lshr(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[XOR]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %xor = xor i32 %shl, %y
+  %lshr = lshr i32 %xor, %c
+  ret i32 %lshr
+}
+
 ; Negative test
 
+define i32 @shl_and_lshr(i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_and_lshr(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[AND:%.*]] = and i32 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[AND]], [[C]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %shl = shl nuw i32 %x, %c
+  %and = and i32 %shl, %y
+  %lshr = lshr i32 %and, %c
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @shl_add_lshr_neg(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @shl_add_lshr_neg(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[SHL]], [[Z:%.*]]
+; CHECK-NEXT:    [[RES:%.*]] = lshr exact i32 [[ADD]], [[Z]]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %shl = shl nuw i32 %x, %y
+  %add = add nuw nsw i32 %shl, %z
+  %res = lshr exact i32 %add, %z
+  ret i32 %res
+}
+
 define i32 @mul_splat_fold_wrong_mul_const(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold_wrong_mul_const(
 ; CHECK-NEXT:    [[M:%.*]] = mul nuw i32 [[X:%.*]], 65538
@@ -375,6 +506,21 @@ define i32 @mul_splat_fold_wrong_mul_const(i32 %x) {
 
 ; Negative test
 
+define i32 @shl_add_lshr_multiuse(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @shl_add_lshr_multiuse(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[SHL]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use(i32 [[ADD]])
+; CHECK-NEXT:    [[RES:%.*]] = lshr exact i32 [[ADD]], [[Z]]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %shl = shl nuw i32 %x, %y
+  %add = add nuw nsw i32 %shl, %z
+  call void @use (i32 %add)
+  %res = lshr exact i32 %add, %z
+  ret i32 %res
+}
+
 define i32 @mul_splat_fold_wrong_lshr_const(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold_wrong_lshr_const(
 ; CHECK-NEXT:    [[M:%.*]] = mul nuw i32 [[X:%.*]], 65537

>From f4acce6e26dd39a48945499fb1b2d8646e122a92 Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Mon, 6 May 2024 20:36:28 -0400
Subject: [PATCH 2/2] [InstCombine] Fold ((X << nuw Z) binop nuw Y) >>u Z --> X
 binop nuw (Y >>u Z)

Proofs:
https://alive2.llvm.org/ce/z/N9dRzP
https://alive2.llvm.org/ce/z/Nc2VMk
---
 .../InstCombine/InstCombineShifts.cpp         | 37 +++++++++++++++----
 llvm/test/Transforms/InstCombine/lshr.ll      |  4 +-
 2 files changed, 31 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 1cb21a1d81af4b..cced403ce0b244 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1411,13 +1411,24 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
 
     const APInt *MulC;
     if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) {
-      // Look for a "splat" mul pattern - it replicates bits across each half of
-      // a value, so a right shift is just a mask of the low bits:
-      // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
-      // TODO: Generalize to allow more than just half-width shifts?
-      if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() &&
-          MulC->logBase2() == ShAmtC)
-        return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2));
+      if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+          MulC->logBase2() == ShAmtC) {
+        // Look for a "splat" mul pattern - it replicates bits across each half
+        // of a value, so a right shift is just a mask of the low bits:
+        // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
+        if (ShAmtC * 2 == BitWidth)
+          return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2));
+
+        // lshr (mul nuw (X, 2^N + 1)), N -> add nuw (X, lshr(X, N))
+        if (Op0->hasOneUse()) {
+          auto *NewAdd = BinaryOperator::CreateNUWAdd(
+              X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
+                                    I.isExact()));
+          NewAdd->setHasNoSignedWrap(
+              cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap());
+          return NewAdd;
+        }
+      }
 
       // The one-use check is not strictly necessary, but codegen may not be
       // able to invert the transform and perf may suffer with an extra mul
@@ -1437,6 +1448,16 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
       }
     }
 
+    // lshr (mul nsw (X, 2^N + 1)), N -> add nsw (X, lshr(X, N))
+    if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC))))) {
+      if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+          MulC->logBase2() == ShAmtC) {
+        return BinaryOperator::CreateNSWAdd(
+            X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
+                                  I.isExact()));
+      }
+    }
+
     // Try to narrow bswap.
     // In the case where the shift amount equals the bitwidth difference, the
     // shift is eliminated.
@@ -1681,4 +1702,4 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
   }
 
   return nullptr;
-}
+}
\ No newline at end of file
diff --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll
index d320f4dab77801..052e67410a563f 100644
--- a/llvm/test/Transforms/InstCombine/lshr.ll
+++ b/llvm/test/Transforms/InstCombine/lshr.ll
@@ -536,8 +536,8 @@ define i32 @mul_splat_fold_wrong_lshr_const(i32 %x) {
 
 define i32 @mul_splat_fold_no_nuw(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold_no_nuw(
-; CHECK-NEXT:    [[M:%.*]] = mul nsw i32 [[X:%.*]], 65537
-; CHECK-NEXT:    [[T:%.*]] = lshr i32 [[M]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[X:%.*]], 16
+; CHECK-NEXT:    [[T:%.*]] = add nsw i32 [[TMP1]], [[X]]
 ; CHECK-NEXT:    ret i32 [[T]]
 ;
   %m = mul nsw i32 %x, 65537



More information about the llvm-commits mailing list