[llvm] [X86] Generate `kmov` for masking integers (PR #120593)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 25 01:02:51 PST 2025
================
@@ -55689,6 +55689,94 @@ static SDValue truncateAVX512SetCCNoBWI(EVT VT, EVT OpVT, SDValue LHS,
return SDValue();
}
+// The pattern (setcc (and (broadcast x), (2^n, 2^{n+1}, ...)), (0, 0, ...),
+// eq/ne) is generated when using an integer as a mask. Instead of generating a
+// broadcast + vptest, we can directly move the integer to a mask register.
+static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,
+ const SDLoc &DL, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ if (CC != ISD::SETNE && CC != ISD::SETEQ)
+ return SDValue();
+
+ if (!Subtarget.hasAVX512())
+ return SDValue();
+
+ if (Op0.getOpcode() != ISD::AND)
+ return SDValue();
+
+ SDValue Broadcast = Op0.getOperand(0);
+ if (Broadcast.getOpcode() != X86ISD::VBROADCAST &&
+ Broadcast.getOpcode() != X86ISD::VBROADCAST_LOAD)
+ return SDValue();
+
+ SDValue Load = Op0.getOperand(1);
+ EVT LoadVT = Load.getSimpleValueType();
+
+ APInt UndefElts;
+ SmallVector<APInt, 32> EltBits;
+ if (!getTargetConstantBitsFromNode(Load, LoadVT.getScalarSizeInBits(),
+ UndefElts, EltBits,
+ /*AllowWholeUndefs*/ true,
+ /*AllowPartialUndefs*/ false) ||
+ UndefElts[0] || !EltBits[0].isPowerOf2() || UndefElts.getBitWidth() > 16)
+ return SDValue();
+
+ // Check if the constant pool contains only powers of 2 starting from some
+ // 2^N. The table may also contain undefs because of widening of vector
+ // operands.
+ unsigned N = EltBits[0].logBase2();
+ unsigned Len = UndefElts.getBitWidth();
+ for (unsigned I = 1; I != Len; ++I) {
+ if (UndefElts[I]) {
+ if (!UndefElts.extractBits(Len - (I + 1), I + 1).isAllOnes())
+ return SDValue();
+ break;
+ }
+
+ if (EltBits[I].getBitWidth() <= N + I || !EltBits[I].isOneBitSet(N + I))
+ return SDValue();
+ }
+
+ MVT BroadcastOpVT = Broadcast.getSimpleValueType().getVectorElementType();
+ SDValue BroadcastOp;
+ if (Broadcast.getOpcode() != X86ISD::VBROADCAST) {
+ BroadcastOp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, BroadcastOpVT,
+ Broadcast, DAG.getVectorIdxConstant(0, DL));
+ } else {
+ BroadcastOp = Broadcast.getOperand(0);
+ if (BroadcastOp.getValueType().isVector())
+ return SDValue();
+ }
+
+ SDValue Masked = BroadcastOp;
+ if (N != 0) {
+ APInt Mask = APInt::getLowBitsSet(BroadcastOpVT.getSizeInBits(), Len);
+ SDValue ShiftedValue = DAG.getNode(ISD::SRL, DL, BroadcastOpVT, BroadcastOp,
+ DAG.getConstant(N, DL, BroadcastOpVT));
+ Masked = DAG.getNode(ISD::AND, DL, BroadcastOpVT, ShiftedValue,
+ DAG.getConstant(Mask, DL, BroadcastOpVT));
+ }
+ // We can't extract more than 16 bits using this pattern, because 2^{17} will
+ // not fit in an i16 and a vXi32 where X > 16 is more than 512 bits.
+ SDValue Trunc = DAG.getAnyExtOrTrunc(Masked, DL, MVT::i16);
+ SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1, Trunc);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ const DataLayout &DataLayout = DAG.getDataLayout();
+ MVT PtrTy = TLI.getPointerTy(DataLayout);
+
+ if (CC == ISD::SETEQ)
+ Bitcast =
+ DAG.getNode(ISD::XOR, DL, MVT::v16i1, Bitcast,
----------------
RKSimon wrote:
Can't this be handled by DAG.getNOT() ?
https://github.com/llvm/llvm-project/pull/120593
More information about the llvm-commits
mailing list