[llvm] [AArch64] Lower alias mask to a whilewr (PR #100769)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 1 04:10:46 PDT 2024
================
@@ -13782,8 +13783,128 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
return ResultSLI;
}
+/// Try to lower the construction of a pointer alias mask to a WHILEWR.
+/// The mask's enabled lanes represent the elements that will not overlap across
+/// one loop iteration. This tries to match:
+/// or (splat (setcc_lt (sub ptrA, ptrB), -(element_size - 1))),
+/// (get_active_lane_mask 0, (div (sub ptrA, ptrB), element_size))
+SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG) {
+ if (!DAG.getSubtarget<AArch64Subtarget>().hasSVE2())
+ return SDValue();
+ SDValue LaneMask = Op.getOperand(0);
+ SDValue Splat = Op.getOperand(1);
+
+ if (Splat.getOpcode() != ISD::SPLAT_VECTOR)
+ std::swap(LaneMask, Splat);
+
+ if (LaneMask.getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
+ LaneMask.getConstantOperandVal(0) != Intrinsic::get_active_lane_mask ||
+ Splat.getOpcode() != ISD::SPLAT_VECTOR)
+ return SDValue();
+
+ SDValue Cmp = Splat.getOperand(0);
+ if (Cmp.getOpcode() != ISD::SETCC)
+ return SDValue();
+
+ CondCodeSDNode *Cond = cast<CondCodeSDNode>(Cmp.getOperand(2));
+
+ auto ComparatorConst = dyn_cast<ConstantSDNode>(Cmp.getOperand(1));
+ if (!ComparatorConst || ComparatorConst->getSExtValue() > 0 ||
+ Cond->get() != ISD::CondCode::SETLT)
+ return SDValue();
+ unsigned CompValue = std::abs(ComparatorConst->getSExtValue());
+ unsigned EltSize = CompValue + 1;
+ if (!isPowerOf2_64(EltSize) || EltSize > 8)
+ return SDValue();
+
+ SDValue Diff = Cmp.getOperand(0);
+ if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
+ return SDValue();
+
+ auto LaneMaskConst = dyn_cast<ConstantSDNode>(LaneMask.getOperand(1));
+ if (!LaneMaskConst || LaneMaskConst->getZExtValue() != 0 ||
+ (EltSize != 1 && LaneMask.getOperand(2).getOpcode() != ISD::SRA))
+ return SDValue();
+
+ // The number of elements that alias is calculated by dividing the positive
+ // difference between the pointers by the element size. An alias mask for i8
+ // elements omits the division because it would just divide by 1
+ if (EltSize > 1) {
+ SDValue DiffDiv = LaneMask.getOperand(2);
+ auto DiffDivConst = dyn_cast<ConstantSDNode>(DiffDiv.getOperand(1));
+ if (!DiffDivConst || DiffDivConst->getZExtValue() != Log2_64(EltSize))
+ return SDValue();
+ if (EltSize > 2) {
+ // When masking i32 or i64 elements, the positive value of the
+ // possibly-negative difference comes from a select of the difference if
+ // it's positive, otherwise the difference plus the element size if it's
+ // negative: pos_diff = diff < 0 ? (diff + 7) : diff
+ SDValue Select = DiffDiv.getOperand(0);
+ // Make sure the difference is being compared by the select
+ if (Select.getOpcode() != ISD::SELECT_CC || Select.getOperand(3) != Diff)
+ return SDValue();
+ // Make sure it's checking if the difference is less than 0
+ if (auto *SelectConst = dyn_cast<ConstantSDNode>(Select.getOperand(1));
+ !SelectConst || SelectConst->getZExtValue() != 0 ||
----------------
davemgreen wrote:
isNullConstant(Select.getOperand(1))
https://github.com/llvm/llvm-project/pull/100769
More information about the llvm-commits
mailing list