[llvm] 11c1827 - [RISCV] Use masked pseudo peephole for reduction pseudos (#71508)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 7 20:46:09 PST 2023


Author: Luke Lau
Date: 2023-11-08T12:46:06+08:00
New Revision: 11c182740a3b57a05a2d8cd2eb398a02165d4a57

URL: https://github.com/llvm/llvm-project/commit/11c182740a3b57a05a2d8cd2eb398a02165d4a57
DIFF: https://github.com/llvm/llvm-project/commit/11c182740a3b57a05a2d8cd2eb398a02165d4a57.diff

LOG: [RISCV] Use masked pseudo peephole for reduction pseudos (#71508)

After #71483 we now have a way of marking masked pseudos as having an
unmasked
equivalent, but their mask shouldn't be folded unless it's all ones
since it
would affect the result.

This patch uses it to mark the pseudos for vredsum and friends, which in
turn
allows us to remove the unmasked patterns, and catch some other forms of
vmerge.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
    llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 83faa5bbef7931f..01a425298c9da28 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -3213,7 +3213,8 @@ multiclass VPseudoTernaryWithTailPolicy<VReg RetClass,
     defvar mx = MInfo.MX;
     let isCommutable = Commutable in
     def "_" # mx # "_E" # sew : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
-    def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>;
+    def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>,
+                                          RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
   }
 }
 
@@ -3232,7 +3233,8 @@ multiclass VPseudoTernaryWithTailPolicyRoundingMode<VReg RetClass,
                                                      Op2Class, Constraint>;
     def "_" # mx # "_E" # sew # "_MASK"
         : VPseudoTernaryMaskPolicyRoundingMode<RetClass, Op1Class,
-                                               Op2Class, Constraint>;
+                                               Op2Class, Constraint>,
+          RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
   }
 }
 

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index d92d3975d12f533..a27719455642a71 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1381,16 +1381,6 @@ multiclass VPatReductionVL<SDNode vop, string instruction_name, bit is_float> {
   foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
     defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
     let Predicates = GetVTypePredicates<vti>.Predicates in {
-      def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
-                                   (vti.Vector vti.RegClass:$rs1), VR:$rs2,
-                                   (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-          (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-              (vti_m1.Vector VR:$merge),
-              (vti.Vector vti.RegClass:$rs1),
-              (vti_m1.Vector VR:$rs2),
-              GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
-
       def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
                                    (vti.Vector vti.RegClass:$rs1), VR:$rs2,
                                    (vti.Mask V0), VLOpFrag,
@@ -1408,19 +1398,6 @@ multiclass VPatReductionVL_RM<SDNode vop, string instruction_name, bit is_float>
   foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
     defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
     let Predicates = GetVTypePredicates<vti>.Predicates in {
-      def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
-                                   (vti.Vector vti.RegClass:$rs1), VR:$rs2,
-                                   (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-          (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-              (vti_m1.Vector VR:$merge),
-              (vti.Vector vti.RegClass:$rs1),
-              (vti_m1.Vector VR:$rs2),
-              // Value to indicate no rounding mode change in
-              // RISCVInsertReadWriteCSR
-              FRM_DYN,
-              GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
-
       def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
                                    (vti.Vector vti.RegClass:$rs1), VR:$rs2,
                                    (vti.Mask V0), VLOpFrag,
@@ -1486,14 +1463,6 @@ multiclass VPatWidenReductionVL<SDNode vop, PatFrags extop, string instruction_n
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1513,18 +1482,6 @@ multiclass VPatWidenReductionVL_RM<SDNode vop, PatFrags extop, string instructio
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2),
-                  // Value to indicate no rounding mode change in
-                  // RISCVInsertReadWriteCSR
-                  FRM_DYN,
-                  GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1548,14 +1505,6 @@ multiclass VPatWidenReductionVL_Ext_VL<SDNode vop, PatFrags extop, string instru
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1575,18 +1524,6 @@ multiclass VPatWidenReductionVL_Ext_VL_RM<SDNode vop, PatFrags extop, string ins
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2),
-                  // Value to indicate no rounding mode change in
-                  // RISCVInsertReadWriteCSR
-                  FRM_DYN,
-                  GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,

diff  --git a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
index 450ab3cbb0dc369..c639f092444fc43 100644
--- a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
@@ -1049,11 +1049,8 @@ define <vscale x 2 x float> @vfredusum(<vscale x 2 x float> %passthru, <vscale x
 define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, i64 %vl) {
 ; CHECK-LABEL: vredsum_allones_mask:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
-; CHECK-NEXT:    vmv1r.v v11, v8
-; CHECK-NEXT:    vredsum.vs v11, v9, v10
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.v.v v8, v11
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, tu, ma
+; CHECK-NEXT:    vredsum.vs v8, v9, v10
 ; CHECK-NEXT:    ret
   %splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0
   %mask = shufflevector <vscale x 2 x i1> %splat, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
@@ -1070,12 +1067,9 @@ define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <v
 define <vscale x 2 x float> @vfredusum_allones_mask(<vscale x 2 x float> %passthru, <vscale x 2 x float> %x, <vscale x 2 x float> %y, i64 %vl) {
 ; CHECK-LABEL: vfredusum_allones_mask:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, tu, ma
 ; CHECK-NEXT:    fsrmi a0, 0
-; CHECK-NEXT:    vmv1r.v v11, v8
-; CHECK-NEXT:    vfredusum.vs v11, v9, v10
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.v.v v8, v11
+; CHECK-NEXT:    vfredusum.vs v8, v9, v10
 ; CHECK-NEXT:    fsrm a0
 ; CHECK-NEXT:    ret
   %splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0


        


More information about the llvm-commits mailing list