[llvm] [X86] Combine `store + vselect` to `masked_store` (PR #145176)
Abhishek Kaushik via llvm-commits
llvm-commits at lists.llvm.org
Sun Jun 22 22:32:17 PDT 2025
================
@@ -53403,6 +53404,76 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+static SDValue foldToMaskedStore(StoreSDNode *Store, SelectionDAG &DAG,
+ const SDLoc &Dl,
+ const X86Subtarget &Subtarget) {
+ using namespace llvm::SDPatternMatch;
+
+ if (!Subtarget.hasAVX() && !Subtarget.hasAVX2() && !Subtarget.hasAVX512())
+ return SDValue();
+
+ if (!Store->isSimple() || Store->isTruncatingStore())
+ return SDValue();
+
+ SDValue StoredVal = Store->getValue();
+ SDValue StorePtr = Store->getBasePtr();
+ SDValue StoreOffset = Store->getOffset();
+ EVT VT = Store->getMemoryVT();
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+ if (!TLI.isTypeLegal(VT) || !TLI.isOperationLegalOrCustom(ISD::MSTORE, VT))
+ return SDValue();
+
+ SDValue Mask, TrueVec, LoadCh;
+ if (!sd_match(StoredVal,
+ m_VSelect(m_Value(Mask), m_Value(TrueVec),
+ m_Load(m_Value(LoadCh), m_Specific(StorePtr),
+ m_Specific(StoreOffset)))))
+ return SDValue();
+
+ LoadSDNode *Load = cast<LoadSDNode>(StoredVal.getOperand(2));
+ if (!Load || !Load->isSimple())
+ return SDValue();
+
+ auto IsSafeToFold = [](StoreSDNode *Store, LoadSDNode *Load) {
+ std::queue<SDValue> Worklist;
+
+ Worklist.push(Store->getChain());
+
+ while (!Worklist.empty()) {
+ SDValue Chain = Worklist.front();
+ Worklist.pop();
+
+ SDNode *Node = Chain.getNode();
+ if (!Node)
+ return false;
+
+ if (Node == Load)
+ return true;
+
+ if (const auto *MemNode = dyn_cast<MemSDNode>(Node))
+ if (!MemNode->isSimple() || MemNode->writeMem())
+ return false;
+
+ if (Node->getOpcode() == ISD::TokenFactor) {
----------------
abhishek-kaushik22 wrote:
```
define void @test_masked_store_multiple(<8 x i32> %x, <8 x i32> %y, ptr %ptr1, ptr %ptr2, <8 x i1> %cmp, <8 x i1> %cmp2) {
%load = load <8 x i32>, ptr %ptr1, align 32
%load2 = load <8 x i32>, ptr %ptr2, align 32
%sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
%sel2 = select <8 x i1> %cmp2, <8 x i32> %y, <8 x i32> %load2
store <8 x i32> %sel, ptr %ptr1, align 32
store <8 x i32> %sel2, ptr %ptr2, align 32
ret void
}
```
This test generates a DAG which has a `TokenFactor` node. https://godbolt.org/z/zdTce1xYE
https://github.com/llvm/llvm-project/pull/145176
More information about the llvm-commits
mailing list