[llvm] [X86] LowerSelect - generalize "select icmp(x,0), lhs, rhs" folding patterns [NFC] (PR #107374)

via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 5 02:37:12 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

<details>
<summary>Changes</summary>

This patch proposes we add a LowerSELECTWithCmpZero helper, which allows us to fold the compare-with-zero from different condition nodes with minimal duplication.

So far I've only handled the simple no-cmov case for or/xor nodes, but the intention is to convert more folds in future PRs.

NFC preliminary patch for #<!-- -->107272

---
Full diff: https://github.com/llvm/llvm-project/pull/107374.diff


1 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+52-35) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 451881e1d61415..efca1a303edfac 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -24074,6 +24074,55 @@ static bool isTruncWithZeroHighBitsInput(SDValue V, SelectionDAG &DAG) {
   return DAG.MaskedValueIsZero(VOp0, APInt::getHighBitsSet(InBits,InBits-Bits));
 }
 
+// Lower various (select (icmp CmpVal, 0), LHS, RHS) custom patterns.
+static SDValue LowerSELECTWithCmpZero(SDValue CmpVal, SDValue LHS, SDValue RHS,
+                                      unsigned X86CC, const SDLoc &DL,
+                                      SelectionDAG &DAG,
+                                      const X86Subtarget &Subtarget) {
+  EVT CmpVT = CmpVal.getValueType();
+  EVT VT = LHS.getValueType();
+  if (!CmpVT.isScalarInteger() || !VT.isScalarInteger())
+    return SDValue();
+
+  if (!Subtarget.canUseCMOV() && X86CC == X86::COND_E &&
+      CmpVal.getOpcode() == ISD::AND && isOneConstant(CmpVal.getOperand(1))) {
+    SDValue Src1, Src2;
+    // true if RHS is XOR or OR operator and one of its operands
+    // is equal to LHS
+    // ( a , a op b) || ( b , a op b)
+    auto isOrXorPattern = [&]() {
+      if ((RHS.getOpcode() == ISD::XOR || RHS.getOpcode() == ISD::OR) &&
+          (RHS.getOperand(0) == LHS || RHS.getOperand(1) == LHS)) {
+        Src1 = RHS.getOperand(RHS.getOperand(0) == LHS ? 1 : 0);
+        Src2 = LHS;
+        return true;
+      }
+      return false;
+    };
+
+    if (isOrXorPattern()) {
+      SDValue Neg;
+      unsigned int CmpSz = CmpVT.getSizeInBits();
+      // we need mask of all zeros or ones with same size of the other
+      // operands.
+      if (CmpSz > VT.getSizeInBits())
+        Neg = DAG.getNode(ISD::TRUNCATE, DL, VT, CmpVal);
+      else if (CmpSz < VT.getSizeInBits())
+        Neg = DAG.getNode(
+            ISD::AND, DL, VT,
+            DAG.getNode(ISD::ANY_EXTEND, DL, VT, CmpVal.getOperand(0)),
+            DAG.getConstant(1, DL, VT));
+      else
+        Neg = CmpVal;
+      SDValue Mask = DAG.getNegative(Neg, DL, VT); // -(and (x, 0x1))
+      SDValue And = DAG.getNode(ISD::AND, DL, VT, Mask, Src1); // Mask & z
+      return DAG.getNode(RHS.getOpcode(), DL, VT, And, Src2);  // And Op y
+    }
+  }
+
+  return SDValue();
+}
+
 SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
   bool AddTest = true;
   SDValue Cond  = Op.getOperand(0);
@@ -24218,41 +24267,9 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
                                 DAG.getTargetConstant(X86::COND_B, DL, MVT::i8),
                                 Sub.getValue(1));
       return DAG.getNode(ISD::OR, DL, VT, SBB, Y);
-    } else if (!Subtarget.canUseCMOV() && CondCode == X86::COND_E &&
-               CmpOp0.getOpcode() == ISD::AND &&
-               isOneConstant(CmpOp0.getOperand(1))) {
-      SDValue Src1, Src2;
-      // true if Op2 is XOR or OR operator and one of its operands
-      // is equal to Op1
-      // ( a , a op b) || ( b , a op b)
-      auto isOrXorPattern = [&]() {
-        if ((Op2.getOpcode() == ISD::XOR || Op2.getOpcode() == ISD::OR) &&
-            (Op2.getOperand(0) == Op1 || Op2.getOperand(1) == Op1)) {
-          Src1 =
-              Op2.getOperand(0) == Op1 ? Op2.getOperand(1) : Op2.getOperand(0);
-          Src2 = Op1;
-          return true;
-        }
-        return false;
-      };
-
-      if (isOrXorPattern()) {
-        SDValue Neg;
-        unsigned int CmpSz = CmpOp0.getSimpleValueType().getSizeInBits();
-        // we need mask of all zeros or ones with same size of the other
-        // operands.
-        if (CmpSz > VT.getSizeInBits())
-          Neg = DAG.getNode(ISD::TRUNCATE, DL, VT, CmpOp0);
-        else if (CmpSz < VT.getSizeInBits())
-          Neg = DAG.getNode(ISD::AND, DL, VT,
-              DAG.getNode(ISD::ANY_EXTEND, DL, VT, CmpOp0.getOperand(0)),
-              DAG.getConstant(1, DL, VT));
-        else
-          Neg = CmpOp0;
-        SDValue Mask = DAG.getNegative(Neg, DL, VT); // -(and (x, 0x1))
-        SDValue And = DAG.getNode(ISD::AND, DL, VT, Mask, Src1); // Mask & z
-        return DAG.getNode(Op2.getOpcode(), DL, VT, And, Src2);  // And Op y
-      }
+    } else if (SDValue R = LowerSELECTWithCmpZero(CmpOp0, Op1, Op2, CondCode,
+                                                  DL, DAG, Subtarget)) {
+      return R;
     } else if ((VT == MVT::i32 || VT == MVT::i64) && isNullConstant(Op2) &&
                Cmp.getNode()->hasOneUse() && (CmpOp0 == Op1) &&
                ((CondCode == X86::COND_S) ||                    // smin(x, 0)

``````````

</details>


https://github.com/llvm/llvm-project/pull/107374


More information about the llvm-commits mailing list