[llvm] [AArch64] Improve lowering for scalable masked interleaving stores (PR #156718)

Cullen Rhodes via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 4 02:19:18 PDT 2025


================
@@ -24632,6 +24632,106 @@ static SDValue performSTORECombine(SDNode *N,
   return SDValue();
 }
 
+static bool
+isSequentialConcatOfVectorInterleave(SDNode *N, SmallVectorImpl<SDValue> &Ops) {
+  if (N->getOpcode() != ISD::CONCAT_VECTORS)
+    return false;
+
+  unsigned NumParts = N->getNumOperands();
+
+  // We should be concatenating each sequential result from a
+  // VECTOR_INTERLEAVE.
+  SDNode *InterleaveOp = N->getOperand(0).getNode();
+  if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
+      InterleaveOp->getNumOperands() != NumParts)
+    return false;
+
+  for (unsigned I = 0; I < NumParts; I++) {
+    if (N->getOperand(I) != SDValue(InterleaveOp, I))
+      return false;
+  }
+
+  Ops.append(InterleaveOp->op_begin(), InterleaveOp->op_end());
+  return true;
+}
+
+static SDValue getNarrowMaskForInterleavedOps(SelectionDAG &DAG, SDLoc &DL,
+                                              SDValue WideMask,
+                                              unsigned RequiredNumParts) {
+  if (WideMask->getOpcode() == ISD::CONCAT_VECTORS) {
+    SmallVector<SDValue, 4> MaskInterleaveOps;
+    if (!isSequentialConcatOfVectorInterleave(WideMask.getNode(),
+                                              MaskInterleaveOps))
+      return SDValue();
+
+    if (MaskInterleaveOps.size() != RequiredNumParts)
+      return SDValue();
+
+    // Make sure the inputs to the vector interleave are identical.
+    if (!llvm::all_equal(MaskInterleaveOps))
+      return SDValue();
+
+    return MaskInterleaveOps[0];
+  } else if (WideMask->getOpcode() == ISD::SPLAT_VECTOR) {
+    ElementCount EC = WideMask.getValueType().getVectorElementCount();
+    assert(EC.isKnownMultipleOf(RequiredNumParts) &&
+           "Expected element count divisible by number of parts");
+    EC = EC.divideCoefficientBy(RequiredNumParts);
+    return DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
+                       WideMask->getOperand(0));
+  }
+  return SDValue();
+}
+
+static SDValue
+performStoreInterleaveCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                              SelectionDAG &DAG) {
+  if (!DCI.isBeforeLegalize())
+    return SDValue();
+
+  MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
+  SDValue WideValue = MST->getValue();
+
+  // Bail out if the stored value has an unexpected number of uses, since we'll
+  // have to peform manual interleaving and may as well just use normal masked
+  // stores. Also, discard masked stores that are truncating or indexed.
+  if (!WideValue.hasOneUse() || !ISD::isNormalMaskedStore(MST) ||
+      !MST->getOffset().isUndef())
+    return SDValue();
+
+  SmallVector<SDValue, 4> ValueInterleaveOps;
+  if (!isSequentialConcatOfVectorInterleave(WideValue.getNode(),
+                                            ValueInterleaveOps))
+    return SDValue();
+
+  unsigned NumParts = ValueInterleaveOps.size();
+  if (NumParts != 2 && NumParts != 4)
+    return SDValue();
+
+  // At the moment we're unlikely to see a fixed-width vector deinterleave as
+  // we usually generate shuffles instead.
+  EVT SubVecTy = ValueInterleaveOps[0].getValueType();
+  if (!SubVecTy.isScalableVT() ||
+      SubVecTy.getSizeInBits().getKnownMinValue() != 128 ||
+      !DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy))
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue NarrowMask =
+      getNarrowMaskForInterleavedOps(DAG, DL, MST->getMask(), NumParts);
+  if (!NarrowMask)
+    return SDValue();
+
+  const Intrinsic::ID IID =
+      NumParts == 2 ? Intrinsic::aarch64_sve_st2 : Intrinsic::aarch64_sve_st4;
+  SDValue Res;
----------------
c-rhodes wrote:

unused

https://github.com/llvm/llvm-project/pull/156718


More information about the llvm-commits mailing list