[llvm-branch-commits] [llvm] [AArch64] Split large loop dependence masks (PR #153187)

Benjamin Maxwell via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Aug 20 08:22:23 PDT 2025


================
@@ -5248,49 +5248,94 @@ AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
                                                  SelectionDAG &DAG) const {
   SDLoc DL(Op);
   uint64_t EltSize = Op.getConstantOperandVal(2);
-  EVT VT = Op.getValueType();
+  EVT FullVT = Op.getValueType();
+  unsigned NumElements = FullVT.getVectorMinNumElements();
+  unsigned NumSplits = 0;
+  EVT EltVT;
   switch (EltSize) {
   case 1:
-    if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
-      return SDValue();
+    EltVT = MVT::i8;
     break;
   case 2:
-    if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
-      return SDValue();
+    if (NumElements >= 16)
+      NumSplits = NumElements / 16;
+    EltVT = MVT::i16;
     break;
   case 4:
-    if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
-      return SDValue();
+    if (NumElements >= 8)
+      NumSplits = NumElements / 8;
+    EltVT = MVT::i32;
     break;
   case 8:
-    if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
-      return SDValue();
+    if (NumElements >= 4)
+      NumSplits = NumElements / 4;
+    EltVT = MVT::i64;
     break;
   default:
     // Other element sizes are incompatible with whilewr/rw, so expand instead
     return SDValue();
   }
 
-  SDValue PtrA = Op.getOperand(0);
-  SDValue PtrB = Op.getOperand(1);
+  auto LowerToWhile = [&](EVT VT, unsigned AddrScale) {
+    SDValue PtrA = Op.getOperand(0);
+    SDValue PtrB = Op.getOperand(1);
 
-  if (VT.isScalableVT())
-    return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2));
+    EVT StoreVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
+                                   VT.getVectorMinNumElements(), false);
+    if (AddrScale > 0) {
+      unsigned Offset = StoreVT.getStoreSizeInBits() / 8 * AddrScale;
+      SDValue Addend;
 
-  // We can use the SVE whilewr/whilerw instruction to lower this
-  // intrinsic by creating the appropriate sequence of scalable vector
-  // operations and then extracting a fixed-width subvector from the scalable
-  // vector. Scalable vector variants are already legal.
-  EVT ContainerVT =
-      EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
-                       VT.getVectorNumElements(), true);
-  EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+      if (VT.isScalableVT())
+        Addend = DAG.getVScale(DL, MVT::i64, APInt(64, Offset));
+      else
+        Addend = DAG.getConstant(Offset, DL, MVT::i64);
 
-  SDValue Mask =
-      DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2));
-  SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
-  return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
-                     DAG.getVectorIdxConstant(0, DL));
+      PtrA = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrA, Addend);
+      PtrB = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrB, Addend);
+    }
+
+    if (VT.isScalableVT())
+      return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2));
+
+    // We can use the SVE whilewr/whilerw instruction to lower this
+    // intrinsic by creating the appropriate sequence of scalable vector
+    // operations and then extracting a fixed-width subvector from the scalable
+    // vector. Scalable vector variants are already legal.
+    EVT ContainerVT =
+        EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
+                         VT.getVectorNumElements(), true);
+    EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+
+    SDValue Mask =
+        DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2));
+    SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
+                       DAG.getVectorIdxConstant(0, DL));
+  };
+
+  if (NumSplits == 0)
+    return LowerToWhile(FullVT, 0);
+
+  SDValue FullVec = DAG.getUNDEF(FullVT);
+
+  unsigned NumElementsPerSplit = NumElements / (2 * NumSplits);
+  EVT PartVT =
+      EVT::getVectorVT(*DAG.getContext(), FullVT.getVectorElementType(),
+                       NumElementsPerSplit, FullVT.isScalableVT());
+  for (unsigned Split = 0, InsertIdx = 0; Split < NumSplits;
+       Split++, InsertIdx += 2) {
+    SDValue Low = LowerToWhile(PartVT, InsertIdx);
+    SDValue High = LowerToWhile(PartVT, InsertIdx + 1);
+    unsigned InsertIdxLow = InsertIdx * NumElementsPerSplit;
+    unsigned InsertIdxHigh = (InsertIdx + 1) * NumElementsPerSplit;
+    SDValue Insert =
+        DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, FullVec, Low,
+                    DAG.getVectorIdxConstant(InsertIdxLow, DL));
+    FullVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, Insert, High,
+                          DAG.getVectorIdxConstant(InsertIdxHigh, DL));
+  }
----------------
MacDue wrote:

Following from my first suggestion (that adds `NumWhiles`) this can be simplified to:
```cpp
  if (NumWhiles <= 1)
    return LowerToWhile(FullVT, 0);

  unsigned NumElementsPerSplit = NumElements / NumWhiles;
  EVT PartVT =
      EVT::getVectorVT(*DAG.getContext(), FullVT.getVectorElementType(),
                       NumElementsPerSplit, FullVT.isScalableVT());
  SDValue FullVec = DAG.getUNDEF(FullVT);
  for (unsigned I = 0; I < NumWhiles; I++) {
    SDValue While = LowerToWhile(PartVT, I);
    unsigned InsertIdx = I * NumElementsPerSplit;
    FullVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, FullVec, While,
                          DAG.getVectorIdxConstant(InsertIdx, DL));
  }
```

https://github.com/llvm/llvm-project/pull/153187


More information about the llvm-branch-commits mailing list