[llvm] [AArch64] Implement vector splitting for histogram intrinsic (PR #103037)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 13 03:27:07 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Max Beck-Jones (DevM-uk)
<details>
<summary>Changes</summary>
Adds support for wider-than-legal vector types for the histogram intrinsic (llvm.experimental.vector.histogram.add) by splitting the vector.
---
Full diff: https://github.com/llvm/llvm-project/pull/103037.diff
2 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+39)
- (modified) llvm/test/CodeGen/AArch64/sve2-histcnt.ll (+95)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7777aa4b50a370..7c9b34b272f17d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1128,6 +1128,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);
+ setTargetDAGCombine(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM);
+
// In case of strict alignment, avoid an excessive number of byte wide stores.
MaxStoresPerMemsetOptSize = 8;
MaxStoresPerMemset =
@@ -25434,6 +25436,41 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return NVCAST;
}
+static SDValue performHistogramCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+ if (!DCI.isBeforeLegalize())
+ return SDValue();
+
+ MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
+ SDLoc DL(HG);
+ SDValue Chain = HG->getChain();
+ SDValue Inc = HG->getInc();
+ SDValue Mask = HG->getMask();
+ SDValue Ptr = HG->getBasePtr();
+ SDValue Index = HG->getIndex();
+ SDValue Scale = HG->getScale();
+ SDValue IntID = HG->getIntID();
+ EVT MemVT = HG->getMemoryVT();
+ EVT IndexVT = Index.getValueType();
+ MachineMemOperand *MMO = HG->getMemOperand();
+ ISD::MemIndexType IndexType = HG->getIndexType();
+
+ if (IndexVT == MVT::nxv4i32 || IndexVT == MVT::nxv2i64)
+ return SDValue();
+
+ // Split vectors which are too wide
+ SDValue IndexLo, IndexHi, MaskLo, MaskHi;
+ std::tie(IndexLo, IndexHi) = DAG.SplitVector(Index, DL);
+ std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, DL);
+ SDValue HistogramOpsLo[] = {Chain, Inc, MaskLo, Ptr, IndexLo, Scale, IntID};
+ SDValue HChain = DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL,
+ HistogramOpsLo, MMO, IndexType);
+ SDValue HistogramOpsHi[] = {HChain, Inc, MaskHi, Ptr, IndexHi, Scale, IntID};
+ return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL,
+ HistogramOpsHi, MMO, IndexType);
+}
+
SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -25778,6 +25815,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performCTLZCombine(N, DAG, Subtarget);
case ISD::SCALAR_TO_VECTOR:
return performScalarToVectorCombine(N, DCI, DAG);
+ case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
+ return performHistogramCombine(N, DCI, DAG);
}
return SDValue();
}
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index 2874e47511e12f..56d5eb13ab12e3 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -169,4 +169,99 @@ define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vs
ret void
}
+define void @histogram_i64_4_lane(<vscale x 4 x ptr> %buckets, i64 %inc, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i64_4_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: mov z4.d, x0
+; CHECK-NEXT: ptrue p2.d
+; CHECK-NEXT: histcnt z2.d, p1/z, z0.d, z0.d
+; CHECK-NEXT: ld1d { z3.d }, p1/z, [z0.d]
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mad z2.d, p2/m, z4.d, z3.d
+; CHECK-NEXT: st1d { z2.d }, p1, [z0.d]
+; CHECK-NEXT: histcnt z0.d, p0/z, z1.d, z1.d
+; CHECK-NEXT: ld1d { z2.d }, p0/z, [z1.d]
+; CHECK-NEXT: mad z0.d, p2/m, z4.d, z2.d
+; CHECK-NEXT: st1d { z0.d }, p0, [z1.d]
+; CHECK-NEXT: ret
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i64(<vscale x 4 x ptr> %buckets, i64 %inc, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i64_8_lane(<vscale x 8 x ptr> %buckets, i64 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i64_8_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpklo p2.h, p0.b
+; CHECK-NEXT: mov z6.d, x0
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: punpklo p3.h, p2.b
+; CHECK-NEXT: punpkhi p2.h, p2.b
+; CHECK-NEXT: histcnt z4.d, p3/z, z0.d, z0.d
+; CHECK-NEXT: ld1d { z5.d }, p3/z, [z0.d]
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mad z4.d, p1/m, z6.d, z5.d
+; CHECK-NEXT: st1d { z4.d }, p3, [z0.d]
+; CHECK-NEXT: histcnt z0.d, p2/z, z1.d, z1.d
+; CHECK-NEXT: ld1d { z4.d }, p2/z, [z1.d]
+; CHECK-NEXT: mad z0.d, p1/m, z6.d, z4.d
+; CHECK-NEXT: st1d { z0.d }, p2, [z1.d]
+; CHECK-NEXT: punpklo p2.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: histcnt z0.d, p2/z, z2.d, z2.d
+; CHECK-NEXT: ld1d { z1.d }, p2/z, [z2.d]
+; CHECK-NEXT: mad z0.d, p1/m, z6.d, z1.d
+; CHECK-NEXT: st1d { z0.d }, p2, [z2.d]
+; CHECK-NEXT: histcnt z0.d, p0/z, z3.d, z3.d
+; CHECK-NEXT: ld1d { z1.d }, p0/z, [z3.d]
+; CHECK-NEXT: mad z0.d, p1/m, z6.d, z1.d
+; CHECK-NEXT: st1d { z0.d }, p0, [z3.d]
+; CHECK-NEXT: ret
+ call void @llvm.experimental.vector.histogram.add.nxv8p0.i64(<vscale x 8 x ptr> %buckets, i64 %inc, <vscale x 8 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i32_8_lane(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: mov z4.s, w1
+; CHECK-NEXT: ptrue p2.s
+; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i32, ptr %base, <vscale x 8 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_8_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: mov z4.s, w1
+; CHECK-NEXT: ptrue p2.s
+; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT: ld1h { z3.s }, p1/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT: st1h { z2.s }, p1, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z1.s, sxtw #1]
+; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT: st1h { z0.s }, p0, [x0, z1.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 8 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv8p0.i16(<vscale x 8 x ptr> %buckets, i16 %inc, <vscale x 8 x i1> %mask)
+ ret void
+}
+
+
attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }
``````````
</details>
https://github.com/llvm/llvm-project/pull/103037
More information about the llvm-commits
mailing list