[llvm] 1693d8e - [AArch64][SelectionDAG] Vector splitting and promotion for histogram intrinsic (#103037)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 30 00:54:16 PDT 2024


Author: Max Beck-Jones
Date: 2024-08-30T08:54:12+01:00
New Revision: 1693d8eb9aa94b0e8e2395234e6c63b57a2017b7

URL: https://github.com/llvm/llvm-project/commit/1693d8eb9aa94b0e8e2395234e6c63b57a2017b7
DIFF: https://github.com/llvm/llvm-project/commit/1693d8eb9aa94b0e8e2395234e6c63b57a2017b7.diff

LOG: [AArch64][SelectionDAG] Vector splitting and promotion for histogram intrinsic (#103037)

Adds support for wider-than-legal vector types for the histogram
intrinsic (llvm.experimental.vector.histogram.add) by splitting the
vector. Also adds integer promotion for the Inc operand.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve2-histcnt.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 74e3a898569bea..f5fbc01cd95e96 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1241,6 +1241,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
     Action = TLI.getOperationAction(Node->getOpcode(),
                                     Node->getOperand(0).getValueType());
     break;
+  case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
+    Action = TLI.getOperationAction(
+        Node->getOpcode(),
+        cast<MaskedHistogramSDNode>(Node)->getIndex().getValueType());
+    break;
   default:
     if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
       Action = TLI.getCustomOperationAction(*Node);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 05971152d535c6..a1cb74f43e6050 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2043,6 +2043,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::EXPERIMENTAL_VP_SPLICE:
     Res = PromoteIntOp_VP_SPLICE(N, OpNo);
     break;
+  case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
+    Res = PromoteIntOp_VECTOR_HISTOGRAM(N, OpNo);
+    break;
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -2755,6 +2758,14 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo) {
   return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N,
+                                                        unsigned OpNo) {
+  assert(OpNo == 1 && "Unexpected operand for promotion");
+  SmallVector<SDValue, 7> NewOps(N->ops());
+  NewOps[1] = GetPromotedInteger(N->getOperand(1));
+  return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
+}
+
 //===----------------------------------------------------------------------===//
 //  Integer Result Expansion
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 4577346a02d605..f15e3d7fd1eb63 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -425,6 +425,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntOp_PATCHPOINT(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VP_STRIDED(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
+  SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
 
   void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
   void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
@@ -982,6 +983,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SplitVecOp_CMP(SDNode *N);
   SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
   SDValue SplitVecOp_VP_CttzElements(SDNode *N);
+  SDValue SplitVecOp_VECTOR_HISTOGRAM(SDNode *N);
 
   //===--------------------------------------------------------------------===//
   // Vector Widening Support: LegalizeVectorTypes.cpp

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 4c6da7c5df6b40..ae994d492d57e6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3355,6 +3355,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
     Res = SplitVecOp_VP_CttzElements(N);
     break;
+  case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
+    Res = SplitVecOp_VECTOR_HISTOGRAM(N);
+    break;
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -4374,6 +4377,28 @@ SDValue DAGTypeLegalizer::SplitVecOp_VP_CttzElements(SDNode *N) {
                        DAG.getNode(ISD::ADD, DL, ResVT, VLo, ResHi));
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
+  MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
+  SDLoc DL(HG);
+  SDValue Inc = HG->getInc();
+  SDValue Ptr = HG->getBasePtr();
+  SDValue Scale = HG->getScale();
+  SDValue IntID = HG->getIntID();
+  EVT MemVT = HG->getMemoryVT();
+  MachineMemOperand *MMO = HG->getMemOperand();
+  ISD::MemIndexType IndexType = HG->getIndexType();
+
+  SDValue IndexLo, IndexHi, MaskLo, MaskHi;
+  std::tie(IndexLo, IndexHi) = DAG.SplitVector(HG->getIndex(), DL);
+  std::tie(MaskLo, MaskHi) = DAG.SplitVector(HG->getMask(), DL);
+  SDValue OpsLo[] = {HG->getChain(), Inc, MaskLo, Ptr, IndexLo, Scale, IntID};
+  SDValue Lo = DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL,
+                                      OpsLo, MMO, IndexType);
+  SDValue OpsHi[] = {Lo, Inc, MaskHi, Ptr, IndexHi, Scale, IntID};
+  return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, OpsHi,
+                                MMO, IndexType);
+}
+
 //===----------------------------------------------------------------------===//
 //  Result Vector Widening
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a49dfdb28a41eb..3296f63a9b8876 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1790,10 +1790,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     // Histcnt is SVE2 only
     if (Subtarget->hasSVE2()) {
-      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
+      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv4i32,
+                         Custom);
+      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
                          Custom);
-      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
-      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
     }
   }
 
