[llvm] [AArch64] Fold lsl + lsr + orr to rev for half-width shifts (PR #159953)

via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 20 15:36:21 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: AZero13 (AZero13)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/159953.diff


2 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+261) 
- (added) llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll (+97) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cd7f0e719ad0c..183fc763cd2e9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -19964,6 +19964,254 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) {
                      CSel0.getOperand(1), getCondCode(DAG, CC1), CCmp);
 }
 
+// Fold lsl + lsr + orr to rev for half-width shifts
+// Pattern: orr(lsl(x, shift), lsr(x, shift)) -> rev(x) when shift == half_bitwidth
+static SDValue performLSL_LSR_ORRCombine(SDNode *N, SelectionDAG &DAG,
+                                         const AArch64Subtarget *Subtarget) {
+  if (!Subtarget->hasSVE())
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  if (!VT.isScalableVector())
+    return SDValue();
+
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  
+  // Check if one operand is LSL and the other is LSR
+  SDValue LSL, LSR;
+  if (LHS.getOpcode() == ISD::SHL && RHS.getOpcode() == ISD::SRL) {
+    LSL = LHS;
+    LSR = RHS;
+  } else if (LHS.getOpcode() == ISD::SRL && RHS.getOpcode() == ISD::SHL) {
+    LSL = RHS;
+    LSR = LHS;
+  } else {
+    return SDValue();
+  }
+
+  // Check that both shifts operate on the same source
+  SDValue Src = LSL.getOperand(0);
+  if (Src != LSR.getOperand(0))
+    return SDValue();
+
+  // Check that both shifts have the same constant amount
+  if (!isa<ConstantSDNode>(LSL.getOperand(1)) || 
+      !isa<ConstantSDNode>(LSR.getOperand(1)))
+    return SDValue();
+
+  uint64_t ShiftAmt = LSL.getConstantOperandVal(1);
+  if (ShiftAmt != LSR.getConstantOperandVal(1))
+    return SDValue();
+
+  // Check if shift amount equals half the bitwidth
+  EVT EltVT = VT.getVectorElementType();
+  if (!EltVT.isSimple())
+    return SDValue();
+
+  unsigned EltSize = EltVT.getSizeInBits();
+  if (ShiftAmt != EltSize / 2)
+    return SDValue();
+
+  // Determine the appropriate REV instruction based on element size and shift amount
+  unsigned RevOp;
+  switch (EltSize) {
+  case 16:
+    if (ShiftAmt == 8)
+      RevOp = AArch64ISD::BSWAP_MERGE_PASSTHRU;  // 16-bit elements, 8-bit shift -> revb
+    else
+      return SDValue();
+    break;
+  case 32:
+    if (ShiftAmt == 16)
+      RevOp = AArch64ISD::REVH_MERGE_PASSTHRU;  // 32-bit elements, 16-bit shift -> revh
+    else
+      return SDValue();
+    break;
+  case 64:
+    if (ShiftAmt == 32)
+      RevOp = AArch64ISD::REVW_MERGE_PASSTHRU;  // 64-bit elements, 32-bit shift -> revw
+    else
+      return SDValue();
+    break;
+  default:
+    return SDValue();
+  }
+
+  // Create the REV instruction
+  SDLoc DL(N);
+  SDValue Pg = getPredicateForVector(DAG, DL, VT);
+  SDValue Undef = DAG.getUNDEF(VT);
+  
+  return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef);
+}
+
+// Fold bswap to correct rev instruction for scalable vectors
+// DAGCombiner converts lsl+lsr+orr with 8-bit shift to BSWAP, but for scalable vectors
+// we need to use the correct REV instruction based on element size
+static SDValue performBSWAPCombine(SDNode *N, SelectionDAG &DAG,
+                                   const AArch64Subtarget *Subtarget) {
+  LLVM_DEBUG(dbgs() << "BSWAP combine called\n");
+  if (!Subtarget->hasSVE())
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  if (!VT.isScalableVector())
+    return SDValue();
+
+  LLVM_DEBUG(dbgs() << "BSWAP combine called for scalable vector\n");
+
+  EVT EltVT = VT.getVectorElementType();
+  if (!EltVT.isSimple())
+    return SDValue();
+
+  unsigned EltSize = EltVT.getSizeInBits();
+  
+  // For scalable vectors with 16-bit elements, BSWAP should use REVB, not REVH
+  // REVH is not available for 16-bit elements, only for 32-bit and 64-bit elements
+  // For 16-bit elements, REVB (byte reverse) is equivalent to halfword reverse
+  if (EltSize != 16)
+    return SDValue();  // Use default BSWAP lowering for other sizes
+
+  // The current BSWAP lowering is already correct for 16-bit elements
+  // BSWAP_MERGE_PASSTHRU maps to REVB which is correct for 16-bit elements
+  return SDValue();
+}
+
+// Fold rotl to rev instruction for half-width rotations on scalable vectors
+// Pattern: rotl(x, half_bitwidth) -> rev(x) for scalable vectors
+static SDValue performROTLCombine(SDNode *N, SelectionDAG &DAG,
+                                  const AArch64Subtarget *Subtarget) {
+  LLVM_DEBUG(dbgs() << "ROTL combine called\n");
+  if (!Subtarget->hasSVE())
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  if (!VT.isScalableVector())
+    return SDValue();
+
+  // Check that the rotation amount is a constant
+  if (!isa<ConstantSDNode>(N->getOperand(1)))
+    return SDValue();
+
+  uint64_t RotAmt = N->getConstantOperandVal(1);
+  
+  // Check if rotation amount equals half the bitwidth
+  EVT EltVT = VT.getVectorElementType();
+  if (!EltVT.isSimple())
+    return SDValue();
+
+  unsigned EltSize = EltVT.getSizeInBits();
+  if (RotAmt != EltSize / 2)
+    return SDValue();
+
+  // Determine the appropriate REV instruction based on element size
+  unsigned RevOp;
+  switch (EltSize) {
+  case 16:
+    return SDValue(); // 16-bit case handled by BSWAP
+  case 32:
+    RevOp = AArch64ISD::REVW_MERGE_PASSTHRU;  // 32-bit elements, 16-bit rotation -> revw
+    break;
+  case 64:
+    RevOp = AArch64ISD::REVD_MERGE_PASSTHRU;  // 64-bit elements, 32-bit rotation -> revd
+    break;
+  default:
+    return SDValue();
+  }
+
+  // Create the REV instruction
+  SDLoc DL(N);
+  SDValue Src = N->getOperand(0);
+  SDValue Pg = getPredicateForVector(DAG, DL, VT);
+  SDValue Undef = DAG.getUNDEF(VT);
+  
+  return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef);
+}
+
+// Fold predicated shl + srl + orr to rev for half-width shifts on scalable vectors
+// Pattern: orr(AArch64ISD::SHL_PRED(pg, x, shift), AArch64ISD::SRL_PRED(pg, x, shift)) -> rev(x) when shift == half_bitwidth
+static SDValue performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG,
+                                             const AArch64Subtarget *Subtarget) {
+  if (!Subtarget->hasSVE())
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  if (!VT.isScalableVector())
+    return SDValue();
+
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  
+  // Check if one operand is predicated SHL and the other is predicated SRL
+  SDValue SHL, SRL;
+  if (LHS.getOpcode() == AArch64ISD::SHL_PRED && RHS.getOpcode() == AArch64ISD::SRL_PRED) {
+    SHL = LHS;
+    SRL = RHS;
+  } else if (LHS.getOpcode() == AArch64ISD::SRL_PRED && RHS.getOpcode() == AArch64ISD::SHL_PRED) {
+    SHL = RHS;
+    SRL = LHS;
+  } else {
+    return SDValue();
+  }
+
+  // Check that both shifts operate on the same predicate and source
+  SDValue SHLPred = SHL.getOperand(0);
+  SDValue SHLSrc = SHL.getOperand(1);
+  SDValue SHLAmt = SHL.getOperand(2);
+  
+  SDValue SRLPred = SRL.getOperand(0);
+  SDValue SRLSrc = SRL.getOperand(1);
+  SDValue SRLAmt = SRL.getOperand(2);
+  
+  if (SHLPred != SRLPred || SHLSrc != SRLSrc || SHLAmt != SRLAmt)
+    return SDValue();
+
+  // Check that the shift amount is a constant
+  if (!isa<ConstantSDNode>(SHLAmt->getOperand(0))) // For splat_vector
+    return SDValue();
+
+  uint64_t ShiftAmt = cast<ConstantSDNode>(SHLAmt->getOperand(0))->getZExtValue();
+
+  // Check if shift amount equals half the bitwidth
+  EVT EltVT = VT.getVectorElementType();
+  if (!EltVT.isSimple())
+    return SDValue();
+
+  unsigned EltSize = EltVT.getSizeInBits();
+  if (ShiftAmt != EltSize / 2)
+    return SDValue();
+
+  // Determine the appropriate REV instruction based on element size and shift amount
+  unsigned RevOp;
+  switch (EltSize) {
+  case 16:
+    return SDValue(); // 16-bit case handled by BSWAP
+  case 32:
+    if (ShiftAmt == 16)
+      RevOp = AArch64ISD::REVH_MERGE_PASSTHRU;  // 32-bit elements, 16-bit shift -> revh
+    else
+      return SDValue();
+    break;
+  case 64:
+    if (ShiftAmt == 32)
+      RevOp = AArch64ISD::REVW_MERGE_PASSTHRU;  // 64-bit elements, 32-bit shift -> revw
+    else
+      return SDValue();
+    break;
+  default:
+    return SDValue();
+  }
+
+  // Create the REV instruction
+  SDLoc DL(N);
+  SDValue Pg = SHLPred;
+  SDValue Src = SHLSrc;
+  SDValue Undef = DAG.getUNDEF(VT);
+  
+  return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef);
+}
+
 static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                 const AArch64Subtarget *Subtarget,
                                 const AArch64TargetLowering &TLI) {
@@ -19972,6 +20220,13 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   if (SDValue R = performANDORCSELCombine(N, DAG))
     return R;
 
+  if (SDValue R = performLSL_LSR_ORRCombine(N, DAG, Subtarget))
+    return R;
+
+  // Try the predicated shift combine for SVE 
+  if (SDValue R = performSVE_SHL_SRL_ORRCombine(N, DAG, Subtarget))
+    return R;
+
   return SDValue();
 }
 
