[llvm] [AArch64] Combine getActiveLaneMask with vector_extract (PR #81139)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 8 06:39:19 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>

... into a `whilelo` instruction with a pair of predicate registers.

---
Full diff: https://github.com/llvm/llvm-project/pull/81139.diff


3 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+86-35) 
- (modified) llvm/test/CodeGen/AArch64/active_lane_mask.ll (+1) 
- (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+90) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8573939b04389f..4405e8d3f91df2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1813,8 +1813,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
@@ -20004,47 +20004,98 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
-static SDValue performIntrinsicCombine(SDNode *N,
-                                       TargetLowering::DAGCombinerInfo &DCI,
-                                       const AArch64Subtarget *Subtarget) {
+static SDValue tryCombineGetActiveLaneMask(SDNode *N,
+                                           TargetLowering::DAGCombinerInfo &DCI,
+                                           const AArch64Subtarget *Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
-  unsigned IID = getIntrinsicID(N);
-  switch (IID) {
-  default:
-    break;
-  case Intrinsic::get_active_lane_mask: {
-    SDValue Res = SDValue();
-    EVT VT = N->getValueType(0);
-    if (VT.isFixedLengthVector()) {
-      // We can use the SVE whilelo instruction to lower this intrinsic by
-      // creating the appropriate sequence of scalable vector operations and
-      // then extracting a fixed-width subvector from the scalable vector.
+  EVT VT = N->getValueType(0);
+  if (VT.isFixedLengthVector()) {
+    // We can use the SVE whilelo instruction to lower this intrinsic by
+    // creating the appropriate sequence of scalable vector operations and
+    // then extracting a fixed-width subvector from the scalable vector.
+    SDLoc DL(N);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
 
-      SDLoc DL(N);
-      SDValue ID =
-          DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
+    EVT WhileVT =
+        EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                         ElementCount::getScalable(VT.getVectorNumElements()));
 
-      EVT WhileVT = EVT::getVectorVT(
-          *DAG.getContext(), MVT::i1,
-          ElementCount::getScalable(VT.getVectorNumElements()));
+    // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
+    EVT PromVT = getPromotedVTForPredicate(WhileVT);
 
-      // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
-      EVT PromVT = getPromotedVTForPredicate(WhileVT);
+    // Get the fixed-width equivalent of PromVT for extraction.
+    EVT ExtVT =
+        EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
+                         VT.getVectorElementCount());
 
-      // Get the fixed-width equivalent of PromVT for extraction.
-      EVT ExtVT =
-          EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
-                           VT.getVectorElementCount());
+    SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
+                              N->getOperand(1), N->getOperand(2));
+    Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
+    Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
 
-      Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
-                        N->getOperand(1), N->getOperand(2));
-      Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
-      Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
-                        DAG.getConstant(0, DL, MVT::i64));
-      Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
-    }
     return Res;
   }
