[llvm] [RISCV] Properly lower multiply-accumulate chains containing a constant (PR #168660)

Ryan Buchner via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 18 23:11:19 PST 2025


https://github.com/bababuck updated https://github.com/llvm/llvm-project/pull/168660

>From cf98a3c6fed0496b7732e15de087061357b0a709 Mon Sep 17 00:00:00 2001
From: bababuck <buchner.ryan at gmail.com>
Date: Tue, 18 Nov 2025 13:10:54 -0800
Subject: [PATCH 1/4] [RISCV] Add test for lowering vector multiply add chains

Namely, tests case such as the following:
%mul1 = mul %m00, %m01
%mul0 = mul %m10, %m11
%add0 = add %mul0, %constant
%add1 = add %add0, %mul1
---
 llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 146 +++++++++++++++++++
 1 file changed, 146 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/vmadd-reassociate.ll

diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
new file mode 100644
index 0000000000000..d161c60c6c7bc
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -0,0 +1,146 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s
+
+define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind {
+; CHECK-LABEL: madd_scalar:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    mul a0, a0, a1
+; CHECK-NEXT:    mul a1, a2, a3
+; CHECK-NEXT:    add a0, a0, a1
+; CHECK-NEXT:    addiw a0, a0, 32
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul nsw i32 %m00, %m01
+  %mul1 = mul nsw i32 %m10, %m11
+  %add0 = add i32 %mul0, 32
+  %add1 = add i32 %add0, %mul1
+  ret i32 %add1
+}
+
+define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) {
+; CHECK-LABEL: vmadd_non_constant:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmadd.vv v8, v10, v16
+; CHECK-NEXT:    vmacc.vv v8, v14, v12
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul nsw <8 x i32> %m00, %m01
+  %mul1 = mul nsw <8 x i32> %m10, %m11
+  %add0 = add <8 x i32> %mul0, %addend
+  %add1 = add <8 x i32> %add0, %mul1
+  ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01) {
+; CHECK-LABEL: vmadd_vscale_no_chain:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v10, a0
+; CHECK-NEXT:    vmadd.vv v8, v9, v10
+; CHECK-NEXT:    ret
+entry:
+  %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+  %mul = mul nsw <vscale x 1 x i32> %m00, %m01
+  %add = add <vscale x 1 x i32> %mul, splat (i32 32)
+  ret <vscale x 1 x i32> %add
+}
+
+define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) {
+; CHECK-LABEL: vmadd_fixed_no_chain:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmv.v.x v12, a0
+; CHECK-NEXT:    vmadd.vv v8, v10, v12
+; CHECK-NEXT:    ret
+entry:
+  %mul = mul nsw <8 x i32> %m00, %m01
+  %add = add <8 x i32> %mul, splat (i32 32)
+  ret <8 x i32> %add
+}
+
+define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
+; CHECK-LABEL: vmadd_vscale:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmul.vv v8, v8, v9
+; CHECK-NEXT:    vmacc.vv v8, v11, v10
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vadd.vx v8, v8, a0
+; CHECK-NEXT:    ret
+entry:
+  %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+  %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
+  %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
+  %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+  %add1 = add <vscale x 1 x i32> %add0, %mul1
+  ret <vscale x 1 x i32> %add1
+}
+
+define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
+; CHECK-LABEL: vmadd_fixed:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmul.vv v8, v8, v10
+; CHECK-NEXT:    vmacc.vv v8, v14, v12
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vadd.vx v8, v8, a0
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul nsw <8 x i32> %m00, %m01
+  %mul1 = mul nsw <8 x i32> %m10, %m11
+  %add0 = add <8 x i32> %mul0, splat (i32 32)
+  %add1 = add <8 x i32> %add0, %mul1
+  ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
+; CHECK-LABEL: vmadd_vscale_long:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmul.vv v8, v8, v9
+; CHECK-NEXT:    vmacc.vv v8, v11, v10
+; CHECK-NEXT:    vmacc.vv v8, v13, v12
+; CHECK-NEXT:    vmacc.vv v8, v15, v14
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vadd.vx v8, v8, a0
+; CHECK-NEXT:    ret
+                                             <vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
+entry:
+  %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+  %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
+  %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
+  %mul2 = mul nsw <vscale x 1 x i32> %m20, %m21
+  %mul3 = mul nsw <vscale x 1 x i32> %m30, %m31
+  %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+  %add1 = add <vscale x 1 x i32> %add0, %mul1
+  %add2 = add <vscale x 1 x i32> %add1, %mul2
+  %add3 = add <vscale x 1 x i32> %add2, %mul3
+  ret <vscale x 1 x i32> %add3
+}
+
+define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
+; CHECK-LABEL: vmadd_fixed_long:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmul.vv v8, v8, v10
+; CHECK-NEXT:    vmacc.vv v8, v14, v12
+; CHECK-NEXT:    vmacc.vv v8, v18, v16
+; CHECK-NEXT:    vmacc.vv v8, v22, v20
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vadd.vx v8, v8, a0
+; CHECK-NEXT:    ret
+                                   <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
+entry:
+  %mul0 = mul nsw <8 x i32> %m00, %m01
+  %mul1 = mul nsw <8 x i32> %m10, %m11
+  %mul2 = mul nsw <8 x i32> %m20, %m21
+  %mul3 = mul nsw <8 x i32> %m30, %m31
+  %add0 = add <8 x i32> %mul0, splat (i32 32)
+  %add1 = add <8 x i32> %add0, %mul1
+  %add2 = add <8 x i32> %add1, %mul2
+  %add3 = add <8 x i32> %add2, %mul3
+  ret <8 x i32> %add3
+}

