[llvm] [X86] Shrink width of masked loads/stores (PR #105451)

Shengchen Kan via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 1 18:39:09 PDT 2024


================
@@ -51536,20 +51536,112 @@ combineMaskedLoadConstantMask(MaskedLoadSDNode *ML, SelectionDAG &DAG,
   return DCI.CombineTo(ML, Blend, NewML.getValue(1), true);
 }
 
+static bool tryShrinkMaskedOperation(SelectionDAG &DAG, const SDLoc &DL,
+                                     SDValue Mask, EVT OrigVT,
+                                     SDValue *ValInOut, EVT *NewVTOut,
+                                     SDValue *NewMaskOut) {
+  // Ensure we have a reasonable input type.
+  // Also ensure input bits is larger then xmm, otherwise its not
+  // profitable to try to shrink.
+  if (!OrigVT.isSimple() ||
+      !(OrigVT.is256BitVector() || OrigVT.is512BitVector()))
+    return false;
+
+  SmallVector<SDValue> OrigMask;
+  APInt DemandedElts = getDemandedEltsForMaskedOp(
+      Mask, OrigVT.getVectorNumElements(), &OrigMask);
+  if (DemandedElts.isAllOnes() || DemandedElts.isZero())
+    return false;
+
+  unsigned OrigNumElts = OrigVT.getVectorNumElements();
+  // Potential TODO: It might be profitable to extra not just use the "lower"
+  // sub-vector.
+  unsigned ReqElts =
+      DemandedElts.getBitWidth() - DemandedElts.countLeadingZeros();
+  // We can't shrink out vector category in a meaningful way.
+  if (ReqElts > OrigNumElts / 2U)
+    return false;
+
+  // At most shrink to xmm.
+  unsigned NewNumElts =
+      std::max(128U / OrigVT.getScalarSizeInBits(), PowerOf2Ceil(ReqElts));
----------------
KanRobert wrote:

Line 51559 use the number of non-zero bits as the `ReqElts`.  I think this assumes the mask is a vector of boolean?

https://github.com/llvm/llvm-project/pull/105451


More information about the llvm-commits mailing list