[llvm] [DAGCombiner] Attempt to fold 'add' nodes to funnel-shift or rotate (PR #125612)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 3 17:12:36 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: Alex MacLean (AlexMaclean)
<details>
<summary>Changes</summary>
Almost all of the rotate idioms that are valid for an 'or' are also valid when the halves are combined with an 'add'. Further, many of these cases are not handled by common bits tracking meaning that the 'add' is not converted to a 'disjoint or'.
---
Full diff: https://github.com/llvm/llvm-project/pull/125612.diff
2 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+53-43)
- (added) llvm/test/CodeGen/NVPTX/add-rotate.ll (+118)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index f4caaf426de6a07..d671a6d555621c1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -662,14 +662,15 @@ namespace {
bool DemandHighBits = true);
SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
- SDValue InnerPos, SDValue InnerNeg, bool HasPos,
- unsigned PosOpcode, unsigned NegOpcode,
- const SDLoc &DL);
+ SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
+ bool HasPos, unsigned PosOpcode,
+ unsigned NegOpcode, const SDLoc &DL);
SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
- SDValue InnerPos, SDValue InnerNeg, bool HasPos,
- unsigned PosOpcode, unsigned NegOpcode,
- const SDLoc &DL);
- SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
+ SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
+ bool HasPos, unsigned PosOpcode,
+ unsigned NegOpcode, const SDLoc &DL);
+ SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
+ bool FromAdd);
SDValue MatchLoadCombine(SDNode *N);
SDValue mergeTruncStores(StoreSDNode *N);
SDValue reduceLoadWidth(SDNode *N);
@@ -2992,6 +2993,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
return V;
+ if (SDValue V = MatchRotate(N0, N1, SDLoc(N), /*FromAdd=*/true))
+ return V;
+
// Try to match AVGFLOOR fixedwidth pattern
if (SDValue V = foldAddToAvg(N, DL))
return V;
@@ -8161,7 +8165,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
return V;
// See if this is some rotate idiom.
- if (SDValue Rot = MatchRotate(N0, N1, DL))
+ if (SDValue Rot = MatchRotate(N0, N1, DL, /*FromAdd=*/false))
return Rot;
if (SDValue Load = MatchLoadCombine(N))
@@ -8350,7 +8354,7 @@ static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
// The IsRotate flag should be set when the LHS of both shifts is the same.
// Otherwise if matching a general funnel shift, it should be clear.
static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
- SelectionDAG &DAG, bool IsRotate) {
+ SelectionDAG &DAG, bool IsRotate, bool FromAdd) {
const auto &TLI = DAG.getTargetLoweringInfo();
// If EltSize is a power of 2 then:
//
@@ -8389,7 +8393,7 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
// NOTE: We can only do this when matching operations which won't modify the
// least Log2(EltSize) significant bits and not a general funnel shift.
unsigned MaskLoBits = 0;
- if (IsRotate && isPowerOf2_64(EltSize)) {
+ if (IsRotate && !FromAdd && isPowerOf2_64(EltSize)) {
unsigned Bits = Log2_64(EltSize);
unsigned NegBits = Neg.getScalarValueSizeInBits();
if (NegBits >= Bits) {
@@ -8472,9 +8476,9 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
// Neg with outer conversions stripped away.
SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
SDValue Neg, SDValue InnerPos,
- SDValue InnerNeg, bool HasPos,
- unsigned PosOpcode, unsigned NegOpcode,
- const SDLoc &DL) {
+ SDValue InnerNeg, bool FromAdd,
+ bool HasPos, unsigned PosOpcode,
+ unsigned NegOpcode, const SDLoc &DL) {
// fold (or (shl x, (*ext y)),
// (srl x, (*ext (sub 32, y)))) ->
// (rotl x, y) or (rotr x, (sub 32, y))
@@ -8484,10 +8488,9 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
// (rotr x, y) or (rotl x, (sub 32, y))
EVT VT = Shifted.getValueType();
if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
- /*IsRotate*/ true)) {
+ /*IsRotate*/ true, FromAdd))
return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
HasPos ? Pos : Neg);
- }
return SDValue();
}
@@ -8500,9 +8503,9 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
// TODO: Merge with MatchRotatePosNeg.
SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
SDValue Neg, SDValue InnerPos,
- SDValue InnerNeg, bool HasPos,
- unsigned PosOpcode, unsigned NegOpcode,
- const SDLoc &DL) {
+ SDValue InnerNeg, bool FromAdd,
+ bool HasPos, unsigned PosOpcode,
+ unsigned NegOpcode, const SDLoc &DL) {
EVT VT = N0.getValueType();
unsigned EltBits = VT.getScalarSizeInBits();
@@ -8513,10 +8516,10 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
// fold (or (shl x0, (*ext (sub 32, y))),
// (srl x1, (*ext y))) ->
// (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
- if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
+ if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1,
+ FromAdd))
return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
HasPos ? Pos : Neg);
- }
// Matching the shift+xor cases, we can't easily use the xor'd shift amount
// so for now just use the PosOpcode case if its legal.
@@ -8561,11 +8564,12 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
return SDValue();
}
-// MatchRotate - Handle an 'or' of two operands. If this is one of the many
-// idioms for rotate, and if the target supports rotation instructions, generate
-// a rot[lr]. This also matches funnel shift patterns, similar to rotation but
-// with different shifted sources.
-SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
+// MatchRotate - Handle an 'or' or 'add' of two operands. If this is one of the
+// many idioms for rotate, and if the target supports rotation instructions,
+// generate a rot[lr]. This also matches funnel shift patterns, similar to
+// rotation but with different shifted sources.
+SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
+ bool FromAdd) {
EVT VT = LHS.getValueType();
// The target must have at least one rotate/funnel flavor.
@@ -8592,9 +8596,9 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
assert(LHS.getValueType() == RHS.getValueType());
- if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
+ if (SDValue Rot =
+ MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL, FromAdd))
return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
- }
}
// Match "(X shl/srl V1) & V2" where V2 may not be present.
@@ -8773,30 +8777,36 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
RExtOp0 = RHSShiftAmt.getOperand(0);
}
+ // // If we are here from visitADD() we must ensure the Right-Shift Amt is
+ // // non-zero when the pattern includes AND op. So, allow optimizing to ROTL
+ // // only if it is recognized as a non-zero constant. Same for ROTR.
+ // auto RotateSafe = [FromAdd](const SDValue& ExtOp0) {
+ // if (!FromAdd || ExtOp0.getOpcode() != ISD::AND)
+ // return true;
+ // auto *ExtOp0Const = dyn_cast<ConstantSDNode>(ExtOp0);
+ // return ExtOp0Const && !ExtOp0Const->isZero();
+ // };
+
if (IsRotate && (HasROTL || HasROTR)) {
- SDValue TryL =
- MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
- RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
- if (TryL)
+ if (SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt,
+ LExtOp0, RExtOp0, FromAdd, HasROTL,
+ ISD::ROTL, ISD::ROTR, DL))
return TryL;
- SDValue TryR =
- MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
- LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
- if (TryR)
+ if (SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
+ RExtOp0, LExtOp0, FromAdd, HasROTR,
+ ISD::ROTR, ISD::ROTL, DL))
return TryR;
}
- SDValue TryL =
- MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
- LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
- if (TryL)
+ if (SDValue TryL = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt,
+ RHSShiftAmt, LExtOp0, RExtOp0, FromAdd,
+ HasFSHL, ISD::FSHL, ISD::FSHR, DL))
return TryL;
- SDValue TryR =
- MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
- RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
- if (TryR)
+ if (SDValue TryR = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt,
+ LHSShiftAmt, RExtOp0, LExtOp0, FromAdd,
+ HasFSHR, ISD::FSHR, ISD::FSHL, DL))
return TryR;
return SDValue();
diff --git a/llvm/test/CodeGen/NVPTX/add-rotate.ll b/llvm/test/CodeGen/NVPTX/add-rotate.ll
new file mode 100644
index 000000000000000..f777ac964154343
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/add-rotate.ll
@@ -0,0 +1,118 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_50 | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+
+define i32 @test_rotl(i32 %x) {
+; CHECK-LABEL: test_rotl(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_rotl_param_0];
+; CHECK-NEXT: shf.l.wrap.b32 %r2, %r1, %r1, 7;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %shl = shl i32 %x, 7
+ %shr = lshr i32 %x, 25
+ %add = add i32 %shl, %shr
+ ret i32 %add
+}
+
+define i32 @test_rotr(i32 %x) {
+; CHECK-LABEL: test_rotr(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_rotr_param_0];
+; CHECK-NEXT: shf.l.wrap.b32 %r2, %r1, %r1, 25;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %shr = lshr i32 %x, 7
+ %shl = shl i32 %x, 25
+ %add = add i32 %shr, %shl
+ ret i32 %add
+}
+
+define i32 @test_rotl_var(i32 %x, i32 %y) {
+; CHECK-LABEL: test_rotl_var(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_rotl_var_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_rotl_var_param_1];
+; CHECK-NEXT: shf.l.wrap.b32 %r3, %r1, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %shl = shl i32 %x, %y
+ %sub = sub i32 32, %y
+ %shr = lshr i32 %x, %sub
+ %add = add i32 %shl, %shr
+ ret i32 %add
+}
+
+define i32 @test_rotr_var(i32 %x, i32 %y) {
+; CHECK-LABEL: test_rotr_var(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_rotr_var_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_rotr_var_param_1];
+; CHECK-NEXT: shf.r.wrap.b32 %r3, %r1, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %shr = lshr i32 %x, %y
+ %sub = sub i32 32, %y
+ %shl = shl i32 %x, %sub
+ %add = add i32 %shr, %shl
+ ret i32 %add
+}
+
+define i32 @test_rotl_var_and(i32 %x, i32 %y) {
+; CHECK-LABEL: test_rotl_var_and(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<8>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_rotl_var_and_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_rotl_var_and_param_1];
+; CHECK-NEXT: shl.b32 %r3, %r1, %r2;
+; CHECK-NEXT: neg.s32 %r4, %r2;
+; CHECK-NEXT: and.b32 %r5, %r4, 31;
+; CHECK-NEXT: shr.u32 %r6, %r1, %r5;
+; CHECK-NEXT: add.s32 %r7, %r6, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r7;
+; CHECK-NEXT: ret;
+ %shr = shl i32 %x, %y
+ %sub = sub nsw i32 0, %y
+ %and = and i32 %sub, 31
+ %shl = lshr i32 %x, %and
+ %add = add i32 %shl, %shr
+ ret i32 %add
+}
+
+define i32 @test_rotr_var_and(i32 %x, i32 %y) {
+; CHECK-LABEL: test_rotr_var_and(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<8>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_rotr_var_and_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_rotr_var_and_param_1];
+; CHECK-NEXT: shr.u32 %r3, %r1, %r2;
+; CHECK-NEXT: neg.s32 %r4, %r2;
+; CHECK-NEXT: and.b32 %r5, %r4, 31;
+; CHECK-NEXT: shl.b32 %r6, %r1, %r5;
+; CHECK-NEXT: add.s32 %r7, %r3, %r6;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r7;
+; CHECK-NEXT: ret;
+ %shr = lshr i32 %x, %y
+ %sub = sub nsw i32 0, %y
+ %and = and i32 %sub, 31
+ %shl = shl i32 %x, %and
+ %add = add i32 %shr, %shl
+ ret i32 %add
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/125612
More information about the llvm-commits
mailing list