[llvm] [AArch64] Combine getActiveLaneMask with vector_extract (PR #81139)

Momchil Velikov via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 1 03:53:17 PST 2024


================
@@ -20004,47 +20004,98 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
-static SDValue performIntrinsicCombine(SDNode *N,
-                                       TargetLowering::DAGCombinerInfo &DCI,
-                                       const AArch64Subtarget *Subtarget) {
+static SDValue tryCombineGetActiveLaneMask(SDNode *N,
+                                           TargetLowering::DAGCombinerInfo &DCI,
+                                           const AArch64Subtarget *Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
-  unsigned IID = getIntrinsicID(N);
-  switch (IID) {
-  default:
-    break;
-  case Intrinsic::get_active_lane_mask: {
-    SDValue Res = SDValue();
-    EVT VT = N->getValueType(0);
-    if (VT.isFixedLengthVector()) {
-      // We can use the SVE whilelo instruction to lower this intrinsic by
-      // creating the appropriate sequence of scalable vector operations and
-      // then extracting a fixed-width subvector from the scalable vector.
+  EVT VT = N->getValueType(0);
+  if (VT.isFixedLengthVector()) {
+    // We can use the SVE whilelo instruction to lower this intrinsic by
+    // creating the appropriate sequence of scalable vector operations and
+    // then extracting a fixed-width subvector from the scalable vector.
+    SDLoc DL(N);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
 
-      SDLoc DL(N);
-      SDValue ID =
-          DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
+    EVT WhileVT =
+        EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                         ElementCount::getScalable(VT.getVectorNumElements()));
 
-      EVT WhileVT = EVT::getVectorVT(
-          *DAG.getContext(), MVT::i1,
-          ElementCount::getScalable(VT.getVectorNumElements()));
+    // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
+    EVT PromVT = getPromotedVTForPredicate(WhileVT);
 
-      // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
-      EVT PromVT = getPromotedVTForPredicate(WhileVT);
+    // Get the fixed-width equivalent of PromVT for extraction.
+    EVT ExtVT =
+        EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
+                         VT.getVectorElementCount());
 
-      // Get the fixed-width equivalent of PromVT for extraction.
-      EVT ExtVT =
-          EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
-                           VT.getVectorElementCount());
+    SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
+                              N->getOperand(1), N->getOperand(2));
+    Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
+    Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
 
-      Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
-                        N->getOperand(1), N->getOperand(2));
-      Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
-      Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
-                        DAG.getConstant(0, DL, MVT::i64));
-      Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
-    }
     return Res;
   }
+
+  if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
+  uint64_t OffLo, OffHi;
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+      (OffLo != 0 && OffLo != HalfSize) ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Hi->getOperand(1).getNode(), OffHi) ||
+      (OffHi != 0 && OffHi != HalfSize))
+    return SDValue();
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
----------------
momchil-velikov wrote:

Done

https://github.com/llvm/llvm-project/pull/81139


More information about the llvm-commits mailing list