[llvm] [RISCV] fold trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y (PR #76550)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 28 18:38:41 PST 2023


https://github.com/sun-jacobi created https://github.com/llvm/llvm-project/pull/76550

This patch aims to use vaaddu for averaging unsigned addition.

## Source code 
```
define <8 x i8> @vaaddu_auto(ptr %x, ptr %y, ptr %z) {
  %xv = load <8 x i8>, ptr %x, align 2
  %yv = load <8 x i8>, ptr %y, align 2
  %xzv = zext <8 x i8> %xv to <8 x i16>
  %yzv = zext <8 x i8> %yv to <8 x i16>
  %add = add nuw nsw <8 x i16> %xzv, %yzv
  %div = lshr <8 x i16> %add, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
  %ret = trunc <8 x i16> %div to <8 x i8>
  ret <8 x i8> %ret 
}
```

## Before this patch 
```
vaaddu_auto: 
        vsetivli        zero, 8, e8, mf2, ta, ma
        vle8.v  v8, (a0)
        vle8.v  v9, (a1)
        vwaddu.vv       v10, v8, v9
        vnsrl.wi        v8, v10, 1
        ret
```
### After this patch 
```
vaaddu_auto: 
	vsetivli	zero, 8, e8, mf2, ta, ma
	vle8.v	v8, (a0)
	vle8.v	v9, (a1)
	csrwi	vxrm, 2
	vaaddu.vv	v8, v8, v9
	ret
```

### Note on signed averaging addition

Based on the rvv spec,  there is also a variant for signed averaging addition called `vaadd`.   
But AFAIU, no matter in which rounding mode, we cannot achieve the semantic of signed averaging addition through `vaadd`.   
Thus this patch only introduces `vaaddu`


>From ba41fb8ce0ce997e06ecd19d981bfb8b2596d62c Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 28 Dec 2023 21:37:31 +0900
Subject: [PATCH 1/2] [RISCV][Isel] fold trunc_vl (srl_vl (vwadd X, Y), splat
 1) -> vaadd X, Y

---
 .../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 31 +++++++++++++++++++
 1 file changed, 31 insertions(+)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 33bdc3366aa3e3..bbfeffb384700d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2338,6 +2338,37 @@ defm : VPatBinaryVL_VV_VX_VI<riscv_uaddsat_vl, "PseudoVSADDU">;
 defm : VPatBinaryVL_VV_VX<riscv_ssubsat_vl, "PseudoVSSUB">;
 defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
 
+// 12.2 Vector Single-Width Averaging Add and Subtract
+
+// trunc_vl (srl_vl (vwadd X, Y), splat 1) -> vaadd X, Y
+class VPatAverageAddTruncShift<string inst_name,
+                               SDNode vop,
+                               VTypeInfo vti,
+                               VTypeInfo wti>
+  : Pat<
+    (vti.Vector (riscv_trunc_vector_vl
+      (wti.Vector (riscv_srl_vl
+        (wti.Vector (vop
+          (vti.Vector vti.RegClass:$rs1),
+          (vti.Vector vti.RegClass:$rs2),
+          (wti.Vector undef), (vti.Mask V0), VLOpFrag)),
+        (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), 1, VLOpFrag)),
+        (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
+      (wti.Mask V0), VLOpFrag)),
+    (!cast<Instruction>(inst_name#"_VV_"#vti.LMul.MX#"_MASK")
+      (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs2,
+      (vti.Mask V0), 0b10, GPR:$vl, vti.Log2SEW, TA_MA)>;
+
+foreach vtiToWti = AllWidenableIntVectors in {
+  defvar vti = vtiToWti.Vti;
+  defvar wti = vtiToWti.Wti;
+  let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                               GetVTypePredicates<wti>.Predicates) in {
+    def : VPatAverageAddTruncShift<"PseudoVAADDU", riscv_vwaddu_vl, vti, wti>;
+    def : VPatAverageAddTruncShift<"PseudoVAADD", riscv_vwadd_vl, vti, wti>;
+  }
+}
+
 // 13. Vector Floating-Point Instructions
 
 // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions

>From db7a8e44f1b21203cd9ae7f342495e6d367727c3 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Fri, 29 Dec 2023 11:23:33 +0900
Subject: [PATCH 2/2] [RISCV] remove vaadd folding case

---
 llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index bbfeffb384700d..67fbd6141967a6 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2340,7 +2340,7 @@ defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
 
 // 12.2 Vector Single-Width Averaging Add and Subtract
 
-// trunc_vl (srl_vl (vwadd X, Y), splat 1) -> vaadd X, Y
+// trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y
 class VPatAverageAddTruncShift<string inst_name,
                                SDNode vop,
                                VTypeInfo vti,
@@ -2363,10 +2363,8 @@ foreach vtiToWti = AllWidenableIntVectors in {
   defvar vti = vtiToWti.Vti;
   defvar wti = vtiToWti.Wti;
   let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
-                               GetVTypePredicates<wti>.Predicates) in {
+                               GetVTypePredicates<wti>.Predicates) in
     def : VPatAverageAddTruncShift<"PseudoVAADDU", riscv_vwaddu_vl, vti, wti>;
-    def : VPatAverageAddTruncShift<"PseudoVAADD", riscv_vwadd_vl, vti, wti>;
-  }
 }
 
 // 13. Vector Floating-Point Instructions



More information about the llvm-commits mailing list