[llvm] [DAGCombiner] Attempt to fold 'add' nodes to funnel-shift or rotate (PR #125612)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 1 16:32:07 PDT 2025


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/125612

>From dc0e5b7374b09339a8e7d9970ca423c0a1f6c321 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 4 Feb 2025 17:39:29 +0000
Subject: [PATCH 1/2] [DAGCombiner] Attempt to fold 'add' nodes to funnel-shift
 or rotate

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 86 +++++++++----------
 llvm/test/CodeGen/AMDGPU/rotate-add.ll        | 30 ++-----
 llvm/test/CodeGen/ARM/rotate-add.ll           |  7 +-
 llvm/test/CodeGen/NVPTX/rotate-add.ll         | 42 +++------
 llvm/test/CodeGen/X86/rotate-add.ll           | 55 +++---------
 5 files changed, 77 insertions(+), 143 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index dc5c5f38e3bd8..0d74199bd4b36 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -649,14 +649,15 @@ namespace {
                                bool DemandHighBits = true);
     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
-                              SDValue InnerPos, SDValue InnerNeg, bool HasPos,
-                              unsigned PosOpcode, unsigned NegOpcode,
-                              const SDLoc &DL);
+                              SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
+                              bool HasPos, unsigned PosOpcode,
+                              unsigned NegOpcode, const SDLoc &DL);
     SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
-                              SDValue InnerPos, SDValue InnerNeg, bool HasPos,
-                              unsigned PosOpcode, unsigned NegOpcode,
-                              const SDLoc &DL);
-    SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
+                              SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
+                              bool HasPos, unsigned PosOpcode,
+                              unsigned NegOpcode, const SDLoc &DL);
+    SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
+                        bool FromAdd);
     SDValue MatchLoadCombine(SDNode *N);
     SDValue mergeTruncStores(StoreSDNode *N);
     SDValue reduceLoadWidth(SDNode *N);
@@ -2986,6 +2987,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
   if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
     return V;
 
+  if (SDValue V = MatchRotate(N0, N1, SDLoc(N), /*FromAdd=*/true))
+    return V;
+
   // Try to match AVGFLOOR fixedwidth pattern
   if (SDValue V = foldAddToAvg(N, DL))
     return V;
@@ -8175,7 +8179,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
       return V;
 
   // See if this is some rotate idiom.
-  if (SDValue Rot = MatchRotate(N0, N1, DL))
+  if (SDValue Rot = MatchRotate(N0, N1, DL, /*FromAdd=*/false))
     return Rot;
 
   if (SDValue Load = MatchLoadCombine(N))
@@ -8364,7 +8368,7 @@ static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
 // The IsRotate flag should be set when the LHS of both shifts is the same.
 // Otherwise if matching a general funnel shift, it should be clear.
 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
-                           SelectionDAG &DAG, bool IsRotate) {
+                           SelectionDAG &DAG, bool IsRotate, bool FromAdd) {
   const auto &TLI = DAG.getTargetLoweringInfo();
   // If EltSize is a power of 2 then:
   //
@@ -8403,7 +8407,7 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
   // NOTE: We can only do this when matching operations which won't modify the
   // least Log2(EltSize) significant bits and not a general funnel shift.
   unsigned MaskLoBits = 0;
-  if (IsRotate && isPowerOf2_64(EltSize)) {
+  if (IsRotate && !FromAdd && isPowerOf2_64(EltSize)) {
     unsigned Bits = Log2_64(EltSize);
     unsigned NegBits = Neg.getScalarValueSizeInBits();
     if (NegBits >= Bits) {
@@ -8486,9 +8490,9 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
 // Neg with outer conversions stripped away.
 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
                                        SDValue Neg, SDValue InnerPos,
-                                       SDValue InnerNeg, bool HasPos,
-                                       unsigned PosOpcode, unsigned NegOpcode,
-                                       const SDLoc &DL) {
+                                       SDValue InnerNeg, bool FromAdd,
+                                       bool HasPos, unsigned PosOpcode,
+                                       unsigned NegOpcode, const SDLoc &DL) {
   // fold (or (shl x, (*ext y)),
   //          (srl x, (*ext (sub 32, y)))) ->
   //   (rotl x, y) or (rotr x, (sub 32, y))
@@ -8498,10 +8502,9 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
   //   (rotr x, y) or (rotl x, (sub 32, y))
   EVT VT = Shifted.getValueType();
   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
-                     /*IsRotate*/ true)) {
+                     /*IsRotate*/ true, FromAdd))
     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
                        HasPos ? Pos : Neg);
