[llvm] Subject: [PATCH] [AArch64ISelLowering] Optimize rounding shift and saturation truncation (PR #74325)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 6 07:44:23 PST 2023


https://github.com/JohnLee1243 updated https://github.com/llvm/llvm-project/pull/74325

>From 216ea81be4e864416c9f3fbf191c1b914fa1675d Mon Sep 17 00:00:00 2001
From: JohnLee1243 <lizhuohang911 at 126.com>
Date: Mon, 4 Dec 2023 22:30:56 +0800
Subject: [PATCH] Subject: [PATCH] [AArch64ISelLowering] Optimize rounding
 shift and saturation truncation

This patch does 2 kinds of instruction simplification.
1. Rounding shift operation like SHIFT(Add(OpA, 1<<(imm-1)), imm)
can be simplified as srshrn(OpA, imm).
2. Rounding shift saturation truncation operations like
Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),imm)),0),maxValue)) can be
simplified as uqrshrn(OpA, imm)or sqrshrun(OpA, imm).
3. Add a pattern for RSHRN.
This patch does these optimization in backend after legalization.
---
 .../Target/AArch64/AArch64ISelDAGToDAG.cpp    |  11 +
 .../Target/AArch64/AArch64ISelLowering.cpp    | 597 +++++++++++++++++-
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |   2 +-
 .../test/CodeGen/AArch64/isel_rounding_opt.ll | 327 ++++++++++
 llvm/test/CodeGen/AArch64/neon-rshrn.ll       |  30 +-
 5 files changed, 945 insertions(+), 22 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/isel_rounding_opt.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 2f49e9a6b37cc..b559a470ed3f2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -177,6 +177,17 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
   }
 
   bool SelectRoundingVLShr(SDValue N, SDValue &Res1, SDValue &Res2) {
+    if (N.getOpcode() == AArch64ISD::URSHR_I) {
+      EVT VT = N.getValueType();
+      unsigned ShtAmt = N->getConstantOperandVal(1);
+      if (ShtAmt >= VT.getScalarSizeInBits() / 2)
+        return false;
+
+      Res1 = N.getOperand(0);
+      Res2 = CurDAG->getTargetConstant(ShtAmt, SDLoc(N), MVT::i32);
+      return true;
+    }
+
     if (N.getOpcode() != AArch64ISD::VLSHR)
       return false;
     SDValue Op = N->getOperand(0);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b6a16217dfae3..d0d35017a3365 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -111,6 +111,28 @@ STATISTIC(NumTailCalls, "Number of tail calls");
 STATISTIC(NumShiftInserts, "Number of vector shift inserts");
 STATISTIC(NumOptimizedImms, "Number of times immediates were optimized");
 
+static cl::opt<bool>
+    EnableAArch64RoundingOpt("aarch64-optimize-rounding", cl::Hidden,
+                             cl::desc("Enable AArch64 rounding optimization"),
+                             cl::init(true));
+
+static cl::opt<bool> EnableAArch64RoundingSatNarrowOpt(
+    "aarch64-optimize-rounding-saturation", cl::Hidden,
+    cl::desc("Enable AArch64 rounding and saturation narrow optimization"),
+    cl::init(true));
+
+static cl::opt<bool> EnableAArch64ExtractVecElementCombine(
+    "aarch64-extract-vector-element-trunc-combine", cl::Hidden,
+    cl::desc("Allow AArch64 extract vector element combination with "
+             "truncation"),
+    cl::init(true));
+
+static cl::opt<int> RoundingSearchMaxDepth(
+    "aarch64-rounding-search-max-depth", cl::Hidden,
+    cl::desc("Maximum depth to bfs search rounding value in rounding "
+             "optimization"),
+    cl::init(4));
+
 // FIXME: The necessary dtprel relocations don't seem to be supported
 // well in the GNU bfd and gold linkers at the moment. Therefore, by
 // default, for now, fall back to GeneralDynamic code generation.
@@ -995,6 +1017,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
   setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
+  setTargetDAGCombine(ISD::TRUNCATE);
+
   setTargetDAGCombine({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
                        ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG,
                        ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR,
@@ -17681,6 +17705,482 @@ static SDValue performANDCombine(SDNode *N,
   return SDValue();
 }
 
+// BFS search the operand which is equal to rounding value
+static bool searchRoundingValueBFS(
+    uint64_t RoundingValue, const SDValue &OperandIn,
+    SmallVectorImpl<std::pair<SDValue, SDNodeFlags>> &AddOperands, int Level) {
+  SmallVector<SDValue, 4> WorkList;
+  WorkList.emplace_back(OperandIn);
+  while (Level > 0 && !WorkList.empty()) {
+    auto Operand = WorkList.front();
+    SmallVector<SDValue>::iterator k = WorkList.begin();
+    WorkList.erase(k);
+    Level--;
+    SDValue Operand0 = Operand.getOperand(0);
+    SDValue Operand1 = Operand.getOperand(1);
+    BuildVectorSDNode *AddOp0 = dyn_cast<BuildVectorSDNode>(Operand0);
+    BuildVectorSDNode *AddOp1 = dyn_cast<BuildVectorSDNode>(Operand1);
+    auto foundRounding = [&](BuildVectorSDNode *AddOp, SDValue &OtherOperand) {
+      APInt SplatBitsAdd, SplatUndefAdd;
+      unsigned SplatBitSizeAdd = 0;
+      bool HasAnyUndefsAnd = false;
+      if (AddOp &&
+          AddOp->isConstantSplat(SplatBitsAdd, SplatUndefAdd, SplatBitSizeAdd,
+                                 HasAnyUndefsAnd) &&
+          (SplatBitsAdd == RoundingValue)) {
+        AddOperands.emplace_back(
+            std::make_pair(OtherOperand, OtherOperand.getNode()->getFlags()));
+        while (!WorkList.empty()) {
+          SDValue TempVal = WorkList.front();
+          SmallVector<SDValue>::iterator k = WorkList.begin();
+          WorkList.erase(k);
+          AddOperands.emplace_back(
+              std::make_pair(TempVal, TempVal.getNode()->getFlags()));
+        }
+        return true;
+      }
+      return false;
+    };
+    if (foundRounding(AddOp0, Operand1))
+      return true;
+    if (foundRounding(AddOp1, Operand0))
+      return true;
+    if (Operand0.getOpcode() == ISD::ADD)
+      WorkList.emplace_back(Operand0);
+    else
+      AddOperands.emplace_back(
+          std::make_pair(Operand0, Operand.getNode()->getFlags()));
+    if (Operand1.getOpcode() == ISD::ADD)
+      WorkList.emplace_back(Operand1);
+    else
+      AddOperands.emplace_back(
+          std::make_pair(Operand1, Operand.getNode()->getFlags()));
+  }
+
+  return false;
+}
+
+// Try to match pattern "OpB = SHIFT(Add(OpA, 1<<(imm-1)), imm) ", where
+// shift must be an immediate number
+static SDValue matchShiftRounding(const SDValue &ShiftOp0,
+                                  const SDValue &ShiftOp1, SelectionDAG &DAG,
+                                  int64_t &ShiftAmount) {
+  ShiftAmount = 0;
+  // For illegal type, do nothing. Wait until type is legalized.
+  EVT VT0 = ShiftOp0.getValueType();
+  if (!VT0.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT0))
+    return SDValue();
+
+  BuildVectorSDNode *ShiftOperand1 = dyn_cast<BuildVectorSDNode>(ShiftOp1);
+
+  // Shift value must be an immediate, either a constant splat vector form or
+  // scalar form.
+  int64_t TempAmount;
+  EVT VT = ShiftOp1.getValueType();
+  if (ShiftOperand1 && VT.isVector() &&
+      isVShiftRImm(ShiftOp1, VT, false, TempAmount))
+    ShiftAmount = TempAmount;
+
+  if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(ShiftOp1))
+    ShiftAmount = C->getSExtValue();
+
+  // For shift value = 1, match XHADD in first priority which accomplishes (a +
+  // b +1)>>1
+  if (ShiftOp0.getOpcode() == ISD::ADD && ShiftAmount > 1 &&
+      ShiftAmount <= 64) {
+    uint64_t RoundingValue = 1 << static_cast<uint64_t>(ShiftAmount - 1);
+    SmallVector<std::pair<SDValue, SDNodeFlags>, 4> AddOperands;
+
+    // In case expression has pattern "(a + roundingValue+b+c+d ) >> shift ",
+    // in which rounding value is not the direct input operand of shift,
+    // rounding value should be searched from root to leaves. In this case, the
+    // expresion will be matched as "RoundingShift(a + b + c + d, shift)"
+    if (searchRoundingValueBFS(RoundingValue, ShiftOp0, AddOperands,
+                               RoundingSearchMaxDepth)) {
+      SDLoc DL(ShiftOp0);
+      EVT VTAdd = AddOperands[0].first.getValueType();
+
+      for (size_t i = 1; i < AddOperands.size(); i++) {
+        AddOperands[i].first =
+            DAG.getNode(ISD::ADD, DL, VTAdd, AddOperands[i].first,
+                        AddOperands[i - 1].first, AddOperands[i].second);
+      }
+      return AddOperands[AddOperands.size() - 1].first;
+    }
+  }
+
+  return SDValue();
+}
+
+// Attempt to form XRSHR(OpA, imm) from "SRX(Add(OpA, 1<<(imm-1)), imm) "
+// Depending on logic and arithmetic shift, 'XRSHR' can be 'SRSHR' or 'URSHR'
+// and 'SRA' or 'SRL' for SRX.
+
+static SDValue matchXRSHR(SDNode *N, SelectionDAG &DAG,
+                          unsigned ShiftRoundingOpc) {
+  EVT VT = N->getValueType(0);
+  int64_t ShiftAmount;
+  SDValue AddOperand =
+      matchShiftRounding(N->getOperand(0), N->getOperand(1), DAG, ShiftAmount);
+  if (!AddOperand)
+    return SDValue();
+  SDLoc DL(N);
+  SDValue ResultRounding = DAG.getNode(
+      ShiftRoundingOpc, DL, VT, AddOperand,
+      DAG.getConstant(static_cast<uint64_t>(ShiftAmount), DL, MVT::i32));
+  return ResultRounding;
+}
+
+static SDValue performSRLCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI) {
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+
+  // Attempt to form URSHR(OpA, imm) from "SRL(Add(OpA, 1<<(imm-1), imm)"
+  SDValue ResultRounding;
+  if (EnableAArch64RoundingOpt && VT.isVector() &&
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON())
+    ResultRounding = matchXRSHR(N, DAG, AArch64ISD::URSHR_I);
+
+  if (ResultRounding)
+    return ResultRounding;
+
+  if (VT != MVT::i32 && VT != MVT::i64)
+    return SDValue();
+
+  // Canonicalize (srl (bswap i32 x), 16) to (rotr (bswap i32 x), 16), if the
+  // high 16-bits of x are zero. Similarly, canonicalize (srl (bswap i64 x), 32)
+  // to (rotr (bswap i64 x), 32), if the high 32-bits of x are zero.
+  SDValue N0 = N->getOperand(0);
+  if (N0.getOpcode() == ISD::BSWAP) {
+    SDLoc DL(N);
+    SDValue N1 = N->getOperand(1);
+    SDValue N00 = N0.getOperand(0);
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) {
+      uint64_t ShiftAmt = C->getZExtValue();
+      if (VT == MVT::i32 && ShiftAmt == 16 &&
+          DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(32, 16)))
+        return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
+      if (VT == MVT::i64 && ShiftAmt == 32 &&
+          DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(64, 32)))
+        return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
+    }
+  }
+  return SDValue();
+}
+
+static SDValue performSRACombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI) {
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+
+  // Attempt to form URSHR(OpA, shift) from "SRL(Add(OpA, 1<<(imm-1), imm) "
+  SDValue ResultRounding;
+  if (EnableAArch64RoundingOpt && VT.isVector() &&
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON())
+    ResultRounding = matchXRSHR(N, DAG, AArch64ISD::SRSHR_I);
+
+  return ResultRounding;
+}
+
+// Try to match pattern "Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),
+// imm)),0),maxValue)) " Here min(max(x,0),maxValue) is an unsigned
+// saturation operation, shift is signed right scalar or vector shift.
+static SDValue matchTruncSatRounding(const SDValue &UminOp, SelectionDAG &DAG,
+                                     uint64_t MaxScalarValue,
+                                     int64_t &ShiftAmount) {
+  if (UminOp.getOpcode() != ISD::UMIN)
+    return SDValue();
+
+  APInt SplatBitsRound, SplatUndefRound;
+  unsigned SplatBitSizeRound = 0;
+  bool HasAnyUndefsRound = false;
+  uint64_t UminOp1VecVal = 0;
+
+  BuildVectorSDNode *UminOp1Vec =
+      dyn_cast<BuildVectorSDNode>(UminOp.getOperand(1));
+
+  if (UminOp1Vec &&
+      UminOp1Vec->isConstantSplat(SplatBitsRound, SplatUndefRound,
+                                  SplatBitSizeRound, HasAnyUndefsRound))
+    UminOp1VecVal = SplatBitsRound.getZExtValue();
+
+  if (UminOp1VecVal != MaxScalarValue)
+    return SDValue();
+
+  SDValue SmaxOp = UminOp.getOperand(0);
+  if (SmaxOp.getOpcode() != ISD::SMAX ||
+      !isNullOrNullSplat(SmaxOp.getOperand(1)))
+    return SDValue();
+
+  SDValue RoundingOp = SmaxOp.getOperand(0);
+  unsigned int RoundingOpCode = RoundingOp.getOpcode();
+  if (RoundingOpCode == AArch64ISD::VASHR ||
+      RoundingOpCode == AArch64ISD::VLSHR || RoundingOpCode == ISD::SRA ||
+      RoundingOpCode == ISD::SRL) {
+    SDValue AddOperand = matchShiftRounding(
+        RoundingOp.getOperand(0), RoundingOp.getOperand(1), DAG, ShiftAmount);
+    // Rounding+Truncation instruction doesn't support shift amount > input data
+    // width/2
+    int64_t InSize = UminOp.getValueType().getScalarSizeInBits();
+    if (ShiftAmount > InSize / 2)
+      return SDValue();
+    return AddOperand;
+  }
+
+  return SDValue();
+}
+
+// A helper function to build sqrshrun or uqrshrn instruction
+SDValue generateQrshrnInstruction(unsigned ShiftRoundingOpc, SelectionDAG &DAG,
+                                  SDValue &AddOperand, int64_t ShiftAmount) {
+  SDLoc DL(AddOperand);
+  EVT InVT = AddOperand.getValueType();
+  EVT HalvedVT = InVT.changeVectorElementType(
+      InVT.getVectorElementType().getHalfSizedIntegerVT(*DAG.getContext()));
+
+  SDValue ResultRounding = DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, HalvedVT,
+      DAG.getConstant(ShiftRoundingOpc, DL, MVT::i64), AddOperand,
+      DAG.getConstant(static_cast<uint64_t>(ShiftAmount), DL, MVT::i32));
+  return ResultRounding;
+}
+
+// Attempt to match Trunc(Concat(Trunc(min(max(Shift(Add(OpA, 1<<(imm1-1),
+// imm1)),0),maxValue)),Trunc(min(max(Shift(Add(OpB, 1<<(imm2-1),
+// imm2)),0),maxValue))))
+// The pattern in tree is shown below
+//  OpA     1<<(imm1-1)                         OpB     1<<(imm2-1)
+//    \       /                                    \       /
+//     \     /                                      \     /
+//      Add      imm1                       imm2      Add
+//        \      /                              \      /
+//         \    /                                \    /
+//         Shift     0                    0      Shift
+//            \     /                      \      /
+//             \   /                        \    /
+//              max     maxVal      maxVal   max
+//                \     /             \      /
+//                 \   /               \    /
+//                  min                 min
+//                    \                 /
+//                     \               /
+//                    Trunc         Trunc
+//                        \         /
+//                         \       /
+//                          Concat
+//                             |
+//                             |
+//                           Trunc
+// The pattern will be matched as uqxtn(concat(qrshrn(OpA,imm1),
+// qrshrn(OpB,imm2))) where uqxtn is saturation truncation, qrshrn is sqrshrun
+// or uqrshrn.
+static SDValue matchQrshrn2(SDNode *N, SelectionDAG &DAG) {
+  TypeSize OutSizeTyped = N->getValueSizeInBits(0);
+  if (OutSizeTyped.isScalable())
+    return SDValue();
+  uint64_t OutSize = OutSizeTyped;
+  SDValue Operand0 = N->getOperand(0);
+  if (Operand0.getOpcode() != ISD::CONCAT_VECTORS || OutSize != 64)
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  uint64_t OutScalarSize = VT.getScalarSizeInBits();
+  uint64_t MaxScalarValue = (1 << OutScalarSize) - 1;
+  SDLoc DL(N);
+  SDValue TruncOp0 = Operand0.getOperand(0);
+  SDValue TruncOp1 = Operand0.getOperand(1);
+  if (TruncOp0.getOpcode() != ISD::TRUNCATE ||
+      TruncOp1.getOpcode() != ISD::TRUNCATE)
+    return SDValue();
+
+  int64_t ShiftAmount0 = -1, ShiftAmount1 = -1;
+  SDValue AddOperand0 = matchTruncSatRounding(TruncOp0.getOperand(0), DAG,
+                                              MaxScalarValue, ShiftAmount0);
+  SDValue AddOperand1 = matchTruncSatRounding(TruncOp1.getOperand(0), DAG,
+                                              MaxScalarValue, ShiftAmount1);
+  if (!AddOperand0 || !AddOperand1)
+    return SDValue();
+
+  auto getShiftVal = [&](SDValue &TruncOp) {
+    if (SDValue Operand0 = TruncOp.getOperand(0))
+      if (SDValue Operand00 = Operand0.getOperand(0))
+        return Operand00.getOperand(0);
+
+    return SDValue();
+  };
+
+  SDValue Shift0Val = getShiftVal(TruncOp0);
+  SDValue Shift1Val = getShiftVal(TruncOp1);
+  if (!Shift0Val || !Shift1Val)
+    return SDValue();
+
+  unsigned Shift0Opc = Shift0Val.getOpcode();
+  unsigned Shift1Opc = Shift1Val.getOpcode();
+  unsigned TruncOpc0 = (Shift0Opc == AArch64ISD::VLSHR || Shift0Opc == ISD::SRL)
+                           ? Intrinsic::aarch64_neon_uqrshrn
+                           : Intrinsic::aarch64_neon_sqrshrun;
+  unsigned TruncOpc1 = (Shift1Opc == AArch64ISD::VLSHR || Shift1Opc == ISD::SRL)
+                           ? Intrinsic::aarch64_neon_uqrshrn
+                           : Intrinsic::aarch64_neon_sqrshrun;
+  SDValue ResultTrunc0 =
+      generateQrshrnInstruction(TruncOpc0, DAG, AddOperand0, ShiftAmount0);
+  SDValue ResultTrunc1 =
+      generateQrshrnInstruction(TruncOpc1, DAG, AddOperand1, ShiftAmount1);
+
+  EVT ConcatVT = ResultTrunc1.getValueType().getDoubleNumVectorElementsVT(
+      *DAG.getContext());
+  SDValue ConcatOp = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
+                                 ResultTrunc0, ResultTrunc1);
+  // Notice MaxScalarValue is finial truncated type max value, so twice
+  // saturation is needed. uqxtn is saturation truncation.
+  SDValue ResultOp = DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, VT,
+      DAG.getConstant(Intrinsic::aarch64_neon_uqxtn, DL, MVT::i64), ConcatOp);
+
+  return ResultOp;
+}
+
+// To match SQRSHRUN from trunc(umin(smax((a + b + (1<<(imm-1))) >>imm,
+// max), 0))
+static SDValue matchQrshrn(SDNode *N, SelectionDAG &DAG) {
+  int64_t ShiftAmount = -1;
+  SDValue AddOperand;
+  TypeSize OutSizeTyped = N->getValueSizeInBits(0);
+  if (OutSizeTyped.isScalable())
+    return SDValue();
+  uint64_t OutSize = OutSizeTyped;
+  EVT VT = N->getValueType(0);
+
+  uint64_t OutScalarSize = VT.getScalarSizeInBits();
+  uint64_t MaxScalarValue = (1 << OutScalarSize) - 1;
+  SDLoc DL(N);
+
+  if (OutSize <= 64 && OutSize >= 32)
+    AddOperand = matchTruncSatRounding(N->getOperand(0), DAG, MaxScalarValue,
+                                       ShiftAmount);
+
+  if (!AddOperand)
+    return SDValue();
+
+  uint64_t UminOpScalarSize =
+      N->getOperand(0).getValueType().getScalarSizeInBits();
+  unsigned ShiftOpc = N->getOperand(0).getOperand(0).getOperand(0).getOpcode();
+
+  unsigned TruncOpc = (ShiftOpc == AArch64ISD::VLSHR || ShiftOpc == ISD::SRL)
+                          ? Intrinsic::aarch64_neon_uqrshrn
+                          : Intrinsic::aarch64_neon_sqrshrun;
+
+  SDValue ResultTrunc =
+      generateQrshrnInstruction(TruncOpc, DAG, AddOperand, ShiftAmount);
+  // For pattern "trunc(trunc(umin(smax((a + b + (1<<(imm-1)))
+  // >>imm,0),max))", where max is the outer truncated type max value, so
+  // another saturation trunction is needed.
+  // Because final truncated type may be illegal, and there is no method to
+  // legalize intrinsic, so add redundant operation `CONCAT_VECTORS` and
+  // `EXTRACT_SUBVECTOR` to automatically legalize the operation.
+  // But generated asm code is not so optimized in this way.
+  // For example, generated asm:
+  // sqrshrun  v1.4h, v0.4s, #6
+  // sqrshrun2 v1.8h, v0.4s, #6
+  // uqxtn   v0.8b, v1.8h
+  // zip1    v0.16b, v0.16b, v0.16b
+  // xtn     v0.8b, v0.8h
+  //
+  // ideal Optimized code:
+  //
+  // sqrshrun  v1.4h, v0.4s, #6
+  // uqxtn   v0.8b, v1.8h
+  // To solve this issue, Function `hasUselessTrunc` is introduced which can
+  // optimize code like above.
+  if (ResultTrunc && OutScalarSize == UminOpScalarSize / 4) {
+    EVT ConcatVT = ResultTrunc.getValueType().getDoubleNumVectorElementsVT(
+        *DAG.getContext());
+    SDValue ConcatOp = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, ConcatVT, ResultTrunc,
+        DAG.getNode(ISD::UNDEF, SDLoc(), ResultTrunc.getValueType()));
+    EVT HalvedVT = ConcatVT.changeVectorElementType(
+        ConcatVT.getVectorElementType().getHalfSizedIntegerVT(
+            *DAG.getContext()));
+
+    ResultTrunc = DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, HalvedVT,
+        DAG.getConstant(Intrinsic::aarch64_neon_uqxtn, DL, MVT::i64), ConcatOp);
+
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ResultTrunc,
+                       DAG.getVectorIdxConstant(0, DL));
+  }
+
+  // Only support truncated type size = 1/2 or 1/4 of input type size
+  if (ResultTrunc)
+    assert(OutScalarSize == UminOpScalarSize / 2 &&
+           "Invalid Truncation Type Size!");
+  return ResultTrunc;
+}
+
+// Try to match pattern  (truncate (BUILD_VECTOR(a[i],a[i+1],..., x,x,..)))
+static bool hasUselessTrunc(SDValue &N, SDValue &UsefullValue,
+                            unsigned FinalEleNum, uint64_t OldExtractIndex,
+                            int &NewExtractIndex) {
+  if (N.getOpcode() != ISD::TRUNCATE)
+    return false;
+
+  SDValue N0 = N.getOperand(0);
+  if (N0.getOpcode() != ISD::BUILD_VECTOR)
+    return false;
+
+  SmallVector<SDValue, 8> ExtractValues;
+  SmallVector<int, 8> Indices;
+  unsigned ElementNum = N.getValueType().getVectorNumElements();
+
+  // for example, if v8i8 is bitcasted to v2i32, then `TransformedSize` is 4
+  unsigned TransformedSize = 0;
+  if (FinalEleNum != 0)
+    TransformedSize = ElementNum / FinalEleNum;
+  for (unsigned Ele = 0; Ele < TransformedSize; Ele++) {
+    SDValue Extract = N0.getOperand(Ele);
+    if (Extract.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+      return false;
+
+    ExtractValues.emplace_back(Extract.getOperand(0));
+    if (ConstantSDNode *ConstantN =
+            dyn_cast<ConstantSDNode>(Extract.getOperand(1)))
+      Indices.emplace_back(ConstantN->getZExtValue());
+    else
+      return false;
+  }
+
+  if (ExtractValues.size() == 0)
+    return false;
+
+  SDValue FirstVal = ExtractValues[0];
+
+  if (FirstVal.getValueType() != N.getValueType())
+    return false;
+
+  for (auto &Value : ExtractValues) {
+    if (Value != FirstVal)
+      return false;
+  }
+
+  signed OldIdx = static_cast<int>(TransformedSize * OldExtractIndex);
+  signed Diff = OldIdx - Indices[0];
+  if (Diff % static_cast<int>(TransformedSize) != 0)
+    return false;
+
+  for (auto Index : Indices) {
+    if (Index != OldIdx - Diff)
+      return false;
+
+    OldIdx++;
+  }
+
+  UsefullValue = FirstVal;
+  NewExtractIndex = static_cast<int>(OldExtractIndex) -
+                    Diff / static_cast<int>(TransformedSize);
+  return true;
+}
+
 static SDValue performFADDCombine(SDNode *N,
                                   TargetLowering::DAGCombinerInfo &DCI) {
   SelectionDAG &DAG = DCI.DAG;
@@ -17883,6 +18383,40 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
     }
   }
 
