[llvm] [AArch64] Improve lowering for scalable masked deinterleaving loads (PR #154338)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 29 06:55:46 PDT 2025


================
@@ -27015,6 +27016,120 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   return NVCAST;
 }
 
+static SDValue performVectorDeinterleaveCombine(
+    SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
+  unsigned NumParts = N->getNumOperands();
+  if (NumParts != 2 && NumParts != 4)
+    return SDValue();
+
+  EVT SubVecTy = N->getValueType(0);
+
+  // At the moment we're unlikely to see a fixed-width vector deinterleave as
+  // we usually generate shuffles instead.
+  unsigned MinNumElements = SubVecTy.getVectorMinNumElements();
+  if (!SubVecTy.isScalableVT() ||
+      SubVecTy.getSizeInBits().getKnownMinValue() != 128 || MinNumElements == 1)
+    return SDValue();
+
+  // Make sure each input operand is the correct extract_subvector of the same
+  // wider vector.
+  SDValue Op0 = N->getOperand(0);
+  for (unsigned I = 0; I < NumParts; I++) {
+    SDValue OpI = N->getOperand(I);
+    if (OpI->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+        OpI->getOperand(0) != Op0->getOperand(0))
+      return SDValue();
+    auto *Idx = cast<ConstantSDNode>(OpI->getOperand(1));
+    if (Idx->getZExtValue() != (I * MinNumElements))
+      return SDValue();
+  }
+
+  // Normal loads are currently already handled by the InterleavedAccessPass so
+  // we don't expect to see them here. Bail out if the masked load has an
+  // unexpected number of uses, since we want to avoid a situation where we have
+  // both deinterleaving loads and normal loads in the same block. Also, discard
+  // masked loads that are extending, indexed, have an unexpected offset or have
+  // an unsupported passthru value until we find a valid use case.
+  auto MaskedLoad = dyn_cast<MaskedLoadSDNode>(Op0->getOperand(0));
+  if (!MaskedLoad || !MaskedLoad->hasNUsesOfValue(NumParts, 0) ||
+      MaskedLoad->getExtensionType() != ISD::NON_EXTLOAD ||
+      MaskedLoad->getAddressingMode() != ISD::UNINDEXED ||
+      !MaskedLoad->getOffset().isUndef() ||
+      (!MaskedLoad->getPassThru()->isUndef() &&
+       !isZerosVector(MaskedLoad->getPassThru().getNode())))
+    return SDValue();
+
+  // Now prove that the mask is an interleave of identical masks.
+  SDValue Mask = MaskedLoad->getMask();
+  if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
+      Mask->getOpcode() != ISD::CONCAT_VECTORS)
+    return SDValue();
+
+  SDValue NarrowMask;
+  SDLoc DL(N);
+  if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
+    if (Mask->getNumOperands() != NumParts)
+      return SDValue();
+
+    // We should be concatenating each sequential result from a
+    // VECTOR_INTERLEAVE.
+    SDValue InterleaveOp = Mask->getOperand(0);
+    if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
+        InterleaveOp->getNumOperands() != NumParts)
+      return SDValue();
+
+    for (unsigned I = 0; I < NumParts; I++) {
+      SDValue ConcatOp = Mask->getOperand(I);
+      if (ConcatOp.getResNo() != I ||
+          ConcatOp.getNode() != InterleaveOp.getNode())
+        return SDValue();
+    }
+
+    // Make sure the inputs to the vector interleave are identical.
+    for (unsigned I = 1; I < NumParts; I++) {
+      if (InterleaveOp->getOperand(I) != InterleaveOp->getOperand(0))
+        return SDValue();
+    }
+
+    NarrowMask = InterleaveOp->getOperand(0);
+  } else { // ISD::SPLAT_VECTOR
+    auto *SplatVal = dyn_cast<ConstantSDNode>(Mask->getOperand(0));
+    if (!SplatVal || SplatVal->getZExtValue() != 1)
+      return SDValue();
----------------
paulwalker-arm wrote:

I would expect the all-false case to be caught by instcombine so there would be no need for the code generator to worry about it.

https://github.com/llvm/llvm-project/pull/154338


More information about the llvm-commits mailing list