[llvm-branch-commits] [llvm] 1b209ff - [DAG] Move vselect(icmp_ult, 0, sub(x, y)) -> usubsat(x, y) to DAGCombine (PR40111)

Simon Pilgrim via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Dec 1 06:30:22 PST 2020


Author: Simon Pilgrim
Date: 2020-12-01T14:25:29Z
New Revision: 1b209ff9e3e13492fd56ae6662989ef47510db4d

URL: https://github.com/llvm/llvm-project/commit/1b209ff9e3e13492fd56ae6662989ef47510db4d
DIFF: https://github.com/llvm/llvm-project/commit/1b209ff9e3e13492fd56ae6662989ef47510db4d.diff

LOG: [DAG] Move vselect(icmp_ult, 0, sub(x,y)) -> usubsat(x,y) to DAGCombine (PR40111)

Move the X86 VSELECT->USUBSAT fold to DAGCombiner - there's nothing target specific about these folds.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/usub_sat_vec.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7a87521ae344..1684ec90f676 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -9743,6 +9743,68 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
         }
       }
     }
+
+    // Match VSELECTs into sub with unsigned saturation.
+    if (hasOperation(ISD::USUBSAT, VT)) {
+      // Check if one of the arms of the VSELECT is a zero vector. If it's on
+      // the left side invert the predicate to simplify logic below.
+      SDValue Other;
+      ISD::CondCode SatCC = CC;
+      if (ISD::isBuildVectorAllZeros(N1.getNode())) {
+        Other = N2;
+        SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
+      } else if (ISD::isBuildVectorAllZeros(N2.getNode())) {
+        Other = N1;
+      }
+
+      if (Other && Other.getNumOperands() == 2 && Other.getOperand(0) == LHS) {
+        SDValue CondLHS = LHS, CondRHS = RHS;
+        SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
+
+        // Look for a general sub with unsigned saturation first.
+        // x >= y ? x-y : 0 --> usubsat x, y
+        // x >  y ? x-y : 0 --> usubsat x, y
+        if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
+            Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
+          return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
+
+        if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS)) {
+          if (isa<BuildVectorSDNode>(CondRHS)) {
+            // If the RHS is a constant we have to reverse the const
+            // canonicalization.
+            // x > C-1 ? x+-C : 0 --> usubsat x, C
+            auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
+              return (!Op && !Cond) ||
+                     (Op && Cond &&
+                      Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
+            };
+            if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
+                ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
+                                          /*AllowUndefs*/ true)) {
+              OpRHS = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
+                                  OpRHS);
+              return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
+            }
+
+            // Another special case: If C was a sign bit, the sub has been
+            // canonicalized into a xor.
+            // FIXME: Would it be better to use computeKnownBits to determine
+            //        whether it's safe to decanonicalize the xor?
+            // x s< 0 ? x^C : 0 --> usubsat x, C
+            if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) {
+              if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
+                  ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
+                  OpRHSConst->getAPIntValue().isSignMask()) {
+                // Note that we have to rebuild the RHS constant here to ensure
+                // we don't rely on particular values of undef lanes.
+                OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT);
+                return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
+              }
+            }
+          }
+        }
+      }
+    }
   }
 
   if (SimplifySelectOps(N, N1, N2))

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ce3497934a07..3020354ca499 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -40912,75 +40912,6 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
     }
   }
 
