[llvm] 39d4dfb - [RISCV] Incorporate scalar addends to extend vector multiply accumulate chains (#168660)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 21 09:49:20 PST 2025


Author: Ryan Buchner
Date: 2025-11-21T09:49:15-08:00
New Revision: 39d4dfbe55cbea6ca7d506b8acd8455ed0443bf9

URL: https://github.com/llvm/llvm-project/commit/39d4dfbe55cbea6ca7d506b8acd8455ed0443bf9
DIFF: https://github.com/llvm/llvm-project/commit/39d4dfbe55cbea6ca7d506b8acd8455ed0443bf9.diff

LOG: [RISCV] Incorporate scalar addends to extend vector multiply accumulate chains (#168660)

Previously, the following:
      %mul0 = mul nsw <8 x i32> %m00, %m01
      %mul1 = mul nsw <8 x i32> %m10, %m11
      %add0 = add <8 x i32> %mul0, splat (i32 32)
      %add1 = add <8 x i32> %add0, %mul1

    lowered to:
      vsetivli zero, 8, e32, m2, ta, ma
      vmul.vv v8, v8, v9
      vmacc.vv v8, v11, v10
      li a0, 32
      vadd.vx v8, v8, a0

    After this patch, now lowers to:
      li a0, 32
      vsetivli zero, 8, e32, m2, ta, ma
      vmv.v.x v12, a0
      vmadd.vv v8, v9, v12
      vmacc.vv v8, v11, v10

Modeled on 0cc981e0 from the AArch64 backend.

C-code for the example case (`clang -O3 -S -mcpu=sifive-x280`):
```
int madd_fail(int a, int b, int * restrict src, int * restrict dst, int loop_bound) {
  for (int i = 0; i < loop_bound; i += 2) {
    dst[i] = src[i] * a + src[i + 1] * b + 32;
  }
}
```

Added: 
    llvm/test/CodeGen/RISCV/vmadd-reassociate.ll

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6020fb6ca16ce..dd3225507dde7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -25722,3 +25722,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {
 
   return VT.getSizeInBits() <= Subtarget.getXLen();
 }
+
+bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+                                              SDValue N1) const {
+  if (!N0.hasOneUse())
+    return false;
+
+  // Avoid reassociating expressions that can be lowered to vector
+  // multiply accumulate (i.e. add (mul x, y), z)
+  if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
+      (N0.getValueType().isVector() && Subtarget.hasVInstructions()))
+    return false;
+
+  return true;
+}

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 616664306bcab..9b46936f195e6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering {
 
   bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override;
 
+  /// Control the following reassociation of operands: (op (op x, c1), y) -> (op
+  /// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
+  bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+                           SDValue N1) const override;
+
   /// Match a mask which "spreads" the leading elements of a vector evenly
   /// across the result.  Factor is the spread amount, and Index is the
   /// offset applied.

