[llvm-branch-commits] [RISCV][CodeGen] Combine vwaddu+vabd(u) to vwabdacc(u) (PR #180162)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Feb 6 02:31:13 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Pengcheng Wang (wangpc-pp)
<details>
<summary>Changes</summary>
Note that we only support SEW=8/16 for `vwabdacc(u)`.
---
Full diff: https://github.com/llvm/llvm-project/pull/180162.diff
4 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+44)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+14-2)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td (+21-1)
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll (+14-10)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d46cb575c54c5..171fc391a7aa8 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18770,6 +18770,48 @@ static SDValue combineVWADDSUBWSelect(SDNode *N, SelectionDAG &DAG) {
N->getFlags());
}
+// vwaddu C (vabd A B) -> vwabda(A B C)
+// vwaddu C (vabdu A B) -> vwabdau(A B C)
+static SDValue performVWABDACombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ if (!Subtarget.hasStdExtZvabd())
+ return SDValue();
+
+ MVT VT = N->getSimpleValueType(0);
+ if (VT.getVectorElementType() != MVT::i8 &&
+ VT.getVectorElementType() != MVT::i16)
+ return SDValue();
+
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue Passthru = N->getOperand(2);
+ if (!Passthru->isUndef())
+ return SDValue();
+
+ SDValue Mask = N->getOperand(3);
+ SDValue VL = N->getOperand(4);
+ auto IsABD = [](SDValue Op) {
+ if (Op->getOpcode() != RISCVISD::ABDS_VL &&
+ Op->getOpcode() != RISCVISD::ABDU_VL)
+ return SDValue();
+ return Op;
+ };
+
+ SDValue Diff = IsABD(Op0);
+ Diff = Diff ? Diff : IsABD(Op1);
+ if (!Diff)
+ return SDValue();
+ SDValue Acc = Diff == Op0 ? Op1 : Op0;
+
+ SDLoc DL(N);
+ Acc = DAG.getNode(RISCVISD::VZEXT_VL, DL, VT, Acc, Mask, VL);
+ SDValue Result = DAG.getNode(
+ Diff.getOpcode() == RISCVISD::ABDS_VL ? RISCVISD::VWABDA_VL
+ : RISCVISD::VWABDAU_VL,
+ DL, VT, Diff.getOperand(0), Diff.getOperand(1), Acc, Mask, VL);
+ return Result;
+}
+
static SDValue performVWADDSUBW_VLCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
@@ -21681,6 +21723,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
return V;
return combineToVWMACC(N, DAG, Subtarget);
+ case RISCVISD::VWADDU_VL:
+ return performVWABDACombine(N, DAG, Subtarget);
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
case RISCVISD::VWSUB_W_VL:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 46dd45876a384..d1bcaffdeac5b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1750,8 +1750,9 @@ multiclass VPatMultiplyAddVL_VV_VX<SDNode op, string instruction_name> {
}
}
-multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> {
- foreach vtiTowti = AllWidenableIntVectors in {
+multiclass VPatWidenMultiplyAddVL_VV<SDNode vwmacc_op, string instr_name,
+ list<VTypeInfoToWide> vtilist = AllWidenableIntVectors> {
+ foreach vtiTowti = vtilist in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
@@ -1763,6 +1764,17 @@ multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> {
(!cast<Instruction>(instr_name#"_VV_"#vti.LMul.MX#"_MASK")
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ }
+ }
+}
+
+multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name>
+ : VPatWidenMultiplyAddVL_VV<vwmacc_op, instr_name> {
+ foreach vtiTowti = AllWidenableIntVectors in {
+ defvar vti = vtiTowti.Vti;
+ defvar wti = vtiTowti.Wti;
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in {
def : Pat<(vwmacc_op (SplatPat XLenVT:$rs1),
(vti.Vector vti.RegClass:$rs2),
(wti.Vector wti.RegClass:$rd),
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td
index 139372b70e590..46261d83711cc 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td
@@ -29,7 +29,6 @@ let Predicates = [HasStdExtZvabd] in {
//===----------------------------------------------------------------------===//
// Pseudos
//===----------------------------------------------------------------------===//
-
multiclass PseudoVABS {
foreach m = MxList in {
defvar mx = m.MX;
@@ -44,10 +43,23 @@ multiclass PseudoVABS {
}
}
+multiclass VPseudoVWABD_VV {
+ foreach m = MxListW in {
+ defvar mx = m.MX;
+ defm "" : VPseudoTernaryW_VV<m, Commutable = 1>,
+ SchedTernary<"WriteVIWMulAddV", "ReadVIWMulAddV",
+ "ReadVIWMulAddV", "ReadVIWMulAddV", mx>;
+ }
+}
+
let Predicates = [HasStdExtZvabd] in {
defm PseudoVABS : PseudoVABS;
defm PseudoVABD : VPseudoVALU_VV<Commutable = 1>;
defm PseudoVABDU : VPseudoVALU_VV<Commutable = 1>;
+ let IsRVVWideningReduction = 1 in {
+ defm PseudoVWABDA : VPseudoVWABD_VV;
+ defm PseudoVWABDAU : VPseudoVWABD_VV;
+ } // IsRVVWideningReduction = 1
} // Predicates = [HasStdExtZvabd]
//===----------------------------------------------------------------------===//
@@ -57,12 +69,17 @@ let HasPassthruOp = true, HasMaskOp = true in {
def riscv_abs_vl : RVSDNode<"ABS_VL", SDT_RISCVIntUnOp_VL>;
def riscv_abds_vl : RVSDNode<"ABDS_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
def riscv_abdu_vl : RVSDNode<"ABDU_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
+def rvv_vwabda_vl : RVSDNode<"VWABDA_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
+def rvv_vwabdau_vl : RVSDNode<"VWABDAU_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
} // let HasPassthruOp = true, HasMaskOp = true
// These instructions are defined for SEW=8 and SEW=16, otherwise the instruction
// encoding is reserved.
defvar ABDIntVectors = !filter(vti, AllIntegerVectors, !or(!eq(vti.SEW, 8),
!eq(vti.SEW, 16)));
+defvar ABDAIntVectors = !filter(vtiTowti, AllWidenableIntVectors,
+ !or(!eq(vtiTowti.Vti.SEW, 8),
+ !eq(vtiTowti.Vti.SEW, 16)));
let Predicates = [HasStdExtZvabd] in {
defm : VPatBinarySDNode_VV<abds, "PseudoVABD", ABDIntVectors>;
@@ -79,4 +96,7 @@ foreach vti = AllIntegerVectors in {
}
defm : VPatUnaryVL_V<riscv_abs_vl, "PseudoVABS", HasStdExtZvabd>;
+
+defm : VPatWidenMultiplyAddVL_VV<rvv_vwabda_vl, "PseudoVWABDA", ABDAIntVectors>;
+defm : VPatWidenMultiplyAddVL_VV<rvv_vwabdau_vl, "PseudoVWABDAU", ABDAIntVectors>;
} // Predicates = [HasStdExtZvabd]
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
index 9f6c34cb052ff..dcb8b31c682b3 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
@@ -199,16 +199,18 @@ define signext i32 @sad_2block_16xi8_as_i32(ptr %a, ptr %b, i32 signext %stridea
; ZVABD-NEXT: vle8.v v15, (a1)
; ZVABD-NEXT: add a0, a0, a2
; ZVABD-NEXT: add a1, a1, a3
+; ZVABD-NEXT: vle8.v v16, (a0)
+; ZVABD-NEXT: vle8.v v17, (a1)
; ZVABD-NEXT: vabdu.vv v8, v8, v9
-; ZVABD-NEXT: vle8.v v9, (a0)
-; ZVABD-NEXT: vabdu.vv v10, v10, v11
-; ZVABD-NEXT: vle8.v v11, (a1)
-; ZVABD-NEXT: vwaddu.vv v12, v10, v8
+; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; ZVABD-NEXT: vzext.vf2 v12, v8
+; ZVABD-NEXT: vsetvli zero, zero, e8, m1, ta, ma
+; ZVABD-NEXT: vwabdau.vv v12, v10, v11
; ZVABD-NEXT: vabdu.vv v8, v14, v15
; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma
; ZVABD-NEXT: vzext.vf2 v14, v8
; ZVABD-NEXT: vsetvli zero, zero, e8, m1, ta, ma
-; ZVABD-NEXT: vabdu.vv v16, v9, v11
+; ZVABD-NEXT: vabdu.vv v16, v16, v17
; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma
; ZVABD-NEXT: vwaddu.vv v8, v14, v12
; ZVABD-NEXT: vzext.vf2 v12, v16
@@ -320,16 +322,18 @@ define signext i32 @sadu_2block_16xi8_as_i32(ptr %a, ptr %b, i32 signext %stride
; ZVABD-NEXT: vle8.v v15, (a1)
; ZVABD-NEXT: add a0, a0, a2
; ZVABD-NEXT: add a1, a1, a3
+; ZVABD-NEXT: vle8.v v16, (a0)
+; ZVABD-NEXT: vle8.v v17, (a1)
; ZVABD-NEXT: vabd.vv v8, v8, v9
-; ZVABD-NEXT: vle8.v v9, (a0)
-; ZVABD-NEXT: vabd.vv v10, v10, v11
-; ZVABD-NEXT: vle8.v v11, (a1)
-; ZVABD-NEXT: vwaddu.vv v12, v10, v8
+; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; ZVABD-NEXT: vzext.vf2 v12, v8
+; ZVABD-NEXT: vsetvli zero, zero, e8, m1, ta, ma
+; ZVABD-NEXT: vwabda.vv v12, v10, v11
; ZVABD-NEXT: vabd.vv v8, v14, v15
; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma
; ZVABD-NEXT: vzext.vf2 v14, v8
; ZVABD-NEXT: vsetvli zero, zero, e8, m1, ta, ma
-; ZVABD-NEXT: vabd.vv v16, v9, v11
+; ZVABD-NEXT: vabd.vv v16, v16, v17
; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma
; ZVABD-NEXT: vwaddu.vv v8, v14, v12
; ZVABD-NEXT: vzext.vf2 v12, v16
``````````
</details>
https://github.com/llvm/llvm-project/pull/180162
More information about the llvm-branch-commits
mailing list