+
+  if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
+  uint64_t OffLo, OffHi;
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+      (OffLo != 0 && OffLo != HalfSize) ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Hi->getOperand(1).getNode(), OffHi) ||
+      (OffHi != 0 && OffHi != HalfSize))
+    return SDValue();
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  SDValue Idx = N->getOperand(1);
+  SDValue TC = N->getOperand(2);
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  }
+  auto R =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+  DCI.CombineTo(Lo, R.getValue(0));
+  DCI.CombineTo(Hi, R.getValue(1));
+
+  return SDValue(N, 0);
+}
+
+static SDValue performIntrinsicCombine(SDNode *N,
+                                       TargetLowering::DAGCombinerInfo &DCI,
+                                       const AArch64Subtarget *Subtarget) {
+  SelectionDAG &DAG = DCI.DAG;
+  unsigned IID = getIntrinsicID(N);
+  switch (IID) {
+  default:
+    break;
+  case Intrinsic::get_active_lane_mask:
+    return tryCombineGetActiveLaneMask(N, DCI, Subtarget);
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/test/CodeGen/AArch64/active_lane_mask.ll b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
index a65c5d66677946..6a509b5f3afcae 100644
--- a/llvm/test/CodeGen/AArch64/active_lane_mask.ll
+++ b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme < %s | FileCheck %s
 
 ; == Scalable ==
 
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
new file mode 100644
index 00000000000000..e32dec91d2ff40
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -0,0 +1,90 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mattr=+sve    < %s | FileCheck %s -check-prefix CHECK-SVE
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1
+; RUN: llc -mattr=+sme2   < %s | FileCheck %s -check-prefix CHECK-SME2
+target triple = "aarch64-linux"
+
+; Test combining of getActiveLaneMask with a pair of extract_vector operations.
+
+define void @f0(i32 %i, i32 %n, ptr %p0, ptr %p1) #0 {
+; CHECK-SVE-LABEL: f0:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.b, w0, w1
+; CHECK-SVE-NEXT:    ptrue p1.h
+; CHECK-SVE-NEXT:    punpklo p2.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p0.h, p0.b
+; CHECK-SVE-NEXT:    and p2.b, p2/z, p2.b, p1.b
+; CHECK-SVE-NEXT:    and p0.b, p0/z, p0.b, p1.b
+; CHECK-SVE-NEXT:    str p2, [x2]
+; CHECK-SVE-NEXT:    str p0, [x3]
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-LABEL: f0:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    mov w8, w1
+; CHECK-SVE2p1-NEXT:    mov w9, w0
+; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SVE2p1-NEXT:    str p0, [x2]
+; CHECK-SVE2p1-NEXT:    str p1, [x3]
+; CHECK-SVE2p1-NEXT:    ret
+;
+; CHECK-SME2-LABEL: f0:
+; CHECK-SME2:       // %bb.0:
+; CHECK-SME2-NEXT:    mov w8, w1
+; CHECK-SME2-NEXT:    mov w9, w0
+; CHECK-SME2-NEXT:    whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SME2-NEXT:    str p0, [x2]
+; CHECK-SME2-NEXT:    str p1, [x3]
+; CHECK-SME2-NEXT:    ret
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    %pg0 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v0)
+    %pg1 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v1)
+    store <vscale x 16 x i1> %pg0, ptr %p0
+    store <vscale x 16 x i1> %pg1, ptr %p1
+    ret void
+}
+
+define void @f1(i64 %i, i64 %n, ptr %p0, ptr %p1) #0 {
+; CHECK-SVE-LABEL: f1:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.b, x0, x1
+; CHECK-SVE-NEXT:    ptrue p1.h
+; CHECK-SVE-NEXT:    punpklo p2.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p0.h, p0.b
+; CHECK-SVE-NEXT:    and p2.b, p2/z, p2.b, p1.b
+; CHECK-SVE-NEXT:    and p0.b, p0/z, p0.b, p1.b
+; CHECK-SVE-NEXT:    str p2, [x2]
+; CHECK-SVE-NEXT:    str p0, [x3]
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-LABEL: f1:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-NEXT:    str p0, [x2]
+; CHECK-SVE2p1-NEXT:    str p1, [x3]
+; CHECK-SVE2p1-NEXT:    ret
+;
+; CHECK-SME2-LABEL: f1:
+; CHECK-SME2:       // %bb.0:
+; CHECK-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SME2-NEXT:    str p0, [x2]
+; CHECK-SME2-NEXT:    str p1, [x3]
+; CHECK-SME2-NEXT:    ret
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    %pg0 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v0)
+    %pg1 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v1)
+    store <vscale x 16 x i1> %pg0, ptr %p0
+    store <vscale x 16 x i1> %pg1, ptr %p1
+    ret void
+}
+
+declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
+declare <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1>, i64)
+declare <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32, i32)
+declare <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64, i64)
+
+attributes #0 = { nounwind }

``````````

</details>


https://github.com/llvm/llvm-project/pull/81139


More information about the llvm-commits mailing list