[llvm] [DAG] Support saturated truncate (PR #99418)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 25 22:45:20 PDT 2024
https://github.com/ParkHanbum updated https://github.com/llvm/llvm-project/pull/99418
>From e3baff292e25990a1db3b3b8e41b9952674c7441 Mon Sep 17 00:00:00 2001
From: hanbeom <kese111 at gmail.com>
Date: Tue, 16 Jul 2024 13:52:29 +0900
Subject: [PATCH 1/3] [DAG] Support saturated truncate
`truncate` is `saturated` if no additional conversion is required
between the target and return values. if the target is `saturated`
when attempting to crop from a `vector`, there is an opportunity
to optimize it.
previously, each architecture had an attemping optimization, so there
was redundant code.
this patch implements common logic by adding `ISD::TRUNCATE_[US]SAT`
to indicate saturated truncate.
Fixes #85903
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 7 +
.../include/llvm/Target/TargetSelectionDAG.td | 3 +
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 165 +++++++++++++++++-
.../SelectionDAG/SelectionDAGDumper.cpp | 3 +
llvm/lib/CodeGen/TargetLoweringBase.cpp | 5 +
5 files changed, 182 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 5b657fb171296..f67c6ee3d55dc 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -814,6 +814,13 @@ enum NodeType {
/// TRUNCATE - Completely drop the high bits.
TRUNCATE,
+ /// TRUNCATE_[SU]SAT - Truncate for saturated operand
+ TRUNCATE_SSAT_S, // saturate signed input to signed result -
+ // truncate(smin(smax(x)))
+ TRUNCATE_SSAT_U, // saturate signed input to unsigned result -
+ // truncate(smin(smax(x,0)))
+ TRUNCATE_USAT_U, // saturate unsigned input to unsigned result -
+ // truncate(umin(x))
/// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
/// depends on the first letter) to floating point.
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 46044aab79a83..92d10a94bd81e 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -477,6 +477,9 @@ def sext : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
def zext : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
def anyext : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
def trunc : SDNode<"ISD::TRUNCATE" , SDTIntTruncOp>;
+def truncssat_s : SDNode<"ISD::TRUNCATE_SSAT_S", SDTIntTruncOp>;
+def truncssat_u : SDNode<"ISD::TRUNCATE_SSAT_U", SDTIntTruncOp>;
+def truncusat_u : SDNode<"ISD::TRUNCATE_USAT_U", SDTIntTruncOp>;
def bitconvert : SDNode<"ISD::BITCAST" , SDTUnaryOp>;
def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
def freeze : SDNode<"ISD::FREEZE" , SDTFreeze>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 060e66175d965..8840aa7be2a5d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -486,6 +486,7 @@ namespace {
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
SDValue visitTRUNCATE(SDNode *N);
+ SDValue visitTRUNCATE_USAT(SDNode *N);
SDValue visitBITCAST(SDNode *N);
SDValue visitFREEZE(SDNode *N);
SDValue visitBUILD_PAIR(SDNode *N);
@@ -1908,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::ZERO_EXTEND_VECTOR_INREG:
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
case ISD::TRUNCATE: return visitTRUNCATE(N);
+ case ISD::TRUNCATE_USAT_U:
+ case ISD::TRUNCATE_SSAT_U: return visitTRUNCATE_USAT(N);
case ISD::BITCAST: return visitBITCAST(N);
case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
case ISD::FADD: return visitFADD(N);
@@ -13203,7 +13206,9 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
unsigned CastOpcode = Cast->getOpcode();
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
- CastOpcode == ISD::FP_ROUND) &&
+ CastOpcode == ISD::TRUNCATE_SSAT_S ||
+ CastOpcode == ISD::TRUNCATE_SSAT_U ||
+ CastOpcode == ISD::TRUNCATE_USAT_U || CastOpcode == ISD::FP_ROUND) &&
"Unexpected opcode for vector select narrowing/widening");
// We only do this transform before legal ops because the pattern may be
@@ -14915,6 +14920,159 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0);
+ SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
+ if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
+ FPInstr.getOpcode() == ISD::FP_TO_UINT) {
+ EVT FPVT = FPInstr.getOperand(0).getValueType();
+ if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+ FPVT, VT))
+ return SDValue();
+ SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+ FPInstr.getOperand(0),
+ DAG.getValueType(VT.getScalarType()));
+ return Sat;
+ }
+
+ return SDValue();
+}
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+/// Return the source value x to be truncated or SDValue() if the pattern was
+/// not matched.
+///
+static SDValue detectUSatUPattern(SDValue In, EVT VT) {
+ EVT InVT = In.getValueType();
+
+ // Saturation with truncation. We truncate from InVT to VT.
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+ "Unexpected types for truncate operation");
+
+ // Match min/max and return limit value as a parameter.
+ auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt C1, C2;
+ if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
+ // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+ // the element size of the destination type.
+ if (C2.isMask(VT.getScalarSizeInBits()))
+ return UMin;
+
+ return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+/// signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+/// signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectSSatSPattern(SDValue In, EVT VT) {
+ unsigned NumDstBits = VT.getScalarSizeInBits();
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+ auto MatchMinMax = [](SDValue V, unsigned Opcode,
+ const APInt &Limit) -> SDValue {
+ APInt C;
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt SignedMax, SignedMin;
+ SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+ SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
+ return SMax;
+ }
+ }
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
+ return SMin;
+ }
+ }
+ return SDValue();
+}
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// (truncate (smin (smax (x, C1), C2)) to dest_type),
+/// where C1 >= 0 and C2 is unsigned max of destination type.
+///
+/// (truncate (smax (smin (x, C2), C1)) to dest_type)
+/// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
+///
+static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+ const SDLoc &DL) {
+ EVT InVT = In.getValueType();
+
+ // Saturation with truncation. We truncate from InVT to VT.
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+ "Unexpected types for truncate operation");
+
+ // Match min/max and return limit value as a parameter.
+ auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt C1, C2;
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
+ if (MatchMinMax(SMin, ISD::SMAX, C1))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
+ return SMin;
+
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
+ C2.uge(C1))
+ return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
+
+ return SDValue();
+}
+
+static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
+ SDLoc &DL, const TargetLowering &TLI,
+ SelectionDAG &DAG) {
+ if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT_S, SrcVT) &&
+ TLI.isTypeDesirableForOp(ISD::TRUNCATE_SSAT_S, VT)) {
+ if (SDValue SSatVal = detectSSatSPattern(Src, VT))
+ return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
+ } else if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT_U, SrcVT) &&
+ TLI.isTypeDesirableForOp(ISD::TRUNCATE_SSAT_U, VT)) {
+ if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
+ return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
+ }
+ } else if (Src.getOpcode() == ISD::UMIN) {
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT_U, SrcVT) &&
+ TLI.isTypeDesirableForOp(ISD::TRUNCATE_USAT_U, VT)) {
+ if (SDValue USatVal = detectUSatUPattern(Src, VT)) {
+ return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
+ }
+ }
+ }
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
@@ -14930,6 +15088,11 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
if (N0.getOpcode() == ISD::TRUNCATE)
return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
+ // fold satruated truncate
+ if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG)) {
+ return SaturatedTR;
+ }
+
// fold (truncate c1) -> c1
if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
return C;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 16fc52caebb75..46e8e54ee4ed7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,6 +380,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::SIGN_EXTEND_VECTOR_INREG: return "sign_extend_vector_inreg";
case ISD::ZERO_EXTEND_VECTOR_INREG: return "zero_extend_vector_inreg";
case ISD::TRUNCATE: return "truncate";
+ case ISD::TRUNCATE_SSAT_S: return "truncate_ssat_s";
+ case ISD::TRUNCATE_SSAT_U: return "truncate_ssat_u";
+ case ISD::TRUNCATE_USAT_U: return "truncate_usat_u";
case ISD::FP_ROUND: return "fp_round";
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
case ISD::FP_EXTEND: return "fp_extend";
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 6ca9955993d24..149b5dabee056 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -753,6 +753,11 @@ void TargetLoweringBase::initActions() {
// Absolute difference
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);
+ // Saturated trunc
+ setOperationAction(ISD::TRUNCATE_SSAT_S, VT, Expand);
+ setOperationAction(ISD::TRUNCATE_SSAT_U, VT, Expand);
+ setOperationAction(ISD::TRUNCATE_USAT_U, VT, Expand);
+
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
Expand);
>From e53f9bde236d0a87b7a00799d6a654ae7cbb9914 Mon Sep 17 00:00:00 2001
From: hanbeom <kese111 at gmail.com>
Date: Tue, 16 Jul 2024 14:05:55 +0900
Subject: [PATCH 2/3] [AArch64] Support saturated truncate
Add support for `ISD::TRUNCATE_[US]SAT`.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 18 ++++++++
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 5 +++
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 44 +++++++++----------
llvm/test/CodeGen/AArch64/qmovn.ll | 12 ++---
4 files changed, 49 insertions(+), 30 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d86e52d49000a..c42dc9d4fc3b2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1410,6 +1410,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
}
}
+ for (MVT VT : {MVT::v8i16, MVT::v4i32}) {
+ setOperationAction(ISD::TRUNCATE_SSAT_S, VT, Custom);
+ setOperationAction(ISD::TRUNCATE_SSAT_U, VT, Custom);
+ setOperationAction(ISD::TRUNCATE_USAT_U, VT, Custom);
+ }
+
if (Subtarget->hasSME()) {
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
}
@@ -28730,6 +28736,18 @@ bool AArch64TargetLowering::hasInlineStackProbe(
MF.getInfo<AArch64FunctionInfo>()->hasStackProbing();
}
+bool AArch64TargetLowering::isTypeDesirableForOp(unsigned Opc, EVT VT) const {
+ switch (Opc) {
+ case ISD::TRUNCATE_SSAT_S:
+ case ISD::TRUNCATE_SSAT_U:
+ case ISD::TRUNCATE_USAT_U:
+ if (VT == MVT::v8i8 || VT == MVT::v4i16)
+ return true;
+ }
+
+ return TargetLowering::isTypeDesirableForOp(Opc, VT);
+}
+
#ifndef NDEBUG
void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
switch (N->getOpcode()) {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81e15185f985d..50e26612ac863 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -743,6 +743,11 @@ class AArch64TargetLowering : public TargetLowering {
bool generateFMAsInMachineCombiner(EVT VT,
CodeGenOptLevel OptLevel) const override;
+ /// Return true if the target has native support for
+ /// the specified value type and it is 'desirable' to use the type for the
+ /// given node type.
+ bool isTypeDesirableForOp(unsigned Opc, EVT VT) const override;
+
const MCPhysReg *getScratchRegisters(CallingConv::ID CC) const override;
ArrayRef<MCPhysReg> getRoundingControlRegisters() const override;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 1053ba9242768..ac42f9cb6eb63 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -5418,64 +5418,60 @@ def VImm7FFF: PatLeaf<(AArch64movi_msl (i32 127), (i32 264))>;
def VImm8000: PatLeaf<(AArch64mvni_msl (i32 127), (i32 264))>;
// trunc(umin(X, 255)) -> UQXTRN v8i8
-def : Pat<(v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))),
+def : Pat<(v8i8 (truncusat_u (v8i16 V128:$Vn))),
(UQXTNv8i8 V128:$Vn)>;
// trunc(umin(X, 65535)) -> UQXTRN v4i16
-def : Pat<(v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))),
+def : Pat<(v4i16 (truncusat_u (v4i32 V128:$Vn))),
(UQXTNv4i16 V128:$Vn)>;
// trunc(smin(smax(X, -128), 128)) -> SQXTRN
// with reversed min/max
-def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
- (v8i16 VImm7F)))),
- (SQXTNv8i8 V128:$Vn)>;
-def : Pat<(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
- (v8i16 VImm80)))),
+def : Pat<(v8i8 (truncssat_s (v8i16 V128:$Vn))),
(SQXTNv8i8 V128:$Vn)>;
+// trunc(umin(smax(X, 0), 255)) -> SQXTUN
+def : Pat<(v8i8 (truncssat_u (v8i16 V128:$Vn))),
+ (SQXTUNv8i8 V128:$Vn)>;
// trunc(smin(smax(X, -32768), 32767)) -> SQXTRN
// with reversed min/max
-def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
- (v4i32 VImm7FFF)))),
- (SQXTNv4i16 V128:$Vn)>;
-def : Pat<(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
- (v4i32 VImm8000)))),
+def : Pat<(v4i16 (truncssat_s (v4i32 V128:$Vn))),
(SQXTNv4i16 V128:$Vn)>;
+// trunc(umin(smax(X, 0), 65535)) -> SQXTUN
+def : Pat<(v4i16 (truncssat_u (v4i32 V128:$Vn))),
+ (SQXTUNv4i16 V128:$Vn)>;
// concat_vectors(Vd, trunc(umin(X, 255))) -> UQXTRN(Vd, Vn)
def : Pat<(v16i8 (concat_vectors
(v8i8 V64:$Vd),
- (v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))))),
+ (v8i8 (truncusat_u (v8i16 V128:$Vn))))),
(UQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// concat_vectors(Vd, trunc(umin(X, 65535))) -> UQXTRN(Vd, Vn)
def : Pat<(v8i16 (concat_vectors
(v4i16 V64:$Vd),
- (v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))))),
+ (v4i16 (truncusat_u (v4i32 V128:$Vn))))),
(UQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// concat_vectors(Vd, trunc(smin(smax Vm, -128), 127) ~> SQXTN2(Vd, Vn)
// with reversed min/max
def : Pat<(v16i8 (concat_vectors
(v8i8 V64:$Vd),
- (v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
- (v8i16 VImm7F)))))),
+ (v8i8 (truncssat_s (v8i16 V128:$Vn))))),
(SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+// concat_vectors(Vd, trunc(smin(smax Vm, 0), 127) ~> SQXTUN2(Vd, Vn)
def : Pat<(v16i8 (concat_vectors
(v8i8 V64:$Vd),
- (v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
- (v8i16 VImm80)))))),
- (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+ (v8i8 (truncssat_u (v8i16 V128:$Vn))))),
+ (SQXTUNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// concat_vectors(Vd, trunc(smin(smax Vm, -32768), 32767) ~> SQXTN2(Vd, Vn)
// with reversed min/max
def : Pat<(v8i16 (concat_vectors
(v4i16 V64:$Vd),
- (v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
- (v4i32 VImm7FFF)))))),
+ (v4i16 (truncssat_s (v4i32 V128:$Vn))))),
(SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+// concat_vectors(Vd, trunc(smin(smax Vm, 0), 32767) ~> SQXTUN2(Vd, Vn)
def : Pat<(v8i16 (concat_vectors
(v4i16 V64:$Vd),
- (v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
- (v4i32 VImm8000)))))),
- (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+ (v4i16 (truncssat_u (v4i32 V128:$Vn))))),
+ (SQXTUNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// Select BSWAP vector instructions into REV instructions
def : Pat<(v4i16 (bswap (v4i16 V64:$Rn))),
diff --git a/llvm/test/CodeGen/AArch64/qmovn.ll b/llvm/test/CodeGen/AArch64/qmovn.ll
index 35c172adbad3d..0b19a9ff7e3dd 100644
--- a/llvm/test/CodeGen/AArch64/qmovn.ll
+++ b/llvm/test/CodeGen/AArch64/qmovn.ll
@@ -292,15 +292,15 @@ entry:
; Test the (concat_vectors (X), (trunc(umin(smax(Y, 0), 2^n))))) pattern.
+; TODO: %min is a value between 0 and 255 and is within the unsigned range of i8.
+; So it is saturated truncate. we have an optimization opportunity.
define <16 x i8> @us_maxmin_v8i16_to_v16i8(<8 x i8> %x, <8 x i16> %y) {
; CHECK-LABEL: us_maxmin_v8i16_to_v16i8:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v2.2d, #0000000000000000
-; CHECK-NEXT: movi v3.2d, #0xff00ff00ff00ff
; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NEXT: smax v1.8h, v1.8h, v2.8h
-; CHECK-NEXT: smin v1.8h, v1.8h, v3.8h
-; CHECK-NEXT: xtn2 v0.16b, v1.8h
+; CHECK-NEXT: uqxtn2 v0.16b, v1.8h
; CHECK-NEXT: ret
entry:
%max = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %y, <8 x i16> zeroinitializer)
@@ -310,15 +310,15 @@ entry:
ret <16 x i8> %shuffle
}
+; TODO: %min is a value between 0 and 65535 and is within the unsigned range of i16.
+; So it is saturated. we have an optimization opportunity.
define <8 x i16> @us_maxmin_v4i32_to_v8i16(<4 x i16> %x, <4 x i32> %y) {
; CHECK-LABEL: us_maxmin_v4i32_to_v8i16:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v2.2d, #0000000000000000
; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NEXT: smax v1.4s, v1.4s, v2.4s
-; CHECK-NEXT: movi v2.2d, #0x00ffff0000ffff
-; CHECK-NEXT: smin v1.4s, v1.4s, v2.4s
-; CHECK-NEXT: xtn2 v0.8h, v1.4s
+; CHECK-NEXT: uqxtn2 v0.8h, v1.4s
; CHECK-NEXT: ret
entry:
%max = call <4 x i32> @llvm.smax.v4i32(<4 x i32> %y, <4 x i32> zeroinitializer)
>From 5523f7ae23ab40249ab59d467afe8d9cc025b493 Mon Sep 17 00:00:00 2001
From: hanbeom <kese111 at gmail.com>
Date: Tue, 16 Jul 2024 14:14:40 +0900
Subject: [PATCH 3/3] [RISCV] Support saturated truncate
Add support for `ISD::TRUNCATE_[US]SAT`.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 25 ++++++++++++++++-----
1 file changed, 20 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d40d4997d7614..704caeab90bb6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -853,7 +853,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
// nodes which truncate by one power of two at a time.
- setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction({ISD::TRUNCATE, ISD::TRUNCATE_SSAT_S,
+ ISD::TRUNCATE_SSAT_U, ISD::TRUNCATE_USAT_U},
+ VT, Custom);
// Custom-lower insert/extract operations to simplify patterns.
setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
@@ -1168,7 +1170,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SELECT, VT, Custom);
- setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction({ISD::TRUNCATE, ISD::TRUNCATE_SSAT_S,
+ ISD::TRUNCATE_SSAT_U, ISD::TRUNCATE_USAT_U},
+ VT, Custom);
setOperationAction(ISD::BITCAST, VT, Custom);
@@ -6395,6 +6399,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
}
case ISD::TRUNCATE:
+ case ISD::TRUNCATE_SSAT_S:
+ case ISD::TRUNCATE_SSAT_U:
+ case ISD::TRUNCATE_USAT_U:
// Only custom-lower vector truncates
if (!Op.getSimpleValueType().isVector())
return Op;
@@ -8234,7 +8241,8 @@ SDValue RISCVTargetLowering::lowerVectorMaskTruncLike(SDValue Op,
SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
SelectionDAG &DAG) const {
- bool IsVPTrunc = Op.getOpcode() == ISD::VP_TRUNCATE;
+ unsigned Opc = Op.getOpcode();
+ bool IsVPTrunc = Opc == ISD::VP_TRUNCATE;
SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
@@ -8279,11 +8287,18 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
}
+ unsigned NewOpc;
+ if (Opc == ISD::TRUNCATE_SSAT_S)
+ NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
+ else if (Opc == ISD::TRUNCATE_SSAT_U || Opc == ISD::TRUNCATE_USAT_U)
+ NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
+ else
+ NewOpc = RISCVISD::TRUNCATE_VECTOR_VL;
+
do {
SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
MVT ResultVT = ContainerVT.changeVectorElementType(SrcEltVT);
- Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
- Mask, VL);
+ Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL);
} while (SrcEltVT != DstEltVT);
if (SrcVT.isFixedLengthVector())
More information about the llvm-commits
mailing list