+  // Rewrite for pattern
+  //   (extract_vector_elt
+  //           (bitcast (truncate (BUILD_VECTOR(a[i],a[i+1],..., x,x,..))), 0)
+  // as
+  //   (extract_vector_elt
+  //           (bitcast (a)), k)
+  //
+  // For example, pattern
+  //
+  // (i32 extract_vector_elt
+  //           (v2i32 bitcast (v8i8 truncate (v8i16
+  //           BUILD_VECTOR(a[4],a[5],a[6],a[7], undef,undef,undef,undef))), 0)
+  //  where the type of a is v8i8
+  //
+  // can be transformed to
+  //   (i32 extract_vector_elt
+  //           (v2i32 bitcast (a)), 1)
+
+  EVT N0VT = N0.getValueType();
+  if (EnableAArch64ExtractVecElementCombine && !N0VT.isScalableVector() &&
+      isNullConstant(N1) && N0.getOpcode() == ISD::BITCAST && N0VT.isVector()) {
+    SDLoc DL(N0);
+    SDValue N00 = N0->getOperand(0);
+
+    SDValue Other;
+    int NewIndex = 0;
+
+    if (hasUselessTrunc(N00, Other, N0VT.getVectorNumElements(), 0, NewIndex))
+      return DAG.getNode(
+          ISD::EXTRACT_VECTOR_ELT, DL, VT,
+          DAG.getNode(ISD::BITCAST, DL, N0VT, Other),
+          DAG.getConstant(static_cast<uint64_t>(NewIndex), DL, MVT::i64));
+  }
+
   return SDValue();
 }
 
