[llvm] 1d14323 - [AArch64][SVE2] Generate urshr rounding shift rights (#78374)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 31 14:04:02 PST 2024


Author: Usman Nadeem
Date: 2024-01-31T14:03:58-08:00
New Revision: 1d1432356e656fcae7b2a3634a2b349334ba3d80

URL: https://github.com/llvm/llvm-project/commit/1d1432356e656fcae7b2a3634a2b349334ba3d80
DIFF: https://github.com/llvm/llvm-project/commit/1d1432356e656fcae7b2a3634a2b349334ba3d80.diff

LOG: [AArch64][SVE2] Generate urshr rounding shift rights (#78374)

Add a new node `AArch64ISD::URSHR_I_PRED`.

`srl(add(X, 1 << (ShiftValue - 1)), ShiftValue)` is transformed to
`urshr`, or to `rshrnb` (as before) if the result it truncated.

`uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C))` is converted to
`urshr(X, C)` (tested by the wide_trunc tests).

Pattern matching code in `canLowerSRLToRoundingShiftForVT` is taken
from prior code in rshrnb. It returns true if the add has NUW or if the
number of bits used in the return value allow us to not care about the
overflow (tested by rshrnb test cases).

Added: 
    llvm/test/CodeGen/AArch64/sve2-rsh.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 823d181efc4f0..bb19aef978b94 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2690,6 +2690,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::RSHRNB_I)
     MAKE_CASE(AArch64ISD::CTTZ_ELTS)
     MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
+    MAKE_CASE(AArch64ISD::URSHR_I_PRED)
   }
 #undef MAKE_CASE
   return nullptr;
@@ -2974,6 +2975,7 @@ static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
 static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
 static SDValue convertFixedMaskToScalableVector(SDValue Mask,
                                                 SelectionDAG &DAG);
+static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
 static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
                                              EVT VT);
 
@@ -13862,6 +13864,51 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
   return SDValue();
 }
 
+// Check if we can we lower this SRL to a rounding shift instruction. ResVT is
+// possibly a truncated type, it tells how many bits of the value are to be
+// used.
+static bool canLowerSRLToRoundingShiftForVT(SDValue Shift, EVT ResVT,
+                                            SelectionDAG &DAG,
+                                            unsigned &ShiftValue,
+                                            SDValue &RShOperand) {
+  if (Shift->getOpcode() != ISD::SRL)
+    return false;
+
+  EVT VT = Shift.getValueType();
+  assert(VT.isScalableVT());
+
+  auto ShiftOp1 =
+      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
+  if (!ShiftOp1)
+    return false;
+
+  ShiftValue = ShiftOp1->getZExtValue();
+  if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
+    return false;
+
+  SDValue Add = Shift->getOperand(0);
+  if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
+    return false;
+
+  assert(ResVT.getScalarSizeInBits() <= VT.getScalarSizeInBits() &&
+         "ResVT must be truncated or same type as the shift.");
+  // Check if an overflow can lead to incorrect results.
+  uint64_t ExtraBits = VT.getScalarSizeInBits() - ResVT.getScalarSizeInBits();
+  if (ShiftValue > ExtraBits && !Add->getFlags().hasNoUnsignedWrap())
+    return false;
+
+  auto AddOp1 =
+      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
+  if (!AddOp1)
+    return false;
+  uint64_t AddValue = AddOp1->getZExtValue();
+  if (AddValue != 1ULL << (ShiftValue - 1))
+    return false;
+
+  RShOperand = Add->getOperand(0);
+  return true;
+}
+
 SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
                                                       SelectionDAG &DAG) const {
   EVT VT = Op.getValueType();
@@ -13887,6 +13934,15 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
                        Op.getOperand(0), Op.getOperand(1));
   case ISD::SRA:
   case ISD::SRL:
