[llvm] clastb representation in existing IR, and AArch64 codegen (PR #112738)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 17 09:17:22 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Graham Hunter (huntergr-arm)

<details>
<summary>Changes</summary>

These commits show a possible representation of SVE's `clastb` instruction using existing IR instructions and intrinsics, along with DAGCombines to emit the actual instruction.

At 9 instructions to represent `clastb`, it feels a little fragile and may be changed by passes between LoopVectorize and codegen. While we can sink the loop-invariant terms back into the right block in CGP, I do wonder if we want a more direct intrinsic to represent this kind of operation.

Perhaps something like `llvm.vector.extract.last.active(data, mask)` ?

This is something we would use to support the CSA vectorization in #<!-- -->106560 for SVE, though we would prefer to use clastb inside the vector loop instead of after it. That patch uses an int max reduction to determine the index instead of the cttz.elts based approach in this PR, so we have another existing IR option to use if we want.


---
Full diff: https://github.com/llvm/llvm-project/pull/112738.diff


2 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+100) 
- (added) llvm/test/CodeGen/AArch64/sve-clastb.ll (+104) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b5657584016ea6..ccd89af6eb8770 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -19512,6 +19512,57 @@ performLastTrueTestVectorCombine(SDNode *N,
   return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::LAST_ACTIVE);
 }
 
+static SDValue
+performLastActiveExtractEltCombine(SDNode *N,
+                                   TargetLowering::DAGCombinerInfo &DCI,
+                                   const AArch64Subtarget *Subtarget) {
+  SDValue Index = N->getOperand(1);
+  // FIXME: Make this more generic. Should be a utility func somewhere?
+  if (Index.getOpcode() == ISD::ZERO_EXTEND)
+    Index = Index.getOperand(0);
+
+  // Looking for an add of an inverted value.
+  if (Index.getOpcode() != ISD::ADD)
+    return SDValue();
+
+  SDValue Size = Index.getOperand(1);
+
+  if (Size.getOpcode() == ISD::TRUNCATE)
+    Size = Size.getOperand(0);
+
+  // Check that we're looking at the size of the overall vector...
+  // FIXME: What about VSL codegen?
+  if (Size.getOpcode() != ISD::VSCALE)
+    return SDValue();
+
+  unsigned NElts = N->getOperand(0)->getValueType(0).getVectorElementCount().getKnownMinValue();
+  if (Size.getConstantOperandVal(0) != NElts)
+    return SDValue();
+
+  SDValue Invert = Index.getOperand(0);
+  if (Invert.getOpcode() != ISD::XOR)
+    return SDValue();
+
+  if (!Invert.getConstantOperandAPInt(1).isAllOnes())
+    return SDValue();
+
+  SDValue LZeroes = Invert.getOperand(0);
+  if (LZeroes.getOpcode() == ISD::TRUNCATE)
+    LZeroes = LZeroes.getOperand(0);
+
+  // Check that we're looking at a cttz.elts from a reversed predicate...
+  if (LZeroes.getOpcode() != AArch64ISD::CTTZ_ELTS)
+    return SDValue();
+
+  SDValue Pred = LZeroes.getOperand(0);
+  if (Pred.getOpcode() != ISD::VECTOR_REVERSE)
+    return SDValue();
+
+  // Matched a LASTB pattern.
+  return DCI.DAG.getNode(AArch64ISD::LASTB, SDLoc(N), N->getValueType(0),
+                         Pred.getOperand(0), N->getOperand(0));
+}
+
 static SDValue
 performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                const AArch64Subtarget *Subtarget) {
@@ -19520,6 +19571,8 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
     return Res;
   if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
     return Res;
+  if (SDValue Res = performLastActiveExtractEltCombine(N, DCI, Subtarget))
+    return Res;
 
   SelectionDAG &DAG = DCI.DAG;
   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
@@ -24363,6 +24416,50 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
   return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
 }
 
