[llvm] 89172e9 - [X86] combineBitcastToBoolVector - add XOR + Constant handling, match existing BITCASTs and limit recursion depth
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri May 24 08:30:43 PDT 2024
Author: Simon Pilgrim
Date: 2024-05-24T16:30:23+01:00
New Revision: 89172e954c51b16e5556f758e1a91dd385bc5aab
URL: https://github.com/llvm/llvm-project/commit/89172e954c51b16e5556f758e1a91dd385bc5aab
DIFF: https://github.com/llvm/llvm-project/commit/89172e954c51b16e5556f758e1a91dd385bc5aab.diff
LOG: [X86] combineBitcastToBoolVector - add XOR + Constant handling, match existing BITCASTs and limit recursion depth
Add XOR + constant handling to allow us to detect NOT patterns.
If a recursive combineBitcastToBoolVector call finds an existing BITCAST node then use that.
As combineBitcastToBoolVector is recursive, ensure we limit the maximum recursion depth.
Fixes #93000
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/pr93000.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 19c942faf3c30..ca32cfe542330 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -43413,7 +43413,11 @@ static SDValue createMMXBuildVector(BuildVectorSDNode *BV, SelectionDAG &DAG,
// the chain.
static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
SelectionDAG &DAG,
- const X86Subtarget &Subtarget) {
+ const X86Subtarget &Subtarget,
+ unsigned Depth = 0) {
+ if (Depth >= SelectionDAG::MaxRecursionDepth)
+ return SDValue(); // Limit search depth.
+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
unsigned Opc = V.getOpcode();
switch (Opc) {
@@ -43425,14 +43429,22 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
return DAG.getBitcast(VT, Src);
break;
}
+ case ISD::Constant: {
+ auto *C = cast<ConstantSDNode>(V);
+ if (C->isZero())
+ return DAG.getConstant(0, DL, VT);
+ if (C->isAllOnes())
+ return DAG.getAllOnesConstant(DL, VT);
+ break;
+ }
case ISD::TRUNCATE: {
// If we find a suitable source, a truncated scalar becomes a subvector.
SDValue Src = V.getOperand(0);
EVT NewSrcVT =
EVT::getVectorVT(*DAG.getContext(), MVT::i1, Src.getValueSizeInBits());
if (TLI.isTypeLegal(NewSrcVT))
- if (SDValue N0 =
- combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
+ if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
+ Subtarget, Depth + 1))
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, N0,
DAG.getIntPtrConstant(0, DL));
break;
@@ -43444,20 +43456,22 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
EVT NewSrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
Src.getScalarValueSizeInBits());
if (TLI.isTypeLegal(NewSrcVT))
- if (SDValue N0 =
- combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
+ if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
+ Subtarget, Depth + 1))
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
Opc == ISD::ANY_EXTEND ? DAG.getUNDEF(VT)
: DAG.getConstant(0, DL, VT),
N0, DAG.getIntPtrConstant(0, DL));
break;
}
- case ISD::OR: {
- // If we find suitable sources, we can just move an OR to the vector domain.
- SDValue Src0 = V.getOperand(0);
- SDValue Src1 = V.getOperand(1);
- if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
- if (SDValue N1 = combineBitcastToBoolVector(VT, Src1, DL, DAG, Subtarget))
+ case ISD::OR:
+ case ISD::XOR: {
+ // If we find suitable sources, we can just move the op to the vector
+ // domain.
+ if (SDValue N0 = combineBitcastToBoolVector(VT, V.getOperand(0), DL, DAG,
+ Subtarget, Depth + 1))
+ if (SDValue N1 = combineBitcastToBoolVector(VT, V.getOperand(1), DL, DAG,
+ Subtarget, Depth + 1))
return DAG.getNode(Opc, DL, VT, N0, N1);
break;
}
@@ -43469,13 +43483,20 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
break;
if (auto *Amt = dyn_cast<ConstantSDNode>(V.getOperand(1)))
- if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
+ if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget,
+ Depth + 1))
return DAG.getNode(
X86ISD::KSHIFTL, DL, VT, N0,
DAG.getTargetConstant(Amt->getZExtValue(), DL, MVT::i8));
break;
}
}
+
+ // Does the inner bitcast already exist?
+ if (Depth > 0)
+ if (SDNode *Alt = DAG.getNodeIfExists(ISD::BITCAST, DAG.getVTList(VT), {V}))
+ return SDValue(Alt, 0);
+
return SDValue();
}
diff --git a/llvm/test/CodeGen/X86/pr93000.ll b/llvm/test/CodeGen/X86/pr93000.ll
index 97c17f2ec2dc6..0bd5da48847e8 100644
--- a/llvm/test/CodeGen/X86/pr93000.ll
+++ b/llvm/test/CodeGen/X86/pr93000.ll
@@ -10,8 +10,7 @@ define void @PR93000(ptr %a0, ptr %a1, ptr %a2, <32 x i16> %a3) {
; CHECK-NEXT: .LBB0_1: # %Loop
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1
; CHECK-NEXT: kmovd %eax, %k1
-; CHECK-NEXT: notl %eax
-; CHECK-NEXT: kmovd %eax, %k2
+; CHECK-NEXT: knotd %k1, %k2
; CHECK-NEXT: vpblendmw (%rsi), %zmm0, %zmm1 {%k1}
; CHECK-NEXT: vmovdqu16 (%rdx), %zmm1 {%k2}
; CHECK-NEXT: vmovdqu64 %zmm1, (%rsi)
More information about the llvm-commits
mailing list