[llvm] [X86] Generate `kmov` for masking integers (PR #120593)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 6 03:10:45 PST 2025


================
@@ -55447,6 +55447,85 @@ static SDValue truncateAVX512SetCCNoBWI(EVT VT, EVT OpVT, SDValue LHS,
   return SDValue();
 }
 
+// The pattern (setcc (and (broadcast x), (2^n, 2^{n+1}, ...)), (0, 0, ...),
+// eq/ne) is generated when using an integer as a mask. Instead of generating a
+// broadcast + vptest, we can directly move the integer to a mask register.
+static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,
+                                        const SDLoc &DL, SelectionDAG &DAG,
+                                        const X86Subtarget &Subtarget) {
+  if (CC != ISD::SETNE && CC != ISD::SETEQ)
+    return SDValue();
+
+  if (!Subtarget.hasAVX512())
+    return SDValue();
+
+  if (Op0.getOpcode() != ISD::AND)
+    return SDValue();
+
+  SDValue Broadcast = Op0.getOperand(0);
+  if (Broadcast.getOpcode() != X86ISD::VBROADCAST &&
+      Broadcast.getOpcode() != X86ISD::VBROADCAST_LOAD)
+    return SDValue();
+
+  SDValue Load = Op0.getOperand(1);
+  EVT LoadVT = Load.getSimpleValueType();
+
+  APInt UndefElts;
+  SmallVector<APInt, 32> EltBits;
+  if (!getTargetConstantBitsFromNode(Load, LoadVT.getScalarSizeInBits(),
+                                     UndefElts, EltBits,
+                                     /*AllowWholeUndefs*/ true,
+                                     /*AllowPartialUndefs*/ false) ||
+      UndefElts[0] || !EltBits[0].isPowerOf2())
+    return SDValue();
+
+  // Check if the constant pool contains only powers of 2 starting from some
+  // 2^N. The table may also contain undefs because of widening of vector
+  // operands.
+  unsigned N = EltBits[0].logBase2();
+  unsigned Len = UndefElts.getBitWidth();
+  for (unsigned I = 1; I != Len; ++I) {
+    if (UndefElts[I]) {
+      if (!UndefElts.extractBits(Len - (I + 1), I + 1).isAllOnes())
+        return SDValue();
+      break;
+    }
+
+    if (EltBits[I] != 1ULL << (N + I))
+      return SDValue();
+  }
+
+  SDValue BroadcastOp = Broadcast.getOpcode() == X86ISD::VBROADCAST
+                            ? Broadcast.getOperand(0)
+                            : Broadcast.getOperand(1);
+  MVT BroadcastOpVT = BroadcastOp.getSimpleValueType();
+  SDValue Masked = BroadcastOp;
+  if (N != 0) {
+    unsigned Mask = (1ULL << Len) - 1;
+    SDValue ShiftedValue = DAG.getNode(ISD::SRL, DL, BroadcastOpVT, BroadcastOp,
+                                       DAG.getConstant(N, DL, BroadcastOpVT));
+    Masked = DAG.getNode(ISD::AND, DL, BroadcastOpVT, ShiftedValue,
+                         DAG.getConstant(Mask, DL, BroadcastOpVT));
+  }
+  SDValue Trunc = DAG.getAnyExtOrTrunc(Masked, DL, MVT::i16);
+  SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1, Trunc);
+  MVT PtrTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
+
+  if (CC == ISD::SETEQ)
+    Bitcast = DAG.getNode(
+        ISD::XOR, DL, MVT::v16i1, Bitcast,
+        DAG.getSplatBuildVector(
+            MVT::v16i1, DL,
+            DAG.getConstant(APInt::getAllOnes(PtrTy.getSizeInBits()), DL,
----------------
RKSimon wrote:

`DAG.getAllOnesConstant`

https://github.com/llvm/llvm-project/pull/120593


More information about the llvm-commits mailing list