[llvm] [AArch64][SVE2] Generate urshr rounding shift rights (PR #78374)

Usman Nadeem via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 31 13:28:44 PST 2024


https://github.com/UsmanNadeem updated https://github.com/llvm/llvm-project/pull/78374

>From 5387297fd54b3055f9876db7e1e7f4299421672f Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Tue, 16 Jan 2024 17:02:01 -0800
Subject: [PATCH 1/4] [AArch64][SVE2] Generate signed/unsigned rounding shift
 rights

Matching code is similar to that for rshrnb except that immediate
shift value has a larger range, and support for signed shift. rshrnb
now uses the new AArch64ISD node for uniform rounding.

Change-Id: Idbb811f318d33c7637371cf7bb00285d20e1771d
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  81 +++++--
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   2 +
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  10 +-
 .../AArch64/sve2-intrinsics-combine-rshrnb.ll |  17 +-
 llvm/test/CodeGen/AArch64/sve2-rsh.ll         | 203 ++++++++++++++++++
 5 files changed, 276 insertions(+), 37 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve2-rsh.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index aa208e397f5d9..895319f3d7115 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2689,6 +2689,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::RSHRNB_I)
     MAKE_CASE(AArch64ISD::CTTZ_ELTS)
     MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
+    MAKE_CASE(AArch64ISD::SRSHR_I_PRED)
+    MAKE_CASE(AArch64ISD::URSHR_I_PRED)
   }
 #undef MAKE_CASE
   return nullptr;
@@ -2973,6 +2975,7 @@ static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
 static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
 static SDValue convertFixedMaskToScalableVector(SDValue Mask,
                                                 SelectionDAG &DAG);
+static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
 static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
                                              EVT VT);
 
@@ -13838,6 +13841,42 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
   return SDValue();
 }
 
+static SDValue tryLowerToRoundingShiftRightByImm(SDValue Shift,
+                                                 SelectionDAG &DAG) {
+  if (Shift->getOpcode() != ISD::SRL && Shift->getOpcode() != ISD::SRA)
+    return SDValue();
+
+  EVT ResVT = Shift.getValueType();
+  assert(ResVT.isScalableVT());
+
+  auto ShiftOp1 =
+      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
+  if (!ShiftOp1)
+    return SDValue();
+  unsigned ShiftValue = ShiftOp1->getZExtValue();
+
+  if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
+    return SDValue();
+
+  SDValue Add = Shift->getOperand(0);
+  if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
+    return SDValue();
+  auto AddOp1 =
+      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
+  if (!AddOp1)
+    return SDValue();
+  uint64_t AddValue = AddOp1->getZExtValue();
+  if (AddValue != 1ULL << (ShiftValue - 1))
+    return SDValue();
+
+  SDLoc DL(Shift);
+  unsigned Opc = Shift->getOpcode() == ISD::SRA ? AArch64ISD::SRSHR_I_PRED
+                                                : AArch64ISD::URSHR_I_PRED;
+  return DAG.getNode(Opc, DL, ResVT, getPredicateForVector(DAG, DL, ResVT),
+                     Add->getOperand(0),
+                     DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
+}
+
 SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
                                                       SelectionDAG &DAG) const {
   EVT VT = Op.getValueType();
@@ -13863,6 +13902,10 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
                        Op.getOperand(0), Op.getOperand(1));
   case ISD::SRA:
   case ISD::SRL:
+    if (VT.isScalableVector() && Subtarget->hasSVE2orSME())
+      if (SDValue RSH = tryLowerToRoundingShiftRightByImm(Op, DAG))
+        return RSH;
+
     if (VT.isScalableVector() ||
         useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
       unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
@@ -20192,6 +20235,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_uqsub_x:
     return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
+  case Intrinsic::aarch64_sve_srshr:
+    return DAG.getNode(AArch64ISD::SRSHR_I_PRED, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2), N->getOperand(3));
+  case Intrinsic::aarch64_sve_urshr:
+    return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2), N->getOperand(3));
   case Intrinsic::aarch64_sve_asrd:
     return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -20819,12 +20868,13 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
 // a uzp1 or a truncating store.
 static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
                                          const AArch64Subtarget *Subtarget) {
-  EVT VT = Srl->getValueType(0);
+  if (Srl->getOpcode() != AArch64ISD::URSHR_I_PRED)
+    return SDValue();
 
-  if (!VT.isScalableVector() || !Subtarget->hasSVE2() ||
-      Srl->getOpcode() != ISD::SRL)
+  if (!isAllActivePredicate(DAG, Srl.getOperand(0)))
     return SDValue();
 
+  EVT VT = Srl->getValueType(0);
   EVT ResVT;
   if (VT == MVT::nxv8i16)
     ResVT = MVT::nxv16i8;
@@ -20835,29 +20885,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
   else
     return SDValue();
 
-  auto SrlOp1 =
-      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Srl->getOperand(1)));
-  if (!SrlOp1)
-    return SDValue();
-  unsigned ShiftValue = SrlOp1->getZExtValue();
-  if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
-    return SDValue();
-
-  SDValue Add = Srl->getOperand(0);
-  if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
-    return SDValue();
-  auto AddOp1 =
-      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
-  if (!AddOp1)
-    return SDValue();
-  uint64_t AddValue = AddOp1->getZExtValue();
-  if (AddValue != 1ULL << (ShiftValue - 1))
+  unsigned ShiftValue =
+      cast<ConstantSDNode>(Srl->getOperand(2))->getZExtValue();
+  if (ShiftValue > ResVT.getScalarSizeInBits())
     return SDValue();
 
   SDLoc DL(Srl);
