[llvm] [AArch64] Improve index selection for histograms (PR #111150)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 4 06:24:17 PDT 2024


https://github.com/JamesChesterman created https://github.com/llvm/llvm-project/pull/111150

Search for extends to the index used in a histogram operation then perform a truncate on it. This avoids the need to split the instruction in two.

>From 5db9c2e9165f3fd948d20f8b4a820e0fd93b6eca Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 4 Oct 2024 12:38:25 +0000
Subject: [PATCH] [AArch64] Improve index selection for histograms

Search for extends to the index used in a histogram operation
then perform a truncate on it. This avoids the need to split the
instruction in two.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 37 +++++++++-
 llvm/test/CodeGen/AArch64/sve2-histcnt.ll     | 73 +++++++++++++++++++
 2 files changed, 107 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 48e1b96d841efb..545d5b59c64562 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1114,7 +1114,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                        ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
                        ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});
 
-  setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER});
+  setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
 
   setTargetDAGCombine(ISD::FP_EXTEND);
 
@@ -24079,12 +24079,42 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
 
 static SDValue performMaskedGatherScatterCombine(
     SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
-  MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
-  assert(MGS && "Can only combine gather load or scatter store nodes");
+  MaskedHistogramSDNode *HG;
+  MaskedGatherScatterSDNode *MGS;
+  if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+    HG = cast<MaskedHistogramSDNode>(N);
+  } else {
+    MGS = cast<MaskedGatherScatterSDNode>(N);
+  }
+  assert((HG || MGS) &&
+         "Can only combine gather load, scatter store or histogram nodes");
 
   if (!DCI.isBeforeLegalize())
     return SDValue();
 
+  if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+    SDLoc DL(HG);
+    SDValue Index = HG->getIndex();
+    if (ISD::isExtOpcode(Index->getOpcode())) {
+      SDValue Chain = HG->getChain();
+      SDValue Inc = HG->getInc();
+      SDValue Mask = HG->getMask();
+      SDValue BasePtr = HG->getBasePtr();
+      SDValue Scale = HG->getScale();
+      SDValue IntID = HG->getIntID();
+      EVT MemVT = HG->getMemoryVT();
+      MachineMemOperand *MMO = HG->getMemOperand();
+      ISD::MemIndexType IndexType = HG->getIndexType();
+      SDValue ExtOp = Index.getOperand(0);
+      auto SrcType = ExtOp.getValueType();
+      auto TruncatedIndex = DAG.getAnyExtOrTrunc(Index, DL, SrcType);
+      SDValue Ops[] = {Chain, Inc, Mask, BasePtr, TruncatedIndex, Scale, IntID};
+      return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
+                                    MMO, IndexType);
+    }
+    return SDValue();
+  }
+
   SDLoc DL(MGS);
   SDValue Chain = MGS->getChain();
   SDValue Scale = MGS->getScale();
@@ -26277,6 +26307,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performMSTORECombine(N, DCI, DAG, Subtarget);
   case ISD::MGATHER:
   case ISD::MSCATTER:
+  case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return performMaskedGatherScatterCombine(N, DCI, DAG);
   case ISD::FP_EXTEND:
     return performFPExtendCombine(N, DAG, DCI, Subtarget);
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index dd0b9639a8fc2f..42fff1ec7c532f 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -267,5 +267,78 @@ define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %i
   ret void
 }
 
+define void @histogram_i32_extend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_extend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+define void @histogram_i32_8_lane_extend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_extend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.s, w1
+; CHECK-NEXT:    ptrue p2.s
+; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT:    st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+  ret void
+}
+define void @histogram_i32_sextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_i32_sextend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+define void @histogram_i32_8_lane_sextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_sextend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.s, w1
+; CHECK-NEXT:    ptrue p2.s
+; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT:    st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = sext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+  ret void
+}
+
 
 attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }



More information about the llvm-commits mailing list