@@ -18047,6 +18581,13 @@ static SDValue performConcatVectorsCombine(SDNode *N,
   }
 
   auto IsRSHRN = [](SDValue Shr) {
+    if (Shr.getOpcode() == AArch64ISD::URSHR_I) {
+      EVT VT = Shr.getValueType();
+      unsigned ShtAmt = Shr->getConstantOperandVal(1);
+      if (ShtAmt >= VT.getScalarSizeInBits() / 2)
+        return false;
+      return true;
+    }
     if (Shr.getOpcode() != AArch64ISD::VLSHR)
       return false;
     SDValue Op = Shr.getOperand(0);
@@ -18077,6 +18618,17 @@ static SDValue performConcatVectorsCombine(SDNode *N,
       ((IsRSHRN(N1) &&
         N0.getConstantOperandVal(1) == N1.getConstantOperandVal(1)) ||
        N1.isUndef())) {
+    if (N0.getOpcode() == AArch64ISD::URSHR_I) {
+      SDValue X = N0.getOperand(0);
+      SDValue Y =
+          N1.isUndef() ? DAG.getUNDEF(X.getValueType()) : N1.getOperand(0);
+      EVT BVT =
+          X.getValueType().getDoubleNumVectorElementsVT(*DCI.DAG.getContext());
+      SDValue CC = DAG.getNode(ISD::CONCAT_VECTORS, dl, BVT, X, Y);
+      SDValue Shr =
+          DAG.getNode(AArch64ISD::URSHR_I, dl, BVT, CC, N0.getOperand(1));
+      return Shr;
+    }
     SDValue X = N0.getOperand(0).getOperand(0);
     SDValue Y = N1.isUndef() ? DAG.getUNDEF(X.getValueType())
                              : N1.getOperand(0).getOperand(0);
@@ -18832,9 +19384,28 @@ static SDValue performBuildVectorCombine(SDNode *N,
   return SDValue();
 }
 
-static SDValue performTruncateCombine(SDNode *N,
-                                      SelectionDAG &DAG) {
+static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG) {
   EVT VT = N->getValueType(0);
+  bool HasNeon =
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON();
+
+  if (EnableAArch64RoundingSatNarrowOpt && VT.isVector() && HasNeon) {
+    // Attempt to match Trunc(Concat(Trunc(min(max(Shift(Add(OpA, 1<<(imm1-1),
+    // imm1)),0),maxValue)),Trunc(min(max(Shift(Add(OpB, 1<<(imm2-1),
+    // imm2)),0),maxValue)))) and can be simplified as
+    // uqxtn(concat(qrshrn(OpA,imm1), qrshrn(OpB,imm2))
+    // where qrshrn can be sqrshrun or uqrshrn
+    SDValue ResultTrunc = matchQrshrn2(N, DAG);
+    if (ResultTrunc)
+      return ResultTrunc;
+
+    // To match QRSHRN from trunc(umin(smax((a + b + (1<<(imm-1))) >>imm,
+    // max), 0))
+    ResultTrunc = matchQrshrn(N, DAG);
+    if (ResultTrunc)
+      return ResultTrunc;
+  }
+
   SDValue N0 = N->getOperand(0);
   if (VT.isFixedLengthVector() && VT.is64BitVector() && N0.hasOneUse() &&
       N0.getOpcode() == AArch64ISD::DUP) {
@@ -20854,6 +21425,24 @@ static SDValue performVectorShiftCombine(SDNode *N,
   if (TLI.SimplifyDemandedBits(Op, DemandedMask, DCI))
     return SDValue(N, 0);
 
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+
+  SDValue ResultRounding;
+
+  // Attempt to form URSHR(OpA, imm) from "SRL(Add(OpA, 1<<(imm-1), imm)"
+  //          or     SRSHR(OpA, imm) from "SRA(Add(OpA, 1<<(imm-1), imm)"
+  if (EnableAArch64RoundingOpt && VT.isVector() &&
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON()) {
+    unsigned ShiftRoundingOpc = (N->getOpcode() == AArch64ISD::VASHR)
+                                    ? AArch64ISD::SRSHR_I
+                                    : AArch64ISD::URSHR_I;
+    ResultRounding = matchXRSHR(N, DAG, ShiftRoundingOpc);
+  }
+
+  if (ResultRounding)
+    return ResultRounding;
+
   return SDValue();
 }
 
@@ -23662,6 +24251,10 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performORCombine(N, DCI, Subtarget, *this);
   case ISD::AND:
     return performANDCombine(N, DCI);
+  case ISD::SRL:
+    return performSRLCombine(N, DCI);
+  case ISD::SRA:
+    return performSRACombine(N, DCI);
   case ISD::FADD:
     return performFADDCombine(N, DCI);
   case ISD::INTRINSIC_WO_CHAIN:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 44b0337fe7879..1180b7bb9a3e9 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -817,7 +817,7 @@ def AArch64saddlp   : PatFrags<(ops node:$src),
 def AArch64faddp     : PatFrags<(ops node:$Rn, node:$Rm),
                                 [(AArch64addp_n node:$Rn, node:$Rm),
                                  (int_aarch64_neon_faddp node:$Rn, node:$Rm)]>;
-def AArch64roundingvlshr : ComplexPattern<vAny, 2, "SelectRoundingVLShr", [AArch64vlshr]>;
+def AArch64roundingvlshr : ComplexPattern<vAny, 2, "SelectRoundingVLShr", [AArch64vlshr, AArch64urshri]>;
 def AArch64rshrn : PatFrags<(ops node:$LHS, node:$RHS),
                             [(trunc (AArch64roundingvlshr node:$LHS, node:$RHS)),
                              (int_aarch64_neon_rshrn node:$LHS, node:$RHS)]>;
diff --git a/llvm/test/CodeGen/AArch64/isel_rounding_opt.ll b/llvm/test/CodeGen/AArch64/isel_rounding_opt.ll
new file mode 100644
index 0000000000000..d745ad215209c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/isel_rounding_opt.ll
@@ -0,0 +1,327 @@
+; RUN: llc -o - %s -mtriple=aarch64-none-linux-gnu | FileCheck %s
+; RUN: llc -o - %s -mtriple=aarch64-none-linux-gnu -aarch64-optimize-rounding=false -aarch64-optimize-rounding-saturation=false -aarch64-extract-vector-element-trunc-combine=false | FileCheck %s --check-prefix=NOROUNDING
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define void @test_srshr(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_srshr:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:    srshr
+; NOROUNDING-NOT:: srshr
+; CHECK-NOT:   sshr
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 32, i32 32, i32 32, i32 32>
+  %3 = ashr <4 x i32> %2, <i32 6, i32 6, i32 6, i32 6>
+  %4 = bitcast i8* %dst to <4 x i32>*
+  store <4 x i32> %3, <4 x i32>* %4, align 1
+  ret void
+}
+
+define void @test_srshr2(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_srshr2:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:    sqrshrun
+; CHECK:    uqxtn
+; NOROUNDING-NOT: sqrshrun
+; NOROUNDING-NOT: uqxtn
+; CHECK-NOT:   sshr
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 32, i32 32, i32 32, i32 32>
+  %3 = ashr <4 x i32> %2, <i32 6, i32 6, i32 6, i32 6>
+  %4 = icmp sgt <4 x i32> %3, zeroinitializer
+  %5 = select <4 x i1> %4, <4 x i32> %3, <4 x i32> zeroinitializer
+  %6 = icmp ult <4 x i32> %5, <i32 255, i32 255, i32 255, i32 255>
+  %7 = select <4 x i1> %6, <4 x i32> %5, <4 x i32> <i32 255, i32 255, i32 255, i32 255>
+  %8 = trunc <4 x i32> %7 to <4 x i8>
+  %9 = bitcast i8* %dst to <4 x i8>*
+  store <4 x i8> %8, <4 x i8>* %9, align 1
+  ret void
+}
+
+define void @test_srshr3(i8* nocapture %dst, i8* nocapture readonly %pix1, i8* nocapture readonly %pix2) {
+; CHECK-LABEL: test_srshr3:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:    add
+; CHECK:    srshr
+; NOROUNDING-NOT: srshr
+; CHECK-NOT:   sshr
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = bitcast i8* %pix2 to <4 x i32>*
+  %3 = load <4 x i32>, <4 x i32>* %2, align 1
+  %4 = add nuw nsw <4 x i32> %1, <i32 32, i32 32, i32 32, i32 32>
+  %5 = add nuw nsw <4 x i32> %4, %3
+  %6 = ashr <4 x i32> %5, <i32 6, i32 6, i32 6, i32 6>
+  %7 = bitcast i8* %dst to <4 x i32>*
+  store <4 x i32> %6, <4 x i32>* %7, align 1
+  ret void
+}
+
+
+define void @test_sqrshrun(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_sqrshrun:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:    sqrshrun
+; NOROUNDING-NOT: sqrshrun
+; CHECK-NOT:   sshr
+; CHECK-NOT:   smax
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 32, i32 32, i32 32, i32 32>
+  %3 = ashr <4 x i32> %2, <i32 6, i32 6, i32 6, i32 6>
+  %4 = icmp sgt <4 x i32> %3, zeroinitializer
+  %5 = select <4 x i1> %4, <4 x i32> %3, <4 x i32> zeroinitializer
+  %6 = icmp ult <4 x i32> %5, <i32 65535, i32 65535, i32 65535, i32 65535>
+  %7 = select <4 x i1> %6, <4 x i32> %5, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>
+  %8 = trunc <4 x i32> %7 to <4 x i16>
+  %9 = bitcast i8* %dst to <4 x i16>*
+  store <4 x i16> %8, <4 x i16>* %9, align 1
+  ret void
+}
+
+
+
+define void @test_urshr(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_urshr:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:    urshr
+; NOROUNDING-NOT: urshr
+; CHECK-NOT:   ushr
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 32, i32 32, i32 32, i32 32>
+  %3 = lshr <4 x i32> %2, <i32 6, i32 6, i32 6, i32 6>
+  %4 = bitcast i8* %dst to <4 x i32>*
+  store <4 x i32> %3, <4 x i32>* %4, align 1
+  ret void
+}
+
+
+define void @test_sqrshrun2(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_sqrshrun2:
+; CHECK:    ldp
+; NOROUNDING: ldp
+; CHECK:    sqrshrun
+; CHECK:    sqrshrun2
+; NOROUNDING-NOT: sqrshrun
+; NOROUNDING-NOT: sqrshrun2
+; CHECK-NOT:   ushr
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <8 x i32>*
+  %1 = load <8 x i32>, <8 x i32>* %0, align 1
+  %2 = add nuw nsw <8 x i32> %1, <i32 32, i32 32, i32 32, i32 32, i32 32, i32 32, i32 32, i32 32>
+  %3 = ashr <8 x i32> %2, <i32 6, i32 6, i32 6, i32 6, i32 6, i32 6, i32 6, i32 6>
+  %4 = icmp sgt <8 x i32> %3, zeroinitializer
+  %5 = select <8 x i1> %4, <8 x i32> %3, <8 x i32> zeroinitializer
+  %6 = icmp ult <8 x i32> %5, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %7 = select <8 x i1> %6, <8 x i32> %5, <8 x i32> <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %8 = trunc <8 x i32> %7 to <8 x i8>
+  %9 = bitcast i8* %dst to <8 x i8>*
+  store <8 x i8> %8, <8 x i8>* %9, align 1
+  ret void
+}
+
+define void @test_uqrshrn(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_uqrshrn:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:    uqrshrn
+; NOROUNDING-NOT: uqrshrn
+; CHECK-NOT:   ushr
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 32, i32 32, i32 32, i32 32>
+  %3 = lshr <4 x i32> %2, <i32 6, i32 6, i32 6, i32 6>
+  %4 = icmp sgt <4 x i32> %3, zeroinitializer
+  %5 = select <4 x i1> %4, <4 x i32> %3, <4 x i32> zeroinitializer
+  %6 = icmp ult <4 x i32> %5, <i32 65535, i32 65535, i32 65535, i32 65535>
+  %7 = select <4 x i1> %6, <4 x i32> %5, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>
+  %8 = trunc <4 x i32> %7 to <4 x i16>
+  %9 = bitcast i8* %dst to <4 x i16>*
+  store <4 x i16> %8, <4 x i16>* %9, align 1
+  ret void
+}
+
+define void @test_srshr_long_shift(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_srshr_long_shift:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:    srshr
+; NOROUNDING-NOT:: srshr
+; CHECK-NOT:   sshr
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 2097152, i32 2097152, i32 2097152, i32 2097152>
+  %3 = ashr <4 x i32> %2, <i32 22, i32 22, i32 22, i32 22>
+  %4 = bitcast i8* %dst to <4 x i32>*
+  store <4 x i32> %3, <4 x i32>* %4, align 1
+  ret void
+}
+
+define void @test_urshr_long_shift(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_urshr_long_shift:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:    urshr
+; NOROUNDING-NOT: urshr
+; CHECK-NOT:   ushr
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 2097152, i32 2097152, i32 2097152, i32 2097152>
+  %3 = lshr <4 x i32> %2, <i32 22, i32 22, i32 22, i32 22>
+  %4 = bitcast i8* %dst to <4 x i32>*
+  store <4 x i32> %3, <4 x i32>* %4, align 1
+  ret void
+}
+
+; Negative test: Rounding+Truncation instruction doesn't support shift amount > input data with / 2
+define void @test_srshr2_long_shift(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_srshr2_long_shift:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:   srshr
+; CHECK-NOT:    sqrshrun
+; CHECK-NOT:    uqxtn
+; NOROUNDING-NOT: sqrshrun
+; NOROUNDING-NOT: uqxtn
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 2097152, i32 2097152, i32 2097152, i32 2097152>
+  %3 = ashr <4 x i32> %2, <i32 22, i32 22, i32 22, i32 22>
+  %4 = icmp sgt <4 x i32> %3, zeroinitializer
+  %5 = select <4 x i1> %4, <4 x i32> %3, <4 x i32> zeroinitializer
+  %6 = icmp ult <4 x i32> %5, <i32 255, i32 255, i32 255, i32 255>
+  %7 = select <4 x i1> %6, <4 x i32> %5, <4 x i32> <i32 255, i32 255, i32 255, i32 255>
+  %8 = trunc <4 x i32> %7 to <4 x i8>
+  %9 = bitcast i8* %dst to <4 x i8>*
+  store <4 x i8> %8, <4 x i8>* %9, align 1
+  ret void
+}
+
+
+; Negative test: Rounding+Truncation instruction doesn't support shift amount > input data with / 2
+define void @test_sqrshrun_long_shift(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_sqrshrun_long_shift:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:   srshr
+; CHECK-NOT:    sqrshrun
+; NOROUNDING-NOT: sqrshrun
+; CHECK:   smax
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 2097152, i32 2097152, i32 2097152, i32 2097152>
+  %3 = ashr <4 x i32> %2, <i32 22, i32 22, i32 22, i32 22>
+  %4 = icmp sgt <4 x i32> %3, zeroinitializer
+  %5 = select <4 x i1> %4, <4 x i32> %3, <4 x i32> zeroinitializer
+  %6 = icmp ult <4 x i32> %5, <i32 65535, i32 65535, i32 65535, i32 65535>
+  %7 = select <4 x i1> %6, <4 x i32> %5, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>
+  %8 = trunc <4 x i32> %7 to <4 x i16>
+  %9 = bitcast i8* %dst to <4 x i16>*
+  store <4 x i16> %8, <4 x i16>* %9, align 1
+  ret void
+}
+
+; Negative test: Rounding+Truncation instruction doesn't support shift amount > input data with / 2
+define void @test_sqrshrun2_long_shift(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_sqrshrun2_long_shift:
+; CHECK:    ldp
+; CHECK:   srshr
+; NOROUNDING: ldp
+; CHECK-NOT:    sqrshrun
+; CHECK-NOT:    sqrshrun2
+; NOROUNDING-NOT: sqrshrun
+; NOROUNDING-NOT: sqrshrun2
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <8 x i32>*
+  %1 = load <8 x i32>, <8 x i32>* %0, align 1
+  %2 = add nuw nsw <8 x i32> %1, <i32 2097152, i32 2097152, i32 2097152, i32 2097152, i32 2097152, i32 2097152, i32 2097152, i32 2097152>
+  %3 = ashr <8 x i32> %2, <i32 22, i32 22, i32 22, i32 22, i32 22, i32 22, i32 22, i32 22>
+  %4 = icmp sgt <8 x i32> %3, zeroinitializer
+  %5 = select <8 x i1> %4, <8 x i32> %3, <8 x i32> zeroinitializer
+  %6 = icmp ult <8 x i32> %5, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %7 = select <8 x i1> %6, <8 x i32> %5, <8 x i32> <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %8 = trunc <8 x i32> %7 to <8 x i8>
+  %9 = bitcast i8* %dst to <8 x i8>*
+  store <8 x i8> %8, <8 x i8>* %9, align 1
+  ret void
+}
+
+; Negative test: Rounding+Truncation instruction doesn't support shift amount > input data with / 2
+define void @test_uqrshrn_long_shift(i8* nocapture %dst, i8* nocapture readonly %pix1) {
+; CHECK-LABEL: test_uqrshrn_long_shift:
+; CHECK:    ldr
+; NOROUNDING: ldr
+; CHECK:   urshr
+; CHECK-NOT:    uqrshrn
+; NOROUNDING-NOT: uqrshrn
+; CHECK:    str
+; NOROUNDING: ret
+; CHECK:    ret
+entry:
+  %0 = bitcast i8* %pix1 to <4 x i32>*
+  %1 = load <4 x i32>, <4 x i32>* %0, align 1
+  %2 = add nuw nsw <4 x i32> %1, <i32 2097152, i32 2097152, i32 2097152, i32 2097152>
+  %3 = lshr <4 x i32> %2, <i32 22, i32 22, i32 22, i32 22>
+  %4 = icmp sgt <4 x i32> %3, zeroinitializer
+  %5 = select <4 x i1> %4, <4 x i32> %3, <4 x i32> zeroinitializer
+  %6 = icmp ult <4 x i32> %5, <i32 65535, i32 65535, i32 65535, i32 65535>
+  %7 = select <4 x i1> %6, <4 x i32> %5, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>
+  %8 = trunc <4 x i32> %7 to <4 x i16>
+  %9 = bitcast i8* %dst to <4 x i16>*
+  store <4 x i16> %8, <4 x i16>* %9, align 1
+  ret void
+}
diff --git a/llvm/test/CodeGen/AArch64/neon-rshrn.ll b/llvm/test/CodeGen/AArch64/neon-rshrn.ll
index 8d47f4afb355f..c6f71cbf63dca 100644
--- a/llvm/test/CodeGen/AArch64/neon-rshrn.ll
+++ b/llvm/test/CodeGen/AArch64/neon-rshrn.ll
@@ -95,9 +95,9 @@ entry:
 define <16 x i8> @rshrn_v16i16_8(<16 x i16> %a) {
 ; CHECK-LABEL: rshrn_v16i16_8:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v2.2d, #0000000000000000
-; CHECK-NEXT:    raddhn v0.8b, v0.8h, v2.8h
-; CHECK-NEXT:    raddhn2 v0.16b, v1.8h, v2.8h
+; CHECK-NEXT:    urshr v1.8h, v1.8h, #8
+; CHECK-NEXT:    urshr v0.8h, v0.8h, #8
+; CHECK-NEXT:    uzp1 v0.16b, v0.16b, v1.16b
 ; CHECK-NEXT:    ret
 entry:
   %b = add <16 x i16> %a, <i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128, i16 128>
@@ -109,11 +109,8 @@ entry:
 define <16 x i8> @rshrn_v16i16_9(<16 x i16> %a) {
 ; CHECK-LABEL: rshrn_v16i16_9:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v2.8h, #1, lsl #8
-; CHECK-NEXT:    add v0.8h, v0.8h, v2.8h
-; CHECK-NEXT:    add v1.8h, v1.8h, v2.8h
-; CHECK-NEXT:    ushr v1.8h, v1.8h, #9
-; CHECK-NEXT:    ushr v0.8h, v0.8h, #9
+; CHECK-NEXT:    urshr v1.8h, v1.8h, #9
+; CHECK-NEXT:    urshr v0.8h, v0.8h, #9
 ; CHECK-NEXT:    uzp1 v0.16b, v0.16b, v1.16b
 ; CHECK-NEXT:    ret
 entry:
@@ -321,9 +318,9 @@ entry:
 define <8 x i16> @rshrn_v8i32_16(<8 x i32> %a) {
 ; CHECK-LABEL: rshrn_v8i32_16:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v2.2d, #0000000000000000
-; CHECK-NEXT:    raddhn v0.4h, v0.4s, v2.4s
-; CHECK-NEXT:    raddhn2 v0.8h, v1.4s, v2.4s
+; CHECK-NEXT:    urshr v1.4s, v1.4s, #16
+; CHECK-NEXT:    urshr v0.4s, v0.4s, #16
+; CHECK-NEXT:    uzp1 v0.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    ret
 entry:
   %b = add <8 x i32> %a, <i32 32768, i32 32768, i32 32768, i32 32768, i32 32768, i32 32768, i32 32768, i32 32768>
@@ -335,11 +332,8 @@ entry:
 define <8 x i16> @rshrn_v8i32_17(<8 x i32> %a) {
 ; CHECK-LABEL: rshrn_v8i32_17:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v2.4s, #1, lsl #16
-; CHECK-NEXT:    add v0.4s, v0.4s, v2.4s
-; CHECK-NEXT:    add v1.4s, v1.4s, v2.4s
-; CHECK-NEXT:    ushr v1.4s, v1.4s, #17
-; CHECK-NEXT:    ushr v0.4s, v0.4s, #17
+; CHECK-NEXT:    urshr v1.4s, v1.4s, #17
+; CHECK-NEXT:    urshr v0.4s, v0.4s, #17
 ; CHECK-NEXT:    uzp1 v0.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    ret
 entry:
@@ -868,9 +862,7 @@ entry:
 define void @rshrn_v2i32_4(<2 x i32> %a, ptr %p) {
 ; CHECK-LABEL: rshrn_v2i32_4:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v1.2s, #8
-; CHECK-NEXT:    add v0.2s, v0.2s, v1.2s
-; CHECK-NEXT:    ushr v0.2s, v0.2s, #4
+; CHECK-NEXT:    urshr v0.2s, v0.2s, #4
 ; CHECK-NEXT:    mov w8, v0.s[1]
 ; CHECK-NEXT:    fmov w9, s0
 ; CHECK-NEXT:    strh w9, [x0]



More information about the llvm-commits mailing list