>From 4d5aa5dd815a98d6fe2ec503b64c4e14d3d771bb Mon Sep 17 00:00:00 2001
From: bababuck <buchner.ryan at gmail.com>
Date: Tue, 18 Nov 2025 10:32:48 -0800
Subject: [PATCH 2/4] [RISCV] Properly lower multiply-accumulate chains
 containing a constant

Previously, the following:
  %mul0 = mul nsw <8 x i32> %m00, %m01
  %mul1 = mul nsw <8 x i32> %m10, %m11
  %add0 = add <8 x i32> %mul0, splat (i32 32)
  %add1 = add <8 x i32> %add0, %mul1

lowered to:
  vsetivli zero, 8, e32, m2, ta, ma
  vmul.vv v8, v8, v9
  vmacc.vv v8, v11, v10
  li a0, 32
  vadd.vx v8, v8, a0

After this patch, now lowers to:
  li a0, 32
  vsetivli zero, 8, e32, m2, ta, ma
  vmv.v.x v12, a0
  vmadd.vv v8, v9, v12
  vmacc.vv v8, v11, v10
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp  | 14 ++++++++++
 llvm/lib/Target/RISCV/RISCVISelLowering.h    |  5 ++++
 llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 28 ++++++++++----------
 3 files changed, 33 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 921d12757d672..809abbc69ce90 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -25655,3 +25655,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {
 
   return VT.getSizeInBits() <= Subtarget.getXLen();
 }
+
+bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+                                              SDValue N1) const {
+  if (!N0.hasOneUse())
+    return false;
+
+  // Avoid reassociating expressions that can be lowered to vector
+  // multiply accumulate (i.e. add (mul x, y), z)
+  if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
+      (N0.getValueType().isVector() && Subtarget.hasStdExtV()))
+    return false;
+
+  return true;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 5cc427c867cfd..f4b3faefb1e95 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering {
 
   bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override;
 
+  /// Control the following reassociation of operands: (op (op x, c1), y) -> (op
+  /// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
+  bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+                           SDValue N1) const override;
+
   /// Match a mask which "spreads" the leading elements of a vector evenly
   /// across the result.  Factor is the spread amount, and Index is the
   /// offset applied.
diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
index d161c60c6c7bc..d7618d1d2bcf7 100644
--- a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -64,11 +64,11 @@ entry:
 define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
 ; CHECK-LABEL: vmadd_vscale:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vmul.vv v8, v8, v9
-; CHECK-NEXT:    vmacc.vv v8, v11, v10
 ; CHECK-NEXT:    li a0, 32
-; CHECK-NEXT:    vadd.vx v8, v8, a0
+; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v12, a0
+; CHECK-NEXT:    vmadd.vv v8, v9, v12
+; CHECK-NEXT:    vmacc.vv v8, v11, v10
 ; CHECK-NEXT:    ret
 entry:
   %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