-  }
 
   return SDValue();
 }
@@ -8514,9 +8517,9 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
 // TODO: Merge with MatchRotatePosNeg.
 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
                                        SDValue Neg, SDValue InnerPos,
-                                       SDValue InnerNeg, bool HasPos,
-                                       unsigned PosOpcode, unsigned NegOpcode,
-                                       const SDLoc &DL) {
+                                       SDValue InnerNeg, bool FromAdd,
+                                       bool HasPos, unsigned PosOpcode,
+                                       unsigned NegOpcode, const SDLoc &DL) {
   EVT VT = N0.getValueType();
   unsigned EltBits = VT.getScalarSizeInBits();
 
@@ -8527,10 +8530,10 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
   // fold (or (shl x0, (*ext (sub 32, y))),
   //          (srl x1, (*ext y))) ->
   //   (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
-  if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
+  if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1,
+                     FromAdd))
     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
                        HasPos ? Pos : Neg);
-  }
 
   // Matching the shift+xor cases, we can't easily use the xor'd shift amount
   // so for now just use the PosOpcode case if its legal.
@@ -8569,11 +8572,12 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
   return SDValue();
 }
 
-// MatchRotate - Handle an 'or' of two operands.  If this is one of the many
-// idioms for rotate, and if the target supports rotation instructions, generate
-// a rot[lr]. This also matches funnel shift patterns, similar to rotation but
-// with different shifted sources.
-SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
+// MatchRotate - Handle an 'or' or 'add' of two operands.  If this is one of the
+// many idioms for rotate, and if the target supports rotation instructions,
+// generate a rot[lr]. This also matches funnel shift patterns, similar to
+// rotation but with different shifted sources.
+SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
+                                 bool FromAdd) {
   EVT VT = LHS.getValueType();
 
   // The target must have at least one rotate/funnel flavor.
@@ -8600,9 +8604,9 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
     assert(LHS.getValueType() == RHS.getValueType());
-    if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
+    if (SDValue Rot =
+            MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL, FromAdd))
       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
-    }
   }
 
   // Match "(X shl/srl V1) & V2" where V2 may not be present.
@@ -8782,29 +8786,25 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
   }
 
   if (IsRotate && (HasROTL || HasROTR)) {
-    SDValue TryL =
-        MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
-                          RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
-    if (TryL)
+    if (SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt,
+                                         LExtOp0, RExtOp0, FromAdd, HasROTL,
+                                         ISD::ROTL, ISD::ROTR, DL))
       return TryL;
 
-    SDValue TryR =
-        MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
-                          LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
-    if (TryR)
+    if (SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
+                                         RExtOp0, LExtOp0, FromAdd, HasROTR,
+                                         ISD::ROTR, ISD::ROTL, DL))
       return TryR;
   }
 
-  SDValue TryL =
-      MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
-                        LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
-  if (TryL)
+  if (SDValue TryL = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt,
+                                       RHSShiftAmt, LExtOp0, RExtOp0, FromAdd,
+                                       HasFSHL, ISD::FSHL, ISD::FSHR, DL))
     return TryL;
 
-  SDValue TryR =
-      MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
-                        RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
-  if (TryR)
+  if (SDValue TryR = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt,
+                                       LHSShiftAmt, RExtOp0, LExtOp0, FromAdd,
+                                       HasFSHR, ISD::FSHR, ISD::FSHL, DL))
     return TryR;
 
   return SDValue();