+static SDValue foldCSELOfLASTB(SDNode *N, SelectionDAG &DAG) {
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  AArch64CC::CondCode CC =
+      static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
+  SDValue PTAny = N->getOperand(3);
+
+  // FIXME: Handle the inverse?
+  if (Op0.getOpcode() != AArch64ISD::LASTB)
+    return SDValue();
+
+  if (PTAny.getOpcode() != AArch64ISD::PTEST_ANY)
+    return SDValue();
+
+  // Get the predicate...
+  SDValue LBPred = Op0.getOperand(0);
+
+  // Look through reinterprets...
+  SDValue PTestPG = PTAny.getOperand(0);
+  if (PTestPG.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    PTestPG = PTestPG.getOperand(0);
+
+  SDValue PTestOp = PTAny.getOperand(1);
+  if (PTestOp.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    PTestOp = PTestOp.getOperand(0);
+
+  // And compare against the csel cmp.
+  // Make sure the same predicate is used.
+  if (PTestOp != LBPred)
+    return SDValue();
+
+  // Make sure that PG for the test is either the same as the input or
+  // an explicit ptrue.
+  // FIXME:... look for ptrue_all instead of just ptrue...
+  if (PTestPG != LBPred && PTestPG.getOpcode() != AArch64ISD::PTRUE)
+    return SDValue();
+
+  if (CC != AArch64CC::NE)
+    return SDValue();
+
+  return DAG.getNode(AArch64ISD::CLASTB_N, SDLoc(N), N->getValueType(0),
+                     LBPred, Op1, Op0.getOperand(1));
+}
+
 // Optimize CSEL instructions
 static SDValue performCSELCombine(SDNode *N,
                                   TargetLowering::DAGCombinerInfo &DCI,
@@ -24379,6 +24476,9 @@ static SDValue performCSELCombine(SDNode *N,
   if (SDValue Folded = foldCSELofCTTZ(N, DAG))
 		return Folded;
 
+  if (SDValue CLastB = foldCSELOfLASTB(N, DAG))
+    return CLastB;
+
   return performCONDCombine(N, DCI, DAG, 2, 3);
 }
 
diff --git a/llvm/test/CodeGen/AArch64/sve-clastb.ll b/llvm/test/CodeGen/AArch64/sve-clastb.ll
new file mode 100644
index 00000000000000..e2ff1e478f6b87
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-clastb.ll
@@ -0,0 +1,104 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=aarch64-linux-unknown -mattr=+sve -o - < %s | FileCheck %s
+
+define i8 @clastb_i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %pg, i8 %existing) {
+; CHECK-LABEL: clastb_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    clastb w0, p0, w0, z0.b
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 16 x i1> @llvm.vector.reverse.nxv16i1(<vscale x 16 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv16i1(<vscale x 16 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 4
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 16 x i8> %data, i32 %idx
+  %res = select i1 %any.set, i8 %extr, i8 %existing
+  ret i8 %res
+}
+
+define i16 @clastb_i16(<vscale x 8 x i16> %data, <vscale x 8 x i1> %pg, i16 %existing) {
+; CHECK-LABEL: clastb_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    clastb w0, p0, w0, z0.h
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 8 x i1> @llvm.vector.reverse.nxv8i1(<vscale x 8 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv8i1(<vscale x 8 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv8i1(<vscale x 8 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 3
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 8 x i16> %data, i32 %idx
+  %res = select i1 %any.set, i16 %extr, i16 %existing
+  ret i16 %res
+}
+
+define i32 @clastb_i32(<vscale x 4 x i32> %data, <vscale x 4 x i1> %pg, i32 %existing) {
+; CHECK-LABEL: clastb_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    clastb w0, p0, w0, z0.s
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 4 x i1> @llvm.vector.reverse.nxv4i1(<vscale x 4 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv4i1(<vscale x 4 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 2
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 4 x i32> %data, i32 %idx
+  %res = select i1 %any.set, i32 %extr, i32 %existing
+  ret i32 %res
+}
+
+define i64 @clastb_i64(<vscale x 2 x i64> %data, <vscale x 2 x i1> %pg, i64 %existing) {
+; CHECK-LABEL: clastb_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    clastb x0, p0, x0, z0.d
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv2i1(<vscale x 2 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 1
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 2 x i64> %data, i32 %idx
+  %res = select i1 %any.set, i64 %extr, i64 %existing
+  ret i64 %res
+}
+
+define float @clastb_float(float %existing, <vscale x 4 x float> %data, <vscale x 4 x i1> %pg) {
+; CHECK-LABEL: clastb_float:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    clastb s0, p0, s0, z1.s
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 4 x i1> @llvm.vector.reverse.nxv4i1(<vscale x 4 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.float.nxv4i1(<vscale x 4 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.float()
+  %size = shl i32 %vscale, 2
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 4 x float> %data, i32 %idx
+  %res = select i1 %any.set, float %extr, float %existing
+  ret float %res
+}
+
+define double @clastb_double(double %existing, <vscale x 2 x double> %data, <vscale x 2 x i1> %pg) {
+; CHECK-LABEL: clastb_double:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    clastb d0, p0, d0, z1.d
+; CHECK-NEXT:    ret
+  %rev.pg = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %pg)
+  %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv2i1(<vscale x 2 x i1> %rev.pg, i1 false)
+  %any.set = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %pg)
+  %vscale = call i32 @llvm.vscale.i32()
+  %size = shl i32 %vscale, 1
+  %sub = sub i32 %size, %tz.cnt
+  %idx = sub i32 %sub, 1
+  %extr = extractelement <vscale x 2 x double> %data, i32 %idx
+  %res = select i1 %any.set, double %extr, double %existing
+  ret double %res
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/112738


More information about the llvm-commits mailing list