[llvm] [AArch64] Improve index selection for histograms (PR #111150)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 16 03:14:56 PDT 2024


https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/111150

>From 5db9c2e9165f3fd948d20f8b4a820e0fd93b6eca Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 4 Oct 2024 12:38:25 +0000
Subject: [PATCH 1/6] [AArch64] Improve index selection for histograms

Search for extends to the index used in a histogram operation
then perform a truncate on it. This avoids the need to split the
instruction in two.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 37 +++++++++-
 llvm/test/CodeGen/AArch64/sve2-histcnt.ll     | 73 +++++++++++++++++++
 2 files changed, 107 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 48e1b96d841efb..545d5b59c64562 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1114,7 +1114,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                        ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
                        ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});
 
-  setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER});
+  setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
 
   setTargetDAGCombine(ISD::FP_EXTEND);
 
@@ -24079,12 +24079,42 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
 
 static SDValue performMaskedGatherScatterCombine(
     SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
-  MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
-  assert(MGS && "Can only combine gather load or scatter store nodes");
+  MaskedHistogramSDNode *HG;
+  MaskedGatherScatterSDNode *MGS;
+  if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+    HG = cast<MaskedHistogramSDNode>(N);
+  } else {
+    MGS = cast<MaskedGatherScatterSDNode>(N);
+  }
+  assert((HG || MGS) &&
+         "Can only combine gather load, scatter store or histogram nodes");
 
   if (!DCI.isBeforeLegalize())
     return SDValue();
 
+  if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+    SDLoc DL(HG);
+    SDValue Index = HG->getIndex();
+    if (ISD::isExtOpcode(Index->getOpcode())) {
+      SDValue Chain = HG->getChain();
+      SDValue Inc = HG->getInc();
+      SDValue Mask = HG->getMask();
+      SDValue BasePtr = HG->getBasePtr();
+      SDValue Scale = HG->getScale();
+      SDValue IntID = HG->getIntID();
+      EVT MemVT = HG->getMemoryVT();
+      MachineMemOperand *MMO = HG->getMemOperand();
+      ISD::MemIndexType IndexType = HG->getIndexType();
+      SDValue ExtOp = Index.getOperand(0);
+      auto SrcType = ExtOp.getValueType();
+      auto TruncatedIndex = DAG.getAnyExtOrTrunc(Index, DL, SrcType);
+      SDValue Ops[] = {Chain, Inc, Mask, BasePtr, TruncatedIndex, Scale, IntID};
+      return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
+                                    MMO, IndexType);
+    }
+    return SDValue();
+  }
+
   SDLoc DL(MGS);
   SDValue Chain = MGS->getChain();
   SDValue Scale = MGS->getScale();
@@ -26277,6 +26307,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performMSTORECombine(N, DCI, DAG, Subtarget);
   case ISD::MGATHER:
   case ISD::MSCATTER:
+  case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return performMaskedGatherScatterCombine(N, DCI, DAG);
   case ISD::FP_EXTEND:
     return performFPExtendCombine(N, DAG, DCI, Subtarget);
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index dd0b9639a8fc2f..42fff1ec7c532f 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -267,5 +267,78 @@ define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %i
   ret void
 }
 
+define void @histogram_i32_extend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_extend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+define void @histogram_i32_8_lane_extend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_extend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.s, w1
+; CHECK-NEXT:    ptrue p2.s
+; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT:    st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+  ret void
+}
+define void @histogram_i32_sextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_i32_sextend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+define void @histogram_i32_8_lane_sextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_sextend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.s, w1
+; CHECK-NEXT:    ptrue p2.s
+; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT:    st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = sext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+  ret void
+}
+
 
 attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }

>From 517c129b6fab10551a6afd3d4627abffc9087bc2 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 7 Oct 2024 09:26:42 +0000
Subject: [PATCH 2/6] Ensure conformity to code formatting rules and improve
 code quality

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 43 ++++++++-----------
 llvm/test/CodeGen/AArch64/sve2-histcnt.ll     |  9 ++--
 2 files changed, 22 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 545d5b59c64562..2cf59cb7e1dfe0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1114,7 +1114,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                        ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
                        ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});
 