+    if (VT.isScalableVector() && Subtarget->hasSVE2orSME()) {
+      SDValue RShOperand;
+      unsigned ShiftValue;
+      if (canLowerSRLToRoundingShiftForVT(Op, VT, DAG, ShiftValue, RShOperand))
+        return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, VT,
+                           getPredicateForVector(DAG, DL, VT), RShOperand,
+                           DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
+    }
+
     if (VT.isScalableVector() ||
         useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
       unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
@@ -17711,9 +17767,6 @@ static SDValue performReinterpretCastCombine(SDNode *N) {
 
 static SDValue performSVEAndCombine(SDNode *N,
                                     TargetLowering::DAGCombinerInfo &DCI) {
-  if (DCI.isBeforeLegalizeOps())
-    return SDValue();
-
   SelectionDAG &DAG = DCI.DAG;
   SDValue Src = N->getOperand(0);
   unsigned Opc = Src->getOpcode();
@@ -17769,6 +17822,9 @@ static SDValue performSVEAndCombine(SDNode *N,
     return DAG.getNode(Opc, DL, N->getValueType(0), And);
   }
 
+  if (DCI.isBeforeLegalizeOps())
+    return SDValue();
+
   // If both sides of AND operations are i1 splat_vectors then
   // we can produce just i1 splat_vector as the result.
   if (isAllActivePredicate(DAG, N->getOperand(0)))
@@ -20216,6 +20272,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_uqsub_x:
     return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
+  case Intrinsic::aarch64_sve_urshr:
+    return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2), N->getOperand(3));
   case Intrinsic::aarch64_sve_asrd:
     return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -20832,6 +20891,51 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static bool isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode *N) {
+  if (N->getOpcode() != AArch64ISD::UZP1)
+    return false;
+  SDValue Op0 = N->getOperand(0);
+  EVT SrcVT = Op0->getValueType(0);
+  EVT DstVT = N->getValueType(0);
+  return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv16i8) ||
+         (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv8i16) ||
+         (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv4i32);
+}
+
+// Try to combine rounding shifts where the operands come from an extend, and
+// the result is truncated and combined into one vector.
+//   uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C)) -> urshr(X, C)
+static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  EVT ResVT = N->getValueType(0);
+
+  unsigned RshOpc = Op0.getOpcode();
+  if (RshOpc != AArch64ISD::RSHRNB_I)
+    return SDValue();
+
+  // Same op code and imm value?
+  SDValue ShiftValue = Op0.getOperand(1);
+  if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(1))
+    return SDValue();
+
+  // Same unextended operand value?
+  SDValue Lo = Op0.getOperand(0);
+  SDValue Hi = Op1.getOperand(0);
+  if (Lo.getOpcode() != AArch64ISD::UUNPKLO &&
+      Hi.getOpcode() != AArch64ISD::UUNPKHI)
+    return SDValue();
+  SDValue OrigArg = Lo.getOperand(0);
+  if (OrigArg != Hi.getOperand(0))
+    return SDValue();
+
+  SDLoc DL(N);
+  return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
+                     getPredicateForVector(DAG, DL, ResVT), OrigArg,
+                     ShiftValue);
+}
+
 // Try to simplify:
 //    t1 = nxv8i16 add(X, 1 << (ShiftValue - 1))
 //    t2 = nxv8i16 srl(t1, ShiftValue)
@@ -20844,9 +20948,7 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
 static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
                                          const AArch64Subtarget *Subtarget) {
   EVT VT = Srl->getValueType(0);
-
-  if (!VT.isScalableVector() || !Subtarget->hasSVE2() ||
-      Srl->getOpcode() != ISD::SRL)
+  if (!VT.isScalableVector() || !Subtarget->hasSVE2())
     return SDValue();
 
   EVT ResVT;
@@ -20859,29 +20961,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
   else
     return SDValue();
 
-  auto SrlOp1 =
-      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Srl->getOperand(1)));
-  if (!SrlOp1)
-    return SDValue();
-  unsigned ShiftValue = SrlOp1->getZExtValue();
-  if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
-    return SDValue();
-
-  SDValue Add = Srl->getOperand(0);
-  if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
-    return SDValue();
-  auto AddOp1 =
-      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
-  if (!AddOp1)
-    return SDValue();
-  uint64_t AddValue = AddOp1->getZExtValue();
-  if (AddValue != 1ULL << (ShiftValue - 1))
-    return SDValue();
-
   SDLoc DL(Srl);
