[llvm] 88da875 - [AArch64] Combine getActiveLaneMask with vector_extract (#81139)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 10 12:01:49 PDT 2024


Author: Momchil Velikov
Date: 2024-05-10T20:01:42+01:00
New Revision: 88da8756a6f54a8ae441750eaa890c2450b94490

URL: https://github.com/llvm/llvm-project/commit/88da8756a6f54a8ae441750eaa890c2450b94490
DIFF: https://github.com/llvm/llvm-project/commit/88da8756a6f54a8ae441750eaa890c2450b94490.diff

LOG: [AArch64] Combine getActiveLaneMask with vector_extract (#81139)

... into a `whilelo` instruction with a pair of predicate registers.

Added: 
    llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 29b66a72d21df..4f2f170922ec1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20535,6 +20535,66 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
+static SDValue tryCombineWhileLo(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const AArch64Subtarget *Subtarget) {
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  if (!Subtarget->hasSVE2p1())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
+  if (HalfSize < 2)
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
+    return SDValue();
+
+  uint64_t OffLo = Lo->getConstantOperandVal(1);
+  uint64_t OffHi = Hi->getConstantOperandVal(1);
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
+    return SDValue();
+
+  EVT HalfVec = Lo->getValueType(0);
+  if (HalfVec != Hi->getValueType(0) ||
+      HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
+    return SDValue();
+
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc DL(N);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  SDValue Idx = N->getOperand(1);
+  SDValue TC = N->getOperand(2);
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  }
+  auto R =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+  DCI.CombineTo(Lo, R.getValue(0));
+  DCI.CombineTo(Hi, R.getValue(1));
+
+  return SDValue(N, 0);
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -20832,6 +20892,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_ptest_last:
     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
                     AArch64CC::LAST_ACTIVE);
+  case Intrinsic::aarch64_sve_whilelo:
+    return tryCombineWhileLo(N, DCI, Subtarget);
   }
   return SDValue();
 }

diff  --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
new file mode 100644
index 0000000000000..2d84a69f3144e
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -0,0 +1,156 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mattr=+sve    < %s | FileCheck %s -check-prefix CHECK-SVE
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1
+target triple = "aarch64-linux"
+
+; Test combining of getActiveLaneMask with a pair of extract_vector operations.
+
+define void @test_2x8bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p1.b, w0, w1
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    mov w8, w1
+; CHECK-SVE2p1-NEXT:    mov w9, w0
+; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SVE2p1-NEXT:    b use
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+    ret void
+}
+
+define void @test_2x8bit_mask_with_64bit_index_and_trip_count(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p1.b, x0, x1
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-NEXT:    b use
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+    ret void
+}
+
+define void @test_edge_case_2x1bit_mask(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_edge_case_2x1bit_mask:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p1.d, x0, x1
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-LABEL: test_edge_case_2x1bit_mask:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    whilelo p1.d, x0, x1
+; CHECK-SVE2p1-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE2p1-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE2p1-NEXT:    b use
+    %r = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 %i, i64 %n)
+    %v0 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 0)
+    %v1 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 1)
+    tail call void @use(<vscale x 1 x i1> %v0, <vscale x 1 x i1> %v1)
+    ret void
+}
+
+define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_boring_case_2x2bit_mask:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p1.s, x0, x1
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-LABEL: test_boring_case_2x2bit_mask:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    whilelo { p0.d, p1.d }, x0, x1
+; CHECK-SVE2p1-NEXT:    b use
+    %r = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 %i, i64 %n)
+    %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 0)
+    %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 2)
+    tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1)
+    ret void
+}
+
+; Negative test for when not extracting exactly two halves of the source vector
+define void @test_partial_extract(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_partial_extract:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p2.h, p0.b
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpklo p1.h, p2.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-LABEL: test_partial_extract:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE2p1-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE2p1-NEXT:    punpkhi p2.h, p0.b
+; CHECK-SVE2p1-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE2p1-NEXT:    punpklo p1.h, p2.b
+; CHECK-SVE2p1-NEXT:    b use
+    %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+    %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+    %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+    tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1)
+    ret void
+}
+
+;; Negative test for when extracting a fixed-length vector.
+define void @test_fixed_extract(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_fixed_extract:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    cset w8, mi
+; CHECK-SVE-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-SVE-NEXT:    umov w9, v0.h[4]
+; CHECK-SVE-NEXT:    umov w10, v0.h[1]
+; CHECK-SVE-NEXT:    umov w11, v0.h[5]
+; CHECK-SVE-NEXT:    fmov s0, w8
+; CHECK-SVE-NEXT:    fmov s1, w9
+; CHECK-SVE-NEXT:    mov v0.s[1], w10
+; CHECK-SVE-NEXT:    // kill: def $d0 killed $d0 killed $q0
+; CHECK-SVE-NEXT:    mov v1.s[1], w11
+; CHECK-SVE-NEXT:    // kill: def $d1 killed $d1 killed $q1
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-LABEL: test_fixed_extract:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE2p1-NEXT:    cset w8, mi
+; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-SVE2p1-NEXT:    umov w9, v0.h[4]
+; CHECK-SVE2p1-NEXT:    umov w10, v0.h[1]
+; CHECK-SVE2p1-NEXT:    umov w11, v0.h[5]
+; CHECK-SVE2p1-NEXT:    fmov s0, w8
+; CHECK-SVE2p1-NEXT:    fmov s1, w9
+; CHECK-SVE2p1-NEXT:    mov v0.s[1], w10
+; CHECK-SVE2p1-NEXT:    // kill: def $d0 killed $d0 killed $q0
+; CHECK-SVE2p1-NEXT:    mov v1.s[1], w11
+; CHECK-SVE2p1-NEXT:    // kill: def $d1 killed $d1 killed $q1
+; CHECK-SVE2p1-NEXT:    b use
+    %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+    %v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+    %v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+    tail call void @use(<2 x i1> %v0, <2 x i1> %v1)
+    ret void
+}
+
+declare void @use(...)
+
+attributes #0 = { nounwind }


        


More information about the llvm-commits mailing list