[llvm] [AArch64] Support lowering smaller than legal LOOP_DEP_MASKs to whilewr/rw (PR #171982)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 12 03:03:54 PST 2025


================
@@ -5439,55 +5439,43 @@ static MVT getSVEContainerType(EVT ContentTy);
 SDValue
 AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
                                                  SelectionDAG &DAG) const {
+  assert((Subtarget->hasSVE2() ||
+          (Subtarget->hasSME() && Subtarget->isStreaming())) &&
+         "Lowering loop_dependence_raw_mask or loop_dependence_war_mask "
+         "requires SVE or SME");
+
   SDLoc DL(Op);
   EVT VT = Op.getValueType();
-  SDValue EltSize = Op.getOperand(2);
-  switch (EltSize->getAsZExtVal()) {
-  case 1:
-    if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
-      return SDValue();
-    break;
-  case 2:
-    if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
-      return SDValue();
-    break;
-  case 4:
-    if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
-      return SDValue();
-    break;
-  case 8:
-    if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
-      return SDValue();
-    break;
-  default:
-    // Other element sizes are incompatible with whilewr/rw, so expand instead
-    return SDValue();
-  }
+  unsigned LaneOffset = Op.getConstantOperandVal(3);
+  unsigned NumElements = VT.getVectorMinNumElements();
+  uint64_t EltSizeInBytes = Op.getConstantOperandVal(2);
 
-  SDValue LaneOffset = Op.getOperand(3);
-  if (LaneOffset->getAsZExtVal())
+  // Lane offsets and other element sizes are not supported by whilewr/rw.
+  if (LaneOffset != 0 || !is_contained({1u, 2u, 4u, 8u}, EltSizeInBytes))
     return SDValue();
 
-  SDValue PtrA = Op.getOperand(0);
-  SDValue PtrB = Op.getOperand(1);
+  EVT EltVT = MVT::getIntegerVT(EltSizeInBytes * 8);
+  EVT PredVT = getPackedSVEVectorVT(EltVT).changeElementType(MVT::i1);
 
-  if (VT.isScalableVT())
-    return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, EltSize, LaneOffset);
+  // Legal whilewr/rw (lowered by tablegen matcher).
+  if (PredVT == VT)
+    return Op;
 
-  // We can use the SVE whilewr/whilerw instruction to lower this
-  // intrinsic by creating the appropriate sequence of scalable vector
-  // operations and then extracting a fixed-width subvector from the scalable
-  // vector. Scalable vector variants are already legal.
-  EVT ContainerVT =
-      EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
-                       VT.getVectorNumElements(), true);
-  EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+  // Expand if this mask needs splitting (this will produce a whilelo).
+  if (NumElements > PredVT.getVectorMinNumElements())
+    return SDValue();
 
   SDValue Mask =
-      DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, EltSize, LaneOffset);
-  SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
-  return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
-                     DAG.getVectorIdxConstant(0, DL));
+      DAG.getNode(Op.getOpcode(), DL, PredVT, to_vector(Op->op_values()));
+
+  if (VT.isFixedLengthVector()) {
----------------
MacDue wrote:

Sure :+1: 

https://github.com/llvm/llvm-project/pull/171982


More information about the llvm-commits mailing list