@@ -82,11 +82,11 @@ entry:
 define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
 ; CHECK-LABEL: vmadd_fixed:
 ; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
 ; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT:    vmul.vv v8, v8, v10
+; CHECK-NEXT:    vmv.v.x v16, a0
+; CHECK-NEXT:    vmadd.vv v8, v10, v16
 ; CHECK-NEXT:    vmacc.vv v8, v14, v12
-; CHECK-NEXT:    li a0, 32
-; CHECK-NEXT:    vadd.vx v8, v8, a0
 ; CHECK-NEXT:    ret
 entry:
   %mul0 = mul nsw <8 x i32> %m00, %m01
@@ -99,13 +99,13 @@ entry:
 define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
 ; CHECK-LABEL: vmadd_vscale_long:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vmul.vv v8, v8, v9
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v16, a0
+; CHECK-NEXT:    vmadd.vv v8, v9, v16
 ; CHECK-NEXT:    vmacc.vv v8, v11, v10
 ; CHECK-NEXT:    vmacc.vv v8, v13, v12
 ; CHECK-NEXT:    vmacc.vv v8, v15, v14
-; CHECK-NEXT:    li a0, 32
-; CHECK-NEXT:    vadd.vx v8, v8, a0
 ; CHECK-NEXT:    ret
                                              <vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
 entry:
@@ -124,13 +124,13 @@ entry:
 define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
 ; CHECK-LABEL: vmadd_fixed_long:
 ; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
 ; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT:    vmul.vv v8, v8, v10
+; CHECK-NEXT:    vmv.v.x v24, a0
+; CHECK-NEXT:    vmadd.vv v8, v10, v24
 ; CHECK-NEXT:    vmacc.vv v8, v14, v12
 ; CHECK-NEXT:    vmacc.vv v8, v18, v16
 ; CHECK-NEXT:    vmacc.vv v8, v22, v20
-; CHECK-NEXT:    li a0, 32
-; CHECK-NEXT:    vadd.vx v8, v8, a0
 ; CHECK-NEXT:    ret
                                    <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
 entry:

>From e6218f2d5d00c1e70901d00d5e7e8ea4ea8b1d0c Mon Sep 17 00:00:00 2001
From: bababuck <buchner.ryan at gmail.com>
Date: Tue, 18 Nov 2025 22:37:36 -0800
Subject: [PATCH 3/4] Check hasVInstructions() rather than hasStdExtV

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 809abbc69ce90..e3f9e24555dae 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -25664,7 +25664,7 @@ bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
   // Avoid reassociating expressions that can be lowered to vector
   // multiply accumulate (i.e. add (mul x, y), z)
   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
-      (N0.getValueType().isVector() && Subtarget.hasStdExtV()))
+      (N0.getValueType().isVector() && Subtarget.hasVInstructions()))
     return false;
 
   return true;

>From 5f3502f325861da4114bcf4abd23e2d5a19429e2 Mon Sep 17 00:00:00 2001
From: bababuck <buchner.ryan at gmail.com>
Date: Tue, 18 Nov 2025 22:43:53 -0800
Subject: [PATCH 4/4] Remove dead instructions from test

---
 llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
index d7618d1d2bcf7..e2bcd5c08efd2 100644
--- a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -41,7 +41,6 @@ define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscal
 ; CHECK-NEXT:    vmadd.vv v8, v9, v10
 ; CHECK-NEXT:    ret
 entry:
-  %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
   %mul = mul nsw <vscale x 1 x i32> %m00, %m01
   %add = add <vscale x 1 x i32> %mul, splat (i32 32)
   ret <vscale x 1 x i32> %add
@@ -71,7 +70,6 @@ define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i
 ; CHECK-NEXT:    vmacc.vv v8, v11, v10
 ; CHECK-NEXT:    ret
 entry:
-  %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
   %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
   %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
   %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
@@ -109,7 +107,6 @@ define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x
 ; CHECK-NEXT:    ret
                                              <vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
 entry:
-  %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
   %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
   %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
   %mul2 = mul nsw <vscale x 1 x i32> %m20, %m21



More information about the llvm-commits mailing list