-  // Match VSELECTs into subs with unsigned saturation.
-  if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC &&
-      // psubus is available in SSE2 for i8 and i16 vectors.
-      Subtarget.hasSSE2() && VT.getVectorNumElements() >= 2 &&
-      isPowerOf2_32(VT.getVectorNumElements()) &&
-      (VT.getVectorElementType() == MVT::i8 ||
-       VT.getVectorElementType() == MVT::i16)) {
-    ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
-
-    // Check if one of the arms of the VSELECT is a zero vector. If it's on the
-    // left side invert the predicate to simplify logic below.
-    SDValue Other;
-    if (ISD::isBuildVectorAllZeros(LHS.getNode())) {
-      Other = RHS;
-      CC = ISD::getSetCCInverse(CC, VT.getVectorElementType());
-    } else if (ISD::isBuildVectorAllZeros(RHS.getNode())) {
-      Other = LHS;
-    }
-
-    if (Other.getNode() && Other->getNumOperands() == 2 &&
-        Other->getOperand(0) == Cond.getOperand(0)) {
-      SDValue OpLHS = Other->getOperand(0), OpRHS = Other->getOperand(1);
-      SDValue CondRHS = Cond->getOperand(1);
-
-      // Look for a general sub with unsigned saturation first.
-      // x >= y ? x-y : 0 --> subus x, y
-      // x >  y ? x-y : 0 --> subus x, y
-      if ((CC == ISD::SETUGE || CC == ISD::SETUGT) &&
-          Other->getOpcode() == ISD::SUB && OpRHS == CondRHS)
-        return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
-
-      if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS)) {
-        if (isa<BuildVectorSDNode>(CondRHS)) {
-          // If the RHS is a constant we have to reverse the const
-          // canonicalization.
-          // x > C-1 ? x+-C : 0 --> subus x, C
-          auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
-            return (!Op && !Cond) ||
-                   (Op && Cond &&
-                    Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
-          };
-          if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD &&
-              ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
-                                        /*AllowUndefs*/ true)) {
-            OpRHS = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
-                                OpRHS);
-            return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
-          }
-
-          // Another special case: If C was a sign bit, the sub has been
-          // canonicalized into a xor.
-          // FIXME: Would it be better to use computeKnownBits to determine
-          //        whether it's safe to decanonicalize the xor?
-          // x s< 0 ? x^C : 0 --> subus x, C
-          if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) {
-            if (CC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
-                ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
-                OpRHSConst->getAPIntValue().isSignMask()) {
-              // Note that we have to rebuild the RHS constant here to ensure we
-              // don't rely on particular values of undef lanes.
-              OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT);
-              return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
-            }
-          }
-        }
-      }
-    }
-  }
-
   // Check if the first operand is all zeros and Cond type is vXi1.
   // If this an avx512 target we can improve the use of zero masking by
   // swapping the operands and inverting the condition.

diff  --git a/llvm/test/CodeGen/X86/usub_sat_vec.ll b/llvm/test/CodeGen/X86/usub_sat_vec.ll
index 63482cff994c..efd5b0d3895c 100644
--- a/llvm/test/CodeGen/X86/usub_sat_vec.ll
+++ b/llvm/test/CodeGen/X86/usub_sat_vec.ll
@@ -1161,17 +1161,10 @@ define void @PR48223(<32 x i16>* %p0) {
 ; AVX512F-NEXT:    vmovdqa (%rdi), %ymm0
 ; AVX512F-NEXT:    vmovdqa 32(%rdi), %ymm1
 ; AVX512F-NEXT:    vmovdqa {{.*#+}} ymm2 = [64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64]
-; AVX512F-NEXT:    vpmaxuw %ymm2, %ymm1, %ymm3
-; AVX512F-NEXT:    vpcmpeqw %ymm3, %ymm1, %ymm3
-; AVX512F-NEXT:    vpmaxuw %ymm2, %ymm0, %ymm2
-; AVX512F-NEXT:    vpcmpeqw %ymm2, %ymm0, %ymm2
-; AVX512F-NEXT:    vinserti64x4 $1, %ymm3, %zmm2, %zmm2
-; AVX512F-NEXT:    vmovdqa {{.*#+}} ymm3 = [65472,65472,65472,65472,65472,65472,65472,65472,65472,65472,65472,65472,65472,65472,65472,65472]
-; AVX512F-NEXT:    vpaddw %ymm3, %ymm1, %ymm1
-; AVX512F-NEXT:    vpaddw %ymm3, %ymm0, %ymm0
-; AVX512F-NEXT:    vinserti64x4 $1, %ymm1, %zmm0, %zmm0
-; AVX512F-NEXT:    vpandq %zmm0, %zmm2, %zmm0
-; AVX512F-NEXT:    vmovdqa64 %zmm0, (%rdi)
+; AVX512F-NEXT:    vpsubusw %ymm2, %ymm1, %ymm1
+; AVX512F-NEXT:    vpsubusw %ymm2, %ymm0, %ymm0
+; AVX512F-NEXT:    vmovdqa %ymm0, (%rdi)
+; AVX512F-NEXT:    vmovdqa %ymm1, 32(%rdi)
 ; AVX512F-NEXT:    vzeroupper
 ; AVX512F-NEXT:    retq
 ;


        


More information about the llvm-branch-commits mailing list