[llvm] [AArch64] Fold lsl + lsr + orr to rev for half-width shifts (PR #159953)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Sep 20 15:40:11 PDT 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 183fc763c..16ce1b715 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -19965,7 +19965,8 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) {
}
// Fold lsl + lsr + orr to rev for half-width shifts
-// Pattern: orr(lsl(x, shift), lsr(x, shift)) -> rev(x) when shift == half_bitwidth
+// Pattern: orr(lsl(x, shift), lsr(x, shift)) -> rev(x) when shift ==
+// half_bitwidth
static SDValue performLSL_LSR_ORRCombine(SDNode *N, SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
if (!Subtarget->hasSVE())
@@ -19977,7 +19978,7 @@ static SDValue performLSL_LSR_ORRCombine(SDNode *N, SelectionDAG &DAG,
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
-
+
// Check if one operand is LSL and the other is LSR
SDValue LSL, LSR;
if (LHS.getOpcode() == ISD::SHL && RHS.getOpcode() == ISD::SRL) {
@@ -19996,7 +19997,7 @@ static SDValue performLSL_LSR_ORRCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
// Check that both shifts have the same constant amount
- if (!isa<ConstantSDNode>(LSL.getOperand(1)) ||
+ if (!isa<ConstantSDNode>(LSL.getOperand(1)) ||
!isa<ConstantSDNode>(LSR.getOperand(1)))
return SDValue();
@@ -20013,24 +20014,28 @@ static SDValue performLSL_LSR_ORRCombine(SDNode *N, SelectionDAG &DAG,
if (ShiftAmt != EltSize / 2)
return SDValue();
- // Determine the appropriate REV instruction based on element size and shift amount
+ // Determine the appropriate REV instruction based on element size and shift
+ // amount
unsigned RevOp;
switch (EltSize) {
case 16:
if (ShiftAmt == 8)
- RevOp = AArch64ISD::BSWAP_MERGE_PASSTHRU; // 16-bit elements, 8-bit shift -> revb
+ RevOp = AArch64ISD::BSWAP_MERGE_PASSTHRU; // 16-bit elements, 8-bit shift
+ // -> revb
else
return SDValue();
break;
case 32:
if (ShiftAmt == 16)
- RevOp = AArch64ISD::REVH_MERGE_PASSTHRU; // 32-bit elements, 16-bit shift -> revh
+ RevOp = AArch64ISD::REVH_MERGE_PASSTHRU; // 32-bit elements, 16-bit shift
+ // -> revh
else
return SDValue();
break;
case 64:
if (ShiftAmt == 32)
- RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 64-bit elements, 32-bit shift -> revw
+ RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 64-bit elements, 32-bit shift
+ // -> revw
else
return SDValue();
break;
@@ -20042,13 +20047,13 @@ static SDValue performLSL_LSR_ORRCombine(SDNode *N, SelectionDAG &DAG,
SDLoc DL(N);
SDValue Pg = getPredicateForVector(DAG, DL, VT);
SDValue Undef = DAG.getUNDEF(VT);
-
+
return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef);
}
// Fold bswap to correct rev instruction for scalable vectors
-// DAGCombiner converts lsl+lsr+orr with 8-bit shift to BSWAP, but for scalable vectors
-// we need to use the correct REV instruction based on element size
+// DAGCombiner converts lsl+lsr+orr with 8-bit shift to BSWAP, but for scalable
+// vectors we need to use the correct REV instruction based on element size
static SDValue performBSWAPCombine(SDNode *N, SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
LLVM_DEBUG(dbgs() << "BSWAP combine called\n");
@@ -20066,12 +20071,13 @@ static SDValue performBSWAPCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
unsigned EltSize = EltVT.getSizeInBits();
-
+
// For scalable vectors with 16-bit elements, BSWAP should use REVB, not REVH
- // REVH is not available for 16-bit elements, only for 32-bit and 64-bit elements
- // For 16-bit elements, REVB (byte reverse) is equivalent to halfword reverse
+ // REVH is not available for 16-bit elements, only for 32-bit and 64-bit
+ // elements For 16-bit elements, REVB (byte reverse) is equivalent to halfword
+ // reverse
if (EltSize != 16)
- return SDValue(); // Use default BSWAP lowering for other sizes
+ return SDValue(); // Use default BSWAP lowering for other sizes
// The current BSWAP lowering is already correct for 16-bit elements
// BSWAP_MERGE_PASSTHRU maps to REVB which is correct for 16-bit elements
@@ -20095,7 +20101,7 @@ static SDValue performROTLCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
uint64_t RotAmt = N->getConstantOperandVal(1);
-
+
// Check if rotation amount equals half the bitwidth
EVT EltVT = VT.getVectorElementType();
if (!EltVT.isSimple())
@@ -20111,10 +20117,12 @@ static SDValue performROTLCombine(SDNode *N, SelectionDAG &DAG,
case 16:
return SDValue(); // 16-bit case handled by BSWAP
case 32:
- RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 32-bit elements, 16-bit rotation -> revw
+ RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 32-bit elements, 16-bit rotation
+ // -> revw
break;
case 64:
- RevOp = AArch64ISD::REVD_MERGE_PASSTHRU; // 64-bit elements, 32-bit rotation -> revd
+ RevOp = AArch64ISD::REVD_MERGE_PASSTHRU; // 64-bit elements, 32-bit rotation
+ // -> revd
break;
default:
return SDValue();
@@ -20125,14 +20133,16 @@ static SDValue performROTLCombine(SDNode *N, SelectionDAG &DAG,
SDValue Src = N->getOperand(0);
SDValue Pg = getPredicateForVector(DAG, DL, VT);
SDValue Undef = DAG.getUNDEF(VT);
-
+
return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef);
}
-// Fold predicated shl + srl + orr to rev for half-width shifts on scalable vectors
-// Pattern: orr(AArch64ISD::SHL_PRED(pg, x, shift), AArch64ISD::SRL_PRED(pg, x, shift)) -> rev(x) when shift == half_bitwidth
-static SDValue performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG,
- const AArch64Subtarget *Subtarget) {
+// Fold predicated shl + srl + orr to rev for half-width shifts on scalable
+// vectors Pattern: orr(AArch64ISD::SHL_PRED(pg, x, shift),
+// AArch64ISD::SRL_PRED(pg, x, shift)) -> rev(x) when shift == half_bitwidth
+static SDValue
+performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
if (!Subtarget->hasSVE())
return SDValue();
@@ -20142,13 +20152,15 @@ static SDValue performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG,
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
-
+
// Check if one operand is predicated SHL and the other is predicated SRL
SDValue SHL, SRL;
- if (LHS.getOpcode() == AArch64ISD::SHL_PRED && RHS.getOpcode() == AArch64ISD::SRL_PRED) {
+ if (LHS.getOpcode() == AArch64ISD::SHL_PRED &&
+ RHS.getOpcode() == AArch64ISD::SRL_PRED) {
SHL = LHS;
SRL = RHS;
- } else if (LHS.getOpcode() == AArch64ISD::SRL_PRED && RHS.getOpcode() == AArch64ISD::SHL_PRED) {
+ } else if (LHS.getOpcode() == AArch64ISD::SRL_PRED &&
+ RHS.getOpcode() == AArch64ISD::SHL_PRED) {
SHL = RHS;
SRL = LHS;
} else {
@@ -20159,11 +20171,11 @@ static SDValue performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG,
SDValue SHLPred = SHL.getOperand(0);
SDValue SHLSrc = SHL.getOperand(1);
SDValue SHLAmt = SHL.getOperand(2);
-
+
SDValue SRLPred = SRL.getOperand(0);
SDValue SRLSrc = SRL.getOperand(1);
SDValue SRLAmt = SRL.getOperand(2);
-
+
if (SHLPred != SRLPred || SHLSrc != SRLSrc || SHLAmt != SRLAmt)
return SDValue();
@@ -20171,7 +20183,8 @@ static SDValue performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG,
if (!isa<ConstantSDNode>(SHLAmt->getOperand(0))) // For splat_vector
return SDValue();
- uint64_t ShiftAmt = cast<ConstantSDNode>(SHLAmt->getOperand(0))->getZExtValue();
+ uint64_t ShiftAmt =
+ cast<ConstantSDNode>(SHLAmt->getOperand(0))->getZExtValue();
// Check if shift amount equals half the bitwidth
EVT EltVT = VT.getVectorElementType();
@@ -20182,20 +20195,23 @@ static SDValue performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG,
if (ShiftAmt != EltSize / 2)
return SDValue();
- // Determine the appropriate REV instruction based on element size and shift amount
+ // Determine the appropriate REV instruction based on element size and shift
+ // amount
unsigned RevOp;
switch (EltSize) {
case 16:
return SDValue(); // 16-bit case handled by BSWAP
case 32:
if (ShiftAmt == 16)
- RevOp = AArch64ISD::REVH_MERGE_PASSTHRU; // 32-bit elements, 16-bit shift -> revh
+ RevOp = AArch64ISD::REVH_MERGE_PASSTHRU; // 32-bit elements, 16-bit shift
+ // -> revh
else
return SDValue();
break;
case 64:
if (ShiftAmt == 32)
- RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 64-bit elements, 32-bit shift -> revw
+ RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 64-bit elements, 32-bit shift
+ // -> revw
else
return SDValue();
break;
@@ -20208,7 +20224,7 @@ static SDValue performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG,
SDValue Pg = SHLPred;
SDValue Src = SHLSrc;
SDValue Undef = DAG.getUNDEF(VT);
-
+
return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef);
}
@@ -20223,7 +20239,7 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
if (SDValue R = performLSL_LSR_ORRCombine(N, DAG, Subtarget))
return R;
- // Try the predicated shift combine for SVE
+ // Try the predicated shift combine for SVE
if (SDValue R = performSVE_SHL_SRL_ORRCombine(N, DAG, Subtarget))
return R;
``````````
</details>
https://github.com/llvm/llvm-project/pull/159953
More information about the llvm-commits
mailing list