[llvm] [RISCV] fold trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y (PR #76550)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 28 18:38:41 PST 2023
https://github.com/sun-jacobi created https://github.com/llvm/llvm-project/pull/76550
This patch aims to use vaaddu for averaging unsigned addition.
## Source code
```
define <8 x i8> @vaaddu_auto(ptr %x, ptr %y, ptr %z) {
%xv = load <8 x i8>, ptr %x, align 2
%yv = load <8 x i8>, ptr %y, align 2
%xzv = zext <8 x i8> %xv to <8 x i16>
%yzv = zext <8 x i8> %yv to <8 x i16>
%add = add nuw nsw <8 x i16> %xzv, %yzv
%div = lshr <8 x i16> %add, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
%ret = trunc <8 x i16> %div to <8 x i8>
ret <8 x i8> %ret
}
```
## Before this patch
```
vaaddu_auto:
vsetivli zero, 8, e8, mf2, ta, ma
vle8.v v8, (a0)
vle8.v v9, (a1)
vwaddu.vv v10, v8, v9
vnsrl.wi v8, v10, 1
ret
```
### After this patch
```
vaaddu_auto:
vsetivli zero, 8, e8, mf2, ta, ma
vle8.v v8, (a0)
vle8.v v9, (a1)
csrwi vxrm, 2
vaaddu.vv v8, v8, v9
ret
```
### Note on signed averaging addition
Based on the rvv spec, there is also a variant for signed averaging addition called `vaadd`.
But AFAIU, no matter in which rounding mode, we cannot achieve the semantic of signed averaging addition through `vaadd`.
Thus this patch only introduces `vaaddu`
>From ba41fb8ce0ce997e06ecd19d981bfb8b2596d62c Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 28 Dec 2023 21:37:31 +0900
Subject: [PATCH 1/2] [RISCV][Isel] fold trunc_vl (srl_vl (vwadd X, Y), splat
1) -> vaadd X, Y
---
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 31 +++++++++++++++++++
1 file changed, 31 insertions(+)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 33bdc3366aa3e3..bbfeffb384700d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2338,6 +2338,37 @@ defm : VPatBinaryVL_VV_VX_VI<riscv_uaddsat_vl, "PseudoVSADDU">;
defm : VPatBinaryVL_VV_VX<riscv_ssubsat_vl, "PseudoVSSUB">;
defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
+// 12.2 Vector Single-Width Averaging Add and Subtract
+
+// trunc_vl (srl_vl (vwadd X, Y), splat 1) -> vaadd X, Y
+class VPatAverageAddTruncShift<string inst_name,
+ SDNode vop,
+ VTypeInfo vti,
+ VTypeInfo wti>
+ : Pat<
+ (vti.Vector (riscv_trunc_vector_vl
+ (wti.Vector (riscv_srl_vl
+ (wti.Vector (vop
+ (vti.Vector vti.RegClass:$rs1),
+ (vti.Vector vti.RegClass:$rs2),
+ (wti.Vector undef), (vti.Mask V0), VLOpFrag)),
+ (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), 1, VLOpFrag)),
+ (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
+ (wti.Mask V0), VLOpFrag)),
+ (!cast<Instruction>(inst_name#"_VV_"#vti.LMul.MX#"_MASK")
+ (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs2,
+ (vti.Mask V0), 0b10, GPR:$vl, vti.Log2SEW, TA_MA)>;
+
+foreach vtiToWti = AllWidenableIntVectors in {
+ defvar vti = vtiToWti.Vti;
+ defvar wti = vtiToWti.Wti;
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in {
+ def : VPatAverageAddTruncShift<"PseudoVAADDU", riscv_vwaddu_vl, vti, wti>;
+ def : VPatAverageAddTruncShift<"PseudoVAADD", riscv_vwadd_vl, vti, wti>;
+ }
+}
+
// 13. Vector Floating-Point Instructions
// 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
>From db7a8e44f1b21203cd9ae7f342495e6d367727c3 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Fri, 29 Dec 2023 11:23:33 +0900
Subject: [PATCH 2/2] [RISCV] remove vaadd folding case
---
llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index bbfeffb384700d..67fbd6141967a6 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2340,7 +2340,7 @@ defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
// 12.2 Vector Single-Width Averaging Add and Subtract
-// trunc_vl (srl_vl (vwadd X, Y), splat 1) -> vaadd X, Y
+// trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y
class VPatAverageAddTruncShift<string inst_name,
SDNode vop,
VTypeInfo vti,
@@ -2363,10 +2363,8 @@ foreach vtiToWti = AllWidenableIntVectors in {
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
- GetVTypePredicates<wti>.Predicates) in {
+ GetVTypePredicates<wti>.Predicates) in
def : VPatAverageAddTruncShift<"PseudoVAADDU", riscv_vwaddu_vl, vti, wti>;
- def : VPatAverageAddTruncShift<"PseudoVAADD", riscv_vwadd_vl, vti, wti>;
- }
}
// 13. Vector Floating-Point Instructions
More information about the llvm-commits
mailing list