[llvm] [AArch64] Lower alias mask to a whilewr (PR #100769)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 31 06:37:42 PDT 2024


================
@@ -13782,8 +13784,89 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
   return ResultSLI;
 }
 
+/// Try to lower the construction of a pointer alias mask to a WHILEWR.
+/// The mask's enabled lanes represent the elements that will not overlap across
+/// one loop iteration. This tries to match: or (splat (setcc_lt (sub ptrA,
+/// ptrB), -(element_size - 1))),
+///    (get_active_lane_mask 0, (div (sub ptrA, ptrB), element_size))
+SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG) {
+  if (!DAG.getSubtarget<AArch64Subtarget>().hasSVE2())
+    return SDValue();
+  auto LaneMask = Op.getOperand(0);
+  auto Splat = Op.getOperand(1);
+
+  if (LaneMask.getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
+      LaneMask.getConstantOperandVal(0) != Intrinsic::get_active_lane_mask ||
+      Splat.getOpcode() != ISD::SPLAT_VECTOR)
+    return SDValue();
+
+  auto Cmp = Splat.getOperand(0);
+  if (Cmp.getOpcode() != ISD::SETCC)
+    return SDValue();
+
+  CondCodeSDNode *Cond = dyn_cast<CondCodeSDNode>(Cmp.getOperand(2));
+  assert(Cond && "SETCC doesn't have a condition code");
+
+  auto ComparatorConst = dyn_cast<ConstantSDNode>(Cmp.getOperand(1));
+  if (!ComparatorConst || ComparatorConst->getSExtValue() > 0 ||
+      Cond->get() != ISD::CondCode::SETLT)
+    return SDValue();
+  unsigned CompValue = std::abs(ComparatorConst->getSExtValue());
+  unsigned EltSize = CompValue + 1;
+  if (!isPowerOf2_64(EltSize) || EltSize > 64)
+    return SDValue();
+
+  auto Diff = Cmp.getOperand(0);
+  if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
+    return SDValue();
+
+  auto LaneMaskConst = dyn_cast<ConstantSDNode>(LaneMask.getOperand(1));
+  if (!LaneMaskConst || LaneMaskConst->getZExtValue() != 0 ||
+      (EltSize != 1 && LaneMask.getOperand(2).getOpcode() != ISD::SRA))
+    return SDValue();
+
+  // An alias mask for i8 elements omits the division because it would just
+  // divide by 1
+  if (EltSize > 1) {
+    auto DiffDiv = LaneMask.getOperand(2);
+    auto DiffDivConst = dyn_cast<ConstantSDNode>(DiffDiv.getOperand(1));
+    if (!DiffDivConst || DiffDivConst->getZExtValue() != std::log2(EltSize))
----------------
SamTebbs33 wrote:

There is indeed a check of the divide operand missing. It's a bit more involved since `Diff` can be negative, so some extra selects and such are inserted to turn it positive before dividing. The 16 bit element case is a little different to the others and so needs some difference matching. I hope I've explained it well enough in the comments in the commit I just pushed to address your suggestion.

https://github.com/llvm/llvm-project/pull/100769


More information about the llvm-commits mailing list