[llvm] [RISCV] Properly lower multiply-accumulate chains containing a constant (PR #168660)
Ryan Buchner via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 18 23:11:19 PST 2025
https://github.com/bababuck updated https://github.com/llvm/llvm-project/pull/168660
>From cf98a3c6fed0496b7732e15de087061357b0a709 Mon Sep 17 00:00:00 2001
From: bababuck <buchner.ryan at gmail.com>
Date: Tue, 18 Nov 2025 13:10:54 -0800
Subject: [PATCH 1/4] [RISCV] Add test for lowering vector multiply add chains
Namely, tests case such as the following:
%mul1 = mul %m00, %m01
%mul0 = mul %m10, %m11
%add0 = add %mul0, %constant
%add1 = add %add0, %mul1
---
llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 146 +++++++++++++++++++
1 file changed, 146 insertions(+)
create mode 100644 llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
new file mode 100644
index 0000000000000..d161c60c6c7bc
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -0,0 +1,146 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s
+
+define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind {
+; CHECK-LABEL: madd_scalar:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: mul a0, a0, a1
+; CHECK-NEXT: mul a1, a2, a3
+; CHECK-NEXT: add a0, a0, a1
+; CHECK-NEXT: addiw a0, a0, 32
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul nsw i32 %m00, %m01
+ %mul1 = mul nsw i32 %m10, %m11
+ %add0 = add i32 %mul0, 32
+ %add1 = add i32 %add0, %mul1
+ ret i32 %add1
+}
+
+define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) {
+; CHECK-LABEL: vmadd_non_constant:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmadd.vv v8, v10, v16
+; CHECK-NEXT: vmacc.vv v8, v14, v12
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul nsw <8 x i32> %m00, %m01
+ %mul1 = mul nsw <8 x i32> %m10, %m11
+ %add0 = add <8 x i32> %mul0, %addend
+ %add1 = add <8 x i32> %add0, %mul1
+ ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01) {
+; CHECK-LABEL: vmadd_vscale_no_chain:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmv.v.x v10, a0
+; CHECK-NEXT: vmadd.vv v8, v9, v10
+; CHECK-NEXT: ret
+entry:
+ %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+ %mul = mul nsw <vscale x 1 x i32> %m00, %m01
+ %add = add <vscale x 1 x i32> %mul, splat (i32 32)
+ ret <vscale x 1 x i32> %add
+}
+
+define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) {
+; CHECK-LABEL: vmadd_fixed_no_chain:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmv.v.x v12, a0
+; CHECK-NEXT: vmadd.vv v8, v10, v12
+; CHECK-NEXT: ret
+entry:
+ %mul = mul nsw <8 x i32> %m00, %m01
+ %add = add <8 x i32> %mul, splat (i32 32)
+ ret <8 x i32> %add
+}
+
+define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
+; CHECK-LABEL: vmadd_vscale:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmul.vv v8, v8, v9
+; CHECK-NEXT: vmacc.vv v8, v11, v10
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vadd.vx v8, v8, a0
+; CHECK-NEXT: ret
+entry:
+ %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+ %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
+ %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
+ %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+ %add1 = add <vscale x 1 x i32> %add0, %mul1
+ ret <vscale x 1 x i32> %add1
+}
+
+define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
+; CHECK-LABEL: vmadd_fixed:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmul.vv v8, v8, v10
+; CHECK-NEXT: vmacc.vv v8, v14, v12
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vadd.vx v8, v8, a0
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul nsw <8 x i32> %m00, %m01
+ %mul1 = mul nsw <8 x i32> %m10, %m11
+ %add0 = add <8 x i32> %mul0, splat (i32 32)
+ %add1 = add <8 x i32> %add0, %mul1
+ ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
+; CHECK-LABEL: vmadd_vscale_long:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmul.vv v8, v8, v9
+; CHECK-NEXT: vmacc.vv v8, v11, v10
+; CHECK-NEXT: vmacc.vv v8, v13, v12
+; CHECK-NEXT: vmacc.vv v8, v15, v14
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vadd.vx v8, v8, a0
+; CHECK-NEXT: ret
+ <vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
+entry:
+ %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+ %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
+ %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
+ %mul2 = mul nsw <vscale x 1 x i32> %m20, %m21
+ %mul3 = mul nsw <vscale x 1 x i32> %m30, %m31
+ %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+ %add1 = add <vscale x 1 x i32> %add0, %mul1
+ %add2 = add <vscale x 1 x i32> %add1, %mul2
+ %add3 = add <vscale x 1 x i32> %add2, %mul3
+ ret <vscale x 1 x i32> %add3
+}
+
+define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
+; CHECK-LABEL: vmadd_fixed_long:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmul.vv v8, v8, v10
+; CHECK-NEXT: vmacc.vv v8, v14, v12
+; CHECK-NEXT: vmacc.vv v8, v18, v16
+; CHECK-NEXT: vmacc.vv v8, v22, v20
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vadd.vx v8, v8, a0
+; CHECK-NEXT: ret
+ <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
+entry:
+ %mul0 = mul nsw <8 x i32> %m00, %m01
+ %mul1 = mul nsw <8 x i32> %m10, %m11
+ %mul2 = mul nsw <8 x i32> %m20, %m21
+ %mul3 = mul nsw <8 x i32> %m30, %m31
+ %add0 = add <8 x i32> %mul0, splat (i32 32)
+ %add1 = add <8 x i32> %add0, %mul1
+ %add2 = add <8 x i32> %add1, %mul2
+ %add3 = add <8 x i32> %add2, %mul3
+ ret <8 x i32> %add3
+}
>From 4d5aa5dd815a98d6fe2ec503b64c4e14d3d771bb Mon Sep 17 00:00:00 2001
From: bababuck <buchner.ryan at gmail.com>
Date: Tue, 18 Nov 2025 10:32:48 -0800
Subject: [PATCH 2/4] [RISCV] Properly lower multiply-accumulate chains
containing a constant
Previously, the following:
%mul0 = mul nsw <8 x i32> %m00, %m01
%mul1 = mul nsw <8 x i32> %m10, %m11
%add0 = add <8 x i32> %mul0, splat (i32 32)
%add1 = add <8 x i32> %add0, %mul1
lowered to:
vsetivli zero, 8, e32, m2, ta, ma
vmul.vv v8, v8, v9
vmacc.vv v8, v11, v10
li a0, 32
vadd.vx v8, v8, a0
After this patch, now lowers to:
li a0, 32
vsetivli zero, 8, e32, m2, ta, ma
vmv.v.x v12, a0
vmadd.vv v8, v9, v12
vmacc.vv v8, v11, v10
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 14 ++++++++++
llvm/lib/Target/RISCV/RISCVISelLowering.h | 5 ++++
llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 28 ++++++++++----------
3 files changed, 33 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 921d12757d672..809abbc69ce90 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -25655,3 +25655,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {
return VT.getSizeInBits() <= Subtarget.getXLen();
}
+
+bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+ SDValue N1) const {
+ if (!N0.hasOneUse())
+ return false;
+
+ // Avoid reassociating expressions that can be lowered to vector
+ // multiply accumulate (i.e. add (mul x, y), z)
+ if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
+ (N0.getValueType().isVector() && Subtarget.hasStdExtV()))
+ return false;
+
+ return true;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 5cc427c867cfd..f4b3faefb1e95 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering {
bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override;
+ /// Control the following reassociation of operands: (op (op x, c1), y) -> (op
+ /// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
+ bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+ SDValue N1) const override;
+
/// Match a mask which "spreads" the leading elements of a vector evenly
/// across the result. Factor is the spread amount, and Index is the
/// offset applied.
diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
index d161c60c6c7bc..d7618d1d2bcf7 100644
--- a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -64,11 +64,11 @@ entry:
define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
; CHECK-LABEL: vmadd_vscale:
; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT: vmul.vv v8, v8, v9
-; CHECK-NEXT: vmacc.vv v8, v11, v10
; CHECK-NEXT: li a0, 32
-; CHECK-NEXT: vadd.vx v8, v8, a0
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmv.v.x v12, a0
+; CHECK-NEXT: vmadd.vv v8, v9, v12
+; CHECK-NEXT: vmacc.vv v8, v11, v10
; CHECK-NEXT: ret
entry:
%vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
@@ -82,11 +82,11 @@ entry:
define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
; CHECK-LABEL: vmadd_fixed:
; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT: vmul.vv v8, v8, v10
+; CHECK-NEXT: vmv.v.x v16, a0
+; CHECK-NEXT: vmadd.vv v8, v10, v16
; CHECK-NEXT: vmacc.vv v8, v14, v12
-; CHECK-NEXT: li a0, 32
-; CHECK-NEXT: vadd.vx v8, v8, a0
; CHECK-NEXT: ret
entry:
%mul0 = mul nsw <8 x i32> %m00, %m01
@@ -99,13 +99,13 @@ entry:
define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
; CHECK-LABEL: vmadd_vscale_long:
; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT: vmul.vv v8, v8, v9
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmv.v.x v16, a0
+; CHECK-NEXT: vmadd.vv v8, v9, v16
; CHECK-NEXT: vmacc.vv v8, v11, v10
; CHECK-NEXT: vmacc.vv v8, v13, v12
; CHECK-NEXT: vmacc.vv v8, v15, v14
-; CHECK-NEXT: li a0, 32
-; CHECK-NEXT: vadd.vx v8, v8, a0
; CHECK-NEXT: ret
<vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
entry:
@@ -124,13 +124,13 @@ entry:
define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
; CHECK-LABEL: vmadd_fixed_long:
; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT: vmul.vv v8, v8, v10
+; CHECK-NEXT: vmv.v.x v24, a0
+; CHECK-NEXT: vmadd.vv v8, v10, v24
; CHECK-NEXT: vmacc.vv v8, v14, v12
; CHECK-NEXT: vmacc.vv v8, v18, v16
; CHECK-NEXT: vmacc.vv v8, v22, v20
-; CHECK-NEXT: li a0, 32
-; CHECK-NEXT: vadd.vx v8, v8, a0
; CHECK-NEXT: ret
<8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
entry:
>From e6218f2d5d00c1e70901d00d5e7e8ea4ea8b1d0c Mon Sep 17 00:00:00 2001
From: bababuck <buchner.ryan at gmail.com>
Date: Tue, 18 Nov 2025 22:37:36 -0800
Subject: [PATCH 3/4] Check hasVInstructions() rather than hasStdExtV
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 809abbc69ce90..e3f9e24555dae 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -25664,7 +25664,7 @@ bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
// Avoid reassociating expressions that can be lowered to vector
// multiply accumulate (i.e. add (mul x, y), z)
if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
- (N0.getValueType().isVector() && Subtarget.hasStdExtV()))
+ (N0.getValueType().isVector() && Subtarget.hasVInstructions()))
return false;
return true;
>From 5f3502f325861da4114bcf4abd23e2d5a19429e2 Mon Sep 17 00:00:00 2001
From: bababuck <buchner.ryan at gmail.com>
Date: Tue, 18 Nov 2025 22:43:53 -0800
Subject: [PATCH 4/4] Remove dead instructions from test
---
llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 3 ---
1 file changed, 3 deletions(-)
diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
index d7618d1d2bcf7..e2bcd5c08efd2 100644
--- a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -41,7 +41,6 @@ define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscal
; CHECK-NEXT: vmadd.vv v8, v9, v10
; CHECK-NEXT: ret
entry:
- %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
%mul = mul nsw <vscale x 1 x i32> %m00, %m01
%add = add <vscale x 1 x i32> %mul, splat (i32 32)
ret <vscale x 1 x i32> %add
@@ -71,7 +70,6 @@ define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i
; CHECK-NEXT: vmacc.vv v8, v11, v10
; CHECK-NEXT: ret
entry:
- %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
%mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
%mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
%add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
@@ -109,7 +107,6 @@ define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x
; CHECK-NEXT: ret
<vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
entry:
- %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
%mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
%mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
%mul2 = mul nsw <vscale x 1 x i32> %m20, %m21
More information about the llvm-commits
mailing list