[llvm] [DAGCombine] Support (shl %x, constant) in foldPartialReduceMLAMulOp. (PR #160663)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 25 01:53:50 PDT 2025


https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/160663

Support shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as (mul %x, (shl 1, %c)).

>From 35da689f547d8ab8d86a0f9779cd7c5b39bee7a3 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 25 Sep 2025 09:52:45 +0100
Subject: [PATCH] [DAGCombine] Support (shl %x, constant) in
 foldPartialReduceMLAMulOp.

Support shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as
(mul %x, (shl 1, %c)).
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 24 +++++-
 .../neon-partial-reduce-dot-product.ll        | 86 +++++++++++++++----
 2 files changed, 88 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a6ba6e518899f..5794ce06a0fa3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12996,13 +12996,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
 
-  APInt C;
-  if (Op1->getOpcode() != ISD::MUL ||
-      !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+  unsigned Opc = Op1->getOpcode();
+  if (Opc != ISD::MUL && Opc != ISD::SHL)
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
   SDValue RHS = Op1->getOperand(1);
+
+  // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c.
+  if (Opc == ISD::SHL) {
+    APInt C;
+    if (!ISD::isConstantSplatVector(RHS.getNode(), C))
+      return SDValue();
+
+    RHS =
+        DAG.getSplatVector(RHS.getValueType(), DL,
+                           DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL,
+                                           RHS.getValueType().getScalarType()));
+    Opc = ISD::MUL;
+  }
+
+  APInt C;
+  if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
+      !C.isOne())
+    return SDValue();
+
   unsigned LHSOpcode = LHS->getOpcode();
   if (!ISD::isExtOpcode(LHSOpcode))
     return SDValue();
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index d60c870003e4d..428750740fc56 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1257,21 +1257,55 @@ entry:
 }
 
 define <4 x i32> @partial_reduce_shl_sext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
-; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-NODOT-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NODOT-NEXT:    sshll v3.4s, v0.4h, #6
+; CHECK-NODOT-NEXT:    sshll2 v4.4s, v2.8h, #6
+; CHECK-NODOT-NEXT:    sshll v2.4s, v2.4h, #6
+; CHECK-NODOT-NEXT:    sshll2 v0.4s, v0.8h, #6
+; CHECK-NODOT-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT:    add v2.4s, v4.4s, v3.4s
+; CHECK-NODOT-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT:    add v0.4s, v1.4s, v0.4s
+; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-DOT-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    movi v2.16b, #64
+; CHECK-DOT-NEXT:    sdot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-DOT-I8MM:       // %bb.0:
+; CHECK-DOT-I8MM-NEXT:    movi v2.16b, #64
+; CHECK-DOT-I8MM-NEXT:    sdot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-DOT-I8MM-NEXT:    ret
+  %ext = sext <16 x i8> %l to <16 x i32>
+  %shift = shl nsw <16 x i32> %ext, splat (i32 6)
+  %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
+  ret <4 x i32> %red
+}
+
+define <4 x i32> @partial_reduce_shl_sext_const_rhs7(<16 x i8> %l, <4 x i32> %part) {
+; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs7:
 ; CHECK-COMMON:       // %bb.0:
 ; CHECK-COMMON-NEXT:    sshll v2.8h, v0.8b, #0
 ; CHECK-COMMON-NEXT:    sshll2 v0.8h, v0.16b, #0
-; CHECK-COMMON-NEXT:    sshll v3.4s, v0.4h, #6
-; CHECK-COMMON-NEXT:    sshll2 v4.4s, v2.8h, #6
-; CHECK-COMMON-NEXT:    sshll v2.4s, v2.4h, #6
-; CHECK-COMMON-NEXT:    sshll2 v0.4s, v0.8h, #6
+; CHECK-COMMON-NEXT:    sshll v3.4s, v0.4h, #7
+; CHECK-COMMON-NEXT:    sshll2 v4.4s, v2.8h, #7
+; CHECK-COMMON-NEXT:    sshll v2.4s, v2.4h, #7
+; CHECK-COMMON-NEXT:    sshll2 v0.4s, v0.8h, #7
 ; CHECK-COMMON-NEXT:    add v1.4s, v1.4s, v2.4s
 ; CHECK-COMMON-NEXT:    add v2.4s, v4.4s, v3.4s
 ; CHECK-COMMON-NEXT:    add v1.4s, v1.4s, v2.4s
 ; CHECK-COMMON-NEXT:    add v0.4s, v1.4s, v0.4s
 ; CHECK-COMMON-NEXT:    ret
   %ext = sext <16 x i8> %l to <16 x i32>
-  %shift = shl nsw <16 x i32> %ext, splat (i32 6)
+  %shift = shl nsw <16 x i32> %ext, splat (i32 7)
   %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
   ret <4 x i32> %red
 }
@@ -1331,19 +1365,33 @@ define <4 x i32> @partial_reduce_shl_sext_non_const_rhs(<16 x i8> %l, <4 x i32>
 }
 
 define <4 x i32> @partial_reduce_shl_zext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
-; CHECK-COMMON-LABEL: partial_reduce_shl_zext_const_rhs6:
-; CHECK-COMMON:       // %bb.0:
-; CHECK-COMMON-NEXT:    ushll v2.8h, v0.8b, #0
-; CHECK-COMMON-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-COMMON-NEXT:    ushll v3.4s, v0.4h, #6
-; CHECK-COMMON-NEXT:    ushll2 v4.4s, v2.8h, #6
-; CHECK-COMMON-NEXT:    ushll v2.4s, v2.4h, #6
-; CHECK-COMMON-NEXT:    ushll2 v0.4s, v0.8h, #6
-; CHECK-COMMON-NEXT:    add v1.4s, v1.4s, v2.4s
-; CHECK-COMMON-NEXT:    add v2.4s, v4.4s, v3.4s
-; CHECK-COMMON-NEXT:    add v1.4s, v1.4s, v2.4s
-; CHECK-COMMON-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-COMMON-NEXT:    ret
+; CHECK-NODOT-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-NODOT-NEXT:    ushll v3.4s, v0.4h, #6
+; CHECK-NODOT-NEXT:    ushll2 v4.4s, v2.8h, #6
+; CHECK-NODOT-NEXT:    ushll v2.4s, v2.4h, #6
+; CHECK-NODOT-NEXT:    ushll2 v0.4s, v0.8h, #6
+; CHECK-NODOT-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT:    add v2.4s, v4.4s, v3.4s
+; CHECK-NODOT-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT:    add v0.4s, v1.4s, v0.4s
+; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-DOT-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    movi v2.16b, #64
+; CHECK-DOT-NEXT:    udot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-DOT-I8MM:       // %bb.0:
+; CHECK-DOT-I8MM-NEXT:    movi v2.16b, #64
+; CHECK-DOT-I8MM-NEXT:    udot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-DOT-I8MM-NEXT:    ret
   %ext = zext <16 x i8> %l to <16 x i32>
   %shift = shl nsw <16 x i32> %ext, splat (i32 6)
   %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)



More information about the llvm-commits mailing list