[llvm] [DAG] Support saturated truncate (PR #99418)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 17 18:44:59 PDT 2024
https://github.com/ParkHanbum created https://github.com/llvm/llvm-project/pull/99418
`truncate` is `saturated` if no additional conversion is required
between the target and return values. if the target is `saturated`
when attempting to crop from a `vector`, there is an opportunity
to optimize it.
previously, each architecture had an attemping optimization, so there
was redundant code.
this patch implements common logic by adding `ISD::TRUNCATE_[US]SAT`
to indicate saturated truncate.
>From 66b93322069570310c029954dc17dfb7f7d99ec4 Mon Sep 17 00:00:00 2001
From: hanbeom <kese111 at gmail.com>
Date: Tue, 16 Jul 2024 13:52:29 +0900
Subject: [PATCH 1/3] [DAG] Support saturated truncate
`truncate` is `saturated` if no additional conversion is required
between the target and return values. if the target is `saturated`
when attempting to crop from a `vector`, there is an opportunity
to optimize it.
previously, each architecture had an attemping optimization, so there
was redundant code.
this patch implements common logic by adding `ISD::TRUNCATE_[US]SAT`
to indicate saturated truncate.
Fixes #85903
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 3 +
.../include/llvm/Target/TargetSelectionDAG.td | 2 +
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 132 +++++++++++++++++-
.../SelectionDAG/SelectionDAGDumper.cpp | 2 +
llvm/lib/CodeGen/TargetLoweringBase.cpp | 4 +
5 files changed, 142 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index e6b10209b4767..0b36e5b40da73 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -804,6 +804,9 @@ enum NodeType {
/// TRUNCATE - Completely drop the high bits.
TRUNCATE,
+ /// TRUNCATE_[SU]SAT - Truncate for saturated operand
+ TRUNCATE_SSAT,
+ TRUNCATE_USAT,
/// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
/// depends on the first letter) to floating point.
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 133c9b113e51b..a5242694c9507 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -471,6 +471,8 @@ def sext : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
def zext : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
def anyext : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
def trunc : SDNode<"ISD::TRUNCATE" , SDTIntTruncOp>;
+def truncssat : SDNode<"ISD::TRUNCATE_SSAT", SDTIntTruncOp>;
+def truncusat : SDNode<"ISD::TRUNCATE_USAT", SDTIntTruncOp>;
def bitconvert : SDNode<"ISD::BITCAST" , SDTUnaryOp>;
def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
def freeze : SDNode<"ISD::FREEZE" , SDTFreeze>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 302ad128f4f53..967f313c9885e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -486,6 +486,8 @@ namespace {
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
SDValue visitTRUNCATE(SDNode *N);
+ SDValue visitTRUNCATE_SSAT(SDNode *N);
+ SDValue visitTRUNCATE_USAT(SDNode *N);
SDValue visitBITCAST(SDNode *N);
SDValue visitFREEZE(SDNode *N);
SDValue visitBUILD_PAIR(SDNode *N);
@@ -1907,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::ZERO_EXTEND_VECTOR_INREG:
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
case ISD::TRUNCATE: return visitTRUNCATE(N);
+ case ISD::TRUNCATE_SSAT: return visitTRUNCATE_SSAT(N);
+ case ISD::TRUNCATE_USAT: return visitTRUNCATE_USAT(N);
case ISD::BITCAST: return visitBITCAST(N);
case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
case ISD::FADD: return visitFADD(N);
@@ -13154,7 +13158,8 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
unsigned CastOpcode = Cast->getOpcode();
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
- CastOpcode == ISD::FP_ROUND) &&
+ CastOpcode == ISD::TRUNCATE_SSAT ||
+ CastOpcode == ISD::TRUNCATE_USAT || CastOpcode == ISD::FP_ROUND) &&
"Unexpected opcode for vector select narrowing/widening");
// We only do this transform before legal ops because the pattern may be
@@ -14867,6 +14872,119 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0);
+ SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
+ if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
+ FPInstr.getOpcode() == ISD::FP_TO_UINT) {
+ EVT FPVT = FPInstr.getOperand(0).getValueType();
+ if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+ FPVT, VT))
+ return SDValue();
+ SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+ FPInstr.getOperand(0),
+ DAG.getValueType(VT.getScalarType()));
+ return Sat;
+ }
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitTRUNCATE_SSAT(SDNode *N) { return SDValue(); }
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+/// Return the source value x to be truncated or SDValue() if the pattern was
+/// not matched.
+///
+/// 2. (truncate (smin (smax (x, C1), C2)) to dest_type),
+/// where C1 >= 0 and C2 is unsigned max of destination type.
+///
+/// (truncate (smax (smin (x, C2), C1)) to dest_type)
+/// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
+///
+/// These two patterns are equivalent to:
+/// (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type)
+/// So return the smax(x, C1) value to be truncated or SDValue() if the
+/// pattern was not matched.
+static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+ const SDLoc &DL) {
+ EVT InVT = In.getValueType();
+
+ // Saturation with truncation. We truncate from InVT to VT.
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+ "Unexpected types for truncate operation");
+
+ // Match min/max and return limit value as a parameter.
+ auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt C1, C2;
+ if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
+ // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+ // the element size of the destination type.
+ if (C2.isMask(VT.getScalarSizeInBits()))
+ return UMin;
+
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
+ if (MatchMinMax(SMin, ISD::SMAX, C1))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
+ return SMin;
+
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
+ C2.uge(C1))
+ return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
+
+ return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+/// signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+/// signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectSSatPattern(SDValue In, EVT VT) {
+ unsigned NumDstBits = VT.getScalarSizeInBits();
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+ auto MatchMinMax = [](SDValue V, unsigned Opcode,
+ const APInt &Limit) -> SDValue {
+ APInt C;
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt SignedMax, SignedMin;
+ SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+ SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
+ return SMax;
+ }
+ }
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
+ return SMin;
+ }
+ }
+ return SDValue();
+}
+
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
@@ -14874,6 +14992,18 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
bool isLE = DAG.getDataLayout().isLittleEndian();
SDLoc DL(N);
+ if (!LegalOperations && N->getOpcode() == ISD::TRUNCATE) {
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT, SrcVT)) {
+ if (SDValue SSatVal = detectSSatPattern(N0, VT))
+ return DAG.getNode(ISD::TRUNCATE_SSAT, DL, VT, SSatVal);
+ }
+
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT, SrcVT)) {
+ if (SDValue USatVal = detectUSatPattern(N0, VT, DAG, DL))
+ return DAG.getNode(ISD::TRUNCATE_USAT, DL, VT, USatVal);
+ }
+ }
+
// trunc(undef) = undef
if (N0.isUndef())
return DAG.getUNDEF(VT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index cc8de3a217f82..d3ad6c8acf4f1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,6 +380,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::SIGN_EXTEND_VECTOR_INREG: return "sign_extend_vector_inreg";
case ISD::ZERO_EXTEND_VECTOR_INREG: return "zero_extend_vector_inreg";
case ISD::TRUNCATE: return "truncate";
+ case ISD::TRUNCATE_SSAT: return "truncate_ssat";
+ case ISD::TRUNCATE_USAT: return "truncate_usat";
case ISD::FP_ROUND: return "fp_round";
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
case ISD::FP_EXTEND: return "fp_extend";
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index bf031c00a2449..3e855d5e450df 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -718,6 +718,10 @@ void TargetLoweringBase::initActions() {
// Absolute difference
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);
+ // Saturated trunc
+ setOperationAction(ISD::TRUNCATE_SSAT, VT, Expand);
+ setOperationAction(ISD::TRUNCATE_USAT, VT, Expand);
+
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
Expand);
>From c80af0333a0c241dfe6d333e711e7b11be20ea08 Mon Sep 17 00:00:00 2001
From: hanbeom <kese111 at gmail.com>
Date: Tue, 16 Jul 2024 14:05:55 +0900
Subject: [PATCH 2/3] [AArch64] Support saturated truncate
Add support for `ISD::TRUNCATE_[US]SAT`.
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 ++
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 16 ++++++++++++++++
2 files changed, 18 insertions(+)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index df9b0ae1a632f..504bbaed1c8aa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1274,6 +1274,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::AVGCEILU, VT, Legal);
setOperationAction(ISD::ABDS, VT, Legal);
setOperationAction(ISD::ABDU, VT, Legal);
+ setOperationAction(ISD::TRUNCATE_SSAT, VT, Legal);
+ setOperationAction(ISD::TRUNCATE_USAT, VT, Legal);
}
// Vector reductions
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index dd11f74882115..322219607407b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -5343,9 +5343,13 @@ def VImm8000: PatLeaf<(AArch64mvni_msl (i32 127), (i32 264))>;
// trunc(umin(X, 255)) -> UQXTRN v8i8
def : Pat<(v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))),
(UQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncusat (v8i16 V128:$Vn))),
+ (UQXTNv8i8 V128:$Vn)>;
// trunc(umin(X, 65535)) -> UQXTRN v4i16
def : Pat<(v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))),
(UQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncusat (v4i32 V128:$Vn))),
+ (UQXTNv4i16 V128:$Vn)>;
// trunc(smin(smax(X, -128), 128)) -> SQXTRN
// with reversed min/max
def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
@@ -5354,6 +5358,8 @@ def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
def : Pat<(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
(v8i16 VImm80)))),
(SQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncssat (v8i16 V128:$Vn))),
+ (SQXTNv8i8 V128:$Vn)>;
// trunc(smin(smax(X, -32768), 32767)) -> SQXTRN
// with reversed min/max
def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
@@ -5362,6 +5368,8 @@ def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
def : Pat<(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
(v4i32 VImm8000)))),
(SQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncssat (v4i32 V128:$Vn))),
+ (SQXTNv4i16 V128:$Vn)>;
// concat_vectors(Vd, trunc(smin(smax Vm, -128), 127) ~> SQXTN2(Vd, Vn)
// with reversed min/max
@@ -5375,6 +5383,10 @@ def : Pat<(v16i8 (concat_vectors
(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
(v8i16 VImm80)))))),
(SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v16i8 (concat_vectors
+ (v8i8 V64:$Vd),
+ (v8i8 (truncssat (v8i16 V128:$Vn))))),
+ (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// concat_vectors(Vd, trunc(smin(smax Vm, -32768), 32767) ~> SQXTN2(Vd, Vn)
// with reversed min/max
@@ -5388,6 +5400,10 @@ def : Pat<(v8i16 (concat_vectors
(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
(v4i32 VImm8000)))))),
(SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v8i16 (concat_vectors
+ (v4i16 V64:$Vd),
+ (v4i16 (truncssat (v4i32 V128:$Vn))))),
+ (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// Select BSWAP vector instructions into REV instructions
def : Pat<(v4i16 (bswap (v4i16 V64:$Rn))),
>From 108a26c9a7a771d23a7ada07172dccabf5e0a89c Mon Sep 17 00:00:00 2001
From: hanbeom <kese111 at gmail.com>
Date: Tue, 16 Jul 2024 14:14:40 +0900
Subject: [PATCH 3/3] [RISCV] Support saturated truncate
Add support for `ISD::TRUNCATE_[US]SAT`.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 40 +++++++++++++++------
llvm/lib/Target/RISCV/RISCVISelLowering.h | 2 ++
2 files changed, 32 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 953196a586b6e..3b54416d1b5b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -853,7 +853,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
// nodes which truncate by one power of two at a time.
- setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction(
+ {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT, Custom);
// Custom-lower insert/extract operations to simplify patterns.
setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
@@ -1168,7 +1169,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SELECT, VT, Custom);
- setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction(
+ {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT,
+ Custom);
setOperationAction(ISD::BITCAST, VT, Custom);
@@ -1479,8 +1482,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
if ((Subtarget.hasStdExtZbs() && Subtarget.is64Bit()) ||
- Subtarget.hasStdExtV())
+ Subtarget.hasStdExtV()) {
setTargetDAGCombine(ISD::TRUNCATE);
+ setTargetDAGCombine(ISD::TRUNCATE_SSAT);
+ setTargetDAGCombine(ISD::TRUNCATE_USAT);
+ }
if (Subtarget.hasStdExtZbkb())
setTargetDAGCombine(ISD::BITREVERSE);
@@ -6092,7 +6098,7 @@ static bool hasMergeOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 130 &&
+ 132 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -6118,7 +6124,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 130 &&
+ 132 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -6389,6 +6395,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
}
case ISD::TRUNCATE:
+ case ISD::TRUNCATE_SSAT:
+ case ISD::TRUNCATE_USAT:
// Only custom-lower vector truncates
if (!Op.getSimpleValueType().isVector())
return Op;
@@ -8275,11 +8283,15 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
LLVMContext &Context = *DAG.getContext();
const ElementCount Count = ContainerVT.getVectorElementCount();
+ unsigned NewOpc = RISCVISD::TRUNCATE_VECTOR_VL;
+ if (Op.getOpcode() == ISD::TRUNCATE_SSAT)
+ NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
+ else if (Op.getOpcode() == ISD::TRUNCATE_USAT)
+ NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
do {
SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
- Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
- Mask, VL);
+ Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL);
} while (SrcEltVT != DstEltVT);
if (SrcVT.isFixedLengthVector())
@@ -16512,7 +16524,9 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
// minimum value.
static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+ assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL ||
+ N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT ||
+ N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT);
MVT VT = N->getSimpleValueType(0);
@@ -16617,9 +16631,11 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
SDValue Val;
unsigned ClipOpc;
- if ((Val = DetectUSatPattern(Src)))
+
+ Val = N->getOperand(0);
+ if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT)
ClipOpc = RISCVISD::VNCLIPU_VL;
- else if ((Val = DetectSSatPattern(Src)))
+ else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT)
ClipOpc = RISCVISD::VNCLIP_VL;
else
return SDValue();
@@ -16857,6 +16873,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
return SDValue();
case RISCVISD::TRUNCATE_VECTOR_VL:
+ case RISCVISD::TRUNCATE_VECTOR_VL_SSAT:
+ case RISCVISD::TRUNCATE_VECTOR_VL_USAT:
if (SDValue V = combineTruncOfSraSext(N, DAG))
return V;
return combineTruncToVnclip(N, DAG, Subtarget);
@@ -20433,6 +20451,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+ NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
+ NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
NODE_NAME_CASE(VSLIDEUP_VL)
NODE_NAME_CASE(VSLIDE1UP_VL)
NODE_NAME_CASE(VSLIDEDOWN_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b0ad9229f0b3..3d582fcdaf64b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -181,6 +181,8 @@ enum NodeType : unsigned {
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
// mask and VL operand.
TRUNCATE_VECTOR_VL,
+ TRUNCATE_VECTOR_VL_SSAT,
+ TRUNCATE_VECTOR_VL_USAT,
// Matches the semantics of vslideup/vslidedown. The first operand is the
// pass-thru operand, the second is the source vector, the third is the XLenVT
// index (either constant or non-constant), the fourth is the mask, the fifth
More information about the llvm-commits
mailing list