[llvm] [DAG] Combine `store + vselect` to `masked_store` (PR #145176)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 30 04:14:46 PDT 2025


================
@@ -22515,12 +22515,62 @@ SDValue DAGCombiner::visitATOMIC_STORE(SDNode *N) {
   return SDValue();
 }
 
+static SDValue foldToMaskedStore(StoreSDNode *Store, SelectionDAG &DAG,
+                                 const SDLoc &Dl) {
+  if (!Store->isSimple() || !ISD::isNormalStore(Store))
+    return SDValue();
+
+  SDValue StoredVal = Store->getValue();
+  SDValue StorePtr = Store->getBasePtr();
+  SDValue StoreOffset = Store->getOffset();
+  EVT VT = Store->getMemoryVT();
+  unsigned AddrSpace = Store->getAddressSpace();
+  Align Alignment = Store->getAlign();
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  if (!TLI.isOperationLegalOrCustom(ISD::MSTORE, VT) ||
+      !TLI.allowsMisalignedMemoryAccesses(VT, AddrSpace, Alignment))
+    return SDValue();
+
+  SDValue Mask, OtherVec, LoadCh;
+  unsigned LoadPos;
+  if (sd_match(StoredVal,
+               m_VSelect(m_Value(Mask), m_Value(OtherVec),
+                         m_Load(m_Value(LoadCh), m_Specific(StorePtr),
+                                m_Specific(StoreOffset))))) {
+    LoadPos = 2;
+  } else if (sd_match(StoredVal,
+                      m_VSelect(m_Value(Mask),
+                                m_Load(m_Value(LoadCh), m_Specific(StorePtr),
+                                       m_Specific(StoreOffset)),
+                                m_Value(OtherVec)))) {
+    LoadPos = 1;
+    Mask = DAG.getNOT(Dl, Mask, Mask.getValueType());
+  } else {
+    return SDValue();
+  }
+
+  auto *Load = cast<LoadSDNode>(StoredVal.getOperand(LoadPos));
+  if (!Load->isSimple() || !ISD::isNormalLoad(Load))
+    return SDValue();
+
+  if (!Store->getChain().reachesChainWithoutSideEffects(LoadCh))
+    return SDValue();
+
+  return DAG.getMaskedStore(Store->getChain(), Dl, OtherVec, StorePtr,
+                            StoreOffset, Mask, VT, Store->getMemOperand(),
+                            Store->getAddressingMode());
+}
+
 SDValue DAGCombiner::visitSTORE(SDNode *N) {
   StoreSDNode *ST  = cast<StoreSDNode>(N);
   SDValue Chain = ST->getChain();
   SDValue Value = ST->getValue();
   SDValue Ptr   = ST->getBasePtr();
 
+  if (SDValue MaskedStore = foldToMaskedStore(ST, DAG, SDLoc(N)))
+    return MaskedStore;
----------------
arsenm wrote:

Can you move this further down in the function? This isn't going to be one of the more basic, common scenarios 

https://github.com/llvm/llvm-project/pull/145176


More information about the llvm-commits mailing list