-  setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
+  setTargetDAGCombine(
+      {ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
 
   setTargetDAGCombine(ISD::FP_EXTEND);
 
@@ -24079,42 +24080,32 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
 
 static SDValue performMaskedGatherScatterCombine(
     SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
-  MaskedHistogramSDNode *HG;
-  MaskedGatherScatterSDNode *MGS;
-  if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
-    HG = cast<MaskedHistogramSDNode>(N);
-  } else {
-    MGS = cast<MaskedGatherScatterSDNode>(N);
-  }
-  assert((HG || MGS) &&
-         "Can only combine gather load, scatter store or histogram nodes");
-
   if (!DCI.isBeforeLegalize())
     return SDValue();
 
   if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
-    SDLoc DL(HG);
+    MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
+    assert(HG &&
+           "Can only combine gather load, scatter store or histogram nodes");
+
     SDValue Index = HG->getIndex();
     if (ISD::isExtOpcode(Index->getOpcode())) {
-      SDValue Chain = HG->getChain();
-      SDValue Inc = HG->getInc();
-      SDValue Mask = HG->getMask();
-      SDValue BasePtr = HG->getBasePtr();
-      SDValue Scale = HG->getScale();
-      SDValue IntID = HG->getIntID();
-      EVT MemVT = HG->getMemoryVT();
-      MachineMemOperand *MMO = HG->getMemOperand();
-      ISD::MemIndexType IndexType = HG->getIndexType();
+      SDLoc DL(HG);
       SDValue ExtOp = Index.getOperand(0);
-      auto SrcType = ExtOp.getValueType();
-      auto TruncatedIndex = DAG.getAnyExtOrTrunc(Index, DL, SrcType);
-      SDValue Ops[] = {Chain, Inc, Mask, BasePtr, TruncatedIndex, Scale, IntID};
-      return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
-                                    MMO, IndexType);
+      SDValue Ops[] = {HG->getChain(),   HG->getInc(), HG->getMask(),
+                       HG->getBasePtr(), ExtOp,        HG->getScale(),
+                       HG->getIntID()};
+      return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other),
+                                    HG->getMemoryVT(), DL, Ops,
+                                    HG->getMemOperand(), HG->getIndexType());
     }
     return SDValue();
   }
 
+  MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
+  assert(MGS &&
+         "Can only combine gather load, scatter store or histogram nodes");
+
   SDLoc DL(MGS);
   SDValue Chain = MGS->getChain();
   SDValue Scale = MGS->getScale();
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index 42fff1ec7c532f..7bac4e6d306c4c 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -267,8 +267,8 @@ define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %i
   ret void
 }
 
-define void @histogram_i32_extend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
-; CHECK-LABEL: histogram_i32_extend:
+define void @histogram_i32_zextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_zextend:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
 ; CHECK-NEXT:    mov z3.s, #1 // =0x1
@@ -282,8 +282,9 @@ define void @histogram_i32_extend(ptr %base, <vscale x 4 x i32> %indices, <vscal
   call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
   ret void
 }
-define void @histogram_i32_8_lane_extend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
-; CHECK-LABEL: histogram_i32_8_lane_extend:
+
+define void @histogram_i32_8_lane_zextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_zextend:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    punpklo p1.h, p0.b
 ; CHECK-NEXT:    mov z4.s, w1