diff --git a/llvm/test/CodeGen/AMDGPU/rotate-add.ll b/llvm/test/CodeGen/AMDGPU/rotate-add.ll
index faf89f41bdf86..53a49c9a21e2c 100644
--- a/llvm/test/CodeGen/AMDGPU/rotate-add.ll
+++ b/llvm/test/CodeGen/AMDGPU/rotate-add.ll
@@ -44,19 +44,15 @@ define i32 @test_rotl_var(i32 %x, i32 %y) {
 ; SI-LABEL: test_rotl_var:
 ; SI:       ; %bb.0:
 ; SI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; SI-NEXT:    v_lshlrev_b32_e32 v2, v1, v0
 ; SI-NEXT:    v_sub_i32_e32 v1, vcc, 32, v1
-; SI-NEXT:    v_lshrrev_b32_e32 v0, v1, v0
-; SI-NEXT:    v_add_i32_e32 v0, vcc, v2, v0
+; SI-NEXT:    v_alignbit_b32 v0, v0, v0, v1
 ; SI-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; VI-LABEL: test_rotl_var:
 ; VI:       ; %bb.0:
 ; VI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; VI-NEXT:    v_lshlrev_b32_e32 v2, v1, v0
 ; VI-NEXT:    v_sub_u32_e32 v1, vcc, 32, v1
-; VI-NEXT:    v_lshrrev_b32_e32 v0, v1, v0
-; VI-NEXT:    v_add_u32_e32 v0, vcc, v2, v0
+; VI-NEXT:    v_alignbit_b32 v0, v0, v0, v1
 ; VI-NEXT:    s_setpc_b64 s[30:31]
   %shl = shl i32 %x, %y
   %sub = sub i32 32, %y
@@ -69,19 +65,13 @@ define i32 @test_rotr_var(i32 %x, i32 %y) {
 ; SI-LABEL: test_rotr_var:
 ; SI:       ; %bb.0:
 ; SI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; SI-NEXT:    v_lshrrev_b32_e32 v2, v1, v0
-; SI-NEXT:    v_sub_i32_e32 v1, vcc, 32, v1
-; SI-NEXT:    v_lshlrev_b32_e32 v0, v1, v0
-; SI-NEXT:    v_add_i32_e32 v0, vcc, v2, v0
+; SI-NEXT:    v_alignbit_b32 v0, v0, v0, v1
 ; SI-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; VI-LABEL: test_rotr_var:
 ; VI:       ; %bb.0:
 ; VI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; VI-NEXT:    v_lshrrev_b32_e32 v2, v1, v0
-; VI-NEXT:    v_sub_u32_e32 v1, vcc, 32, v1
-; VI-NEXT:    v_lshlrev_b32_e32 v0, v1, v0
-; VI-NEXT:    v_add_u32_e32 v0, vcc, v2, v0
+; VI-NEXT:    v_alignbit_b32 v0, v0, v0, v1
 ; VI-NEXT:    s_setpc_b64 s[30:31]
   %shr = lshr i32 %x, %y
   %sub = sub i32 32, %y
@@ -174,21 +164,13 @@ define i32 @test_fshr_special_case(i32 %x0, i32 %x1, i32 %y) {
 ; SI-LABEL: test_fshr_special_case:
 ; SI:       ; %bb.0:
 ; SI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; SI-NEXT:    v_lshrrev_b32_e32 v1, v2, v1
-; SI-NEXT:    v_lshlrev_b32_e32 v0, 1, v0
-; SI-NEXT:    v_xor_b32_e32 v2, 31, v2
-; SI-NEXT:    v_lshlrev_b32_e32 v0, v2, v0
-; SI-NEXT:    v_add_i32_e32 v0, vcc, v1, v0
+; SI-NEXT:    v_alignbit_b32 v0, v0, v1, v2
 ; SI-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; VI-LABEL: test_fshr_special_case:
 ; VI:       ; %bb.0:
 ; VI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; VI-NEXT:    v_lshrrev_b32_e32 v1, v2, v1
-; VI-NEXT:    v_lshlrev_b32_e32 v0, 1, v0
-; VI-NEXT:    v_xor_b32_e32 v2, 31, v2
-; VI-NEXT:    v_lshlrev_b32_e32 v0, v2, v0
-; VI-NEXT:    v_add_u32_e32 v0, vcc, v1, v0
+; VI-NEXT:    v_alignbit_b32 v0, v0, v1, v2
 ; VI-NEXT:    s_setpc_b64 s[30:31]
   %shl = lshr i32 %x1, %y
   %srli = shl i32 %x0, 1
diff --git a/llvm/test/CodeGen/ARM/rotate-add.ll b/llvm/test/CodeGen/ARM/rotate-add.ll
index 9325e8b062dda..fd3055e5e2725 100644
--- a/llvm/test/CodeGen/ARM/rotate-add.ll
+++ b/llvm/test/CodeGen/ARM/rotate-add.ll
@@ -29,9 +29,8 @@ define i32 @test_simple_rotr(i32 %x) {
 define i32 @test_rotl_var(i32 %x, i32 %y) {
 ; CHECK-LABEL: test_rotl_var:
 ; CHECK:       @ %bb.0:
-; CHECK-NEXT:    lsl r2, r0, r1
 ; CHECK-NEXT:    rsb r1, r1, #32
-; CHECK-NEXT:    add r0, r2, r0, lsr r1
+; CHECK-NEXT:    ror r0, r0, r1
 ; CHECK-NEXT:    bx lr
   %shl = shl i32 %x, %y
   %sub = sub i32 32, %y
@@ -43,9 +42,7 @@ define i32 @test_rotl_var(i32 %x, i32 %y) {
 define i32 @test_rotr_var(i32 %x, i32 %y) {
 ; CHECK-LABEL: test_rotr_var:
 ; CHECK:       @ %bb.0:
-; CHECK-NEXT:    lsr r2, r0, r1
-; CHECK-NEXT:    rsb r1, r1, #32
-; CHECK-NEXT:    add r0, r2, r0, lsl r1
+; CHECK-NEXT:    ror r0, r0, r1
 ; CHECK-NEXT:    bx lr
   %shr = lshr i32 %x, %y
   %sub = sub i32 32, %y
diff --git a/llvm/test/CodeGen/NVPTX/rotate-add.ll b/llvm/test/CodeGen/NVPTX/rotate-add.ll
index c79a95958eca2..820e8000a5657 100644
--- a/llvm/test/CodeGen/NVPTX/rotate-add.ll
+++ b/llvm/test/CodeGen/NVPTX/rotate-add.ll
@@ -39,16 +39,13 @@ define i32 @test_simple_rotr(i32 %x) {
 define i32 @test_rotl_var(i32 %x, i32 %y) {
 ; CHECK-LABEL: test_rotl_var(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-NEXT:    .reg .b32 %r<4>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u32 %r1, [test_rotl_var_param_0];
 ; CHECK-NEXT:    ld.param.u32 %r2, [test_rotl_var_param_1];
-; CHECK-NEXT:    shl.b32 %r3, %r1, %r2;
-; CHECK-NEXT:    sub.s32 %r4, 32, %r2;
-; CHECK-NEXT:    shr.u32 %r5, %r1, %r4;
-; CHECK-NEXT:    add.s32 %r6, %r3, %r5;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r6;
+; CHECK-NEXT:    shf.l.wrap.b32 %r3, %r1, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
 ; CHECK-NEXT:    ret;
   %shl = shl i32 %x, %y
   %sub = sub i32 32, %y
@@ -60,16 +57,13 @@ define i32 @test_rotl_var(i32 %x, i32 %y) {
 define i32 @test_rotr_var(i32 %x, i32 %y) {
 ; CHECK-LABEL: test_rotr_var(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-NEXT:    .reg .b32 %r<4>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u32 %r1, [test_rotr_var_param_0];
 ; CHECK-NEXT:    ld.param.u32 %r2, [test_rotr_var_param_1];
-; CHECK-NEXT:    shr.u32 %r3, %r1, %r2;
-; CHECK-NEXT:    sub.s32 %r4, 32, %r2;
-; CHECK-NEXT:    shl.b32 %r5, %r1, %r4;
-; CHECK-NEXT:    add.s32 %r6, %r3, %r5;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r6;
+; CHECK-NEXT:    shf.r.wrap.b32 %r3, %r1, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
 ; CHECK-NEXT:    ret;
   %shr = lshr i32 %x, %y
   %sub = sub i32 32, %y
@@ -127,18 +121,14 @@ define i32 @test_invalid_rotr_var_and(i32 %x, i32 %y) {
 define i32 @test_fshl_special_case(i32 %x0, i32 %x1, i32 %y) {
 ; CHECK-LABEL: test_fshl_special_case(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u32 %r1, [test_fshl_special_case_param_0];
-; CHECK-NEXT:    ld.param.u32 %r2, [test_fshl_special_case_param_2];
-; CHECK-NEXT:    shl.b32 %r3, %r1, %r2;
-; CHECK-NEXT:    ld.param.u32 %r4, [test_fshl_special_case_param_1];
-; CHECK-NEXT:    shr.u32 %r5, %r4, 1;
-; CHECK-NEXT:    xor.b32 %r6, %r2, 31;
-; CHECK-NEXT:    shr.u32 %r7, %r5, %r6;
-; CHECK-NEXT:    add.s32 %r8, %r3, %r7;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r8;
+; CHECK-NEXT:    ld.param.u32 %r2, [test_fshl_special_case_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_fshl_special_case_param_2];
+; CHECK-NEXT:    shf.l.wrap.b32 %r4, %r2, %r1, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
 ; CHECK-NEXT:    ret;
   %shl = shl i32 %x0, %y
   %srli = lshr i32 %x1, 1
@@ -151,18 +141,14 @@ define i32 @test_fshl_special_case(i32 %x0, i32 %x1, i32 %y) {
 define i32 @test_fshr_special_case(i32 %x0, i32 %x1, i32 %y) {
 ; CHECK-LABEL: test_fshr_special_case(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u32 %r1, [test_fshr_special_case_param_0];
 ; CHECK-NEXT:    ld.param.u32 %r2, [test_fshr_special_case_param_1];
 ; CHECK-NEXT:    ld.param.u32 %r3, [test_fshr_special_case_param_2];
-; CHECK-NEXT:    shr.u32 %r4, %r2, %r3;
-; CHECK-NEXT:    shl.b32 %r5, %r1, 1;
-; CHECK-NEXT:    xor.b32 %r6, %r3, 31;
-; CHECK-NEXT:    shl.b32 %r7, %r5, %r6;
-; CHECK-NEXT:    add.s32 %r8, %r4, %r7;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r8;
+; CHECK-NEXT:    shf.r.wrap.b32 %r4, %r2, %r1, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
 ; CHECK-NEXT:    ret;
   %shl = lshr i32 %x1, %y
   %srli = shl i32 %x0, 1
diff --git a/llvm/test/CodeGen/X86/rotate-add.ll b/llvm/test/CodeGen/X86/rotate-add.ll
index 6e19fc20abf04..c705505bbbf2a 100644
--- a/llvm/test/CodeGen/X86/rotate-add.ll
+++ b/llvm/test/CodeGen/X86/rotate-add.ll
@@ -43,22 +43,15 @@ define i32 @test_rotl_var(i32 %x, i32 %y) {
 ; X86:       # %bb.0:
 ; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %ecx
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; X86-NEXT:    movl %eax, %edx
-; X86-NEXT:    shll %cl, %edx
-; X86-NEXT:    negb %cl
-; X86-NEXT:    shrl %cl, %eax
-; X86-NEXT:    addl %edx, %eax
+; X86-NEXT:    roll %cl, %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: test_rotl_var:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %esi, %ecx
 ; X64-NEXT:    movl %edi, %eax
-; X64-NEXT:    shll %cl, %eax
-; X64-NEXT:    negb %cl
 ; X64-NEXT:    # kill: def $cl killed $cl killed $ecx
-; X64-NEXT:    shrl %cl, %edi
-; X64-NEXT:    addl %edi, %eax
+; X64-NEXT:    roll %cl, %eax
 ; X64-NEXT:    retq
   %shl = shl i32 %x, %y
   %sub = sub i32 32, %y
@@ -72,22 +65,15 @@ define i32 @test_rotr_var(i32 %x, i32 %y) {
 ; X86:       # %bb.0:
 ; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %ecx
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; X86-NEXT:    movl %eax, %edx
-; X86-NEXT:    shrl %cl, %edx
-; X86-NEXT:    negb %cl
-; X86-NEXT:    shll %cl, %eax
-; X86-NEXT:    addl %edx, %eax
+; X86-NEXT:    rorl %cl, %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: test_rotr_var:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %esi, %ecx
 ; X64-NEXT:    movl %edi, %eax
-; X64-NEXT:    shrl %cl, %eax
-; X64-NEXT:    negb %cl
 ; X64-NEXT:    # kill: def $cl killed $cl killed $ecx
-; X64-NEXT:    shll %cl, %edi
-; X64-NEXT:    addl %edi, %eax
+; X64-NEXT:    rorl %cl, %eax
 ; X64-NEXT:    retq
   %shr = lshr i32 %x, %y
   %sub = sub i32 32, %y
@@ -159,27 +145,18 @@ define i32 @test_invalid_rotr_var_and(i32 %x, i32 %y) {
 define i32 @test_fshl_special_case(i32 %x0, i32 %x1, i32 %y) {
 ; X86-LABEL: test_fshl_special_case:
 ; X86:       # %bb.0:
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %ecx
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %edx
-; X86-NEXT:    shll %cl, %edx
-; X86-NEXT:    shrl %eax
-; X86-NEXT:    notb %cl
-; X86-NEXT:    shrl %cl, %eax
-; X86-NEXT:    addl %edx, %eax
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    shldl %cl, %edx, %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: test_fshl_special_case:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %edx, %ecx
-; X64-NEXT:    # kill: def $esi killed $esi def $rsi
-; X64-NEXT:    # kill: def $edi killed $edi def $rdi
-; X64-NEXT:    shll %cl, %edi
-; X64-NEXT:    shrl %esi
-; X64-NEXT:    notb %cl
+; X64-NEXT:    movl %edi, %eax
 ; X64-NEXT:    # kill: def $cl killed $cl killed $ecx
-; X64-NEXT:    shrl %cl, %esi
-; X64-NEXT:    leal (%rsi,%rdi), %eax
+; X64-NEXT:    shldl %cl, %esi, %eax
 ; X64-NEXT:    retq
   %shl = shl i32 %x0, %y
   %srli = lshr i32 %x1, 1
@@ -192,26 +169,18 @@ define i32 @test_fshl_special_case(i32 %x0, i32 %x1, i32 %y) {
 define i32 @test_fshr_special_case(i32 %x0, i32 %x1, i32 %y) {
 ; X86-LABEL: test_fshr_special_case:
 ; X86:       # %bb.0:
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %ecx
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %edx
-; X86-NEXT:    shrl %cl, %edx
-; X86-NEXT:    addl %eax, %eax
-; X86-NEXT:    notb %cl
-; X86-NEXT:    shll %cl, %eax
-; X86-NEXT:    addl %edx, %eax
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    shrdl %cl, %edx, %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: test_fshr_special_case:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %edx, %ecx
-; X64-NEXT:    # kill: def $edi killed $edi def $rdi
-; X64-NEXT:    shrl %cl, %esi
-; X64-NEXT:    leal (%rdi,%rdi), %eax
-; X64-NEXT:    notb %cl
+; X64-NEXT:    movl %esi, %eax
 ; X64-NEXT:    # kill: def $cl killed $cl killed $ecx
-; X64-NEXT:    shll %cl, %eax
-; X64-NEXT:    addl %esi, %eax
+; X64-NEXT:    shrdl %cl, %edi, %eax
 ; X64-NEXT:    retq
   %shl = lshr i32 %x1, %y
   %srli = shl i32 %x0, 1

>From e420a8f67e40ed060999a1b13aeecbdebb10dd30 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 19 Feb 2025 21:01:25 +0000
Subject: [PATCH 2/2] address comment comment

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 30 +++++++++----------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0d74199bd4b36..8136f1794775e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -8493,12 +8493,12 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
                                        SDValue InnerNeg, bool FromAdd,
                                        bool HasPos, unsigned PosOpcode,
                                        unsigned NegOpcode, const SDLoc &DL) {
-  // fold (or (shl x, (*ext y)),
-  //          (srl x, (*ext (sub 32, y)))) ->
+  // fold (or/add (shl x, (*ext y)),
+  //              (srl x, (*ext (sub 32, y)))) ->
   //   (rotl x, y) or (rotr x, (sub 32, y))
   //
-  // fold (or (shl x, (*ext (sub 32, y))),
-  //          (srl x, (*ext y))) ->
+  // fold (or/add (shl x, (*ext (sub 32, y))),
+  //              (srl x, (*ext y))) ->
   //   (rotr x, y) or (rotl x, (sub 32, y))
   EVT VT = Shifted.getValueType();
   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
@@ -8523,12 +8523,12 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
   EVT VT = N0.getValueType();
   unsigned EltBits = VT.getScalarSizeInBits();
 
-  // fold (or (shl x0, (*ext y)),
-  //          (srl x1, (*ext (sub 32, y)))) ->
+  // fold (or/add (shl x0, (*ext y)),
+  //              (srl x1, (*ext (sub 32, y)))) ->
   //   (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
   //
-  // fold (or (shl x0, (*ext (sub 32, y))),
-  //          (srl x1, (*ext y))) ->
+  // fold (or/add (shl x0, (*ext (sub 32, y))),
+  //              (srl x1, (*ext y))) ->
   //   (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
   if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1,
                      FromAdd))
@@ -8540,7 +8540,7 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
   // TODO: When can we use the NegOpcode case?
   if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
     SDValue X;
-    // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
+    // fold (or/add (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
     //   -> (fshl x0, x1, y)
     if (sd_match(N1, m_Srl(m_Value(X), m_One())) &&
         sd_match(InnerNeg,
@@ -8549,7 +8549,7 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
       return DAG.getNode(ISD::FSHL, DL, VT, N0, X, Pos);
     }
 
-    // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
+    // fold (or/add (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
     //   -> (fshr x0, x1, y)
     if (sd_match(N0, m_Shl(m_Value(X), m_One())) &&
         sd_match(InnerPos,
@@ -8558,7 +8558,7 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
       return DAG.getNode(ISD::FSHR, DL, VT, X, N1, Neg);
     }
 
-    // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
+    // fold (or/add (shl (add x0, x0), (xor y, 31)), (srl x1, y))
     //   -> (fshr x0, x1, y)
     // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
     if (sd_match(N0, m_Add(m_Value(X), m_Deferred(X))) &&
@@ -8740,10 +8740,10 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
     return SDValue(); // Requires funnel shift support.
   }
 
-  // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
-  // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
-  // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
-  // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
+  // fold (or/add (shl x, C1), (srl x, C2)) -> (rotl x, C1)
+  // fold (or/add (shl x, C1), (srl x, C2)) -> (rotr x, C2)
+  // fold (or/add (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
+  // fold (or/add (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
   // iff C1+C2 == EltSizeInBits
   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
     SDValue Res;



More information about the llvm-commits mailing list