[llvm] [AArch64] Improve index selection for histograms (PR #111150)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 8 03:27:55 PDT 2024
https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/111150
>From 5db9c2e9165f3fd948d20f8b4a820e0fd93b6eca Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 4 Oct 2024 12:38:25 +0000
Subject: [PATCH 1/5] [AArch64] Improve index selection for histograms
Search for extends to the index used in a histogram operation
then perform a truncate on it. This avoids the need to split the
instruction in two.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 37 +++++++++-
llvm/test/CodeGen/AArch64/sve2-histcnt.ll | 73 +++++++++++++++++++
2 files changed, 107 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 48e1b96d841efb..545d5b59c64562 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1114,7 +1114,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});
- setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER});
+ setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
setTargetDAGCombine(ISD::FP_EXTEND);
@@ -24079,12 +24079,42 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
static SDValue performMaskedGatherScatterCombine(
SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
- MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
- assert(MGS && "Can only combine gather load or scatter store nodes");
+ MaskedHistogramSDNode *HG;
+ MaskedGatherScatterSDNode *MGS;
+ if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+ HG = cast<MaskedHistogramSDNode>(N);
+ } else {
+ MGS = cast<MaskedGatherScatterSDNode>(N);
+ }
+ assert((HG || MGS) &&
+ "Can only combine gather load, scatter store or histogram nodes");
if (!DCI.isBeforeLegalize())
return SDValue();
+ if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+ SDLoc DL(HG);
+ SDValue Index = HG->getIndex();
+ if (ISD::isExtOpcode(Index->getOpcode())) {
+ SDValue Chain = HG->getChain();
+ SDValue Inc = HG->getInc();
+ SDValue Mask = HG->getMask();
+ SDValue BasePtr = HG->getBasePtr();
+ SDValue Scale = HG->getScale();
+ SDValue IntID = HG->getIntID();
+ EVT MemVT = HG->getMemoryVT();
+ MachineMemOperand *MMO = HG->getMemOperand();
+ ISD::MemIndexType IndexType = HG->getIndexType();
+ SDValue ExtOp = Index.getOperand(0);
+ auto SrcType = ExtOp.getValueType();
+ auto TruncatedIndex = DAG.getAnyExtOrTrunc(Index, DL, SrcType);
+ SDValue Ops[] = {Chain, Inc, Mask, BasePtr, TruncatedIndex, Scale, IntID};
+ return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
+ MMO, IndexType);
+ }
+ return SDValue();
+ }
+
SDLoc DL(MGS);
SDValue Chain = MGS->getChain();
SDValue Scale = MGS->getScale();
@@ -26277,6 +26307,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performMSTORECombine(N, DCI, DAG, Subtarget);
case ISD::MGATHER:
case ISD::MSCATTER:
+ case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return performMaskedGatherScatterCombine(N, DCI, DAG);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index dd0b9639a8fc2f..42fff1ec7c532f 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -267,5 +267,78 @@ define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %i
ret void
}
+define void @histogram_i32_extend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_extend:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, #1 // =0x1
+; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: ret
+ %extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+ %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+ ret void
+}
+define void @histogram_i32_8_lane_extend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_extend:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: mov z4.s, w1
+; CHECK-NEXT: ptrue p2.s
+; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: ret
+ %extended = zext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+ %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+ call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+ ret void
+}
+define void @histogram_i32_sextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_i32_sextend:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, #1 // =0x1
+; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: ret
+ %extended = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+ %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+ ret void
+}
+define void @histogram_i32_8_lane_sextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_sextend:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: mov z4.s, w1
+; CHECK-NEXT: ptrue p2.s
+; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: ret
+ %extended = sext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+ %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+ call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+ ret void
+}
+
attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }
>From 517c129b6fab10551a6afd3d4627abffc9087bc2 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 7 Oct 2024 09:26:42 +0000
Subject: [PATCH 2/5] Ensure conformity to code formatting rules and improve
code quality
---
.../Target/AArch64/AArch64ISelLowering.cpp | 43 ++++++++-----------
llvm/test/CodeGen/AArch64/sve2-histcnt.ll | 9 ++--
2 files changed, 22 insertions(+), 30 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 545d5b59c64562..2cf59cb7e1dfe0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1114,7 +1114,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});
- setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
+ setTargetDAGCombine(
+ {ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
setTargetDAGCombine(ISD::FP_EXTEND);
@@ -24079,42 +24080,32 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
static SDValue performMaskedGatherScatterCombine(
SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
- MaskedHistogramSDNode *HG;
- MaskedGatherScatterSDNode *MGS;
- if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
- HG = cast<MaskedHistogramSDNode>(N);
- } else {
- MGS = cast<MaskedGatherScatterSDNode>(N);
- }
- assert((HG || MGS) &&
- "Can only combine gather load, scatter store or histogram nodes");
-
if (!DCI.isBeforeLegalize())
return SDValue();
if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
- SDLoc DL(HG);
+ MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
+ assert(HG &&
+ "Can only combine gather load, scatter store or histogram nodes");
+
SDValue Index = HG->getIndex();
if (ISD::isExtOpcode(Index->getOpcode())) {
- SDValue Chain = HG->getChain();
- SDValue Inc = HG->getInc();
- SDValue Mask = HG->getMask();
- SDValue BasePtr = HG->getBasePtr();
- SDValue Scale = HG->getScale();
- SDValue IntID = HG->getIntID();
- EVT MemVT = HG->getMemoryVT();
- MachineMemOperand *MMO = HG->getMemOperand();
- ISD::MemIndexType IndexType = HG->getIndexType();
+ SDLoc DL(HG);
SDValue ExtOp = Index.getOperand(0);
- auto SrcType = ExtOp.getValueType();
- auto TruncatedIndex = DAG.getAnyExtOrTrunc(Index, DL, SrcType);
- SDValue Ops[] = {Chain, Inc, Mask, BasePtr, TruncatedIndex, Scale, IntID};
- return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
- MMO, IndexType);
+ SDValue Ops[] = {HG->getChain(), HG->getInc(), HG->getMask(),
+ HG->getBasePtr(), ExtOp, HG->getScale(),
+ HG->getIntID()};
+ return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other),
+ HG->getMemoryVT(), DL, Ops,
+ HG->getMemOperand(), HG->getIndexType());
}
return SDValue();
}
+ MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
+ assert(MGS &&
+ "Can only combine gather load, scatter store or histogram nodes");
+
SDLoc DL(MGS);
SDValue Chain = MGS->getChain();
SDValue Scale = MGS->getScale();
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index 42fff1ec7c532f..7bac4e6d306c4c 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -267,8 +267,8 @@ define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %i
ret void
}
-define void @histogram_i32_extend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
-; CHECK-LABEL: histogram_i32_extend:
+define void @histogram_i32_zextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_zextend:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: mov z3.s, #1 // =0x1
@@ -282,8 +282,9 @@ define void @histogram_i32_extend(ptr %base, <vscale x 4 x i32> %indices, <vscal
call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
ret void
}
-define void @histogram_i32_8_lane_extend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
-; CHECK-LABEL: histogram_i32_8_lane_extend:
+
+define void @histogram_i32_8_lane_zextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_zextend:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: mov z4.s, w1
>From 9af9286560174be63f86343947f28229dee3c443 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 7 Oct 2024 16:07:18 +0000
Subject: [PATCH 3/5] Code quality improvements
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 +++------
1 file changed, 3 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2cf59cb7e1dfe0..dfc89787e8e477 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24085,11 +24085,11 @@ static SDValue performMaskedGatherScatterCombine(
if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
- assert(HG &&
- "Can only combine gather load, scatter store or histogram nodes");
SDValue Index = HG->getIndex();
- if (ISD::isExtOpcode(Index->getOpcode())) {
+ if (!ISD::isExtOpcode(Index->getOpcode())) {
+ return SDValue();
+ } else {
SDLoc DL(HG);
SDValue ExtOp = Index.getOperand(0);
SDValue Ops[] = {HG->getChain(), HG->getInc(), HG->getMask(),
@@ -24099,12 +24099,9 @@ static SDValue performMaskedGatherScatterCombine(
HG->getMemoryVT(), DL, Ops,
HG->getMemOperand(), HG->getIndexType());
}
- return SDValue();
}
MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
- assert(MGS &&
- "Can only combine gather load, scatter store or histogram nodes");
SDLoc DL(MGS);
SDValue Chain = MGS->getChain();
>From a28165f83eac427e672692ef6c08a90ad46393d7 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 8 Oct 2024 09:22:30 +0000
Subject: [PATCH 4/5] Small change to conform to LLVM coding standards
---
.../lib/Target/AArch64/AArch64ISelLowering.cpp | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dfc89787e8e477..94389843de8961 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24089,16 +24089,16 @@ static SDValue performMaskedGatherScatterCombine(
SDValue Index = HG->getIndex();
if (!ISD::isExtOpcode(Index->getOpcode())) {
return SDValue();
- } else {
- SDLoc DL(HG);
- SDValue ExtOp = Index.getOperand(0);
- SDValue Ops[] = {HG->getChain(), HG->getInc(), HG->getMask(),
- HG->getBasePtr(), ExtOp, HG->getScale(),
- HG->getIntID()};
- return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other),
- HG->getMemoryVT(), DL, Ops,
- HG->getMemOperand(), HG->getIndexType());
}
+ SDLoc DL(HG);
+ SDValue ExtOp = Index.getOperand(0);
+ SDValue Ops[] = {HG->getChain(), HG->getInc(), HG->getMask(),
+ HG->getBasePtr(), ExtOp, HG->getScale(),
+ HG->getIntID()};
+ return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other),
+ HG->getMemoryVT(), DL, Ops,
+ HG->getMemOperand(), HG->getIndexType());
+
}
MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
>From 7c61336bbcb18f0c2bbda6979df13f3a81ef270d Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 8 Oct 2024 10:27:02 +0000
Subject: [PATCH 5/5] Small change as previous changes made code not conform to
formatting standards.
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 94389843de8961..55c4a7cde2aca6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24093,12 +24093,11 @@ static SDValue performMaskedGatherScatterCombine(
SDLoc DL(HG);
SDValue ExtOp = Index.getOperand(0);
SDValue Ops[] = {HG->getChain(), HG->getInc(), HG->getMask(),
- HG->getBasePtr(), ExtOp, HG->getScale(),
- HG->getIntID()};
- return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other),
- HG->getMemoryVT(), DL, Ops,
- HG->getMemOperand(), HG->getIndexType());
-
+ HG->getBasePtr(), ExtOp, HG->getScale(),
+ HG->getIntID()};
+ return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), HG->getMemoryVT(),
+ DL, Ops, HG->getMemOperand(),
+ HG->getIndexType());
}
MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
More information about the llvm-commits
mailing list