[llvm] [DAGCombine] Simplify partial_reduce_*mla with constant. (PR #138289)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Fri May 2 12:32:14 PDT 2025
https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/138289
>From 6c9e8ec8928c60c80584d251d9c5d22cafeb100b Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 2 May 2025 14:05:29 +0000
Subject: [PATCH] [DAGCombine] Simplify partial_reduce_*mla with constant.
partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-> partial_reduce_*mla(acc, x, C)
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 54 ++++---
.../AArch64/sve-partial-reduce-dot-product.ll | 143 +++++++++++++++++-
2 files changed, 176 insertions(+), 21 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..de9b979fe3072 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
+// partial_reduce_*mla(acc, mul(ext(a), ext(b)))
+// -> partial_reduce_*mla(acc, a, b)
+//
+// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, C)
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);
-
+ auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt ConstantOne;
+ APInt C;
if (Op1->getOpcode() != ISD::MUL ||
- !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
return SDValue();
SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
unsigned LHSOpcode = LHS->getOpcode();
- unsigned RHSOpcode = RHS->getOpcode();
- if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
+ if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();
SDValue LHSExtOp = LHS->getOperand(0);
- SDValue RHSExtOp = RHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
- if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
- return SDValue();
- // Only perform the DAG combine if there is custom lowering provided by the
- // target
- auto *Context = DAG.getContext();
+ // Only perform these combines if the target supports folding
+ // the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+ unsigned NewOpcode =
+ ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+ // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
+ // -> partial_reduce_*mla(acc, x, C)
+ if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
+ APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
+ unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
+ if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
+ (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
+ return SDValue();
+
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+ DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+ }
+
+ unsigned RHSOpcode = RHS->getOpcode();
+ if (!ISD::isExtOpcode(RHSOpcode))
+ return SDValue();
+
+ SDValue RHSExtOp = RHS->getOperand(0);
+ if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+ return SDValue();
// For a 2-stage extend the signedness of both of the extends must be the
// same. This is so the node can be folded into only a signed or unsigned
@@ -12663,8 +12679,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- unsigned NewOpcode =
- ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
RHSExtOp);
}
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 039cac01008b8..5326bccbbc3d5 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1139,7 +1139,6 @@ entry:
ret <vscale x 2 x i16> %partial.reduce
}
-
define <vscale x 4 x i64> @partial_reduce_only_split_acc(<vscale x 4 x i64> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: partial_reduce_only_split_acc:
; CHECK: // %bb.0: // %entry
@@ -1178,3 +1177,145 @@ entry:
<vscale x 4 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 4 x i64> %partial.reduce
}
+
+define <vscale x 4 x i32> @sdot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEXT: sub z0.s, z0.s, z3.s
+; CHECK-NEXT: sunpklo z3.s, z1.h
+; CHECK-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEXT: sub z0.s, z0.s, z2.s
+; CHECK-NEXT: sub z0.s, z0.s, z3.s
+; CHECK-NEXT: sub z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 -1)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm_does_not_fit:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEXT: sunpklo z4.s, z1.h
+; CHECK-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm_does_not_fit:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uunpklo z3.h, z1.b
+; CHECK-NEXT: mov z2.s, #255 // =0xff
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEXT: uunpklo z4.s, z3.h
+; CHECK-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: mla z0.s, p0/m, z3.s, z2.s
+; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 255)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm_does_not_fit:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uunpklo z2.h, z1.b
+; CHECK-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm_does_not_fit:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
More information about the llvm-commits
mailing list