[llvm] [AArch64] Fold lsl + lsr + orr to rev for half-width shifts (PR #159953)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Sep 20 15:36:21 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: AZero13 (AZero13)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/159953.diff
2 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+261)
- (added) llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll (+97)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cd7f0e719ad0c..183fc763cd2e9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -19964,6 +19964,254 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) {
CSel0.getOperand(1), getCondCode(DAG, CC1), CCmp);
}
+// Fold lsl + lsr + orr to rev for half-width shifts
+// Pattern: orr(lsl(x, shift), lsr(x, shift)) -> rev(x) when shift == half_bitwidth
+static SDValue performLSL_LSR_ORRCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ if (!Subtarget->hasSVE())
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+ if (!VT.isScalableVector())
+ return SDValue();
+
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+
+ // Check if one operand is LSL and the other is LSR
+ SDValue LSL, LSR;
+ if (LHS.getOpcode() == ISD::SHL && RHS.getOpcode() == ISD::SRL) {
+ LSL = LHS;
+ LSR = RHS;
+ } else if (LHS.getOpcode() == ISD::SRL && RHS.getOpcode() == ISD::SHL) {
+ LSL = RHS;
+ LSR = LHS;
+ } else {
+ return SDValue();
+ }
+
+ // Check that both shifts operate on the same source
+ SDValue Src = LSL.getOperand(0);
+ if (Src != LSR.getOperand(0))
+ return SDValue();
+
+ // Check that both shifts have the same constant amount
+ if (!isa<ConstantSDNode>(LSL.getOperand(1)) ||
+ !isa<ConstantSDNode>(LSR.getOperand(1)))
+ return SDValue();
+
+ uint64_t ShiftAmt = LSL.getConstantOperandVal(1);
+ if (ShiftAmt != LSR.getConstantOperandVal(1))
+ return SDValue();
+
+ // Check if shift amount equals half the bitwidth
+ EVT EltVT = VT.getVectorElementType();
+ if (!EltVT.isSimple())
+ return SDValue();
+
+ unsigned EltSize = EltVT.getSizeInBits();
+ if (ShiftAmt != EltSize / 2)
+ return SDValue();
+
+ // Determine the appropriate REV instruction based on element size and shift amount
+ unsigned RevOp;
+ switch (EltSize) {
+ case 16:
+ if (ShiftAmt == 8)
+ RevOp = AArch64ISD::BSWAP_MERGE_PASSTHRU; // 16-bit elements, 8-bit shift -> revb
+ else
+ return SDValue();
+ break;
+ case 32:
+ if (ShiftAmt == 16)
+ RevOp = AArch64ISD::REVH_MERGE_PASSTHRU; // 32-bit elements, 16-bit shift -> revh
+ else
+ return SDValue();
+ break;
+ case 64:
+ if (ShiftAmt == 32)
+ RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 64-bit elements, 32-bit shift -> revw
+ else
+ return SDValue();
+ break;
+ default:
+ return SDValue();
+ }
+
+ // Create the REV instruction
+ SDLoc DL(N);
+ SDValue Pg = getPredicateForVector(DAG, DL, VT);
+ SDValue Undef = DAG.getUNDEF(VT);
+
+ return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef);
+}
+
+// Fold bswap to correct rev instruction for scalable vectors
+// DAGCombiner converts lsl+lsr+orr with 8-bit shift to BSWAP, but for scalable vectors
+// we need to use the correct REV instruction based on element size
+static SDValue performBSWAPCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ LLVM_DEBUG(dbgs() << "BSWAP combine called\n");
+ if (!Subtarget->hasSVE())
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+ if (!VT.isScalableVector())
+ return SDValue();
+
+ LLVM_DEBUG(dbgs() << "BSWAP combine called for scalable vector\n");
+
+ EVT EltVT = VT.getVectorElementType();
+ if (!EltVT.isSimple())
+ return SDValue();
+
+ unsigned EltSize = EltVT.getSizeInBits();
+
+ // For scalable vectors with 16-bit elements, BSWAP should use REVB, not REVH
+ // REVH is not available for 16-bit elements, only for 32-bit and 64-bit elements
+ // For 16-bit elements, REVB (byte reverse) is equivalent to halfword reverse
+ if (EltSize != 16)
+ return SDValue(); // Use default BSWAP lowering for other sizes
+
+ // The current BSWAP lowering is already correct for 16-bit elements
+ // BSWAP_MERGE_PASSTHRU maps to REVB which is correct for 16-bit elements
+ return SDValue();
+}
+
+// Fold rotl to rev instruction for half-width rotations on scalable vectors
+// Pattern: rotl(x, half_bitwidth) -> rev(x) for scalable vectors
+static SDValue performROTLCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ LLVM_DEBUG(dbgs() << "ROTL combine called\n");
+ if (!Subtarget->hasSVE())
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+ if (!VT.isScalableVector())
+ return SDValue();
+
+ // Check that the rotation amount is a constant
+ if (!isa<ConstantSDNode>(N->getOperand(1)))
+ return SDValue();
+
+ uint64_t RotAmt = N->getConstantOperandVal(1);
+
+ // Check if rotation amount equals half the bitwidth
+ EVT EltVT = VT.getVectorElementType();
+ if (!EltVT.isSimple())
+ return SDValue();
+
+ unsigned EltSize = EltVT.getSizeInBits();
+ if (RotAmt != EltSize / 2)
+ return SDValue();
+
+ // Determine the appropriate REV instruction based on element size
+ unsigned RevOp;
+ switch (EltSize) {
+ case 16:
+ return SDValue(); // 16-bit case handled by BSWAP
+ case 32:
+ RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 32-bit elements, 16-bit rotation -> revw
+ break;
+ case 64:
+ RevOp = AArch64ISD::REVD_MERGE_PASSTHRU; // 64-bit elements, 32-bit rotation -> revd
+ break;
+ default:
+ return SDValue();
+ }
+
+ // Create the REV instruction
+ SDLoc DL(N);
+ SDValue Src = N->getOperand(0);
+ SDValue Pg = getPredicateForVector(DAG, DL, VT);
+ SDValue Undef = DAG.getUNDEF(VT);
+
+ return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef);
+}
+
+// Fold predicated shl + srl + orr to rev for half-width shifts on scalable vectors
+// Pattern: orr(AArch64ISD::SHL_PRED(pg, x, shift), AArch64ISD::SRL_PRED(pg, x, shift)) -> rev(x) when shift == half_bitwidth
+static SDValue performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ if (!Subtarget->hasSVE())
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+ if (!VT.isScalableVector())
+ return SDValue();
+
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+
+ // Check if one operand is predicated SHL and the other is predicated SRL
+ SDValue SHL, SRL;
+ if (LHS.getOpcode() == AArch64ISD::SHL_PRED && RHS.getOpcode() == AArch64ISD::SRL_PRED) {
+ SHL = LHS;
+ SRL = RHS;
+ } else if (LHS.getOpcode() == AArch64ISD::SRL_PRED && RHS.getOpcode() == AArch64ISD::SHL_PRED) {
+ SHL = RHS;
+ SRL = LHS;
+ } else {
+ return SDValue();
+ }
+
+ // Check that both shifts operate on the same predicate and source
+ SDValue SHLPred = SHL.getOperand(0);
+ SDValue SHLSrc = SHL.getOperand(1);
+ SDValue SHLAmt = SHL.getOperand(2);
+
+ SDValue SRLPred = SRL.getOperand(0);
+ SDValue SRLSrc = SRL.getOperand(1);
+ SDValue SRLAmt = SRL.getOperand(2);
+
+ if (SHLPred != SRLPred || SHLSrc != SRLSrc || SHLAmt != SRLAmt)
+ return SDValue();
+
+ // Check that the shift amount is a constant
+ if (!isa<ConstantSDNode>(SHLAmt->getOperand(0))) // For splat_vector
+ return SDValue();
+
+ uint64_t ShiftAmt = cast<ConstantSDNode>(SHLAmt->getOperand(0))->getZExtValue();
+
+ // Check if shift amount equals half the bitwidth
+ EVT EltVT = VT.getVectorElementType();
+ if (!EltVT.isSimple())
+ return SDValue();
+
+ unsigned EltSize = EltVT.getSizeInBits();
+ if (ShiftAmt != EltSize / 2)
+ return SDValue();
+
+ // Determine the appropriate REV instruction based on element size and shift amount
+ unsigned RevOp;
+ switch (EltSize) {
+ case 16:
+ return SDValue(); // 16-bit case handled by BSWAP
+ case 32:
+ if (ShiftAmt == 16)
+ RevOp = AArch64ISD::REVH_MERGE_PASSTHRU; // 32-bit elements, 16-bit shift -> revh
+ else
+ return SDValue();
+ break;
+ case 64:
+ if (ShiftAmt == 32)
+ RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 64-bit elements, 32-bit shift -> revw
+ else
+ return SDValue();
+ break;
+ default:
+ return SDValue();
+ }
+
+ // Create the REV instruction
+ SDLoc DL(N);
+ SDValue Pg = SHLPred;
+ SDValue Src = SHLSrc;
+ SDValue Undef = DAG.getUNDEF(VT);
+
+ return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef);
+}
+
static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget,
const AArch64TargetLowering &TLI) {
@@ -19972,6 +20220,13 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
if (SDValue R = performANDORCSELCombine(N, DAG))
return R;
+ if (SDValue R = performLSL_LSR_ORRCombine(N, DAG, Subtarget))
+ return R;
+
+ // Try the predicated shift combine for SVE
+ if (SDValue R = performSVE_SHL_SRL_ORRCombine(N, DAG, Subtarget))
+ return R;
+
return SDValue();
}
@@ -27592,6 +27847,12 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performFpToIntCombine(N, DAG, DCI, Subtarget);
case ISD::OR:
return performORCombine(N, DCI, Subtarget, *this);
+ case ISD::BSWAP:
+ return performBSWAPCombine(N, DAG, Subtarget);
+ case AArch64ISD::BSWAP_MERGE_PASSTHRU:
+ return performBSWAPCombine(N, DAG, Subtarget);
+ case ISD::ROTL:
+ return performROTLCombine(N, DAG, Subtarget);
case ISD::AND:
return performANDCombine(N, DCI);
case ISD::FADD:
diff --git a/llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll b/llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll
new file mode 100644
index 0000000000000..8abfbdcc3edcb
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll
@@ -0,0 +1,97 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck --check-prefixes=CHECK,CHECK-SVE %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -force-streaming < %s | FileCheck --check-prefixes=CHECK,CHECK-SME %s
+
+; Test the optimization that folds lsl + lsr + orr to rev for half-width shifts
+
+; Test case 1: 16-bit elements with 8-bit shift -> revh
+define <vscale x 8 x i16> @lsl_lsr_orr_revh_i16(<vscale x 8 x i16> %x) {
+; CHECK-LABEL: lsl_lsr_orr_revh_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: revb z0.h, p0/m, z0.h
+; CHECK-NEXT: ret
+ %lsl = shl <vscale x 8 x i16> %x, splat(i16 8)
+ %lsr = lshr <vscale x 8 x i16> %x, splat(i16 8)
+ %orr = or <vscale x 8 x i16> %lsl, %lsr
+ ret <vscale x 8 x i16> %orr
+}
+
+; Test case 2: 32-bit elements with 16-bit shift -> revh
+define <vscale x 4 x i32> @lsl_lsr_orr_revw_i32(<vscale x 4 x i32> %x) {
+; CHECK-SVE-LABEL: lsl_lsr_orr_revw_i32:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: ptrue p0.s
+; CHECK-SVE-NEXT: revh z0.s, p0/m, z0.s
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SME-LABEL: lsl_lsr_orr_revw_i32:
+; CHECK-SME: // %bb.0:
+; CHECK-SME-NEXT: movi v1.2d, #0000000000000000
+; CHECK-SME-NEXT: xar z0.s, z0.s, z1.s, #16
+; CHECK-SME-NEXT: ret
+ %lsl = shl <vscale x 4 x i32> %x, splat(i32 16)
+ %lsr = lshr <vscale x 4 x i32> %x, splat(i32 16)
+ %orr = or <vscale x 4 x i32> %lsl, %lsr
+ ret <vscale x 4 x i32> %orr
+}
+
+; Test case 3: 64-bit elements with 32-bit shift -> revw
+define <vscale x 2 x i64> @lsl_lsr_orr_revd_i64(<vscale x 2 x i64> %x) {
+; CHECK-SVE-LABEL: lsl_lsr_orr_revd_i64:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: ptrue p0.d
+; CHECK-SVE-NEXT: revw z0.d, p0/m, z0.d
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SME-LABEL: lsl_lsr_orr_revd_i64:
+; CHECK-SME: // %bb.0:
+; CHECK-SME-NEXT: movi v1.2d, #0000000000000000
+; CHECK-SME-NEXT: xar z0.d, z0.d, z1.d, #32
+; CHECK-SME-NEXT: ret
+ %lsl = shl <vscale x 2 x i64> %x, splat(i64 32)
+ %lsr = lshr <vscale x 2 x i64> %x, splat(i64 32)
+ %orr = or <vscale x 2 x i64> %lsl, %lsr
+ ret <vscale x 2 x i64> %orr
+}
+
+; Test case 4: Order doesn't matter - lsr + lsl + orr -> revh
+define <vscale x 8 x i16> @lsr_lsl_orr_revh_i16(<vscale x 8 x i16> %x) {
+; CHECK-LABEL: lsr_lsl_orr_revh_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: revb z0.h, p0/m, z0.h
+; CHECK-NEXT: ret
+ %lsr = lshr <vscale x 8 x i16> %x, splat(i16 8)
+ %lsl = shl <vscale x 8 x i16> %x, splat(i16 8)
+ %orr = or <vscale x 8 x i16> %lsr, %lsl
+ ret <vscale x 8 x i16> %orr
+}
+
+; Test case 5: Non-half-width shift should not be optimized
+define <vscale x 8 x i16> @lsl_lsr_orr_no_opt_i16(<vscale x 8 x i16> %x) {
+; CHECK-LABEL: lsl_lsr_orr_no_opt_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z1.h, z0.h, #4
+; CHECK-NEXT: lsr z0.h, z0.h, #4
+; CHECK-NEXT: orr z0.d, z1.d, z0.d
+; CHECK-NEXT: ret
+ %lsl = shl <vscale x 8 x i16> %x, splat(i16 4)
+ %lsr = lshr <vscale x 8 x i16> %x, splat(i16 4)
+ %orr = or <vscale x 8 x i16> %lsl, %lsr
+ ret <vscale x 8 x i16> %orr
+}
+
+; Test case 6: Different shift amounts should not be optimized
+define <vscale x 8 x i16> @lsl_lsr_orr_different_shifts_i16(<vscale x 8 x i16> %x) {
+; CHECK-LABEL: lsl_lsr_orr_different_shifts_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z1.h, z0.h, #8
+; CHECK-NEXT: lsr z0.h, z0.h, #4
+; CHECK-NEXT: orr z0.d, z1.d, z0.d
+; CHECK-NEXT: ret
+ %lsl = shl <vscale x 8 x i16> %x, splat(i16 8)
+ %lsr = lshr <vscale x 8 x i16> %x, splat(i16 4)
+ %orr = or <vscale x 8 x i16> %lsl, %lsr
+ ret <vscale x 8 x i16> %orr
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/159953
More information about the llvm-commits
mailing list