[llvm] [AArch64] Lower alias mask to a whilewr (PR #100769)

David Green via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 1 04:10:46 PDT 2024


================
@@ -13782,8 +13783,128 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
   return ResultSLI;
 }
 
+/// Try to lower the construction of a pointer alias mask to a WHILEWR.
+/// The mask's enabled lanes represent the elements that will not overlap across
+/// one loop iteration. This tries to match:
+/// or (splat (setcc_lt (sub ptrA, ptrB), -(element_size - 1))),
+///    (get_active_lane_mask 0, (div (sub ptrA, ptrB), element_size))
+SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG) {
+  if (!DAG.getSubtarget<AArch64Subtarget>().hasSVE2())
+    return SDValue();
+  SDValue LaneMask = Op.getOperand(0);
+  SDValue Splat = Op.getOperand(1);
+
+  if (Splat.getOpcode() != ISD::SPLAT_VECTOR)
+    std::swap(LaneMask, Splat);
+
+  if (LaneMask.getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
+      LaneMask.getConstantOperandVal(0) != Intrinsic::get_active_lane_mask ||
+      Splat.getOpcode() != ISD::SPLAT_VECTOR)
+    return SDValue();
+
+  SDValue Cmp = Splat.getOperand(0);
+  if (Cmp.getOpcode() != ISD::SETCC)
+    return SDValue();
+
+  CondCodeSDNode *Cond = cast<CondCodeSDNode>(Cmp.getOperand(2));
+
+  auto ComparatorConst = dyn_cast<ConstantSDNode>(Cmp.getOperand(1));
+  if (!ComparatorConst || ComparatorConst->getSExtValue() > 0 ||
+      Cond->get() != ISD::CondCode::SETLT)
+    return SDValue();
+  unsigned CompValue = std::abs(ComparatorConst->getSExtValue());
+  unsigned EltSize = CompValue + 1;
+  if (!isPowerOf2_64(EltSize) || EltSize > 8)
+    return SDValue();
+
+  SDValue Diff = Cmp.getOperand(0);
+  if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
+    return SDValue();
+
+  auto LaneMaskConst = dyn_cast<ConstantSDNode>(LaneMask.getOperand(1));
+  if (!LaneMaskConst || LaneMaskConst->getZExtValue() != 0 ||
+      (EltSize != 1 && LaneMask.getOperand(2).getOpcode() != ISD::SRA))
+    return SDValue();
+
+  // The number of elements that alias is calculated by dividing the positive
+  // difference between the pointers by the element size. An alias mask for i8
+  // elements omits the division because it would just divide by 1
+  if (EltSize > 1) {
+    SDValue DiffDiv = LaneMask.getOperand(2);
+    auto DiffDivConst = dyn_cast<ConstantSDNode>(DiffDiv.getOperand(1));
+    if (!DiffDivConst || DiffDivConst->getZExtValue() != Log2_64(EltSize))
+      return SDValue();
+    if (EltSize > 2) {
+      // When masking i32 or i64 elements, the positive value of the
+      // possibly-negative difference comes from a select of the difference if
+      // it's positive, otherwise the difference plus the element size if it's
+      // negative: pos_diff = diff < 0 ? (diff + 7) : diff
+      SDValue Select = DiffDiv.getOperand(0);
+      // Make sure the difference is being compared by the select
+      if (Select.getOpcode() != ISD::SELECT_CC || Select.getOperand(3) != Diff)
+        return SDValue();
+      // Make sure it's checking if the difference is less than 0
+      if (auto *SelectConst = dyn_cast<ConstantSDNode>(Select.getOperand(1));
+          !SelectConst || SelectConst->getZExtValue() != 0 ||
+          cast<CondCodeSDNode>(Select.getOperand(4))->get() !=
+              ISD::CondCode::SETLT)
+        return SDValue();
+      // An add creates a positive value from the negative difference
+      SDValue Add = Select.getOperand(2);
+      if (Add.getOpcode() != ISD::ADD || Add.getOperand(0) != Diff)
+        return SDValue();
+      if (auto *AddConst = dyn_cast<ConstantSDNode>(Add.getOperand(1));
+          !AddConst || AddConst->getZExtValue() != EltSize - 1)
+        return SDValue();
+    } else {
+      // When masking i16 elements, this positive value comes from adding the
+      // difference's sign bit to the difference itself. This is equivalent to
+      // the 32 bit and 64 bit case: pos_diff = diff + sign_bit (diff)
+      SDValue Add = DiffDiv.getOperand(0);
+      if (Add.getOpcode() != ISD::ADD || Add.getOperand(0) != Diff)
+        return SDValue();
+      // A logical right shift by 63 extracts the sign bit from the difference
+      SDValue Shift = Add.getOperand(1);
+      if (Shift.getOpcode() != ISD::SRL || Shift.getOperand(0) != Diff)
+        return SDValue();
+      if (auto *ShiftConst = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
+          !ShiftConst || ShiftConst->getZExtValue() != 63)
+        return SDValue();
+    }
+  } else if (LaneMask.getOperand(2) != Diff)
+    return SDValue();
+
+  SDValue StorePtr = Diff.getOperand(0);
+  SDValue ReadPtr = Diff.getOperand(1);
+
+  unsigned IntrinsicID = 0;
+  switch (EltSize) {
+  case 1:
+    IntrinsicID = Intrinsic::aarch64_sve_whilewr_b;
+    break;
+  case 2:
+    IntrinsicID = Intrinsic::aarch64_sve_whilewr_h;
+    break;
+  case 4:
+    IntrinsicID = Intrinsic::aarch64_sve_whilewr_s;
+    break;
+  case 8:
+    IntrinsicID = Intrinsic::aarch64_sve_whilewr_d;
+    break;
+  default:
+    return SDValue();
+  }
+  SDLoc DL(Op);
+  SDValue ID = DAG.getConstant(IntrinsicID, DL, MVT::i32);
+  return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), ID,
+                     StorePtr, ReadPtr);
+}
+
 SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
                                              SelectionDAG &DAG) const {
+
+  if (SDValue SV = tryWhileWRFromOR(Op, DAG))
+    return SV;
----------------
davemgreen wrote:

I would personally remove the first line but have a line between this and the useSVEForFixedLengthVectorVT if below it.

https://github.com/llvm/llvm-project/pull/100769


More information about the llvm-commits mailing list