[llvm] 89172e9 - [X86] combineBitcastToBoolVector - add XOR + Constant handling, match existing BITCASTs and limit recursion depth

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri May 24 08:30:43 PDT 2024


Author: Simon Pilgrim
Date: 2024-05-24T16:30:23+01:00
New Revision: 89172e954c51b16e5556f758e1a91dd385bc5aab

URL: https://github.com/llvm/llvm-project/commit/89172e954c51b16e5556f758e1a91dd385bc5aab
DIFF: https://github.com/llvm/llvm-project/commit/89172e954c51b16e5556f758e1a91dd385bc5aab.diff

LOG: [X86] combineBitcastToBoolVector - add XOR + Constant handling, match existing BITCASTs and limit recursion depth

Add XOR + constant handling to allow us to detect NOT patterns.

If a recursive combineBitcastToBoolVector call finds an existing BITCAST node then use that.

As combineBitcastToBoolVector is recursive, ensure we limit the maximum recursion depth.

Fixes #93000

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/pr93000.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 19c942faf3c30..ca32cfe542330 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -43413,7 +43413,11 @@ static SDValue createMMXBuildVector(BuildVectorSDNode *BV, SelectionDAG &DAG,
 // the chain.
 static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
                                           SelectionDAG &DAG,
-                                          const X86Subtarget &Subtarget) {
+                                          const X86Subtarget &Subtarget,
+                                          unsigned Depth = 0) {
+  if (Depth >= SelectionDAG::MaxRecursionDepth)
+    return SDValue(); // Limit search depth.
+
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   unsigned Opc = V.getOpcode();
   switch (Opc) {
@@ -43425,14 +43429,22 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
       return DAG.getBitcast(VT, Src);
     break;
   }
+  case ISD::Constant: {
+    auto *C = cast<ConstantSDNode>(V);
+    if (C->isZero())
+      return DAG.getConstant(0, DL, VT);
+    if (C->isAllOnes())
+      return DAG.getAllOnesConstant(DL, VT);
+    break;
+  }
   case ISD::TRUNCATE: {
     // If we find a suitable source, a truncated scalar becomes a subvector.
     SDValue Src = V.getOperand(0);
     EVT NewSrcVT =
         EVT::getVectorVT(*DAG.getContext(), MVT::i1, Src.getValueSizeInBits());
     if (TLI.isTypeLegal(NewSrcVT))
-      if (SDValue N0 =
-              combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
+      if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
+                                                  Subtarget, Depth + 1))
         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, N0,
                            DAG.getIntPtrConstant(0, DL));
     break;
@@ -43444,20 +43456,22 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
     EVT NewSrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
                                     Src.getScalarValueSizeInBits());
     if (TLI.isTypeLegal(NewSrcVT))
-      if (SDValue N0 =
-              combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
+      if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
+                                                  Subtarget, Depth + 1))
         return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
                            Opc == ISD::ANY_EXTEND ? DAG.getUNDEF(VT)
                                                   : DAG.getConstant(0, DL, VT),
                            N0, DAG.getIntPtrConstant(0, DL));
     break;
   }
-  case ISD::OR: {
-    // If we find suitable sources, we can just move an OR to the vector domain.
-    SDValue Src0 = V.getOperand(0);
-    SDValue Src1 = V.getOperand(1);
-    if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
-      if (SDValue N1 = combineBitcastToBoolVector(VT, Src1, DL, DAG, Subtarget))
+  case ISD::OR:
+  case ISD::XOR: {
+    // If we find suitable sources, we can just move the op to the vector
+    // domain.
+    if (SDValue N0 = combineBitcastToBoolVector(VT, V.getOperand(0), DL, DAG,
+                                                Subtarget, Depth + 1))
+      if (SDValue N1 = combineBitcastToBoolVector(VT, V.getOperand(1), DL, DAG,
+                                                  Subtarget, Depth + 1))
         return DAG.getNode(Opc, DL, VT, N0, N1);
     break;
   }
@@ -43469,13 +43483,20 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
       break;
 
     if (auto *Amt = dyn_cast<ConstantSDNode>(V.getOperand(1)))
-      if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
+      if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget,
+                                                  Depth + 1))
         return DAG.getNode(
             X86ISD::KSHIFTL, DL, VT, N0,
             DAG.getTargetConstant(Amt->getZExtValue(), DL, MVT::i8));
     break;
   }
   }
+
+  // Does the inner bitcast already exist?
+  if (Depth > 0)
+    if (SDNode *Alt = DAG.getNodeIfExists(ISD::BITCAST, DAG.getVTList(VT), {V}))
+      return SDValue(Alt, 0);
+
   return SDValue();
 }
 

diff  --git a/llvm/test/CodeGen/X86/pr93000.ll b/llvm/test/CodeGen/X86/pr93000.ll
index 97c17f2ec2dc6..0bd5da48847e8 100644
--- a/llvm/test/CodeGen/X86/pr93000.ll
+++ b/llvm/test/CodeGen/X86/pr93000.ll
@@ -10,8 +10,7 @@ define void @PR93000(ptr %a0, ptr %a1, ptr %a2, <32 x i16> %a3) {
 ; CHECK-NEXT:  .LBB0_1: # %Loop
 ; CHECK-NEXT:    # =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    kmovd %eax, %k1
-; CHECK-NEXT:    notl %eax
-; CHECK-NEXT:    kmovd %eax, %k2
+; CHECK-NEXT:    knotd %k1, %k2
 ; CHECK-NEXT:    vpblendmw (%rsi), %zmm0, %zmm1 {%k1}
 ; CHECK-NEXT:    vmovdqu16 (%rdx), %zmm1 {%k2}
 ; CHECK-NEXT:    vmovdqu64 %zmm1, (%rsi)


        


More information about the llvm-commits mailing list