diff  --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
new file mode 100644
index 0000000000000..9fa0cec0ea339
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s
+
+define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind {
+; CHECK-LABEL: madd_scalar:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    mul a0, a0, a1
+; CHECK-NEXT:    mul a1, a2, a3
+; CHECK-NEXT:    add a0, a0, a1
+; CHECK-NEXT:    addiw a0, a0, 32
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul i32 %m00, %m01
+  %mul1 = mul i32 %m10, %m11
+  %add0 = add i32 %mul0, 32
+  %add1 = add i32 %add0, %mul1
+  ret i32 %add1
+}
+
+define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) {
+; CHECK-LABEL: vmadd_non_constant:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmadd.vv v8, v10, v16
+; CHECK-NEXT:    vmacc.vv v8, v14, v12
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul <8 x i32> %m00, %m01
+  %mul1 = mul <8 x i32> %m10, %m11
+  %add0 = add <8 x i32> %mul0, %addend
+  %add1 = add <8 x i32> %add0, %mul1
+  ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01) {
+; CHECK-LABEL: vmadd_vscale_no_chain:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v10, a0
+; CHECK-NEXT:    vmadd.vv v8, v9, v10
+; CHECK-NEXT:    ret
+entry:
+  %mul = mul <vscale x 1 x i32> %m00, %m01
+  %add = add <vscale x 1 x i32> %mul, splat (i32 32)
+  ret <vscale x 1 x i32> %add
+}
+
+define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) {
+; CHECK-LABEL: vmadd_fixed_no_chain:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmv.v.x v12, a0
+; CHECK-NEXT:    vmadd.vv v8, v10, v12
+; CHECK-NEXT:    ret
+entry:
+  %mul = mul <8 x i32> %m00, %m01
+  %add = add <8 x i32> %mul, splat (i32 32)
+  ret <8 x i32> %add
+}
+
+define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
+; CHECK-LABEL: vmadd_vscale:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v12, a0
+; CHECK-NEXT:    vmadd.vv v8, v9, v12
+; CHECK-NEXT:    vmacc.vv v8, v11, v10
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul <vscale x 1 x i32> %m00, %m01
+  %mul1 = mul <vscale x 1 x i32> %m10, %m11
+  %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+  %add1 = add <vscale x 1 x i32> %add0, %mul1
+  ret <vscale x 1 x i32> %add1
+}
+
+define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
+; CHECK-LABEL: vmadd_fixed:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmv.v.x v16, a0
+; CHECK-NEXT:    vmadd.vv v8, v10, v16
+; CHECK-NEXT:    vmacc.vv v8, v14, v12
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul <8 x i32> %m00, %m01
+  %mul1 = mul <8 x i32> %m10, %m11
+  %add0 = add <8 x i32> %mul0, splat (i32 32)
+  %add1 = add <8 x i32> %add0, %mul1
+  ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
+; CHECK-LABEL: vmadd_vscale_long:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v16, a0
+; CHECK-NEXT:    vmadd.vv v8, v9, v16
+; CHECK-NEXT:    vmacc.vv v8, v11, v10
+; CHECK-NEXT:    vmacc.vv v8, v13, v12
+; CHECK-NEXT:    vmacc.vv v8, v15, v14
+; CHECK-NEXT:    ret
+                                             <vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
+entry:
+  %mul0 = mul <vscale x 1 x i32> %m00, %m01
+  %mul1 = mul <vscale x 1 x i32> %m10, %m11
+  %mul2 = mul <vscale x 1 x i32> %m20, %m21
+  %mul3 = mul <vscale x 1 x i32> %m30, %m31
+  %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+  %add1 = add <vscale x 1 x i32> %add0, %mul1
+  %add2 = add <vscale x 1 x i32> %add1, %mul2
+  %add3 = add <vscale x 1 x i32> %add2, %mul3
+  ret <vscale x 1 x i32> %add3
+}
+
+define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
+; CHECK-LABEL: vmadd_fixed_long:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmv.v.x v24, a0
+; CHECK-NEXT:    vmadd.vv v8, v10, v24
+; CHECK-NEXT:    vmacc.vv v8, v14, v12
+; CHECK-NEXT:    vmacc.vv v8, v18, v16
+; CHECK-NEXT:    vmacc.vv v8, v22, v20
+; CHECK-NEXT:    ret
+                                   <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
+entry:
+  %mul0 = mul <8 x i32> %m00, %m01
+  %mul1 = mul <8 x i32> %m10, %m11
+  %mul2 = mul <8 x i32> %m20, %m21
+  %mul3 = mul <8 x i32> %m30, %m31
+  %add0 = add <8 x i32> %mul0, splat (i32 32)
+  %add1 = add <8 x i32> %add0, %mul1
+  %add2 = add <8 x i32> %add1, %mul2
+  %add3 = add <8 x i32> %add2, %mul3
+  ret <8 x i32> %add3
+}


        


More information about the llvm-commits mailing list