>From 9af9286560174be63f86343947f28229dee3c443 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 7 Oct 2024 16:07:18 +0000
Subject: [PATCH 3/6] Code quality improvements

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2cf59cb7e1dfe0..dfc89787e8e477 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24085,11 +24085,11 @@ static SDValue performMaskedGatherScatterCombine(
 
   if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
     MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
-    assert(HG &&
-           "Can only combine gather load, scatter store or histogram nodes");
 
     SDValue Index = HG->getIndex();
-    if (ISD::isExtOpcode(Index->getOpcode())) {
+    if (!ISD::isExtOpcode(Index->getOpcode())) {
+      return SDValue();
+    } else {
       SDLoc DL(HG);
       SDValue ExtOp = Index.getOperand(0);
       SDValue Ops[] = {HG->getChain(),   HG->getInc(), HG->getMask(),
@@ -24099,12 +24099,9 @@ static SDValue performMaskedGatherScatterCombine(
                                     HG->getMemoryVT(), DL, Ops,
                                     HG->getMemOperand(), HG->getIndexType());
     }
-    return SDValue();
   }
 
   MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
-  assert(MGS &&
-         "Can only combine gather load, scatter store or histogram nodes");
 
   SDLoc DL(MGS);
   SDValue Chain = MGS->getChain();

>From a28165f83eac427e672692ef6c08a90ad46393d7 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 8 Oct 2024 09:22:30 +0000
Subject: [PATCH 4/6] Small change to conform to LLVM coding standards

---
 .../lib/Target/AArch64/AArch64ISelLowering.cpp | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dfc89787e8e477..94389843de8961 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24089,16 +24089,16 @@ static SDValue performMaskedGatherScatterCombine(
     SDValue Index = HG->getIndex();
     if (!ISD::isExtOpcode(Index->getOpcode())) {
       return SDValue();
-    } else {
-      SDLoc DL(HG);
-      SDValue ExtOp = Index.getOperand(0);
-      SDValue Ops[] = {HG->getChain(),   HG->getInc(), HG->getMask(),
-                       HG->getBasePtr(), ExtOp,        HG->getScale(),
-                       HG->getIntID()};
-      return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other),
-                                    HG->getMemoryVT(), DL, Ops,
-                                    HG->getMemOperand(), HG->getIndexType());
     }
+    SDLoc DL(HG);
+    SDValue ExtOp = Index.getOperand(0);
+    SDValue Ops[] = {HG->getChain(),   HG->getInc(), HG->getMask(),
+                      HG->getBasePtr(), ExtOp,        HG->getScale(),
+                      HG->getIntID()};
+    return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other),
+                                  HG->getMemoryVT(), DL, Ops,
+                                  HG->getMemOperand(), HG->getIndexType());
+    
   }
 
   MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);

>From 7c61336bbcb18f0c2bbda6979df13f3a81ef270d Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 8 Oct 2024 10:27:02 +0000
Subject: [PATCH 5/6] Small change as previous changes made code not conform to
 formatting standards.

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 94389843de8961..55c4a7cde2aca6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24093,12 +24093,11 @@ static SDValue performMaskedGatherScatterCombine(
     SDLoc DL(HG);
     SDValue ExtOp = Index.getOperand(0);
     SDValue Ops[] = {HG->getChain(),   HG->getInc(), HG->getMask(),
-                      HG->getBasePtr(), ExtOp,        HG->getScale(),
-                      HG->getIntID()};
-    return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other),
-                                  HG->getMemoryVT(), DL, Ops,
-                                  HG->getMemOperand(), HG->getIndexType());
-    
+                     HG->getBasePtr(), ExtOp,        HG->getScale(),
+                     HG->getIntID()};
+    return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), HG->getMemoryVT(),
+                                  DL, Ops, HG->getMemOperand(),
+                                  HG->getIndexType());
   }
 
   MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);

>From 9430fcdfc86e72f3a353ee7c6027db74e7f45326 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 16 Oct 2024 10:13:12 +0000
Subject: [PATCH 6/6] Major change to the patch. Shares more code with Masked
 Gather and Scatter. Now also removes the instruction when its mask is zero.

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h |  10 +-
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  33 +++
 .../Target/AArch64/AArch64ISelLowering.cpp    |  39 ++--
 llvm/test/CodeGen/AArch64/sve2-histcnt.ll     | 210 +++++++++++++++---
 4 files changed, 234 insertions(+), 58 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 6067b3b29ea181..e23c8c0f175d7e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -2935,8 +2935,8 @@ class MaskedGatherScatterSDNode : public MemSDNode {
   const SDValue &getScale()   const { return getOperand(5); }
 
   static bool classof(const SDNode *N) {
-    return N->getOpcode() == ISD::MGATHER ||
-           N->getOpcode() == ISD::MSCATTER;
+    return N->getOpcode() == ISD::MGATHER || N->getOpcode() == ISD::MSCATTER ||
+           N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM;
   }
 };
 
