[llvm] clastb representation in existing IR, and AArch64 codegen (PR #112738)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 17 09:16:48 PDT 2024


https://github.com/huntergr-arm created https://github.com/llvm/llvm-project/pull/112738

These commits show a possible representation of SVE's `clastb` instruction using existing IR instructions and intrinsics, along with DAGCombines to emit the actual instruction.

At 9 instructions to represent `clastb`, it feels a little fragile and may be changed by passes between LoopVectorize and codegen. While we can sink the loop-invariant terms back into the right block in CGP, I do wonder if we want a more direct intrinsic to represent this kind of operation.

Perhaps something like `llvm.vector.extract.last.active(data, mask)` ?

This is something we would use to support the CSA vectorization in #106560 for SVE, though we would prefer to use clastb inside the vector loop instead of after it. That patch uses an int max reduction to determine the index instead of the cttz.elts based approach in this PR, so we have another existing IR option to use if we want.


>From 3f781f1bc6fedb82a1fd19e7a6161adc723d5ab2 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Tue, 15 Oct 2024 12:17:39 +0000
Subject: [PATCH 1/3] Initial clastb test

---
 llvm/test/CodeGen/AArch64/sve-clastb.ll | 169 ++++++++++++++++++++++++
 1 file changed, 169 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/sve-clastb.ll

diff --git a/llvm/test/CodeGen/AArch64/sve-clastb.ll b/llvm/test/CodeGen/AArch64/sve-clastb.ll
new file mode 100644
index 00000000000000..83b515c39cfa98
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-clastb.ll
@@ -0,0 +1,169 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=aarch64-linux-unknown -mattr=+sve -o - < %s | FileCheck %s
+
+define i8 @clastb_i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %pg, i8 %existing) {
+; CHECK-LABEL: clastb_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p1.b
+; CHECK-NEXT:    rdvl x9, #1
+; CHECK-NEXT:    rev p2.b, p0.b
+; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
+; CHECK-NEXT:    cntp x8, p1, p1.b
+; CHECK-NEXT:    mvn w8, w8
+; CHECK-NEXT:    add w8, w8, w9
+; CHECK-NEXT:    whilels p1.b, xzr, x8
+; CHECK-NEXT:    ptest p0, p0.b
+; CHECK-NEXT:    lastb w8, p1, z0.b
+; CHECK-NEXT:    csel w0, w8, w0, ne
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 16 x i1> @llvm.vector.reverse.nxv16i1(<vscale x 16 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv16i1(<vscale x 16 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 4
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 16 x i8> %data, i32 %idx
+  %res = select i1 %any.set, i8 %extr, i8 %existing
+  ret i8 %res
+}
+
+define i16 @clastb_i16(<vscale x 8 x i16> %data, <vscale x 8 x i1> %pg, i16 %existing) {
+; CHECK-LABEL: clastb_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    cnth x9
+; CHECK-NEXT:    rev p2.h, p0.h
+; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
+; CHECK-NEXT:    cntp x8, p1, p1.h
+; CHECK-NEXT:    mvn w8, w8
+; CHECK-NEXT:    add w8, w8, w9
+; CHECK-NEXT:    whilels p1.h, xzr, x8
+; CHECK-NEXT:    lastb w8, p1, z0.h
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    ptest p1, p0.b
+; CHECK-NEXT:    csel w0, w8, w0, ne
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 8 x i1> @llvm.vector.reverse.nxv8i1(<vscale x 8 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv8i1(<vscale x 8 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv8i1(<vscale x 8 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 3
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 8 x i16> %data, i32 %idx
+  %res = select i1 %any.set, i16 %extr, i16 %existing
+  ret i16 %res
+}
+
+define i32 @clastb_i32(<vscale x 4 x i32> %data, <vscale x 4 x i1> %pg, i32 %existing) {
+; CHECK-LABEL: clastb_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    cntw x9
+; CHECK-NEXT:    rev p2.s, p0.s
+; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
+; CHECK-NEXT:    cntp x8, p1, p1.s
+; CHECK-NEXT:    mvn w8, w8
+; CHECK-NEXT:    add w8, w8, w9
+; CHECK-NEXT:    whilels p1.s, xzr, x8
+; CHECK-NEXT:    lastb w8, p1, z0.s
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    ptest p1, p0.b
+; CHECK-NEXT:    csel w0, w8, w0, ne
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 4 x i1> @llvm.vector.reverse.nxv4i1(<vscale x 4 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv4i1(<vscale x 4 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 2
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 4 x i32> %data, i32 %idx
+  %res = select i1 %any.set, i32 %extr, i32 %existing
+  ret i32 %res
+}
+
+define i64 @clastb_i64(<vscale x 2 x i64> %data, <vscale x 2 x i1> %pg, i64 %existing) {
+; CHECK-LABEL: clastb_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    rev p2.d, p0.d
+; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
+; CHECK-NEXT:    cntp x8, p1, p1.d
+; CHECK-NEXT:    mvn w8, w8
+; CHECK-NEXT:    add w8, w8, w9
+; CHECK-NEXT:    whilels p1.d, xzr, x8
+; CHECK-NEXT:    lastb x8, p1, z0.d
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    ptest p1, p0.b
+; CHECK-NEXT:    csel x0, x8, x0, ne
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv2i1(<vscale x 2 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 1
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 2 x i64> %data, i32 %idx
+  %res = select i1 %any.set, i64 %extr, i64 %existing
+  ret i64 %res
+}
+
+define float @clastb_float(<vscale x 4 x float> %data, <vscale x 4 x i1> %pg, float %existing) {
+; CHECK-LABEL: clastb_float:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    cntw x9
+; CHECK-NEXT:    rev p2.s, p0.s
+; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
+; CHECK-NEXT:    cntp x8, p1, p1.s
+; CHECK-NEXT:    mvn w8, w8
+; CHECK-NEXT:    add w8, w8, w9
+; CHECK-NEXT:    whilels p1.s, xzr, x8
+; CHECK-NEXT:    lastb s0, p1, z0.s
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    ptest p1, p0.b
+; CHECK-NEXT:    fcsel s0, s0, s1, ne
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 4 x i1> @llvm.vector.reverse.nxv4i1(<vscale x 4 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.float.nxv4i1(<vscale x 4 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.float()
+  %size = shl i32 %vscale, 2
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 4 x float> %data, i32 %idx
+  %res = select i1 %any.set, float %extr, float %existing
+  ret float %res
+}
+
+define double @clastb_double(<vscale x 2 x double> %data, <vscale x 2 x i1> %pg, double %existing) {
+; CHECK-LABEL: clastb_double:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    rev p2.d, p0.d
+; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
+; CHECK-NEXT:    cntp x8, p1, p1.d
+; CHECK-NEXT:    mvn w8, w8
+; CHECK-NEXT:    add w8, w8, w9
+; CHECK-NEXT:    whilels p1.d, xzr, x8
+; CHECK-NEXT:    lastb d0, p1, z0.d
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    ptest p1, p0.b
+; CHECK-NEXT:    fcsel d0, d0, d1, ne
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv2i1(<vscale x 2 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 1
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 2 x double> %data, i32 %idx
+  %res = select i1 %any.set, double %extr, double %existing
+  ret double %res
+}

>From 30222c618711b1a46cf9697085bbed804b1a2684 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Tue, 15 Oct 2024 12:17:57 +0000
Subject: [PATCH 2/3] DAGCombine for lastb

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 53 ++++++++++++++++
 llvm/test/CodeGen/AArch64/sve-clastb.ll       | 60 ++-----------------
 2 files changed, 59 insertions(+), 54 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b5657584016ea6..ad15512de54aa4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -19512,6 +19512,57 @@ performLastTrueTestVectorCombine(SDNode *N,
   return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::LAST_ACTIVE);
 }
 
+static SDValue
+performLastActiveExtractEltCombine(SDNode *N,
+                                   TargetLowering::DAGCombinerInfo &DCI,
+                                   const AArch64Subtarget *Subtarget) {
+  SDValue Index = N->getOperand(1);
+  // FIXME: Make this more generic. Should be a utility func somewhere?
+  if (Index.getOpcode() == ISD::ZERO_EXTEND)
+    Index = Index.getOperand(0);
+
+  // Looking for an add of an inverted value.
+  if (Index.getOpcode() != ISD::ADD)
+    return SDValue();
+
+  SDValue Size = Index.getOperand(1);
+
+  if (Size.getOpcode() == ISD::TRUNCATE)
+    Size = Size.getOperand(0);
+
+  // Check that we're looking at the size of the overall vector...
+  // FIXME: What about VSL codegen?
+  if (Size.getOpcode() != ISD::VSCALE)
+    return SDValue();
+
+  unsigned NElts = N->getOperand(0)->getValueType(0).getVectorElementCount().getKnownMinValue();
+  if (Size.getConstantOperandVal(0) != NElts)
+    return SDValue();
+
+  SDValue Invert = Index.getOperand(0);
+  if (Invert.getOpcode() != ISD::XOR)
+    return SDValue();
+
+  if (!Invert.getConstantOperandAPInt(1).isAllOnes())
+    return SDValue();
+
+  SDValue LZeroes = Invert.getOperand(0);
+  if (LZeroes.getOpcode() == ISD::TRUNCATE)
+    LZeroes = LZeroes.getOperand(0);
+
+  // Check that we're looking at a cttz.elts from a reversed predicate...
+  if (LZeroes.getOpcode() != AArch64ISD::CTTZ_ELTS)
+    return SDValue();
+
+  SDValue Pred = LZeroes.getOperand(0);
+  if (Pred.getOpcode() != ISD::VECTOR_REVERSE)
+    return SDValue();
+
+  // Matched a LASTB pattern.
+  return DCI.DAG.getNode(AArch64ISD::LASTB, SDLoc(N), N->getValueType(0),
+                         Pred.getOperand(0), N->getOperand(0));
+}
+
 static SDValue
 performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                const AArch64Subtarget *Subtarget) {
@@ -19520,6 +19571,8 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
     return Res;
   if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
     return Res;
+  if (SDValue Res = performLastActiveExtractEltCombine(N, DCI, Subtarget))
+    return Res;
 
   SelectionDAG &DAG = DCI.DAG;
   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
diff --git a/llvm/test/CodeGen/AArch64/sve-clastb.ll b/llvm/test/CodeGen/AArch64/sve-clastb.ll
index 83b515c39cfa98..cb09780c697c2e 100644
--- a/llvm/test/CodeGen/AArch64/sve-clastb.ll
+++ b/llvm/test/CodeGen/AArch64/sve-clastb.ll
@@ -4,16 +4,8 @@
 define i8 @clastb_i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %pg, i8 %existing) {
 ; CHECK-LABEL: clastb_i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.b
-; CHECK-NEXT:    rdvl x9, #1
-; CHECK-NEXT:    rev p2.b, p0.b
-; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
-; CHECK-NEXT:    cntp x8, p1, p1.b
-; CHECK-NEXT:    mvn w8, w8
-; CHECK-NEXT:    add w8, w8, w9
-; CHECK-NEXT:    whilels p1.b, xzr, x8
+; CHECK-NEXT:    lastb w8, p0, z0.b
 ; CHECK-NEXT:    ptest p0, p0.b
-; CHECK-NEXT:    lastb w8, p1, z0.b
 ; CHECK-NEXT:    csel w0, w8, w0, ne
 ; CHECK-NEXT:    ret
   %rev.pg = call <vscale x 16 x i1> @llvm.vector.reverse.nxv16i1(<vscale x 16 x i1> %pg)
@@ -31,15 +23,7 @@ define i8 @clastb_i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %pg, i8 %exist
 define i16 @clastb_i16(<vscale x 8 x i16> %data, <vscale x 8 x i1> %pg, i16 %existing) {
 ; CHECK-LABEL: clastb_i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.h
-; CHECK-NEXT:    cnth x9
-; CHECK-NEXT:    rev p2.h, p0.h
-; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
-; CHECK-NEXT:    cntp x8, p1, p1.h
-; CHECK-NEXT:    mvn w8, w8
-; CHECK-NEXT:    add w8, w8, w9
-; CHECK-NEXT:    whilels p1.h, xzr, x8
-; CHECK-NEXT:    lastb w8, p1, z0.h
+; CHECK-NEXT:    lastb w8, p0, z0.h
 ; CHECK-NEXT:    ptrue p1.h
 ; CHECK-NEXT:    ptest p1, p0.b
 ; CHECK-NEXT:    csel w0, w8, w0, ne
@@ -59,15 +43,7 @@ define i16 @clastb_i16(<vscale x 8 x i16> %data, <vscale x 8 x i1> %pg, i16 %exi
 define i32 @clastb_i32(<vscale x 4 x i32> %data, <vscale x 4 x i1> %pg, i32 %existing) {
 ; CHECK-LABEL: clastb_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.s
-; CHECK-NEXT:    cntw x9
-; CHECK-NEXT:    rev p2.s, p0.s
-; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
-; CHECK-NEXT:    cntp x8, p1, p1.s
-; CHECK-NEXT:    mvn w8, w8
-; CHECK-NEXT:    add w8, w8, w9
-; CHECK-NEXT:    whilels p1.s, xzr, x8
-; CHECK-NEXT:    lastb w8, p1, z0.s
+; CHECK-NEXT:    lastb w8, p0, z0.s
 ; CHECK-NEXT:    ptrue p1.s
 ; CHECK-NEXT:    ptest p1, p0.b
 ; CHECK-NEXT:    csel w0, w8, w0, ne
@@ -87,15 +63,7 @@ define i32 @clastb_i32(<vscale x 4 x i32> %data, <vscale x 4 x i1> %pg, i32 %exi
 define i64 @clastb_i64(<vscale x 2 x i64> %data, <vscale x 2 x i1> %pg, i64 %existing) {
 ; CHECK-LABEL: clastb_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.d
-; CHECK-NEXT:    cntd x9
-; CHECK-NEXT:    rev p2.d, p0.d
-; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
-; CHECK-NEXT:    cntp x8, p1, p1.d
-; CHECK-NEXT:    mvn w8, w8
-; CHECK-NEXT:    add w8, w8, w9
-; CHECK-NEXT:    whilels p1.d, xzr, x8
-; CHECK-NEXT:    lastb x8, p1, z0.d
+; CHECK-NEXT:    lastb x8, p0, z0.d
 ; CHECK-NEXT:    ptrue p1.d
 ; CHECK-NEXT:    ptest p1, p0.b
 ; CHECK-NEXT:    csel x0, x8, x0, ne
@@ -115,15 +83,7 @@ define i64 @clastb_i64(<vscale x 2 x i64> %data, <vscale x 2 x i1> %pg, i64 %exi
 define float @clastb_float(<vscale x 4 x float> %data, <vscale x 4 x i1> %pg, float %existing) {
 ; CHECK-LABEL: clastb_float:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.s
-; CHECK-NEXT:    cntw x9
-; CHECK-NEXT:    rev p2.s, p0.s
-; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
-; CHECK-NEXT:    cntp x8, p1, p1.s
-; CHECK-NEXT:    mvn w8, w8
-; CHECK-NEXT:    add w8, w8, w9
-; CHECK-NEXT:    whilels p1.s, xzr, x8
-; CHECK-NEXT:    lastb s0, p1, z0.s
+; CHECK-NEXT:    lastb s0, p0, z0.s
 ; CHECK-NEXT:    ptrue p1.s
 ; CHECK-NEXT:    ptest p1, p0.b
 ; CHECK-NEXT:    fcsel s0, s0, s1, ne
@@ -143,15 +103,7 @@ define float @clastb_float(<vscale x 4 x float> %data, <vscale x 4 x i1> %pg, fl
 define double @clastb_double(<vscale x 2 x double> %data, <vscale x 2 x i1> %pg, double %existing) {
 ; CHECK-LABEL: clastb_double:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.d
-; CHECK-NEXT:    cntd x9
-; CHECK-NEXT:    rev p2.d, p0.d
-; CHECK-NEXT:    brkb p1.b, p1/z, p2.b
-; CHECK-NEXT:    cntp x8, p1, p1.d
-; CHECK-NEXT:    mvn w8, w8
-; CHECK-NEXT:    add w8, w8, w9
-; CHECK-NEXT:    whilels p1.d, xzr, x8
-; CHECK-NEXT:    lastb d0, p1, z0.d
+; CHECK-NEXT:    lastb d0, p0, z0.d
 ; CHECK-NEXT:    ptrue p1.d
 ; CHECK-NEXT:    ptest p1, p0.b
 ; CHECK-NEXT:    fcsel d0, d0, d1, ne

>From b80e292d1e08b63abdadd645431649c1912c293e Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Tue, 15 Oct 2024 14:26:42 +0000
Subject: [PATCH 3/3] DAGCombine lastb + csel into clastb

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 47 +++++++++++++++++++
 llvm/test/CodeGen/AArch64/sve-clastb.ll       | 33 ++++---------
 2 files changed, 55 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ad15512de54aa4..ccd89af6eb8770 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24416,6 +24416,50 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
   return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
 }
 
+static SDValue foldCSELOfLASTB(SDNode *N, SelectionDAG &DAG) {
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  AArch64CC::CondCode CC =
+      static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
+  SDValue PTAny = N->getOperand(3);
+
+  // FIXME: Handle the inverse?
+  if (Op0.getOpcode() != AArch64ISD::LASTB)
+    return SDValue();
+
+  if (PTAny.getOpcode() != AArch64ISD::PTEST_ANY)
+    return SDValue();
+
+  // Get the predicate...
+  SDValue LBPred = Op0.getOperand(0);
+
+  // Look through reinterprets...
+  SDValue PTestPG = PTAny.getOperand(0);
+  if (PTestPG.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    PTestPG = PTestPG.getOperand(0);
+
+  SDValue PTestOp = PTAny.getOperand(1);
+  if (PTestOp.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    PTestOp = PTestOp.getOperand(0);
+
+  // And compare against the csel cmp.
+  // Make sure the same predicate is used.
+  if (PTestOp != LBPred)
+    return SDValue();
+
+  // Make sure that PG for the test is either the same as the input or
+  // an explicit ptrue.
+  // FIXME:... look for ptrue_all instead of just ptrue...
+  if (PTestPG != LBPred && PTestPG.getOpcode() != AArch64ISD::PTRUE)
+    return SDValue();
+
+  if (CC != AArch64CC::NE)
+    return SDValue();
+
+  return DAG.getNode(AArch64ISD::CLASTB_N, SDLoc(N), N->getValueType(0),
+                     LBPred, Op1, Op0.getOperand(1));
+}
+
 // Optimize CSEL instructions
 static SDValue performCSELCombine(SDNode *N,
                                   TargetLowering::DAGCombinerInfo &DCI,
@@ -24432,6 +24476,9 @@ static SDValue performCSELCombine(SDNode *N,
   if (SDValue Folded = foldCSELofCTTZ(N, DAG))
 		return Folded;
 
+  if (SDValue CLastB = foldCSELOfLASTB(N, DAG))
+    return CLastB;
+
   return performCONDCombine(N, DCI, DAG, 2, 3);
 }
 
diff --git a/llvm/test/CodeGen/AArch64/sve-clastb.ll b/llvm/test/CodeGen/AArch64/sve-clastb.ll
index cb09780c697c2e..e2ff1e478f6b87 100644
--- a/llvm/test/CodeGen/AArch64/sve-clastb.ll
+++ b/llvm/test/CodeGen/AArch64/sve-clastb.ll
@@ -4,9 +4,7 @@
 define i8 @clastb_i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %pg, i8 %existing) {
 ; CHECK-LABEL: clastb_i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    lastb w8, p0, z0.b
-; CHECK-NEXT:    ptest p0, p0.b
-; CHECK-NEXT:    csel w0, w8, w0, ne
+; CHECK-NEXT:    clastb w0, p0, w0, z0.b
 ; CHECK-NEXT:    ret
   %rev.pg = call <vscale x 16 x i1> @llvm.vector.reverse.nxv16i1(<vscale x 16 x i1> %pg)
   %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1> %rev.pg, i1 false)
@@ -23,10 +21,7 @@ define i8 @clastb_i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %pg, i8 %exist
 define i16 @clastb_i16(<vscale x 8 x i16> %data, <vscale x 8 x i1> %pg, i16 %existing) {
 ; CHECK-LABEL: clastb_i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    lastb w8, p0, z0.h
-; CHECK-NEXT:    ptrue p1.h
-; CHECK-NEXT:    ptest p1, p0.b
-; CHECK-NEXT:    csel w0, w8, w0, ne
+; CHECK-NEXT:    clastb w0, p0, w0, z0.h
 ; CHECK-NEXT:    ret
   %rev.pg = call <vscale x 8 x i1> @llvm.vector.reverse.nxv8i1(<vscale x 8 x i1> %pg)
   %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv8i1(<vscale x 8 x i1> %rev.pg, i1 false)
@@ -43,10 +38,7 @@ define i16 @clastb_i16(<vscale x 8 x i16> %data, <vscale x 8 x i1> %pg, i16 %exi
 define i32 @clastb_i32(<vscale x 4 x i32> %data, <vscale x 4 x i1> %pg, i32 %existing) {
 ; CHECK-LABEL: clastb_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    lastb w8, p0, z0.s
-; CHECK-NEXT:    ptrue p1.s
-; CHECK-NEXT:    ptest p1, p0.b
-; CHECK-NEXT:    csel w0, w8, w0, ne
+; CHECK-NEXT:    clastb w0, p0, w0, z0.s
 ; CHECK-NEXT:    ret
   %rev.pg = call <vscale x 4 x i1> @llvm.vector.reverse.nxv4i1(<vscale x 4 x i1> %pg)
   %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv4i1(<vscale x 4 x i1> %rev.pg, i1 false)
@@ -63,10 +55,7 @@ define i32 @clastb_i32(<vscale x 4 x i32> %data, <vscale x 4 x i1> %pg, i32 %exi
 define i64 @clastb_i64(<vscale x 2 x i64> %data, <vscale x 2 x i1> %pg, i64 %existing) {
 ; CHECK-LABEL: clastb_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    lastb x8, p0, z0.d
-; CHECK-NEXT:    ptrue p1.d
-; CHECK-NEXT:    ptest p1, p0.b
-; CHECK-NEXT:    csel x0, x8, x0, ne
+; CHECK-NEXT:    clastb x0, p0, x0, z0.d
 ; CHECK-NEXT:    ret
   %rev.pg = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %pg)
   %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv2i1(<vscale x 2 x i1> %rev.pg, i1 false)
@@ -80,13 +69,10 @@ define i64 @clastb_i64(<vscale x 2 x i64> %data, <vscale x 2 x i1> %pg, i64 %exi
   ret i64 %res
 }
 
-define float @clastb_float(<vscale x 4 x float> %data, <vscale x 4 x i1> %pg, float %existing) {
+define float @clastb_float(float %existing, <vscale x 4 x float> %data, <vscale x 4 x i1> %pg) {
 ; CHECK-LABEL: clastb_float:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    lastb s0, p0, z0.s
-; CHECK-NEXT:    ptrue p1.s
-; CHECK-NEXT:    ptest p1, p0.b
-; CHECK-NEXT:    fcsel s0, s0, s1, ne
+; CHECK-NEXT:    clastb s0, p0, s0, z1.s
 ; CHECK-NEXT:    ret
   %rev.pg = call <vscale x 4 x i1> @llvm.vector.reverse.nxv4i1(<vscale x 4 x i1> %pg)
   %tz.cnt = call i32 @llvm.experimental.cttz.elts.float.nxv4i1(<vscale x 4 x i1> %rev.pg, i1 false)
@@ -100,13 +86,10 @@ define float @clastb_float(<vscale x 4 x float> %data, <vscale x 4 x i1> %pg, fl
   ret float %res
 }
 
-define double @clastb_double(<vscale x 2 x double> %data, <vscale x 2 x i1> %pg, double %existing) {
+define double @clastb_double(double %existing, <vscale x 2 x double> %data, <vscale x 2 x i1> %pg) {
 ; CHECK-LABEL: clastb_double:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    lastb d0, p0, z0.d
-; CHECK-NEXT:    ptrue p1.d
-; CHECK-NEXT:    ptest p1, p0.b
-; CHECK-NEXT:    fcsel d0, d0, d1, ne
+; CHECK-NEXT:    clastb d0, p0, d0, z1.d
 ; CHECK-NEXT:    ret
   %rev.pg = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %pg)
   %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv2i1(<vscale x 2 x i1> %rev.pg, i1 false)



More information about the llvm-commits mailing list