[llvm] Subject: [PATCH] [AArch64ISelLowering] Optimize rounding shift and saturation truncation (PR #74325)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 4 06:41:39 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: None (JohnLee1243)
<details>
<summary>Changes</summary>
This patch does 2 kinds of instruction simplification.
1. Rounding shift operation like SHIFT(Add(OpA, 1<<(imm-1)), imm) can be simplified as srshrn(OpA, imm).
2. Rounding shift saturation truncation operations like Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),imm)),0),maxValue)) can be simplified as uqrshrn(OpA, imm)or sqrshrun(OpA, imm).
3. Add a pattern for RSHRN. This patch does these optimization in backend after legalization.
---
Patch is 41.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74325.diff
5 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp (+11)
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+595-1)
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+1-1)
- (added) llvm/test/CodeGen/AArch64/isel_rounding_opt.ll (+327)
- (modified) llvm/test/CodeGen/AArch64/neon-rshrn.ll (+11-19)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 2f49e9a6b37cc..b559a470ed3f2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -177,6 +177,17 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
}
bool SelectRoundingVLShr(SDValue N, SDValue &Res1, SDValue &Res2) {
+ if (N.getOpcode() == AArch64ISD::URSHR_I) {
+ EVT VT = N.getValueType();
+ unsigned ShtAmt = N->getConstantOperandVal(1);
+ if (ShtAmt >= VT.getScalarSizeInBits() / 2)
+ return false;
+
+ Res1 = N.getOperand(0);
+ Res2 = CurDAG->getTargetConstant(ShtAmt, SDLoc(N), MVT::i32);
+ return true;
+ }
+
if (N.getOpcode() != AArch64ISD::VLSHR)
return false;
SDValue Op = N->getOperand(0);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b6a16217dfae3..17da58f892baa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -111,6 +111,28 @@ STATISTIC(NumTailCalls, "Number of tail calls");
STATISTIC(NumShiftInserts, "Number of vector shift inserts");
STATISTIC(NumOptimizedImms, "Number of times immediates were optimized");
+static cl::opt<bool> EnableAArch64RoundingOpt(
+ "aarch64-optimize-rounding", cl::Hidden,
+ cl::desc("Enable AArch64 rounding optimization"),
+ cl::init(true));
+
+static cl::opt<bool> EnableAArch64RoundingSatNarrowOpt(
+ "aarch64-optimize-rounding-saturation", cl::Hidden,
+ cl::desc("Enable AArch64 rounding and saturation narrow optimization"),
+ cl::init(true));
+
+static cl::opt<bool> EnableAArch64ExtractVecElementCombine(
+ "aarch64-extract-vector-element-trunc-combine", cl::Hidden,
+ cl::desc("Allow AArch64 extract vector element combination with "
+ "truncation"),
+ cl::init(true));
+
+static cl::opt<int> RoundingSearchMaxDepth(
+ "aarch64-rounding-search-max-depth", cl::Hidden,
+ cl::desc("Maximum depth to bfs search rounding value in rounding "
+ "optimization"),
+ cl::init(4));
+
// FIXME: The necessary dtprel relocations don't seem to be supported
// well in the GNU bfd and gold linkers at the moment. Therefore, by
// default, for now, fall back to GeneralDynamic code generation.
@@ -995,6 +1017,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
+ setTargetDAGCombine(ISD::TRUNCATE);
+
setTargetDAGCombine({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG,
ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR,
@@ -17681,6 +17705,482 @@ static SDValue performANDCombine(SDNode *N,
return SDValue();
}
+// BFS search the operand which is equal to rounding value
+static bool searchRoundingValueBFS(
+ uint64_t RoundingValue, const SDValue &OperandIn,
+ SmallVectorImpl<std::pair<SDValue, SDNodeFlags>> &AddOperands, int Level) {
+ SmallVector<SDValue, 4> WorkList;
+ WorkList.emplace_back(OperandIn);
+ while (Level > 0 && !WorkList.empty()) {
+ auto Operand = WorkList.front();
+ SmallVector<SDValue>::iterator k = WorkList.begin();
+ WorkList.erase(k);
+ Level--;
+ SDValue Operand0 = Operand.getOperand(0);
+ SDValue Operand1 = Operand.getOperand(1);
+ BuildVectorSDNode *AddOp0 = dyn_cast<BuildVectorSDNode>(Operand0);
+ BuildVectorSDNode *AddOp1 = dyn_cast<BuildVectorSDNode>(Operand1);
+ auto foundRounding = [&](BuildVectorSDNode *AddOp, SDValue &OtherOperand) {
+ APInt SplatBitsAdd, SplatUndefAdd;
+ unsigned SplatBitSizeAdd = 0;
+ bool HasAnyUndefsAnd = false;
+ if (AddOp &&
+ AddOp->isConstantSplat(SplatBitsAdd, SplatUndefAdd, SplatBitSizeAdd,
+ HasAnyUndefsAnd) &&
+ (SplatBitsAdd == RoundingValue)) {
+ AddOperands.emplace_back(
+ std::make_pair(OtherOperand, OtherOperand.getNode()->getFlags()));
+ while (!WorkList.empty()) {
+ SDValue TempVal = WorkList.front();
+ SmallVector<SDValue>::iterator k = WorkList.begin();
+ WorkList.erase(k);
+ AddOperands.emplace_back(
+ std::make_pair(TempVal, TempVal.getNode()->getFlags()));
+ }
+ return true;
+ }
+ return false;
+ };
+ if (foundRounding(AddOp0, Operand1))
+ return true;
+ if (foundRounding(AddOp1, Operand0))
+ return true;
+ if (Operand0.getOpcode() == ISD::ADD)
+ WorkList.emplace_back(Operand0);
+ else
+ AddOperands.emplace_back(
+ std::make_pair(Operand0, Operand.getNode()->getFlags()));
+ if (Operand1.getOpcode() == ISD::ADD)
+ WorkList.emplace_back(Operand1);
+ else
+ AddOperands.emplace_back(
+ std::make_pair(Operand1, Operand.getNode()->getFlags()));
+ }
+
+ return false;
+}
+
+// Try to match pattern "OpB = SHIFT(Add(OpA, 1<<(imm-1)), imm) ", where
+// shift must be an immediate number
+static SDValue matchShiftRounding(const SDValue &ShiftOp0,
+ const SDValue &ShiftOp1, SelectionDAG &DAG,
+ int64_t &ShiftAmount) {
+ ShiftAmount = 0;
+ // For illegal type, do nothing. Wait until type is legalized.
+ EVT VT0 = ShiftOp0.getValueType();
+ if (!VT0.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT0))
+ return SDValue();
+
+ BuildVectorSDNode *ShiftOperand1 = dyn_cast<BuildVectorSDNode>(ShiftOp1);
+
+ // Shift value must be an immediate, either a constant splat vector form or
+ // scalar form.
+ int64_t TempAmount;
+ EVT VT = ShiftOp1.getValueType();
+ if (ShiftOperand1 && VT.isVector() &&
+ isVShiftRImm(ShiftOp1, VT, false, TempAmount))
+ ShiftAmount = TempAmount;
+
+ if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(ShiftOp1))
+ ShiftAmount = C->getSExtValue();
+
+ // For shift value = 1, match XHADD in first priority which accomplishes (a +
+ // b +1)>>1
+ if (ShiftOp0.getOpcode() == ISD::ADD && ShiftAmount > 1 &&
+ ShiftAmount <= 64) {
+ uint64_t RoundingValue = 1 << static_cast<uint64_t>(ShiftAmount - 1);
+ SmallVector<std::pair<SDValue, SDNodeFlags>, 4> AddOperands;
+
+ // In case expression has pattern "(a + roundingValue+b+c+d ) >> shift ",
+ // in which rounding value is not the direct input operand of shift,
+ // rounding value should be searched from root to leaves. In this case, the
+ // expresion will be matched as "RoundingShift(a + b + c + d, shift)"
+ if (searchRoundingValueBFS(RoundingValue, ShiftOp0, AddOperands,
+ RoundingSearchMaxDepth)) {
+ SDLoc DL(ShiftOp0);
+ EVT VTAdd = AddOperands[0].first.getValueType();
+
+ for (size_t i = 1; i < AddOperands.size(); i++) {
+ AddOperands[i].first =
+ DAG.getNode(ISD::ADD, DL, VTAdd, AddOperands[i].first,
+ AddOperands[i - 1].first, AddOperands[i].second);
+ }
+ return AddOperands[AddOperands.size() - 1].first;
+ }
+ }
+
+ return SDValue();
+}
+
+// Attempt to form XRSHR(OpA, imm) from "SRX(Add(OpA, 1<<(imm-1)), imm) "
+// Depending on logic and arithmetic shift, 'XRSHR' can be 'SRSHR' or 'URSHR'
+// and 'SRA' or 'SRL' for SRX.
+
+static SDValue matchXRSHR(SDNode *N, SelectionDAG &DAG,
+ unsigned ShiftRoundingOpc) {
+ EVT VT = N->getValueType(0);
+ int64_t ShiftAmount;
+ SDValue AddOperand =
+ matchShiftRounding(N->getOperand(0), N->getOperand(1), DAG, ShiftAmount);
+ if (!AddOperand)
+ return SDValue();
+ SDLoc DL(N);
+ SDValue ResultRounding = DAG.getNode(
+ ShiftRoundingOpc, DL, VT, AddOperand,
+ DAG.getConstant(static_cast<uint64_t>(ShiftAmount), DL, MVT::i32));
+ return ResultRounding;
+}
+
+static SDValue performSRLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ SelectionDAG &DAG = DCI.DAG;
+ EVT VT = N->getValueType(0);
+
+ // Attempt to form URSHR(OpA, imm) from "SRL(Add(OpA, 1<<(imm-1), imm)"
+ SDValue ResultRounding;
+ if (EnableAArch64RoundingOpt && VT.isVector() &&
+ static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON())
+ ResultRounding = matchXRSHR(N, DAG, AArch64ISD::URSHR_I);
+
+ if (ResultRounding)
+ return ResultRounding;
+
+ if (VT != MVT::i32 && VT != MVT::i64)
+ return SDValue();
+
+ // Canonicalize (srl (bswap i32 x), 16) to (rotr (bswap i32 x), 16), if the
+ // high 16-bits of x are zero. Similarly, canonicalize (srl (bswap i64 x), 32)
+ // to (rotr (bswap i64 x), 32), if the high 32-bits of x are zero.
+ SDValue N0 = N->getOperand(0);
+ if (N0.getOpcode() == ISD::BSWAP) {
+ SDLoc DL(N);
+ SDValue N1 = N->getOperand(1);
+ SDValue N00 = N0.getOperand(0);
+ if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) {
+ uint64_t ShiftAmt = C->getZExtValue();
+ if (VT == MVT::i32 && ShiftAmt == 16 &&
+ DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(32, 16)))
+ return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
+ if (VT == MVT::i64 && ShiftAmt == 32 &&
+ DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(64, 32)))
+ return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
+ }
+ }
+ return SDValue();
+}
+
+static SDValue performSRACombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ SelectionDAG &DAG = DCI.DAG;
+ EVT VT = N->getValueType(0);
+
+ // Attempt to form URSHR(OpA, shift) from "SRL(Add(OpA, 1<<(imm-1), imm) "
+ SDValue ResultRounding;
+ if (EnableAArch64RoundingOpt && VT.isVector() &&
+ static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON())
+ ResultRounding = matchXRSHR(N, DAG, AArch64ISD::SRSHR_I);
+
+ return ResultRounding;
+}
+
+// Try to match pattern "Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),
+// imm)),0),maxValue)) " Here min(max(x,0),maxValue) is an unsigned
+// saturation operation, shift is signed right scalar or vector shift.
+static SDValue matchTruncSatRounding(const SDValue &UminOp, SelectionDAG &DAG,
+ uint64_t MaxScalarValue,
+ int64_t &ShiftAmount) {
+ if (UminOp.getOpcode() != ISD::UMIN)
+ return SDValue();
+
+ APInt SplatBitsRound, SplatUndefRound;
+ unsigned SplatBitSizeRound = 0;
+ bool HasAnyUndefsRound = false;
+ uint64_t UminOp1VecVal = 0;
+
+ BuildVectorSDNode *UminOp1Vec =
+ dyn_cast<BuildVectorSDNode>(UminOp.getOperand(1));
+
+ if (UminOp1Vec &&
+ UminOp1Vec->isConstantSplat(SplatBitsRound, SplatUndefRound,
+ SplatBitSizeRound, HasAnyUndefsRound))
+ UminOp1VecVal = SplatBitsRound.getZExtValue();
+
+ if (UminOp1VecVal != MaxScalarValue)
+ return SDValue();
+
+ SDValue SmaxOp = UminOp.getOperand(0);
+ if (SmaxOp.getOpcode() != ISD::SMAX ||
+ !isNullOrNullSplat(SmaxOp.getOperand(1)))
+ return SDValue();
+
+ SDValue RoundingOp = SmaxOp.getOperand(0);
+ unsigned int RoundingOpCode = RoundingOp.getOpcode();
+ if (RoundingOpCode == AArch64ISD::VASHR ||
+ RoundingOpCode == AArch64ISD::VLSHR || RoundingOpCode == ISD::SRA ||
+ RoundingOpCode == ISD::SRL) {
+ SDValue AddOperand = matchShiftRounding(
+ RoundingOp.getOperand(0), RoundingOp.getOperand(1), DAG, ShiftAmount);
+ // Rounding+Truncation instruction doesn't support shift amount > input data
+ // width/2
+ int64_t InSize = UminOp.getValueType().getScalarSizeInBits();
+ if (ShiftAmount > InSize / 2)
+ return SDValue();
+ return AddOperand;
+ }
+
+ return SDValue();
+}
+
+// A helper function to build sqrshrun or uqrshrn instruction
+SDValue generateQrshrnInstruction(unsigned ShiftRoundingOpc, SelectionDAG &DAG,
+ SDValue &AddOperand, int64_t ShiftAmount) {
+ SDLoc DL(AddOperand);
+ EVT InVT = AddOperand.getValueType();
+ EVT HalvedVT = InVT.changeVectorElementType(
+ InVT.getVectorElementType().getHalfSizedIntegerVT(*DAG.getContext()));
+
+ SDValue ResultRounding = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, HalvedVT,
+ DAG.getConstant(ShiftRoundingOpc, DL, MVT::i64), AddOperand,
+ DAG.getConstant(static_cast<uint64_t>(ShiftAmount), DL, MVT::i32));
+ return ResultRounding;
+}
+
+// Attempt to match Trunc(Concat(Trunc(min(max(Shift(Add(OpA, 1<<(imm1-1),
+// imm1)),0),maxValue)),Trunc(min(max(Shift(Add(OpB, 1<<(imm2-1),
+// imm2)),0),maxValue))))
+// The pattern in tree is shown below
+// OpA 1<<(imm1-1) OpB 1<<(imm2-1)
+// \ / \ /
+// \ / \ /
+// Add imm1 imm2 Add
+// \ / \ /
+// \ / \ /
+// Shift 0 0 Shift
+// \ / \ /
+// \ / \ /
+// max maxVal maxVal max
+// \ / \ /
+// \ / \ /
+// min min
+// \ /
+// \ /
+// Trunc Trunc
+// \ /
+// \ /
+// Concat
+// |
+// |
+// Trunc
+// The pattern will be matched as uqxtn(concat(qrshrn(OpA,imm1),
+// qrshrn(OpB,imm2))) where uqxtn is saturation truncation, qrshrn is sqrshrun
+// or uqrshrn.
+static SDValue matchQrshrn2(SDNode *N, SelectionDAG &DAG) {
+ TypeSize OutSizeTyped = N->getValueSizeInBits(0);
+ if (OutSizeTyped.isScalable())
+ return SDValue();
+ uint64_t OutSize = OutSizeTyped;
+ SDValue Operand0 = N->getOperand(0);
+ if (Operand0.getOpcode() != ISD::CONCAT_VECTORS || OutSize != 64)
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+ uint64_t OutScalarSize = VT.getScalarSizeInBits();
+ uint64_t MaxScalarValue = (1 << OutScalarSize) - 1;
+ SDLoc DL(N);
+ SDValue TruncOp0 = Operand0.getOperand(0);
+ SDValue TruncOp1 = Operand0.getOperand(1);
+ if (TruncOp0.getOpcode() != ISD::TRUNCATE ||
+ TruncOp1.getOpcode() != ISD::TRUNCATE)
+ return SDValue();
+
+ int64_t ShiftAmount0 = -1, ShiftAmount1 = -1;
+ SDValue AddOperand0 = matchTruncSatRounding(TruncOp0.getOperand(0), DAG,
+ MaxScalarValue, ShiftAmount0);
+ SDValue AddOperand1 = matchTruncSatRounding(TruncOp1.getOperand(0), DAG,
+ MaxScalarValue, ShiftAmount1);
+ if (!AddOperand0 || !AddOperand1)
+ return SDValue();
+
+ auto getShiftVal = [&](SDValue &TruncOp) {
+ if (SDValue Operand0 = TruncOp.getOperand(0))
+ if (SDValue Operand00 = Operand0.getOperand(0))
+ return Operand00.getOperand(0);
+
+ return SDValue();
+ };
+
+ SDValue Shift0Val = getShiftVal(TruncOp0);
+ SDValue Shift1Val = getShiftVal(TruncOp1);
+ if (!Shift0Val || !Shift1Val)
+ return SDValue();
+
+ unsigned Shift0Opc = Shift0Val.getOpcode();
+ unsigned Shift1Opc = Shift1Val.getOpcode();
+ unsigned TruncOpc0 = (Shift0Opc == AArch64ISD::VLSHR || Shift0Opc == ISD::SRL)
+ ? Intrinsic::aarch64_neon_uqrshrn
+ : Intrinsic::aarch64_neon_sqrshrun;
+ unsigned TruncOpc1 = (Shift1Opc == AArch64ISD::VLSHR || Shift1Opc == ISD::SRL)
+ ? Intrinsic::aarch64_neon_uqrshrn
+ : Intrinsic::aarch64_neon_sqrshrun;
+ SDValue ResultTrunc0 =
+ generateQrshrnInstruction(TruncOpc0, DAG, AddOperand0, ShiftAmount0);
+ SDValue ResultTrunc1 =
+ generateQrshrnInstruction(TruncOpc1, DAG, AddOperand1, ShiftAmount1);
+
+ EVT ConcatVT = ResultTrunc1.getValueType().getDoubleNumVectorElementsVT(
+ *DAG.getContext());
+ SDValue ConcatOp = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
+ ResultTrunc0, ResultTrunc1);
+ // Notice MaxScalarValue is finial truncated type max value, so twice
+ // saturation is needed. uqxtn is saturation truncation.
+ SDValue ResultOp = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::aarch64_neon_uqxtn, DL, MVT::i64), ConcatOp);
+
+ return ResultOp;
+}
+
+// To match SQRSHRUN from trunc(umin(smax((a + b + (1<<(imm-1))) >>imm,
+// max), 0))
+static SDValue matchQrshrn(SDNode *N, SelectionDAG &DAG) {
+ int64_t ShiftAmount = -1;
+ SDValue AddOperand;
+ TypeSize OutSizeTyped = N->getValueSizeInBits(0);
+ if (OutSizeTyped.isScalable())
+ return SDValue();
+ uint64_t OutSize = OutSizeTyped;
+ EVT VT = N->getValueType(0);
+
+ uint64_t OutScalarSize = VT.getScalarSizeInBits();
+ uint64_t MaxScalarValue = (1 << OutScalarSize) - 1;
+ SDLoc DL(N);
+
+ if (OutSize <= 64 && OutSize >= 32)
+ AddOperand = matchTruncSatRounding(N->getOperand(0), DAG, MaxScalarValue,
+ ShiftAmount);
+
+ if (!AddOperand)
+ return SDValue();
+
+ uint64_t UminOpScalarSize =
+ N->getOperand(0).getValueType().getScalarSizeInBits();
+ unsigned ShiftOpc = N->getOperand(0).getOperand(0).getOperand(0).getOpcode();
+
+ unsigned TruncOpc = (ShiftOpc == AArch64ISD::VLSHR || ShiftOpc == ISD::SRL)
+ ? Intrinsic::aarch64_neon_uqrshrn
+ : Intrinsic::aarch64_neon_sqrshrun;
+
+ SDValue ResultTrunc =
+ generateQrshrnInstruction(TruncOpc, DAG, AddOperand, ShiftAmount);
+ // For pattern "trunc(trunc(umin(smax((a + b + (1<<(imm-1)))
+ // >>imm,0),max))", where max is the outer truncated type max value, so
+ // another saturation trunction is needed.
+ // Because final truncated type may be illegal, and there is no method to
+ // legalize intrinsic, so add redundant operation `CONCAT_VECTORS` and
+ // `EXTRACT_SUBVECTOR` to automatically legalize the operation.
+ // But generated asm code is not so optimized in this way.
+ // For example, generated asm:
+ // sqrshrun v1.4h, v0.4s, #6
+ // sqrshrun2 v1.8h, v0.4s, #6
+ // uqxtn v0.8b, v1.8h
+ // zip1 v0.16b, v0.16b, v0.16b
+ // xtn v0.8b, v0.8h
+ //
+ // ideal Optimized code:
+ //
+ // sqrshrun v1.4h, v0.4s, #6
+ // uqxtn v0.8b, v1.8h
+ // To solve this issue, Function `hasUselessTrunc` is introduced which can
+ // optimize code like above.
+ if (ResultTrunc && OutScalarSize == UminOpScalarSize / 4) {
+ EVT ConcatVT = ResultTrunc.getValueType().getDoubleNumVectorElementsVT(
+ *DAG.getContext());
+ SDValue ConcatOp = DAG.getNode(
+ ISD::CONCAT_VECTORS, DL, ConcatVT, ResultTrunc,
+ DAG.getNode(ISD::UNDEF, SDLoc(), ResultTrunc.getValueType()));
+ EVT HalvedVT = ConcatVT.changeVectorElementType(
+ ConcatVT.getVectorElementType().getHalfSizedIntegerVT(
+ *DAG.getContext()));
+
+ ResultTrunc = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, HalvedVT,
+ DAG.getConstant(Intrinsic::aarch64_neon_uqxtn, DL, MVT::i64), ConcatOp);
+
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ResultTrunc,
+ DAG.getVectorIdxConstant(0, DL));
+ }
+
+ // Only support truncated type size = 1/2 or 1/4 of input type size
+ if (ResultTrunc)
+ assert(OutScalarSize == UminOpScalarSize / 2 &&
+ "Invalid Truncation Type Size!");
+ return ResultTrunc;
+}
+
+// Try to match pattern (truncate (BUILD_VECTOR(a[i],a[i+1],..., x,x,..)))
+static bool hasUselessTrunc(SDValue &N, SDValue &UsefullValue,
+ unsigned FinalEleNum, uint64_t OldExtractIndex,
+ int &NewExtractIndex) {
+ if (N.getOpcode() != ISD::TRUNCATE)
+ return false;
+
+ SDValue N0 = N.getOperand(0);
+ if (N0.getOpcode() != ISD::BUILD_VECTOR)
+ return false;
+
+ SmallVector<SDValue, 8> ExtractValues;
+ SmallVector<int, 8> Indices;
+ unsigned ElementNum = N.getValueType().getVectorNumElements();
+
+ // for example, if v8i8 is bitcasted to v2i32, then `TransformedSize` is 4
+ unsigned TransformedSize = 0;
+ if (FinalEleNum != 0)
+ TransformedSize = ElementNum / FinalEleNum;
+ for...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/74325
More information about the llvm-commits
mailing list