[llvm] [AArch64] Implement promotion type legalisation for histogram intrinsic (PR #101017)
Max Beck-Jones via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 31 05:45:37 PDT 2024
https://github.com/DevM-uk updated https://github.com/llvm/llvm-project/pull/101017
>From dc87f0c7fe826044bb179a8f3c29cb049e2b41e4 Mon Sep 17 00:00:00 2001
From: Max Beck-Jones <max.beck-jones at arm.com>
Date: Mon, 29 Jul 2024 11:48:36 +0000
Subject: [PATCH 1/2] [AArch64] Implement promotion type legalisation for
histogram intrinsic
Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 32 +++--
llvm/test/CodeGen/AArch64/sve2-histcnt.ll | 119 ++++++++++++++++++
2 files changed, 141 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1e9da9b819bdd..153d5fe28be7b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1775,9 +1775,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
// Histcnt is SVE2 only
- if (Subtarget->hasSVE2())
+ if (Subtarget->hasSVE2()) {
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
Custom);
+ setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
+ setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
+ }
}
@@ -28018,9 +28021,17 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
EVT IndexVT = Index.getValueType();
EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
IndexVT.getVectorElementCount());
+ EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4
+ ? MVT::i32
+ : MVT::i64;
+ EVT IncSplatVT = EVT::getVectorVT(*DAG.getContext(), IncExtVT,
+ IndexVT.getVectorElementCount());
+ bool ExtTrunc = IncSplatVT != MemVT;
+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
- SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
- SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
+ SDValue PassThru = DAG.getSplatVector(IncSplatVT, DL, Zero);
+ SDValue IncSplat = DAG.getSplatVector(
+ IncSplatVT, DL, DAG.getAnyExtOrTrunc(Inc, DL, IncExtVT));
SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};
MachineMemOperand *MMO = HG->getMemOperand();
@@ -28029,18 +28040,19 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
MMO->getAlign(), MMO->getAAInfo());
ISD::MemIndexType IndexType = HG->getIndexType();
- SDValue Gather =
- DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
- GMMO, IndexType, ISD::NON_EXTLOAD);
+ SDValue Gather = DAG.getMaskedGather(
+ DAG.getVTList(IncSplatVT, MVT::Other), MemVT, DL, Ops, GMMO, IndexType,
+ ExtTrunc ? ISD::EXTLOAD : ISD::NON_EXTLOAD);
SDValue GChain = Gather.getValue(1);
// Perform the histcnt, multiply by inc, add to bucket data.
- SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncExtVT);
SDValue HistCnt =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
- SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, IncSplatVT, HistCnt, IncSplat);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, IncSplatVT, Gather, Mul);
// Create an MMO for the scatter, without load|store flags.
MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
@@ -28049,7 +28061,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
- ScatterOps, SMMO, IndexType, false);
+ ScatterOps, SMMO, IndexType, ExtTrunc);
return Scatter;
}
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index db164e288abde..2874e47511e12 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -50,4 +50,123 @@ define void @histogram_i32_literal_noscale(ptr %base, <vscale x 4 x i32> %indice
ret void
}
+define void @histogram_i32_promote(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i32 %inc) #0 {
+; CHECK-LABEL: histogram_i32_promote:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1w { z2.d }, p0/z, [x0, z0.d, lsl #2]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1w { z1.d }, p0, [x0, z0.d, lsl #2]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i32, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i32(<vscale x 2 x ptr> %buckets, i32 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, w1
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 %inc, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i8(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, w1
+; CHECK-NEXT: ld1b { z2.s }, p0/z, [x0, z0.s, sxtw]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1b { z1.s }, p0, [x0, z0.s, sxtw]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i8, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i8(<vscale x 4 x ptr> %buckets, i8 %inc, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16_2_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1h { z2.d }, p0/z, [x0, z0.d, lsl #1]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1h { z1.d }, p0, [x0, z0.d, lsl #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i16(<vscale x 2 x ptr> %buckets, i16 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i8_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8_2_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1b { z2.d }, p0/z, [x0, z0.d]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1b { z1.d }, p0, [x0, z0.d]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i8, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i8(<vscale x 2 x ptr> %buckets, i8 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: add z1.s, z2.s, z1.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 1, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: adr z1.s, [z2.s, z1.s, lsl #1]
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 2, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_3:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, #3 // =0x3
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 3, <vscale x 4 x i1> %mask)
+ ret void
+}
+
attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }
>From bd9c2c36013f235537a30e3c0b477a114cee18d7 Mon Sep 17 00:00:00 2001
From: Max Beck-Jones <max.beck-jones at arm.com>
Date: Wed, 31 Jul 2024 11:10:51 +0000
Subject: [PATCH 2/2] fixup: Neaten VTs generation
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 13 ++++++-------
1 file changed, 6 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 153d5fe28be7b..d30d548e4126e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -28019,13 +28019,12 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
EVT IncVT = Inc.getValueType();
EVT IndexVT = Index.getValueType();
- EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
- IndexVT.getVectorElementCount());
- EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4
- ? MVT::i32
- : MVT::i64;
- EVT IncSplatVT = EVT::getVectorVT(*DAG.getContext(), IncExtVT,
- IndexVT.getVectorElementCount());
+ LLVMContext &Ctx = *DAG.getContext();
+ ElementCount EC = IndexVT.getVectorElementCount();
+ EVT MemVT = EVT::getVectorVT(Ctx, IncVT, EC);
+ EVT IncExtVT =
+ EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
+ EVT IncSplatVT = EVT::getVectorVT(Ctx, IncExtVT, EC);
bool ExtTrunc = IncSplatVT != MemVT;
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
More information about the llvm-commits
mailing list