@@ -27592,6 +27847,12 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performFpToIntCombine(N, DAG, DCI, Subtarget);
   case ISD::OR:
     return performORCombine(N, DCI, Subtarget, *this);
+  case ISD::BSWAP:
+    return performBSWAPCombine(N, DAG, Subtarget);
+  case AArch64ISD::BSWAP_MERGE_PASSTHRU:
+    return performBSWAPCombine(N, DAG, Subtarget);
+  case ISD::ROTL:
+    return performROTLCombine(N, DAG, Subtarget);
   case ISD::AND:
     return performANDCombine(N, DCI);
   case ISD::FADD:
diff --git a/llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll b/llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll
new file mode 100644
index 0000000000000..8abfbdcc3edcb
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll
@@ -0,0 +1,97 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck --check-prefixes=CHECK,CHECK-SVE %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -force-streaming < %s | FileCheck --check-prefixes=CHECK,CHECK-SME %s
+
+; Test the optimization that folds lsl + lsr + orr to rev for half-width shifts
+
+; Test case 1: 16-bit elements with 8-bit shift -> revh
+define <vscale x 8 x i16> @lsl_lsr_orr_revh_i16(<vscale x 8 x i16> %x) {
+; CHECK-LABEL: lsl_lsr_orr_revh_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    revb z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %lsl = shl <vscale x 8 x i16> %x, splat(i16 8)
+  %lsr = lshr <vscale x 8 x i16> %x, splat(i16 8)
+  %orr = or <vscale x 8 x i16> %lsl, %lsr
+  ret <vscale x 8 x i16> %orr
+}
+
+; Test case 2: 32-bit elements with 16-bit shift -> revh
+define <vscale x 4 x i32> @lsl_lsr_orr_revw_i32(<vscale x 4 x i32> %x) {
+; CHECK-SVE-LABEL: lsl_lsr_orr_revw_i32:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    ptrue p0.s
+; CHECK-SVE-NEXT:    revh z0.s, p0/m, z0.s
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SME-LABEL: lsl_lsr_orr_revw_i32:
+; CHECK-SME:       // %bb.0:
+; CHECK-SME-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-SME-NEXT:    xar z0.s, z0.s, z1.s, #16
+; CHECK-SME-NEXT:    ret
+  %lsl = shl <vscale x 4 x i32> %x, splat(i32 16)
+  %lsr = lshr <vscale x 4 x i32> %x, splat(i32 16)
+  %orr = or <vscale x 4 x i32> %lsl, %lsr
+  ret <vscale x 4 x i32> %orr
+}
+
+; Test case 3: 64-bit elements with 32-bit shift -> revw
+define <vscale x 2 x i64> @lsl_lsr_orr_revd_i64(<vscale x 2 x i64> %x) {
+; CHECK-SVE-LABEL: lsl_lsr_orr_revd_i64:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    ptrue p0.d
+; CHECK-SVE-NEXT:    revw z0.d, p0/m, z0.d
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SME-LABEL: lsl_lsr_orr_revd_i64:
+; CHECK-SME:       // %bb.0:
+; CHECK-SME-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-SME-NEXT:    xar z0.d, z0.d, z1.d, #32
+; CHECK-SME-NEXT:    ret
+  %lsl = shl <vscale x 2 x i64> %x, splat(i64 32)
+  %lsr = lshr <vscale x 2 x i64> %x, splat(i64 32)
+  %orr = or <vscale x 2 x i64> %lsl, %lsr
+  ret <vscale x 2 x i64> %orr
+}
+
+; Test case 4: Order doesn't matter - lsr + lsl + orr -> revh
+define <vscale x 8 x i16> @lsr_lsl_orr_revh_i16(<vscale x 8 x i16> %x) {
+; CHECK-LABEL: lsr_lsl_orr_revh_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    revb z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %lsr = lshr <vscale x 8 x i16> %x, splat(i16 8)
+  %lsl = shl <vscale x 8 x i16> %x, splat(i16 8)
+  %orr = or <vscale x 8 x i16> %lsr, %lsl
+  ret <vscale x 8 x i16> %orr
+}
+
+; Test case 5: Non-half-width shift should not be optimized
+define <vscale x 8 x i16> @lsl_lsr_orr_no_opt_i16(<vscale x 8 x i16> %x) {
+; CHECK-LABEL: lsl_lsr_orr_no_opt_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z1.h, z0.h, #4
+; CHECK-NEXT:    lsr z0.h, z0.h, #4
+; CHECK-NEXT:    orr z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+  %lsl = shl <vscale x 8 x i16> %x, splat(i16 4)
+  %lsr = lshr <vscale x 8 x i16> %x, splat(i16 4)
+  %orr = or <vscale x 8 x i16> %lsl, %lsr
+  ret <vscale x 8 x i16> %orr
+}
+
+; Test case 6: Different shift amounts should not be optimized
+define <vscale x 8 x i16> @lsl_lsr_orr_different_shifts_i16(<vscale x 8 x i16> %x) {
+; CHECK-LABEL: lsl_lsr_orr_different_shifts_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z1.h, z0.h, #8
+; CHECK-NEXT:    lsr z0.h, z0.h, #4
+; CHECK-NEXT:    orr z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+  %lsl = shl <vscale x 8 x i16> %x, splat(i16 8)
+  %lsr = lshr <vscale x 8 x i16> %x, splat(i16 4)
+  %orr = or <vscale x 8 x i16> %lsl, %lsr
+  ret <vscale x 8 x i16> %orr
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/159953


More information about the llvm-commits mailing list