@@ -2991,15 +2991,15 @@ class MaskedScatterSDNode : public MaskedGatherScatterSDNode {
   }
 };
 
-class MaskedHistogramSDNode : public MemSDNode {
+class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
 public:
   friend class SelectionDAG;
 
   MaskedHistogramSDNode(unsigned Order, const DebugLoc &DL, SDVTList VTs,
                         EVT MemVT, MachineMemOperand *MMO,
                         ISD::MemIndexType IndexType)
-      : MemSDNode(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, Order, DL, VTs, MemVT,
-                  MMO) {
+      : MaskedGatherScatterSDNode(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, Order, DL,
+                                  VTs, MemVT, MMO, IndexType) {
     LSBaseSDNodeBits.AddressingMode = IndexType;
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c6f6fc25080541..e8551a891d626d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -549,6 +549,7 @@ namespace {
     SDValue visitMSTORE(SDNode *N);
     SDValue visitMGATHER(SDNode *N);
     SDValue visitMSCATTER(SDNode *N);
+    SDValue visitMHISTOGRAM(SDNode *N);
     SDValue visitVPGATHER(SDNode *N);
     SDValue visitVPSCATTER(SDNode *N);
     SDValue visitVP_STRIDED_LOAD(SDNode *N);
@@ -1972,6 +1973,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::MLOAD:              return visitMLOAD(N);
   case ISD::MSCATTER:           return visitMSCATTER(N);
   case ISD::MSTORE:             return visitMSTORE(N);
+  case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
@@ -12353,6 +12355,37 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
+  MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
+  SDValue Chain = HG->getChain();
+  SDValue Inc = HG->getInc();
+  SDValue Mask = HG->getMask();
+  SDValue BasePtr = HG->getBasePtr();
+  SDValue Index = HG->getIndex();
+  SDLoc DL(HG);
+
+  EVT MemVT = HG->getMemoryVT();
+  MachineMemOperand *MMO = HG->getMemOperand();
+  ISD::MemIndexType IndexType = HG->getIndexType();
+
+  if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) {
+    return Chain;
+  }
+  SDValue Ops[] = {Chain,          Inc,           Mask, BasePtr, Index,
+                   HG->getScale(), HG->getIntID()};
+  if (refineUniformBase(BasePtr, Index, HG->isIndexScaled(), DAG, DL)) {
+    return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
+                                  MMO, IndexType);
+  }
+  EVT DataVT = Index.getValueType();
+  DataVT.changeVectorElementType(Inc.getValueType());
+  if (refineIndexType(Index, IndexType, DataVT, DAG)) {
+    return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
+                                  MMO, IndexType);
+  }
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
   auto *SLD = cast<VPStridedLoadSDNode>(N);
   EVT EltVT = SLD->getValueType(0).getVectorElementType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 55c4a7cde2aca6..3420d2ef07bcc6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24082,24 +24082,6 @@ static SDValue performMaskedGatherScatterCombine(
     SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
   if (!DCI.isBeforeLegalize())
     return SDValue();
-
-  if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
-    MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
-
-    SDValue Index = HG->getIndex();
-    if (!ISD::isExtOpcode(Index->getOpcode())) {
-      return SDValue();
-    }
-    SDLoc DL(HG);
-    SDValue ExtOp = Index.getOperand(0);
-    SDValue Ops[] = {HG->getChain(),   HG->getInc(), HG->getMask(),
-                     HG->getBasePtr(), ExtOp,        HG->getScale(),
-                     HG->getIntID()};
-    return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), HG->getMemoryVT(),
-                                  DL, Ops, HG->getMemOperand(),
-                                  HG->getIndexType());
-  }
-
   MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
 
   SDLoc DL(MGS);
