[llvm] 669d844 - [X86] Convert truncsat clamping patterns to use SDPatternMatch. NFC.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 9 08:12:51 PDT 2024


Author: Simon Pilgrim
Date: 2024-08-09T16:11:36+01:00
New Revision: 669d8448a308dcf0d7b26fc4168177e54d4ccf82

URL: https://github.com/llvm/llvm-project/commit/669d8448a308dcf0d7b26fc4168177e54d4ccf82
DIFF: https://github.com/llvm/llvm-project/commit/669d8448a308dcf0d7b26fc4168177e54d4ccf82.diff

LOG: [X86] Convert truncsat clamping patterns to use SDPatternMatch. NFC.

Inspired by #99418 (which hopefully we can replace this code with at some point)

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 63690487632f88..5c0bdce443a798 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -50995,38 +50995,31 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
 ///   pattern was not matched.
 static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
                                  const SDLoc &DL) {
+  using namespace llvm::SDPatternMatch;
   EVT InVT = In.getValueType();
 
   // Saturation with truncation. We truncate from InVT to VT.
   assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
          "Unexpected types for truncate operation");
 
-  // Match min/max and return limit value as a parameter.
-  auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
-    if (V.getOpcode() == Opcode &&
-        ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
-      return V.getOperand(0);
-    return SDValue();
-  };
-
   APInt C1, C2;
-  if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
-    // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
-    // the element size of the destination type.
-    if (C2.isMask(VT.getScalarSizeInBits()))
-      return UMin;
+  SDValue UMin, SMin, SMax;
 
-  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
-    if (MatchMinMax(SMin, ISD::SMAX, C1))
-      if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
-        return SMin;
+  // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+  // the element size of the destination type.
+  if (sd_match(In, m_UMin(m_Value(UMin), m_ConstInt(C2))) &&
+      C2.isMask(VT.getScalarSizeInBits()))
+    return UMin;
 
-  if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
-    if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
-      if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
-          C2.uge(C1)) {
-        return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
-      }
+  if (sd_match(In, m_SMin(m_Value(SMin), m_ConstInt(C2))) &&
+      sd_match(SMin, m_SMax(m_Value(SMax), m_ConstInt(C1))) &&
+      C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
+    return SMin;
+
+  if (sd_match(In, m_SMax(m_Value(SMax), m_ConstInt(C1))) &&
+      sd_match(SMax, m_SMin(m_Value(SMin), m_ConstInt(C2))) &&
+      C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) && C2.uge(C1))
+    return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
 
   return SDValue();
 }
@@ -51041,35 +51034,28 @@ static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
 /// Return the source value to be truncated or SDValue() if the pattern was not
 /// matched.
 static SDValue detectSSatPattern(SDValue In, EVT VT, bool MatchPackUS = false) {
+  using namespace llvm::SDPatternMatch;
   unsigned NumDstBits = VT.getScalarSizeInBits();
   unsigned NumSrcBits = In.getScalarValueSizeInBits();
   assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
 
-  auto MatchMinMax = [](SDValue V, unsigned Opcode,
-                        const APInt &Limit) -> SDValue {
-    APInt C;
-    if (V.getOpcode() == Opcode &&
-        ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
-      return V.getOperand(0);
-    return SDValue();
-  };
-
   APInt SignedMax, SignedMin;
   if (MatchPackUS) {
     SignedMax = APInt::getAllOnes(NumDstBits).zext(NumSrcBits);
-    SignedMin = APInt(NumSrcBits, 0);
+    SignedMin = APInt::getZero(NumSrcBits);
   } else {
     SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
     SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
   }
 
-  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax))
-    if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin))
-      return SMax;
+  SDValue SMin, SMax;
+  if (sd_match(In, m_SMin(m_Value(SMin), m_SpecificInt(SignedMax))) &&
+      sd_match(SMin, m_SMax(m_Value(SMax), m_SpecificInt(SignedMin))))
+    return SMax;
 
-  if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin))
-    if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax))
-      return SMin;
+  if (sd_match(In, m_SMax(m_Value(SMax), m_SpecificInt(SignedMin))) &&
+      sd_match(SMax, m_SMin(m_Value(SMin), m_SpecificInt(SignedMax))))
+    return SMin;
 
   return SDValue();
 }


        


More information about the llvm-commits mailing list