[llvm] [AArch64] Implement promotion type legalisation for histogram intrinsic (PR #101017)
Max Beck-Jones via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 29 07:14:58 PDT 2024
https://github.com/DevM-uk created https://github.com/llvm/llvm-project/pull/101017
Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.
>From dc87f0c7fe826044bb179a8f3c29cb049e2b41e4 Mon Sep 17 00:00:00 2001
From: Max Beck-Jones <max.beck-jones at arm.com>
Date: Mon, 29 Jul 2024 11:48:36 +0000
Subject: [PATCH] [AArch64] Implement promotion type legalisation for histogram
intrinsic
Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 32 +++--
llvm/test/CodeGen/AArch64/sve2-histcnt.ll | 119 ++++++++++++++++++
2 files changed, 141 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1e9da9b819bdd..153d5fe28be7b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1775,9 +1775,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
// Histcnt is SVE2 only
- if (Subtarget->hasSVE2())
+ if (Subtarget->hasSVE2()) {
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
Custom);
+ setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
+ setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
+ }
}
@@ -28018,9 +28021,17 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
EVT IndexVT = Index.getValueType();
EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
IndexVT.getVectorElementCount());
+ EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4
+ ? MVT::i32
+ : MVT::i64;
+ EVT IncSplatVT = EVT::getVectorVT(*DAG.getContext(), IncExtVT,
+ IndexVT.getVectorElementCount());
+ bool ExtTrunc = IncSplatVT != MemVT;
+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
- SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
- SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
+ SDValue PassThru = DAG.getSplatVector(IncSplatVT, DL, Zero);
+ SDValue IncSplat = DAG.getSplatVector(
+ IncSplatVT, DL, DAG.getAnyExtOrTrunc(Inc, DL, IncExtVT));
SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};
MachineMemOperand *MMO = HG->getMemOperand();
@@ -28029,18 +28040,19 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
MMO->getAlign(), MMO->getAAInfo());
ISD::MemIndexType IndexType = HG->getIndexType();
- SDValue Gather =
- DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
- GMMO, IndexType, ISD::NON_EXTLOAD);
+ SDValue Gather = DAG.getMaskedGather(
+ DAG.getVTList(IncSplatVT, MVT::Other), MemVT, DL, Ops, GMMO, IndexType,
+ ExtTrunc ? ISD::EXTLOAD : ISD::NON_EXTLOAD);
SDValue GChain = Gather.getValue(1);
// Perform the histcnt, multiply by inc, add to bucket data.
- SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncExtVT);
SDValue HistCnt =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
- SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, IncSplatVT, HistCnt, IncSplat);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, IncSplatVT, Gather, Mul);
// Create an MMO for the scatter, without load|store flags.
MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
@@ -28049,7 +28061,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
- ScatterOps, SMMO, IndexType, false);
+ ScatterOps, SMMO, IndexType, ExtTrunc);
return Scatter;
}
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index db164e288abde..2874e47511e12 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -50,4 +50,123 @@ define void @histogram_i32_literal_noscale(ptr %base, <vscale x 4 x i32> %indice
ret void
}
+define void @histogram_i32_promote(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i32 %inc) #0 {
+; CHECK-LABEL: histogram_i32_promote:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1w { z2.d }, p0/z, [x0, z0.d, lsl #2]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1w { z1.d }, p0, [x0, z0.d, lsl #2]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i32, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i32(<vscale x 2 x ptr> %buckets, i32 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, w1
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 %inc, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i8(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, w1
+; CHECK-NEXT: ld1b { z2.s }, p0/z, [x0, z0.s, sxtw]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1b { z1.s }, p0, [x0, z0.s, sxtw]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i8, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i8(<vscale x 4 x ptr> %buckets, i8 %inc, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16_2_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1h { z2.d }, p0/z, [x0, z0.d, lsl #1]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1h { z1.d }, p0, [x0, z0.d, lsl #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i16(<vscale x 2 x ptr> %buckets, i16 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i8_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8_2_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1b { z2.d }, p0/z, [x0, z0.d]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1b { z1.d }, p0, [x0, z0.d]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i8, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i8(<vscale x 2 x ptr> %buckets, i8 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: add z1.s, z2.s, z1.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 1, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: adr z1.s, [z2.s, z1.s, lsl #1]
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 2, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_3:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, #3 // =0x3
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 3, <vscale x 4 x i1> %mask)
+ ret void
+}
+
attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }
More information about the llvm-commits
mailing list