+  unsigned ShiftValue;
+  SDValue RShOperand;
+  if (!canLowerSRLToRoundingShiftForVT(Srl, ResVT, DAG, ShiftValue, RShOperand))
+    return SDValue();
   SDValue Rshrnb = DAG.getNode(
       AArch64ISD::RSHRNB_I, DL, ResVT,
-      {Add->getOperand(0), DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
+      {RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
   return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
 }
 
@@ -20919,6 +21006,9 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
+    return Urshr;
+
   if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
     return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
 
@@ -20949,6 +21039,19 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
   if (!IsLittleEndian)
     return SDValue();
 
+  // uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
+  // Example:
+  // nxv4i32 = uzp1 bitcast(nxv4i32 x to nxv2i64), bitcast(nxv4i32 y to nxv2i64)
+  // to
+  // nxv4i32 = uzp1 nxv4i32 x, nxv4i32 y
+  if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
+      Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
+    if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
+      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
+                         Op1.getOperand(0));
+    }
+  }
+
   if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
     return SDValue();
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 541a810fb5cba..436b21fd13463 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -218,6 +218,7 @@ enum NodeType : unsigned {
   SQSHLU_I,
   SRSHR_I,
   URSHR_I,
+  URSHR_I_PRED,
 
   // Vector narrowing shift by immediate (bottom)
   RSHRNB_I,

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index c4d69232c9e30..e83d8e5bde79e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -232,6 +232,7 @@ def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [
 ]>;
 
 def AArch64asrd_m1 : SDNode<"AArch64ISD::SRAD_MERGE_OP1", SDT_AArch64Arith_Imm>;
+def AArch64urshri_p : SDNode<"AArch64ISD::URSHR_I_PRED", SDT_AArch64Arith_Imm>;
 
 def SDT_AArch64IntExtend : SDTypeProfile<1, 4, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVT<3, OtherVT>, SDTCisVec<4>,
@@ -3539,7 +3540,7 @@ let Predicates = [HasSVE2orSME] in {
   defm SQSHL_ZPmI  : sve_int_bin_pred_shift_imm_left_dup<0b0110, "sqshl",  "SQSHL_ZPZI",  int_aarch64_sve_sqshl>;
   defm UQSHL_ZPmI  : sve_int_bin_pred_shift_imm_left_dup<0b0111, "uqshl",  "UQSHL_ZPZI",  int_aarch64_sve_uqshl>;
   defm SRSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1100, "srshr",  "SRSHR_ZPZI",  int_aarch64_sve_srshr>;
-  defm URSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1101, "urshr",  "URSHR_ZPZI",  int_aarch64_sve_urshr>;
+  defm URSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1101, "urshr",  "URSHR_ZPZI",  AArch64urshri_p>;
   defm SQSHLU_ZPmI : sve_int_bin_pred_shift_imm_left<    0b1111, "sqshlu", "SQSHLU_ZPZI", int_aarch64_sve_sqshlu>;
 
   // SVE2 integer add/subtract long
@@ -3584,7 +3585,7 @@ let Predicates = [HasSVE2orSME] in {
   defm SSRA_ZZI  : sve2_int_bin_accum_shift_imm_right<0b00, "ssra",  AArch64ssra>;
   defm USRA_ZZI  : sve2_int_bin_accum_shift_imm_right<0b01, "usra",  AArch64usra>;
   defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, int_aarch64_sve_srshr>;
-  defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, int_aarch64_sve_urshr>;
+  defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, AArch64urshri_p>;
 
   // SVE2 complex integer add
   defm CADD_ZZI   : sve2_int_cadd<0b0, "cadd",   int_aarch64_sve_cadd_x>;

diff  --git a/llvm/test/CodeGen/AArch64/sve2-rsh.ll b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
new file mode 100644
index 0000000000000..516ef3bd581ee
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
@@ -0,0 +1,279 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64 -mattr=+sve < %s -o - | FileCheck --check-prefixes=CHECK,SVE %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 < %s -o - | FileCheck --check-prefixes=CHECK,SVE2 %s
+
+; Wrong add/shift amount. Should be 32 for shift of 6.
+define <vscale x 2 x i64> @neg_urshr_1(<vscale x 2 x i64> %x) {
+; CHECK-LABEL: neg_urshr_1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add z0.d, z0.d, #16 // =0x10
+; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    ret
+  %add = add nuw nsw <vscale x 2 x i64> %x, splat (i64 16)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+; Vector Shift.
+define <vscale x 2 x i64> @neg_urshr_2(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
+; CHECK-LABEL: neg_urshr_2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
+; CHECK-NEXT:    lsr z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %add = add nuw nsw <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, %y
+  ret <vscale x 2 x i64> %sh
+}
+
+; Vector Add.
+define <vscale x 2 x i64> @neg_urshr_3(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
+; CHECK-LABEL: neg_urshr_3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add z0.d, z0.d, z1.d
+; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    ret
+  %add = add nuw nsw <vscale x 2 x i64> %x, %y
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+; Add has two uses.
+define <vscale x 2 x i64> @neg_urshr_4(<vscale x 2 x i64> %x, ptr %p) {
+; CHECK-LABEL: neg_urshr_4:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z1.d, z0.d
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    add z1.d, z1.d, #32 // =0x20
+; CHECK-NEXT:    lsr z0.d, z1.d, #6
+; CHECK-NEXT:    st1d { z1.d }, p0, [x0]
+; CHECK-NEXT:    ret
+  %add = add nuw nsw <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  store <vscale x 2 x i64> %add, ptr %p
+  ret <vscale x 2 x i64> %sh
+}
+
+; Add can overflow.
+define <vscale x 2 x i64> @neg_urshr_5(<vscale x 2 x i64> %x) {
+; CHECK-LABEL: neg_urshr_5:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
+; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+define <vscale x 16 x i8> @urshr_i8(<vscale x 16 x i8> %x) {
+; SVE-LABEL: urshr_i8:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.b, z0.b, #32 // =0x20
+; SVE-NEXT:    lsr z0.b, z0.b, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i8:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.b
+; SVE2-NEXT:    urshr z0.b, p0/m, z0.b, #6
+; SVE2-NEXT:    ret
+  %add = add nuw nsw <vscale x 16 x i8> %x, splat (i8 32)
+  %sh = lshr <vscale x 16 x i8> %add, splat (i8 6)
+  ret <vscale x 16 x i8> %sh
+}
+
+define <vscale x 16 x i8> @urshr_8_wide_trunc(<vscale x 16 x i8> %x) {
+; SVE-LABEL: urshr_8_wide_trunc:
+; SVE:       // %bb.0:
+; SVE-NEXT:    uunpkhi z1.h, z0.b
+; SVE-NEXT:    uunpklo z0.h, z0.b
+; SVE-NEXT:    add z0.h, z0.h, #32 // =0x20
+; SVE-NEXT:    add z1.h, z1.h, #32 // =0x20
+; SVE-NEXT:    lsr z1.h, z1.h, #6
+; SVE-NEXT:    lsr z0.h, z0.h, #6
+; SVE-NEXT:    uzp1 z0.b, z0.b, z1.b
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_8_wide_trunc:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.b
+; SVE2-NEXT:    urshr z0.b, p0/m, z0.b, #6
+; SVE2-NEXT:    ret
+  %ext = zext <vscale x 16 x i8> %x to <vscale x 16 x i16>
+  %add = add nuw nsw <vscale x 16 x i16> %ext, splat (i16 32)
+  %sh = lshr <vscale x 16 x i16> %add, splat (i16 6)
+  %sht = trunc <vscale x 16 x i16> %sh to <vscale x 16 x i8>
+  ret <vscale x 16 x i8> %sht
+}
+
+define <vscale x 16 x i8> @urshr_8_wide_trunc_nomerge(<vscale x 16 x i16> %ext) {
+; SVE-LABEL: urshr_8_wide_trunc_nomerge:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.h, z0.h, #256 // =0x100
+; SVE-NEXT:    add z1.h, z1.h, #256 // =0x100
+; SVE-NEXT:    lsr z1.h, z1.h, #9
+; SVE-NEXT:    lsr z0.h, z0.h, #9
+; SVE-NEXT:    uzp1 z0.b, z0.b, z1.b
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_8_wide_trunc_nomerge:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.h
+; SVE2-NEXT:    urshr z1.h, p0/m, z1.h, #9
+; SVE2-NEXT:    urshr z0.h, p0/m, z0.h, #9
+; SVE2-NEXT:    uzp1 z0.b, z0.b, z1.b
+; SVE2-NEXT:    ret
+  %add = add nuw nsw <vscale x 16 x i16> %ext, splat (i16 256)
+  %sh = lshr <vscale x 16 x i16> %add, splat (i16 9)
+  %sht = trunc <vscale x 16 x i16> %sh to <vscale x 16 x i8>
+  ret <vscale x 16 x i8> %sht
+}
+
+define <vscale x 8 x i16> @urshr_i16(<vscale x 8 x i16> %x) {
+; SVE-LABEL: urshr_i16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.h, z0.h, #32 // =0x20
+; SVE-NEXT:    lsr z0.h, z0.h, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i16:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.h
+; SVE2-NEXT:    urshr z0.h, p0/m, z0.h, #6
+; SVE2-NEXT:    ret
+  %add = add nuw nsw <vscale x 8 x i16> %x, splat (i16 32)
+  %sh = lshr <vscale x 8 x i16> %add, splat (i16 6)
+  ret <vscale x 8 x i16> %sh
+}
+
+define <vscale x 8 x i16> @urshr_16_wide_trunc(<vscale x 8 x i16> %x) {
+; SVE-LABEL: urshr_16_wide_trunc:
+; SVE:       // %bb.0:
+; SVE-NEXT:    uunpkhi z1.s, z0.h
+; SVE-NEXT:    uunpklo z0.s, z0.h
+; SVE-NEXT:    add z0.s, z0.s, #32 // =0x20
+; SVE-NEXT:    add z1.s, z1.s, #32 // =0x20
+; SVE-NEXT:    lsr z1.s, z1.s, #6
+; SVE-NEXT:    lsr z0.s, z0.s, #6
+; SVE-NEXT:    uzp1 z0.h, z0.h, z1.h
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_16_wide_trunc:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.h
+; SVE2-NEXT:    urshr z0.h, p0/m, z0.h, #6
+; SVE2-NEXT:    ret
+  %ext = zext <vscale x 8 x i16> %x to <vscale x 8 x i32>
+  %add = add nuw nsw <vscale x 8 x i32> %ext, splat (i32 32)
+  %sh = lshr <vscale x 8 x i32> %add, splat (i32 6)
+  %sht = trunc <vscale x 8 x i32> %sh to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %sht
+}
+
+define <vscale x 8 x i16> @urshr_16_wide_trunc_nomerge(<vscale x 8 x i32> %ext) {
+; SVE-LABEL: urshr_16_wide_trunc_nomerge:
+; SVE:       // %bb.0:
+; SVE-NEXT:    mov z2.s, #0x10000
+; SVE-NEXT:    add z0.s, z0.s, z2.s
+; SVE-NEXT:    add z1.s, z1.s, z2.s
+; SVE-NEXT:    lsr z1.s, z1.s, #17
+; SVE-NEXT:    lsr z0.s, z0.s, #17
+; SVE-NEXT:    uzp1 z0.h, z0.h, z1.h
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_16_wide_trunc_nomerge:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    urshr z1.s, p0/m, z1.s, #17
+; SVE2-NEXT:    urshr z0.s, p0/m, z0.s, #17
+; SVE2-NEXT:    uzp1 z0.h, z0.h, z1.h
+; SVE2-NEXT:    ret
+  %add = add nuw nsw <vscale x 8 x i32> %ext, splat (i32 65536)
+  %sh = lshr <vscale x 8 x i32> %add, splat (i32 17)
+  %sht = trunc <vscale x 8 x i32> %sh to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %sht
+}
+
+define <vscale x 4 x i32> @urshr_i32(<vscale x 4 x i32> %x) {
+; SVE-LABEL: urshr_i32:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.s, z0.s, #32 // =0x20
+; SVE-NEXT:    lsr z0.s, z0.s, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i32:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    urshr z0.s, p0/m, z0.s, #6
+; SVE2-NEXT:    ret
+  %add = add nuw nsw <vscale x 4 x i32> %x, splat (i32 32)
+  %sh = lshr <vscale x 4 x i32> %add, splat (i32 6)
+  ret <vscale x 4 x i32> %sh
+}
+
+define <vscale x 4 x i32> @urshr_32_wide_trunc(<vscale x 4 x i32> %x) {
+; SVE-LABEL: urshr_32_wide_trunc:
+; SVE:       // %bb.0:
+; SVE-NEXT:    uunpkhi z1.d, z0.s
+; SVE-NEXT:    uunpklo z0.d, z0.s
+; SVE-NEXT:    add z0.d, z0.d, #32 // =0x20
+; SVE-NEXT:    add z1.d, z1.d, #32 // =0x20
+; SVE-NEXT:    lsr z1.d, z1.d, #6
+; SVE-NEXT:    lsr z0.d, z0.d, #6
+; SVE-NEXT:    uzp1 z0.s, z0.s, z1.s
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_32_wide_trunc:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    urshr z0.s, p0/m, z0.s, #6
+; SVE2-NEXT:    ret
+  %ext = zext <vscale x 4 x i32> %x to <vscale x 4 x i64>
+  %add = add nuw nsw <vscale x 4 x i64> %ext, splat (i64 32)
+  %sh = lshr <vscale x 4 x i64> %add, splat (i64 6)
+  %sht = trunc <vscale x 4 x i64> %sh to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %sht
+}
+
+define <vscale x 4 x i32> @urshr_32_wide_trunc_nomerge(<vscale x 4 x i64> %ext) {
+; SVE-LABEL: urshr_32_wide_trunc_nomerge:
+; SVE:       // %bb.0:
+; SVE-NEXT:    mov z2.d, #0x100000000
+; SVE-NEXT:    add z0.d, z0.d, z2.d
+; SVE-NEXT:    add z1.d, z1.d, z2.d
+; SVE-NEXT:    lsr z1.d, z1.d, #33
+; SVE-NEXT:    lsr z0.d, z0.d, #33
+; SVE-NEXT:    uzp1 z0.s, z0.s, z1.s
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_32_wide_trunc_nomerge:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.d
+; SVE2-NEXT:    urshr z1.d, p0/m, z1.d, #33
+; SVE2-NEXT:    urshr z0.d, p0/m, z0.d, #33
+; SVE2-NEXT:    uzp1 z0.s, z0.s, z1.s
+; SVE2-NEXT:    ret
+  %add = add nuw nsw <vscale x 4 x i64> %ext, splat (i64 4294967296)
+  %sh = lshr <vscale x 4 x i64> %add, splat (i64 33)
+  %sht = trunc <vscale x 4 x i64> %sh to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %sht
+}
+
+define <vscale x 2 x i64> @urshr_i64(<vscale x 2 x i64> %x) {
+; SVE-LABEL: urshr_i64:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.d, z0.d, #32 // =0x20
+; SVE-NEXT:    lsr z0.d, z0.d, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i64:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.d
+; SVE2-NEXT:    urshr z0.d, p0/m, z0.d, #6
+; SVE2-NEXT:    ret
+  %add = add nuw nsw <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}


        


More information about the llvm-commits mailing list