[llvm] [InstCombine] Remove shl if we only demand known signbits of shift source (PR #79014)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 22 09:20:26 PST 2024


https://github.com/ParkHanbum created https://github.com/llvm/llvm-project/pull/79014


this patch resolve TODO written in commit
https://github.com/llvm/llvm-project/commit/5909c678831f3a5c1669f6906f777d4ec4532fa1

>From 8c3412097cbc71a4813caf61f2ecbf9933c9d0f4 Mon Sep 17 00:00:00 2001
From: Hanbum Park <kese111 at gmail.com>
Date: Tue, 23 Jan 2024 01:38:18 +0900
Subject: [PATCH 1/2] [InstCombine] Add test for removing shl if we only demand
 known signbits of shift

---
 .../test/Transforms/InstCombine/shl-demand.ll | 121 ++++++++++++++++++
 1 file changed, 121 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/shl-demand.ll b/llvm/test/Transforms/InstCombine/shl-demand.ll
index 85752890b4b80da..024d31c32c8d621 100644
--- a/llvm/test/Transforms/InstCombine/shl-demand.ll
+++ b/llvm/test/Transforms/InstCombine/shl-demand.ll
@@ -1,6 +1,127 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt -passes=instcombine -S < %s | FileCheck %s
 
+
+; If we only want bits that already match the signbit then we don't need to shift.
+; https://alive2.llvm.org/ce/z/WJBPVt
+define i32 @src_srem_shl_demand_max_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_signbit(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 30
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -2147483648
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 2           ; srem  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD
+  %shl = shl i32 %srem, 30          ; shl   = SD000000000000000000000000000000
+  %mask = and i32 %shl, -2147483648 ; mask  = 10000000000000000000000000000000
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_min_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_min_signbit(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 1073741823
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -2147483648
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 1073741823  ; srem  = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+  %shl = shl i32 %srem, 1           ; shl   = SDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+  %mask = and i32 %shl, -2147483648 ; mask  = 10000000000000000000000000000000
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_max_mask(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_mask(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -4
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 2           ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD
+  %shl = shl i32 %srem, 1           ; shl  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD0
+  %mask = and i32 %shl, -4          ; mask = 11111111111111111111111111111100
+  ret i32 %mask
+}
+
+; Negative test - mask demands non-signbit from shift source
+define i32 @src_srem_shl_demand_max_signbit_mask_hit_first_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_signbit_mask_hit_first_demand(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 4
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 29
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -1073741824
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 4           ; srem  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD
+  %shl = shl i32 %srem, 29          ; shl   = SDD00000000000000000000000000000
+  %mask = and i32 %shl, -1073741824 ; mask  = 11000000000000000000000000000000
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_min_signbit_mask_hit_last_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_min_signbit_mask_hit_last_demand(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 536870912
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -1073741822
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 536870912   ; srem  = SSSDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+  %shl = shl i32 %srem, 1           ; shl   = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+  %mask = and i32 %shl, -1073741822 ; mask  = 11000000000000000000000000000010
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_eliminate_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_eliminate_signbit(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 1073741824
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], 2
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 1073741824  ; srem  = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+  %shl = shl i32 %srem, 1           ; shl   = DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+  %mask = and i32 %shl, 2           ; mask  = 00000000000000000000000000000010
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_max_mask_hit_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_mask_hit_demand(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 4
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -4
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 4           ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD
+  %shl= shl i32 %srem, 1            ; shl  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD0
+  %mask = and i32 %shl, -4          ; mask = 11111111111111111111111111111100
+  ret i32 %mask
+}
+
+define <2 x i32> @src_srem_shl_mask_vector(<2 x i32> %a0) {
+; CHECK-LABEL: @src_srem_shl_mask_vector(
+; CHECK-NEXT:    [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], <i32 4, i32 4>
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw <2 x i32> [[SREM]], <i32 29, i32 29>
+; CHECK-NEXT:    [[MASK:%.*]] = and <2 x i32> [[SHL]], <i32 -1073741824, i32 -1073741824>
+; CHECK-NEXT:    ret <2 x i32> [[MASK]]
+;
+  %srem = srem <2 x i32> %a0, <i32 4, i32 4>
+  %shl = shl <2 x i32> %srem, <i32 29, i32 29>
+  %mask = and <2 x i32> %shl, <i32 -1073741824, i32 -1073741824>
+  ret <2 x i32> %mask
+}
+
+define <2 x i32> @src_srem_shl_mask_vector_nonconstant(<2 x i32> %a0, <2 x i32> %a1) {
+; CHECK-LABEL: @src_srem_shl_mask_vector_nonconstant(
+; CHECK-NEXT:    [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], <i32 4, i32 4>
+; CHECK-NEXT:    [[SHL:%.*]] = shl <2 x i32> [[SREM]], [[A1:%.*]]
+; CHECK-NEXT:    [[MASK:%.*]] = and <2 x i32> [[SHL]], <i32 -1073741824, i32 -1073741824>
+; CHECK-NEXT:    ret <2 x i32> [[MASK]]
+;
+  %srem = srem <2 x i32> %a0, <i32 4, i32 4>
+  %shl = shl <2 x i32> %srem, %a1
+  %mask = and <2 x i32> %shl, <i32 -1073741824, i32 -1073741824>
+  ret <2 x i32> %mask
+}
+
 define i16 @sext_shl_trunc_same_size(i16 %x, i32 %y) {
 ; CHECK-LABEL: @sext_shl_trunc_same_size(
 ; CHECK-NEXT:    [[CONV1:%.*]] = zext i16 [[X:%.*]] to i32

>From 011f822b9bc6b945b5228f5759d35939477ed6f8 Mon Sep 17 00:00:00 2001
From: Hanbum Park <kese111 at gmail.com>
Date: Tue, 23 Jan 2024 01:42:50 +0900
Subject: [PATCH 2/2] [InstCombine] Remove shl if we only demand known signbits
 of shift source

this patch resolve TODO written in commit
5909c678831f3a5c1669f6906f777d4ec4532fa1

proof: https://alive2.llvm.org/ce/z/WJBPVt
---
 .../InstCombineSimplifyDemanded.cpp           | 40 +++++++++++--------
 .../test/Transforms/InstCombine/shl-demand.ll |  9 ++---
 2 files changed, 26 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index a8a5f9831e15e3a..6ef00bab5307ec4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -640,25 +640,31 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
                                                     DemandedMask, Known))
             return R;
 
-      // TODO: If we only want bits that already match the signbit then we don't
+      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth - 1);
+      // If we only want bits that already match the signbit then we don't
       // need to shift.
+      if (DemandedMask.countr_zero() >= ShiftAmt) {
+        unsigned NumLowDemandedBits = BitWidth - DemandedMask.countr_zero();
+        unsigned SignBits =
+            ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
+        if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumLowDemandedBits)
+          return I->getOperand(0);
 
-      // If we can pre-shift a right-shifted constant to the left without
-      // losing any high bits amd we don't demand the low bits, then eliminate
-      // the left-shift:
-      // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X
-      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
-      Value *X;
-      Constant *C;
-      if (DemandedMask.countr_zero() >= ShiftAmt &&
-          match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
-        Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
-        Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
-                                                      LeftShiftAmtC, DL);
-        if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, LeftShiftAmtC,
-                                         DL) == C) {
-          Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
-          return InsertNewInstWith(Lshr, I->getIterator());
+        // If we can pre-shift a right-shifted constant to the left without
+        // losing any high bits amd we don't demand the low bits, then eliminate
+        // the left-shift:
+        // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X
+        Value *X;
+        Constant *C;
+        if (match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
+          Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
+          Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
+                                                        LeftShiftAmtC, DL);
+          if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC,
+                                           LeftShiftAmtC, DL) == C) {
+            Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
+            return InsertNewInstWith(Lshr, I->getIterator());
+          }
         }
       }
 
diff --git a/llvm/test/Transforms/InstCombine/shl-demand.ll b/llvm/test/Transforms/InstCombine/shl-demand.ll
index 024d31c32c8d621..26175ebbe153588 100644
--- a/llvm/test/Transforms/InstCombine/shl-demand.ll
+++ b/llvm/test/Transforms/InstCombine/shl-demand.ll
@@ -7,8 +7,7 @@
 define i32 @src_srem_shl_demand_max_signbit(i32 %a0) {
 ; CHECK-LABEL: @src_srem_shl_demand_max_signbit(
 ; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
-; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 30
-; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -2147483648
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SREM]], -2147483648
 ; CHECK-NEXT:    ret i32 [[MASK]]
 ;
   %srem = srem i32 %a0, 2           ; srem  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD
@@ -20,8 +19,7 @@ define i32 @src_srem_shl_demand_max_signbit(i32 %a0) {
 define i32 @src_srem_shl_demand_min_signbit(i32 %a0) {
 ; CHECK-LABEL: @src_srem_shl_demand_min_signbit(
 ; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 1073741823
-; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
-; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -2147483648
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SREM]], -2147483648
 ; CHECK-NEXT:    ret i32 [[MASK]]
 ;
   %srem = srem i32 %a0, 1073741823  ; srem  = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
@@ -33,8 +31,7 @@ define i32 @src_srem_shl_demand_min_signbit(i32 %a0) {
 define i32 @src_srem_shl_demand_max_mask(i32 %a0) {
 ; CHECK-LABEL: @src_srem_shl_demand_max_mask(
 ; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
-; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
-; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -4
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SREM]], -4
 ; CHECK-NEXT:    ret i32 [[MASK]]
 ;
   %srem = srem i32 %a0, 2           ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD



More information about the llvm-commits mailing list