-  SDValue Rshrnb = DAG.getNode(
-      AArch64ISD::RSHRNB_I, DL, ResVT,
-      {Add->getOperand(0), DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
+  SDValue Rshrnb = DAG.getNode(AArch64ISD::RSHRNB_I, DL, ResVT,
+                               {Srl->getOperand(1), Srl->getOperand(2)});
   return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6505931e17e18..b292783886b78 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -215,7 +215,9 @@ enum NodeType : unsigned {
   UQSHL_I,
   SQSHLU_I,
   SRSHR_I,
+  SRSHR_I_PRED,
   URSHR_I,
+  URSHR_I_PRED,
 
   // Vector narrowing shift by immediate (bottom)
   RSHRNB_I,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index c4d69232c9e30..516ab36464379 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -232,6 +232,8 @@ def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [
 ]>;
 
 def AArch64asrd_m1 : SDNode<"AArch64ISD::SRAD_MERGE_OP1", SDT_AArch64Arith_Imm>;
+def AArch64urshri_p : SDNode<"AArch64ISD::URSHR_I_PRED", SDT_AArch64Arith_Imm>;
+def AArch64srshri_p : SDNode<"AArch64ISD::SRSHR_I_PRED", SDT_AArch64Arith_Imm>;
 
 def SDT_AArch64IntExtend : SDTypeProfile<1, 4, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVT<3, OtherVT>, SDTCisVec<4>,
@@ -3538,8 +3540,8 @@ let Predicates = [HasSVE2orSME] in {
   // SVE2 predicated shifts
   defm SQSHL_ZPmI  : sve_int_bin_pred_shift_imm_left_dup<0b0110, "sqshl",  "SQSHL_ZPZI",  int_aarch64_sve_sqshl>;
   defm UQSHL_ZPmI  : sve_int_bin_pred_shift_imm_left_dup<0b0111, "uqshl",  "UQSHL_ZPZI",  int_aarch64_sve_uqshl>;
-  defm SRSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1100, "srshr",  "SRSHR_ZPZI",  int_aarch64_sve_srshr>;
-  defm URSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1101, "urshr",  "URSHR_ZPZI",  int_aarch64_sve_urshr>;
+  defm SRSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1100, "srshr",  "SRSHR_ZPZI",  AArch64srshri_p>;
+  defm URSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1101, "urshr",  "URSHR_ZPZI",  AArch64urshri_p>;
   defm SQSHLU_ZPmI : sve_int_bin_pred_shift_imm_left<    0b1111, "sqshlu", "SQSHLU_ZPZI", int_aarch64_sve_sqshlu>;
 
   // SVE2 integer add/subtract long
@@ -3583,8 +3585,8 @@ let Predicates = [HasSVE2orSME] in {
   // SVE2 bitwise shift right and accumulate
   defm SSRA_ZZI  : sve2_int_bin_accum_shift_imm_right<0b00, "ssra",  AArch64ssra>;
   defm USRA_ZZI  : sve2_int_bin_accum_shift_imm_right<0b01, "usra",  AArch64usra>;
-  defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, int_aarch64_sve_srshr>;
-  defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, int_aarch64_sve_urshr>;
+  defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, AArch64srshri_p>;
+  defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, AArch64urshri_p>;
 
   // SVE2 complex integer add
   defm CADD_ZZI   : sve2_int_cadd<0b0, "cadd",   int_aarch64_sve_cadd_x>;
diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
index 0afd11d098a00..58ef846a31723 100644
--- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
@@ -184,16 +184,14 @@ define void @wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i6
 define void @neg_wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i64> %arg1){
 ; CHECK-LABEL: neg_wide_add_shift_add_rshrnb_d:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.d, #0x800000000000
-; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z1.d, z1.d, z2.d
-; CHECK-NEXT:    lsr z1.d, z1.d, #48
-; CHECK-NEXT:    lsr z0.d, z0.d, #48
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    urshr z1.d, p0/m, z1.d, #48
+; CHECK-NEXT:    urshr z0.d, p0/m, z0.d, #48
 ; CHECK-NEXT:    uzp1 z0.s, z0.s, z1.s
-; CHECK-NEXT:    ld1w { z1.s }, p0/z, [x0, x1, lsl #2]
+; CHECK-NEXT:    ld1w { z1.s }, p1/z, [x0, x1, lsl #2]
 ; CHECK-NEXT:    add z0.s, z1.s, z0.s
-; CHECK-NEXT:    st1w { z0.s }, p0, [x0, x1, lsl #2]
+; CHECK-NEXT:    st1w { z0.s }, p1, [x0, x1, lsl #2]
 ; CHECK-NEXT:    ret
   %1 = add <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 140737488355328, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %2 = lshr <vscale x 4 x i64> %1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 48, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
@@ -286,8 +284,7 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0]
-; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
-; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    urshr z0.d, p0/m, z0.d, #6
 ; CHECK-NEXT:    st1h { z0.d }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 2 x i64>, ptr %ptr, align 2
diff --git a/llvm/test/CodeGen/AArch64/sve2-rsh.ll b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
new file mode 100644
index 0000000000000..2bdfc1931cdc2
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
@@ -0,0 +1,203 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64 -mattr=+sve < %s -o - | FileCheck --check-prefixes=CHECK,SVE %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 < %s -o - | FileCheck --check-prefixes=CHECK,SVE2 %s
+
+; Wrong add/shift amount. Should be 32 for shift of 6.
+define <vscale x 2 x i64> @neg_urshr_1(<vscale x 2 x i64> %x) {
+; CHECK-LABEL: neg_urshr_1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add z0.d, z0.d, #16 // =0x10
+; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 16)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+; Vector Shift.
+define <vscale x 2 x i64> @neg_urshr_2(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
+; CHECK-LABEL: neg_urshr_2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
+; CHECK-NEXT:    lsr z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, %y
+  ret <vscale x 2 x i64> %sh
+}
+
+; Vector Add.
+define <vscale x 2 x i64> @neg_urshr_3(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
+; CHECK-LABEL: neg_urshr_3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add z0.d, z0.d, z1.d
+; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, %y
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+; Add has two uses.
+define <vscale x 2 x i64> @neg_urshr_4(<vscale x 2 x i64> %x) {
+; CHECK-LABEL: neg_urshr_4:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    str z8, [sp] // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w30, -8
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
+; CHECK-NEXT:    lsr z8.d, z0.d, #6
+; CHECK-NEXT:    bl use
+; CHECK-NEXT:    mov z0.d, z8.d
+; CHECK-NEXT:    ldr z8, [sp] // 16-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  call void @use(<vscale x 2 x i64> %add)
+  ret <vscale x 2 x i64> %sh
+}
+
+define <vscale x 16 x i8> @urshr_i8(<vscale x 16 x i8> %x) {
+; SVE-LABEL: urshr_i8:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.b, z0.b, #32 // =0x20
+; SVE-NEXT:    lsr z0.b, z0.b, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i8:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.b
+; SVE2-NEXT:    urshr z0.b, p0/m, z0.b, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 16 x i8> %x, splat (i8 32)
+  %sh = lshr <vscale x 16 x i8> %add, splat (i8 6)
+  ret <vscale x 16 x i8> %sh
+}
+
+define <vscale x 8 x i16> @urshr_i16(<vscale x 8 x i16> %x) {
+; SVE-LABEL: urshr_i16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.h, z0.h, #32 // =0x20
+; SVE-NEXT:    lsr z0.h, z0.h, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i16:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.h
+; SVE2-NEXT:    urshr z0.h, p0/m, z0.h, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 8 x i16> %x, splat (i16 32)
+  %sh = lshr <vscale x 8 x i16> %add, splat (i16 6)
+  ret <vscale x 8 x i16> %sh
+}
+
+define <vscale x 4 x i32> @urshr_i32(<vscale x 4 x i32> %x) {
+; SVE-LABEL: urshr_i32:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.s, z0.s, #32 // =0x20
+; SVE-NEXT:    lsr z0.s, z0.s, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i32:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    urshr z0.s, p0/m, z0.s, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 4 x i32> %x, splat (i32 32)
+  %sh = lshr <vscale x 4 x i32> %add, splat (i32 6)
+  ret <vscale x 4 x i32> %sh
+}
+
+define <vscale x 2 x i64> @urshr_i64(<vscale x 2 x i64> %x) {
+; SVE-LABEL: urshr_i64:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.d, z0.d, #32 // =0x20
+; SVE-NEXT:    lsr z0.d, z0.d, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i64:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.d
+; SVE2-NEXT:    urshr z0.d, p0/m, z0.d, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+define <vscale x 16 x i8> @srshr_i8(<vscale x 16 x i8> %x) {
+; SVE-LABEL: srshr_i8:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.b, z0.b, #32 // =0x20
+; SVE-NEXT:    asr z0.b, z0.b, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: srshr_i8:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.b
+; SVE2-NEXT:    srshr z0.b, p0/m, z0.b, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 16 x i8> %x, splat (i8 32)
+  %sh = ashr <vscale x 16 x i8> %add, splat (i8 6)
+  ret <vscale x 16 x i8> %sh
+}
+
+define <vscale x 8 x i16> @srshr_i16(<vscale x 8 x i16> %x) {
+; SVE-LABEL: srshr_i16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.h, z0.h, #32 // =0x20
+; SVE-NEXT:    asr z0.h, z0.h, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: srshr_i16:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.h
+; SVE2-NEXT:    srshr z0.h, p0/m, z0.h, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 8 x i16> %x, splat (i16 32)
+  %sh = ashr <vscale x 8 x i16> %add, splat (i16 6)
+  ret <vscale x 8 x i16> %sh
+}
+
+define <vscale x 4 x i32> @srshr_i32(<vscale x 4 x i32> %x) {
+; SVE-LABEL: srshr_i32:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.s, z0.s, #32 // =0x20
+; SVE-NEXT:    asr z0.s, z0.s, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: srshr_i32:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    srshr z0.s, p0/m, z0.s, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 4 x i32> %x, splat (i32 32)
+  %sh = ashr <vscale x 4 x i32> %add, splat (i32 6)
+  ret <vscale x 4 x i32> %sh
+}
+
+define <vscale x 2 x i64> @srshr_i64(<vscale x 2 x i64> %x) {
+; SVE-LABEL: srshr_i64:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.d, z0.d, #32 // =0x20
+; SVE-NEXT:    asr z0.d, z0.d, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: srshr_i64:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.d
+; SVE2-NEXT:    srshr z0.d, p0/m, z0.d, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = ashr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+declare void @use(<vscale x 2 x i64>)

>From a1c5d95a4b824fdcbbfd3e4e31b40f38b27cbab5 Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Fri, 26 Jan 2024 19:06:14 -0800
Subject: [PATCH 2/4] Handle wide operations, Account for overflows, remove
 signed shift

Change-Id: I7450629fa43bb3ac1bc40daaa760255eed483c10
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  79 +++++--
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   1 -
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |   5 +-
 .../AArch64/sve2-intrinsics-combine-rshrnb.ll |  36 +--
 llvm/test/CodeGen/AArch64/sve2-rsh.ll         | 209 +++++++++++++-----
 5 files changed, 230 insertions(+), 100 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 895319f3d7115..b06828f58ddfc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2689,7 +2689,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::RSHRNB_I)
     MAKE_CASE(AArch64ISD::CTTZ_ELTS)
     MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
-    MAKE_CASE(AArch64ISD::SRSHR_I_PRED)
     MAKE_CASE(AArch64ISD::URSHR_I_PRED)
   }
 #undef MAKE_CASE
@@ -13843,7 +13842,7 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
 
 static SDValue tryLowerToRoundingShiftRightByImm(SDValue Shift,
                                                  SelectionDAG &DAG) {
-  if (Shift->getOpcode() != ISD::SRL && Shift->getOpcode() != ISD::SRA)
+  if (Shift->getOpcode() != ISD::SRL)
     return SDValue();
 
   EVT ResVT = Shift.getValueType();
@@ -13861,6 +13860,10 @@ static SDValue tryLowerToRoundingShiftRightByImm(SDValue Shift,
   SDValue Add = Shift->getOperand(0);
   if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
     return SDValue();
+
+  if (!Add->getFlags().hasNoUnsignedWrap())
+    return SDValue();
+
   auto AddOp1 =
       dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
   if (!AddOp1)
@@ -13870,10 +13873,8 @@ static SDValue tryLowerToRoundingShiftRightByImm(SDValue Shift,
     return SDValue();
 
   SDLoc DL(Shift);
-  unsigned Opc = Shift->getOpcode() == ISD::SRA ? AArch64ISD::SRSHR_I_PRED
-                                                : AArch64ISD::URSHR_I_PRED;
-  return DAG.getNode(Opc, DL, ResVT, getPredicateForVector(DAG, DL, ResVT),
-                     Add->getOperand(0),
+  return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
+                     getPredicateForVector(DAG, DL, ResVT), Add->getOperand(0),
                      DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
 }
 
@@ -17730,9 +17731,6 @@ static SDValue performReinterpretCastCombine(SDNode *N) {
 
 static SDValue performSVEAndCombine(SDNode *N,
                                     TargetLowering::DAGCombinerInfo &DCI) {
-  if (DCI.isBeforeLegalizeOps())
-    return SDValue();
-
   SelectionDAG &DAG = DCI.DAG;
   SDValue Src = N->getOperand(0);
   unsigned Opc = Src->getOpcode();
@@ -17788,6 +17786,9 @@ static SDValue performSVEAndCombine(SDNode *N,
     return DAG.getNode(Opc, DL, N->getValueType(0), And);
   }
 
+  if (DCI.isBeforeLegalizeOps())
+    return SDValue();
+
   // If both sides of AND operations are i1 splat_vectors then
   // we can produce just i1 splat_vector as the result.
   if (isAllActivePredicate(DAG, N->getOperand(0)))
@@ -20235,9 +20236,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_uqsub_x:
     return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
-  case Intrinsic::aarch64_sve_srshr:
-    return DAG.getNode(AArch64ISD::SRSHR_I_PRED, SDLoc(N), N->getValueType(0),
-                       N->getOperand(1), N->getOperand(2), N->getOperand(3));
   case Intrinsic::aarch64_sve_urshr:
     return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -20857,11 +20855,56 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// Try to combine rounding shifts where the operands come from an extend, and
+// the result is truncated and combined into one vector.
+//   uzp1(urshr(uunpklo(X),C), urshr(uunpkhi(X), C)) -> urshr(X, C)
+static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  EVT VT = Op0->getValueType(0);
+  EVT ResVT = N->getValueType(0);
+
+  // Truncating combine?
+  if (ResVT.widenIntegerVectorElementType(*DAG.getContext()) !=
+      VT.getDoubleNumVectorElementsVT(*DAG.getContext()))
+    return SDValue();
+
+  unsigned RshOpc = Op0.getOpcode();
+  if (RshOpc != AArch64ISD::URSHR_I_PRED)
+    return SDValue();
+  if (!isAllActivePredicate(DAG, Op0.getOperand(0)) ||
+      !isAllActivePredicate(DAG, Op1.getOperand(0)))
+    return SDValue();
+
+  // Same op code and imm value?
+  SDValue ShiftValue = Op0.getOperand(2);
+  if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(2))
+    return SDValue();
+  // We cannot reduce the type if shift value is too large for type.
+  if (ShiftValue->getAsZExtVal() > ResVT.getScalarSizeInBits())
+    return SDValue();
+
+  // Same unextended operand value?
+  SDValue Lo = Op0.getOperand(1);
+  SDValue Hi = Op1.getOperand(1);
+  if (Lo.getOpcode() != AArch64ISD::UUNPKLO &&
+      Hi.getOpcode() != AArch64ISD::UUNPKHI)
+    return SDValue();
+  SDValue OrigArg = Lo.getOperand(0);
+  if (OrigArg != Op1.getOperand(1).getOperand(0))
+    return SDValue();
+
+  SDLoc DL(N);
+  return DAG.getNode(RshOpc, DL, ResVT, getPredicateForVector(DAG, DL, ResVT),
+                     OrigArg, Op0.getOperand(2));
+}
+
 // Try to simplify:
-//    t1 = nxv8i16 add(X, 1 << (ShiftValue - 1))
-//    t2 = nxv8i16 srl(t1, ShiftValue)
+//    t1 = nxv8i16 urshr(X, shiftvalue)
 // to
-//    t1 = nxv8i16 rshrnb(X, shiftvalue).
+//    t1 = nxv16i8 rshrnb(X, shiftvalue).
+//    t2 = nxv8i16 = bitcast t1
 // rshrnb will zero the top half bits of each element. Therefore, this combine
 // should only be performed when a following instruction with the rshrnb
 // as an operand does not care about the top half of each element. For example,
@@ -20885,8 +20928,7 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
   else
     return SDValue();
 
-  unsigned ShiftValue =
-      cast<ConstantSDNode>(Srl->getOperand(2))->getZExtValue();
+  unsigned ShiftValue = Srl->getOperand(2)->getAsZExtVal();
   if (ShiftValue > ResVT.getScalarSizeInBits())
     return SDValue();
 
@@ -20930,6 +20972,9 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  if (SDValue RSh = tryCombineExtendRShTrunc(N, DAG))
+    return RSh;
+
   if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
     return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index b292783886b78..d4968691809e5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -215,7 +215,6 @@ enum NodeType : unsigned {
   UQSHL_I,
   SQSHLU_I,
   SRSHR_I,
-  SRSHR_I_PRED,
   URSHR_I,
   URSHR_I_PRED,
 
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 516ab36464379..e83d8e5bde79e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -233,7 +233,6 @@ def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [
 
 def AArch64asrd_m1 : SDNode<"AArch64ISD::SRAD_MERGE_OP1", SDT_AArch64Arith_Imm>;
 def AArch64urshri_p : SDNode<"AArch64ISD::URSHR_I_PRED", SDT_AArch64Arith_Imm>;
-def AArch64srshri_p : SDNode<"AArch64ISD::SRSHR_I_PRED", SDT_AArch64Arith_Imm>;
 
 def SDT_AArch64IntExtend : SDTypeProfile<1, 4, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVT<3, OtherVT>, SDTCisVec<4>,
@@ -3540,7 +3539,7 @@ let Predicates = [HasSVE2orSME] in {
   // SVE2 predicated shifts
   defm SQSHL_ZPmI  : sve_int_bin_pred_shift_imm_left_dup<0b0110, "sqshl",  "SQSHL_ZPZI",  int_aarch64_sve_sqshl>;
   defm UQSHL_ZPmI  : sve_int_bin_pred_shift_imm_left_dup<0b0111, "uqshl",  "UQSHL_ZPZI",  int_aarch64_sve_uqshl>;
-  defm SRSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1100, "srshr",  "SRSHR_ZPZI",  AArch64srshri_p>;
+  defm SRSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1100, "srshr",  "SRSHR_ZPZI",  int_aarch64_sve_srshr>;
   defm URSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1101, "urshr",  "URSHR_ZPZI",  AArch64urshri_p>;
   defm SQSHLU_ZPmI : sve_int_bin_pred_shift_imm_left<    0b1111, "sqshlu", "SQSHLU_ZPZI", int_aarch64_sve_sqshlu>;
 
@@ -3585,7 +3584,7 @@ let Predicates = [HasSVE2orSME] in {
   // SVE2 bitwise shift right and accumulate
   defm SSRA_ZZI  : sve2_int_bin_accum_shift_imm_right<0b00, "ssra",  AArch64ssra>;
   defm USRA_ZZI  : sve2_int_bin_accum_shift_imm_right<0b01, "usra",  AArch64usra>;
-  defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, AArch64srshri_p>;
+  defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, int_aarch64_sve_srshr>;
   defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, AArch64urshri_p>;
 
   // SVE2 complex integer add
diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
index 58ef846a31723..2795b6bd3fd45 100644
--- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
@@ -10,7 +10,7 @@ define void @add_lshr_rshrnb_b_6(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add nuw <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -28,7 +28,7 @@ define void @neg_add_lshr_rshrnb_b_6(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 1, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add nuw <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 1, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -45,7 +45,7 @@ define void @add_lshr_rshrnb_h_7(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 64, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add nuw <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 64, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 7, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -62,7 +62,7 @@ define void @add_lshr_rshrn_h_6(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1h { z0.s }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 4 x i32>, ptr %ptr, align 2
-  %1 = add <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 32, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
+  %1 = add nuw <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 32, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %2 = lshr <vscale x 4 x i32> %1, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 6, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %3 = trunc <vscale x 4 x i32> %2 to <vscale x 4 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -79,7 +79,7 @@ define void @add_lshr_rshrnb_h_2(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1h { z0.s }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 4 x i32>, ptr %ptr, align 2
-  %1 = add <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
+  %1 = add nuw <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %2 = lshr <vscale x 4 x i32> %1, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %3 = trunc <vscale x 4 x i32> %2 to <vscale x 4 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -92,7 +92,7 @@ define void @neg_add_lshr_rshrnb_h_0(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ret
   %load = load <vscale x 4 x i32>, ptr %ptr, align 2
-  %1 = add <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
+  %1 = add nuw <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %2 = lshr <vscale x 4 x i32> %1, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 -1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %3 = trunc <vscale x 4 x i32> %2 to <vscale x 4 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -109,7 +109,7 @@ define void @neg_zero_shift(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1h { z0.s }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 4 x i32>, ptr %ptr, align 2
-  %1 = add <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
+  %1 = add nuw <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %2 = lshr <vscale x 4 x i32> %1, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 0, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %3 = trunc <vscale x 4 x i32> %2 to <vscale x 4 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -128,7 +128,7 @@ define void @wide_add_shift_add_rshrnb_b(ptr %dest, i64 %index, <vscale x 16 x i
 ; CHECK-NEXT:    add z0.b, z1.b, z0.b
 ; CHECK-NEXT:    st1b { z0.b }, p0, [x0, x1]
 ; CHECK-NEXT:    ret
-  %1 = add <vscale x 16 x i16> %arg1, shufflevector (<vscale x 16 x i16> insertelement (<vscale x 16 x i16> poison, i16 32, i64 0), <vscale x 16 x i16> poison, <vscale x 16 x i32> zeroinitializer)
+  %1 = add nuw <vscale x 16 x i16> %arg1, shufflevector (<vscale x 16 x i16> insertelement (<vscale x 16 x i16> poison, i16 32, i64 0), <vscale x 16 x i16> poison, <vscale x 16 x i32> zeroinitializer)
   %2 = lshr <vscale x 16 x i16> %1, shufflevector (<vscale x 16 x i16> insertelement (<vscale x 16 x i16> poison, i16 6, i64 0), <vscale x 16 x i16> poison, <vscale x 16 x i32> zeroinitializer)
   %3 = getelementptr inbounds i8, ptr %dest, i64 %index
   %load = load <vscale x 16 x i8>, ptr %3, align 2
@@ -149,7 +149,7 @@ define void @wide_add_shift_add_rshrnb_h(ptr %dest, i64 %index, <vscale x 8 x i3
 ; CHECK-NEXT:    add z0.h, z1.h, z0.h
 ; CHECK-NEXT:    st1h { z0.h }, p0, [x0, x1, lsl #1]
 ; CHECK-NEXT:    ret
-  %1 = add <vscale x 8 x i32> %arg1, shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer)
+  %1 = add nuw <vscale x 8 x i32> %arg1, shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer)
   %2 = lshr <vscale x 8 x i32> %1, shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer)
   %3 = getelementptr inbounds i16, ptr %dest, i64 %index
   %load = load <vscale x 8 x i16>, ptr %3, align 2
@@ -170,7 +170,7 @@ define void @wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i6
 ; CHECK-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NEXT:    st1w { z0.s }, p0, [x0, x1, lsl #2]
 ; CHECK-NEXT:    ret
-  %1 = add <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2147483648, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
+  %1 = add nuw <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2147483648, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %2 = lshr <vscale x 4 x i64> %1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 32, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %3 = getelementptr inbounds i32, ptr %dest, i64 %index
   %load = load <vscale x 4 x i32>, ptr %3, align 4
@@ -193,7 +193,7 @@ define void @neg_wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4
 ; CHECK-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NEXT:    st1w { z0.s }, p1, [x0, x1, lsl #2]
 ; CHECK-NEXT:    ret
-  %1 = add <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 140737488355328, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
+  %1 = add nuw <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 140737488355328, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %2 = lshr <vscale x 4 x i64> %1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 48, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %3 = getelementptr inbounds i32, ptr %dest, i64 %index
   %load = load <vscale x 4 x i32>, ptr %3, align 4
@@ -213,7 +213,7 @@ define void @neg_trunc_lsr_add_op1_not_splat(ptr %ptr, ptr %dst, i64 %index, <vs
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add <vscale x 8 x i16> %load, %add_op1
+  %1 = add nuw <vscale x 8 x i16> %load, %add_op1
   %2 = lshr <vscale x 8 x i16> %1, shufflevector (<vscale x 8 x i16> insertelement (<vscale x 8 x i16> poison, i16 6, i64 0), <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -231,7 +231,7 @@ define void @neg_trunc_lsr_op1_not_splat(ptr %ptr, ptr %dst, i64 %index, <vscale
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add <vscale x 8 x i16> %load, shufflevector (<vscale x 8 x i16> insertelement (<vscale x 8 x i16> poison, i16 32, i64 0), <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer)
+  %1 = add nuw <vscale x 8 x i16> %load, shufflevector (<vscale x 8 x i16> insertelement (<vscale x 8 x i16> poison, i16 32, i64 0), <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer)
   %2 = lshr <vscale x 8 x i16> %1, %lshr_op1
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -251,9 +251,9 @@ define void @neg_add_has_two_uses(ptr %ptr, ptr %dst, ptr %dst2, i64 %index){
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x3]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add nuw <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
-  %3 = add <vscale x 8 x i16> %1, %1
+  %3 = add nuw <vscale x 8 x i16> %1, %1
   %4 = getelementptr inbounds i16, ptr %dst2, i64 %index
   %5 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %6 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -271,7 +271,7 @@ define void @add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1w { z0.d }, p0, [x1, x2, lsl #2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 2 x i64>, ptr %ptr, align 2
-  %1 = add <vscale x 2 x i64> %load, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 32, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
+  %1 = add nuw <vscale x 2 x i64> %load, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 32, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
   %2 = lshr <vscale x 2 x i64> %1, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 6, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
   %3 = trunc <vscale x 2 x i64> %2 to <vscale x 2 x i32>
   %4 = getelementptr inbounds i32, ptr %dst, i64 %index
@@ -288,7 +288,7 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1h { z0.d }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 2 x i64>, ptr %ptr, align 2
-  %1 = add <vscale x 2 x i64> %load, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 32, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
+  %1 = add nuw <vscale x 2 x i64> %load, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 32, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
   %2 = lshr <vscale x 2 x i64> %1, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 6, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
   %3 = trunc <vscale x 2 x i64> %2 to <vscale x 2 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -304,7 +304,7 @@ define void @masked_store_rshrnb(ptr %ptr, ptr %dst, i64 %index, <vscale x 8 x i
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %wide.masked.load = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %ptr, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> poison)
-  %1 = add <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add nuw <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
diff --git a/llvm/test/CodeGen/AArch64/sve2-rsh.ll b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
index 2bdfc1931cdc2..0e3bfb90a45f4 100644
--- a/llvm/test/CodeGen/AArch64/sve2-rsh.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
@@ -9,7 +9,7 @@ define <vscale x 2 x i64> @neg_urshr_1(<vscale x 2 x i64> %x) {
 ; CHECK-NEXT:    add z0.d, z0.d, #16 // =0x10
 ; CHECK-NEXT:    lsr z0.d, z0.d, #6
 ; CHECK-NEXT:    ret
-  %add = add <vscale x 2 x i64> %x, splat (i64 16)
+  %add = add nuw nsw <vscale x 2 x i64> %x, splat (i64 16)
   %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
   ret <vscale x 2 x i64> %sh
 }
@@ -22,7 +22,7 @@ define <vscale x 2 x i64> @neg_urshr_2(<vscale x 2 x i64> %x, <vscale x 2 x i64>
 ; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
 ; CHECK-NEXT:    lsr z0.d, p0/m, z0.d, z1.d
 ; CHECK-NEXT:    ret
-  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %add = add nuw nsw <vscale x 2 x i64> %x, splat (i64 32)
   %sh = lshr <vscale x 2 x i64> %add, %y
   ret <vscale x 2 x i64> %sh
 }
@@ -34,7 +34,7 @@ define <vscale x 2 x i64> @neg_urshr_3(<vscale x 2 x i64> %x, <vscale x 2 x i64>
 ; CHECK-NEXT:    add z0.d, z0.d, z1.d
 ; CHECK-NEXT:    lsr z0.d, z0.d, #6
 ; CHECK-NEXT:    ret
-  %add = add <vscale x 2 x i64> %x, %y
+  %add = add nuw nsw <vscale x 2 x i64> %x, %y
   %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
   ret <vscale x 2 x i64> %sh
 }
@@ -58,12 +58,24 @@ define <vscale x 2 x i64> @neg_urshr_4(<vscale x 2 x i64> %x) {
 ; CHECK-NEXT:    addvl sp, sp, #1
 ; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %add = add nuw nsw <vscale x 2 x i64> %x, splat (i64 32)
   %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
   call void @use(<vscale x 2 x i64> %add)
   ret <vscale x 2 x i64> %sh
 }
 
+; Add can overflow.
+define <vscale x 2 x i64> @neg_urshr_5(<vscale x 2 x i64> %x) {
+; CHECK-LABEL: neg_urshr_5:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
+; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
 define <vscale x 16 x i8> @urshr_i8(<vscale x 16 x i8> %x) {
 ; SVE-LABEL: urshr_i8:
 ; SVE:       // %bb.0:
@@ -76,11 +88,58 @@ define <vscale x 16 x i8> @urshr_i8(<vscale x 16 x i8> %x) {
 ; SVE2-NEXT:    ptrue p0.b
 ; SVE2-NEXT:    urshr z0.b, p0/m, z0.b, #6
 ; SVE2-NEXT:    ret
-  %add = add <vscale x 16 x i8> %x, splat (i8 32)
+  %add = add nuw nsw <vscale x 16 x i8> %x, splat (i8 32)
   %sh = lshr <vscale x 16 x i8> %add, splat (i8 6)
   ret <vscale x 16 x i8> %sh
 }
 
+define <vscale x 16 x i8> @urshr_8_wide_trunc(<vscale x 16 x i8> %x) {
+; SVE-LABEL: urshr_8_wide_trunc:
+; SVE:       // %bb.0:
+; SVE-NEXT:    uunpkhi z1.h, z0.b
+; SVE-NEXT:    uunpklo z0.h, z0.b
+; SVE-NEXT:    add z0.h, z0.h, #32 // =0x20
+; SVE-NEXT:    add z1.h, z1.h, #32 // =0x20
+; SVE-NEXT:    lsr z1.h, z1.h, #6
+; SVE-NEXT:    lsr z0.h, z0.h, #6
+; SVE-NEXT:    uzp1 z0.b, z0.b, z1.b
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_8_wide_trunc:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.b
+; SVE2-NEXT:    urshr z0.b, p0/m, z0.b, #6
+; SVE2-NEXT:    ret
+  %ext = zext <vscale x 16 x i8> %x to <vscale x 16 x i16>
+  %add = add nuw nsw <vscale x 16 x i16> %ext, splat (i16 32)
+  %sh = lshr <vscale x 16 x i16> %add, splat (i16 6)
+  %sht = trunc <vscale x 16 x i16> %sh to <vscale x 16 x i8>
+  ret <vscale x 16 x i8> %sht
+}
+
+define <vscale x 16 x i8> @urshr_8_wide_trunc_nomerge(<vscale x 16 x i16> %ext) {
+; SVE-LABEL: urshr_8_wide_trunc_nomerge:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.h, z0.h, #256 // =0x100
+; SVE-NEXT:    add z1.h, z1.h, #256 // =0x100
+; SVE-NEXT:    lsr z1.h, z1.h, #9
+; SVE-NEXT:    lsr z0.h, z0.h, #9
+; SVE-NEXT:    uzp1 z0.b, z0.b, z1.b
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_8_wide_trunc_nomerge:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.h
+; SVE2-NEXT:    urshr z1.h, p0/m, z1.h, #9
+; SVE2-NEXT:    urshr z0.h, p0/m, z0.h, #9
+; SVE2-NEXT:    uzp1 z0.b, z0.b, z1.b
+; SVE2-NEXT:    ret
+  %add = add nuw nsw <vscale x 16 x i16> %ext, splat (i16 256)
+  %sh = lshr <vscale x 16 x i16> %add, splat (i16 9)
+  %sht = trunc <vscale x 16 x i16> %sh to <vscale x 16 x i8>
+  ret <vscale x 16 x i8> %sht
+}
+
 define <vscale x 8 x i16> @urshr_i16(<vscale x 8 x i16> %x) {
 ; SVE-LABEL: urshr_i16:
 ; SVE:       // %bb.0:
@@ -93,110 +152,138 @@ define <vscale x 8 x i16> @urshr_i16(<vscale x 8 x i16> %x) {
 ; SVE2-NEXT:    ptrue p0.h
 ; SVE2-NEXT:    urshr z0.h, p0/m, z0.h, #6
 ; SVE2-NEXT:    ret
-  %add = add <vscale x 8 x i16> %x, splat (i16 32)
+  %add = add nuw nsw <vscale x 8 x i16> %x, splat (i16 32)
   %sh = lshr <vscale x 8 x i16> %add, splat (i16 6)
   ret <vscale x 8 x i16> %sh
 }
 
-define <vscale x 4 x i32> @urshr_i32(<vscale x 4 x i32> %x) {
-; SVE-LABEL: urshr_i32:
+define <vscale x 8 x i16> @urshr_16_wide_trunc(<vscale x 8 x i16> %x) {
+; SVE-LABEL: urshr_16_wide_trunc:
 ; SVE:       // %bb.0:
+; SVE-NEXT:    uunpkhi z1.s, z0.h
+; SVE-NEXT:    uunpklo z0.s, z0.h
 ; SVE-NEXT:    add z0.s, z0.s, #32 // =0x20
+; SVE-NEXT:    add z1.s, z1.s, #32 // =0x20
+; SVE-NEXT:    lsr z1.s, z1.s, #6
 ; SVE-NEXT:    lsr z0.s, z0.s, #6
+; SVE-NEXT:    uzp1 z0.h, z0.h, z1.h
 ; SVE-NEXT:    ret
 ;
-; SVE2-LABEL: urshr_i32:
+; SVE2-LABEL: urshr_16_wide_trunc:
 ; SVE2:       // %bb.0:
-; SVE2-NEXT:    ptrue p0.s
-; SVE2-NEXT:    urshr z0.s, p0/m, z0.s, #6
+; SVE2-NEXT:    ptrue p0.h
+; SVE2-NEXT:    urshr z0.h, p0/m, z0.h, #6
 ; SVE2-NEXT:    ret
-  %add = add <vscale x 4 x i32> %x, splat (i32 32)
-  %sh = lshr <vscale x 4 x i32> %add, splat (i32 6)
-  ret <vscale x 4 x i32> %sh
+  %ext = zext <vscale x 8 x i16> %x to <vscale x 8 x i32>
+  %add = add nuw nsw <vscale x 8 x i32> %ext, splat (i32 32)
+  %sh = lshr <vscale x 8 x i32> %add, splat (i32 6)
+  %sht = trunc <vscale x 8 x i32> %sh to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %sht
 }
 
-define <vscale x 2 x i64> @urshr_i64(<vscale x 2 x i64> %x) {
-; SVE-LABEL: urshr_i64:
+define <vscale x 8 x i16> @urshr_16_wide_trunc_nomerge(<vscale x 8 x i32> %ext) {
+; SVE-LABEL: urshr_16_wide_trunc_nomerge:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    add z0.d, z0.d, #32 // =0x20
-; SVE-NEXT:    lsr z0.d, z0.d, #6
+; SVE-NEXT:    mov z2.s, #0x10000
+; SVE-NEXT:    add z0.s, z0.s, z2.s
+; SVE-NEXT:    add z1.s, z1.s, z2.s
+; SVE-NEXT:    lsr z1.s, z1.s, #17
+; SVE-NEXT:    lsr z0.s, z0.s, #17
+; SVE-NEXT:    uzp1 z0.h, z0.h, z1.h
 ; SVE-NEXT:    ret
 ;
-; SVE2-LABEL: urshr_i64:
+; SVE2-LABEL: urshr_16_wide_trunc_nomerge:
 ; SVE2:       // %bb.0:
-; SVE2-NEXT:    ptrue p0.d
-; SVE2-NEXT:    urshr z0.d, p0/m, z0.d, #6
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    urshr z1.s, p0/m, z1.s, #17
+; SVE2-NEXT:    urshr z0.s, p0/m, z0.s, #17
+; SVE2-NEXT:    uzp1 z0.h, z0.h, z1.h
 ; SVE2-NEXT:    ret
-  %add = add <vscale x 2 x i64> %x, splat (i64 32)
-  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
-  ret <vscale x 2 x i64> %sh
+  %add = add nuw nsw <vscale x 8 x i32> %ext, splat (i32 65536)
+  %sh = lshr <vscale x 8 x i32> %add, splat (i32 17)
+  %sht = trunc <vscale x 8 x i32> %sh to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %sht
 }
 
-define <vscale x 16 x i8> @srshr_i8(<vscale x 16 x i8> %x) {
-; SVE-LABEL: srshr_i8:
+define <vscale x 4 x i32> @urshr_i32(<vscale x 4 x i32> %x) {
+; SVE-LABEL: urshr_i32:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    add z0.b, z0.b, #32 // =0x20
-; SVE-NEXT:    asr z0.b, z0.b, #6
+; SVE-NEXT:    add z0.s, z0.s, #32 // =0x20
+; SVE-NEXT:    lsr z0.s, z0.s, #6
 ; SVE-NEXT:    ret
 ;
-; SVE2-LABEL: srshr_i8:
+; SVE2-LABEL: urshr_i32:
 ; SVE2:       // %bb.0:
-; SVE2-NEXT:    ptrue p0.b
-; SVE2-NEXT:    srshr z0.b, p0/m, z0.b, #6
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    urshr z0.s, p0/m, z0.s, #6
 ; SVE2-NEXT:    ret
-  %add = add <vscale x 16 x i8> %x, splat (i8 32)
-  %sh = ashr <vscale x 16 x i8> %add, splat (i8 6)
-  ret <vscale x 16 x i8> %sh
+  %add = add nuw nsw <vscale x 4 x i32> %x, splat (i32 32)
+  %sh = lshr <vscale x 4 x i32> %add, splat (i32 6)
+  ret <vscale x 4 x i32> %sh
 }
 
-define <vscale x 8 x i16> @srshr_i16(<vscale x 8 x i16> %x) {
-; SVE-LABEL: srshr_i16:
+define <vscale x 4 x i32> @urshr_32_wide_trunc(<vscale x 4 x i32> %x) {
+; SVE-LABEL: urshr_32_wide_trunc:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    add z0.h, z0.h, #32 // =0x20
-; SVE-NEXT:    asr z0.h, z0.h, #6
+; SVE-NEXT:    uunpkhi z1.d, z0.s
+; SVE-NEXT:    uunpklo z0.d, z0.s
+; SVE-NEXT:    add z0.d, z0.d, #32 // =0x20
+; SVE-NEXT:    add z1.d, z1.d, #32 // =0x20
+; SVE-NEXT:    lsr z1.d, z1.d, #6
+; SVE-NEXT:    lsr z0.d, z0.d, #6
+; SVE-NEXT:    uzp1 z0.s, z0.s, z1.s
 ; SVE-NEXT:    ret
 ;
-; SVE2-LABEL: srshr_i16:
+; SVE2-LABEL: urshr_32_wide_trunc:
 ; SVE2:       // %bb.0:
-; SVE2-NEXT:    ptrue p0.h
-; SVE2-NEXT:    srshr z0.h, p0/m, z0.h, #6
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    urshr z0.s, p0/m, z0.s, #6
 ; SVE2-NEXT:    ret
-  %add = add <vscale x 8 x i16> %x, splat (i16 32)
-  %sh = ashr <vscale x 8 x i16> %add, splat (i16 6)
-  ret <vscale x 8 x i16> %sh
+  %ext = zext <vscale x 4 x i32> %x to <vscale x 4 x i64>
+  %add = add nuw nsw <vscale x 4 x i64> %ext, splat (i64 32)
+  %sh = lshr <vscale x 4 x i64> %add, splat (i64 6)
+  %sht = trunc <vscale x 4 x i64> %sh to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %sht
 }
 
-define <vscale x 4 x i32> @srshr_i32(<vscale x 4 x i32> %x) {
-; SVE-LABEL: srshr_i32:
+define <vscale x 4 x i32> @urshr_32_wide_trunc_nomerge(<vscale x 4 x i64> %ext) {
+; SVE-LABEL: urshr_32_wide_trunc_nomerge:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    add z0.s, z0.s, #32 // =0x20
-; SVE-NEXT:    asr z0.s, z0.s, #6
+; SVE-NEXT:    mov z2.d, #0x100000000
+; SVE-NEXT:    add z0.d, z0.d, z2.d
+; SVE-NEXT:    add z1.d, z1.d, z2.d
+; SVE-NEXT:    lsr z1.d, z1.d, #33
+; SVE-NEXT:    lsr z0.d, z0.d, #33
+; SVE-NEXT:    uzp1 z0.s, z0.s, z1.s
 ; SVE-NEXT:    ret
 ;
-; SVE2-LABEL: srshr_i32:
+; SVE2-LABEL: urshr_32_wide_trunc_nomerge:
 ; SVE2:       // %bb.0:
-; SVE2-NEXT:    ptrue p0.s
-; SVE2-NEXT:    srshr z0.s, p0/m, z0.s, #6
+; SVE2-NEXT:    ptrue p0.d
+; SVE2-NEXT:    urshr z1.d, p0/m, z1.d, #33
+; SVE2-NEXT:    urshr z0.d, p0/m, z0.d, #33
+; SVE2-NEXT:    uzp1 z0.s, z0.s, z1.s
 ; SVE2-NEXT:    ret
-  %add = add <vscale x 4 x i32> %x, splat (i32 32)
-  %sh = ashr <vscale x 4 x i32> %add, splat (i32 6)
-  ret <vscale x 4 x i32> %sh
+  %add = add nuw nsw <vscale x 4 x i64> %ext, splat (i64 4294967296)
+  %sh = lshr <vscale x 4 x i64> %add, splat (i64 33)
+  %sht = trunc <vscale x 4 x i64> %sh to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %sht
 }
 
-define <vscale x 2 x i64> @srshr_i64(<vscale x 2 x i64> %x) {
-; SVE-LABEL: srshr_i64:
+define <vscale x 2 x i64> @urshr_i64(<vscale x 2 x i64> %x) {
+; SVE-LABEL: urshr_i64:
 ; SVE:       // %bb.0:
 ; SVE-NEXT:    add z0.d, z0.d, #32 // =0x20
-; SVE-NEXT:    asr z0.d, z0.d, #6
+; SVE-NEXT:    lsr z0.d, z0.d, #6
 ; SVE-NEXT:    ret
 ;
-; SVE2-LABEL: srshr_i64:
+; SVE2-LABEL: urshr_i64:
 ; SVE2:       // %bb.0:
 ; SVE2-NEXT:    ptrue p0.d
-; SVE2-NEXT:    srshr z0.d, p0/m, z0.d, #6
+; SVE2-NEXT:    urshr z0.d, p0/m, z0.d, #6
 ; SVE2-NEXT:    ret
-  %add = add <vscale x 2 x i64> %x, splat (i64 32)
-  %sh = ashr <vscale x 2 x i64> %add, splat (i64 6)
+  %add = add nuw nsw <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
   ret <vscale x 2 x i64> %sh
 }
 

>From a63ba9b1143558cd61a9ae01b9d4ae02f1028dff Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Mon, 29 Jan 2024 20:57:49 -0800
Subject: [PATCH 3/4] fix overflow handling

Change-Id: Id6dceead02c7473ed5c3635c2b56c7f367315563
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 132 ++++++++++--------
 .../AArch64/sve2-intrinsics-combine-rshrnb.ll |  53 +++----
 2 files changed, 104 insertions(+), 81 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b06828f58ddfc..181b73fa478e5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -13840,42 +13840,49 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
   return SDValue();
 }
 
-static SDValue tryLowerToRoundingShiftRightByImm(SDValue Shift,
-                                                 SelectionDAG &DAG) {
+// Check if we can we lower this SRL to a rounding shift instruction. ResVT is
+// possibly a truncated type, it tells how many bits of the value are to be
+// used.
+static bool canLowerSRLToRoundingShiftForVT(SDValue Shift, EVT ResVT,
+                                            SelectionDAG &DAG,
+                                            unsigned &ShiftValue,
+                                            SDValue &RShOperand) {
   if (Shift->getOpcode() != ISD::SRL)
-    return SDValue();
+    return false;
 
-  EVT ResVT = Shift.getValueType();
-  assert(ResVT.isScalableVT());
+  EVT VT = Shift.getValueType();
+  assert(VT.isScalableVT());
 
   auto ShiftOp1 =
       dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
   if (!ShiftOp1)
-    return SDValue();
-  unsigned ShiftValue = ShiftOp1->getZExtValue();
+    return false;
 
+  ShiftValue = ShiftOp1->getZExtValue();
   if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
-    return SDValue();
+    return false;
 
   SDValue Add = Shift->getOperand(0);
   if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
-    return SDValue();
+    return false;
 
-  if (!Add->getFlags().hasNoUnsignedWrap())
-    return SDValue();
+  assert(ResVT.getScalarSizeInBits() <= VT.getScalarSizeInBits() &&
+         "ResVT must be truncated or same type as the shift.");
+  // Check if an overflow can lead to incorrect results.
+  uint64_t ExtraBits = VT.getScalarSizeInBits() - ResVT.getScalarSizeInBits();
+  if (ShiftValue > ExtraBits && !Add->getFlags().hasNoUnsignedWrap())
+    return false;
 
   auto AddOp1 =
       dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
   if (!AddOp1)
-    return SDValue();
+    return false;
   uint64_t AddValue = AddOp1->getZExtValue();
   if (AddValue != 1ULL << (ShiftValue - 1))
-    return SDValue();
+    return false;
 
-  SDLoc DL(Shift);
-  return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
-                     getPredicateForVector(DAG, DL, ResVT), Add->getOperand(0),
-                     DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
+  RShOperand = Add->getOperand(0);
+  return true;
 }
 
 SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
@@ -13903,9 +13910,14 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
                        Op.getOperand(0), Op.getOperand(1));
   case ISD::SRA:
   case ISD::SRL:
-    if (VT.isScalableVector() && Subtarget->hasSVE2orSME())
-      if (SDValue RSH = tryLowerToRoundingShiftRightByImm(Op, DAG))
-        return RSH;
+    if (VT.isScalableVector() && Subtarget->hasSVE2orSME()) {
+      SDValue RShOperand;
+      unsigned ShiftValue;
+      if (canLowerSRLToRoundingShiftForVT(Op, VT, DAG, ShiftValue, RShOperand))
+        return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, VT,
+                           getPredicateForVector(DAG, DL, VT), RShOperand,
+                           DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
+    }
 
     if (VT.isScalableVector() ||
         useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
@@ -20855,9 +20867,20 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static bool isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode *N) {
+  if (N->getOpcode() != AArch64ISD::UZP1)
+    return false;
+  SDValue Op0 = N->getOperand(0);
+  EVT SrcVT = Op0->getValueType(0);
+  EVT DstVT = N->getValueType(0);
+  return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv16i8) ||
+         (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv8i16) ||
+         (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv4i32);
+}
+
 // Try to combine rounding shifts where the operands come from an extend, and
 // the result is truncated and combined into one vector.
-//   uzp1(urshr(uunpklo(X),C), urshr(uunpkhi(X), C)) -> urshr(X, C)
+//   uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C)) -> urshr(X, C)
 static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
   assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
   SDValue Op0 = N->getOperand(0);
@@ -20865,59 +20888,46 @@ static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
   EVT VT = Op0->getValueType(0);
   EVT ResVT = N->getValueType(0);
 
-  // Truncating combine?
-  if (ResVT.widenIntegerVectorElementType(*DAG.getContext()) !=
-      VT.getDoubleNumVectorElementsVT(*DAG.getContext()))
-    return SDValue();
-
   unsigned RshOpc = Op0.getOpcode();
-  if (RshOpc != AArch64ISD::URSHR_I_PRED)
-    return SDValue();
-  if (!isAllActivePredicate(DAG, Op0.getOperand(0)) ||
-      !isAllActivePredicate(DAG, Op1.getOperand(0)))
+  if (RshOpc != AArch64ISD::RSHRNB_I)
     return SDValue();
 
   // Same op code and imm value?
-  SDValue ShiftValue = Op0.getOperand(2);
-  if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(2))
-    return SDValue();
-  // We cannot reduce the type if shift value is too large for type.
-  if (ShiftValue->getAsZExtVal() > ResVT.getScalarSizeInBits())
+  SDValue ShiftValue = Op0.getOperand(1);
+  if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(1))
     return SDValue();
 
   // Same unextended operand value?
-  SDValue Lo = Op0.getOperand(1);
-  SDValue Hi = Op1.getOperand(1);
+  SDValue Lo = Op0.getOperand(0);
+  SDValue Hi = Op1.getOperand(0);
   if (Lo.getOpcode() != AArch64ISD::UUNPKLO &&
       Hi.getOpcode() != AArch64ISD::UUNPKHI)
     return SDValue();
   SDValue OrigArg = Lo.getOperand(0);
-  if (OrigArg != Op1.getOperand(1).getOperand(0))
+  if (OrigArg != Hi.getOperand(0))
     return SDValue();
 
   SDLoc DL(N);
-  return DAG.getNode(RshOpc, DL, ResVT, getPredicateForVector(DAG, DL, ResVT),
-                     OrigArg, Op0.getOperand(2));
+  return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
+                     getPredicateForVector(DAG, DL, ResVT), OrigArg,
+                     ShiftValue);
 }
 
 // Try to simplify:
-//    t1 = nxv8i16 urshr(X, shiftvalue)
+//    t1 = nxv8i16 add(X, 1 << (ShiftValue - 1))
+//    t2 = nxv8i16 srl(t1, ShiftValue)
 // to
-//    t1 = nxv16i8 rshrnb(X, shiftvalue).
-//    t2 = nxv8i16 = bitcast t1
+//    t1 = nxv8i16 rshrnb(X, shiftvalue).
 // rshrnb will zero the top half bits of each element. Therefore, this combine
 // should only be performed when a following instruction with the rshrnb
 // as an operand does not care about the top half of each element. For example,
 // a uzp1 or a truncating store.
 static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
                                          const AArch64Subtarget *Subtarget) {
-  if (Srl->getOpcode() != AArch64ISD::URSHR_I_PRED)
-    return SDValue();
-
-  if (!isAllActivePredicate(DAG, Srl.getOperand(0)))
+  EVT VT = Srl->getValueType(0);
+  if (!VT.isScalableVector() || !Subtarget->hasSVE2())
     return SDValue();
 
-  EVT VT = Srl->getValueType(0);
   EVT ResVT;
   if (VT == MVT::nxv8i16)
     ResVT = MVT::nxv16i8;
@@ -20928,13 +20938,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
   else
     return SDValue();
 
-  unsigned ShiftValue = Srl->getOperand(2)->getAsZExtVal();
-  if (ShiftValue > ResVT.getScalarSizeInBits())
-    return SDValue();
-
   SDLoc DL(Srl);
-  SDValue Rshrnb = DAG.getNode(AArch64ISD::RSHRNB_I, DL, ResVT,
-                               {Srl->getOperand(1), Srl->getOperand(2)});
+  unsigned ShiftValue;
+  SDValue RShOperand;
+  if (!canLowerSRLToRoundingShiftForVT(Srl, ResVT, DAG, ShiftValue, RShOperand))
+    return SDValue();
+  SDValue Rshrnb = DAG.getNode(
+      AArch64ISD::RSHRNB_I, DL, ResVT,
+      {RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
   return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
 }
 
@@ -20972,8 +20983,8 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
     }
   }
 
-  if (SDValue RSh = tryCombineExtendRShTrunc(N, DAG))
-    return RSh;
+  if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
+    return Urshr;
 
   if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
     return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
@@ -20981,6 +20992,15 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
   if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
     return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
 
+  // uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
+  if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
+      Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
+    if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
+      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
+                         Op1.getOperand(0));
+    }
+  }
+
   // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
   if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
     if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
index 2795b6bd3fd45..0afd11d098a00 100644
--- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
@@ -10,7 +10,7 @@ define void @add_lshr_rshrnb_b_6(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -28,7 +28,7 @@ define void @neg_add_lshr_rshrnb_b_6(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 1, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 1, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -45,7 +45,7 @@ define void @add_lshr_rshrnb_h_7(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 64, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 64, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 7, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -62,7 +62,7 @@ define void @add_lshr_rshrn_h_6(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1h { z0.s }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 4 x i32>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 32, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
+  %1 = add <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 32, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %2 = lshr <vscale x 4 x i32> %1, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 6, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %3 = trunc <vscale x 4 x i32> %2 to <vscale x 4 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -79,7 +79,7 @@ define void @add_lshr_rshrnb_h_2(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1h { z0.s }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 4 x i32>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
+  %1 = add <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %2 = lshr <vscale x 4 x i32> %1, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %3 = trunc <vscale x 4 x i32> %2 to <vscale x 4 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -92,7 +92,7 @@ define void @neg_add_lshr_rshrnb_h_0(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ret
   %load = load <vscale x 4 x i32>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
+  %1 = add <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %2 = lshr <vscale x 4 x i32> %1, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 -1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %3 = trunc <vscale x 4 x i32> %2 to <vscale x 4 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -109,7 +109,7 @@ define void @neg_zero_shift(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1h { z0.s }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 4 x i32>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
+  %1 = add <vscale x 4 x i32> %load, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 1, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %2 = lshr <vscale x 4 x i32> %1, trunc (<vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 0, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer) to <vscale x 4 x i32>)
   %3 = trunc <vscale x 4 x i32> %2 to <vscale x 4 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -128,7 +128,7 @@ define void @wide_add_shift_add_rshrnb_b(ptr %dest, i64 %index, <vscale x 16 x i
 ; CHECK-NEXT:    add z0.b, z1.b, z0.b
 ; CHECK-NEXT:    st1b { z0.b }, p0, [x0, x1]
 ; CHECK-NEXT:    ret
-  %1 = add nuw <vscale x 16 x i16> %arg1, shufflevector (<vscale x 16 x i16> insertelement (<vscale x 16 x i16> poison, i16 32, i64 0), <vscale x 16 x i16> poison, <vscale x 16 x i32> zeroinitializer)
+  %1 = add <vscale x 16 x i16> %arg1, shufflevector (<vscale x 16 x i16> insertelement (<vscale x 16 x i16> poison, i16 32, i64 0), <vscale x 16 x i16> poison, <vscale x 16 x i32> zeroinitializer)
   %2 = lshr <vscale x 16 x i16> %1, shufflevector (<vscale x 16 x i16> insertelement (<vscale x 16 x i16> poison, i16 6, i64 0), <vscale x 16 x i16> poison, <vscale x 16 x i32> zeroinitializer)
   %3 = getelementptr inbounds i8, ptr %dest, i64 %index
   %load = load <vscale x 16 x i8>, ptr %3, align 2
@@ -149,7 +149,7 @@ define void @wide_add_shift_add_rshrnb_h(ptr %dest, i64 %index, <vscale x 8 x i3
 ; CHECK-NEXT:    add z0.h, z1.h, z0.h
 ; CHECK-NEXT:    st1h { z0.h }, p0, [x0, x1, lsl #1]
 ; CHECK-NEXT:    ret
-  %1 = add nuw <vscale x 8 x i32> %arg1, shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer)
+  %1 = add <vscale x 8 x i32> %arg1, shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer)
   %2 = lshr <vscale x 8 x i32> %1, shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer)
   %3 = getelementptr inbounds i16, ptr %dest, i64 %index
   %load = load <vscale x 8 x i16>, ptr %3, align 2
@@ -170,7 +170,7 @@ define void @wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i6
 ; CHECK-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NEXT:    st1w { z0.s }, p0, [x0, x1, lsl #2]
 ; CHECK-NEXT:    ret
-  %1 = add nuw <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2147483648, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
+  %1 = add <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 2147483648, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %2 = lshr <vscale x 4 x i64> %1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 32, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %3 = getelementptr inbounds i32, ptr %dest, i64 %index
   %load = load <vscale x 4 x i32>, ptr %3, align 4
@@ -184,16 +184,18 @@ define void @wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i6
 define void @neg_wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i64> %arg1){
 ; CHECK-LABEL: neg_wide_add_shift_add_rshrnb_d:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    ptrue p1.s
-; CHECK-NEXT:    urshr z1.d, p0/m, z1.d, #48
-; CHECK-NEXT:    urshr z0.d, p0/m, z0.d, #48
+; CHECK-NEXT:    mov z2.d, #0x800000000000
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    add z1.d, z1.d, z2.d
+; CHECK-NEXT:    lsr z1.d, z1.d, #48
+; CHECK-NEXT:    lsr z0.d, z0.d, #48
 ; CHECK-NEXT:    uzp1 z0.s, z0.s, z1.s
-; CHECK-NEXT:    ld1w { z1.s }, p1/z, [x0, x1, lsl #2]
+; CHECK-NEXT:    ld1w { z1.s }, p0/z, [x0, x1, lsl #2]
 ; CHECK-NEXT:    add z0.s, z1.s, z0.s
-; CHECK-NEXT:    st1w { z0.s }, p1, [x0, x1, lsl #2]
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, x1, lsl #2]
 ; CHECK-NEXT:    ret
-  %1 = add nuw <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 140737488355328, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
+  %1 = add <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 140737488355328, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %2 = lshr <vscale x 4 x i64> %1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 48, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %3 = getelementptr inbounds i32, ptr %dest, i64 %index
   %load = load <vscale x 4 x i32>, ptr %3, align 4
@@ -213,7 +215,7 @@ define void @neg_trunc_lsr_add_op1_not_splat(ptr %ptr, ptr %dst, i64 %index, <vs
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 8 x i16> %load, %add_op1
+  %1 = add <vscale x 8 x i16> %load, %add_op1
   %2 = lshr <vscale x 8 x i16> %1, shufflevector (<vscale x 8 x i16> insertelement (<vscale x 8 x i16> poison, i16 6, i64 0), <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -231,7 +233,7 @@ define void @neg_trunc_lsr_op1_not_splat(ptr %ptr, ptr %dst, i64 %index, <vscale
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 8 x i16> %load, shufflevector (<vscale x 8 x i16> insertelement (<vscale x 8 x i16> poison, i16 32, i64 0), <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer)
+  %1 = add <vscale x 8 x i16> %load, shufflevector (<vscale x 8 x i16> insertelement (<vscale x 8 x i16> poison, i16 32, i64 0), <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer)
   %2 = lshr <vscale x 8 x i16> %1, %lshr_op1
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -251,9 +253,9 @@ define void @neg_add_has_two_uses(ptr %ptr, ptr %dst, ptr %dst2, i64 %index){
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x3]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 8 x i16>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add <vscale x 8 x i16> %load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
-  %3 = add nuw <vscale x 8 x i16> %1, %1
+  %3 = add <vscale x 8 x i16> %1, %1
   %4 = getelementptr inbounds i16, ptr %dst2, i64 %index
   %5 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %6 = getelementptr inbounds i8, ptr %dst, i64 %index
@@ -271,7 +273,7 @@ define void @add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK-NEXT:    st1w { z0.d }, p0, [x1, x2, lsl #2]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 2 x i64>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 2 x i64> %load, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 32, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
+  %1 = add <vscale x 2 x i64> %load, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 32, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
   %2 = lshr <vscale x 2 x i64> %1, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 6, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
   %3 = trunc <vscale x 2 x i64> %2 to <vscale x 2 x i32>
   %4 = getelementptr inbounds i32, ptr %dst, i64 %index
@@ -284,11 +286,12 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0]
-; CHECK-NEXT:    urshr z0.d, p0/m, z0.d, #6
+; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
+; CHECK-NEXT:    lsr z0.d, z0.d, #6
 ; CHECK-NEXT:    st1h { z0.d }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 2 x i64>, ptr %ptr, align 2
-  %1 = add nuw <vscale x 2 x i64> %load, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 32, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
+  %1 = add <vscale x 2 x i64> %load, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 32, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
   %2 = lshr <vscale x 2 x i64> %1, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 6, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
   %3 = trunc <vscale x 2 x i64> %2 to <vscale x 2 x i16>
   %4 = getelementptr inbounds i16, ptr %dst, i64 %index
@@ -304,7 +307,7 @@ define void @masked_store_rshrnb(ptr %ptr, ptr %dst, i64 %index, <vscale x 8 x i
 ; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
 ; CHECK-NEXT:    ret
   %wide.masked.load = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %ptr, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> poison)
-  %1 = add nuw <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %1 = add <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
   %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
   %4 = getelementptr inbounds i8, ptr %dst, i64 %index

>From 22e9648fd1884ae6485995432167dca308245a28 Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Wed, 31 Jan 2024 13:26:28 -0800
Subject: [PATCH 4/4] address comments

Change-Id: I076f19c947696100ec469c8407b6d235d6444145
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 23 +++++++++--------
 llvm/test/CodeGen/AArch64/sve2-rsh.ll         | 25 ++++++-------------
 2 files changed, 20 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 181b73fa478e5..0a0e5ce73ec13 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20885,7 +20885,6 @@ static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
   assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
   SDValue Op0 = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
-  EVT VT = Op0->getValueType(0);
   EVT ResVT = N->getValueType(0);
 
   unsigned RshOpc = Op0.getOpcode();
@@ -20992,15 +20991,6 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
   if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
     return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
 
-  // uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
-  if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
-      Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
-    if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
-      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
-                         Op1.getOperand(0));
-    }
-  }
-
   // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
   if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
     if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
@@ -21025,6 +21015,19 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
   if (!IsLittleEndian)
     return SDValue();
 
+  // uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
+  // Example:
+  // nxv4i32 = uzp1 bitcast(nxv4i32 x to nxv2i64), bitcast(nxv4i32 y to nxv2i64)
+  // to
+  // nxv4i32 = uzp1 nxv2i64, nxv2i64
+  if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
+      Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
+    if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
+      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
+                         Op1.getOperand(0));
+    }
+  }
+
   if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
     return SDValue();
 
diff --git a/llvm/test/CodeGen/AArch64/sve2-rsh.ll b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
index 0e3bfb90a45f4..516ef3bd581ee 100644
--- a/llvm/test/CodeGen/AArch64/sve2-rsh.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
@@ -40,27 +40,18 @@ define <vscale x 2 x i64> @neg_urshr_3(<vscale x 2 x i64> %x, <vscale x 2 x i64>
 }
 
 ; Add has two uses.
-define <vscale x 2 x i64> @neg_urshr_4(<vscale x 2 x i64> %x) {
+define <vscale x 2 x i64> @neg_urshr_4(<vscale x 2 x i64> %x, ptr %p) {
 ; CHECK-LABEL: neg_urshr_4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
-; CHECK-NEXT:    addvl sp, sp, #-1
-; CHECK-NEXT:    str z8, [sp] // 16-byte Folded Spill
-; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
-; CHECK-NEXT:    .cfi_offset w30, -8
-; CHECK-NEXT:    .cfi_offset w29, -16
-; CHECK-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
-; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
-; CHECK-NEXT:    lsr z8.d, z0.d, #6
-; CHECK-NEXT:    bl use
-; CHECK-NEXT:    mov z0.d, z8.d
-; CHECK-NEXT:    ldr z8, [sp] // 16-byte Folded Reload
-; CHECK-NEXT:    addvl sp, sp, #1
-; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    mov z1.d, z0.d
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    add z1.d, z1.d, #32 // =0x20
+; CHECK-NEXT:    lsr z0.d, z1.d, #6
+; CHECK-NEXT:    st1d { z1.d }, p0, [x0]
 ; CHECK-NEXT:    ret
   %add = add nuw nsw <vscale x 2 x i64> %x, splat (i64 32)
   %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
-  call void @use(<vscale x 2 x i64> %add)
+  store <vscale x 2 x i64> %add, ptr %p
   ret <vscale x 2 x i64> %sh
 }
 
@@ -286,5 +277,3 @@ define <vscale x 2 x i64> @urshr_i64(<vscale x 2 x i64> %x) {
   %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
   ret <vscale x 2 x i64> %sh
 }
-
-declare void @use(<vscale x 2 x i64>)



More information about the llvm-commits mailing list