[llvm] [RISCV] Use masked pseudo peephole for reduction pseudos (PR #71508)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 7 02:27:19 PST 2023


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/71508

After #71483 we now have a way of marking masked pseudos as having an unmasked
equivalent, but their mask shouldn't be folded unless it's all ones since it
would affect the result.

This patch uses it to mark the pseudos for vredsum and friends, which in turn
allows us to remove the unmasked patterns and remove vmerges entirely if it's
known to have an all ones mask.


>From b30e8f6ddb6dec7301693533b97c45bf07dc1519 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Tue, 7 Nov 2023 18:15:17 +0800
Subject: [PATCH] [RISCV] Use masked pseudo peephole for reduction pseudos

After #71483 we now have a way of marking masked pseudos as having an unmasked
equivalent, but their mask shouldn't be folded unless it's all ones since it
would affect the result.

This patch uses it to mark the pseudos for vredsum and friends, which in turn
allows us to remove the unmasked patterns and remove vmerges entirely if it's
known to have an all ones mask.
---
 .../Target/RISCV/RISCVInstrInfoVPseudos.td    |  6 +-
 .../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 63 -------------------
 .../RISCV/rvv/rvv-peephole-vmerge-vops.ll     | 14 ++---
 3 files changed, 8 insertions(+), 75 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 83faa5bbef7931f..01a425298c9da28 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -3213,7 +3213,8 @@ multiclass VPseudoTernaryWithTailPolicy<VReg RetClass,
     defvar mx = MInfo.MX;
     let isCommutable = Commutable in
     def "_" # mx # "_E" # sew : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
-    def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>;
+    def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>,
+                                          RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
   }
 }
 
@@ -3232,7 +3233,8 @@ multiclass VPseudoTernaryWithTailPolicyRoundingMode<VReg RetClass,
                                                      Op2Class, Constraint>;
     def "_" # mx # "_E" # sew # "_MASK"
         : VPseudoTernaryMaskPolicyRoundingMode<RetClass, Op1Class,
-                                               Op2Class, Constraint>;
+                                               Op2Class, Constraint>,
+          RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
   }
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index d92d3975d12f533..a27719455642a71 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1381,16 +1381,6 @@ multiclass VPatReductionVL<SDNode vop, string instruction_name, bit is_float> {
   foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
     defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
     let Predicates = GetVTypePredicates<vti>.Predicates in {
-      def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
-                                   (vti.Vector vti.RegClass:$rs1), VR:$rs2,
-                                   (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-          (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-              (vti_m1.Vector VR:$merge),
-              (vti.Vector vti.RegClass:$rs1),
-              (vti_m1.Vector VR:$rs2),
-              GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
-
       def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
                                    (vti.Vector vti.RegClass:$rs1), VR:$rs2,
                                    (vti.Mask V0), VLOpFrag,
@@ -1408,19 +1398,6 @@ multiclass VPatReductionVL_RM<SDNode vop, string instruction_name, bit is_float>
   foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
     defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
     let Predicates = GetVTypePredicates<vti>.Predicates in {
-      def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
-                                   (vti.Vector vti.RegClass:$rs1), VR:$rs2,
-                                   (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-          (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-              (vti_m1.Vector VR:$merge),
-              (vti.Vector vti.RegClass:$rs1),
-              (vti_m1.Vector VR:$rs2),
-              // Value to indicate no rounding mode change in
-              // RISCVInsertReadWriteCSR
-              FRM_DYN,
-              GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
-
       def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
                                    (vti.Vector vti.RegClass:$rs1), VR:$rs2,
                                    (vti.Mask V0), VLOpFrag,
@@ -1486,14 +1463,6 @@ multiclass VPatWidenReductionVL<SDNode vop, PatFrags extop, string instruction_n
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1513,18 +1482,6 @@ multiclass VPatWidenReductionVL_RM<SDNode vop, PatFrags extop, string instructio
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2),
-                  // Value to indicate no rounding mode change in
-                  // RISCVInsertReadWriteCSR
-                  FRM_DYN,
-                  GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1548,14 +1505,6 @@ multiclass VPatWidenReductionVL_Ext_VL<SDNode vop, PatFrags extop, string instru
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1575,18 +1524,6 @@ multiclass VPatWidenReductionVL_Ext_VL_RM<SDNode vop, PatFrags extop, string ins
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2),
-                  // Value to indicate no rounding mode change in
-                  // RISCVInsertReadWriteCSR
-                  FRM_DYN,
-                  GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
diff --git a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
index 450ab3cbb0dc369..c639f092444fc43 100644
--- a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
@@ -1049,11 +1049,8 @@ define <vscale x 2 x float> @vfredusum(<vscale x 2 x float> %passthru, <vscale x
 define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, i64 %vl) {
 ; CHECK-LABEL: vredsum_allones_mask:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
-; CHECK-NEXT:    vmv1r.v v11, v8
-; CHECK-NEXT:    vredsum.vs v11, v9, v10
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.v.v v8, v11
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, tu, ma
+; CHECK-NEXT:    vredsum.vs v8, v9, v10
 ; CHECK-NEXT:    ret
   %splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0
   %mask = shufflevector <vscale x 2 x i1> %splat, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
@@ -1070,12 +1067,9 @@ define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <v
 define <vscale x 2 x float> @vfredusum_allones_mask(<vscale x 2 x float> %passthru, <vscale x 2 x float> %x, <vscale x 2 x float> %y, i64 %vl) {
 ; CHECK-LABEL: vfredusum_allones_mask:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, tu, ma
 ; CHECK-NEXT:    fsrmi a0, 0
-; CHECK-NEXT:    vmv1r.v v11, v8
-; CHECK-NEXT:    vfredusum.vs v11, v9, v10
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.v.v v8, v11
+; CHECK-NEXT:    vfredusum.vs v8, v9, v10
 ; CHECK-NEXT:    fsrm a0
 ; CHECK-NEXT:    ret
   %splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0



More information about the llvm-commits mailing list