[llvm] [DAGCombine] Simplify partial_reduce_*mla with constant. (PR #138289)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 2 07:51:30 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

<details>
<summary>Changes</summary>

partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-> partial_reduce_*mla(acc, x, C)

---
Full diff: https://github.com/llvm/llvm-project/pull/138289.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+34-20) 
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+142-1) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..345cb4f9fb6ee 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
+// partial_reduce_*mla(acc, mul(zext(a), zext(b)))
+// -> partial_reduce_umla(acc, a, b)
+//
+// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, C)
 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   SDLoc DL(N);
-
+  auto *Context = DAG.getContext();
   SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
 
-  APInt ConstantOne;
+  APInt C;
   if (Op1->getOpcode() != ISD::MUL ||
-      !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
   SDValue RHS = Op1->getOperand(1);
   unsigned LHSOpcode = LHS->getOpcode();
-  unsigned RHSOpcode = RHS->getOpcode();
-  if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
+  if (!ISD::isExtOpcode(LHSOpcode))
     return SDValue();
 
   SDValue LHSExtOp = LHS->getOperand(0);
-  SDValue RHSExtOp = RHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
-  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
-    return SDValue();
 
-  // Only perform the DAG combine if there is custom lowering provided by the
-  // target
-  auto *Context = DAG.getContext();
+  // Only perform these combines if the target supports folding
+  // the extends into the operation.
   if (!TLI.isPartialReduceMLALegalOrCustom(
           TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
   bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+  unsigned NewOpcode =
+      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+  // partial_reduce_*mla(acc, mul(zext(x), splat(C)), splat(1))
+  // -> partial_reduce_umla(acc, x, C)
+  if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
+    APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
+    unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
+    if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
+        (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
+      return SDValue();
+
+    return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+                       DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+  }
+
+  unsigned RHSOpcode = RHS->getOpcode();
+  if (!ISD::isExtOpcode(RHSOpcode))
+    return SDValue();
+
+  SDValue RHSExtOp = RHS->getOperand(0);
+  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+    return SDValue();
 
   // For a 2-stage extend the signedness of both of the extends must be the
   // same. This is so the node can be folded into only a signed or unsigned
@@ -12663,8 +12679,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  unsigned NewOpcode =
-      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
                      RHSExtOp);
 }
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 039cac01008b8..5326bccbbc3d5 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1139,7 +1139,6 @@ entry:
   ret <vscale x 2 x i16> %partial.reduce
 }
 
-
 define <vscale x 4 x i64> @partial_reduce_only_split_acc(<vscale x 4 x i64> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: partial_reduce_only_split_acc:
 ; CHECK:       // %bb.0: // %entry
@@ -1178,3 +1177,145 @@ entry:
   <vscale x 4 x i64> %acc, <vscale x 8 x i64> %mult)
   ret <vscale x 4 x i64> %partial.reduce
 }
+
+define <vscale x 4 x i32> @sdot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    sub z0.s, z0.s, z3.s
+; CHECK-NEXT:    sunpklo z3.s, z1.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    sub z0.s, z0.s, z2.s
+; CHECK-NEXT:    sub z0.s, z0.s, z3.s
+; CHECK-NEXT:    sub z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 -1)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm_does_not_fit:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm_does_not_fit:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uunpklo z3.h, z1.b
+; CHECK-NEXT:    mov z2.s, #255 // =0xff
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEXT:    uunpklo z4.s, z3.h
+; CHECK-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEXT:    mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    mla z0.s, p0/m, z3.s, z2.s
+; CHECK-NEXT:    mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 255)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm_does_not_fit:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uunpklo z2.h, z1.b
+; CHECK-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm_does_not_fit:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT:    lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/138289


More information about the llvm-commits mailing list