[llvm] [AArch64][SelectionDAG] Add support for 8to64 partial reduction cases (PR #138269)
via llvm-commits
llvm-commits at lists.llvm.org
Fri May 2 06:08:57 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Nicholas Guy (NickGuy-Arm)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/138269.diff
3 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+28)
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+1)
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+12-74)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0126b97c9fb9a..ca63473ab085b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1867,6 +1867,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// Other pairs will default to 'Expand'.
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
}
// Handle operations that are only available in non-streaming SVE mode.
@@ -7767,6 +7769,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerFLDEXP(Op, DAG);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return LowerVECTOR_HISTOGRAM(Op, DAG);
+ case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_UMLA:
+ return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -29509,6 +29514,29 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
return Scatter;
}
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+
+ auto Acc = Op.getOperand(0);
+ auto LHS = Op.getOperand(1);
+ auto RHS = Op.getOperand(2);
+
+ auto ResultVT = Op.getValueType();
+
+ assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
+
+ auto NewAcc = DAG.getConstant(0, DL, MVT::nxv4i32);
+ auto DotNode =
+ DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32, NewAcc, LHS, RHS);
+
+ auto Lo = DAG.getNode(AArch64ISD::UUNPKLO, DL, ResultVT, DotNode);
+ auto Hi = DAG.getNode(AArch64ISD::UUNPKHI, DL, ResultVT, DotNode);
+ auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+ return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+}
+
SDValue
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d9b535b910b80..9d8d1c22258be 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1181,6 +1181,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 039cac01008b8..4d5996ee955ce 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -198,43 +198,12 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
;
; CHECK-NEWLOWERING-LABEL: udot_8to64:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -258,43 +227,12 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
;
; CHECK-NEWLOWERING-LABEL: sdot_8to64:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
``````````
</details>
https://github.com/llvm/llvm-project/pull/138269
More information about the llvm-commits
mailing list