@@ -24110,8 +24092,9 @@ static SDValue performMaskedGatherScatterCombine(
   SDValue BasePtr = MGS->getBasePtr();
   ISD::MemIndexType IndexType = MGS->getIndexType();
 
-  if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG))
+  if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG)) {
     return SDValue();
+  }
 
   // Here we catch such cases early and change MGATHER's IndexType to allow
   // the use of an Index that's more legalisation friendly.
@@ -24122,12 +24105,18 @@ static SDValue performMaskedGatherScatterCombine(
         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
   }
-  auto *MSC = cast<MaskedScatterSDNode>(MGS);
-  SDValue Data = MSC->getValue();
-  SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale};
-  return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL,
-                              Ops, MSC->getMemOperand(), IndexType,
-                              MSC->isTruncatingStore());
+  if (auto *MSC = dyn_cast<MaskedScatterSDNode>(MGS)) {
+    SDValue Data = MSC->getValue();
+    SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale};
+    return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
+                                DL, Ops, MSC->getMemOperand(), IndexType,
+                                MSC->isTruncatingStore());
+  }
+  auto *HG = cast<MaskedHistogramSDNode>(MGS);
+  SDValue Ops[] = {Chain, HG->getInc(), Mask,          BasePtr,
+                   Index, Scale,        HG->getIntID()};
+  return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), HG->getMemoryVT(),
+                                DL, Ops, HG->getMemOperand(), IndexType);
 }
 
 /// Target-specific DAG combine function for NEON load/store intrinsics
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index 7bac4e6d306c4c..06cd65620d1c9e 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -267,8 +267,56 @@ define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %i
   ret void
 }
 
