[llvm] [RISCV] fold trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y (PR #76550)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 28 18:39:10 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Chia (sun-jacobi)

<details>
<summary>Changes</summary>

This patch aims to use vaaddu for averaging unsigned addition.

## Source code 
```
define <8 x i8> @<!-- -->vaaddu_auto(ptr %x, ptr %y, ptr %z) {
  %xv = load <8 x i8>, ptr %x, align 2
  %yv = load <8 x i8>, ptr %y, align 2
  %xzv = zext <8 x i8> %xv to <8 x i16>
  %yzv = zext <8 x i8> %yv to <8 x i16>
  %add = add nuw nsw <8 x i16> %xzv, %yzv
  %div = lshr <8 x i16> %add, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
  %ret = trunc <8 x i16> %div to <8 x i8>
  ret <8 x i8> %ret 
}
```

## Before this patch 
```
vaaddu_auto: 
        vsetivli        zero, 8, e8, mf2, ta, ma
        vle8.v  v8, (a0)
        vle8.v  v9, (a1)
        vwaddu.vv       v10, v8, v9
        vnsrl.wi        v8, v10, 1
        ret
```
### After this patch 
```
vaaddu_auto: 
	vsetivli	zero, 8, e8, mf2, ta, ma
	vle8.v	v8, (a0)
	vle8.v	v9, (a1)
	csrwi	vxrm, 2
	vaaddu.vv	v8, v8, v9
	ret
```

### Note on signed averaging addition

Based on the rvv spec,  there is also a variant for signed averaging addition called `vaadd`.   
But AFAIU, no matter in which rounding mode, we cannot achieve the semantic of signed averaging addition through `vaadd`.   
Thus this patch only introduces `vaaddu`


---
Full diff: https://github.com/llvm/llvm-project/pull/76550.diff


1 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+29) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 33bdc3366aa3e3..67fbd6141967a6 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2338,6 +2338,35 @@ defm : VPatBinaryVL_VV_VX_VI<riscv_uaddsat_vl, "PseudoVSADDU">;
 defm : VPatBinaryVL_VV_VX<riscv_ssubsat_vl, "PseudoVSSUB">;
 defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
 
+// 12.2 Vector Single-Width Averaging Add and Subtract
+
+// trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y
+class VPatAverageAddTruncShift<string inst_name,
+                               SDNode vop,
+                               VTypeInfo vti,
+                               VTypeInfo wti>
+  : Pat<
+    (vti.Vector (riscv_trunc_vector_vl
+      (wti.Vector (riscv_srl_vl
+        (wti.Vector (vop
+          (vti.Vector vti.RegClass:$rs1),
+          (vti.Vector vti.RegClass:$rs2),
+          (wti.Vector undef), (vti.Mask V0), VLOpFrag)),
+        (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), 1, VLOpFrag)),
+        (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
+      (wti.Mask V0), VLOpFrag)),
+    (!cast<Instruction>(inst_name#"_VV_"#vti.LMul.MX#"_MASK")
+      (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs2,
+      (vti.Mask V0), 0b10, GPR:$vl, vti.Log2SEW, TA_MA)>;
+
+foreach vtiToWti = AllWidenableIntVectors in {
+  defvar vti = vtiToWti.Vti;
+  defvar wti = vtiToWti.Wti;
+  let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                               GetVTypePredicates<wti>.Predicates) in
+    def : VPatAverageAddTruncShift<"PseudoVAADDU", riscv_vwaddu_vl, vti, wti>;
+}
+
 // 13. Vector Floating-Point Instructions
 
 // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions

``````````

</details>


https://github.com/llvm/llvm-project/pull/76550


More information about the llvm-commits mailing list