[llvm] [DAG] Support saturated truncate (PR #99418)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 6 01:54:34 PDT 2024
================
@@ -14915,6 +14920,181 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0);
+
+ auto MatchFPTOINT = [&](SDValue Val) -> SDValue {
+ if (Val.getOpcode() == ISD::FP_TO_SINT ||
+ Val.getOpcode() == ISD::FP_TO_UINT)
+ return Val;
+ return SDValue();
+ };
+
+ SDValue FPInstr;
+ if (N0.getOpcode() == ISD::SMAX) {
+ FPInstr = MatchFPTOINT(N0.getOperand(0));
+ if (!FPInstr)
+ FPInstr = MatchFPTOINT(N0.getOperand(1));
+ } else
+ FPInstr = MatchFPTOINT(N0);
+
+ if (FPInstr) {
+ EVT FPVT = FPInstr.getOperand(0).getValueType();
+ if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+ FPVT, VT))
+ return SDValue();
+ return DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+ FPInstr.getOperand(0),
+ DAG.getValueType(VT.getScalarType()));
+ }
+
+ return SDValue();
+}
+
+// Match min/max and return limit value as a parameter.
+static SDValue matchMinMax(SDValue V, unsigned Opcode, APInt &Limit,
+ bool Signed) {
+ if (V.getOpcode() == Opcode) {
+ if (Signed) {
+ APInt C;
+ if (ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) &&
+ C == Limit)
+ return V.getOperand(0);
+ } else if (ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+ return V.getOperand(0);
+ }
+ return SDValue();
+}
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+/// Return the source value x to be truncated or SDValue() if the pattern was
+/// not matched.
+///
+static SDValue detectUSatUPattern(SDValue In, EVT VT) {
+ EVT InVT = In.getValueType();
+
+ // Saturation with truncation. We truncate from InVT to VT.
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+ "Unexpected types for truncate operation");
+
+ APInt Max;
+ if (SDValue UMin = matchMinMax(In, ISD::UMIN, Max, /*Signed*/ false)) {
+ // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+ // the element size of the destination type.
+ if (Max.isMask(VT.getScalarSizeInBits()))
+ return UMin;
+ }
+
+ return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+/// signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+/// signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectSSatSPattern(SDValue In, EVT VT) {
+ unsigned NumDstBits = VT.getScalarSizeInBits();
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+ APInt SignedMax, SignedMin;
+ SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+ SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+ if (SDValue SMin = matchMinMax(In, ISD::SMIN, SignedMax, /*Signed*/ true)) {
+ if (SDValue SMax =
+ matchMinMax(SMin, ISD::SMAX, SignedMin, /*Signed*/ true)) {
+ return SMax;
+ }
+ }
+ if (SDValue SMax = matchMinMax(In, ISD::SMAX, SignedMin, /*Signed*/ true)) {
+ if (SDValue SMin =
+ matchMinMax(SMax, ISD::SMIN, SignedMax, /*Signed*/ true)) {
+ return SMin;
+ }
+ }
+ return SDValue();
+}
+
+/// Detect patterns of truncation with unsigned saturation:
+static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+ const SDLoc &DL) {
+ EVT InVT = In.getValueType();
+
+ // Saturation with truncation. We truncate from InVT to VT.
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+ "Unexpected types for truncate operation");
+
+ APInt Min, Max;
+ SDValue SMax, SMin, UMin, First;
+ // (truncate (smin (smax (x, Min), Max)) to dest_type),
+ // (truncate (smax (smin (x, Max), Min)) to dest_type)
+ // (truncate (umin (smax (x, Min), Max)) to dest_type)
+ // where Min >= 0, Max is unsigned max of destination type and Min <= Max.
+ if (First = SMax = matchMinMax(In, ISD::SMAX, Min, /*Signed*/ false))
+ SMin = matchMinMax(First, ISD::SMIN, Max, /*Signed*/ false);
+ else if (First = SMin = matchMinMax(In, ISD::SMIN, Max, /*Signed*/ false))
+ SMax = matchMinMax(First, ISD::SMAX, Min, /*Signed*/ false);
+ else if (First = UMin = matchMinMax(In, ISD::UMIN, Max, /*Signed*/ false))
+ SMax = matchMinMax(UMin, ISD::SMAX, Min, /*Signed*/ false);
+
+ if (SMax && Min.isNonNegative() && Max.isMask(VT.getScalarSizeInBits())) {
----------------
ParkHanbum wrote:
@davemgreen how about this?
```
static SDValue matchMin(SDValue V, unsigned Opcode, APInt &Limit) {
if (V.getOpcode() == Opcode) {
APInt C;
if (ISD::isConstantSplatVector(V.getOperand(1).getNode(), C))
if (C == Limit)
return V.getOperand(0);
if (C > Limit && Limit == 0)
return V.getOperand(0);
}
return SDValue();
}
```
https://github.com/llvm/llvm-project/pull/99418
More information about the llvm-commits
mailing list