[llvm] [AArch64] Implement promotion type legalisation for histogram intrinsic (PR #101017)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 29 07:15:32 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Max Beck-Jones (DevM-uk)

<details>
<summary>Changes</summary>

Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.

---
Full diff: https://github.com/llvm/llvm-project/pull/101017.diff


2 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+22-10) 
- (modified) llvm/test/CodeGen/AArch64/sve2-histcnt.ll (+119) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1e9da9b819bdd..153d5fe28be7b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1775,9 +1775,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
 
     // Histcnt is SVE2 only
-    if (Subtarget->hasSVE2())
+    if (Subtarget->hasSVE2()) {
       setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
                          Custom);
+      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
+      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
+    }
   }
 
 
@@ -28018,9 +28021,17 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   EVT IndexVT = Index.getValueType();
   EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
                                IndexVT.getVectorElementCount());
+  EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4
+                     ? MVT::i32
+                     : MVT::i64;
+  EVT IncSplatVT = EVT::getVectorVT(*DAG.getContext(), IncExtVT,
+                                    IndexVT.getVectorElementCount());
+  bool ExtTrunc = IncSplatVT != MemVT;
+
   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
-  SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
-  SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
+  SDValue PassThru = DAG.getSplatVector(IncSplatVT, DL, Zero);
+  SDValue IncSplat = DAG.getSplatVector(
+      IncSplatVT, DL, DAG.getAnyExtOrTrunc(Inc, DL, IncExtVT));
   SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};
 
   MachineMemOperand *MMO = HG->getMemOperand();
@@ -28029,18 +28040,19 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
       MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
       MMO->getAlign(), MMO->getAAInfo());
   ISD::MemIndexType IndexType = HG->getIndexType();
-  SDValue Gather =
-      DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
-                          GMMO, IndexType, ISD::NON_EXTLOAD);
+  SDValue Gather = DAG.getMaskedGather(
+      DAG.getVTList(IncSplatVT, MVT::Other), MemVT, DL, Ops, GMMO, IndexType,
+      ExtTrunc ? ISD::EXTLOAD : ISD::NON_EXTLOAD);
 
   SDValue GChain = Gather.getValue(1);
 
   // Perform the histcnt, multiply by inc, add to bucket data.
-  SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncExtVT);
   SDValue HistCnt =
       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
-  SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
-  SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
+  SDValue Mul = DAG.getNode(ISD::MUL, DL, IncSplatVT, HistCnt, IncSplat);
+  SDValue Add = DAG.getNode(ISD::ADD, DL, IncSplatVT, Gather, Mul);
 
   // Create an MMO for the scatter, without load|store flags.
   MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
@@ -28049,7 +28061,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
 
   SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
   SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
-                                         ScatterOps, SMMO, IndexType, false);
+                                         ScatterOps, SMMO, IndexType, ExtTrunc);
   return Scatter;
 }
 
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index db164e288abde..2874e47511e12 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -50,4 +50,123 @@ define void @histogram_i32_literal_noscale(ptr %base, <vscale x 4 x i32> %indice
   ret void
 }
 
+define void @histogram_i32_promote(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i32 %inc) #0 {
+; CHECK-LABEL: histogram_i32_promote:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    mov z3.d, x1
+; CHECK-NEXT:    ld1w { z2.d }, p0/z, [x0, z0.d, lsl #2]
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT:    st1w { z1.d }, p0, [x0, z0.d, lsl #2]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i32, ptr %base, <vscale x 2 x i64> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv2p0.i32(<vscale x 2 x ptr> %buckets, i32 %inc, <vscale x 2 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, w1
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 %inc, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i8(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, w1
+; CHECK-NEXT:    ld1b { z2.s }, p0/z, [x0, z0.s, sxtw]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1b { z1.s }, p0, [x0, z0.s, sxtw]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i8, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i8(<vscale x 4 x ptr> %buckets, i8 %inc, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16_2_lane:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    mov z3.d, x1
+; CHECK-NEXT:    ld1h { z2.d }, p0/z, [x0, z0.d, lsl #1]
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT:    st1h { z1.d }, p0, [x0, z0.d, lsl #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 2 x i64> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv2p0.i16(<vscale x 2 x ptr> %buckets, i16 %inc, <vscale x 2 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i8_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8_2_lane:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    mov z3.d, x1
+; CHECK-NEXT:    ld1b { z2.d }, p0/z, [x0, z0.d]
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT:    st1b { z1.d }, p0, [x0, z0.d]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i8, ptr %base, <vscale x 2 x i64> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv2p0.i8(<vscale x 2 x ptr> %buckets, i8 %inc, <vscale x 2 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    add z1.s, z2.s, z1.s
+; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    adr z1.s, [z2.s, z1.s, lsl #1]
+; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 2, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #3 // =0x3
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 3, <vscale x 4 x i1> %mask)
+  ret void
+}
+
 attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }

``````````

</details>


https://github.com/llvm/llvm-project/pull/101017


More information about the llvm-commits mailing list