[llvm] 670d208 - [AArch64] Implement promotion type legalisation for histogram intrinsic (#101017)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 12 03:09:00 PDT 2024
Author: Max Beck-Jones
Date: 2024-08-12T11:08:57+01:00
New Revision: 670d208ffc156b5b8f01aee7439847b01b18d05d
URL: https://github.com/llvm/llvm-project/commit/670d208ffc156b5b8f01aee7439847b01b18d05d
DIFF: https://github.com/llvm/llvm-project/commit/670d208ffc156b5b8f01aee7439847b01b18d05d.diff
LOG: [AArch64] Implement promotion type legalisation for histogram intrinsic (#101017)
Currently the histogram intrinsic
(llvm.experimental.vector.histogram.add) only allows i32 and i64 types
for the memory locations to be updated, matching the restrictions of the
histcnt instruction. This patch adds support for the legalisation of
smaller types (i8 and i16) via promotion.
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/sve2-histcnt.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 98ec2c7f529ecd..7777aa4b50a370 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1776,9 +1776,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
// Histcnt is SVE2 only
- if (Subtarget->hasSVE2())
+ if (Subtarget->hasSVE2()) {
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
Custom);
+ setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
+ setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
+ }
}
@@ -28175,11 +28178,18 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
EVT IncVT = Inc.getValueType();
EVT IndexVT = Index.getValueType();
- EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
- IndexVT.getVectorElementCount());
+ LLVMContext &Ctx = *DAG.getContext();
+ ElementCount EC = IndexVT.getVectorElementCount();
+ EVT MemVT = EVT::getVectorVT(Ctx, IncVT, EC);
+ EVT IncExtVT =
+ EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
+ EVT IncSplatVT = EVT::getVectorVT(Ctx, IncExtVT, EC);
+ bool ExtTrunc = IncSplatVT != MemVT;
+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
- SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
- SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
+ SDValue PassThru = DAG.getSplatVector(IncSplatVT, DL, Zero);
+ SDValue IncSplat = DAG.getSplatVector(
+ IncSplatVT, DL, DAG.getAnyExtOrTrunc(Inc, DL, IncExtVT));
SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};
MachineMemOperand *MMO = HG->getMemOperand();
@@ -28188,18 +28198,19 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
MMO->getAlign(), MMO->getAAInfo());
ISD::MemIndexType IndexType = HG->getIndexType();
- SDValue Gather =
- DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
- GMMO, IndexType, ISD::NON_EXTLOAD);
+ SDValue Gather = DAG.getMaskedGather(
+ DAG.getVTList(IncSplatVT, MVT::Other), MemVT, DL, Ops, GMMO, IndexType,
+ ExtTrunc ? ISD::EXTLOAD : ISD::NON_EXTLOAD);
SDValue GChain = Gather.getValue(1);
// Perform the histcnt, multiply by inc, add to bucket data.
- SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncExtVT);
SDValue HistCnt =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
- SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, IncSplatVT, HistCnt, IncSplat);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, IncSplatVT, Gather, Mul);
// Create an MMO for the scatter, without load|store flags.
MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
@@ -28208,7 +28219,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
- ScatterOps, SMMO, IndexType, false);
+ ScatterOps, SMMO, IndexType, ExtTrunc);
return Scatter;
}
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index db164e288abde3..2874e47511e12f 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -50,4 +50,123 @@ define void @histogram_i32_literal_noscale(ptr %base, <vscale x 4 x i32> %indice
ret void
}
+define void @histogram_i32_promote(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i32 %inc) #0 {
+; CHECK-LABEL: histogram_i32_promote:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1w { z2.d }, p0/z, [x0, z0.d, lsl #2]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1w { z1.d }, p0, [x0, z0.d, lsl #2]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i32, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i32(<vscale x 2 x ptr> %buckets, i32 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, w1
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 %inc, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i8(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, w1
+; CHECK-NEXT: ld1b { z2.s }, p0/z, [x0, z0.s, sxtw]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1b { z1.s }, p0, [x0, z0.s, sxtw]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i8, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i8(<vscale x 4 x ptr> %buckets, i8 %inc, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16_2_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1h { z2.d }, p0/z, [x0, z0.d, lsl #1]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1h { z1.d }, p0, [x0, z0.d, lsl #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i16(<vscale x 2 x ptr> %buckets, i16 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i8_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8_2_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1b { z2.d }, p0/z, [x0, z0.d]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1b { z1.d }, p0, [x0, z0.d]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i8, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i8(<vscale x 2 x ptr> %buckets, i8 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: add z1.s, z2.s, z1.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 1, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: adr z1.s, [z2.s, z1.s, lsl #1]
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 2, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_3:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, #3 // =0x3
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 3, <vscale x 4 x i1> %mask)
+ ret void
+}
+
attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }
More information about the llvm-commits
mailing list