[llvm] [DAG] Support saturated truncate (PR #99418)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 5 23:31:56 PDT 2024


================
@@ -14915,6 +14920,180 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+  EVT VT = N->getValueType(0);
+  SDValue N0 = N->getOperand(0);
+
+  std::function<SDValue(SDValue)> MatchFPTOINT = [&](SDValue Val) -> SDValue {
+    if (Val.getOpcode() == ISD::FP_TO_SINT ||
+        Val.getOpcode() == ISD::FP_TO_UINT)
+      return Val;
+    if (Val.getOpcode() == ISD::SMAX) {
+      for (unsigned I = 0; I < Val.getNumOperands(); I++)
+        if (SDValue Matched = MatchFPTOINT(Val.getOperand(I)))
+          return Matched;
+    }
+    return SDValue();
+  };
+
+  SDValue FPInstr = MatchFPTOINT(N0);
+  if (!FPInstr)
+    return SDValue();
+
+  EVT FPVT = FPInstr.getOperand(0).getValueType();
+  if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+                                                        FPVT, VT))
+    return SDValue();
+  return DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+                     FPInstr.getOperand(0),
+                     DAG.getValueType(VT.getScalarType()));
+}
+
+// Match min and return limit value as a parameter.
+static SDValue matchMin(SDValue V, APInt &Limit, APInt &C) {
+  if (V.getOpcode() == ISD::SMIN || V.getOpcode() == ISD::UMIN) {
+    if (ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+      return V.getOperand(0);
+  }
+  return SDValue();
+}
+
+// Match max and return limit value as a parameter.
+static SDValue matchMax(SDValue V, APInt &Limit, APInt &C) {
+  if (V.getOpcode() == ISD::SMAX) {
+    if (ISD::isConstantSplatVector(V.getOperand(1).getNode(), C)) {
+      if (Limit.isNegative() && C == Limit)
+        return V.getOperand(0);
+      else if (C.uge(Limit))
+        return V.getOperand(0);
+    }
+  }
+  return SDValue();
+}
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+/// Return the source value x to be truncated or SDValue() if the pattern was
+/// not matched.
+///
+static SDValue detectUSatUPattern(SDValue In, EVT VT) {
+  unsigned NumDstBits = VT.getScalarSizeInBits();
+  unsigned NumSrcBits = In.getScalarValueSizeInBits();
+  // Saturation with truncation. We truncate from InVT to VT.
+  assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+  APInt Max;
+  APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
+  if (SDValue UMin = matchMin(In, UnsignedMax, Max)) {
----------------
ParkHanbum wrote:

it matched well. but it doesn't matter, I'll recover earlier revision

https://github.com/llvm/llvm-project/pull/99418


More information about the llvm-commits mailing list