[llvm] e86a92f - [AArch64][SelectionDAG] Add support for 8to64 partial reduction cases (#138269)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 6 08:55:02 PDT 2025


Author: Nicholas Guy
Date: 2025-05-06T16:54:59+01:00
New Revision: e86a92f947b0ddf624b19b005a23e55823219524

URL: https://github.com/llvm/llvm-project/commit/e86a92f947b0ddf624b19b005a23e55823219524
DIFF: https://github.com/llvm/llvm-project/commit/e86a92f947b0ddf624b19b005a23e55823219524.diff

LOG: [AArch64][SelectionDAG] Add support for 8to64 partial reduction cases (#138269)

---------

Co-authored-by: James Chesterman <james.chesterman at arm.com>

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1c889d67c81e0..16ef5b072a028 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1868,6 +1868,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     // Other pairs will default to 'Expand'.
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
     setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+
+    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
   }
 
   // Handle operations that are only available in non-streaming SVE mode.
@@ -7740,6 +7742,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerFLDEXP(Op, DAG);
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return LowerVECTOR_HISTOGRAM(Op, DAG);
+  case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_UMLA:
+    return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
 
@@ -29476,6 +29481,40 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   return Scatter;
 }
 
+/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
+/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
+/// however still make use of the dot product instruction by instead
+/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+                                               SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+
+  SDValue Acc = Op.getOperand(0);
+  SDValue LHS = Op.getOperand(1);
+  SDValue RHS = Op.getOperand(2);
+  EVT ResultVT = Op.getValueType();
+  assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
+
+  SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
+                                DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
+
+  bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
+  if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
+    unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
+    unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
+    SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
+    return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
+  }
+
+  unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
+  unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
+  auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
+  auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
+  auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+  return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+}
+
 SDValue
 AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
                                                     SelectionDAG &DAG) const {

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d9b535b910b80..9d8d1c22258be 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1181,6 +1181,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;

diff  --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 67be3f58e8a24..5bc9a101b1e44 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,7 +1,9 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
-; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
+; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
 
 define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: udot:
@@ -196,46 +198,31 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: udot_8to64:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -256,46 +243,31 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: sdot_8to64:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>


        


More information about the llvm-commits mailing list