-define void @histogram_i32_zextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
-; CHECK-LABEL: histogram_i32_zextend:
+define void @histogram_i8_zext(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i8 %inc) #0{
+; CHECK-LABEL: histogram_i8_zext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, w1
+; CHECK-NEXT:    ld1b { z2.s }, p0/z, [x0, z0.s, uxtw]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1b { z1.s }, p0, [x0, z0.s, uxtw]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i8, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i8(<vscale x 4 x ptr> %buckets, i8 %inc, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_zext(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i16 %inc) #0{
+; CHECK-LABEL: histogram_i16_zext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, w1
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, uxtw #1]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, uxtw #1]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i16, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 %inc, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i32_zext(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_zext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i32_sext(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_sext:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
 ; CHECK-NEXT:    mov z3.s, #1 // =0x1
@@ -277,53 +325,142 @@ define void @histogram_i32_zextend(ptr %base, <vscale x 4 x i32> %indices, <vsca
 ; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
 ; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
 ; CHECK-NEXT:    ret
-  %extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %extended = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
   %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
   call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
   ret void
 }
 
-define void @histogram_i32_8_lane_zextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
-; CHECK-LABEL: histogram_i32_8_lane_zextend:
+define void @histogram_zext_from_i8_to_i64(ptr %base, <vscale x 4 x i8> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_zext_from_i8_to_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    punpklo p1.h, p0.b
-; CHECK-NEXT:    mov z4.s, w1
-; CHECK-NEXT:    ptrue p2.s
-; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
-; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
-; CHECK-NEXT:    punpkhi p0.h, p0.b
-; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
-; CHECK-NEXT:    st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
-; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
-; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
-; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
-; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    and z0.s, z0.s, #0xff
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, uxtw #2]
 ; CHECK-NEXT:    ret
-  %extended = zext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
-  %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
-  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+  %extended = zext <vscale x 4 x i8> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
   ret void
 }
-define void @histogram_i32_sextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0{
-; CHECK-LABEL: histogram_i32_sextend:
+
+define void @histogram_zext_from_i16_to_i64(ptr %base, <vscale x 4 x i16> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_zext_from_i16_to_i64:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    and z0.s, z0.s, #0xffff
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ptrue p1.s
 ; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 4 x i16> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_sext_from_i16_to_i64(ptr %base, <vscale x 4 x i16> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_sext_from_i16_to_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p1.s
 ; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    sxth z0.s, p1/m, z0.s
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
 ; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
-; CHECK-NEXT:    ptrue p1.s
 ; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
 ; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
 ; CHECK-NEXT:    ret
-  %extended = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %extended = sext <vscale x 4 x i16> %indices to <vscale x 4 x i64>
   %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
   call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
   ret void
 }
-define void @histogram_i32_8_lane_sextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
-; CHECK-LABEL: histogram_i32_8_lane_sextend:
+
+define void @histogram_zext_from_i8_to_i32(ptr %base, <vscale x 4 x i8> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_zext_from_i8_to_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    and z0.s, z0.s, #0xff
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 4 x i8> %indices to <vscale x 4 x i32>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i32> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_zext_from_i16_to_i32(ptr %base, <vscale x 4 x i16> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_zext_from_i16_to_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    and z0.s, z0.s, #0xffff
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 4 x i16> %indices to <vscale x 4 x i32>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i32> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_2_lane_zext(ptr %base, <vscale x 2 x i32> %indices, <vscale x 2 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_2_lane_zext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z1.d, z0.d
+; CHECK-NEXT:    mov z3.d, #1 // =0x1
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    ld1w { z2.d }, p0/z, [x0, z0.d, uxtw #2]
+; CHECK-NEXT:    and z1.d, z1.d, #0xffffffff
+; CHECK-NEXT:    histcnt z1.d, p0/z, z1.d, z1.d
+; CHECK-NEXT:    mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT:    st1w { z1.d }, p0, [x0, z0.d, uxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 2 x i32> %indices to <vscale x 2 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 2 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv2p0.i32(<vscale x 2 x ptr> %buckets, i32 1, <vscale x 2 x i1> %mask)
+  ret void
+}
+
+define void @histogram_8_lane_zext(ptr %base, <vscale x 8 x i32> %indices, <vscale x 8 x i1> %mask) #0{
+; CHECK-LABEL: histogram_8_lane_zext:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    punpklo p1.h, p0.b
-; CHECK-NEXT:    mov z4.s, w1
+; CHECK-NEXT:    mov z4.s, #1 // =0x1
+; CHECK-NEXT:    ptrue p2.s
+; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT:    st1w { z2.s }, p1, [x0, z0.s, uxtw #2]
+; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z1.s, uxtw #2]
+; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, uxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 1, <vscale x 8 x i1> %mask)
+  ret void
+}
+
+define void @histogram_8_lane_sext(ptr %base, <vscale x 8 x i32> %indices, <vscale x 8 x i1> %mask) #0{
+; CHECK-LABEL: histogram_8_lane_sext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.s, #1 // =0x1
 ; CHECK-NEXT:    ptrue p2.s
 ; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
 ; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
@@ -337,9 +474,26 @@ define void @histogram_i32_8_lane_sextend(ptr %base, <vscale x 8 x i32> %indices
 ; CHECK-NEXT:    ret
   %extended = sext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
   %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
-  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 1, <vscale x 8 x i1> %mask)
   ret void
 }
 
+define void @histogram_zero_mask(<vscale x 2 x ptr> %buckets, i64 %inc, <vscale x 2 x i1> %mask) #0{
+; CHECK-LABEL: histogram_zero_mask:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  call void @llvm.experimental.vector.histogram.add.nxv2p0.i64(<vscale x 2 x ptr> %buckets, i64 %inc, <vscale x 2 x i1> zeroinitializer)
+  ret void
+}
+
+define void @histogram_sext_zero_mask(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_sext_zero_mask:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  %extended = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> zeroinitializer)
+  ret void
+}
 
 attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }



More information about the llvm-commits mailing list