[llvm] Subject: [PATCH] [AArch64ISelLowering] Optimize rounding shift and saturation truncation (PR #74325)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 4 06:41:39 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: None (JohnLee1243)

<details>
<summary>Changes</summary>

This patch does 2 kinds of instruction simplification.
1. Rounding shift operation like SHIFT(Add(OpA, 1<<(imm-1)), imm) can be simplified as srshrn(OpA, imm).
2. Rounding shift saturation truncation operations like Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),imm)),0),maxValue)) can be simplified as uqrshrn(OpA, imm)or sqrshrun(OpA, imm).
3. Add a pattern for RSHRN. This patch does these optimization in backend after legalization.

---

Patch is 41.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74325.diff


5 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp (+11) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+595-1) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+1-1) 
- (added) llvm/test/CodeGen/AArch64/isel_rounding_opt.ll (+327) 
- (modified) llvm/test/CodeGen/AArch64/neon-rshrn.ll (+11-19) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 2f49e9a6b37cc..b559a470ed3f2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -177,6 +177,17 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
   }
 
   bool SelectRoundingVLShr(SDValue N, SDValue &Res1, SDValue &Res2) {
+    if (N.getOpcode() == AArch64ISD::URSHR_I) {
+      EVT VT = N.getValueType();
+      unsigned ShtAmt = N->getConstantOperandVal(1);
+      if (ShtAmt >= VT.getScalarSizeInBits() / 2)
+        return false;
+
+      Res1 = N.getOperand(0);
+      Res2 = CurDAG->getTargetConstant(ShtAmt, SDLoc(N), MVT::i32);
+      return true;
+    }
+
     if (N.getOpcode() != AArch64ISD::VLSHR)
       return false;
     SDValue Op = N->getOperand(0);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b6a16217dfae3..17da58f892baa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -111,6 +111,28 @@ STATISTIC(NumTailCalls, "Number of tail calls");
 STATISTIC(NumShiftInserts, "Number of vector shift inserts");
 STATISTIC(NumOptimizedImms, "Number of times immediates were optimized");
 
+static cl::opt<bool> EnableAArch64RoundingOpt(
+    "aarch64-optimize-rounding", cl::Hidden,
+    cl::desc("Enable AArch64 rounding optimization"),
+    cl::init(true));
+
+static cl::opt<bool> EnableAArch64RoundingSatNarrowOpt(
+    "aarch64-optimize-rounding-saturation", cl::Hidden,
+    cl::desc("Enable AArch64 rounding and saturation narrow optimization"),
+    cl::init(true));
+
+static cl::opt<bool> EnableAArch64ExtractVecElementCombine(
+    "aarch64-extract-vector-element-trunc-combine", cl::Hidden,
+    cl::desc("Allow AArch64 extract vector element combination with "
+             "truncation"),
+    cl::init(true));
+
+static cl::opt<int> RoundingSearchMaxDepth(
+    "aarch64-rounding-search-max-depth", cl::Hidden,
+    cl::desc("Maximum depth to bfs search rounding value in rounding "
+             "optimization"),
+    cl::init(4));
+
 // FIXME: The necessary dtprel relocations don't seem to be supported
 // well in the GNU bfd and gold linkers at the moment. Therefore, by
 // default, for now, fall back to GeneralDynamic code generation.
@@ -995,6 +1017,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
   setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
+  setTargetDAGCombine(ISD::TRUNCATE);
+
   setTargetDAGCombine({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
                        ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG,
                        ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR,
@@ -17681,6 +17705,482 @@ static SDValue performANDCombine(SDNode *N,
   return SDValue();
 }
 
+// BFS search the operand which is equal to rounding value
+static bool searchRoundingValueBFS(
+    uint64_t RoundingValue, const SDValue &OperandIn,
+    SmallVectorImpl<std::pair<SDValue, SDNodeFlags>> &AddOperands, int Level) {
+  SmallVector<SDValue, 4> WorkList;
+  WorkList.emplace_back(OperandIn);
+  while (Level > 0 && !WorkList.empty()) {
+    auto Operand = WorkList.front();
+    SmallVector<SDValue>::iterator k = WorkList.begin();
+    WorkList.erase(k);
+    Level--;
+    SDValue Operand0 = Operand.getOperand(0);
+    SDValue Operand1 = Operand.getOperand(1);
+    BuildVectorSDNode *AddOp0 = dyn_cast<BuildVectorSDNode>(Operand0);
+    BuildVectorSDNode *AddOp1 = dyn_cast<BuildVectorSDNode>(Operand1);
+    auto foundRounding = [&](BuildVectorSDNode *AddOp, SDValue &OtherOperand) {
+      APInt SplatBitsAdd, SplatUndefAdd;
+      unsigned SplatBitSizeAdd = 0;
+      bool HasAnyUndefsAnd = false;
+      if (AddOp &&
+          AddOp->isConstantSplat(SplatBitsAdd, SplatUndefAdd, SplatBitSizeAdd,
+                                 HasAnyUndefsAnd) &&
+          (SplatBitsAdd == RoundingValue)) {
+        AddOperands.emplace_back(
+            std::make_pair(OtherOperand, OtherOperand.getNode()->getFlags()));
+        while (!WorkList.empty()) {
+          SDValue TempVal = WorkList.front();
+          SmallVector<SDValue>::iterator k = WorkList.begin();
+          WorkList.erase(k);
+          AddOperands.emplace_back(
+              std::make_pair(TempVal, TempVal.getNode()->getFlags()));
+        }
+        return true;
+      }
+      return false;
+    };
+    if (foundRounding(AddOp0, Operand1))
+      return true;
+    if (foundRounding(AddOp1, Operand0))
+      return true;
+    if (Operand0.getOpcode() == ISD::ADD)
+      WorkList.emplace_back(Operand0);
+    else
+      AddOperands.emplace_back(
+          std::make_pair(Operand0, Operand.getNode()->getFlags()));
+    if (Operand1.getOpcode() == ISD::ADD)
+      WorkList.emplace_back(Operand1);
+    else
+      AddOperands.emplace_back(
+          std::make_pair(Operand1, Operand.getNode()->getFlags()));
+  }
+
+  return false;
+}
+
+// Try to match pattern "OpB = SHIFT(Add(OpA, 1<<(imm-1)), imm) ", where
+// shift must be an immediate number
+static SDValue matchShiftRounding(const SDValue &ShiftOp0,
+                                  const SDValue &ShiftOp1, SelectionDAG &DAG,
+                                  int64_t &ShiftAmount) {
+  ShiftAmount = 0;
+  // For illegal type, do nothing. Wait until type is legalized.
+  EVT VT0 = ShiftOp0.getValueType();
+  if (!VT0.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT0))
+    return SDValue();
+
+  BuildVectorSDNode *ShiftOperand1 = dyn_cast<BuildVectorSDNode>(ShiftOp1);
+
+  // Shift value must be an immediate, either a constant splat vector form or
+  // scalar form.
+  int64_t TempAmount;
+  EVT VT = ShiftOp1.getValueType();
+  if (ShiftOperand1 && VT.isVector() &&
+      isVShiftRImm(ShiftOp1, VT, false, TempAmount))
+    ShiftAmount = TempAmount;
+
+  if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(ShiftOp1))
+    ShiftAmount = C->getSExtValue();
+
+  // For shift value = 1, match XHADD in first priority which accomplishes (a +
+  // b +1)>>1
+  if (ShiftOp0.getOpcode() == ISD::ADD && ShiftAmount > 1 &&
+      ShiftAmount <= 64) {
+    uint64_t RoundingValue = 1 << static_cast<uint64_t>(ShiftAmount - 1);
+    SmallVector<std::pair<SDValue, SDNodeFlags>, 4> AddOperands;
+
+    // In case expression has pattern "(a + roundingValue+b+c+d ) >> shift ",
+    // in which rounding value is not the direct input operand of shift,
+    // rounding value should be searched from root to leaves. In this case, the
+    // expresion will be matched as "RoundingShift(a + b + c + d, shift)"
+    if (searchRoundingValueBFS(RoundingValue, ShiftOp0, AddOperands,
+                               RoundingSearchMaxDepth)) {
+      SDLoc DL(ShiftOp0);
+      EVT VTAdd = AddOperands[0].first.getValueType();
+
+      for (size_t i = 1; i < AddOperands.size(); i++) {
+        AddOperands[i].first =
+            DAG.getNode(ISD::ADD, DL, VTAdd, AddOperands[i].first,
+                        AddOperands[i - 1].first, AddOperands[i].second);
+      }
+      return AddOperands[AddOperands.size() - 1].first;
+    }
+  }
+
+  return SDValue();
+}
+
+// Attempt to form XRSHR(OpA, imm) from "SRX(Add(OpA, 1<<(imm-1)), imm) "
+// Depending on logic and arithmetic shift, 'XRSHR' can be 'SRSHR' or 'URSHR'
+// and 'SRA' or 'SRL' for SRX.
+
+static SDValue matchXRSHR(SDNode *N, SelectionDAG &DAG,
+                          unsigned ShiftRoundingOpc) {
+  EVT VT = N->getValueType(0);
+  int64_t ShiftAmount;
+  SDValue AddOperand =
+      matchShiftRounding(N->getOperand(0), N->getOperand(1), DAG, ShiftAmount);
+  if (!AddOperand)
+    return SDValue();
+  SDLoc DL(N);
+  SDValue ResultRounding = DAG.getNode(
+      ShiftRoundingOpc, DL, VT, AddOperand,
+      DAG.getConstant(static_cast<uint64_t>(ShiftAmount), DL, MVT::i32));
+  return ResultRounding;
+}
+
+static SDValue performSRLCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI) {
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+
+  // Attempt to form URSHR(OpA, imm) from "SRL(Add(OpA, 1<<(imm-1), imm)"
+  SDValue ResultRounding;
+  if (EnableAArch64RoundingOpt && VT.isVector() &&
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON())
+    ResultRounding = matchXRSHR(N, DAG, AArch64ISD::URSHR_I);
+
+  if (ResultRounding)
+    return ResultRounding;
+
+  if (VT != MVT::i32 && VT != MVT::i64)
+    return SDValue();
+
+  // Canonicalize (srl (bswap i32 x), 16) to (rotr (bswap i32 x), 16), if the
+  // high 16-bits of x are zero. Similarly, canonicalize (srl (bswap i64 x), 32)
+  // to (rotr (bswap i64 x), 32), if the high 32-bits of x are zero.
+  SDValue N0 = N->getOperand(0);
+  if (N0.getOpcode() == ISD::BSWAP) {
+    SDLoc DL(N);
+    SDValue N1 = N->getOperand(1);
+    SDValue N00 = N0.getOperand(0);
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) {
+      uint64_t ShiftAmt = C->getZExtValue();
+      if (VT == MVT::i32 && ShiftAmt == 16 &&
+          DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(32, 16)))
+        return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
+      if (VT == MVT::i64 && ShiftAmt == 32 &&
+          DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(64, 32)))
+        return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
+    }
+  }
+  return SDValue();
+}
+
+static SDValue performSRACombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI) {
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+
+  // Attempt to form URSHR(OpA, shift) from "SRL(Add(OpA, 1<<(imm-1), imm) "
+  SDValue ResultRounding;
+  if (EnableAArch64RoundingOpt && VT.isVector() &&
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON())
+    ResultRounding = matchXRSHR(N, DAG, AArch64ISD::SRSHR_I);
+
+  return ResultRounding;
+}
+
+// Try to match pattern "Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),
+// imm)),0),maxValue)) " Here min(max(x,0),maxValue) is an unsigned
+// saturation operation, shift is signed right scalar or vector shift.
+static SDValue matchTruncSatRounding(const SDValue &UminOp, SelectionDAG &DAG,
+                                     uint64_t MaxScalarValue,
+                                     int64_t &ShiftAmount) {
+  if (UminOp.getOpcode() != ISD::UMIN)
+    return SDValue();
+
+  APInt SplatBitsRound, SplatUndefRound;
+  unsigned SplatBitSizeRound = 0;
+  bool HasAnyUndefsRound = false;
+  uint64_t UminOp1VecVal = 0;
+
+  BuildVectorSDNode *UminOp1Vec =
+      dyn_cast<BuildVectorSDNode>(UminOp.getOperand(1));
+
+  if (UminOp1Vec &&
+      UminOp1Vec->isConstantSplat(SplatBitsRound, SplatUndefRound,
+                                  SplatBitSizeRound, HasAnyUndefsRound))
+    UminOp1VecVal = SplatBitsRound.getZExtValue();
+
+  if (UminOp1VecVal != MaxScalarValue)
+    return SDValue();
+
+  SDValue SmaxOp = UminOp.getOperand(0);
+  if (SmaxOp.getOpcode() != ISD::SMAX ||
+      !isNullOrNullSplat(SmaxOp.getOperand(1)))
+    return SDValue();
+
+  SDValue RoundingOp = SmaxOp.getOperand(0);
+  unsigned int RoundingOpCode = RoundingOp.getOpcode();
+  if (RoundingOpCode == AArch64ISD::VASHR ||
+      RoundingOpCode == AArch64ISD::VLSHR || RoundingOpCode == ISD::SRA ||
+      RoundingOpCode == ISD::SRL) {
+    SDValue AddOperand = matchShiftRounding(
+        RoundingOp.getOperand(0), RoundingOp.getOperand(1), DAG, ShiftAmount);
+    // Rounding+Truncation instruction doesn't support shift amount > input data
+    // width/2
+    int64_t InSize = UminOp.getValueType().getScalarSizeInBits();
+    if (ShiftAmount > InSize / 2)
+      return SDValue();
+    return AddOperand;
+  }
+
+  return SDValue();
+}
+
+// A helper function to build sqrshrun or uqrshrn instruction
+SDValue generateQrshrnInstruction(unsigned ShiftRoundingOpc, SelectionDAG &DAG,
+                                  SDValue &AddOperand, int64_t ShiftAmount) {
+  SDLoc DL(AddOperand);
+  EVT InVT = AddOperand.getValueType();
+  EVT HalvedVT = InVT.changeVectorElementType(
+      InVT.getVectorElementType().getHalfSizedIntegerVT(*DAG.getContext()));
+
+  SDValue ResultRounding = DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, HalvedVT,
+      DAG.getConstant(ShiftRoundingOpc, DL, MVT::i64), AddOperand,
+      DAG.getConstant(static_cast<uint64_t>(ShiftAmount), DL, MVT::i32));
+  return ResultRounding;
+}
+
+// Attempt to match Trunc(Concat(Trunc(min(max(Shift(Add(OpA, 1<<(imm1-1),
+// imm1)),0),maxValue)),Trunc(min(max(Shift(Add(OpB, 1<<(imm2-1),
+// imm2)),0),maxValue))))
+// The pattern in tree is shown below
+//  OpA     1<<(imm1-1)                         OpB     1<<(imm2-1)
+//    \       /                                    \       /
+//     \     /                                      \     /
+//      Add      imm1                       imm2      Add
+//        \      /                              \      /
+//         \    /                                \    /
+//         Shift     0                    0      Shift
+//            \     /                      \      /
+//             \   /                        \    /
+//              max     maxVal      maxVal   max
+//                \     /             \      /
+//                 \   /               \    /
+//                  min                 min
+//                    \                 /
+//                     \               /
+//                    Trunc         Trunc
+//                        \         /
+//                         \       /
+//                          Concat
+//                             |
+//                             |
+//                           Trunc
+// The pattern will be matched as uqxtn(concat(qrshrn(OpA,imm1),
+// qrshrn(OpB,imm2))) where uqxtn is saturation truncation, qrshrn is sqrshrun
+// or uqrshrn.
+static SDValue matchQrshrn2(SDNode *N, SelectionDAG &DAG) {
+  TypeSize OutSizeTyped = N->getValueSizeInBits(0);
+  if (OutSizeTyped.isScalable())
+    return SDValue();
+  uint64_t OutSize = OutSizeTyped;
+  SDValue Operand0 = N->getOperand(0);
+  if (Operand0.getOpcode() != ISD::CONCAT_VECTORS || OutSize != 64)
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  uint64_t OutScalarSize = VT.getScalarSizeInBits();
+  uint64_t MaxScalarValue = (1 << OutScalarSize) - 1;
+  SDLoc DL(N);
+  SDValue TruncOp0 = Operand0.getOperand(0);
+  SDValue TruncOp1 = Operand0.getOperand(1);
+  if (TruncOp0.getOpcode() != ISD::TRUNCATE ||
+      TruncOp1.getOpcode() != ISD::TRUNCATE)
+    return SDValue();
+
+  int64_t ShiftAmount0 = -1, ShiftAmount1 = -1;
+  SDValue AddOperand0 = matchTruncSatRounding(TruncOp0.getOperand(0), DAG,
+                                              MaxScalarValue, ShiftAmount0);
+  SDValue AddOperand1 = matchTruncSatRounding(TruncOp1.getOperand(0), DAG,
+                                              MaxScalarValue, ShiftAmount1);
+  if (!AddOperand0 || !AddOperand1)
+    return SDValue();
+
+  auto getShiftVal = [&](SDValue &TruncOp) {
+    if (SDValue Operand0 = TruncOp.getOperand(0))
+      if (SDValue Operand00 = Operand0.getOperand(0))
+        return Operand00.getOperand(0);
+
+    return SDValue();
+  };
+
+  SDValue Shift0Val = getShiftVal(TruncOp0);
+  SDValue Shift1Val = getShiftVal(TruncOp1);
+  if (!Shift0Val || !Shift1Val)
+    return SDValue();
+
+  unsigned Shift0Opc = Shift0Val.getOpcode();
+  unsigned Shift1Opc = Shift1Val.getOpcode();
+  unsigned TruncOpc0 = (Shift0Opc == AArch64ISD::VLSHR || Shift0Opc == ISD::SRL)
+                           ? Intrinsic::aarch64_neon_uqrshrn
+                           : Intrinsic::aarch64_neon_sqrshrun;
+  unsigned TruncOpc1 = (Shift1Opc == AArch64ISD::VLSHR || Shift1Opc == ISD::SRL)
+                           ? Intrinsic::aarch64_neon_uqrshrn
+                           : Intrinsic::aarch64_neon_sqrshrun;
+  SDValue ResultTrunc0 =
+      generateQrshrnInstruction(TruncOpc0, DAG, AddOperand0, ShiftAmount0);
+  SDValue ResultTrunc1 =
+      generateQrshrnInstruction(TruncOpc1, DAG, AddOperand1, ShiftAmount1);
+
+  EVT ConcatVT = ResultTrunc1.getValueType().getDoubleNumVectorElementsVT(
+      *DAG.getContext());
+  SDValue ConcatOp = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
+                                 ResultTrunc0, ResultTrunc1);
+  // Notice MaxScalarValue is finial truncated type max value, so twice
+  // saturation is needed. uqxtn is saturation truncation.
+  SDValue ResultOp = DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, VT,
+      DAG.getConstant(Intrinsic::aarch64_neon_uqxtn, DL, MVT::i64), ConcatOp);
+
+  return ResultOp;
+}
+
+// To match SQRSHRUN from trunc(umin(smax((a + b + (1<<(imm-1))) >>imm,
+// max), 0))
+static SDValue matchQrshrn(SDNode *N, SelectionDAG &DAG) {
+  int64_t ShiftAmount = -1;
+  SDValue AddOperand;
+  TypeSize OutSizeTyped = N->getValueSizeInBits(0);
+  if (OutSizeTyped.isScalable())
+    return SDValue();
+  uint64_t OutSize = OutSizeTyped;
+  EVT VT = N->getValueType(0);
+
+  uint64_t OutScalarSize = VT.getScalarSizeInBits();
+  uint64_t MaxScalarValue = (1 << OutScalarSize) - 1;
+  SDLoc DL(N);
+
+  if (OutSize <= 64 && OutSize >= 32)
+    AddOperand = matchTruncSatRounding(N->getOperand(0), DAG, MaxScalarValue,
+                                       ShiftAmount);
+
+  if (!AddOperand)
+    return SDValue();
+
+  uint64_t UminOpScalarSize =
+      N->getOperand(0).getValueType().getScalarSizeInBits();
+  unsigned ShiftOpc = N->getOperand(0).getOperand(0).getOperand(0).getOpcode();
+
+  unsigned TruncOpc = (ShiftOpc == AArch64ISD::VLSHR || ShiftOpc == ISD::SRL)
+                          ? Intrinsic::aarch64_neon_uqrshrn
+                          : Intrinsic::aarch64_neon_sqrshrun;
+
+  SDValue ResultTrunc =
+      generateQrshrnInstruction(TruncOpc, DAG, AddOperand, ShiftAmount);
+  // For pattern "trunc(trunc(umin(smax((a + b + (1<<(imm-1)))
+  // >>imm,0),max))", where max is the outer truncated type max value, so
+  // another saturation trunction is needed.
+  // Because final truncated type may be illegal, and there is no method to
+  // legalize intrinsic, so add redundant operation `CONCAT_VECTORS` and
+  // `EXTRACT_SUBVECTOR` to automatically legalize the operation.
+  // But generated asm code is not so optimized in this way.
+  // For example, generated asm:
+  // sqrshrun  v1.4h, v0.4s, #6
+  // sqrshrun2 v1.8h, v0.4s, #6
+  // uqxtn   v0.8b, v1.8h
+  // zip1    v0.16b, v0.16b, v0.16b
+  // xtn     v0.8b, v0.8h
+  //
+  // ideal Optimized code:
+  //
+  // sqrshrun  v1.4h, v0.4s, #6
+  // uqxtn   v0.8b, v1.8h
+  // To solve this issue, Function `hasUselessTrunc` is introduced which can
+  // optimize code like above.
+  if (ResultTrunc && OutScalarSize == UminOpScalarSize / 4) {
+    EVT ConcatVT = ResultTrunc.getValueType().getDoubleNumVectorElementsVT(
+        *DAG.getContext());
+    SDValue ConcatOp = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, ConcatVT, ResultTrunc,
+        DAG.getNode(ISD::UNDEF, SDLoc(), ResultTrunc.getValueType()));
+    EVT HalvedVT = ConcatVT.changeVectorElementType(
+        ConcatVT.getVectorElementType().getHalfSizedIntegerVT(
+            *DAG.getContext()));
+
+    ResultTrunc = DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, HalvedVT,
+        DAG.getConstant(Intrinsic::aarch64_neon_uqxtn, DL, MVT::i64), ConcatOp);
+
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ResultTrunc,
+                       DAG.getVectorIdxConstant(0, DL));
+  }
+
+  // Only support truncated type size = 1/2 or 1/4 of input type size
+  if (ResultTrunc)
+    assert(OutScalarSize == UminOpScalarSize / 2 &&
+           "Invalid Truncation Type Size!");
+  return ResultTrunc;
+}
+
+// Try to match pattern  (truncate (BUILD_VECTOR(a[i],a[i+1],..., x,x,..)))
+static bool hasUselessTrunc(SDValue &N, SDValue &UsefullValue,
+                            unsigned FinalEleNum, uint64_t OldExtractIndex,
+                            int &NewExtractIndex) {
+  if (N.getOpcode() != ISD::TRUNCATE)
+    return false;
+
+  SDValue N0 = N.getOperand(0);
+  if (N0.getOpcode() != ISD::BUILD_VECTOR)
+    return false;
+
+  SmallVector<SDValue, 8> ExtractValues;
+  SmallVector<int, 8> Indices;
+  unsigned ElementNum = N.getValueType().getVectorNumElements();
+
+  // for example, if v8i8 is bitcasted to v2i32, then `TransformedSize` is 4
+  unsigned TransformedSize = 0;
+  if (FinalEleNum != 0)
+    TransformedSize = ElementNum / FinalEleNum;
+  for...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74325


More information about the llvm-commits mailing list