[llvm] [DAGCombine] Support (shl %x, constant) in foldPartialReduceMLAMulOp. (PR #160663)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 25 01:54:26 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Florian Hahn (fhahn)
<details>
<summary>Changes</summary>
Support shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as (mul %x, (shl 1, %c)).
---
Full diff: https://github.com/llvm/llvm-project/pull/160663.diff
2 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+21-3)
- (modified) llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll (+67-19)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a6ba6e518899f..5794ce06a0fa3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12996,13 +12996,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt C;
- if (Op1->getOpcode() != ISD::MUL ||
- !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+ unsigned Opc = Op1->getOpcode();
+ if (Opc != ISD::MUL && Opc != ISD::SHL)
return SDValue();
SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
+
+ // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c.
+ if (Opc == ISD::SHL) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(RHS.getNode(), C))
+ return SDValue();
+
+ RHS =
+ DAG.getSplatVector(RHS.getValueType(), DL,
+ DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL,
+ RHS.getValueType().getScalarType()));
+ Opc = ISD::MUL;
+ }
+
+ APInt C;
+ if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
+ !C.isOne())
+ return SDValue();
+
unsigned LHSOpcode = LHS->getOpcode();
if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index d60c870003e4d..428750740fc56 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1257,21 +1257,55 @@ entry:
}
define <4 x i32> @partial_reduce_shl_sext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
-; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-NODOT-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-NODOT-NEXT: sshll v3.4s, v0.4h, #6
+; CHECK-NODOT-NEXT: sshll2 v4.4s, v2.8h, #6
+; CHECK-NODOT-NEXT: sshll v2.4s, v2.4h, #6
+; CHECK-NODOT-NEXT: sshll2 v0.4s, v0.8h, #6
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-DOT-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #64
+; CHECK-DOT-NEXT: sdot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-DOT-I8MM: // %bb.0:
+; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
+; CHECK-DOT-I8MM-NEXT: sdot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-I8MM-NEXT: ret
+ %ext = sext <16 x i8> %l to <16 x i32>
+ %shift = shl nsw <16 x i32> %ext, splat (i32 6)
+ %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @partial_reduce_shl_sext_const_rhs7(<16 x i8> %l, <4 x i32> %part) {
+; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs7:
; CHECK-COMMON: // %bb.0:
; CHECK-COMMON-NEXT: sshll v2.8h, v0.8b, #0
; CHECK-COMMON-NEXT: sshll2 v0.8h, v0.16b, #0
-; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #6
-; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #6
-; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #6
-; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #6
+; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #7
+; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #7
+; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #7
+; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #7
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-COMMON-NEXT: ret
%ext = sext <16 x i8> %l to <16 x i32>
- %shift = shl nsw <16 x i32> %ext, splat (i32 6)
+ %shift = shl nsw <16 x i32> %ext, splat (i32 7)
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
ret <4 x i32> %red
}
@@ -1331,19 +1365,33 @@ define <4 x i32> @partial_reduce_shl_sext_non_const_rhs(<16 x i8> %l, <4 x i32>
}
define <4 x i32> @partial_reduce_shl_zext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
-; CHECK-COMMON-LABEL: partial_reduce_shl_zext_const_rhs6:
-; CHECK-COMMON: // %bb.0:
-; CHECK-COMMON-NEXT: ushll v2.8h, v0.8b, #0
-; CHECK-COMMON-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-COMMON-NEXT: ushll v3.4s, v0.4h, #6
-; CHECK-COMMON-NEXT: ushll2 v4.4s, v2.8h, #6
-; CHECK-COMMON-NEXT: ushll v2.4s, v2.4h, #6
-; CHECK-COMMON-NEXT: ushll2 v0.4s, v0.8h, #6
-; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
-; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
-; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
-; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-COMMON-NEXT: ret
+; CHECK-NODOT-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-NODOT-NEXT: ushll v3.4s, v0.4h, #6
+; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #6
+; CHECK-NODOT-NEXT: ushll v2.4s, v2.4h, #6
+; CHECK-NODOT-NEXT: ushll2 v0.4s, v0.8h, #6
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-DOT-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #64
+; CHECK-DOT-NEXT: udot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-DOT-I8MM: // %bb.0:
+; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
+; CHECK-DOT-I8MM-NEXT: udot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-I8MM-NEXT: ret
%ext = zext <16 x i8> %l to <16 x i32>
%shift = shl nsw <16 x i32> %ext, splat (i32 6)
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
``````````
</details>
https://github.com/llvm/llvm-project/pull/160663
More information about the llvm-commits
mailing list