[llvm] [RISCV] fold trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y (PR #76550)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 28 18:39:10 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Chia (sun-jacobi)
<details>
<summary>Changes</summary>
This patch aims to use vaaddu for averaging unsigned addition.
## Source code
```
define <8 x i8> @<!-- -->vaaddu_auto(ptr %x, ptr %y, ptr %z) {
%xv = load <8 x i8>, ptr %x, align 2
%yv = load <8 x i8>, ptr %y, align 2
%xzv = zext <8 x i8> %xv to <8 x i16>
%yzv = zext <8 x i8> %yv to <8 x i16>
%add = add nuw nsw <8 x i16> %xzv, %yzv
%div = lshr <8 x i16> %add, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
%ret = trunc <8 x i16> %div to <8 x i8>
ret <8 x i8> %ret
}
```
## Before this patch
```
vaaddu_auto:
vsetivli zero, 8, e8, mf2, ta, ma
vle8.v v8, (a0)
vle8.v v9, (a1)
vwaddu.vv v10, v8, v9
vnsrl.wi v8, v10, 1
ret
```
### After this patch
```
vaaddu_auto:
vsetivli zero, 8, e8, mf2, ta, ma
vle8.v v8, (a0)
vle8.v v9, (a1)
csrwi vxrm, 2
vaaddu.vv v8, v8, v9
ret
```
### Note on signed averaging addition
Based on the rvv spec, there is also a variant for signed averaging addition called `vaadd`.
But AFAIU, no matter in which rounding mode, we cannot achieve the semantic of signed averaging addition through `vaadd`.
Thus this patch only introduces `vaaddu`
---
Full diff: https://github.com/llvm/llvm-project/pull/76550.diff
1 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+29)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 33bdc3366aa3e3..67fbd6141967a6 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2338,6 +2338,35 @@ defm : VPatBinaryVL_VV_VX_VI<riscv_uaddsat_vl, "PseudoVSADDU">;
defm : VPatBinaryVL_VV_VX<riscv_ssubsat_vl, "PseudoVSSUB">;
defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
+// 12.2 Vector Single-Width Averaging Add and Subtract
+
+// trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y
+class VPatAverageAddTruncShift<string inst_name,
+ SDNode vop,
+ VTypeInfo vti,
+ VTypeInfo wti>
+ : Pat<
+ (vti.Vector (riscv_trunc_vector_vl
+ (wti.Vector (riscv_srl_vl
+ (wti.Vector (vop
+ (vti.Vector vti.RegClass:$rs1),
+ (vti.Vector vti.RegClass:$rs2),
+ (wti.Vector undef), (vti.Mask V0), VLOpFrag)),
+ (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), 1, VLOpFrag)),
+ (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
+ (wti.Mask V0), VLOpFrag)),
+ (!cast<Instruction>(inst_name#"_VV_"#vti.LMul.MX#"_MASK")
+ (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs2,
+ (vti.Mask V0), 0b10, GPR:$vl, vti.Log2SEW, TA_MA)>;
+
+foreach vtiToWti = AllWidenableIntVectors in {
+ defvar vti = vtiToWti.Vti;
+ defvar wti = vtiToWti.Wti;
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in
+ def : VPatAverageAddTruncShift<"PseudoVAADDU", riscv_vwaddu_vl, vti, wti>;
+}
+
// 13. Vector Floating-Point Instructions
// 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
``````````
</details>
https://github.com/llvm/llvm-project/pull/76550
More information about the llvm-commits
mailing list