[llvm] 39d4dfb - [RISCV] Incorporate scalar addends to extend vector multiply accumulate chains (#168660)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 21 09:49:20 PST 2025
Author: Ryan Buchner
Date: 2025-11-21T09:49:15-08:00
New Revision: 39d4dfbe55cbea6ca7d506b8acd8455ed0443bf9
URL: https://github.com/llvm/llvm-project/commit/39d4dfbe55cbea6ca7d506b8acd8455ed0443bf9
DIFF: https://github.com/llvm/llvm-project/commit/39d4dfbe55cbea6ca7d506b8acd8455ed0443bf9.diff
LOG: [RISCV] Incorporate scalar addends to extend vector multiply accumulate chains (#168660)
Previously, the following:
%mul0 = mul nsw <8 x i32> %m00, %m01
%mul1 = mul nsw <8 x i32> %m10, %m11
%add0 = add <8 x i32> %mul0, splat (i32 32)
%add1 = add <8 x i32> %add0, %mul1
lowered to:
vsetivli zero, 8, e32, m2, ta, ma
vmul.vv v8, v8, v9
vmacc.vv v8, v11, v10
li a0, 32
vadd.vx v8, v8, a0
After this patch, now lowers to:
li a0, 32
vsetivli zero, 8, e32, m2, ta, ma
vmv.v.x v12, a0
vmadd.vv v8, v9, v12
vmacc.vv v8, v11, v10
Modeled on 0cc981e0 from the AArch64 backend.
C-code for the example case (`clang -O3 -S -mcpu=sifive-x280`):
```
int madd_fail(int a, int b, int * restrict src, int * restrict dst, int loop_bound) {
for (int i = 0; i < loop_bound; i += 2) {
dst[i] = src[i] * a + src[i + 1] * b + 32;
}
}
```
Added:
llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6020fb6ca16ce..dd3225507dde7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -25722,3 +25722,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {
return VT.getSizeInBits() <= Subtarget.getXLen();
}
+
+bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+ SDValue N1) const {
+ if (!N0.hasOneUse())
+ return false;
+
+ // Avoid reassociating expressions that can be lowered to vector
+ // multiply accumulate (i.e. add (mul x, y), z)
+ if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
+ (N0.getValueType().isVector() && Subtarget.hasVInstructions()))
+ return false;
+
+ return true;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 616664306bcab..9b46936f195e6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering {
bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override;
+ /// Control the following reassociation of operands: (op (op x, c1), y) -> (op
+ /// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
+ bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+ SDValue N1) const override;
+
/// Match a mask which "spreads" the leading elements of a vector evenly
/// across the result. Factor is the spread amount, and Index is the
/// offset applied.
diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
new file mode 100644
index 0000000000000..9fa0cec0ea339
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s
+
+define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind {
+; CHECK-LABEL: madd_scalar:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: mul a0, a0, a1
+; CHECK-NEXT: mul a1, a2, a3
+; CHECK-NEXT: add a0, a0, a1
+; CHECK-NEXT: addiw a0, a0, 32
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul i32 %m00, %m01
+ %mul1 = mul i32 %m10, %m11
+ %add0 = add i32 %mul0, 32
+ %add1 = add i32 %add0, %mul1
+ ret i32 %add1
+}
+
+define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) {
+; CHECK-LABEL: vmadd_non_constant:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmadd.vv v8, v10, v16
+; CHECK-NEXT: vmacc.vv v8, v14, v12
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul <8 x i32> %m00, %m01
+ %mul1 = mul <8 x i32> %m10, %m11
+ %add0 = add <8 x i32> %mul0, %addend
+ %add1 = add <8 x i32> %add0, %mul1
+ ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01) {
+; CHECK-LABEL: vmadd_vscale_no_chain:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmv.v.x v10, a0
+; CHECK-NEXT: vmadd.vv v8, v9, v10
+; CHECK-NEXT: ret
+entry:
+ %mul = mul <vscale x 1 x i32> %m00, %m01
+ %add = add <vscale x 1 x i32> %mul, splat (i32 32)
+ ret <vscale x 1 x i32> %add
+}
+
+define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) {
+; CHECK-LABEL: vmadd_fixed_no_chain:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmv.v.x v12, a0
+; CHECK-NEXT: vmadd.vv v8, v10, v12
+; CHECK-NEXT: ret
+entry:
+ %mul = mul <8 x i32> %m00, %m01
+ %add = add <8 x i32> %mul, splat (i32 32)
+ ret <8 x i32> %add
+}
+
+define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
+; CHECK-LABEL: vmadd_vscale:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmv.v.x v12, a0
+; CHECK-NEXT: vmadd.vv v8, v9, v12
+; CHECK-NEXT: vmacc.vv v8, v11, v10
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul <vscale x 1 x i32> %m00, %m01
+ %mul1 = mul <vscale x 1 x i32> %m10, %m11
+ %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+ %add1 = add <vscale x 1 x i32> %add0, %mul1
+ ret <vscale x 1 x i32> %add1
+}
+
+define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
+; CHECK-LABEL: vmadd_fixed:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmv.v.x v16, a0
+; CHECK-NEXT: vmadd.vv v8, v10, v16
+; CHECK-NEXT: vmacc.vv v8, v14, v12
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul <8 x i32> %m00, %m01
+ %mul1 = mul <8 x i32> %m10, %m11
+ %add0 = add <8 x i32> %mul0, splat (i32 32)
+ %add1 = add <8 x i32> %add0, %mul1
+ ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
+; CHECK-LABEL: vmadd_vscale_long:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmv.v.x v16, a0
+; CHECK-NEXT: vmadd.vv v8, v9, v16
+; CHECK-NEXT: vmacc.vv v8, v11, v10
+; CHECK-NEXT: vmacc.vv v8, v13, v12
+; CHECK-NEXT: vmacc.vv v8, v15, v14
+; CHECK-NEXT: ret
+ <vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
+entry:
+ %mul0 = mul <vscale x 1 x i32> %m00, %m01
+ %mul1 = mul <vscale x 1 x i32> %m10, %m11
+ %mul2 = mul <vscale x 1 x i32> %m20, %m21
+ %mul3 = mul <vscale x 1 x i32> %m30, %m31
+ %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+ %add1 = add <vscale x 1 x i32> %add0, %mul1
+ %add2 = add <vscale x 1 x i32> %add1, %mul2
+ %add3 = add <vscale x 1 x i32> %add2, %mul3
+ ret <vscale x 1 x i32> %add3
+}
+
+define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
+; CHECK-LABEL: vmadd_fixed_long:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmv.v.x v24, a0
+; CHECK-NEXT: vmadd.vv v8, v10, v24
+; CHECK-NEXT: vmacc.vv v8, v14, v12
+; CHECK-NEXT: vmacc.vv v8, v18, v16
+; CHECK-NEXT: vmacc.vv v8, v22, v20
+; CHECK-NEXT: ret
+ <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
+entry:
+ %mul0 = mul <8 x i32> %m00, %m01
+ %mul1 = mul <8 x i32> %m10, %m11
+ %mul2 = mul <8 x i32> %m20, %m21
+ %mul3 = mul <8 x i32> %m30, %m31
+ %add0 = add <8 x i32> %mul0, splat (i32 32)
+ %add1 = add <8 x i32> %add0, %mul1
+ %add2 = add <8 x i32> %add1, %mul2
+ %add3 = add <8 x i32> %add2, %mul3
+ ret <8 x i32> %add3
+}
More information about the llvm-commits
mailing list