[llvm] [AArch64] Improve lowering for scalable masked interleaving stores (PR #156718)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 4 10:06:54 PDT 2025


================
@@ -24632,6 +24632,104 @@ static SDValue performSTORECombine(SDNode *N,
   return SDValue();
 }
 
+static bool
+isSequentialConcatOfVectorInterleave(SDNode *N, SmallVectorImpl<SDValue> &Ops) {
+  if (N->getOpcode() != ISD::CONCAT_VECTORS)
+    return false;
+
+  unsigned NumParts = N->getNumOperands();
+
+  // We should be concatenating each sequential result from a
+  // VECTOR_INTERLEAVE.
+  SDNode *InterleaveOp = N->getOperand(0).getNode();
+  if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
+      InterleaveOp->getNumOperands() != NumParts)
+    return false;
+
+  for (unsigned I = 0; I < NumParts; I++) {
+    if (N->getOperand(I) != SDValue(InterleaveOp, I))
+      return false;
+  }
+
+  Ops.append(InterleaveOp->op_begin(), InterleaveOp->op_end());
+  return true;
+}
+
+static SDValue getNarrowMaskForInterleavedOps(SelectionDAG &DAG, SDLoc &DL,
+                                              SDValue WideMask,
+                                              unsigned RequiredNumParts) {
+  if (WideMask->getOpcode() == ISD::CONCAT_VECTORS) {
+    SmallVector<SDValue, 4> MaskInterleaveOps;
+    if (!isSequentialConcatOfVectorInterleave(WideMask.getNode(),
+                                              MaskInterleaveOps))
+      return SDValue();
+
+    if (MaskInterleaveOps.size() != RequiredNumParts)
+      return SDValue();
+
+    // Make sure the inputs to the vector interleave are identical.
+    if (!llvm::all_equal(MaskInterleaveOps))
+      return SDValue();
+
+    return MaskInterleaveOps[0];
+  } else if (WideMask->getOpcode() == ISD::SPLAT_VECTOR) {
----------------
paulwalker-arm wrote:

Coding standard recommends no else after return.

```suggestion
  }
  
  if (WideMask->getOpcode() == ISD::SPLAT_VECTOR) {    
```

or perhaps even

```suggestion
  }
  
  if (WideMask->getOpcode() != ISD::SPLAT_VECTOR)
    return SDValue();
```

https://github.com/llvm/llvm-project/pull/156718


More information about the llvm-commits mailing list