[llvm] a8ed244 - [DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op (#131326)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 6 08:54:42 PDT 2025


Author: Nicholas Guy
Date: 2025-05-06T16:54:39+01:00
New Revision: a8ed244178b90876570b3e0bcf643f027ed83b8a

URL: https://github.com/llvm/llvm-project/commit/a8ed244178b90876570b3e0bcf643f027ed83b8a
DIFF: https://github.com/llvm/llvm-project/commit/a8ed244178b90876570b3e0bcf643f027ed83b8a.diff

LOG: [DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op (#131326)

Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert:
PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and
PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).

---------

Co-authored-by: James Chesterman <james.chesterman at arm.com>

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b0150fb6367a5..09c6218b3dfd9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -618,6 +618,8 @@ namespace {
     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
     SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
                                  const TargetLowering &TLI);
+    SDValue foldPartialReduceMLAMulOp(SDNode *N);
+    SDValue foldPartialReduceAdd(SDNode *N);
 
     SDValue CombineExtLoad(SDNode *N);
     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12601,12 +12603,20 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+  if (SDValue Res = foldPartialReduceMLAMulOp(N))
+    return Res;
+  if (SDValue Res = foldPartialReduceAdd(N))
+    return Res;
+  return SDValue();
+}
+
 // partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
 // -> partial_reduce_*mla(acc, a, b)
 //
 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
 // -> partial_reduce_*mla(acc, x, C)
-SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDLoc DL(N);
   auto *Context = DAG.getContext();
   SDValue Acc = N->getOperand(0);
@@ -12672,6 +12682,43 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
                      RHSExtOp);
 }
 
+// partial.reduce.umla(acc, zext(op), splat(1))
+// -> partial.reduce.umla(acc, op, splat(trunc(1)))
+// partial.reduce.smla(acc, sext(op), splat(1))
+// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
+  SDLoc DL(N);
+  SDValue Acc = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Op2 = N->getOperand(2);
+
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    return SDValue();
+
+  unsigned Op1Opcode = Op1.getOpcode();
+  if (!ISD::isExtOpcode(Op1Opcode))
+    return SDValue();
+
+  SDValue UnextOp1 = Op1.getOperand(0);
+  EVT UnextOp1VT = UnextOp1.getValueType();
+  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+    return SDValue();
+
+  bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+  bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  EVT AccElemVT = Acc.getValueType().getVectorElementType();
+  if (Op1IsSigned != NodeIsSigned &&
+      Op1.getValueType().getVectorElementType() != AccElemVT)
+    return SDValue();
+
+  unsigned NewOpcode =
+      Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
+                     DAG.getConstant(1, DL, UnextOp1VT));
+}
+
 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
   auto *SLD = cast<VPStridedLoadSDNode>(N);
   EVT EltVT = SLD->getValueType(0).getVectorElementType();

diff  --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 5326bccbbc3d5..67be3f58e8a24 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -516,16 +516,8 @@ define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
 ; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z1.s, z4.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    udot z0.s, z1.b, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -541,16 +533,8 @@ define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
 ; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z1.s, z4.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    sdot z0.s, z1.b, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -566,16 +550,8 @@ define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    udot z0.d, z1.h, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -592,16 +568,8 @@ define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-NEWLOWERING-NEXT:    sdot z0.d, z1.h, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>


        


More information about the llvm-commits mailing list