[llvm] [DAGCombiner] Preserve nuw when converting mul to shl. Use nuw in srl+shl combine. (PR #155043)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 25 14:37:10 PDT 2025


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/155043

>From 6645d2e60169c5d8bc1008ff37d983f99c76b672 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Fri, 22 Aug 2025 15:36:18 -0700
Subject: [PATCH 1/2] [DAGCombiner] Preserve nuw when converting mul to shl.
 Use nuw in srl+shl combine.

If the srl+shl have the same shift amount and the shl has the nuw
flag, we can remove both.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  10 +-
 .../RISCV/rvv/vp-vector-interleaved-access.ll | 109 +++++++-----------
 2 files changed, 53 insertions(+), 66 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 15d7e7626942d..de9bcfd410440 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4710,7 +4710,10 @@ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
-      return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc);
+      SDNodeFlags Flags;
+      Flags.setNoUnsignedWrap(N->getFlags().hasNoUnsignedWrap());
+      // TODO: Preserve setNoSignedWrap if LogBase2 isn't BitWidth - 1.
+      return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc, Flags);
     }
   }
 
@@ -11094,6 +11097,11 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
   if (N0.getOpcode() == ISD::SHL &&
       (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
+    // If the shift amounts are the same and the shl doesn't shift out any
+    // non-zero bits, we can return the shl input.
+    if (N0.getOperand(1) == N1 && N0->getFlags().hasNoUnsignedWrap())
+      return N0.getOperand(0);
+
     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
                                            ConstantSDNode *RHS) {
       const APInt &LHSC = LHS->getAPIntValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-vector-interleaved-access.ll b/llvm/test/CodeGen/RISCV/rvv/vp-vector-interleaved-access.ll
index 2afb72fc71b39..13a836e8a7552 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vp-vector-interleaved-access.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-vector-interleaved-access.ll
@@ -5,16 +5,14 @@
 define {<vscale x 2 x i32>, <vscale x 2 x i32>} @load_factor2_v2(ptr %ptr, i32 %evl) {
 ; RV32-LABEL: load_factor2_v2:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 1
-; RV32-NEXT:    srli a1, a1, 1
 ; RV32-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV32-NEXT:    vlseg2e32.v v8, (a0)
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: load_factor2_v2:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 33
-; RV64-NEXT:    srli a1, a1, 33
+; RV64-NEXT:    slli a1, a1, 32
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV64-NEXT:    vlseg2e32.v v8, (a0)
 ; RV64-NEXT:    ret
@@ -142,16 +140,14 @@ merge:
 define {<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>} @load_factor4_v2(ptr %ptr, i32 %evl) {
 ; RV32-LABEL: load_factor4_v2:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 2
-; RV32-NEXT:    srli a1, a1, 2
 ; RV32-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV32-NEXT:    vlseg4e32.v v8, (a0)
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: load_factor4_v2:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 34
-; RV64-NEXT:    srli a1, a1, 34
+; RV64-NEXT:    slli a1, a1, 32
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV64-NEXT:    vlseg4e32.v v8, (a0)
 ; RV64-NEXT:    ret
@@ -237,16 +233,14 @@ define {<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2
 define {<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>} @load_factor8_v2(ptr %ptr, i32 %evl) {
 ; RV32-LABEL: load_factor8_v2:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 3
-; RV32-NEXT:    srli a1, a1, 3
 ; RV32-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV32-NEXT:    vlseg8e32.v v8, (a0)
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: load_factor8_v2:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 35
-; RV64-NEXT:    srli a1, a1, 35
+; RV64-NEXT:    slli a1, a1, 32
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV64-NEXT:    vlseg8e32.v v8, (a0)
 ; RV64-NEXT:    ret
@@ -276,16 +270,14 @@ define {<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2
 define void @store_factor2_v2(<vscale x 1 x i32> %v0, <vscale x 1 x i32> %v1, ptr %ptr, i32 %evl) {
 ; RV32-LABEL: store_factor2_v2:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 1
-; RV32-NEXT:    srli a1, a1, 1
 ; RV32-NEXT:    vsetvli zero, a1, e32, mf2, ta, ma
 ; RV32-NEXT:    vsseg2e32.v v8, (a0)
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: store_factor2_v2:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 33
-; RV64-NEXT:    srli a1, a1, 33
+; RV64-NEXT:    slli a1, a1, 32
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, mf2, ta, ma
 ; RV64-NEXT:    vsseg2e32.v v8, (a0)
 ; RV64-NEXT:    ret
@@ -384,8 +376,6 @@ define void @store_factor7_v2(<vscale x 1 x i32> %v0, <vscale x 1 x i32> %v1, <v
 define void @store_factor8_v2(<vscale x 1 x i32> %v0, <vscale x 1 x i32> %v1, ptr %ptr, i32 %evl) {
 ; RV32-LABEL: store_factor8_v2:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 3
-; RV32-NEXT:    srli a1, a1, 3
 ; RV32-NEXT:    vsetvli zero, a1, e32, mf2, ta, ma
 ; RV32-NEXT:    vmv1r.v v10, v8
 ; RV32-NEXT:    vmv1r.v v11, v9
@@ -398,8 +388,8 @@ define void @store_factor8_v2(<vscale x 1 x i32> %v0, <vscale x 1 x i32> %v1, pt
 ;
 ; RV64-LABEL: store_factor8_v2:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 35
-; RV64-NEXT:    srli a1, a1, 35
+; RV64-NEXT:    slli a1, a1, 32
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, mf2, ta, ma
 ; RV64-NEXT:    vmv1r.v v10, v8
 ; RV64-NEXT:    vmv1r.v v11, v9
@@ -418,16 +408,14 @@ define void @store_factor8_v2(<vscale x 1 x i32> %v0, <vscale x 1 x i32> %v1, pt
 define {<vscale x 2 x i32>, <vscale x 2 x i32>} @masked_load_factor2_v2(<vscale x 2 x i1> %mask, ptr %ptr, i32 %evl) {
 ; RV32-LABEL: masked_load_factor2_v2:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 1
-; RV32-NEXT:    srli a1, a1, 1
 ; RV32-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV32-NEXT:    vlseg2e32.v v8, (a0), v0.t
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: masked_load_factor2_v2:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 33
-; RV64-NEXT:    srli a1, a1, 33
+; RV64-NEXT:    slli a1, a1, 32
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV64-NEXT:    vlseg2e32.v v8, (a0), v0.t
 ; RV64-NEXT:    ret
@@ -445,16 +433,14 @@ define {<vscale x 2 x i32>, <vscale x 2 x i32>} @masked_load_factor2_v2(<vscale
 define {<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>} @masked_load_factor4_v2(<vscale x 2 x i1> %mask, ptr %ptr, i32 %evl) {
 ; RV32-LABEL: masked_load_factor4_v2:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 2
-; RV32-NEXT:    srli a1, a1, 2
 ; RV32-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV32-NEXT:    vlseg4e32.v v8, (a0), v0.t
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: masked_load_factor4_v2:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 34
-; RV64-NEXT:    srli a1, a1, 34
+; RV64-NEXT:    slli a1, a1, 32
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV64-NEXT:    vlseg4e32.v v8, (a0), v0.t
 ; RV64-NEXT:    ret
@@ -477,20 +463,17 @@ define {<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2
 define void @masked_store_factor2_v2(<vscale x 1 x i1> %mask, <vscale x 1 x i32> %v0, <vscale x 1 x i32> %v1, ptr %ptr, i32 %evl) {
 ; RV32-LABEL: masked_store_factor2_v2:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 1
-; RV32-NEXT:    vsetivli zero, 1, e8, m1, ta, ma
-; RV32-NEXT:    vmv1r.v v9, v8
-; RV32-NEXT:    srli a1, a1, 1
 ; RV32-NEXT:    vsetvli zero, a1, e32, mf2, ta, ma
+; RV32-NEXT:    vmv1r.v v9, v8
 ; RV32-NEXT:    vsseg2e32.v v8, (a0), v0.t
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: masked_store_factor2_v2:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 33
+; RV64-NEXT:    slli a1, a1, 32
 ; RV64-NEXT:    vsetivli zero, 1, e8, m1, ta, ma
 ; RV64-NEXT:    vmv1r.v v9, v8
-; RV64-NEXT:    srli a1, a1, 33
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, mf2, ta, ma
 ; RV64-NEXT:    vsseg2e32.v v8, (a0), v0.t
 ; RV64-NEXT:    ret
@@ -504,8 +487,6 @@ define void @masked_store_factor2_v2(<vscale x 1 x i1> %mask, <vscale x 1 x i32>
 define void @masked_load_store_factor2_v2_shared_mask(<vscale x 2 x i1> %mask, ptr %ptr, i32 %evl) {
 ; RV32-LABEL: masked_load_store_factor2_v2_shared_mask:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 1
-; RV32-NEXT:    srli a1, a1, 1
 ; RV32-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV32-NEXT:    vlseg2e32.v v8, (a0), v0.t
 ; RV32-NEXT:    vsseg2e32.v v8, (a0), v0.t
@@ -513,8 +494,8 @@ define void @masked_load_store_factor2_v2_shared_mask(<vscale x 2 x i1> %mask, p
 ;
 ; RV64-LABEL: masked_load_store_factor2_v2_shared_mask:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 33
-; RV64-NEXT:    srli a1, a1, 33
+; RV64-NEXT:    slli a1, a1, 32
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV64-NEXT:    vlseg2e32.v v8, (a0), v0.t
 ; RV64-NEXT:    vsseg2e32.v v8, (a0), v0.t
@@ -537,37 +518,36 @@ define i32 @masked_load_store_factor2_v2_shared_mask_extract(<vscale x 2 x i1> %
 ; RV32-NEXT:    vmv1r.v v8, v0
 ; RV32-NEXT:    slli a2, a1, 1
 ; RV32-NEXT:    vmv.v.i v9, 0
-; RV32-NEXT:    li a1, -1
+; RV32-NEXT:    li a3, -1
 ; RV32-NEXT:    vmerge.vim v10, v9, 1, v0
 ; RV32-NEXT:    vwaddu.vv v11, v10, v10
-; RV32-NEXT:    vwmaccu.vx v11, a1, v10
-; RV32-NEXT:    csrr a1, vlenb
+; RV32-NEXT:    vwmaccu.vx v11, a3, v10
+; RV32-NEXT:    csrr a3, vlenb
 ; RV32-NEXT:    vsetvli zero, a2, e8, mf2, ta, ma
 ; RV32-NEXT:    vmv.v.i v10, 0
-; RV32-NEXT:    srli a1, a1, 2
-; RV32-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; RV32-NEXT:    srli a3, a3, 2
+; RV32-NEXT:    vsetvli a4, zero, e8, mf4, ta, ma
 ; RV32-NEXT:    vmsne.vi v0, v11, 0
-; RV32-NEXT:    vsetvli a3, zero, e8, mf2, ta, ma
-; RV32-NEXT:    vslidedown.vx v11, v11, a1
+; RV32-NEXT:    vsetvli a4, zero, e8, mf2, ta, ma
+; RV32-NEXT:    vslidedown.vx v11, v11, a3
 ; RV32-NEXT:    vsetvli zero, a2, e8, mf2, ta, ma
 ; RV32-NEXT:    vmerge.vim v10, v10, 1, v0
-; RV32-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; RV32-NEXT:    vsetvli a4, zero, e8, mf4, ta, ma
 ; RV32-NEXT:    vmsne.vi v0, v11, 0
 ; RV32-NEXT:    vmerge.vim v9, v9, 1, v0
 ; RV32-NEXT:    vsetvli zero, a2, e8, mf2, ta, ma
-; RV32-NEXT:    vslideup.vx v10, v9, a1
+; RV32-NEXT:    vslideup.vx v10, v9, a3
 ; RV32-NEXT:    vmsne.vi v0, v10, 0
 ; RV32-NEXT:    vle32.v v10, (a0), v0.t
-; RV32-NEXT:    li a1, 32
+; RV32-NEXT:    li a2, 32
 ; RV32-NEXT:    vsetvli a3, zero, e32, m1, ta, ma
-; RV32-NEXT:    vnsrl.wx v13, v10, a1
-; RV32-NEXT:    vmv.x.s a1, v10
+; RV32-NEXT:    vnsrl.wx v13, v10, a2
 ; RV32-NEXT:    vnsrl.wi v12, v10, 0
-; RV32-NEXT:    srli a2, a2, 1
+; RV32-NEXT:    vmv.x.s a2, v10
 ; RV32-NEXT:    vmv1r.v v0, v8
-; RV32-NEXT:    vsetvli zero, a2, e32, m1, ta, ma
+; RV32-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
 ; RV32-NEXT:    vsseg2e32.v v12, (a0), v0.t
-; RV32-NEXT:    mv a0, a1
+; RV32-NEXT:    mv a0, a2
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: masked_load_store_factor2_v2_shared_mask_extract:
@@ -590,20 +570,21 @@ define i32 @masked_load_store_factor2_v2_shared_mask_extract(<vscale x 2 x i1> %
 ; RV64-NEXT:    vmerge.vim v10, v10, 1, v0
 ; RV64-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
 ; RV64-NEXT:    vmsne.vi v0, v11, 0
-; RV64-NEXT:    slli a3, a1, 33
 ; RV64-NEXT:    vmerge.vim v9, v9, 1, v0
-; RV64-NEXT:    vsetvli a1, zero, e8, mf2, ta, ma
+; RV64-NEXT:    vsetvli a3, zero, e8, mf2, ta, ma
 ; RV64-NEXT:    vslideup.vx v10, v9, a2
+; RV64-NEXT:    slli a2, a1, 33
 ; RV64-NEXT:    vmsne.vi v0, v10, 0
-; RV64-NEXT:    srli a1, a3, 32
-; RV64-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; RV64-NEXT:    srli a2, a2, 32
+; RV64-NEXT:    vsetvli zero, a2, e32, m2, ta, ma
 ; RV64-NEXT:    vle32.v v10, (a0), v0.t
-; RV64-NEXT:    li a1, 32
-; RV64-NEXT:    vsetvli a2, zero, e32, m1, ta, ma
-; RV64-NEXT:    vnsrl.wx v13, v10, a1
+; RV64-NEXT:    li a2, 32
+; RV64-NEXT:    slli a3, a1, 32
+; RV64-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
+; RV64-NEXT:    vnsrl.wx v13, v10, a2
 ; RV64-NEXT:    vmv.x.s a1, v10
 ; RV64-NEXT:    vnsrl.wi v12, v10, 0
-; RV64-NEXT:    srli a3, a3, 33
+; RV64-NEXT:    srli a3, a3, 32
 ; RV64-NEXT:    vmv1r.v v0, v8
 ; RV64-NEXT:    vsetvli zero, a3, e32, m1, ta, ma
 ; RV64-NEXT:    vsseg2e32.v v12, (a0), v0.t
@@ -624,8 +605,6 @@ define i32 @masked_load_store_factor2_v2_shared_mask_extract(<vscale x 2 x i1> %
 define void @masked_store_factor4_v2(<vscale x 1 x i1> %mask, <vscale x 1 x i32> %v0, <vscale x 1 x i32> %v1, ptr %ptr, i32 %evl) {
 ; RV32-LABEL: masked_store_factor4_v2:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    slli a1, a1, 2
-; RV32-NEXT:    srli a1, a1, 2
 ; RV32-NEXT:    vsetvli zero, a1, e32, mf2, ta, ma
 ; RV32-NEXT:    vmv1r.v v10, v8
 ; RV32-NEXT:    vmv1r.v v11, v9
@@ -634,8 +613,8 @@ define void @masked_store_factor4_v2(<vscale x 1 x i1> %mask, <vscale x 1 x i32>
 ;
 ; RV64-LABEL: masked_store_factor4_v2:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    slli a1, a1, 34
-; RV64-NEXT:    srli a1, a1, 34
+; RV64-NEXT:    slli a1, a1, 32
+; RV64-NEXT:    srli a1, a1, 32
 ; RV64-NEXT:    vsetvli zero, a1, e32, mf2, ta, ma
 ; RV64-NEXT:    vmv1r.v v10, v8
 ; RV64-NEXT:    vmv1r.v v11, v9

>From 60d23683945ebdb068ced7c6ab93eb58494a5213 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Mon, 25 Aug 2025 14:36:45 -0700
Subject: [PATCH 2/2] fixup! Address review comment

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 68 +++++++++----------
 1 file changed, 34 insertions(+), 34 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index de9bcfd410440..65e06c1bf30e7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11092,43 +11092,43 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
     }
   }
 
-  // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
-  //                               (and (srl x, (sub c2, c1), MASK)
-  if (N0.getOpcode() == ISD::SHL &&
-      (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
-      TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
-    // If the shift amounts are the same and the shl doesn't shift out any
-    // non-zero bits, we can return the shl input.
+  if (N0.getOpcode() == ISD::SHL) {
+    // fold (srl (shl nuw x, c), c) -> x
     if (N0.getOperand(1) == N1 && N0->getFlags().hasNoUnsignedWrap())
       return N0.getOperand(0);
 
-    auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
-                                           ConstantSDNode *RHS) {
-      const APInt &LHSC = LHS->getAPIntValue();
-      const APInt &RHSC = RHS->getAPIntValue();
-      return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
-             LHSC.getZExtValue() <= RHSC.getZExtValue();
-    };
-    if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
-                                  /*AllowUndefs*/ false,
-                                  /*AllowTypeMismatch*/ true)) {
-      SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
-      SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
-      SDValue Mask = DAG.getAllOnesConstant(DL, VT);
-      Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
-      Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
-      SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
-      return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
-    }
-    if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
-                                  /*AllowUndefs*/ false,
-                                  /*AllowTypeMismatch*/ true)) {
-      SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
-      SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
-      SDValue Mask = DAG.getAllOnesConstant(DL, VT);
-      Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
-      SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
-      return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
+    // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
+    //                               (and (srl x, (sub c2, c1), MASK)
+    if ((N0.getOperand(1) == N1 || N0->hasOneUse()) &&
+        TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
+      auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
+                                             ConstantSDNode *RHS) {
+        const APInt &LHSC = LHS->getAPIntValue();
+        const APInt &RHSC = RHS->getAPIntValue();
+        return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
+               LHSC.getZExtValue() <= RHSC.getZExtValue();
+      };
+      if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
+                                    /*AllowUndefs*/ false,
+                                    /*AllowTypeMismatch*/ true)) {
+        SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
+        SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
+        SDValue Mask = DAG.getAllOnesConstant(DL, VT);
+        Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
+        Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
+        SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
+        return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
+      }
+      if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
+                                    /*AllowUndefs*/ false,
+                                    /*AllowTypeMismatch*/ true)) {
+        SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
+        SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
+        SDValue Mask = DAG.getAllOnesConstant(DL, VT);
+        Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
+        SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
+        return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
+      }
     }
   }
 



More information about the llvm-commits mailing list