[llvm] [DAG] Support saturated truncate (PR #99418)

David Green via llvm-commits llvm-commits at lists.llvm.org
Sat Jul 27 09:49:55 PDT 2024


================
@@ -14915,6 +14920,159 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+  EVT VT = N->getValueType(0);
+  SDValue N0 = N->getOperand(0);
+  SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
+  if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
+      FPInstr.getOpcode() == ISD::FP_TO_UINT) {
+    EVT FPVT = FPInstr.getOperand(0).getValueType();
+    if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+                                                          FPVT, VT))
+      return SDValue();
+    SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+                              FPInstr.getOperand(0),
+                              DAG.getValueType(VT.getScalarType()));
+    return Sat;
+  }
+
+  return SDValue();
+}
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+/// Return the source value x to be truncated or SDValue() if the pattern was
+/// not matched.
+///
+static SDValue detectUSatUPattern(SDValue In, EVT VT) {
+  EVT InVT = In.getValueType();
+
+  // Saturation with truncation. We truncate from InVT to VT.
+  assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+         "Unexpected types for truncate operation");
+
+  // Match min/max and return limit value as a parameter.
+  auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+    if (V.getOpcode() == Opcode &&
+        ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+      return V.getOperand(0);
+    return SDValue();
+  };
+
+  APInt C1, C2;
+  if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
+    // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+    // the element size of the destination type.
+    if (C2.isMask(VT.getScalarSizeInBits()))
+      return UMin;
+
+  return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+///                  signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+///                  signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectSSatSPattern(SDValue In, EVT VT) {
+  unsigned NumDstBits = VT.getScalarSizeInBits();
+  unsigned NumSrcBits = In.getScalarValueSizeInBits();
+  assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+  auto MatchMinMax = [](SDValue V, unsigned Opcode,
+                        const APInt &Limit) -> SDValue {
+    APInt C;
+    if (V.getOpcode() == Opcode &&
+        ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+      return V.getOperand(0);
+    return SDValue();
+  };
+
+  APInt SignedMax, SignedMin;
+  SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+  SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
+    if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
+      return SMax;
+    }
+  }
+  if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
+    if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
+      return SMin;
+    }
+  }
+  return SDValue();
+}
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// (truncate (smin (smax (x, C1), C2)) to dest_type),
+/// where C1 >= 0 and C2 is unsigned max of destination type.
+///
+/// (truncate (smax (smin (x, C2), C1)) to dest_type)
+/// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
+///
+static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+                                  const SDLoc &DL) {
+  EVT InVT = In.getValueType();
+
+  // Saturation with truncation. We truncate from InVT to VT.
+  assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+         "Unexpected types for truncate operation");
+
+  // Match min/max and return limit value as a parameter.
+  auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+    if (V.getOpcode() == Opcode &&
+        ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+      return V.getOperand(0);
+    return SDValue();
+  };
+
+  APInt C1, C2;
+  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
+    if (MatchMinMax(SMin, ISD::SMAX, C1))
+      if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
----------------
davemgreen wrote:

I would expect C1==0, to handle the basic case. Then it can check C1 >=s 0 and C2 isMask, and return the "SMax" value like the others.

Then - because the SMAX is making sure the value is non-negative we can match against either SMIN or UMIN, they should both be valid.

It looks like this is attempting to handle more cases - they might be good to add in a follow-up once the basics are in.

https://github.com/llvm/llvm-project/pull/99418


More information about the llvm-commits mailing list