@@ -28550,11 +28550,10 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   assert(CID->getZExtValue() == Intrinsic::experimental_vector_histogram_add &&
          "Unexpected histogram update operation");
 
-  EVT IncVT = Inc.getValueType();
   EVT IndexVT = Index.getValueType();
   LLVMContext &Ctx = *DAG.getContext();
   ElementCount EC = IndexVT.getVectorElementCount();
-  EVT MemVT = EVT::getVectorVT(Ctx, IncVT, EC);
+  EVT MemVT = EVT::getVectorVT(Ctx, HG->getMemoryVT(), EC);
   EVT IncExtVT =
       EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
   EVT IncSplatVT = EVT::getVectorVT(Ctx, IncExtVT, EC);

diff  --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index 2874e47511e12f..dd0b9639a8fc2f 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -132,8 +132,10 @@ define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vs
 ; CHECK-LABEL: histogram_i16_literal_1:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
 ; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
-; CHECK-NEXT:    add z1.s, z2.s, z1.s
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
 ; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
 ; CHECK-NEXT:    ret
   %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
@@ -145,8 +147,10 @@ define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vs
 ; CHECK-LABEL: histogram_i16_literal_2:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #2 // =0x2
 ; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
-; CHECK-NEXT:    adr z1.s, [z2.s, z1.s, lsl #1]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
 ; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
 ; CHECK-NEXT:    ret
   %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
@@ -169,4 +173,99 @@ define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vs
   ret void
 }
 
+define void @histogram_i64_4_lane(<vscale x 4 x ptr> %buckets, i64 %inc, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i64_4_lane:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.d, x0
+; CHECK-NEXT:    ptrue p2.d
+; CHECK-NEXT:    histcnt z2.d, p1/z, z0.d, z0.d
+; CHECK-NEXT:    ld1d { z3.d }, p1/z, [z0.d]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.d, p2/m, z4.d, z3.d
+; CHECK-NEXT:    st1d { z2.d }, p1, [z0.d]
+; CHECK-NEXT:    histcnt z0.d, p0/z, z1.d, z1.d
+; CHECK-NEXT:    ld1d { z2.d }, p0/z, [z1.d]
+; CHECK-NEXT:    mad z0.d, p2/m, z4.d, z2.d
+; CHECK-NEXT:    st1d { z0.d }, p0, [z1.d]
+; CHECK-NEXT:    ret
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i64(<vscale x 4 x ptr> %buckets, i64 %inc, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i64_8_lane(<vscale x 8 x ptr> %buckets, i64 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i64_8_lane:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p2.h, p0.b
+; CHECK-NEXT:    mov z6.d, x0
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    punpklo p3.h, p2.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    histcnt z4.d, p3/z, z0.d, z0.d
+; CHECK-NEXT:    ld1d { z5.d }, p3/z, [z0.d]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z4.d, p1/m, z6.d, z5.d
+; CHECK-NEXT:    st1d { z4.d }, p3, [z0.d]
+; CHECK-NEXT:    histcnt z0.d, p2/z, z1.d, z1.d
+; CHECK-NEXT:    ld1d { z4.d }, p2/z, [z1.d]
+; CHECK-NEXT:    mad z0.d, p1/m, z6.d, z4.d
+; CHECK-NEXT:    st1d { z0.d }, p2, [z1.d]
+; CHECK-NEXT:    punpklo p2.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    histcnt z0.d, p2/z, z2.d, z2.d
+; CHECK-NEXT:    ld1d { z1.d }, p2/z, [z2.d]
+; CHECK-NEXT:    mad z0.d, p1/m, z6.d, z1.d
+; CHECK-NEXT:    st1d { z0.d }, p2, [z2.d]
+; CHECK-NEXT:    histcnt z0.d, p0/z, z3.d, z3.d
+; CHECK-NEXT:    ld1d { z1.d }, p0/z, [z3.d]
+; CHECK-NEXT:    mad z0.d, p1/m, z6.d, z1.d
+; CHECK-NEXT:    st1d { z0.d }, p0, [z3.d]
+; CHECK-NEXT:    ret
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i64(<vscale x 8 x ptr> %buckets, i64 %inc, <vscale x 8 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i32_8_lane(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.s, w1
+; CHECK-NEXT:    ptrue p2.s
+; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT:    st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i32, ptr %base, <vscale x 8 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_8_lane:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.s, w1
+; CHECK-NEXT:    ptrue p2.s
+; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT:    ld1h { z3.s }, p1/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT:    st1h { z2.s }, p1, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z1.s, sxtw #1]
+; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT:    st1h { z0.s }, p0, [x0, z1.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 8 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i16(<vscale x 8 x ptr> %buckets, i16 %inc, <vscale x 8 x i1> %mask)
+  ret void
+}
+
+
 attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }


        


More information about the llvm-commits mailing list