[llvm] [DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op (PR #131326)
Nicholas Guy via llvm-commits
llvm-commits at lists.llvm.org
Fri May 2 06:29:57 PDT 2025
https://github.com/NickGuy-Arm updated https://github.com/llvm/llvm-project/pull/131326
>From 8004c2e5a754b6f4295d4e4b4002badecf335977 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Thu, 1 May 2025 15:15:25 +0100
Subject: [PATCH 1/3] [DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when
no mul op
Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert:
PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1)))
and
PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).
# Conflicts:
# llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 54 ++++++++++++++++++-
1 file changed, 53 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..0eb9c8e044609 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -618,6 +618,8 @@ namespace {
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI);
+ SDValue foldPartialReduceMLAMulOp(SDNode *N);
+ SDValue foldPartialReduceMLANoMulOp(SDNode *N);
SDValue CombineExtLoad(SDNode *N);
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12612,13 +12614,21 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+ if (SDValue Res = foldPartialReduceMLAMulOp(N))
+ return Res;
+ if (SDValue Res = foldPartialReduceMLANoMulOp(N))
+ return Res;
+ return SDValue();
+}
+
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
// Splat(1)) into
// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
// Splat(1)) into
// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
-SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
@@ -12669,6 +12679,48 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
RHSExtOp);
}
+// Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
+// PARTIAL_REDUCE_UMLA(Acc, Op, TRUNC(Splat(1)))
+// Makes PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
+// PARTIAL_REDUCE_SMLA(Acc, Op, TRUNC(Splat(1)))
+SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
+ SDLoc DL(N);
+ SDValue Acc = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue Op2 = N->getOperand(2);
+
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ return SDValue();
+
+ unsigned Op1Opcode = Op1.getOpcode();
+ if (!ISD::isExtOpcode(Op1Opcode))
+ return SDValue();
+
+ SDValue UnextOp1 = Op1.getOperand(0);
+ EVT UnextOp1VT = UnextOp1.getValueType();
+
+ if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+ return SDValue();
+
+ SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
+
+ bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+
+ bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+ EVT AccElemVT = Acc.getValueType().getVectorElementType();
+ if (Op1IsSigned != NodeIsSigned &&
+ (Op1.getValueType().getVectorElementType() != AccElemVT ||
+ Op2.getValueType().getVectorElementType() != AccElemVT))
+ return SDValue();
+
+ unsigned NewOpcode =
+ Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
+ TruncOp2);
+}
+
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
auto *SLD = cast<VPStridedLoadSDNode>(N);
EVT EltVT = SLD->getValueType(0).getVectorElementType();
>From c0d0029e774455381acac2a7e103cabff9492764 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Fri, 2 May 2025 14:10:45 +0100
Subject: [PATCH 2/3] Update comment
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0eb9c8e044609..209ba7858e4a5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12679,10 +12679,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
RHSExtOp);
}
-// Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, Op, TRUNC(Splat(1)))
-// Makes PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, Op, TRUNC(Splat(1)))
+// Makes partial.reduce.umla(acc, zext(op1), splat(1)) into
+// partial.reduce.umla(acc, op, splat(trunc(1)))
+// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into
+// partial.reduce.smla(acc, op, splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
>From a708b964b10a0b5a36bc8464c5409be79265c630 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Fri, 2 May 2025 14:28:03 +0100
Subject: [PATCH 3/3] Update tests after rebase
---
.../AArch64/sve-partial-reduce-dot-product.ll | 48 ++++---------------
1 file changed, 8 insertions(+), 40 deletions(-)
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 039cac01008b8..3cc223b63d0f7 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -516,16 +516,8 @@ define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT: add z1.s, z4.s, z3.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -541,16 +533,8 @@ define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT: add z1.s, z4.s, z3.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -566,16 +550,8 @@ define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: udot z0.d, z1.h, z2.h
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -592,16 +568,8 @@ define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT: sdot z0.d, z1.h, z2.h
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
More information about the llvm-commits
mailing list